From 597e3fbcc74df9defde652419057649bf7bbc3dc Mon Sep 17 00:00:00 2001 From: lisirui1 Date: Thu, 8 Jan 2026 16:47:34 +0800 Subject: [PATCH 1/7] add sunrise backend --- .gitignore | 1 + CMakeLists.txt | 94 +- .../Conversion/TritonGPUToLLVM/Utility.h | 3 + .../TritonToTritonGPU/CMakeLists.txt | 3 +- .../triton/Dialect/Triton/IR/CMakeLists.txt | 12 +- .../Dialect/Triton/Transforms/CMakeLists.txt | 3 +- .../Dialect/TritonGPU/IR/CMakeLists.txt | 6 +- .../TritonGPU/Transforms/CMakeLists.txt | 3 +- lib/Analysis/CMakeLists.txt | 8 + lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 10 + .../TritonToTritonGPU/CMakeLists.txt | 2 + lib/Dialect/TritonGPU/IR/CMakeLists.txt | 2 + lib/Dialect/TritonGPU/IR/Dialect.cpp | 8 + .../TritonGPU/IR/LinearLayoutConversions.cpp | 8 + .../TritonGPU/Transforms/CMakeLists.txt | 9 + lib/Dialect/TritonGPU/Transforms/Coalesce.cpp | 8 + .../Pipeliner/PipeliningUtility.cpp | 8 + lib/Dialect/TritonGPU/Transforms/Prefetch.cpp | 7 + .../Transforms/RemoveLayoutConversions.cpp | 8 + lib/Dialect/TritonGPU/Transforms/Utility.cpp | 8 + lib/Target/LLVMIR/CMakeLists.txt | 2 + setup_tools/setup_helper.py | 2 +- third_party/amd/CMakeLists.txt | 24 +- third_party/sunrise/CMakeLists.txt | 50 + third_party/sunrise/backend/__init__.py | 0 third_party/sunrise/backend/compiler.py | 353 ++ third_party/sunrise/backend/driver.c | 170 + third_party/sunrise/backend/driver.py | 523 +++ third_party/sunrise/backend/include/ptml.h | 925 +++++ third_party/sunrise/backend/include/tang.h | 2321 +++++++++++ .../backend/include/tang_compiler_api.h | 223 ++ .../backend/include/tang_rt/device_assert.h | 43 + .../include/tang_rt/device_functions.h | 8 + .../backend/include/tang_rt/driver_types.h | 1281 +++++++ .../sunrise/backend/include/tang_rt/fmt.hpp | 1097 ++++++ .../backend/include/tang_rt/host_defines.h | 101 + .../backend/include/tang_rt/vector_types.h | 35 + .../sunrise/backend/include/tang_rt/version.h | 30 + .../sunrise/backend/include/tang_runtime.h | 32 + .../backend/include/tang_runtime_api.h | 1871 +++++++++ .../sunrise/backend/include/tapti/tapti.h | 11 + .../backend/include/tapti/tapti_activity.h | 956 +++++ .../backend/include/tapti/tapti_callbacks.h | 109 + .../backend/include/tapti/tapti_driver_cbid.h | 203 + .../backend/include/tapti/tapti_result.h | 248 ++ .../include/tapti/tapti_runtime_cbid.h | 147 + .../backend/include/tapti/tapti_version.h | 39 + third_party/sunrise/backend/spec/__init__.py | 0 .../backend/spec/include/flagtree_spec.h | 12 + .../triton/Dialect/TritonGPU/CMakeLists.txt | 2 + .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 1447 +++++++ .../Dialect/TritonGPU/IR/sunrise_Dialect.h | 10 + .../IR/sunrise_LinearLayoutConversions.h | 10 + .../Pipeliner/sunrise_PipeliningUtility.h | 8 + .../TritonGPU/Transforms/sunrise_Coalesce.h | 6 + .../TritonGPU/Transforms/sunrise_Prefetch.h | 6 + .../sunrise_RemoveLayoutConversions.h | 8 + .../TritonGPU/Transforms/sunrise_Utility.h | 8 + .../sunrise/backend/spec/lib/CMakeLists.txt | 2 + .../spec/lib/Conversion/CMakeLists.txt | 2 + .../TritonGPUToLLVM/AllocateSharedMemory.cpp | 49 + .../TritonGPUToLLVM/AllocateWarpGroups.cpp | 200 + .../TritonGPUToLLVM/AssertOpToLLVM.cpp | 103 + .../Conversion/TritonGPUToLLVM/CMakeLists.txt | 39 + .../TritonGPUToLLVM/ControlFlowOpToLLVM.cpp | 162 + .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 463 +++ .../TritonGPUToLLVM/DotOpToLLVM/FMA.cpp | 38 + .../DotOpToLLVM/FMADotUtility.cpp | 170 + .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 665 ++++ .../TritonGPUToLLVM/FuncOpToLLVM.cpp | 213 ++ .../TritonGPUToLLVM/GatherOpToLLVM.cpp | 350 ++ .../GlobalScratchMemoryAllocation.cpp | 103 + .../TritonGPUToLLVM/HistogramOpToLLVM.cpp | 227 ++ .../TritonGPUToLLVM/MakeRangeOpToLLVM.cpp | 54 + .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 203 + .../TritonGPUToLLVM/PrintOpToLLVM.cpp | 244 ++ .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 391 ++ .../TritonGPUToLLVM/ReduceScanCommon.h | 163 + .../TritonGPUToLLVM/SPMDOpToLLVM.cpp | 38 + .../TritonGPUToLLVM/ScanOpToLLVM.cpp | 573 +++ .../TritonGPUToLLVM/TypeConverter.cpp | 77 + .../Conversion/TritonGPUToLLVM/Utility.cpp | 1301 +++++++ .../TritonGPUToLLVM/ViewOpToLLVM.cpp | 522 +++ .../TritonToTritonGPU/CMakeLists.txt | 16 + .../TritonToTritonGPU/RelayoutTritonGPU.cpp | 130 + .../TritonToTritonGPU/TritonGPUConversion.cpp | 184 + .../TritonToTritonGPUPass.cpp | 821 ++++ .../backend/spec/lib/Dialect/CMakeLists.txt | 2 + .../spec/lib/Dialect/Triton/CMakeLists.txt | 2 + .../spec/lib/Dialect/Triton/IR/CMakeLists.txt | 22 + .../lib/Dialect/Triton/IR/Canonicalize.td | 17 + .../spec/lib/Dialect/Triton/IR/Dialect.cpp | 79 + .../lib/Dialect/Triton/IR/OpInterfaces.cpp | 77 + .../spec/lib/Dialect/Triton/IR/Ops.cpp | 1368 +++++++ .../spec/lib/Dialect/Triton/IR/Traits.cpp | 217 ++ .../spec/lib/Dialect/Triton/IR/Types.cpp | 142 + .../spec/lib/Dialect/Triton/IR/Utility.cpp | 119 + .../Triton/Transforms/ArithTypeConversion.cpp | 50 + .../Dialect/Triton/Transforms/CMakeLists.txt | 27 + .../lib/Dialect/Triton/Transforms/Combine.cpp | 268 ++ .../lib/Dialect/Triton/Transforms/Combine.td | 47 + .../Transforms/FunctionTypeConversion.cpp | 86 + .../Triton/Transforms/LoopAwareCSE.cpp | 176 + .../Transforms/LoopInvariantCodeMotion.cpp | 83 + .../Dialect/Triton/Transforms/LoopPeeling.cpp | 68 + .../Dialect/Triton/Transforms/LoopUnroll.cpp | 62 + .../Triton/Transforms/ReorderBroadcast.cpp | 232 ++ .../RewriteTensorDescriptorToPointer.cpp | 508 +++ .../Transforms/RewriteTensorPointer.cpp | 614 +++ .../spec/lib/Dialect/TritonGPU/CMakeLists.txt | 2 + .../lib/Dialect/TritonGPU/IR/CMakeLists.txt | 14 + .../spec/lib/Dialect/TritonGPU/IR/Dialect.cpp | 3391 +++++++++++++++++ .../TritonGPU/IR/LinearLayoutConversions.cpp | 1958 ++++++++++ .../TritonGPU/Transforms/CMakeLists.txt | 13 + .../Dialect/TritonGPU/Transforms/Coalesce.cpp | 198 + .../Pipeliner/PipeliningUtility.cpp | 760 ++++ .../Dialect/TritonGPU/Transforms/Prefetch.cpp | 468 +++ .../Transforms/RemoveLayoutConversions.cpp | 1689 ++++++++ .../Dialect/TritonGPU/Transforms/Utility.cpp | 1590 ++++++++ third_party/sunrise/python/src/gluon_ir.cc | 458 +++ third_party/sunrise/python/src/interpreter.cc | 740 ++++ third_party/sunrise/python/src/ir.cc | 1888 +++++++++ third_party/sunrise/python/src/ir.h | 105 + third_party/sunrise/python/src/llvm.cc | 530 +++ third_party/sunrise/python/src/main.cc | 57 + third_party/sunrise/python/src/passes.cc | 117 + third_party/sunrise/python/src/passes.h | 38 + .../python/test_examples/01-vector-add.py | 135 + third_party/sunrise/python/triton/__init__.py | 74 + .../sunrise/python/triton/_filecheck.py | 87 + third_party/sunrise/python/triton/_utils.py | 124 + .../python/triton/compiler/__init__.py | 4 + .../python/triton/compiler/code_generator.py | 1507 ++++++++ .../python/triton/compiler/compiler.py | 526 +++ .../sunrise/python/triton/compiler/errors.py | 51 + .../python/triton/compiler/make_launcher.py | 0 third_party/sunrise/python/triton/errors.py | 5 + third_party/sunrise/python/triton/knobs.py | 473 +++ .../python/triton/language/__init__.py | 336 ++ .../sunrise/python/triton/language/core.py | 3325 ++++++++++++++++ .../python/triton/language/extra/__init__.py | 26 + .../sunrise/python/triton/language/extra/cuda | 1 + .../sunrise/python/triton/language/extra/hip | 1 + .../python/triton/language/extra/libdevice.py | 786 ++++ .../sunrise/python/triton/language/extra/tang | 1 + .../sunrise/python/triton/language/math.py | 249 ++ .../sunrise/python/triton/language/random.py | 218 ++ .../python/triton/language/semantic.py | 1886 +++++++++ .../python/triton/language/standard.py | 535 +++ .../sunrise/python/triton/runtime/__init__.py | 23 + .../python/triton/runtime/_allocation.py | 32 + .../python/triton/runtime/autotuner.py | 483 +++ .../sunrise/python/triton/runtime/build.py | 92 + .../sunrise/python/triton/runtime/cache.py | 266 ++ .../sunrise/python/triton/runtime/driver.py | 63 + .../sunrise/python/triton/runtime/errors.py | 36 + .../python/triton/runtime/interpreter.py | 1406 +++++++ .../sunrise/python/triton/runtime/jit.py | 949 +++++ third_party/sunrise/python/triton/testing.py | 543 +++ .../sunrise/python/triton/tools/__init__.py | 0 .../python/triton/tools/build_extern.py | 365 ++ .../sunrise/python/triton/tools/compile.py | 162 + .../sunrise/python/triton/tools/disasm.py | 143 + .../sunrise/python/triton/tools/extra/cuda | 1 + .../sunrise/python/triton/tools/link.py | 322 ++ .../sunrise/python/triton/tools/mxfp.py | 301 ++ .../python/triton/tools/tensor_descriptor.py | 36 + third_party/triton_shared | 1 + 168 files changed, 53868 insertions(+), 23 deletions(-) create mode 100644 third_party/sunrise/CMakeLists.txt create mode 100644 third_party/sunrise/backend/__init__.py create mode 100644 third_party/sunrise/backend/compiler.py create mode 100644 third_party/sunrise/backend/driver.c create mode 100644 third_party/sunrise/backend/driver.py create mode 100755 third_party/sunrise/backend/include/ptml.h create mode 100755 third_party/sunrise/backend/include/tang.h create mode 100755 third_party/sunrise/backend/include/tang_compiler_api.h create mode 100755 third_party/sunrise/backend/include/tang_rt/device_assert.h create mode 100755 third_party/sunrise/backend/include/tang_rt/device_functions.h create mode 100755 third_party/sunrise/backend/include/tang_rt/driver_types.h create mode 100755 third_party/sunrise/backend/include/tang_rt/fmt.hpp create mode 100755 third_party/sunrise/backend/include/tang_rt/host_defines.h create mode 100755 third_party/sunrise/backend/include/tang_rt/vector_types.h create mode 100755 third_party/sunrise/backend/include/tang_rt/version.h create mode 100755 third_party/sunrise/backend/include/tang_runtime.h create mode 100755 third_party/sunrise/backend/include/tang_runtime_api.h create mode 100755 third_party/sunrise/backend/include/tapti/tapti.h create mode 100755 third_party/sunrise/backend/include/tapti/tapti_activity.h create mode 100755 third_party/sunrise/backend/include/tapti/tapti_callbacks.h create mode 100755 third_party/sunrise/backend/include/tapti/tapti_driver_cbid.h create mode 100755 third_party/sunrise/backend/include/tapti/tapti_result.h create mode 100755 third_party/sunrise/backend/include/tapti/tapti_runtime_cbid.h create mode 100755 third_party/sunrise/backend/include/tapti/tapti_version.h create mode 100644 third_party/sunrise/backend/spec/__init__.py create mode 100644 third_party/sunrise/backend/spec/include/flagtree_spec.h create mode 100644 third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/CMakeLists.txt create mode 100644 third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td create mode 100644 third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/IR/sunrise_Dialect.h create mode 100644 third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/IR/sunrise_LinearLayoutConversions.h create mode 100644 third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/Pipeliner/sunrise_PipeliningUtility.h create mode 100644 third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/sunrise_Coalesce.h create mode 100644 third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/sunrise_Prefetch.h create mode 100644 third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/sunrise_RemoveLayoutConversions.h create mode 100644 third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/sunrise_Utility.h create mode 100644 third_party/sunrise/backend/spec/lib/CMakeLists.txt create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/CMakeLists.txt create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/Utility.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonToTritonGPU/CMakeLists.txt create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonToTritonGPU/RelayoutTritonGPU.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/CMakeLists.txt create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/CMakeLists.txt create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/CMakeLists.txt create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Canonicalize.td create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Dialect.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/OpInterfaces.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Ops.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Traits.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Types.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Utility.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/ArithTypeConversion.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/CMakeLists.txt create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/Combine.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/Combine.td create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/LoopAwareCSE.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/LoopInvariantCodeMotion.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/LoopPeeling.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/LoopUnroll.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/CMakeLists.txt create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/IR/CMakeLists.txt create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/IR/Dialect.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp create mode 100644 third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/Utility.cpp create mode 100644 third_party/sunrise/python/src/gluon_ir.cc create mode 100644 third_party/sunrise/python/src/interpreter.cc create mode 100644 third_party/sunrise/python/src/ir.cc create mode 100644 third_party/sunrise/python/src/ir.h create mode 100644 third_party/sunrise/python/src/llvm.cc create mode 100644 third_party/sunrise/python/src/main.cc create mode 100644 third_party/sunrise/python/src/passes.cc create mode 100644 third_party/sunrise/python/src/passes.h create mode 100644 third_party/sunrise/python/test_examples/01-vector-add.py create mode 100644 third_party/sunrise/python/triton/__init__.py create mode 100644 third_party/sunrise/python/triton/_filecheck.py create mode 100644 third_party/sunrise/python/triton/_utils.py create mode 100644 third_party/sunrise/python/triton/compiler/__init__.py create mode 100644 third_party/sunrise/python/triton/compiler/code_generator.py create mode 100644 third_party/sunrise/python/triton/compiler/compiler.py create mode 100644 third_party/sunrise/python/triton/compiler/errors.py create mode 100644 third_party/sunrise/python/triton/compiler/make_launcher.py create mode 100644 third_party/sunrise/python/triton/errors.py create mode 100644 third_party/sunrise/python/triton/knobs.py create mode 100644 third_party/sunrise/python/triton/language/__init__.py create mode 100644 third_party/sunrise/python/triton/language/core.py create mode 100644 third_party/sunrise/python/triton/language/extra/__init__.py create mode 120000 third_party/sunrise/python/triton/language/extra/cuda create mode 120000 third_party/sunrise/python/triton/language/extra/hip create mode 100644 third_party/sunrise/python/triton/language/extra/libdevice.py create mode 120000 third_party/sunrise/python/triton/language/extra/tang create mode 100644 third_party/sunrise/python/triton/language/math.py create mode 100644 third_party/sunrise/python/triton/language/random.py create mode 100644 third_party/sunrise/python/triton/language/semantic.py create mode 100644 third_party/sunrise/python/triton/language/standard.py create mode 100644 third_party/sunrise/python/triton/runtime/__init__.py create mode 100644 third_party/sunrise/python/triton/runtime/_allocation.py create mode 100644 third_party/sunrise/python/triton/runtime/autotuner.py create mode 100644 third_party/sunrise/python/triton/runtime/build.py create mode 100644 third_party/sunrise/python/triton/runtime/cache.py create mode 100644 third_party/sunrise/python/triton/runtime/driver.py create mode 100644 third_party/sunrise/python/triton/runtime/errors.py create mode 100644 third_party/sunrise/python/triton/runtime/interpreter.py create mode 100644 third_party/sunrise/python/triton/runtime/jit.py create mode 100644 third_party/sunrise/python/triton/testing.py create mode 100644 third_party/sunrise/python/triton/tools/__init__.py create mode 100644 third_party/sunrise/python/triton/tools/build_extern.py create mode 100644 third_party/sunrise/python/triton/tools/compile.py create mode 100644 third_party/sunrise/python/triton/tools/disasm.py create mode 120000 third_party/sunrise/python/triton/tools/extra/cuda create mode 100644 third_party/sunrise/python/triton/tools/link.py create mode 100644 third_party/sunrise/python/triton/tools/mxfp.py create mode 100644 third_party/sunrise/python/triton/tools/tensor_descriptor.py create mode 160000 third_party/triton_shared diff --git a/.gitignore b/.gitignore index fe103b6ed..bb7ed373f 100644 --- a/.gitignore +++ b/.gitignore @@ -67,6 +67,7 @@ ptxas # Third-party include third_party/nvidia/backend/include third_party/nvidia/backend/lib/cupti +third_party/sunrise/backend/lib # Docs docs/_build/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 3a51028ea..ff6d08133 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,12 +38,43 @@ elseif(FLAGTREE_BACKEND STREQUAL "aipu") elseif(FLAGTREE_BACKEND STREQUAL "tsingmicro") set(CMAKE_C_COMPILER clang) set(CMAKE_CXX_COMPILER clang++) +elseif(FLAGTREE_BACKEND STREQUAL "sunrise") + # remove_definitions(-D_GLIBCXX_USE_CXX11_ABI=1) + # add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) + find_package(Python3 3.10 REQUIRED COMPONENTS Development.Module Interpreter) + if(EDITABLE_MODE) + set (DEFAULT_PLUGIN_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/sunrise") + else() + set (DEFAULT_PLUGIN_DIR "${Python3_SITELIB}/triton/_C") + endif() + add_definitions(-DDEFAULT_PLUGIN_DIR="${DEFAULT_PLUGIN_DIR}") endif() set(FLAGTREE_PLUGIN "$ENV{FLAGTREE_PLUGIN}") if(FLAGTREE_PLUGIN) add_definitions(-D__FLAGTREE_PLUGIN__) endif() +# FLAGTREE SPEC LIB GET FUNC +function(get_flagtree_backend_lib lib_name output_lib) + set(ret FlagTree_${FLAGTREE_BACKEND}_${lib_name}) + if(NOT TARGET ${ret}) + set(ret "") + endif() + set(${output_lib} ${ret} PARENT_SCOPE) +endfunction() + +# FLAGTREE SPEC TD FILE GET FUNC +function(set_flagtree_backend_td output_td td_filename) + set(ret ${td_filename}) + file(RELATIVE_PATH relative_path "${PROJECT_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}") + get_filename_component(BACKEND_SPEC_ROOT "${BACKEND_SPEC_INCLUDE_DIR}" DIRECTORY) + set(BACKEND_SPEC_TD ${BACKEND_SPEC_ROOT}/${relative_path}/${td_filename}) + if(EXISTS ${BACKEND_SPEC_TD}) + set(ret ${BACKEND_SPEC_TD}) + endif() + set(${output_td} ${ret} PARENT_SCOPE) +endfunction() + project(triton CXX C) include(CTest) @@ -119,12 +150,20 @@ if(TRITON_BUILD_UT) endif() # Compiler flags -set(BACKEND_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND}/include) +set(FLAGTREE_BACKEND_DIR ${PROJECT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND}) +## flagtree spec include dir +set(BACKEND_SPEC_INCLUDE_DIR ${FLAGTREE_BACKEND_DIR}/backend/spec/include) +if(FLAGTREE_BACKEND AND EXISTS ${BACKEND_SPEC_INCLUDE_DIR}) + include_directories(${BACKEND_SPEC_INCLUDE_DIR}) +endif() +## flagtree third_party include dir +set(BACKEND_INCLUDE_DIR ${FLAGTREE_BACKEND_DIR}/include) if(FLAGTREE_BACKEND AND EXISTS "${BACKEND_INCLUDE_DIR}") include_directories(${BACKEND_INCLUDE_DIR}) else() include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) endif() + if(NOT MSVC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17") else() @@ -378,6 +417,53 @@ if(TRITON_BUILD_PYTHON_MODULE) LLVMRISCVCodeGen LLVMRISCVAsmParser ) + elseif(FLAGTREE_BACKEND STREQUAL "sunrise") + set(TRITON_LIBRARIES + ${triton_libs} + ${triton_plugins} + + # mlir + # MLIRAMDGPUDialect + # MLIRNVVMDialect + MLIRSTVMDialect # STVM + MLIRNVVMToLLVMIRTranslation + MLIRSTVMToLLVMIRTranslation + MLIRGPUToNVVMTransforms + MLIRGPUToSTVMTransforms + MLIRGPUToGPURuntimeTransforms + MLIRGPUTransforms + MLIRIR + MLIRControlFlowToLLVM + MLIRBytecodeWriter + MLIRPass + MLIRTransforms + MLIRLLVMDialect + MLIRSupport + MLIRTargetLLVMIRExport + MLIRMathToLLVM + # MLIRROCDLToLLVMIRTranslation + MLIRGPUDialect + MLIRSCFToControlFlow + MLIRIndexToLLVM + MLIRGPUToROCDLTransforms + MLIRUBToLLVM + + # LLVM + LLVMPasses + # LLVMNVPTXCodeGen + # LLVMAMDGPUCodeGen + # LLVMAMDGPUAsmParser + LLVMSTCUCodeGen + LLVMSTCUAsmParser + LLVMAArch64CodeGen + LLVMAArch64AsmParser + LLVMRISCVCodeGen + LLVMRISCVAsmParser + + Python3::Module + pybind11::headers + ) + endif() if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64" OR # Linux arm64 @@ -424,7 +510,8 @@ if(TRITON_BUILD_PYTHON_MODULE) endif() # Link triton with its dependencies - target_link_libraries(triton PRIVATE ${TRITON_LIBRARIES}) + #target_link_libraries(triton PRIVATE ${TRITON_LIBRARIES}) + target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES}) if(WIN32) target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS}) set_target_properties(triton PROPERTIES SUFFIX ".pyd") @@ -450,7 +537,8 @@ if(TRITON_BUILD_PYTHON_MODULE) endif() if (UNIX AND NOT APPLE) - set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs,ALL") + #set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs,ALL") + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--export-dynamic") endif() if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 04ca702fc..cbba806f5 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -472,6 +472,9 @@ using ::mlir::triton::gpu::CTALayoutAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr; +#ifdef FLAGTREE_SPEC_BackendMmaEncodingAttr +using FLAGTREE_SPEC_BackendMmaEncodingAttr; +#endif Value dot(RewriterBase &rewriter, Location loc, ArrayRef offsets, ArrayRef strides); diff --git a/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt b/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt index 99d90c4d7..c648250d9 100644 --- a/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt +++ b/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -1,3 +1,4 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) +#set(LLVM_TARGET_DEFINITIONS Passes.td) +set_flagtree_backend_td(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonGPU) add_public_tablegen_target(TritonConversionPassIncGen) diff --git a/include/triton/Dialect/Triton/IR/CMakeLists.txt b/include/triton/Dialect/Triton/IR/CMakeLists.txt index fecd5adf6..669a88dab 100644 --- a/include/triton/Dialect/Triton/IR/CMakeLists.txt +++ b/include/triton/Dialect/Triton/IR/CMakeLists.txt @@ -1,6 +1,13 @@ set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) -set(LLVM_TARGET_DEFINITIONS TritonOps.td) +# set(LLVM_TARGET_DEFINITIONS TritonOps.td) +set_flagtree_backend_td(LLVM_TARGET_DEFINITIONS TritonOps.td) +# mlir_tablegen(Ops.h.inc -gen-op-decls) +# mlir_tablegen(Ops.cpp.inc -gen-op-defs) +# mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +# mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +# add_mlir_doc(TritonOps TritonOps dialects/ -gen-op-doc) + mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) @@ -12,7 +19,8 @@ mlir_tablegen(Dialect.h.inc -gen-dialect-decls) mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) add_mlir_doc(TritonDialect TritonDialect dialects/ -gen-dialect-doc) -set(LLVM_TARGET_DEFINITIONS TritonTypes.td) +#set(LLVM_TARGET_DEFINITIONS TritonTypes.td) +set_flagtree_backend_td(LLVM_TARGET_DEFINITIONS TritonTypes.td) mlir_tablegen(Types.h.inc -gen-typedef-decls) mlir_tablegen(Types.cpp.inc -gen-typedef-defs) diff --git a/include/triton/Dialect/Triton/Transforms/CMakeLists.txt b/include/triton/Dialect/Triton/Transforms/CMakeLists.txt index 372a9ec11..996f71254 100644 --- a/include/triton/Dialect/Triton/Transforms/CMakeLists.txt +++ b/include/triton/Dialect/Triton/Transforms/CMakeLists.txt @@ -1,3 +1,4 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) +#set(LLVM_TARGET_DEFINITIONS Passes.td) +set_flagtree_backend_td(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name Triton) add_public_tablegen_target(TritonTransformsIncGen) diff --git a/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt b/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt index a211c7bc8..9d0880983 100644 --- a/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt @@ -1,6 +1,7 @@ set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) -set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td) +# set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td) +set_flagtree_backend_td(LLVM_TARGET_DEFINITIONS TritonGPUOps.td) mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=ttg) mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=ttg) mlir_tablegen(Ops.h.inc -gen-op-decls) @@ -11,7 +12,7 @@ add_mlir_doc(TritonGPUDialect TritonGPUDialect dialects/ -gen-dialect-doc) add_mlir_doc(TritonGPUOps TritonGPUOps dialects/ -gen-op-doc) add_public_tablegen_target(TritonGPUTableGen) -set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td) +set_flagtree_backend_td(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td) mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) mlir_tablegen(AttrDefs.h.inc -gen-attrdef-decls) @@ -21,6 +22,7 @@ mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(TritonGPUAttrDefsIncGen) set(LLVM_TARGET_DEFINITIONS TritonGPUTypeInterfaces.td) +#set_flagtree_backend_td(LLVM_TARGET_DEFINITIONS TritonGPUTypeInterfaces.td) mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls) mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs) add_public_tablegen_target(TritonGPUTypeInterfacesIncGen) diff --git a/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt b/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt index 6be94d1a8..ce67dca90 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -1,3 +1,4 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) +#set(LLVM_TARGET_DEFINITIONS Passes.td) +set_flagtree_backend_td(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonGPU) add_public_tablegen_target(TritonGPUTransformsIncGen) diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index 693d222f2..ae1c60067 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -1,3 +1,11 @@ +if (FLAGTREE_BACKEND) + set(NVGPUIR "") +else() + set(NVGPUIR "TritonNvidiaGPUIR") +endif() + +get_flagtree_backend_lib("TritonAnalysis" _EXTRA_LINK_LIBS) + add_triton_library(TritonAnalysis AxisInfo.cpp Allocation.cpp diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 2e2412025..081870fd3 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -1,3 +1,13 @@ +if (FLAGTREE_BACKEND) + set(NVGPUIR "") + set(NVGPUTransforms "") +else() + set(NVGPUIR "NVGPUIR") + set(NVGPUTransforms "TritonNvidiaGPUTransforms") +endif() + +get_flagtree_backend_lib("TritonGPUToLLVM" _EXTRA_LINK_LIBS) + add_triton_library(TritonGPUToLLVM DotOpToLLVM/FMA.cpp DotOpToLLVM/FMADotUtility.cpp diff --git a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt index ed879c7dd..b8a2a1297 100644 --- a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt +++ b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -1,3 +1,5 @@ +get_flagtree_backend_lib("TritonToTritonGPU" _EXTRA_LINK_LIBS) + add_triton_library(TritonToTritonGPU RelayoutTritonGPU.cpp TritonGPUConversion.cpp diff --git a/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/lib/Dialect/TritonGPU/IR/CMakeLists.txt index be1ee4ae8..855a7162d 100644 --- a/lib/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -1,3 +1,5 @@ +get_flagtree_backend_lib("TritonGPUIR" _EXTRA_LINK_LIBS) + add_triton_library(TritonGPUIR Dialect.cpp LinearLayoutConversions.cpp diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index d573daddd..37af4c6a4 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1,3 +1,9 @@ +#if __has_include("flagtree_spec.h") +#include "flagtree_spec.h" +#endif + +#ifndef FLAGTREE_SPEC_Dialect_TritonGPU_IR_Dialect_cpp + #include "triton/Dialect/Triton/IR/Dialect.h" #include @@ -3215,3 +3221,5 @@ LinearLayout triton::gpu::inferReshapeLinearLayout(ArrayRef srcShape, auto dst = reshapeLayout(ctx, src, dstShape); return dst; } + +#endif// FLAGTREE_SPEC_Dialect_TritonGPU_IR_Dialect_cpp diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 0ac56a8a7..038a384fc 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -1,3 +1,9 @@ +#if __has_include("flagtree_spec.h") +#include "flagtree_spec.h" +#endif + +#ifndef FLAGTREE_SPEC_Triton_Dialect_TritonGPU_IR_sunrise_LinearLayoutConversion + #include #include "triton/Dialect/Triton/IR/Dialect.h" @@ -1820,3 +1826,5 @@ LinearLayout getTmemLoadLayoutSplitLongM(int M, int N, RankedTensorType oldType, } } // namespace mlir::triton::gpu + +#endif//FLAGTREE_SPEC_Triton_Dialect_TritonGPU_IR_sunrise_LinearLayoutConversion diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 0fa0b324a..4d82119d6 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -1,3 +1,12 @@ +if(IS_COMPILE_TritonNvidiaGPU) + set(_TMA_LINK_CPP "Pipeliner/TMAStoresPipeline.cpp") + set(_NVIDIA_LINK_LIBS "TritonNvidiaGPUIR") +else() + set(_TMA_LINK_CPP "") + set(_NVIDIA_LINK_LIBS "") +endif() +get_flagtree_backend_lib("TritonGPUTransforms" _EXTRA_LINK_LIBS) + add_triton_library(TritonGPUTransforms AccelerateMatmul.cpp Canonicalize.cpp diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index c9545f043..deb2f5110 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -1,3 +1,9 @@ +#if __has_include("flagtree_spec.h") +#include "flagtree_spec.h" +#endif + +#ifndef FLAGTREE_SPEC_Triton_Dialect_TritonGPU_Transforms_Sunrise_Coalesce + #include #include @@ -193,3 +199,5 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase { } // namespace gpu } // namespace triton } // namespace mlir + +#endif//FLAGTREE_SPEC_Triton_Dialect_TritonGPU_Transforms_Sunrise_Coalesce diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index eb4a4067b..f17679b14 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -1,3 +1,9 @@ +#if __has_include("flagtree_spec.h") +#include "flagtree_spec.h" +#endif + +#ifndef FLAGTREE_SPEC_Dialect_TritonGPU_Transforms_PipeliningUtility + #include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -743,3 +749,5 @@ scf::ForOp triton::lowerTMADescriptors(scf::ForOp forOp, } return forOp; } + +#endif//FLAGTREE_SPEC_Dialect_TritonGPU_Transforms_PipeliningUtility diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 248871ebd..62a30d646 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -25,6 +25,11 @@ // scf.yield %next_a, ..., %a_prefetch_next // } //===----------------------------------------------------------------------===// +#if __has_include("flagtree_spec.h") +#include "flagtree_spec.h" +#endif + +#ifndef FLAGTREE_SPEC_triton_Dialect_TritonGPU_Transforms_sunrise_Prefetch #include "mlir/IR/IRMapping.h" #include "mlir/Support/LLVM.h" @@ -459,3 +464,5 @@ struct PrefetchPass : public impl::TritonGPUPrefetchBase { } // namespace gpu } // namespace triton } // namespace mlir + +#endif// FLAGTREE_SPEC_triton_Dialect_TritonGPU_Transforms_sunrise_Prefetch diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index a7ab9f8b4..acbb03375 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -1,3 +1,9 @@ +#if __has_include("flagtree_spec.h") +#include "flagtree_spec.h" +#endif + +#ifndef FLAGTREE_SPEC_triton_Dialect_TritonGPU_Transforms_sunrise_RemoveLayoutConversion + #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinAttributes.h" @@ -1675,3 +1681,5 @@ class TritonGPURemoveLayoutConversionsPass }; } // namespace mlir::triton::gpu + +#endif// #ifndef FLAGTREE_SPEC_triton_Dialect_TritonGPU_Transforms_sunrise_RemoveLayoutConversion diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index e08236274..6e9f49366 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -1,3 +1,9 @@ +#if __has_include("flagtree_spec.h") +#include "flagtree_spec.h" +#endif + +#ifndef FLAGTREE_SPEC_triton_Dialect_TritonGPU_Transforms_sunrise_Utility + #include "triton/Analysis/Utility.h" #include @@ -1584,3 +1590,5 @@ bool comesFromLoadOrBlockArg(Value v) { } } // namespace mlir::triton + +#endif//FLAGTREE_SPEC_triton_Dialect_TritonGPU_Transforms_sunrise_Utility diff --git a/lib/Target/LLVMIR/CMakeLists.txt b/lib/Target/LLVMIR/CMakeLists.txt index f2f9adf8f..c3a0010b8 100644 --- a/lib/Target/LLVMIR/CMakeLists.txt +++ b/lib/Target/LLVMIR/CMakeLists.txt @@ -1,3 +1,5 @@ +get_flagtree_backend_lib("TritonLLVMIR" _EXTRA_LINK_LIBS) + add_triton_library(TritonLLVMIR LLVMDIScope.cpp LLVMIRBreakPhiStruct.cpp diff --git a/setup_tools/setup_helper.py b/setup_tools/setup_helper.py index 71145a0fe..f79989c93 100644 --- a/setup_tools/setup_helper.py +++ b/setup_tools/setup_helper.py @@ -14,7 +14,7 @@ flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower() flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower() offline_build = os.getenv("FLAGTREE_PLUGIN", "OFF") -device_mapping = {"xpu": "xpu", "mthreads": "musa", "ascend": "ascend"} +device_mapping = {"xpu": "xpu", "mthreads": "musa", "ascend": "ascend", "sunrise":"sunrise"} activated_module = utils.activate(flagtree_backend) downloader = utils.tools.DownloadManager() diff --git a/third_party/amd/CMakeLists.txt b/third_party/amd/CMakeLists.txt index b030dbbd1..a2f2f7c87 100644 --- a/third_party/amd/CMakeLists.txt +++ b/third_party/amd/CMakeLists.txt @@ -1,12 +1,12 @@ -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) -include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) -add_subdirectory(include) -add_subdirectory(lib) -if(TRITON_BUILD_PYTHON_MODULE) - add_triton_plugin(TritonAMD ${CMAKE_CURRENT_SOURCE_DIR}/python/triton_amd.cc LINK_LIBS TritonAMDGPUToLLVM TritonAMDGPUTransforms TritonAMDGPUDialectToLLVM) - target_link_libraries(TritonAMD PRIVATE Python3::Module pybind11::headers) -endif() -if(TRITON_BUILD_UT) - add_subdirectory(unittest) -endif() -add_subdirectory(test) +# include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +# include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +# add_subdirectory(include) +# add_subdirectory(lib) +# if(TRITON_BUILD_PYTHON_MODULE) +# add_triton_plugin(TritonAMD ${CMAKE_CURRENT_SOURCE_DIR}/python/triton_amd.cc LINK_LIBS TritonAMDGPUToLLVM TritonAMDGPUTransforms TritonAMDGPUDialectToLLVM) +# target_link_libraries(TritonAMD PRIVATE Python3::Module pybind11::headers) +# endif() +# if(TRITON_BUILD_UT) +# add_subdirectory(unittest) +# endif() +# add_subdirectory(test) diff --git a/third_party/sunrise/CMakeLists.txt b/third_party/sunrise/CMakeLists.txt new file mode 100644 index 000000000..d7f0e11e3 --- /dev/null +++ b/third_party/sunrise/CMakeLists.txt @@ -0,0 +1,50 @@ +add_compile_options("-Wno-deprecated-declarations") +add_compile_options("-Wno-error=deprecated-declarations") + +option(EDITABLE_MODE "Build in developer (editable) mode" OFF) +if(FLAGTREE_PLUGIN) + set(SUNRISE_PLUGIN_DIR "${Python3_SITELIB}/triton/_C") +elseif(EDITABLE_MODE) + set(SUNRISE_PLUGIN_DIR "${CMAKE_CURRENT_SOURCE_DIR}") +else() + set(SUNRISE_PLUGIN_DIR "${Python3_SITELIB}/triton/_C") +endif() + +if(TRITON_BUILD_PYTHON_MODULE) + if(FLAGTREE_PLUGIN) + add_subdirectory(plugin) + add_triton_plugin(TritonSunrise + SHARED_LIB sunriseTritonPlugin + ) + else() + if(EDITABLE_MODE) + find_library(sunriseTritonPluginLib + NAMES + sunriseTritonPlugin.so + PATHS + ${SUNRISE_PLUGIN_DIR} + REQUIRED + ) + add_triton_plugin(TritonSunrise + SHARED_LIB ${sunriseTritonPluginLib} + ) + else() + find_library(sunriseTritonPluginLib + NAMES + sunriseTritonPlugin.so + PATHS + ${SUNRISE_PLUGIN_DIR} + REQUIRED + ) + add_triton_plugin(TritonSunrise + SHARED_LIB ${sunriseTritonPluginLib} + ) + endif() + endif() +endif() + +add_subdirectory(backend/spec/lib) +add_subdirectory(${PROJECT_SOURCE_DIR}/include + ${PROJECT_BINARY_DIR}/include) +add_subdirectory(${PROJECT_SOURCE_DIR}/lib + ${PROJECT_BINARY_DIR}/lib) diff --git a/third_party/sunrise/backend/__init__.py b/third_party/sunrise/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/sunrise/backend/compiler.py b/third_party/sunrise/backend/compiler.py new file mode 100644 index 000000000..51f451754 --- /dev/null +++ b/third_party/sunrise/backend/compiler.py @@ -0,0 +1,353 @@ +from triton.backends.compiler import BaseBackend, GPUTarget, Language +from triton._C.libtriton import ir, passes, llvm, sunrise +from triton import knobs +from dataclasses import dataclass +import functools +from typing import Any, Dict, Tuple +from types import ModuleType +import hashlib +import platform +import re +import tempfile +import os +import subprocess +from pathlib import Path + +def min_dot_size(target: GPUTarget): + return lambda lhsType, rhsType: (8, 8, 16) if lhsType.is_int8() else (8, 8, 4) + +@functools.lru_cache(None) +def file_hash(path): + with open(path, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + +@dataclass(frozen=True) +class SunriseOptions: + num_warps: int = 4 + num_ctas: int = 1 + num_stages: int = 3 + cluster_dims: tuple = (1, 1, 1) + enable_fp_fusion: bool = True + supported_fp8_dtypes: Tuple[str] = ("fp8e5", ) + deprecated_fp8_dot_operand_dtypes: Tuple[str] = () + default_dot_input_precision: str = "ieee" + allowed_dot_input_precisions: Tuple[str] = ("ieee", ) + max_num_imprecise_acc_default: bool = None + extern_libs: dict = None + debug: bool = False + backend_name: str = 'tang' + sanitize_overflow: bool = True + arch: str = None + + # 当前s2上没有响应的libdivice库,需要怎么编译出来?? + def __post_init__(self): + warp_size = 32 + object.__setattr__(self, 'warp_size', warp_size) + default_libdir = Path(__file__).parent / 'lib' + extern_libs ={} if self.extern_libs is None else dict(self.extern_libs) + for lib in ["ocml", "ockl"]: + extern_libs[lib] = str(default_libdir / f'{lib}.bc') + object.__setattr__(self, 'extern_libs', tuple(extern_libs.items())) + assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \ + "num_warps must be a power of 2" + + def hash(self): + hash_dict = dict(self.__dict__) + hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"])) + key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +class SunriseBackend(BaseBackend): + + @staticmethod + def supports_target(target: GPUTarget): + return target.backend == 'tang' + + def get_target_name(self, options) -> str: + return f"tang:{options.arch}" # tang:S2 + + def __init__(self, target: GPUTarget) -> None: + super().__init__(target) + self.binary_ext = 'stcu' + + def parse_options(self, opts) -> Any: + args = {'arch': knobs.runtime.override_arch or self.target.arch} + if "enable_fp_fusion" not in opts: + args["enable_fp_fusion"] = knobs.language.default_fp_fusion + args["max_num_imprecise_acc_default"] = 0 # TODO + args.update({k: opts[k] for k in SunriseOptions.__dataclass_fields__.keys() \ + if k in opts and opts[k] is not None}) + return SunriseOptions(**args) + + def pack_metadata(self, metadata): + return ( + metadata.num_warps, + metadata.num_ctas, + metadata.shared, + metadata.cluster_dims[0], + metadata.cluster_dims[1], + metadata.cluster_dims[2], + ) + + def get_codegen_implementation(self, options): + codegen_fns = { + "min_dot_size": min_dot_size(self.target) + } + return codegen_fns + + def get_module_map(self) -> Dict[str, ModuleType]: + from triton.language.extra.tang import libdevice + return {"triton.language.extra.libdevice": libdevice} + + def load_dialects(self, ctx): + sunrise.load_dialects(ctx) + + def path_to_clang_offload_bundler(): + lld_env_path = knobs.sunrise.lld_path + if lld_env_path is not None: + lld = Path(lld_env_path) + if lld.is_file(): + return lld + arch = platform.machine() + lld = Path(f"/usr/local/tangrt/toolchains/llvm/prebuilt/linux-{arch}/bin/clang-offload-bundler") + if lld.is_file(): + return lld + raise Exception("clang-offload-bundler not found. Set 'TRITON_SUNRISE_LLD_PATH' to its path.") + + @staticmethod + def get_triple(): + triple = knobs.sunrise.triple + if triple is None or triple == '': + return "stcu-unknown-tang" + return triple + + @staticmethod + def get_flag(metadata, opt): + flag = knobs.sunrise.flag + if flag is None or flag == []: + flag = ['enable-predicate'] + if isinstance(flag, str): + flag = flag.split() + if metadata["num_warps"] > 16: + flag.append('thread-regfile-size=64') + for name, path in opt.extern_libs: + if name == "ockl": + flag.append('ocklPath='+path) + return flag + + @staticmethod + def get_optimization_level(llvm): + opt = knobs.sunrise.opt_level + if int(opt) == 0: + return llvm.OPTIMIZE_O0 + if int(opt) == 1: + return llvm.OPTIMIZE_O1 + if int(opt) == 2: + return llvm.OPTIMIZE_O2 + return llvm.OPTIMIZE_O3 + + @staticmethod + def make_ttir(mod, metadata, opt): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_rewrite_tensor_pointer(pm) + passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_combine(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.ttir.add_triton_licm(pm) + passes.common.add_symbol_dce(pm) + passes.ttir.add_loop_unroll(pm) + pm.run(mod) + return mod + + @staticmethod + def make_ttgir(mod, metadata, opt, capability): + num_stages = opt.num_stages if opt.num_stages <= 3 else 3 + # TTIR -> TTGIR + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.ttir.add_convert_to_ttgpuir(pm, f"tang:{capability}", opt.num_warps, 32, opt.num_ctas) + # optimize TTGIR + passes.ttgpuir.add_coalesce(pm) + # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass + # nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_thread_locality(pm) + sunrise.passes.ttgpuir.add_combine_optimize(pm) + if os.getenv('OFF_MMA', '0') == '1': + print('not run accelerate_matmul pass') + else: + sunrise.passes.ttgpuir.add_accelerate_matmul(pm, 1, 0) # 版本:1.0 + sunrise.passes.ttgpuir.add_mma_direct_store(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + # passes.ttgpuir.add_optimize_dot_operands(pm, True) + passes.common.add_cse(pm) + if os.getenv('DFT_PP', '0') == '1': + if os.getenv('OFF_ASYNC', '0') == '0': + passes.ttgpuir.add_assign_latencies(pm, num_stages) + passes.ttgpuir.add_schedule_loops(pm) + passes.ttgpuir.add_pipeline(pm, num_stages, False ) + if os.getenv('OFF_PREF', '0') == '0': + passes.ttir.add_loop_aware_cse(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_loop_aware_cse(pm) + passes.ttgpuir.add_prefetch(pm) + else: + if os.getenv('OFF_ASYNC', '0') == '0': + passes.ttgpuir.add_assign_latencies(pm, num_stages) + passes.ttgpuir.add_schedule_loops(pm) + sunrise.passes.ttgpuir.add_pipeline(pm, num_stages, 1, 0) # 版本:1.0 + if os.getenv('OFF_PREF', '0') == '0': + passes.ttir.add_loop_aware_cse(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_loop_aware_cse(pm) + passes.ttgpuir.add_prefetch(pm) + # passes.ttgpuir.add_optimize_dot_operands(pm, True) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_reduce_data_duplication(pm) + passes.ttgpuir.add_reorder_instructions(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + if os.getenv('K_OUTER', '0') == '1': + print('not run split_dot pass because K_OUTER == 1') + else: + sunrise.passes.ttgpuir.add_split_dot(pm, 1, 0) + # if capability // 10 >= 9: + # nvidia.passes.ttnvgpuir.add_fence_insertion(pm) + passes.common.add_canonicalizer(pm) + pm.run(mod) + return mod + + def ttgir_opt(self, src, metadata, options, capability): + mod = src + pm = ir.pass_manager(mod.context) + pm.enable_debug() + + passes.ttgpuir.add_inliner(pm) + passes.common.add_sccp(pm) + passes.ttir.add_loop_aware_cse(pm) + passes.ttgpuir.add_canonicalizer(pm) + passes.ttgpuir.add_combine_tensor_select_and_if(pm) + + pm.run(mod) + # metadata["tensordesc_meta"] = mod.get_tensordesc_metadata() + return mod + + @staticmethod + def make_llir(src, metadata, options, capability): + mod = src + # TritonGPU -> LLVM-IR (MLIR) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.ttgpuir.add_combine_tensor_select_and_if(pm) + passes.convert.add_scf_to_cf(pm) + passes.convert.add_index_to_llvmir(pm) + passes.ttgpuir.add_allocate_shared_memory(pm) + sunrise.passes.ttgpuir.add_to_llvmir(pm, capability) + sunrise.passes.ttgpuir.add_remove_repeated_fence(pm) + passes.convert.add_scf_to_cf(pm) + passes.convert.add_cf_to_llvmir(pm) + passes.convert.add_arith_to_llvmir(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + if not knobs.compilation.disable_line_info: + passes.llvmir.add_di_scope(pm) + pm.run(mod) + # LLVM-IR (MLIR) -> LLVM-IR (LLVM) + llvm.init_targets() + context = llvm.context() + llvm_mod = llvm.to_module(mod, context) + llvm.attach_datalayout(llvm_mod, 'stcu-unknown-tang', '', '') + # sunrise.set_nvvm_reflect_ftz(llvm_mod) # 属性设置,可以不需要 + if options.extern_libs: + for name, path in options.extern_libs: + if name == "ocml": + llvm.link_extern_libs(llvm_mod, [path]) + # if name == "ockl": + # llvm.link_override_lib(llvm_mod, path) + llvm.optimize_module(llvm_mod, SunriseBackend.get_optimization_level(llvm)) + + # Get some metadata + total_num_warps = src.get_int_attr("ttg.total-num-warps") + if total_num_warps is not None: + metadata["num_warps"] = total_num_warps + metadata["shared"] = src.get_int_attr("ttg.shared") + ret = str(llvm_mod) + ret = ret.replace("define void @", "define dso_local cc200 void @") + if knobs.sunrise.dump_stcu: + with open('sunrise.ll', 'w') as f: + f.write(ret) + del llvm_mod + del context + return ret + + @staticmethod + def make_stcu(src, metadata, opt, capability): + names = re.findall(r"define dso_local cc200 void @([a-zA-Z_][a-zA-Z0-9_]*)", src) + + assert len(names) == 1 + metadata["name"] = names[0] + proc = '' + + triple = SunriseBackend.get_triple() + flag = SunriseBackend.get_flag(metadata, opt) + if knobs.sunrise.dump_stcu: + asm_debug = llvm.translate_to_asm(src, triple, proc, '', flag, opt.enable_fp_fusion, + False) + with open('sunrise.asm', 'w') as f: + f.write(asm_debug) + + asm = llvm.translate_to_asm(src, triple, proc, '', flag, opt.enable_fp_fusion, True) + if knobs.sunrise.dump_stcu: + with open('sunrise.elf', 'wb') as f: + f.write(asm) + + bundler = SunriseBackend.path_to_clang_offload_bundler() + + major = 0 + try: + output = subprocess.check_output([bundler, "--version"], stderr=subprocess.STDOUT) + version_str = output.decode("utf-8").strip() + match = re.search(r"version\s+(\d+)\.(\d+)\.(\d+)", version_str) + if match: + major = int(match.group(1)) + else: + print("Cannot parse clang-offload-bundler version\n") + except Exception as e: + print("Error getting version:", e) + + arch = platform.machine() + + with tempfile.NamedTemporaryFile() as tmp_out: + with tempfile.NamedTemporaryFile() as tmp_in: + with open(tmp_in.name, 'wb') as fd_in: + fd_in.write(asm) + try: + cmd = f'{bundler} -type=o -targets=host-{arch}-unknown-linux,tang-stpu-unknown-tang -input=/dev/null -input={tmp_in.name} -output={tmp_out.name}' + subprocess.run(cmd, shell=True, check=True) + except subprocess.CalledProcessError as e: + print(" run error\n") + + with open(tmp_out.name, 'rb') as fd_out: + ret = fd_out.read() + return ret + + def add_stages(self, stages, options, language): + capability = 80 # options.arch + if language == Language.TRITON: + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) + stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability) + elif language == Language.GLUON: + stages["ttgir"] = lambda src, metadata: self.ttgir_opt(src, metadata, options, capability) + stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability) + stages["stcu"] = lambda src, metadata: self.make_stcu(src, metadata, options, capability) + + @functools.lru_cache() + def hash(self): + version = subprocess.check_output([SunriseBackend.path_to_clang_offload_bundler(), "--version"], encoding='utf-8') + return f'{version}-{self.target.arch}' diff --git a/third_party/sunrise/backend/driver.c b/third_party/sunrise/backend/driver.c new file mode 100644 index 000000000..f04c0e9ce --- /dev/null +++ b/third_party/sunrise/backend/driver.c @@ -0,0 +1,170 @@ +#include "tang.h" +#include +#include +#define PY_SSIZE_T_CLEAN +#include + +// Raises a Python exception and returns false if code is not TANG_SUCCESS. +static bool gpuAssert(TAresult code, const char *file, int line) { + if (code == TANG_SUCCESS) + return true; + + const char *prefix = "Triton Error [TANG]: "; + const char *str; + taGetErrorString(code, &str); + char err[1024] = {0}; + strcat(err, prefix); + strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + return false; +} + +// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block. +#define TANG_CHECK_AND_RETURN_NULL(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) \ + return NULL; \ + } while (0) + +// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block. +#define TANG_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) { \ + PyEval_RestoreThread(_save); \ + return NULL; \ + } \ + } while (0) + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + int device_id; + if (!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + // Get device handle + TAdevice device; + taDeviceGet(&device, device_id); + + // create a struct to hold device properties + int max_shared_mem; + int max_num_regs; + int multiprocessor_count; + int warp_size; + int sm_clock_rate; + int mem_clock_rate; + int mem_bus_width; + TANG_CHECK_AND_RETURN_NULL(taDeviceGetAttribute( + &max_shared_mem, TA_DEV_ATTR_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + TANG_CHECK_AND_RETURN_NULL(taDeviceGetAttribute( + &max_num_regs, TA_DEV_ATTR_REGS_PER_BLOCK, device)); + TANG_CHECK_AND_RETURN_NULL(taDeviceGetAttribute( + &multiprocessor_count, TA_DEV_ATTR_MULTIPROCESSOR_COUNT, device)); + TANG_CHECK_AND_RETURN_NULL(taDeviceGetAttribute( + &warp_size, TA_DEV_ATTR_WARP_SIZE, device)); + TANG_CHECK_AND_RETURN_NULL(taDeviceGetAttribute( + &sm_clock_rate, TA_DEV_ATTR_CLOCK_RATE, device)); + TANG_CHECK_AND_RETURN_NULL(taDeviceGetAttribute( + &mem_clock_rate, TA_DEV_ATTR_MEMORY_CLOCK_RATE, device)); + TANG_CHECK_AND_RETURN_NULL(taDeviceGetAttribute( + &mem_bus_width, TA_DEV_ATTR_MEMORY_BUS_WIDTH, device)); + + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", + max_shared_mem, "max_num_regs", max_num_regs, + "multiprocessor_count", multiprocessor_count, "warpSize", + warp_size, "sm_clock_rate", sm_clock_rate, + "mem_clock_rate", mem_clock_rate, "mem_bus_width", + mem_bus_width); +} + +static PyObject *loadBinary(PyObject *self, PyObject *args) { + const char *name; + const char *data; + Py_ssize_t data_size; + int shared; + int device; + if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, + &device)) { + return NULL; + } + TAfunction fun; + TAmodule mod; + int32_t n_regs = 0; + int32_t n_spills = 0; + int32_t n_max_threads = 0; + // create driver handles + TAcontext pctx = 0; + TAdevice device_hd; + taDeviceGet(&device_hd, device); + + Py_BEGIN_ALLOW_THREADS; + TANG_CHECK_AND_RETURN_NULL_ALLOW_THREADS(taCtxGetCurrent(&pctx)); + if (!pctx) { + TANG_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + taDevicePrimaryCtxRetain(&pctx, device_hd)); + TANG_CHECK_AND_RETURN_NULL_ALLOW_THREADS(taCtxSetCurrent(pctx)); + } + + TANG_CHECK_AND_RETURN_NULL_ALLOW_THREADS(taModuleLoadData(&mod, data, (size_t)data_size)); + TANG_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + taModuleGetFunction(&fun, mod, name)); + // get allocated registers and spilled registers from the function + /* 不支持属性获取, 按照默认0处理 */ + // TANG_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + // taFuncGetAttribute(&n_regs, TA_FUNC_ATTRIBUTE_NUM_REGS, fun)); + // TANG_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + // taFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); + // n_spills /= 4; + TANG_CHECK_AND_RETURN_NULL_ALLOW_THREADS(taFuncGetAttribute( + &n_max_threads, TA_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun)); + // set dynamic shared memory if necessary + // int shared_optin; + // TANG_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( + // &shared_optin, TA_DEV_ATTR_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + // device)); + // if (shared > 49152 && shared_optin > 49152) { + // TANG_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + // cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); + // int shared_total, shared_static; + // TANG_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( + // &shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, + // device)); + // TANG_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute( + // &shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); + // TANG_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + // cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + // shared_optin - shared_static)); + // } + Py_END_ALLOW_THREADS; + + if (PyErr_Occurred()) { + return NULL; + } + return Py_BuildValue("(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs, + n_spills, n_max_threads); +} + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBinary, METH_VARARGS, + "Load provided cubin into TANG driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given device"}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "tang_utils", + NULL, // documentation + -1, // size + ModuleMethods}; + +PyMODINIT_FUNC PyInit_tang_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + + PyModule_AddFunctions(m, ModuleMethods); + + return m; +} diff --git a/third_party/sunrise/backend/driver.py b/third_party/sunrise/backend/driver.py new file mode 100644 index 000000000..1008ed4a2 --- /dev/null +++ b/third_party/sunrise/backend/driver.py @@ -0,0 +1,523 @@ +import functools +import os +import platform +import subprocess +import re +from pathlib import Path +#from triton import knobs +from ..python.triton import knobs +from triton.runtime.build import compile_module_from_src +from triton.runtime import _allocation +from triton.backends.compiler import GPUTarget +from triton.backends.driver import GPUDriver + +dirname = os.path.dirname(os.path.realpath(__file__)) +include_dirs = [os.path.join(dirname, "include")] +libdevice_dir = os.path.join(dirname, "lib") +libraries = ['tang', 'tangrt_shared'] +arch = platform.machine() + +@functools.lru_cache() +def libtang_dirs(): + if env_libtang_path := knobs.sunrise.libtang_path: + return [env_libtang_path] + + libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() + # each line looks like the following: + # libtang.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libtang.so.1 + locs = [line.split()[-1] for line in libs.splitlines() if "libtang.so" in line] + dirs = [os.path.dirname(loc) for loc in locs] + env_ld_library_path = os.getenv("LD_LIBRARY_PATH") + if env_ld_library_path and not dirs: + dirs = [dir for dir in env_ld_library_path.split(":") if os.path.exists(os.path.join(dir, "libtang.so"))] + if not dirs: + dirs = [f'/usr/local/tangrt/lib/linux-{arch}/stub/'] + msg = 'libtang.so cannot found!\n' + if locs: + msg += 'Possible files are located at %s.' % str(locs) + msg += 'Please create a symlink of libtang.so to any of the files.' + else: + msg += 'Please make sure GPU is set up and then run "/sbin/ldconfig"' + msg += ' (requires sudo) to refresh the linker cache.' + assert any(os.path.exists(os.path.join(path, 'libtang.so')) for path in dirs), msg + return dirs + + +@functools.lru_cache() +def library_dirs(): + return [libdevice_dir, *libtang_dirs(), f"/usr/local/tangrt/lib/linux-{arch}"] + + +# ------------------------ +# Utils +# ------------------------ + + +class SunriseUtils(object): + + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(SunriseUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + mod = compile_module_from_src( + src=Path(os.path.join(dirname, "driver.c")).read_text(), + name="tang_utils", + library_dirs=library_dirs(), + include_dirs=include_dirs, + libraries=libraries, + ) + self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + + +# ------------------------ +# Launcher +# ------------------------ + + +def ty_to_cpp(ty): + if ty[0] == '*': + return "TAdeviceptr" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "double", + "bf16": "double", + "fp32": "double", + "f32": "double", + "fp64": "double", + }[ty] + +FLOAT_STORAGE_TYPE = { + "fp16": "uint16_t", + "bf16": "uint16_t", + "fp32": "uint32_t", + "f32": "uint32_t", + "fp64": "uint64_t", +} +FLOAT_PACK_FUNCTION = { + "fp16": "pack_fp16", + "bf16": "pack_bf16", + "fp32": "pack_fp32", + "f32": "pack_fp32", + "fp64": "pack_fp64", +} + +_BASE_ARGS_FORMAT = "iiiKKOOOO" + +def make_launcher(constants, signature, warp_size): + def _expand_signature(signature): + output = [] + # Expand tensor descriptor arguments into base pointer, shape, and + # strides + for sig in signature: + if isinstance(sig, str) and sig.startswith("tensordesc"): + ndim = sig.count(",") + 1 + dtype = re.match("tensordesc<([^[>]*)", sig).group() + + output.append("*" + dtype) + # Currently the host side tensor descriptors get passed in as a + # tensor desc, shape, and strides. We have no way to use these + # shape and strides when processing tensor descriptors which is + # why we provide our own decomposition above. Sadly this means + # we have to pass the shape and strides twice. + for _ in range(2 * ndim): + output.append("i64") + + for _ in range(ndim): + output.append("i32") + for _ in range(ndim): + output.append("i64") + else: + output.append(sig) + + return output + + def _flatten_signature(sig, output): + # Flatten tuples + if isinstance(sig, tuple): + for x in sig: + _flatten_signature(x, output) + else: + output.append(sig) + + def _extracted_type(ty): + if isinstance(ty, tuple): + val = ','.join(map(_extracted_type, ty)) + return f"[{val}]" + if ty[0] == '*': + return "PyObject*" + if ty == "constexpr": + return "PyObject*" + return ty_to_cpp(ty) + + def format_of(ty): + if isinstance(ty, tuple): + val = ''.join(map(format_of, ty)) + return f"({val})" + if ty[0] == '*': + return "O" + if ty == "constexpr": + return "O" + return { + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "L", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[ty_to_cpp(ty)] + + expand_signature = _expand_signature(signature.values()) + signature = {i: s for i, s in enumerate(expand_signature)} + + args_format = ''.join([format_of(ty) for ty in signature.values()]) + format = _BASE_ARGS_FORMAT + args_format + + flat_signature = [] + for sig in signature.values(): + _flatten_signature(sig, flat_signature) + signature = {i: s for i, s in enumerate(flat_signature)} + args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + arg_decl_list = [] + for i, ty in signature.items(): + if ty == "constexpr": + continue + if ty in FLOAT_STORAGE_TYPE: + arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}") + else: + arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}") + arg_decls = ', '.join(arg_decl_list) + internal_args_list = [] + for i, ty in signature.items(): + if ty[0] == "*": + internal_args_list.append(f"ptr_info{i}.dev_ptr") + elif ty in FLOAT_STORAGE_TYPE: + internal_args_list.append(f"_arg{i}_storage") + elif ty != "constexpr": + internal_args_list.append(f"_arg{i}") + + # generate glue code + newline = '\n ' + ptr_decls = [ + f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" + for i, ty in signature.items() + if ty[0] == "*" + ] + float_storage_decls = [ + f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});" + for i, ty in signature.items() + if ty in FLOAT_STORAGE_TYPE + ] + params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"] + params.append("&global_scratch") + src = f""" +#include \"tang.h\" +#include \"tang_runtime.h\" +#include +#include +#include + +static inline void gpuAssert(TAresult code, const char *file, int line) +{{ + if (code != TANG_SUCCESS) + {{ + const char* prefix = "Triton Error [TANG]: "; + const char* str; + taGetErrorString(code, &str); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + }} +}} + +#define TANG_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} + +static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, TAstream stream, TAfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + TAdeviceptr global_scratch = 0; + void *params[] = {{ {', '.join(params)} }}; + if (gridX*gridY*gridZ > 0) {{ + TANG_CHECK(taLaunchKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0)); + }} +}} + +typedef struct _DevicePtrInfo {{ + TAdeviceptr dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + + ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret); + + if(!ptr_info.dev_ptr) {{ + return ptr_info; + }} + + // 暂时使用PyTorch接口的方案, 后续taPointerGetAttribute支持使用指针切换后,还是使用它 + // 获取 device 属性 + PyObject* device_obj = PyObject_GetAttrString(obj, "device"); + if (device_obj && device_obj != Py_None) {{ + // 获取 device.index + PyObject* index_obj = PyObject_GetAttrString(device_obj, "index"); + if (index_obj && PyLong_Check(index_obj)) {{ + int dev = PyLong_AsLong(index_obj); + // printf("[DEBUG] Switching to tensor device (device.index): %d\\n", dev); + tangSetDevice(dev); + }} + Py_XDECREF(index_obj); + }} + Py_XDECREF(device_obj); + + uint64_t dev_ptr; + int status = taPointerGetAttribute(&dev_ptr, TA_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); + if (status == TANG_ERROR_INVALID_VALUE) {{ + PyErr_Format(PyExc_ValueError, + "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); + ptr_info.valid = false; + }} else if (status != TANG_SUCCESS) {{ + TANG_CHECK(status); // Catch any other TANG API errors + ptr_info.valid = false; + }} + ptr_info.dev_ptr = dev_ptr; + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; + return ptr_info; +}} + +static uint16_t pack_fp16(double f) {{ + uint16_t result; + // from https://github.com/python/pythoncapi-compat +#if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION) + _PyFloat_Pack2(f, (unsigned char*)&result, 1); +#else + PyFloat_Pack2(f, (unsigned char*)&result, 1); +#endif + return result; +}} + +static uint16_t pack_bf16(double f) {{ + float f32 = (float)f; + uint32_t u32 = *(uint32_t*)&f32; + return (uint16_t)(u32 >> 16); +}} + +static uint32_t pack_fp32(double f) {{ + float f32 = (float)f; + return *(uint32_t*)&f32; +}} + +static uint64_t pack_fp64(double f) {{ + return *(uint64_t*)&f; +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + uint64_t _stream; + uint64_t _function; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &_stream, &_function, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook {args_list})) {{ + PyErr_SetString(PyExc_TypeError, "get input data error"); + return NULL; + }} + + {' '.join(float_storage_decls)} + + int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; + if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ + PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); + return NULL; + }} + + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + Py_DECREF(ret); + }} + + // raise exception asap + {"".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (TAstream)_stream, (TAfunction)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); + if (PyErr_Occurred()) {{ + return NULL; + }} + + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + Py_DECREF(ret); + }} + + Py_RETURN_NONE; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + return src + +def wrap_handle_tensor_descriptor(launcher): + from triton.tools.tensor_descriptor import TensorDescriptor + def inner(*args): + meta_args = args[:len(_BASE_ARGS_FORMAT)] + raw_kernel_args = args[len(_BASE_ARGS_FORMAT):] + final_args = [] + for arg in raw_kernel_args: + if isinstance(arg, TensorDescriptor): + # Currently the host side tensor descriptors get decomposed in + # the frontend to tensor desc, shape, and strides. We have no + # way to use these shape and strides when processing tensor + # descriptors which is why we provide our own decomposition + # above. Sadly this means we have to pass the shape and strides + # twice. + final_args.extend([arg.base, *arg.shape, *arg.strides, *arg.shape, *arg.strides]) + else: + final_args.append(arg) + return launcher(*meta_args, *final_args) + + return inner + +class SunriseLauncher(object): + + def __init__(self, src, metadata): + constants = src.constants if hasattr(src, "constants") else dict() + arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x + constants = {arg_idx(idx): value for idx, value in constants.items()} + signature = {idx: value for idx, value in src.signature.items()} + src = make_launcher(constants, signature, metadata.warp_size) + mod = compile_module_from_src( + src=src, + name="__triton_launcher", + library_dirs=library_dirs(), + include_dirs=include_dirs, + libraries=libraries, + ) + has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values()) + self.launch = wrap_handle_tensor_descriptor(mod.launch) if has_tensor_desc_arg else mod.launch + + def __call__(self, *args): + self.launch(*args) + + +class SunriseDriver(GPUDriver): + + def __init__(self): + self.utils = SunriseUtils() # TODO: make static + self.launcher_cls = SunriseLauncher + from triton.backends.iluvatar import spec + self.spec = spec + super().__init__() + + def get_current_target(self): + capability = "S2" + warp_size = 32 + return GPUTarget("tang", capability, warp_size) + + def get_active_torch_device(self): + import torch + return torch.device("cuda", self.get_current_device()) + + def get_device_interface(self): + import torch + return torch.cuda + + @staticmethod + def is_active(): + if os.getenv('FLAGTREE_BACKEND', '') == 'sunrise': + return True + return False + + def get_benchmarker(self): + from triton.testing import do_bench + return do_bench + + def get_empty_cache_for_benchmark(self): + import torch + import torch_dipu + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 cache + # doesn't contain any input data before the run + cache_size = 256 * 1024 * 1024 + return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda') + + def clear_cache(self, cache): + cache.zero_() \ No newline at end of file diff --git a/third_party/sunrise/backend/include/ptml.h b/third_party/sunrise/backend/include/ptml.h new file mode 100755 index 000000000..2b5d7b3e8 --- /dev/null +++ b/third_party/sunrise/backend/include/ptml.h @@ -0,0 +1,925 @@ +//////////////////////////////////////////////////////// +// @file ptml.h +// ptml api +// ptmlDevice_t represents the type of device index +//////////////////////////////////////////////////////// + +#ifndef _PT_ML_H_ +#define _PT_ML_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif //! __cplusplus + +#include + +#ifndef TA_PT_NUM_MAX +#define TA_PT_NUM_MAX 128 +#endif //! TA_PT_NUM_MAX + +#if defined(_MSC_VER) +#define PTML_DEPRECATED __declspec(deprecated) +#define PTML_API_EXPORT __declspec(dllexport) +#define PTML_API_IMPORT __declspec(dllimport) +#elif defined(__GNUC__) || defined(__clang__) +#define PTML_DEPRECATED __attribute__((deprecated)) +#define PTML_API_EXPORT __attribute__((visibility("default"))) +#define PTML_API_IMPORT __attribute__((visibility("default"))) +#else +#define PTML_DEPRECATED +#define PTML_API_EXPORT +#define PTML_API_IMPORT +#endif //! UNKNOWN COMPILER + +#if defined(ptml_EXPORTS) +#define PTML_API PTML_API_EXPORT +#else +#define PTML_API PTML_API_IMPORT +#endif //! For user + +#define PAGE_SIZE 4096 +#ifndef ALIGN +#define __ALIGN_KERNEL_MASK(x, mask) (((x) + (mask)) & ~(mask)) +#define __ALIGN_KERNEL(x, a) __ALIGN_KERNEL_MASK(x, (__typeof__(x))(a)-1) +#define ALIGN(x, a) __ALIGN_KERNEL((x), (a)) +#endif // ALIGN +#define PAGE_ALIGN(addr) ALIGN(addr, PAGE_SIZE) +#define CLK_NAME_MAX 16 +#define NUMBER_OF_CYCLES_IN_1_SEC 0x38400000 // 900M +#define am_interval_1s (1800000000) // 1800M stands for 1 second + +/*ioctl parm cmd*/ +#define CM3 0x10 +#define MLP 0x11 +#define LINUX 0x12 + +#define CMD_PMIC (0xa8) +#define MOD_TEMP_TYPE_NUM (1) +#define TEMP_IPID_CNT (8) +#define C2C_INFO_CNT (10) + +#define CMD_TEMPERATURE 0xab +#define CMD_HBM_TEMPERATURE 0xb2 +#define CMD_GET_GPIO_STATUS 0xb7 +#define CMD_DUMP_MEM 0xb5 +#define CMD_GET_CPLD_VERSION 0xbc +#define CMD_GET_MAX_POWER 0xbd +#define CMD_GET_EXCEPTION 0xb6 + +/*linux cmd*/ +#define LINUX_CMD_PTUTILI (0x11) +#define LINUX_CMD_PCIERELINK (0x13) +#define LINUX_CMD_HBMBWUTILI (0x1A) +#define LINUX_CMD_HBMUTILI (0x1B) +#define LINUX_CMD_C2CREVDB (0x14) +#define LINUX_CMD_C2CTRANSDB (0x15) +#define LINUX_CMD_PCIEREVDB (0x16) +#define LINUX_CMD_PCIETRANSDB (0x17) +#define LINUX_CMD_TUUTILI (0x18) +#define LINUX_CMD_THREADUTILI (0x19) + +/** + * @brief Return val for ptml API + */ +typedef enum ptmlReturn_enum { + PTML_SUCCESS = 0, //!< APT returns ok + PTML_ERROR_UNINITIALIZED, //!< ptmlInit is not called now + PTML_ERROR_INVALID_ARGUMENT, //!< invalid argument + PTML_ERROR_ALREADY_INITIALIZED, //!< ptmlInit is already called + PTML_ERROR_INSUFFICIENT_SIZE, //!< An input argument is not large enough + PTML_ERROR_IN_USE, //!< PT is in use + PTML_ERROR_DRIVER_NOT_LOADED, //!< driver is not loaded + PTML_ERROR_DEVICE_NOT_FOUND, //!< device is not found + PTML_ERROR_EVENT_TIMEOUT, //!< device is not found + PTML_ERROR_UNKNOWN, //!< An internal driver error occurred +} ptmlReturn_t; + + +typedef struct { + enum ptmlReturn_enum errorCode; + const char *errorMessage; +} ErrorDescriptor; + +typedef enum ptmlClockType { + //!< PTML_CLOCK_GRAPHICS = 0, + //!< PTML_CLOCK_SM, + PTML_CLOCK_PT = 0, + PTML_CLOCK_MEM, + //!< PTML_CLOCK_VIDEO, +} ptmlClockType_t; + +/** + * @brief Device Handle type + * + ********************************************/ +typedef int ptmlDevice_t; + +typedef struct ptMemory { + size_t total; //!< total memory + size_t used; //!< used memory + size_t free; //!< free memory +} ptMemory_t; + +#define PTML_DEVICE_PCI_BUS_ID_BUFFER_SIZE 32 + +typedef struct ptPciInfo { + int domain; //!< domain number + int bus; //!< bus number + int device; //!< dev && func number + int vendor; //!< vendor number + int pciSubSystemId; //!< subsystem Id + char busId[PTML_DEVICE_PCI_BUS_ID_BUFFER_SIZE]; //!< "domain:bus:device:0" + unsigned int max_link_speed; //!< max link speed(MT/s) + unsigned int max_link_width; //!< max link width + unsigned int max_bandwidth; //!< max bandwidth(MB/s) + unsigned int curr_link_speed; //!< current link speed(MT/s) + unsigned int curr_link_width; //!< current link width + unsigned int curr_bandwidth; //!< current bandwidth(MB/s) +} ptPciInfo_t; + +typedef struct ptProcessInfo { + char name[256]; + char pidStr[16]; +}ptProcessInfo_t; + +struct barInfo { + int barIdx; //!< 0-5 + void* addr; //!< The mapped vritual address. + uint64_t paddr; + size_t size; //!< The size of the mapped vritaul address space. +}; + +struct fw_param { + int cmd; + int len; + int data[]; +}; +#define _S2_IOC_SMI_INFO _IOWR(_S2_IOC_MAGIC, 160, struct fw_param) + +struct st_rpmsg_cmd { + unsigned int cmd; + unsigned char data[]; +} __attribute__((packed)); + +struct st_rpmsg_i2c_cmd { + unsigned char rx_data_len; + unsigned char tx_data_len; + unsigned char slave_len; + unsigned char trx_flag; + unsigned char pyload[]; // slave addr + tx data +} __attribute__((packed)); + +struct st_rpmsg_i2c_response { + unsigned char error_code; + unsigned char rx_data_len; + unsigned char slave_len; + unsigned char trx_flag; + unsigned char pyload[]; // slave addr + rx data +} __attribute__((packed)); + +struct st_rpmsg_pmic_cmd { + unsigned char pmic_cmd; + unsigned int pmic_param; +} __attribute__((packed)); + +struct st_rpmsg_pmic_response { + unsigned char pmic_cmd; + unsigned char error_code; + unsigned char pmic_data[4]; +} __attribute__((packed)); + +/* + * RPmsg Clock Command IDs + */ +enum rpmsg_clk_cmd_id { + RPMSG_CLK_GET_STATE, + RPMSG_CLK_GET_NAME, + RPMSG_CLK_GET_RATE, + RPMSG_CLK_ENABLE, + RPMSG_CLK_DISABLE, + RPMSG_CLK_CMD_COUNT, +}; + +enum plat_clock_idx { + MOD_CLOCK_G0_0, + MOD_CLOCK_G0_1, + MOD_CLOCK_G1_0, + MOD_CLOCK_G2_0, + MOD_CLOCK_G3_0, + MOD_CLOCK_G4_0, + MOD_CLOCK_G5_0, + MOD_CLOCK_G6_0, + MOD_CLOCK_G7_0, + MOD_CLOCK_G8_0, + MOD_CLOCK_G9_0, + MOD_CLOCK_G10_0, + MOD_CLOCK_G11_0, + MOD_CLOCK_G12_0, + MOD_CLOCK_G13_0, + MOD_CLOCK_G13_1, + MOD_CLOCK_G13_2, + MOD_CLOCK_G14_0, + MOD_CLOCK_G14_1, + + MOD_CLOCK_L0_CLK2000CLK, + MOD_CLOCK_L0_CLK1000CLK, + MOD_CLOCK_L0_CLK500CLK, + MOD_CLOCK_L0_CLK250CLK, + MOD_CLOCK_L0_CLK125CLK, + MOD_CLOCK_L0_CLK62P5CLK, + MOD_CLOCK_L0_CLK31P25CLK, + MOD_CLOCK_L0_SMB_MELESCLK, + MOD_CLOCK_L0_MELS_REF_CLK, + MOD_CLOCK_L0_SMB_32KCLK, + MOD_CLOCK_L0_PLL0_CLK, + MOD_CLOCK_L1_CORE_CLK_L, + MOD_CLOCK_L1_NOC_CLK0, + MOD_CLOCK_L1_PLL1_CLK, + MOD_CLOCK_L2_CORE_CLK_H, + MOD_CLOCK_L2_NOC_CLK1, + MOD_CLOCK_L2_PLL2_CLK, + MOD_CLOCK_L3_APBCLK, + MOD_CLOCK_L3_PLL3_CLK, + MOD_CLOCK_L4_VIDEO_CLK, + MOD_CLOCK_L4_PLL4_CLK, + MOD_CLOCK_L5_DMA_CLK, + MOD_CLOCK_L5_TIGER_CLK, + MOD_CLOCK_L5_PLL5_CLK, + MOD_CLOCK_L6_AXI_CLOCK0, + MOD_CLOCK_L6_PLL6_CLK, + MOD_CLOCK_L7_AXI_CLOCK1, + MOD_CLOCK_L7_PLL7_CLK, + MOD_CLOCK_L8_JPEG_CLK, + MOD_CLOCK_L8_PLL8_CLK, + MOD_CLOCK_L9_PLLREFCLK0, + MOD_CLOCK_L9_DFICLK0, + MOD_CLOCK_L9_DFIHDRCLK0, + MOD_CLOCK_L9_PLL9_CLK, + MOD_CLOCK_L10_PLLREFCLK1, + MOD_CLOCK_L10_DFICLK1, + MOD_CLOCK_L10_DFIHDRCLK1, + MOD_CLOCK_L10_PLL10_CLK, + MOD_CLOCK_L11_PLLREFCLK2, + MOD_CLOCK_L11_DFICLK2, + MOD_CLOCK_L11_DFIHDRCLK2, + MOD_CLOCK_L11_PLL11_CLK, + MOD_CLOCK_L12_PLLREFCLK3, + MOD_CLOCK_L12_DFICLK3, + MOD_CLOCK_L12_DFIHDRCLK3, + MOD_CLOCK_L12_PLL12_CLK, + MOD_CLOCK_L13_ACLK0, + MOD_CLOCK_L13_PLL13_CLK, + MOD_CLOCK_L15_AUX_CLK0, + MOD_CLOCK_L16_ACLK1, + MOD_CLOCK_L16_PLL16_CLK, + MOD_CLOCK_L17_AUX_CLK1, + + MOD_CLOCK_IDX_COUNT, +}; + +enum c2c_port_index { + C2C0_0 = 0, + C2C0_1, + C2C1_0, + C2C1_1, + C2C2_0, + C2C2_1, + C2C3_0, + C2C3_1, + C2C4_0, + C2C4_1, + PCIE, +}; +/* + * struct c2h_clk_msg - Response payload for RPMSG_CLK_ATTRIBUTES_DUMP command + * @status: Command status + * @state: Clock state(on or off) + * @rate: Clock rate in Hz, + * rate[0] 32bit lsb clock rate + * rate[1] 32bit hsb clock rate + * @name: Clock name + */ +struct c2h_clk_msg { + int status; + uint32_t state; + uint32_t rate[2]; + char name[16]; +}; + +struct c2h_temp_msg { + int status; + int temp[TEMP_IPID_CNT]; +}; + +#define ALLPORTS 0x3ff +typedef struct ptPhyTopo { + unsigned char local_chipid; + unsigned char local_port; + unsigned char remote_chipid; + unsigned char remote_port; + unsigned char link_status; + unsigned char isBif; + unsigned char max_speed; + unsigned char cur_speed; + unsigned char max_bandwidth; + unsigned char cur_bandwidth; +} ptPhyTopo_t; + +/** + * @brief Init ptml module + * + * @return int + * @note must called before using any ptml API + ********************************************/ +ptmlReturn_t PTML_API ptmlInit(void); +ptmlReturn_t PTML_API __ptmlUinit(void); +void PTML_API ptmlUninit(void); + +/** + * @brief Get system driver version 0.1.0 + * + * @param length is driver version's length + * @note driver Version is Equivalent to dirver version, specific information + *in + * "/sys/module/pt/version" + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlSystemGetDriverVersion(char * version, + unsigned int length); + +/** + * @brief Get system tang version 0.1.0 + * + * @note tang Version is Equivalent to cuda version + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlSystemGetTangVersion(int *version); +ptmlReturn_t PTML_API ptmlSystemGetTangVersionForSmi(char * version, + unsigned int length); + +/** + * @brief Get dev base info: ptType and memInfo + * + * @param device device handles + * @param ptTypeOut pointer to ptTypeOut + * @param memInfo pointer to dev memInfo + * @note ptTypeOut is pt200, device info is total mem size + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetBaseInfo(ptmlDevice_t device, + char * ptTypeOut, + unsigned int *memInfo); + +/** + * @brief Get PT board count + * + * @param devCount pointer to devCount + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetBoardCount(unsigned int *devCount); + +/** + * @brief Get PT type :pt200 + * + * @param device device handles + * @param ptTypeOut pointer to pt type + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetBrand(ptmlDevice_t device, char *ptTypeOut); + +/** + * @brief Get device Capacity + * + * @param device device handles + * @param value pointer to Capacity + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetCapacity(ptmlDevice_t device, float *value); + +/** + * @brief Get PT count + * + * @param devCount pointer to devCcount + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetCount(int *devCount); + +/** + * @brief Get device c2c rev DB + * + * @param device device handles + * @param c2cIndex is port num 0-10 + * @param interval is from 1-60(s) + * @param revdb pointer is c2c revdb + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetC2CRevDB(ptmlDevice_t device, + int c2cIndex, + int interval, + unsigned int *revdb); + +/** + * @brief Get device c2c trans DB + * + * @param device device handles + * @param c2cIndex is port num 0-10 + * @param interval is from 1-60(s) + * @param transdb pointer is c2c transdb + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetC2CTransDB(ptmlDevice_t device, + int c2cIndex, + int interval, + unsigned int *transdb); + +/** + * @brief Get device fw version + * + * @param device device handles + * @param version pointer to fw version + * @param length is version's length + * @note fw is cm3 linux mlp mix version Internal use only Not provided + *externally + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetFirmwareVersion(ptmlDevice_t device, + char * version, + unsigned int length); + +/** + * @brief Get Device Handle by idx + * + * @param idx idx of the target PT + * @param device pointer to the handle of target PT + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetHanldeByIdx(unsigned int idx, + ptmlDevice_t *device); + +/** + * @brief Get Device Handle by PCI + * + * @param idx idx of the target PT + * @param device pointer to the handle of target PT + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetHanldeByPciBusId(const char * busId, + ptmlDevice_t *device); + +/** + * @brief Get device mem BW Utilization + * + * @param device device handles + * @param interval is from 1-60(s) + * @param utilization pointer is memBWUtilization 0-100% + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetMemBWUtilizationRates(ptmlDevice_t device, + int interval, + float *utilization); + +/** + * @brief ptmlDeviceGetMemClockFrequency + * + * @param device device handles + * @param clock pointer to memClockFreq + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetMemClockFrequency(ptmlDevice_t device, + unsigned int *clock); + +/** + * @brief Get device memory information + * + * @param device device handle + * @param memInfo pointer to memory information + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetMemoryInfo(ptmlDevice_t device, + ptMemory_t * memInfo); +ptmlReturn_t PTML_API ptmlDeviceGetMemoryUsedInfo(ptmlDevice_t device, + unsigned int *usedInfo); +ptmlReturn_t PTML_API ptmlDeviceGetMemoryFreeInfo(ptmlDevice_t device, + unsigned int *usedInfo); + +/** + * @brief Get device mem Temperature + * + * @param device device handles + * @param temp pointer to ptTemperature + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetMemTemperature(ptmlDevice_t device, + unsigned int *temp); + +/** + * @brief Get device mem Utilization + * + * @param device device handles + * @param interval is from 1-60(s) + * @param utilization pointer is memUtilization 0-100% + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t ptmlDeviceGetMemUtilizationRates(ptmlDevice_t device, + int interval, + float *utilization); + +/** + * @brief Get PT node path /dev/ptpux + * + * @param device device handle + * @param path pointer to devNodePath + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetNodePath(ptmlDevice_t device, char *path); + +/** + * @brief Get device pcie relink times + * + * @param device device handles + * @param poniter to pcie relink times + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetPcieRelinkTime(ptmlDevice_t device, + int * count); + +/** + * @brief Get device pcie rev DB + * + * @param device device handles + * @param interval is from 1-60(s) + * @param revdb pointer is pcie revdb + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetPcieRevDB(ptmlDevice_t device, + int interval, + unsigned int *revdb); + +/** + * @brief Get device pcie trans DB + * + * @param device device handles + * @param interval is from 1-60(s) + * @param transdb pointer is pcie transdb + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetPcieTransDB(ptmlDevice_t device, + int interval, + unsigned int *transdb); + +/** + * @brief Get device pci information + * + * @param device device handle + * @param pciInfo pointer to pci information + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetPciInfo(ptmlDevice_t device, + ptPciInfo_t *pciInfo); + +/** + * @brief Reboot device + * + * @param device device handles + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceRebootBootloader(ptmlDevice_t device); + +/** + * @brief Get PT status 0:invalid 1:valid + * + * @param device device handles + * @param status pointer to ptstatus + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetStatus(ptmlDevice_t device, int *status); + +/** + * @brief ptmlDeviceGetPtClockFrequency + * + * @param device device handles + * @param clock pointer to ptClockFreq + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetPtClockFrequency(ptmlDevice_t device, + unsigned int *clock); + +/** + * @brief Get PTCTRL major and minor + * + * @param device device handles + * @param major pointer to ptpuctrl major + * @param minor pointer to ptpuctrl minor + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetPtCtrlMajorAndMinor(int *major, + int *minor); + +/** + * @brief Get PT major and minor /dev/ptpux + * + * @param device device handles + * @param major pointer to devmajor + * @param minor pointer to devminor + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetPtMajorAndMinor(ptmlDevice_t device, + int * major, + int * minor); + +/** + * @brief Get device pt Temperature + * + * @param device device handles + * @param temp pointer to ptTemperature + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetPtTemperature(ptmlDevice_t device, + unsigned int *temp); + +/** + * @brief Get device pt Utilization + * + * @param device device handles + * @param interval is from 1-60(s) + * @param utilization pointer to ptUtilization 0-100% + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetPtUtilizationRates(ptmlDevice_t device, + int interval, + float *utilization); + +/** + * @brief Get compute capability + * + * @param device device handles + * @param major pointer to major + * @param minor pointer to minor + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetTangComputeCapability(ptmlDevice_t device, + int * major, + int * minor); + +/** + * @brief Get device thread Utilization + * + * @param device device handles + * @param interval is from 1-60(s) + * @param utilization pointer to threadUtilization 0-100% + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetThreadUtilizationRates(ptmlDevice_t device, + int interval, + float *utilization); + +/** + * @brief Get device subcore Utilization + * + * @param device device handles + * @param interval is from 1-60(s) + * @param utilization pointer to subcoreUtilization 0-100% + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetTUUtilizationRates(ptmlDevice_t device, + int interval, + float * utilization); + +/** + * @brief Get device uuid + * + * @param device device handles + * @param uuid pointer to device uuid + * @param length device uuid length + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetUUID(ptmlDevice_t device, + char * uuid, + unsigned int length); + +/** + * @brief ptmlDeviceSetCUFrequency + * + * @param device device handles + * @param CU freq + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceSetCUFrequency(ptmlDevice_t device, + unsigned int freq) ; + +/** + * @brief Get device Mem Temperature + * + * @param device device handles + * @param temp pointer to ptTemperature + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetMemTemperature(ptmlDevice_t device, + unsigned int *temp); + +/** + * @brief ptmlDeviceGetGPIOStatus + * + * @param device device handles + * @param GPIO status + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetGPIOStatus(ptmlDevice_t device, + unsigned int *status); +/** + * @brief ptmlDeviceDumpCM3Regs + * + * @param device device handles + * @param addr is CM3 reg addr + * @param len is dump regs' length + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceDumpCM3Regs(ptmlDevice_t device, + unsigned int addr, + unsigned int len, + unsigned int *value); +/** + * @brief ptmlDeviceSetDumpTempSwitch + * + * @param device device handles + * @param switch + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceSetDumpTempSwitch(ptmlDevice_t device, + unsigned int switchFlag); + +int PTML_API ptmlDeviceGetBarInfo(ptmlDevice_t device, + unsigned int *size); + +/** + * @brief ptmlDeviceGetCPLDVersion + * + * @param device device handles + * @param large version + * @param small version + * @return ptmlReturn_t + ********************************************/ +ptmlReturn_t PTML_API ptmlDeviceGetCPLDVersion(ptmlDevice_t device, + int *lver, int *sver); + +ptmlReturn_t PTML_API ptmlDeviceGetProcessInfo(ptmlDevice_t device, + int processNum, struct ptProcessInfo *procInfo); +ptmlReturn_t PTML_API ptmlDeviceGetMaxPower(ptmlDevice_t device, + int *maxPower); +ptmlReturn_t PTML_API ptmlDeviceGetException(ptmlDevice_t device, + unsigned int *num); +/** + * @brief ptmlPtlinkEnableAll + * + * @return ptmlReturn_t + */ +ptmlReturn_t PTML_API ptmlPtlinkEnableAll(void); + +/** + * @brief ptmlPtlinkDisableAll + * + * @return ptmlReturn_t + */ +ptmlReturn_t PTML_API ptmlPtlinkDisableAll(void); + +/** + * @brief ptmlPtlinkPortControl + * + * @device device id + * @port port number + * @ ops operation: en, disable ... + * @return ptmlReturn_t + */ +ptmlReturn_t PTML_API ptmlPtlinkPortControl(ptmlDevice_t device, + uint32_t port, + uint32_t ops); + +/** + * @brief ptmlPtlinkPhytopoDetect + * + * @device device id + * @size memory size + * @buffer user buffer + * @return ptmlReturn_t + */ +ptmlReturn_t PTML_API ptmlPtlinkPhytopoDetect(ptmlDevice_t device, + uint32_t size, + void * buffer); + +/** + * @brief ptmlEngineCollAssign + * + * @device device id + * @coll_type collective type + * @buffer user buffer + * @size memory size + * @return ptmlReturn_t + */ +ptmlReturn_t PTML_API ptmlEngineCollAssign(ptmlDevice_t device, + uint32_t coll_type, + void * buffer, + uint32_t size); +/** + * @brief ptmlPtlinkGetConnectRelation + * + * @device1 device id + * @device2 device id + * @status the relationship between devices + * @return ptmlReturn_t + * @note before use this api please use ptmlPtlinkEnableAll + */ +ptmlReturn_t PTML_API ptmlPtlinkGetConnectRelation(ptmlDevice_t device1, + ptmlDevice_t device2, + int *status) ; + +/** + * @brief ptmlGetPtlinkStatus + * + * @device device id + * @port port id + * @status the status of the port + * @return ptmlReturn_t + * @note before use this api please use ptmlPtlinkEnableAll + */ +ptmlReturn_t PTML_API ptmlGetPtlinkStatus(ptmlDevice_t device, + int port, + int *status); + +/** + * @brief ptmlGetPtlinkRemoteDevicePciInfo + * + * @device device id + * @port port id + * @pciInfo pciInfo of remote device + * @return ptmlReturn_t + * @note before use this api please use ptmlPtlinkEnableAll + */ +ptmlReturn_t PTML_API ptmlGetPtlinkRemoteDevicePciInfo(ptmlDevice_t device, + int port, + ptPciInfo_t *pciInfo); + +/** + * @brief ptmlGetErrorCodeToDescription + * + * @errCode errCode + * @return errorDescription + */ +PTML_API const char *ptmlGetErrorCodeToDescription(int errorCode); + +typedef enum ptmlEventType_enum { + PTML_EVENT_TYPE_PSTATE, + PTML_EVENT_TYPE_ALL, +} ptmlEventType_t; + +typedef enum ptmlDeviceStateChange { + PTMLDEVICE_GOOD_TO_BAD, + PTMLDEVICE_BAD_TO_GOOD, +} ptmlDeviceStateChange_t; + +typedef enum ptmlDeviceState { + PTMLDEVICE_BAD, + PTMLDEVICE_GOOD, +} ptmlDeviceState_t; + +typedef enum ptmlEventStrategy { + PTMLEVENT_UN_MONITOR, + PTMLEVENT_MONITOR, +} ptmlEventStrategy_t; + +typedef struct ptmlEventData { + ptmlDevice_t device; + unsigned long eventType; + unsigned long eventData; +} ptmlEventData_t; + +typedef struct ptmlEvent { + ptmlEventStrategy_t strategy; + ptmlDeviceState_t state; +} ptmlEvent_t; + +typedef struct ptmlEventSet { + ptmlEvent_t deviceEvent[TA_PT_NUM_MAX][PTML_EVENT_TYPE_ALL]; +} ptmlEventSet_t; + +ptmlReturn_t PTML_API ptmlEventSetCreate(ptmlEventSet_t **set); +ptmlReturn_t PTML_API ptmlEventSetFree(ptmlEventSet_t *set); +ptmlReturn_t PTML_API ptmlDeviceRegisterEvents(ptmlDevice_t device, + unsigned long eventTypes, + ptmlEventSet_t *set); + +ptmlReturn_t PTML_API ptmlEventSetWait_v2(ptmlEventSet_t * set, + ptmlEventData_t *data, + unsigned int timeoutms); + +typedef int* ptmlDeviceErrorCode; +PTML_API int ptmlGetDeviceLastError(ptmlDevice_t device, + ptmlDeviceErrorCode DeviceErrorCode); + +#ifdef __cplusplus +} +#endif //! __cplusplus + +#endif //! _S2_ML_H_ diff --git a/third_party/sunrise/backend/include/tang.h b/third_party/sunrise/backend/include/tang.h new file mode 100755 index 000000000..603782319 --- /dev/null +++ b/third_party/sunrise/backend/include/tang.h @@ -0,0 +1,2321 @@ +//////////////////////////////////////////////////////// +// @file tang.h +// tang DRIVER INTERFACE +// @author linan +//////////////////////////////////////////////////////// + +#ifndef _TANG_H_ +#define _TANG_H_ +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#define TA_VERSION_MAJOR 0 +#define TA_VERSION_MINOR 13 +#define TA_VERSION_PATCH 0 + +#define TA_VERSION \ + ((TA_VERSION_MAJOR * 1000) + (TA_VERSION_MINOR * 10) + TA_VERSION_PATCH) + +#if defined(_MSC_VER) +#define TA_DEPRECATED __declspec(deprecated) +#define TA_API_EXPORT __declspec(dllexport) +#define TA_API_IMPORT __declspec(dllimport) +#elif defined(__GNUC__) || defined(__clang__) +#define TA_DEPRECATED __attribute__((deprecated)) +#define TA_API_EXPORT __attribute__((visibility("default"))) +#define TA_API_IMPORT __attribute__((visibility("default"))) +#else +#define TA_DEPRECATED +#define TA_API_EXPORT +#define TA_API_IMPORT +#endif //! UNKNOWN COMPILER + +#if defined(tang_EXPORTS) +#define TA_API TA_API_EXPORT +#else +#define TA_API TA_API_IMPORT +#endif //! For user + +#define COMMAND_MAGIC 0x100d58ba +enum COMMAND_SA { + COMMAND_ASYNC = 0, + COMMAND_SYNC, +}; + +enum MODE_TYPE { + PT_MGR_TYPE = 1, + PT_COLL_TYPE, + OTHER_TYPE, +}; + +enum OPERATIONS { + OPS_LINK_EN = 1, + OPS_LINK_DIS, + OPS_LINK_DETECT, + OPS_LINK_PORTADDR = 4, + OPS_LINK_BIF, + OPS_LINK_P2P_ATTR, + OPS_PEER_ACCESS_CAN = 7, + OPS_PEER_ACCESS_EN, + OPS_PEER_ACCESS_DIS, + OPS_LINK_INIT, + OPS_LINK_PORT_INIT, +}; + +enum COLL_OPS { + COLL_BROADCAST = 0, + COLL_REDUCE, + COLL_ALLGATHER, + COLL_REDUCESCATTER, + COLL_ALLREDUCE, + COLL_MAX_OPS +}; + +enum COMMANDS { + CMD_MAGIC = 0, + CMD_SYNC, + MODE_TYPE, + CMD_ID, + PORT, + DEV_ID, + RDEV_ID, + MSGIN_LEN, + MSGOUT_LEN, +}; + +struct scp_msg_ack { + int retval; // FW irq return value + int status; // simple status value + char payload[0]; // complex struct return +}; + +#define C2CSCP_MSG_HEAD (8) + +typedef uint64_t TAdeviceptr; //!< TANG device pointer + +#define TAdevice_nullptr (TAdeviceptr)0 + +typedef struct TAdevice_s* TAdevice; //!< TANG device +typedef struct TActx_s* TAcontext; //!< TANG context +typedef struct TAfunc_s* TAfunction; //!< TANG function handle +typedef struct TAevent_s* TAevent; //!< TANG event handle +typedef struct TAstream_s* TAstream; //!< TANG stream handle +typedef struct TAmodule_s* TAmodule; //!< TANG module handle +typedef struct TAvariable_s* TAvariable; //!< TANG variable +typedef struct TAgraph_s* TAgraph; //!< TANG graph handle +typedef struct TAgraphExec_s* TAgraphExec; //!< TANG graph exec handle +typedef struct TAgraphNode_s* TAgraphNode; //!< TANG graph node + +typedef struct TAdsoWrapper_s { + uintptr_t data[20]; +} TAdsoWrapper_t; + +/** + * @brief Stream flags. + * @sa __s2StreamFlags. + */ +typedef enum TAstream_flags_e { + TA_STREAM_DEFAULT = 0x0, //!< The default stream creation flag. + TA_STREAM_NON_BLOCKING = 0x1, //!< The non blocking stream creation flag. + //! TA_STREAM_LEGACY = 0x2, //!< The legacy stream creation flag. + //! //!< This flag can only be used internally. + //! //!< User use this flag will cause + //! //!< ::TANG_ERROR_INVALID_VALUE error. +} TAstream_flags; + +typedef enum TAstreamCaptureMode_e { + TA_STREAM_CAPTURE_MODE_GLOBAL = 0, + TA_STREAM_CAPTURE_MODE_THREAD_LOCAL = 1, + TA_STREAM_CAPTURE_MODE_RELAXED = 2, +} TAstreamCaptureMode; + +typedef enum TAstreamCaptureStatus_e { + TA_STREAM_CAPTURE_STATUS_NONE = 0, + TA_STREAM_CAPTURE_STATUS_ACTIVE = 1, + TA_STREAM_CAPTURE_STATUS_INVALIDATED = 2, +} TAstreamCaptureStatus; + +typedef struct TAgraphInfo_s { + int nr_nodes; +} TAgraphInfo; + +typedef struct TAeventTimestamp_s { + uint64_t comp; + uint64_t comp_sw; + uint64_t create; + uint64_t enqueue; + uint64_t writeq_beg; + uint64_t writeq_end; +} TAeventTimestamp; + +typedef enum TAevent_record_flags_e { + //!< The default record flag + TA_EVENT_RECORD_DEFAULT = 0, + + //!< Require hardware event + TA_EVENT_RECORD_HW = 0x0100, + + //!< Require software event + TA_EVENT_RECORD_SW = 0x0200, + + //!< Allow waiting while allocating hardware event. + TA_EVENT_RECORD_ALLOW_BLOCKING = 0x0400, +} TAevent_record_flags; + +typedef enum TAevent_flags_e { + TA_EVENT_DISABLE_TIMING = 0x02, + TA_EVENT_INTERPROCESS = 0x04, +} TAevent_flags; + +typedef enum TAevent_sync_flags_e { + //!< The default synchronization behaviour. + //!< 1. If the event has not been recorded, + //!< taEventSynchronize will return imediately. + //!< 2. If the event has been recorded, + //!< taEventSynchronize will block until the + //!< event is done. + TA_EVENT_SYNC_DEFAULT = 0x00, + + //!< Block until the event is recorded and done. + TA_EVENT_SYNC_RECORDED_AND_DONE = 0x01, +} TAevent_sync_flags; + +#define TA_IPC_HANDLE_SIZE 64U + +struct TAipcMemHandle_s { + unsigned long reserved[TA_IPC_HANDLE_SIZE / sizeof(unsigned long)]; +}; +typedef struct TAipcMemHandle_s TAipcMemHandle; + +struct TAipcEventHandle_s { + unsigned long reserved[TA_IPC_HANDLE_SIZE / sizeof(unsigned long)]; +}; +typedef struct TAipcEventHandle_s TAipcEventHandle; + +enum TAipcMem_flags_e { + TA_IPC_MEM_LAZY_ENABLE_PEER_ACCESS = 0x01, +}; +typedef enum TAipcMem_flags_e TAipcMem_flags; + +#define TA_LAUNCH_PARAM_END 0 +#define TA_LAUNCH_PARAM_INVALIDATE_L1P5 1 +#define TA_LAUNCH_PARAM_ICACHE_FLUSH 2 +#define TA_LAUNCH_PARAM_WORK_MODE 3 +#define TA_LAUNCH_PARAM_MAX_ACTIVE_BLOCK_COUNT_PER_CU 4 +#define TA_LAUNCH_PARAM_SHARE_MEM_MIRROR 5 +#define TA_LAUNCH_PARAM_CLST_DIMX 6 +#define TA_LAUNCH_PARAM_CLST_DIMY 7 +#define TA_LAUNCH_PARAM_CLST_DIMZ 8 + +struct TAextraLaunchParam_s { + unsigned long type; + union { + unsigned long val; + void *ptr; + }; +}; +typedef struct TAextraLaunchParam_s TAextraLaunchParam; + +typedef enum TAmemorytype_e { + TA_MEMORYTYPE_HOST = 0x01, + TA_MEMORYTYPE_DEVICE = 0x02, + TA_MEMORYTYPE_ARRAY = 0x04, + TA_MEMORYTYPE_UNIFIED = 0x05, +} TAmemorytype; + +typedef enum TApointer_attribute_e { + /**< The ::TAcontext on which a pointer is allocated and registered */ + TA_POINTER_ATTRIBUTE_CONTEXT = 1, + + /**< The ::TAmemorytype describing the physical location of a pointer */ + TA_POINTER_ATTRIBUTE_MEMORY_TYPE = 2, + + TA_POINTER_ATTRIBUTE_DEVICE_POINTER = 3, + TA_POINTER_ATTRIBUTE_HOST_POINTER = 4, + TA_POINTER_ATTRIBUTE_DEVICE_ORDINAL = 9, +} TApointer_attribute; + +typedef enum TAfunction_attribute_e { + // The maximum number of threads per block, beyond which + // a lanuch of the function would fail. + // This value depends on the function and the device. + TA_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK = 0, + + // The number of bytes statically allocated shared memory. + TA_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES = 1, + + // The number of bytes of user allocated constant memory. + // This attribute is not implemented in pt200 + TA_FUNC_ATTRIBUTE_CONST_SIZE_BYTES = 2, + + // The number of bytes of local memory used by each thread of the function. + TA_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES = 3, + + // The number of registers used by each thread of this function. + TA_FUNC_ATTRIBUTE_NUM_REGS = 4, + + // The maximum size of dynamically allocated shared memory that + // can be used by this function. + TA_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES = 8, +} TAfunction_attribute; + +typedef enum TAgraphNodeType_enum { + TA_GRAPH_NODE_TYPE_KERNEL = 0, + TA_GRAPH_NODE_TYPE_MEMCPY = 1, + TA_GRAPH_NODE_TYPE_MEMSET = 2, + TA_GRAPH_NODE_TYPE_HOST = 3, + TA_GRAPH_NODE_TYPE_GRAPH = 4, + TA_GRAPH_NODE_TYPE_EMPTY = 5, + TA_GRAPH_NODE_TYPE_WAIT_EVENT = 6, + TA_GRAPH_NODE_TYPE_EVENT_RECORD = 7, +} TAgraphNodeType; + +typedef enum TAmoduleSymbolType_enum { + TA_MODULE_SYMBOL_FUNCTION = 1, + TA_MODULE_SYMBOL_VARIABLE = 2, +} TAmoduleSymbolType; + +typedef struct TAmoduleSymbolHandle_s { + union { + TAfunction function; + TAvariable variable; + }; +#ifdef __cplusplus + TAmoduleSymbolHandle_s() + : function(nullptr) {} + + explicit TAmoduleSymbolHandle_s(TAfunction func) + : function(func) {} + + explicit TAmoduleSymbolHandle_s(TAvariable var) + : variable(var) {} +#endif // __cplusplus +} TAmoduleSymbolHandle; + +/** + * @brief TAmodule symbols iteration call back function type. + * @note If the function returns true, the iteration will stop. + * Always returns false to iterate all symbols. + */ +// typedef bool (*TAmoduleSymbolIterateFn)(TAmoduleSymbolType symbolType, +// const char* symbolName, +// TAmoduleSymbolHandle symbolHandle, +// void* userData); + +typedef void (*TAhostFn)(void* userData); + +typedef struct TANG_HOST_NODE_PARAMS { + TAhostFn fn; + void* userData; +} TANG_HOST_NODE_PARAMS; + +typedef struct TANG_KERNEL_NODE_PARAMS_s { + TAfunction func; + + unsigned int gridDimX; + unsigned int gridDimY; + unsigned int gridDimZ; + unsigned int blockDimX; + unsigned int blockDimY; + unsigned int blockDimZ; + unsigned int sharedMemBytes; + + void** kernelParams; + void** extra; +} TANG_KERNEL_NODE_PARAMS; + +#define L2_CACHE_MP_CNT 4 +#define L2_CACHE_MPX_SUBCNT 16 +#define WARP_SIZE 32 +#define BUFFER_SIZE 32 +#define IPATH_PERF_CNT 10 +#define EXT_PERF_CNT 13 +#define SHM_PERF_CNT 15 +#define CLUSTER_CNT 12 +#define SUBCORE_CNT 8 +#define CLUSTER_CNT_HALF 6 +#define CLUSTER_CNT_3 3 + +typedef struct ptProfileInfo_s { + uint32_t blockDim_x; + uint32_t blockDim_y; + uint32_t blockDim_z; + uint32_t gridDim_x; + uint32_t gridDim_y; + uint32_t gridDim_z; + int __s2Stream_t; + int __s2Context_t; + int regs; + int device; + int SSMem; + int DSMem; + uint32_t max_bkcnt; + uint32_t thread_cnt; + float waves_per_sm; + uint32_t max_warps_per_sched; + float perf_blk_occupation; + uint32_t achieved_active_warps_per_sm; + uint32_t blk_limit_reg; + uint32_t blk_limit_shared_mem; + uint32_t local_memory_size; + uint32_t blk_shm_size; + char funcName[512]; + uint64_t knlTime; + uint32_t knlSubTime[96]; + uint64_t ipathcnt[38]; //!< ipath-10 ext-13 shm-15 perf + uint64_t l2cache[6]; //!< [1]Monitor Read Hit Counter [Offset: 0x610] + //!< [2]Monitor Cacheable Read Request Counter [Offset: 0x608] + //!< [3]Monitor Read Request Counter [Offset: 0x600] + //!< [4]Monitor Cacheable Write Request Counter[Offset: 0x60c] + //!< [5]Monitor Write Request Counter [Offset: 0x604] + //!< [6]Monitor Write Hit Counter [Offset: 0x614] + uint64_t l1p5cache[6]; //!< [1]Monitor Read Hit Counter [Offset: 0x610] + //!< [2]Monitor Cacheable Read Request Counter [Offset: 0x608] + //!< [3]Monitor Read Request Counter [Offset: 0x600] + //!< [4]Monitor Cacheable Write Request Counter[Offset: 0x60c] + //!< [5]Monitor Write Request Counter [Offset: 0x604] + //!< [6]Monitor Write Hit Counter [Offset: 0x614] + uint32_t clock; + uint32_t knlInfo[3]; //!< blk_cnt\knl_rcvd_cnt\knl_cmpl_cnt + uint32_t l2HitRate[L2_CACHE_MP_CNT*L2_CACHE_MPX_SUBCNT][4]; + uint32_t l1p5HitRate[L2_CACHE_MPX_SUBCNT*(CLUSTER_CNT)][4]; + uint64_t extDetail[CLUSTER_CNT*SUBCORE_CNT][EXT_PERF_CNT]; + uint64_t shmDetail[CLUSTER_CNT*SUBCORE_CNT][SHM_PERF_CNT]; + float perf_achieved_warp_occupation; + uint32_t l1invalid; + uint32_t warp_regfile_size; +} ptProfileInfo; + +#ifndef TA_STREAM_LEGACY +#define TA_STREAM_LEGACY ((TAstream)0x01) +#endif //! TA_STREAM_LEGACY + +#ifndef TA_STREAM_PER_THREAD +#define TA_STREAM_PER_THREAD ((TAstream)0x02) +#endif //! TA_STREAM_PER_THREAD + +/** + * If set, host memory is page locked. + * Flag for ::taMemHostAlloc() + */ +#define TA_MEMHOSTALLOC_DEFAULT 0x00 + +/** + * If set, host memory is portable between TANG contexts. + * Flag for ::taMemHostAlloc() + */ +#define TA_MEMHOSTALLOC_PORTABLE 0x01 + +/** + * If set, host memory is mapped into TANG address space and + * ::taMemHostGetDevicePointer() may be called on the host pointer. + * Flag for ::taMemHostAlloc() + */ +#define TA_MEMHOSTALLOC_DEVICEMAP 0x02 + +/** + * If set, host memory is allocated as write-combined - fast to write, + * faster to DMA, slow to read except via SSE4 streaming load instruction + * (MOVNTDQA). + * Flag for ::taMemHostAlloc() + */ +#define TA_MEMHOSTALLOC_WRITECOMBINED 0x04 + +/** + * If set allocate memory from device side and map it to the user space. + * + */ +#define TA_MEMHOSTALLOC_MAP_DEVICE_MEMORY 0x100 + +/** + * If set, host memory is page locked. + * Flag for ::taMemHostRegister() + */ +#define TA_MEMHOSTREGISTER_DEFAULT 0x00 + +/** + * If set, host memory is portable between TANG contexts. + * Flag for ::taMemHostRegister() + */ +#define TA_MEMHOSTREGISTER_PORTABLE 0x01 + +/** + * If set, host memory is mapped into TANG address space and + * ::taMemHostGetDevicePointer() may be called on the host pointer. + * Flag for ::taMemHostRegister() + */ +#define TA_MEMHOSTREGISTER_DEVICEMAP 0x02 + +/** + * If set, the passed memory pointer is treated as pointing to some + * memory-mapped I/O space, e.g. belonging to a third-party PCIe device. + * Flag for ::taMemHostRegister() + */ +#define TA_MEMHOSTREGISTER_IOMEMORY 0x04 + +/** + * If set, the passed memory pointer is treated as pointing to memory that is + * considered read-only by the device. + * Flag for ::taMemHostRegister() + */ +#define TA_MEMHOSTREGISTER_READ_ONLY 0x08 + +/** + * @ingroup PT_ERROR error handling + * @{ + ************************************************/ +/** + * @brief Driver API error codes. + ************************************************/ +typedef enum TAresult_e { + /** + * @brief The API call returned with no errors. + * @note For asynchronous operations, \c TANG_SUCCESS + * just means the operation is ququed on the \c stream + * successfully. + */ + TANG_SUCCESS = 0, + + /** + * @brief This indicates one or more invalid parameters + * are passed to the API call. + */ + TANG_ERROR_INVALID_VALUE = 1, + + /** + * @brief This indicates the API call failed because + * it can not allocate enough memory to perform the requested + * operation. + */ + TANG_ERROR_OUT_OF_MEMORY = 2, + + /** + * @brief This indicates that the PT dirver has not been initialized + * with ::__taInit or thar initialization has failed. + */ + TANG_ERROR_NOT_INITIALIZED = 3, + + /** + * @brief This indicates that the PT driver is int the process of shutting + * down. + */ + TANG_ERROR_DEINITIALIZED = 4, + + //!< The device is remove for some reason. + //!< echo "1" > /sys/../remove + TANG_ERROR_DEVICE_REMOVED = 5, + + //!< The device is reseted. + //!< Example: enable or disable SR-IOV + TANG_ERROR_DEVICE_RESET = 6, + + //!< The operation is not allowed. + TANG_ERROR_NOT_PERMITTED = 7, + + //!< No such file or directroy + TANG_ERROR_NO_SUCH_FILE = 8, + + /** + * This indicates that a kernel launch is requesting resources that can + * never be satisfied by the current device. + */ + TANG_ERROR_INVALID_CONFIGURATION = 9, + + //!< Null pointer is passed as argument but it is not allowed. + TANG_ERROR_NULL_POINTER = 10, + + //!< The kernel mode driver is not compatible with current runtime. + TANG_ERROR_INCOMPATIBLE_DRIVER = 11, + + //!< Can allocate enough resources to perform the requested operation. + TANG_ERROR_OUT_OF_RESOURCES = 12, + + TANG_ERROR_TIMEOUT = 13, + + /** + * @brief This indicates the API call is not implemented + * and just a stub or for the given parameter(s) the function + * has not been implemented yet. + */ + TANG_ERROR_NOT_IMPLEMENTED = 99, + + /** + * @brief No available PT devices. + */ + TANG_ERROR_NO_DEVICE = 100, + + /** + * @brief Invalid device. + */ + TANG_ERROR_INVALID_DEVICE = 101, + + //!< Bad file descriptor. + TANG_ERROR_BAD_FD = 102, + + //!< Normal indicate some invariant are broken + TANG_ERROR_UNREACHABLE_CODE = 103, + + //!< More than one function use the same symbol name. + TANG_ERROR_DUPLICATE_FUNC_NAME = 198, + + //!< More than one global value use the same symbol name. + TANG_ERROR_DUPLICATE_VAR_NAME = 199, + + /** + * @brief + */ + TANG_ERROR_INVALID_IMAGE = 200, + + /** + * @brief This most frequently indicates there is + * no context bound to the current thread. + * This error code is also returned when an invalid + * context is passed to API call. + */ + TANG_ERROR_INVALID_CONTEXT = 201, + + /** + * @brief No context is bound to the calling thread. + */ + TANG_ERROR_NO_CONTEXT_BOUND = 202, + + /** + * @brief Invalid host address encountered. + */ + TANG_ERROR_ILLEGAL_HOST_ADDRESS = 203, + + //!< Context mismatch + TANG_ERROR_CONTEXT_MISMATCH = 204, + + /** + * This indicates that the ::TAlimit passed to the API call is not + * supported by the active device. + */ + TANG_ERROR_UNSUPPORTED_LIMIT = 215, + + /** + * @brief The key is not found! + * + */ + TANG_ERROR_NOT_FOUND = 301, + + // This indicates that a resource required by the API call + // is not in a valid state to perform the requested operation. + TANG_ERROR_ILLEGAL_STATE = 302, + + // This error indicates that the operation is not permitted + // then the stream is capturing. + TANG_ERROR_STREAM_CAPTURE_UNSUPPORTED = 303, + + // This error indicates that the current capture + // sequence on the stream has been invalidated + // due a previous error. + TANG_ERROR_STREAM_CAPTURE_INVALIDATED = 304, + + // This error indicates that the operation whould + // have resulted in a merge of two independent capture + // sequences. + TANG_ERROR_STREAM_CAPTURE_MERGE = 305, + + // This error indicates that the capture was not initiated in this stream. + TANG_ERROR_STREAM_CAPTURE_UNMATCHED = 306, + + // This error indicates that the capture sequence contains a fork that was + // not joined to the primary stream. + TANG_ERROR_STREAM_CAPTURE_UNJOINED = 307, + + // This error indicates that a dependency would have been created which + // crossed the capture sequence boundary. Only implicit in-stream ordering + // dependencies are allowed to cross the boundary. + TANG_ERROR_STREAM_CAPTURE_ISOLATION = 308, + + // This error indicates a disallowed implicit dependency on a current + // capture sequence from TA_STREAM_LEGACY. + TANG_ERROR_STREAM_CAPTURE_IMPLICIT = 309, + + // A stream capture sequence not initiated with the + // ::TA_STREAM_CAPTURE_MODE_RELAXED argument to taStreamBeginCapture was + // passed to ::cuStreamEndCapture in a different thread. + TANG_ERROR_STREAM_CAPTURE_WRONG_THREAD = 310, + + // This error indicates that the operation is not permitted on an event + // which was last recorded in a capturing stream. + TANG_ERROR_CAPTURED_EVENT = 311, + + /** + * @brief This indicates an invalid resource handle + * passed to a API call. + * In general, resource handles are opaque type like + * ::TAstream and ::TAcontext. + */ + TANG_ERROR_INVALID_HANDLE = 400, + + /** + * @brief This error code indicates asynchronous operations issued previously + * have not been completed yet. + */ + TANG_ERROR_NOT_READY = 600, + + /** + * @brief A load or store instruction on an invalid + * memory address occured when the device executing a + * kernel. + * This error makes the process is an inconsitant state. + * The process should be terminated and relanuched. + */ + TANG_ERROR_ILLEGAL_ADDRESS = 700, + + /** + * @brief resouce is not enougn for the kernel + * + */ + TANG_ERROR_LAUNCH_OUT_OF_RESOURCES = 701, + + /** + * @brief lanch kernel timeout + * + */ + TANG_ERROR_LAUNCH_TIMEOUT = 702, + + TANG_ERROR_PEER_ACCESS_ALREADY_ENABLED = 704, + TANG_ERROR_PEER_ACCESS_NOT_ENABLED = 705, + + /** + * @brief The premary for a context has been + * initialized. + */ + TANG_ERROR_PRIMARY_CONTEXT_ACTIVE = 708, + + /** + * @brief + * + */ + TANG_ERROR_NOT_SUPPORTED = 801, + + /** + * @brief This indicates that an unknown internal error has occurred. + */ + TANG_ERROR_UNKNOWN = 999, + + /** + * @brief context is destroyed or in destroying in kernel + * + */ + TANG_ERROR_CONTEXT_IS_DESTROYED = 3000, + + /** + * @brief context is not valid in kernel + * + */ + TANG_ERROR_CONTEXT_INVALID = 3001, + + /** + * @brief stream is destroyed or in destroying in kernel + * + */ + TANG_ERROR_STREAM_IS_DESTROYED = 3002, + + /** + * @brief stream is not valid in kernel + * + */ + TANG_ERROR_STREAM_INVALID = 3003, + + /** + * @brief event is destroyed or in destroying in kernel + * + */ + TANG_ERROR_EVENT_IS_DESTROYED = 3004, + + /** + * @brief event is not valid in kernel + * + */ + TANG_ERROR_EVENT_INVALID = 3005, + + /** + * @brief device memory is not enough for current operation + * + */ + TANG_ERROR_DEVICE_OUT_OF_MEMORY = 3006, + + /** + * @brief device memory is not found + * + */ + TANG_ERROR_DEVICE_MEMORY_NOT_FOUND = 3007, + + /** + * @brief pcie fatal error occured + * + */ + TANG_ERROR_PCIE_FATAL = 3012, + + /** + * @brief pcie non-fatal unrecovered error occured + * + */ + TANG_ERROR_PCIE_NON_FATAL_UNRECOVERED = 3013, + + /** + * @brief no more event exist + * + */ + TANG_ERROR_SCP_EVENT_NOT_EXIST = 3014, + + /** + * @brief record event failed + * + */ + TANG_ERROR_SCP_EVENT_RECORD_FAILED = 3015, + + /** + * @brief scp packet crc check failed + * + */ + TANG_ERROR_SCP_PACKET_CRC_FAILED = 3016, + + /** + * @brief scp dispatch send failed + * + */ + TANG_ERROR_SCP_DISP_SEND_FAILED = 3017, + + /** + * @brief sq write sequence error + * + */ + TANG_ERROR_SCP_SQ_WRITE_INVALID = 3018, + + /** + * @brief udrc pcie xdma packet invalid + * + */ + TANG_ERROR_UDRC_PCIE_DMA_CMD_PACKET_INVALID = 3019, + + /** + * @brief udrc mp dma packet invalid + * + */ + TANG_ERROR_UDRC_MP_DMA_CMD_PACKET_INVALID = 3020, + + /** + * @brief udrc reg packet invalid + * + */ + TANG_ERROR_UDRC_REG_CMD_PACKET_INVALID = 3021, + + /** + * @brief udrc reg access invalid + * + */ + TANG_ERROR_UDRC_REG_ACCESS_INVALID = 3022, + + /** + * @brief aiss cluster is not configured + * + */ + TANG_ERROR_AISS_VF_CTRL_CLUST_USR_NOT_ALLOCATED = 3023, + + /** + * @brief barrier is destroyed or in destroying in kernel + * + */ + TANG_ERROR_BARRIER_IS_DESTROYED = 3024, + + /** + * @brief barrier is not valid in kernel + * + */ + TANG_ERROR_BARRIER_INVALID = 3025, + + /** + * @brief one obj is destroyed or in destroying in kernel + * + */ + TANG_ERROR_IS_DESTROYED = 3026, + + /** + * @brief xdma C2H align mismath + * + */ + TANG_ERROR_XDMA_C2H_ALIGN_MISMATCH = 3300, + + /** + * @brief xdma C2H invalid magic stopped + * + */ + TANG_ERROR_XDMA_C2H_INVALID_MAGIC_STOPPED = 3301, + + /** + * @brief xdma C2H invalid Len + * + */ + TANG_ERROR_XDMA_C2H_INVALID_LEN = 3302, + + /** + * @brief xdma C2H decode error + * + */ + TANG_ERROR_XDMA_C2H_DECODE = 3303, + + /** + * @brief xdma C2H slave + * + */ + TANG_ERROR_XDMA_C2H_SLAVE = 3304, + + /** + * @brief xdma C2H descriptor unsupport request + * + */ + TANG_ERROR_XDMA_C2H_DESC_UNSUPPORT_REQUEST = 3305, + + /** + * @brief xdma C2H descriptor completer abort + * + */ + TANG_ERROR_XDMA_C2H_DESC_COMPLETER_ABORT = 3306, + + /** + * @brief xdma C2H descriptor parity + * + */ + TANG_ERROR_XDMA_C2H_DESC_PARITY = 3307, + + /** + * @brief xdma C2H descriptor header ep + * + */ + TANG_ERROR_XDMA_C2H_DESC_HEADER_EP = 3308, + + /** + * @brief xdma C2H descriptor unexpected comp + * + */ + TANG_ERROR_XDMA_C2H_DESC_UNEXPECTED_COMP = 3309, + + /** + * @brief xdma C2H timeout + * + */ + TANG_ERROR_XDMA_C2H_TIMEOUT = 3310, + + /** + * @brief xdma C2H unknown + * + */ + TANG_ERROR_XDMA_C2H_UNKNOWN = 3311, + + /** + * @brief xdma H2C align mismatch + * + */ + TANG_ERROR_XDMA_H2C_ALIGN_MISMATCH = 3350, + + /** + * @brief xdma H2C invalid magic stopped + * + */ + TANG_ERROR_XDMA_H2C_INVALID_MAGIC_STOPPED = 3351, + + /** + * @brief xdma H2C invalid len + * + */ + TANG_ERROR_XDMA_H2C_INVALID_LEN = 3352, + + /** + * @brief xdma H2C read unsupport request + * + */ + TANG_ERROR_XDMA_H2C_READ_UNSUPPORT_REQUEST = 3353, + + /** + * @brief xdma H2C read completer abort + * + */ + TANG_ERROR_XDMA_H2C_READ_COMPLETER_ABORT = 3354, + + /** + * @brief xdma H2C read parity + * + */ + TANG_ERROR_XDMA_H2C_READ_PARITY = 3355, + + /** + * @brief xdma H2C read header ep + * + */ + TANG_ERROR_XDMA_H2C_READ_HEADER_EP = 3356, + + /** + * @brief xdma H2C read unexpected comp + * + */ + TANG_ERROR_XDMA_H2C_READ_UNEXPECTED_COMP = 3357, + + /** + * @brief xdma H2C decode error + * + */ + TANG_ERROR_XDMA_H2C_DECODE = 3358, + + /** + * @brief xdma H2C slave + * + */ + TANG_ERROR_XDMA_H2C_SLAVE = 3359, + + /** + * @brief xdma H2C descriptor unsupport request + * + */ + TANG_ERROR_XDMA_H2C_DESC_UNSUPPORT_REQUEST = 3360, + + /** + * @brief xdma H2C descriptor completer abort + * + */ + TANG_ERROR_XDMA_H2C_DESC_COMPLETER_ABORT = 3361, + + /** + * @brief xdma H2C descriptor parity + * + */ + TANG_ERROR_XDMA_H2C_DESC_PARITY = 3362, + + /** + * @brief xdma H2C descriptor header ep + * + */ + TANG_ERROR_XDMA_H2C_DESC_HEADER_EP = 3363, + + /** + * @brief xdma H2C descriptor unexpected com + * + */ + TANG_ERROR_XDMA_H2C_DESC_UNEXPECTED_COMP = 3364, + + /** + * @brief xdma H2C timeout + * + */ + TANG_ERROR_XDMA_H2C_TIMEOUT = 3365, + + /** + * @brief xdma H2C unknown + * + */ + TANG_ERROR_XDMA_H2C_UNKNOWN = 3366, + /** + * @brief gpu profling share mem out of size + * + */ + TANG_ERROR_GPU_PROFLING_SHAMEM_OUT_OF_SIZE = 3367, + + /** + * @brief The requested ipc mem is destroied. + * @sa taIpcOpenMemHandle + */ + TANG_ERROR_IPC_MEM_DESTROIED = 3368, +} TAresult; + +/** + * @brief Get the TANG SDK Runtime version. + * + * @param runtimeVersion - Returned runtime version number. + * @return int + * ::TANG_SUCCESS - \p runtimeVersion is a non-null poiner. + * ::TANG_ERROR_INVALID_VALUE - \p runtimeVersion is a null pointer. + * @deleted Do not use this function. + ******************************************************/ +TAresult TA_API taRuntimeGetVersion(int* runtimeVersion); + +/** + * @brief Get the TANG SDK Driver version. + * + * @param driverVersion - Returned driver version number. + * @return int + * ::TANG_SUCCESS - \p driverVersion is a non-null poiner. + * ::TANG_ERROR_INVALID_VALUE - \p driverVersion is a null pointer. + ******************************************************/ +TAresult TA_API taDriverGetVersion(int* driverVersion); + +/** + * @brief Get the kernel mode driver version. + * + * @param kernelDriverVersion + * @return int + */ +TAresult TA_API taKernelDriverGetVersion(int* kernelDriverVersion); + +/** + * @brief Get error description. + * + * @param error + * @param ppstr - Returned Null-terminated string. + * @return int + * ::TANG_SUCCESS - \p error is a valid error code. + * ::TANG_ERROR_INVALID_VALUE - \p error is an invalid + * error code. + ******************************************************/ +TAresult TA_API taGetErrorString(TAresult error, char const** ppstr); + +/** + * @brief Get the string representation of an error code. + * + * @param error + * @param ppstr - Returned Null-terminated string. + * @return int + * ::TANG_SUCCESS - \p error is a valid error code. + * ::TANG_ERROR_INVALID_VALUE - \p error is an invalid + * error code. + ******************************************************/ +TAresult TA_API taGetErrorName(TAresult error, char const** ppstr); +/** @} PT_ERROR */ + +/** + * @brief Initilaize driver module. + * + * This function initialize driver module. + * @param flags Initialization flags for driver API + * @return + * ::TANG_SUCCESS + *******************************************************/ +TAresult TA_API taInit(unsigned int flags); +// TAresult TA_API taDeinit(void); + +/** + * @brief __taDeviceGet + * Get a handle to a compute device + * @param device Pointer to a device handle. + * @param ordinal The device number to get handle for + * @return int + ********************************************************/ +TAresult TA_API taDeviceGet(TAdevice* device, int ordinal); + +/** + * @brief Wait for all work completed + * + * @param device + * @return int + ********************************************************/ +TAresult TA_API taDeviceSynchronize(TAdevice device); + +/** + * @brief Synchronize with the current device of the calling thread. + * + * @warning Not a public API, may change in the future. + * @return int + * ::TANG_SUCCESS + */ +TAresult TA_API __taDeviceSynchronizeCurrent(void); + +/** + * @brief Reset the current device. + * Only the calling process is impacted. + * @return int + * @warning This is a dangerous API. The caller must + * ensure that all resources allocated from the current + * device will not be used again. The most difficult to + * handle is TAcontext which may be pushed onto thread's context + * stack. + * BE CAREFUL. + **********************************************************/ +TAresult TA_API taDeviceReset(void); + +/** + * @brief Returns a handle to a compute device. + * + * @param device - device handle + * @param pciBusId - PCI Bus ID + * @return int + * ::TANG_SUCCESS - \p error is a valid error code. + * ::TANG_ERROR_INVALID_DEVICE - \p error is an invalid + * error code. + * ::TANG_ERROR_INVALID_VALUE - \p error is an invalid + * error code. + */ +TAresult TA_API taDeviceGetByPCIBusId(TAdevice* device, const char* pciBusId); + +/** + * @brief Returns a PCI Bus Id string for the device. + * + * @param pciBusId - PCI Bus ID + * @param len - Maximum length of pciBusId name string + * @param device - device handle + * @return int + * ::TANG_SUCCESS - \p error is a valid error code. + * ::TANG_ERROR_INVALID_DEVICE - \p error is an invalid + * error code. + * ::TANG_ERROR_INVALID_VALUE - \p error is an invalid + * error code. + */ +TAresult TA_API taDeviceGetPCIBusId(char* pciBusId, int len, TAdevice device); + +/** + * Limits + */ +typedef enum TAlimit_enum { + TA_LIMIT_STACK_SIZE = 0x00, /**< GPU thread stack size */ + TA_LIMIT_PRINTF_FIFO_SIZE = 0x01, /**< GPU printf FIFO size */ + TA_LIMIT_MALLOC_HEAP_SIZE = 0x02, /**< GPU malloc heap size */ + TA_LIMIT_DEV_RUNTIME_SYNC_DEPTH = + 0x03, /**< GPU device runtime launch synchronize depth */ + TA_LIMIT_DEV_RUNTIME_PENDING_LAUNCH_COUNT = + 0x04, /**< GPU device runtime pending launch count */ + TA_LIMIT_MAX_L2_FETCH_GRANULARITY = + 0x05, /**< A value between 0 and 128 that indicates the maximum fetch + granularity of L2 (in Bytes). This is a hint */ + TA_LIMIT_MAX +} TAlimit; + +/** + * @brief Get Resource limits of current context + * + * @param [out] pValue + * @param [in] limit + * @return int + * ::TANG_SUCCESS - \p error is a valid error code. + * ::TANG_ERROR_INVALID_VALUE - \p error is an invalid + * error code. + * ::TANG_ERROR_UNSUPPORTED_LIMIT - \p error is an invalid + * error code. + * + */ +TAresult TA_API taCtxGetLimit(size_t* pValue, TAlimit limit); +TAresult TA_API taCtxSetLimit(TAlimit limit, size_t value); + +/** + * @brief Query detail limit information. + * Not a plublic interface. + * + * @param context + * @param limit + * @param pCurrent The current value of the \p limit + * @param pLimit The limit of the \p limit + * @return int + * ::TANG_SUCCESS + * ::TANG_ERROR_INVALID_VALUE + */ +TAresult TA_API __taCtxQueryLimit(TAcontext context, + TAlimit limit, + size_t* pCurrent, + size_t* pLimit); + +/** + * TA device attributes + */ +typedef enum taDeviceAttr { + TA_DEV_ATTR_SHARED_MEM_PER_BLOCK = 0, //!< sharedMemPerBlock + TA_DEV_ATTR_REGS_PER_BLOCK, //!< regsPerBlock + TA_DEV_ATTR_WARP_SIZE, //!< warpSize + TA_DEV_ATTR_MEM_PITCH, //!< memPitch + TA_DEV_ATTR_MAX_THREADS_PER_BLOCK, //!< maxThreadsPerBlock + TA_DEV_ATTR_MAX_THREADS_DIM_X, //!< maxThreadsDimX + TA_DEV_ATTR_MAX_THREADS_DIM_Y, //!< maxThreadsDimY + TA_DEV_ATTR_MAX_THREADS_DIM_Z, //!< maxThreadsDimZ + TA_DEV_ATTR_MAX_GRID_SIZE_X, //!< maxGridSizeX + TA_DEV_ATTR_MAX_GRID_SIZE_Y, //!< maxGridSizeY + TA_DEV_ATTR_MAX_GRID_SIZE_Z, //!< maxGridSizeZ + TA_DEV_ATTR_CLOCK_RATE, //!< clockRate + TA_DEV_ATTR_TOTAL_CONST_MEM, //!< totalConstMem + TA_DEV_ATTR_MULTIPROCESSOR_COUNT, //!< multiProcessorCount + TA_DEV_ATTR_MAX_BLOCKS_PER_MULTIPROCESSOR, //!< maxBlocksPerMultiProcessor + TA_DEV_ATTR_ASYNC_ENGINE_COUNT, //!< asyncEngineCount + TA_DEV_ATTR_MEMORY_CLOCK_RATE, //!< memoryClockRate + TA_DEV_ATTR_MEMORY_BUS_WIDTH, //!< memoryBusWidth + TA_DEV_ATTR_L2_CACHE_SIZE, //!< l2CacheSize + TA_DEV_ATTR_MAX_THREADS_PER_MULTIPROCESSOR, //!< maxThreadsPerMultiProcessor + TA_DEV_ATTR_GLOBAL_L1_CACHE_SUPPORTED, //!< globalL1CacheSupported + TA_DEV_ATTR_LOCAL_L1_CACHE_SUPPORTED, //!< localL1CacheSupported + TA_DEV_ATTR_SHARED_MEM_PER_MULTIPROCESSOR, //!< sharedMemPerMultiprocessor + TA_DEV_ATTR_REGS_PER_MULTIPROCESSOR, //!< regsPerMultiprocessor + TA_DEV_ATTR_STREAM_PRIORITIES_SUPPORTED, //!< streamPrioritiesSupported + TA_DEV_ATTR_CONCURRENT_KERNELS, //!< concurrentKernels + TA_DEV_ATTR_COMPUTE_PREEMPTION_SUPPORTED, //!< computePreemptionSupported + TA_DEV_ATTR_KERNEL_EXEC_TIMEOUT_ENABLED, //!< kernelExecTimeoutEnabled + TA_DEV_ATTR_ECC_ENABLED, //!< ECCEnabled + TA_DEV_ATTR_ACCESS_POLICY_MAX_WINDOW_SIZE, //!< accessPolicyMaxWindowSize + TA_DEV_ATTR_TCC_DRIVER, //!< tccDriver + TA_DEV_ATTR_SINGLE_TO_DOUBLE_PRECISION_PER_RATIO, //!< singleToDoublePrecisionPerfRatio + TA_DEV_ATTR_COOPERATIVE_LAUNCH, //!< cooperativeLaunch + TA_DEV_ATTR_COOPERATIVE_MULTI_DEVICE_LAUNCH, //!< cooperativeMultiDeviceLaunch + TA_DEV_ATTR_PERSISTING_L2_CACHE_MAX_SIZE, //!< persistingL2CacheMaxSize + TA_DEV_ATTR_CAN_MAP_HOST_MEMORY, //!< canMapHostMemory + TA_DEV_ATTR_UNIFIED_ADDRESSING, //!< unifiedAddressing + TA_DEV_ATTR_MANAGED_MEMORY, //!< managedMemory + TA_DEV_ATTR_CONCURRENT_MANAGED_ACCESS, //!< concurrentManagedAccess + TA_DEV_ATTR_DIRECT_MANAGED_MEM_ACCESS_FROM_HOST, //!< directManagedMemAccessFromHost + TA_DEV_ATTR_PAGEABLE_MEMORY_ACCESS, //!< pageableMemoryAccess + TA_DEV_ATTR_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, //!< pageableMemoryAccessUsesHostPageTables + TA_DEV_ATTR_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM, //!< canUseHostPointerForRegisteredMem + TA_DEV_ATTR_HOST_NATIVE_ATOMIC_SUPPORTED, //!< hostNativeAtomicSupported + TA_DEV_ATTR_CAN_FLUSH_REMOTE_WRITES, //!< canFlushRemoteWrites + TA_DEV_ATTR_GPU_OVERLAP, //!< gpuOverlap + TA_DEV_ATTR_INTEGRATED, //!< integrated + TA_DEV_ATTR_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, //!< maxSharedMemoryPerBlockOptin + TA_DEV_ATTR_GPU_DIRECT_RDMA_SUPPORTED, //!< gpuDirectRDMASupported + TA_DEV_ATTR_GPU_DIRECT_RDMA_FLUSH_WRITES_OPTIONS, //!< gpuDirectRDMAFlushWritesOptions + TA_DEV_ATTR_GPU_DIRECT_RDMA_WRITES_ORDERING, //!< gpuDirectRDMAWritesOrdering + TA_DEV_ATTR_MAJOR, //!< major + TA_DEV_ATTR_MINOR, //!< minor + TA_DEV_ATTR_PCI_BUS_ID, //!< pciBusID + TA_DEV_ATTR_PCI_DEVICE_ID, //!< pciDeviceID + TA_DEV_ATTR_PCI_DOMAIN_ID, //!< pciDomainID + TA_DEV_ATTR_IS_MULTI_GPU_BOARD, //!< isMultiGpuBoard + TA_DEV_ATTR_GPU_BOARD_GROUP_ID, //!< multiGpuBoardGroupID + TA_DEV_ATTR_COMPUTE_MODE, //!< computeMode + TA_DEV_ATTR_RESERVED_SHARED_MEMORY_PER_BLOCK, //!< reservedSharedMemoryPerBlock + TA_DEV_ATTR_SPARSE_TANG_ARRAY_SUPPORTED, //!< sparseTangArraySupported + TA_DEV_ATTR_HOST_REGISTER_SUPPORTED, //!< hostRegisterSupported + TA_DEV_ATTR_HOST_REGISTER_READ_ONLY_SUPPORTED, //!< hostRegisterReadOnlySupported + TA_DEV_ATTR_MEMORY_POOLS_SUPPORTED, //!< memoryPoolsSupported + TA_DEV_ATTR_MEMORY_POOL_SUPPORTED_HANDLE_TYPES, //!< memoryPoolSupportedHandleTypes + TA_DEV_ATTR_MAX +} taDeviceAttr; + +TAresult TA_API taDeviceGetAttribute(int* value, taDeviceAttr attr, TAdevice dev); + +TAresult TA_API taDeviceGetName(char* name, int len, TAdevice dev); +TAresult TA_API taDeviceGetUuid(char* uuid, TAdevice dev); +TAresult TA_API taDeviceTotalMem(size_t* bytes, TAdevice dev); + +/** + * @brief Get the number of available devices. + * + * @param count Returns the count of available devices. + * @return int + * ::TANG_SUCCESS + * ::TANG_ERROR_NOT_INITIALIZED + ********************************************************/ +TAresult TA_API taDeviceGetCount(int* count); + +/** + * @brief Gets free and total memory of the current device. + * + * @param free - Returned free memory in bytes + * @param total - Returned total memory in bytes + * @return int + */ +TAresult TA_API taMemGetInfo(size_t* free, size_t* total); + +/** + * @brief Gets free and total memory of \p device. + * + * @param device + * @param free + * @param total + * @return int + */ +TAresult TA_API taDeviceMemGetInfo(TAdevice device, size_t* free, size_t* total); + +/** + * @brief Create a context. + * + * @param pctx - Returned newly created context. + * @param flags - Flags for creating the context. + * @param dev - The device ID. + * @return int + * ::TANG_SUCCESS + * ::TANG_ERROR_OUT_OF_MEMORY + * @note When the context is no longer used, the caller + * should call \c taCtxDestroy to destroy the context. + * @sa taCtxDestroy + ********************************************************/ +TAresult TA_API taCtxCreate(TAcontext* pctx, unsigned int flags, TAdevice dev); + +/** + * @brief Destroy the context \p ctx. + * + * @param ctx + * @return int + * ::TANG_SUCCESS + * ::TANG_ERROR_INVALID_CONTEXT + * @note It's the caller's responsibility ensure that the + * context \p ctx is no longer referenced by other objects. + * @taCtxCreate + ********************************************************/ +TAresult TA_API taCtxDestroy(TAcontext ctx); + +/** + * @brief Bind the context \p ctx to the calling thread. + * + * @param ctx - The context to be bound to the calling thread. + * @return int + * @note If \p ctx is NULL, this function just POPs the context + * stack of the calling thread if the context stack is not + * empty. + * If \p ctx is no NULL, this function replace the top of + * the stack if the context stack is no empty. + *********************************************************/ +TAresult TA_API taCtxSetCurrent(TAcontext ctx); + +/** + * @brief Get the current context bound to the calling + * thread. + * + * @param pctx - Returned context bound to the calling thread. + * @return int + * ::TANG_SUCCESS + * ::TANG_ERROR_NO_CONTEXT_BOUND + ********************************************************/ +TAresult TA_API taCtxGetCurrent(TAcontext* pctx); + +/** + * @brief Query the current context bound to the calling + * thread. + * + * @param pCtx + * @return int + */ +TAresult TA_API taCtxQueryCurrent(TAcontext* pCtx); + +/** + * @brief Get the device ID of the current context bound + * to the call thread. + * + * @param dev - Returned device ID of the current context. + * @return int + * ::TANG_SUCCESS + * ::TANG_ERROR_NO_CONTEXT_BOUND + ********************************************************/ +TAresult TA_API taCtxGetCurrentDevice(TAdevice* dev); + +TAresult TA_API taCtxGetDevice(TAcontext ctx, TAdevice* dev); +TAresult TA_API taCtxGetOrdinal(TAcontext ctx, int* ordinal); + +TAresult TA_API taCtxGetFunction(TAcontext ctx, + const void* hostFunc, + TAfunction* phFunc); + +TAresult TA_API taCtxRegisterFunction(TAcontext ctx, + const void* hostFunc, + TAfunction hFunc); + +TAresult TA_API taCtxGetBuiltInFunction(TAcontext ctx, + const char* funcName, + TAfunction* phFunc); + +TAresult TA_API taCtxRegisterBuiltInFunction(TAcontext ctx, + const char* funcName, + TAfunction hFunc); + +TAresult TA_API taCtxGetVariable(TAcontext ctx, + const void* hostVar, + TAvariable* hVar); + +TAresult TA_API taCtxRegisterVariable(TAcontext ctx, + const void* hostVar, + TAvariable hVar); + +TAresult TA_API taCtxPushCurrent(TAcontext ctx); +TAresult TA_API taCtxPopCurrent(TAcontext* ctx); +TAresult TA_API taCtxSynchronize(TAcontext ctx); + +/** + * @ingroup Primary Context Management. + * @{ + */ +/** + * @brief Retain the primary context of \p dev. + * + * @param pctx Pointer to receive the primary context handle. + * @param dev + * @return int + ********************************************************/ +TAresult TA_API taDevicePrimaryCtxRetain(TAcontext* pctx, TAdevice dev); + +/** + * @brief Release the primary context of \p dev. + * + * @param dev + * @return int + ********************************************************/ +TAresult TA_API taDevicePrimaryCtxRelease(TAdevice dev); + +/** + * @brief Reset the primary context of \p dev. + * + * @param dev + * @return int + ********************************************************/ +TAresult TA_API taDevicePrimaryCtxReset(TAdevice dev); + +/** + * @brief Get the state of primary context of device \p dev. + * + * @param dev Device to get primary context's state for. + * @param flags Pointer to receive the flags. + * @param active + * @return int + **********************************************************/ +TAresult TA_API taDevicePrimaryCtxGetState(TAdevice dev, + unsigned int* flags, + int* active); + +/** + * @brief Set flags for the primary context of the device \p dev. + * + * @param dev + * @param flags + * @return int + */ +TAresult TA_API taDevicePrimaryCtxSetFlags(TAdevice dev, unsigned int flags); +/** }@ */ + +TAresult TA_API taDeviceGetOrdinal(TAdevice dev, int* ordinal); + +/** + * @brief Allocate a block of memory in the device. + * + * @param dptr Receives the allocated device memory block handle. + * @param size + * @return int + * On success, zero is returned. + * @sa taMemFree + ***********************************************************/ +TAresult TA_API taMemAlloc(TAdeviceptr* dptr, size_t size); +TAresult TA_API taMemAlloc(TAdeviceptr* dptr, size_t size); +TAresult TA_API taMemAllocAsync(TAdeviceptr* dptr, size_t size, TAstream stream); +TAresult TA_API taMemAllocAsync_ptsz(TAdeviceptr* dptr, + size_t size, + TAstream stream); + +/** + * @brief Free a device memory block. + * + * @param dptr Device memory block handle. + * @return int + * On success, zero is returned. + * @sa taMemAlloc + ************************************************************/ +TAresult TA_API taMemFree(TAdeviceptr dptr); +TAresult TA_API taMemFreeAsync(TAdeviceptr dptr, TAstream hStream); +TAresult TA_API taMemFreeAsync_ptsz(TAdeviceptr dptr, TAstream hStream); + +/** + * @brief Allocate page locked host memory. + * + * @param hptr Pointer to the allocated page locked host memory + * @param sizeBytes Requested memory size + * @return int + * On success, zero is returned. + * @sa taMemFreeHost + ***********************************************************/ +TAresult TA_API taMemAllocHost(void** hptr, size_t sizeBytes); + +/** + * @brief Allocate page locked host memory. + * + * @param hptr Pointer to the allocated page locked host memory + * @param sizeBytes Requested memory size + * @param flags See below. + * flags: + * - #TA_MEMHOSTALLOC_PORTABLE Memory is considered registered by all + *contexts. + * - #TA_MEMHOSTALLOC_DEVICEMAP Map the allocation into the address space + *for the current device. + * - #TA_MEMHOSTALLOC_WRITECOMBINED Allocates the memory as write-combined (WC). + * TANG does not support IOMMU on device side, so flags of + *TA_MEMHOSTALLOC_DEVICEMAP and TA_MEMHOSTALLOC_WRITECOMBINED will always return + *false. + * @return int + * On success, zero is returned. + * @sa taMemFreeHost + ***********************************************************/ +TAresult TA_API taMemHostAlloc(void** hptr, size_t sizeBytes, unsigned int flags); +TAresult TA_API taMemHostGetDevicePointer(TAdeviceptr* pdptr, + void* pHost, + unsigned int flags); + +/** + * @brief Get the flags that were used for allocation. + * + * @param pFlags + * @param p + * @return int + */ +TAresult TA_API taMemHostGetFlags(unsigned int* pFlags, void* p); + +/** + * @brief Free page locked host memory. + * + * @param hptr Pointer to memory to be freed + * @return int + * On success, zero is returned. + * @sa taMemAllocHost, taMemHostAlloc + ************************************************************/ +TAresult TA_API taMemFreeHost(void* hptr); + +/** + * @brief Register host memory as page locked memory. + * + * @param hptr Pointer to host memory to be registered. + * @param sizeBytes Requested memory size + * @param flags See below. + * flags: + * - #TA_MEMHOSTREGISTER_PORTABLE Memory is considered registered by all + *contexts. + * - #TA_MEMHOSTREGISTER_DEVICEMAP Map the allocation into the address space for + *the current device. + * - #TA_MEMHOSTREGISTER_IOMEMORY The passed memory pointer is treated as + *pointing to some memory-mapped I/O space. + * - #TA_MEMHOSTREGISTER_READ_ONLY The passed memory pointer is treated as + *pointing to memory that is considered read-only by the device. TANG does not + *support IOMMU on device side, so flags of TA_MEMHOSTREGISTER_DEVICEMAP and + *TA_MEMHOSTREGISTER_IOMEMORY and TA_MEMHOSTREGISTER_READ_ONLY will always + *return false. + * @return int + * On success, zero is returned. + * @sa taMemHostUnregister + ***********************************************************/ +TAresult TA_API taMemHostRegister(void* hptr, size_t sizeBytes, unsigned int flags); + +/** + * @brief Un-register host pointer + * + * @param hptr Host pointer previously registered + * @return int + * On success, zero is returned. + * @sa taMemHostRegister + ************************************************************/ +TAresult TA_API taMemHostUnregister(void* hptr); + +/** + * @brief Returns information about a pointer. + * + * @param data - Pointer to the returned attribute value. + * @param attr - Pointer attribute to query + * @param ptr - Pointer to be queried. + * @return int + * ::TANG_SUCCESS + * ::TANG_ERROR_INVALID_VALUE + */ +TAresult TA_API taPointerGetAttribute(void* data, + TApointer_attribute attr, + TAdeviceptr ptr); + +TAresult TA_API taMemset(TAdeviceptr dptr, int value, size_t size); +TAresult TA_API taMemset_ptds(TAdeviceptr dptr, int value, size_t size); +TAresult TA_API taMemsetAsync(TAdeviceptr dptr, + int value, + size_t size, + TAstream stream); +TAresult TA_API taMemsetAsync_ptsz(TAdeviceptr dptr, + int value, + size_t size, + TAstream stream); + +/** + * @brief Copy data from host memory to host memory. + * + * @param dstHost - Host destination data address. + * @param srcHost - Host source data address. + * @param size - The size in bytes of data to be copied. + * @return int + * ::TANG_SUCCESS + ************************************************************/ +TAresult TA_API taMemcpyH2H(void* dstHost, const void* srcHost, size_t size); + +/** + * @brief Copy data from host memory to host memory. + * + * @param dstHost - Host destination data address. + * @param srcHost - Host source data address + * @param size The size in bytes of data to be copied. + * @return int + ************************************************************/ +TAresult TA_API taMemcpyH2H_ptds(void* dstHost, + const void* srcHost, + size_t size); + + +/** + * @brief Copy data from host memory to device memory. + * + * @param dstDevice - Device destination data address. + * @param srcHost - Host source data address. + * @param size - The size in bytes of data to be copied. + * @return int + * ::TANG_SUCCESS + ************************************************************/ +TAresult TA_API taMemcpyH2D(TAdeviceptr dstDevice, const void* srcHost, size_t size); + +/** + * @brief Copy data from host memory to device memory. + * + * @param dstDevice Device destination data address + * @param srcHost Host source data address + * @param size The size in bytes of data to be copied. + * @return int + ************************************************************/ +TAresult TA_API taMemcpyH2D_ptds(TAdeviceptr dstDevice, + const void* srcHost, + size_t size); + +/** + * @brief Copy data from device to host. + * + * @param dstHost + * @param srcDevice + * @param size + * @return int + * ::TANG_SUCCESS + *************************************************************/ +TAresult TA_API taMemcpyD2H(void* dstHost, TAdeviceptr srcDevice, size_t size); + +/** + * @brief + * + * @param dstHost + * @param srcDevice + * @param size + * @return int + *************************************************************/ +TAresult TA_API taMemcpyD2H_ptds(void* dstHost, TAdeviceptr srcDevice, size_t size); + +/** + * @brief + * + * @param dstDevice + * @param srcDevice + * @param size + * @return int + * ::TANG_ERROR_NOT_IMPLEMENTED. + *************************************************************/ +TAresult TA_API taMemcpyD2D(TAdeviceptr dstDevice, + TAdeviceptr srcDevice, + size_t size); + +/** + * @brief + * + * @param dstDevice + * @param srcDevice + * @param size + * @return int + *************************************************************/ +TAresult TA_API taMemcpyD2D_ptds(TAdeviceptr dstDevice, + TAdeviceptr srcDevice, + size_t size); + +TAresult TA_API taMemcpyH2HAsync(void* dstHost, + const void* srcHost, + size_t size, + TAstream stream); + +TAresult TA_API taMemcpyH2HAsync_ptsz(void* dstHost, + const void* srcHost, + size_t size, + TAstream stream); + +TAresult TA_API taMemcpyH2DAsync(TAdeviceptr dstDevice, + const void* srcHost, + size_t size, + TAstream stream); + +TAresult TA_API taMemcpyH2DAsync_ptsz(TAdeviceptr dstDevice, + const void* srcHost, + size_t size, + TAstream stream); + +TAresult TA_API taMemcpyD2HAsync(void* dstHost, + TAdeviceptr srcDevice, + size_t size, + TAstream stream); + +TAresult TA_API taMemcpyD2HAsync_ptsz(void* dstHost, + TAdeviceptr srcDevice, + size_t size, + TAstream stream); + +TAresult TA_API taMemcpyD2DAsync(TAdeviceptr dstDevice, + TAdeviceptr srcDevice, + size_t size, + TAstream stream); + +TAresult TA_API taMemcpyD2DAsync_ptsz(TAdeviceptr dstDevice, + TAdeviceptr srcDevice, + size_t size, + TAstream stream); + +TAresult TA_API taStreamCreate(TAstream* pStream, unsigned int flags); +TAresult TA_API taStreamCreateWithPriority(TAstream* pstream, + unsigned int flags, + int priority); +TAresult TA_API taStreamGetPriority(TAstream hstream, int* priority); +TAresult TA_API taStreamGetPriority_ptsz(TAstream hstream, int* priority); +TAresult TA_API taStreamGetFlags(TAstream hstream, unsigned int* priority); +TAresult TA_API taStreamGetFlags_ptsz(TAstream hstream, unsigned int* priority); +TAresult TA_API taStreamGetId(TAstream hstream, int* pId); +TAresult TA_API taStreamGetId_ptsz(TAstream hstream, int* pId); +TAresult TA_API taStreamDestroy(TAstream hStream); + +/********************************************************* + ********************************************************/ +#ifndef TANGRT_DEVICE_P2P_ATTR_ENUM +#define TANGRT_DEVICE_P2P_ATTR_ENUM +/** + * TANG Device P2P attributes + * TODO: This is design bug fix this. Remove this + */ +typedef enum tangDeviceP2PAttr { + tangDevP2PAttrPerformanceRank = 1, + tangDevP2PAttrAccessSupported = 2, + tangDevP2PAttrNativeAtomicSupported = 3, + tangDevP2PAttrTangArrayAccessSupported = 4, +} tangDeviceP2PAttr; +#endif // TANGRT_DEVICE_P2P_ATTR_ENUM + +TAresult TA_API taStreamC2Ctransfers(TAstream hStream, + uint32_t* cmd, + uint32_t cmd_count, + uint64_t device_addr, + uint32_t mem_size); + +TAresult TA_API taDeviceGetP2PAttribute(int* value, + tangDeviceP2PAttr attr, + int srcDevice, + int dstDevice); +TAresult TA_API taDeviceGetPeerPointer(int device, + int port, + void* peerAddr, + void** accessAddr); +TAresult TA_API taDeviceCanAccessPeer(int* canAccessPeer, + int device, + int peerDevice); +TAresult TA_API taDeviceEnablePeerAccess(int peerDevice, unsigned int flags); +TAresult TA_API taDeviceDisablePeerAccess(int peerDevice); +TAresult TA_API taMemcpyPeer(void* dst, + int dstDevice, + const void* src, + int srcDevice, + size_t count); + +TAresult TA_API taMemcpyPeerAsync(void* dst, + int dstDevice, + const void* src, + int srcDevice, + size_t count, + TAstream stream); + +//!< @ingroup Memory between peers. +//!< Copy data by HDMA +//!< @{{{ +TAresult TA_API taMemcpyPeer_v2(TAdeviceptr dst, + TAdevice dstDevice, + TAdeviceptr src, + TAdevice srcDevice, + size_t size); + +TAresult TA_API taMemcpyPeer_v2_ptds(TAdeviceptr dst, + TAdevice dstDevice, + TAdeviceptr src, + TAdevice srcDevice, + size_t size); + +TAresult TA_API taMemcpyPeerAsync_v2(TAdeviceptr dst, + TAdevice dstDevice, + TAdeviceptr src, + TAdevice srcDevice, + size_t size, + TAstream hStream); + +TAresult TA_API taMemcpyPeerAsync_v2_ptsz(TAdeviceptr dst, + TAdevice dstDevice, + TAdeviceptr src, + TAdevice srcDevice, + size_t size, + TAstream hStream); + +TAresult TA_API taMemcpyFromPeerAsync(TAdeviceptr dst, + TAdeviceptr src, + TAdevice srcDevice, + size_t size, + TAstream stream); + +TAresult TA_API taMemcpyToPeerAsync(TAdeviceptr dst, + TAdevice dstDevice, + TAdeviceptr src, + size_t size, + TAstream stream); + +TAresult TA_API taMemcpyFromPeerAsync_ptsz(TAdeviceptr dst, + TAdeviceptr src, + TAdevice srcDevice, + size_t size, + TAstream stream); + +TAresult TA_API taMemcpyToPeerAsync_ptsz(TAdeviceptr dst, + TAdevice dstDevice, + TAdeviceptr src, + size_t size, + TAstream stream); +//!< }}}@ + +TAresult TA_API taStreamWaitEvent(TAstream hStream, + TAevent hEvent, + unsigned int flags); +TAresult TA_API taStreamWaitEvent_ptsz(TAstream hStream, + TAevent hEvent, + unsigned int flags); +TAresult TA_API taStreamSynchronize(TAstream hStream); +TAresult TA_API taStreamSynchronize_ptsz(TAstream hStream); +TAresult TA_API taStreamQuery(TAstream hStream); +TAresult TA_API taStreamQuery_ptsz(TAstream hStream); + +TAresult TA_API taStreamBeginCapture(TAstream hStream, TAstreamCaptureMode mode); +TAresult TA_API taStreamBeginCapture_ptsz(TAstream hStream, + TAstreamCaptureMode mode); +TAresult TA_API taThreadExchangeStreamCaptureMode(TAstreamCaptureMode* mode); +TAresult TA_API taStreamEndCapture(TAstream hStream, TAgraph* phGraph); +TAresult TA_API taStreamEndCapture_ptsz(TAstream hStream, TAgraph* phGraph); +TAresult TA_API taStreamIsCapturing(TAstream hStream, + TAstreamCaptureStatus* captureStatus); +TAresult TA_API taStreamIsCapturing_ptsz(TAstream hStream, + TAstreamCaptureStatus* captureStatus); +TAresult TA_API taStreamGetCaptureInfo(TAstream hStream, + TAstreamCaptureStatus* pStatus, + unsigned long long* pId, + TAgraph* pGraph, + const TAgraphNode** deps, + size_t* numDeps); +TAresult TA_API taStreamGetCaptureInfo_ptsz(TAstream hStream, + TAstreamCaptureStatus* pStatus, + unsigned long long* pId, + TAgraph* pGraph, + const TAgraphNode** deps, + size_t* numDeps); + +TAresult TA_API taGraphInstantiateWithFlags(TAgraphExec* phGraphExec, + TAgraph hGraph, + unsigned long long flags); +TAresult TA_API taGraphLaunch(TAgraphExec hGraphExec, TAstream hStream); +TAresult TA_API taGraphLaunch_ptsz(TAgraphExec hGraphExec, TAstream hStream); +TAresult TA_API taGraphDestroy(TAgraph hGraph); +TAresult TA_API taGraphExecDestroy(TAgraphExec hGraphExec); +TAresult TA_API taGraphGetInfo(TAgraph hGraph, TAgraphInfo* pInfo); +TAresult TA_API taGraphCreate(TAgraph* phGraph, unsigned int flags); + +TAresult TA_API taGraphAddHostNode(TAgraphNode* phGraphNode, + TAgraph hGraph, + const TAgraphNode* dependencies, + size_t numDependencies, + const TANG_HOST_NODE_PARAMS* nodeParams); + +TAresult TA_API taGraphAddKernelNode(TAgraphNode* phGraphNode, + TAgraph hGraph, + const TAgraphNode* dependencies, + size_t numDependencies, + const TANG_KERNEL_NODE_PARAMS* nodeParams); + +TAresult TA_API taEventCreate(TAevent* phEvent, unsigned int flags); +TAresult TA_API taEventDestroy(TAevent hEvent); +TAresult TA_API taEventRecord(TAevent hEvent, TAstream hStream); +TAresult TA_API taEventRecord_ptsz(TAevent hEvent, TAstream hStream); +TAresult TA_API taEventRecordWithFlags(TAevent hEvent, + TAstream hStream, + unsigned int flags); +TAresult TA_API taEventRecordWithFlags_ptsz(TAevent hEvent, + TAstream hStream, + unsigned int flags); +TAresult TA_API taEventSynchronize(TAevent hEvent); +TAresult TA_API taEventElapsedTime(float* pMilliseconds, + TAevent hStart, + TAevent hEnd); +TAresult TA_API taEventQuery(TAevent hEvent); +TAresult TA_API taEventQueryTimestamp(TAevent hEvent, TAeventTimestamp* pTs); + +TAresult TA_API taEventSynchronizeWithFlags(TAevent hEvent, unsigned int flags); + +void TA_API taDsoWrapperInit(TAdsoWrapper_t* dso); +void TA_API taDsoWrapperDeinit(TAdsoWrapper_t* dso); + +TAresult TA_API taGetBuiltinModule(TAmodule* phModule, const char* name); + +TAresult TA_API taModuleLoad(TAmodule* phModule, const char* filename); +TAresult TA_API taModuleLoadData(TAmodule* phModule, const void* image, size_t size); +TAresult TA_API taModuleUnload(TAmodule hModule); + +TAresult TA_API taModuleLoadFatBinaryManaged(TAmodule* phModule, + const void* fatbin, + const char* fatbinInfo, + TAdsoWrapper_t* dso); + +TAresult TA_API taModuleUnloadManaged(TAmodule hModule); + +/** + * @brief Get the module symbol type name. + * + * @param name The pointer to receive the name of the type. + * @param type The symbol type. + * @return TAresult + * ::TANG_SUCCESS if the type is a valid value. + * ::TANG_ERROR_INVALID_VALUE if the type is not a valid type. + */ +TAresult TA_API taModuleSymbolTypeGetName(char const** name, TAmoduleSymbolType type); + +/** + * @brief Iterate symbols in \p hmod. + * + * @param hmod + * @param fn The call back function. Return true will cause the + * iteration to stop. + * @param userData + * @return TAresult + */ +// TAresult TA_API taModuleIterateSymbols(TAmodule hmod, +// TAmoduleSymbolIterateFn fn, +// void* userData); + +TAresult TA_API taModuleGetFunction(TAfunction* hfunc, + TAmodule hmod, + const char* name); + +TAresult TA_API taFuncGetAttribute(int* pi, + TAfunction_attribute attr, + TAfunction hfunc); + +TAresult TA_API taFuncGetModule(TAmodule* hmod, TAfunction func); + +TAresult TA_API taFunctionGetAddress(TAfunction func, TAdeviceptr* address); + +TAresult TA_API taFunctionGetNumArgs(TAfunction func, size_t* numArgs); + +TAresult TA_API taFunctionGetInfo(TAfunction func, + TAdeviceptr* address, + size_t* lenArgs, + size_t* thread_regfile_size, + size_t* shared_regfile_base, + size_t* shared_regfile_size, + size_t* warp_regfile_size, + size_t* local_memory_size, + size_t* static_shared_mem_size, + size_t* shared_memory_mirror, + size_t* max_threads_per_block, + size_t* max_dynamic_shared_mem_size_per_block, + size_t* max_block_count); + +TAresult TA_API taModuleGetVariable(TAvariable* hVar, + TAmodule hMod, + const char* varName); + +TAresult TA_API taVariableGetInfo(TAvariable hVar, + TAdeviceptr* address, + size_t* size); + +TAresult TA_API taVariableCopyFromDevice(TAvariable hVar, + TAdeviceptr src, + size_t size, + size_t offset); + +TAresult TA_API taVariableCopyFromDevice_ptds(TAvariable hVar, + TAdeviceptr src, + size_t size, + size_t offset); + +TAresult TA_API taVariableCopyFromDeviceAsync(TAvariable hVar, + TAdeviceptr src, + size_t size, + size_t offset, + TAstream stream); + +TAresult TA_API taVariableCopyFromDeviceAsync_ptsz(TAvariable hVar, + TAdeviceptr src, + size_t size, + size_t offset, + TAstream stream); + +TAresult TA_API taVariableCopyFromHost(TAvariable hVar, + const void* src, + size_t size, + size_t offset); + +TAresult TA_API taVariableCopyFromHost_ptds(TAvariable hVar, + const void* src, + size_t size, + size_t offset); + +TAresult TA_API taVariableCopyFromHostAsync(TAvariable hVar, + const void* src, + size_t size, + size_t offset, + TAstream stream); + +TAresult TA_API taVariableCopyFromHostAsync_ptsz(TAvariable hVar, + const void* src, + size_t size, + size_t offset, + TAstream stream); + +TAresult TA_API taVariableCopyToDevice(TAdeviceptr dst, + TAvariable hVar, + size_t size, + size_t offset); + +TAresult TA_API taVariableCopyToDevice_ptds(TAdeviceptr dst, + TAvariable hVar, + size_t size, + size_t offset); + +TAresult TA_API taVariableCopyToDeviceAsync(TAdeviceptr dst, + TAvariable hVar, + size_t size, + size_t offset, + TAstream stream); + +TAresult TA_API taVariableCopyToDeviceAsync_ptsz(TAdeviceptr dst, + TAvariable hVar, + size_t size, + size_t offset, + TAstream stream); + +TAresult TA_API taVariableCopyToHost(void* dst, + TAvariable hVar, + size_t size, + size_t offset); + +TAresult TA_API taVariableCopyToHost_ptds(void* dst, + TAvariable hVar, + size_t size, + size_t offset); + +TAresult TA_API taVariableCopyToHostAsync(void* dst, + TAvariable hVar, + size_t size, + size_t offset, + TAstream stream); + +TAresult TA_API taVariableCopyToHostAsync_ptsz(void* dst, + TAvariable hVar, + size_t size, + size_t offset, + TAstream stream); + +/** + * @brief Enqueue A raw SCP command packet onto stream. + * + * @param stream The stream. + * @param regs The SCP command packet. + * @param size The size of the command packet in byte. + * @return int + * @warning A raw SCP command packet needs four bytes aligned. + * The \p size must be integral multiple of 4. + ****************************************************************/ +TAresult TA_API taEnqueueCommand(TAstream stream, void* regs, size_t size); + +TAresult TA_API taEnqueueCommand_ptsz(TAstream stream, void* regs, size_t size); + +/** + * @brief Launch a kernel function. + * + * @param func + * @param gridX + * @param gridY + * @param gridZ + * @param blockX + * @param blockY + * @param blockZ + * @param sharedMemBytes + * @param stream + * @param funcParams + * @param extra + * @code {.cpp} + * + * @endcode + * @param extra + * @return int + * ::TANG_SUCCESS + * ::TANG_ERROR_OUT_OF_MEMORY + * ::TANG_ERROR_NOT_IMPLEMENTED + * @sa taModuleGetFunction + * @sa TA_STREAM_LEGACY + ********************************************************/ +TAresult TA_API taLaunchFunction(TAfunction func, + unsigned int gridX, + unsigned int gridY, + unsigned int gridZ, + unsigned int blockX, + unsigned int blockY, + unsigned int blockZ, + unsigned int sharedMemBytes, + TAstream stream, + void** funcParams, + void** extra); + +TAresult TA_API taLaunchFunction_ptsz(TAfunction func, + unsigned int gridX, + unsigned int gridY, + unsigned int gridZ, + unsigned int blockX, + unsigned int blockY, + unsigned int blockZ, + unsigned int sharedMemBytes, + TAstream stream, + void** funcParams, + void** extra); + +TAresult TA_API taLaunchKernel(TAfunction func, + unsigned int gridX, + unsigned int gridY, + unsigned int gridZ, + unsigned int blockX, + unsigned int blockY, + unsigned int blockZ, + unsigned int sharedMemBytes, + TAstream stream, + void** funcParams, + void** extra); + +TAresult TA_API taLaunchKernel_ptsz(TAfunction func, + unsigned int gridX, + unsigned int gridY, + unsigned int gridZ, + unsigned int blockX, + unsigned int blockY, + unsigned int blockZ, + unsigned int sharedMemBytes, + TAstream stream, + void** funcParams, + void** extra); + +//!< fn(usrData) +TAresult TA_API taLaunchHostFunc(TAstream hStream, TAhostFn fn, void* userData); +TAresult TA_API taLaunchHostFunc_ptsz(TAstream hStream, TAhostFn fn, void* userData); + +typedef void (*TAstreamCallback)(TAstream hStream, + TAresult status, + void* userData); + +//!< callback(hStream, status, userData); +TAresult TA_API taStreamAddCallback(TAstream hStream, + TAstreamCallback callback, + void* userData, + unsigned int flags); + +TAresult TA_API taStreamAddCallback_ptsz(TAstream hStream, + TAstreamCallback callback, + void* userData, + unsigned int flags); + +//!< proxy(proxy_data, func, func_data, error) +TAresult TA_API taLaunchHostFuncProxy(TAstream stream, + void* proxy, + void* proxy_data, + void* func, + void* func_data); + +TAresult TA_API taLaunchHostFuncProxy_ptsz(TAstream stream, + void* proxy, + void* proxy_data, + void* func, + void* func_data); + +TAresult TA_API taOccupancyMaxActiveBlocksPerMultiprocessor(int* numBlocks, + TAfunction func, + int blockSize, + size_t dynamicSMemSize); +TAresult TA_API taProfilerStart(); +TAresult TA_API taProfilerStop(); + +/** + * @brief Gets an interprocess communication memory handle from device memory + * allocated by tangMalloc or cuMemAlloc + * + * @param pHandle + * @param dptr + * @return + * ::TANG_SUCCESS + * ::TANG_ERROR_INVALID + * ::TANG_ERROR_OUT_OF_MEMORY + * @sa + * ::taIpcOpenMemHandle + * ::taIpcCloseMemHandle + */ +TAresult TA_API taIpcGetMemHandle(TAipcMemHandle* pHandle, TAdeviceptr dptr); + +/** + * @brief Opens an interprocess communication memory handle exported from another + * process and map it into the current context and returns a device pointer. + * + * @param pdptr + * @param handle + * @param flags + * @return + * ::TANG_SUCCESS + * ::TANG_ERROR_IPC_MEM_DESTROIED + * ::TANG_ERROR_OUT_OF_MEMORY + * @sa + * ::taIpcGetMemHandle + * ::taIpcCloseMemHandle + */ +TAresult TA_API taIpcOpenMemHandle(TAdeviceptr *pdptr, + TAipcMemHandle handle, + unsigned int flags); + +/** + * @brief Unmap the memory got from taIpcOpenMemHandle. + * + * @param dptr + * @return int + */ +TAresult TA_API taIpcCloseMemHandle(TAdeviceptr dptr); + +/** + * @brief Gets an interprocess event handle. The event must be created with + * ::TA_EVENT_INTERPROCESS flag set. + * + * @param pHandle + * @param event + * @return int + */ +TAresult TA_API taIpcGetEventHandle(TAipcEventHandle* pHandle, TAevent event); + +/** + * @brief Opens an interprocess event handle for the calling process. + * + * Opens an interprocess event handle exported from another process with + * ::taIpcGetEventHandle. + * + * Use ::taEventDestroy to free the event. + * @param phEvent + * @param handle + * @return int + * ::TANG_SUCCESS + * ::TANG_ERROR_OUT_MEMORY + */ +TAresult TA_API taIpcOpenEventHandle(TAevent* phEvent, TAipcEventHandle handle); + +/** + * @brief launch engine collectives witch stream. + * + * @param devId which ptpu + * @param collType engine collectives type + * @param devAddr HBM address + * @param size params length + * @param stream for sync + * @return int + */ +TAresult TA_API taStreamEngineCollAssign(int devId, int collType, uint64_t devAddr, int size, TAstream stream); + +// enum TAqueryInfoType_enum { +// TA_QUERY_INFO_MEMORY_USAGE = 0x01, +// }; +// typedef enum TAqueryInfoType_enum TAqueryInfoType; +// +// TAresult TA_API taCtxQueryMemoryUsage(int64_t* pTotal, TAcontext context); + +TAresult TA_API taGetExportTable(void** pTable, void* args); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // _TANG_H_ + diff --git a/third_party/sunrise/backend/include/tang_compiler_api.h b/third_party/sunrise/backend/include/tang_compiler_api.h new file mode 100755 index 000000000..2c2f4ecd4 --- /dev/null +++ b/third_party/sunrise/backend/include/tang_compiler_api.h @@ -0,0 +1,223 @@ +/* +Copyright declaration. +*/ + +#ifndef _TANG_RT_TANG_COMPILER_API_H +#define _TANG_RT_TANG_COMPILER_API_H + +#include "tang_rt/driver_types.h" +#include "tang_rt/host_defines.h" +#include "tang_rt/vector_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief C compliant register fatbinary API + * + * @param [in] fatTangbin - pointer of tang fatbin data + * + * @returns tang fatbin handle registered in runtime + * + */ +TANGRT_API_PUBLIC void** __tangRegisterFatBinary(void* fatTangbin); + +/** + * @brief + * + * @param fatbinHandle + */ +TANGRT_API_PUBLIC void __tangRegisterFatBinaryEnd(void** fatbinHandle); + +/** + * @brief C compliant unregister fatbinary API + * + * @param [in] fatTangbinHandle - tang fatbin handle registered in runtime + * + * @returns void + * + */ +TANGRT_API_PUBLIC void __tangUnregisterFatBinary(void** fatTangbinHandle); + +/** + * @brief Register fatbinary + * + * @param fatbinWrapper + * @return void** + * @example + * @code{.cpp} + * //!< Generated by compiler + * const unsigned long fatbinText[] = { 0x12344567, ..., 0xabcd1235 }; + * static __tangFatbinaryWrapper fatbinWrapper = { + * .version = 0, + * .fatbin = &fatbinText[0], + * .size = sizeof(fatbinText), + * }; + * static void** fatbinHandle; + * static void __tang_RegisterAll(void) __attribute__((constructor)) + * static void __tang_UnregisterAll(void); + * static void __tang_RegisterAll(void) { + * fatbinHandle = __tangRegisterFatBinary_v2(&fatbinWrapper); + * __tangRegisterFunction((const void*)vecadd, + * (const char*)"_Z6vecaddPiS_S_", + * fatbinHandle); + * __tangRegisterFatBinaryEnd(fatbinHandle); + * atexit(__tang_UnregisterAll); + * } + * static void __tang_UnregisterAll(void) { + * __tangUnregisterFatBinary_v2(fatbinHandle); + * } + * @endcode + */ +TANGRT_API_PUBLIC void** __tangRegisterFatBinary_v2(void* fatbinWrapper); + +/** + * @brief Unregister fatbinary + * + * @param fatTangbinHandle + * @return TANGRT_API_PUBLIC + */ +TANGRT_API_PUBLIC void __tangUnregisterFatBinary_v2(void** fatTangbinHandle); + +TANGRT_API_PUBLIC int __tangInitModule(void** fatTangbinHandle); + +/** + * @brief C compliant set fatbinary info API + * + * @param [in] fatbinHandle - tang fatbin handle registered in runtime + * @param [in] info - pointer of tang fatbin info + * + * @returns void + * + */ +TANGRT_API_PUBLIC void __tangSetFatBinaryInfo(void** fatbinHandle, + const char* info); + +/** + * @brief C compliant register function API + * + * @param [in] hostFunc - the pointer of the host stub function + * @param [in] deviceFuncName - the name string of device function. + * @param [in] fatbinHandle - tang fatbin handle registered in runtime + * + * @returns void + * + */ +TANGRT_API_PUBLIC void __tangRegisterFunction(const void* hostFunc, + const char* deviceFuncName, + void** fatbinHandle); + +/** + * @brief Register device variable. + * + * @param fatbinHandle The fatbin handle returned by \c __tangRegisterFatBinary + * @param hostVar The corresponding host variable address used as the key. + * @param deviceVarAddress The device variable address. + * @param deviceVarName The symbol name of the devcie variable. + * @param ext + * @param size The size of variable in bytes. + * @param constant If the variable is in const memory ? + * @param global From the host's point of view, device variable is always local. + * Thus, param "global" is always 0. + * @remark + * For code `__device__ int xyz;` the compiler should generate the following + * code. + * @code{.cpp} + * static int xyz; + * __tangRegisterVariable(fatbinHandle, + * &xyz, (char*)"xyz", "xyz", + * 0, sizeof(int), 0, 0); + * @endcode + */ +TANGRT_API_PUBLIC void __tangRegisterVariable(void** fatbinHandle, + const void* hostVar, + char* deviceVarAddress, + const char* deviceVarName, + int ext, + size_t size, + int constant, + int global); + +/** + * @brief C compliant push call configuration API + * + * @param [in] gridDim - number of blocks in a grid + * @param [in] blockDim - number of threads in a block + * @param [in] sharedMemBytes - Amount of dynamic shared memory to allocate for + * this kernel. The Kernel can access this with TANG_DYNAMIC_SHARED. + * @param [in] stream - Stream where the kernel should be dispatched. May be 0, + * in which case the default stream is used with associated synchronization + * rules. + * + * @returns #tangSuccess + * + */ +TANGRT_API_PUBLIC tangError_t +__tangPushCallConfiguration(dim3 gridDim, + dim3 blockDim, + size_t sharedMemBytes __dparm(0), + tangStream_t stream __dparm(0)); + +/** + * @brief C compliant pop call configuration API + * + * @param [out] gridDim - number of blocks in a grid + * @param [out] blockDim - number of threads in a block + * @param [out] sharedMemBytes - Amount of dynamic shared memory to allocate for + * this kernel. The Kernel can access this with TANG_DYNAMIC_SHARED. + * @param [out] stream - Stream where the kernel should be dispatched. May be + * 0, in which case the default stream is used with associated synchronization + * rules. + * + * @returns #tangSuccess + * + */ +TANGRT_API_PUBLIC tangError_t __tangPopCallConfiguration(dim3* gridDim, + dim3* blockDim, + size_t* sharedMemBytes, + tangStream_t* stream); + +/** + * @brief C compliant kernel launch API + * + * @param [in] hostFunc - the pointer of the host stub function + * @param [in] gridDim - number of blocks in a grid + * @param [in] blockDim - number of threads in a block + * @param [in] args - kernel arguments + * @param [in] numArgs - number of kernel arguments + * @param [in] sharedMemBytes - Amount of dynamic shared memory to allocate for + * this kernel. The Kernel can access this with TANG_DYNAMIC_SHARED. + * @param [in] stream - Stream where the kernel should be dispatched. May be 0, + * in which case th default stream is used with associated synchronization + * rules. + * + * @returns #tangSuccess, #tangErrorInvalidValue, tangInvalidDevice + * + */ +TANGRT_API_PUBLIC tangError_t tangLaunchKernel(const void* hostFunc, + dim3 gridDim, + dim3 blockDim, + void** args, + size_t numArgs, + size_t sharedMemBytes __dparm(0), + tangStream_t stream __dparm(0)); + +TANGRT_API_PUBLIC tangError_t +tangLaunchKernel_ptsz(const void* hostFunc, + dim3 gridDim, + dim3 blockDim, + void** args, + size_t numArgs, + size_t sharedMemBytes __dparm(0), + tangStream_t stream __dparm(0)); + +#ifdef __cplusplus +} +#endif + +#if defined(__TANGRT_API_PER_THREAD_DEFAULT_STREAM) +#define tangLaunchKernel __TANGRT_API_PTSZ(tangLaunchKernel) +#endif //! __TANGRT_API_PER_THREAD_DEFAULT_STREAM + +#endif //! _TANG_RT_TANG_COMPILER_API_H diff --git a/third_party/sunrise/backend/include/tang_rt/device_assert.h b/third_party/sunrise/backend/include/tang_rt/device_assert.h new file mode 100755 index 000000000..691fc8f5b --- /dev/null +++ b/third_party/sunrise/backend/include/tang_rt/device_assert.h @@ -0,0 +1,43 @@ +#ifndef _TANGRT_DEVICE_ASSERT_H_ +#define _TANGRT_DEVICE_ASSERT_H_ + +#include + +#include + +#include "tang_rt/device_functions.h" + +extern "C" { +// #pragma push_macro("size_t") +// #define size_t unsigned +__device__ void __assertfail(const char *__message, + const char *__file, + unsigned __line, + const char *__function, + unsigned __charSize) +//__attribute__((noreturn)) +{ + __pt_printf("%d: block: [%d,%d,%d], thread: [%d,%d,%d] Assertion failed.\n", + __line, + blockIdx.x, + blockIdx.y, + blockIdx.z, + threadIdx.x, + threadIdx.y, + threadIdx.z); + asm volatile("exit\n\t" ::: "memory"); +} +// #undef size_t +// #pragma pop_macro("size_t") + +// In order for standard assert() macro on linux to work we need to +// provide device-side __assert_fail() +__device__ static inline void __assert_fail(const char *__message, + const char *__file, + unsigned __line, + const char *__function) { + __assertfail(__message, __file, __line, __function, sizeof(char)); +} +} // end extern "C" + +#endif diff --git a/third_party/sunrise/backend/include/tang_rt/device_functions.h b/third_party/sunrise/backend/include/tang_rt/device_functions.h new file mode 100755 index 000000000..f95c0112b --- /dev/null +++ b/third_party/sunrise/backend/include/tang_rt/device_functions.h @@ -0,0 +1,8 @@ +#ifndef _TANGRT_DEVICE_FUNCTIONS_H_ +#define _TANGRT_DEVICE_FUNCTIONS_H_ +#include + +#include "tang_rt/fmt.hpp" + +#endif //!< _TANGRT_DEVICE_FUNCTIONS_H_ + diff --git a/third_party/sunrise/backend/include/tang_rt/driver_types.h b/third_party/sunrise/backend/include/tang_rt/driver_types.h new file mode 100755 index 000000000..0b76cc20c --- /dev/null +++ b/third_party/sunrise/backend/include/tang_rt/driver_types.h @@ -0,0 +1,1281 @@ +/* +Copyright declaration. +*/ +#ifndef _TANG_RT_DRIVER_TYPES_H_ +#define _TANG_RT_DRIVER_TYPES_H_ +#include "tang_rt/vector_types.h" + +#define tangHostAllocDefault 0x00 /**< Default page-locked allocation flag */ +#define tangHostAllocPortable \ + 0x01 /**< Pinned memory accessible by all TANG contexts */ +#define tangHostAllocMapped 0x02 /**< Map allocation into device space */ +#define tangHostAllocWriteCombined 0x04 /**< Write-combined memory */ +#define tangHostAllocMapDeviceMemory \ + 0x100 /**< Allocate device memory and map it to userspace */ + +#define tangHostRegisterDefault \ + 0x00 /**< Default host memory registration flag */ +#define tangHostRegisterPortable \ + 0x01 /**< Pinned memory accessible by all TANG contexts */ +#define tangHostRegisterMapped \ + 0x02 /**< Map registered memory into device space */ +#define tangHostRegisterIoMemory 0x04 /**< Memory-mapped I/O space */ +#define tangHostRegisterReadOnly 0x08 /**< Memory-mapped read-only */ + +/**< Default behavior */ +#define tangOccupancyDefault 0x00 +/**< Assume global caching is enabled and cannot be automatically turned off */ +#define tangOccupancyDisableCachingOverride 0x01 +/* + * TANG error types + */ +// Developer note - when updating these, update the tangGetErrorString +// functions. + +/** + * Not consider Cooperative Launch, Peer Access, Pitch, Texture, Graph, JIT, + * Managed Memory, Multi-thread(compute mode) right now + */ + +enum tangError { + /** + * Some of the tangError are not supported, + * such as tangErrorInvalidPitchValue, tangErrorInvalidHostPointer, etc + * If we develop tools to convert tang program to tang, + * tangTangErrorToTangError should redirect them to tangErrorUnknown. + * Need check. + */ + + /** + * The API call returned with no errors. In the case of query calls, this + * also means that the operation being queried is complete (see + * ::tangEventQuery() and ::tangStreamQuery()). + * Need check. + */ + tangSuccess = 0, + + /** + * This indicates that one or more of the parameters passed to the API call + * is not within an acceptable range of values. + */ + tangErrorInvalidValue = 1, + + /** + * The API call failed because it was unable to allocate enough memory to + * perform the requested operation. + * In cuda: cudaErrorMemoryAllocation + * In hip: hipErrorOutOfMemory + */ + tangErrorMemoryAllocation = 2, + + /** + * The API call failed because the TANG driver and runtime could not be + * initialized. + * In cuda: cudaErrorInitializationError + * In hip: hipErrorNotInitialized + */ + tangErrorInitializationError = 3, + + /** + * This indicates that a TANG Runtime API call cannot be executed because + * it is being called during process shut down, at a point in time after + * Tang runtime has been deinitialized. + * In cuda: cudaErrorCudartUnloading + * In hip: hipErrorDeinitialized + * hip declares it but not use. + * Maybe hip will not meet such a point. + * Need check. + */ + tangErrorDeinitialized = 4, + + /** + * This indicates that the device is removed. + */ + tangErrorDeviceRemoved = 5, + + /** + * This indicated that the device is reset. + */ + tangErrorDeviceReset = 6, + + /** + * This operation is not allowed + */ + tangErrorNotPermitted = 7, + + /** + * The specified file or directory is not found + */ + tangErrorNoSuchFile = 8, + + /** + * This indicates that a kernel launch is requesting resources that can + * never be satisfied by the current device. Requesting more shared memory + * per block than the device supports will trigger this error, as will + * requesting too many threads or blocks. See ::tangDeviceProp for more + * device limitations. + */ + tangErrorInvalidConfiguration = 9, + + /** + * @brief Null pointer is passed as argument but it is disallowed. + */ + tangErrorNullPointer = 10, + + tangErrorOutOfResources = 11, + + /** + * This indicates that the symbol name/identifier passed to the API call + * is not a valid name or identifier. + */ + tangErrorInvalidSymbol = 13, + + /** + * This indicates that at least one device pointer passed to the API call is + * not a valid device pointer. + * Note: This error is deprecated from CUDA 10.1, + * but hip still use it + */ + tangErrorInvalidDevicePointer = 17, + + /** + * This indicates that the direction of the memcpy passed to the API call is + * not one of the types specified by ::tangMemcpyKind. + * hip declares it but not use. + * But it seems useful. + * Need check. + */ + tangErrorInvalidMemcpyDirection = 21, + + /** + * This indicates that the installed TANG driver version is mismatch with the + * runtime version. + * hip declares it but not use. + * But it seems useful. + * Need check. + */ + tangErrorInsufficientDriver = 35, + + /** + * The device function being invoked (usually via ::tangLaunchKernel()) was + * not previously configured via the ::tangConfigureCall() function + * nor provided with ".kernel_info" in ELF. + */ + tangErrorMissingConfiguration = 52, + + /** + * The requested device function does not exist or is not compiled for the + * proper device architecture. + */ + tangErrorInvalidDeviceFunction = 98, + + /** + * This error indicates the attempted operation is not implemented. + */ + tangErrorNotImplemented = 99, + + /** + * This indicates that no TANG-capable devices were detected by the installed + * TANG driver. Call to tangGetDeviceCount returned 0 devices. + */ + tangErrorNoDevice = 100, + + /** + * This indicates that the device ordinal supplied by the user does not + * correspond to a valid TANG device. + * DeviceID must be in range 0...#compute-devices. + */ + tangErrorInvalidDevice = 101, + + /** + * This indicates that the device kernel image is invalid. + * In cuda: cudaErrorInvalidKernelImage + * In hip: hipErrorInvalidImage + */ + tangErrorInvalidKernelImage = 200, + + /** + * This most frequently indicates that there is no context bound to the + * current thread. This can also be returned if the context passed to an + * API call is not a valid handle (such as a context that has had + * ::taCtxDestroy() invoked on it). + * In cuda: cudaErrorDeviceUninitialized + * In hip:hipErrorInvalidContext + */ + tangErrorInvalidContext = 201, + + /** + * This indicates that there is no kernel image available that is suitable + * for the device. This can occur when a user specifies code generation + * options for a particular TANG source file that do not include the + * corresponding device configuration. + * In cuda: cudaErrorNoKernelImageForDevice + * In hip: hipErrorNoBinaryForGpu + * Need check. + */ + tangErrorNoKernelImageForDevice = 209, + + /** + * This indicates that the ::tangLimit passed to the API call is not + * supported by the active device. + */ + tangErrorUnsupportedLimit = 215, + + /** + * This indicates that a call tried to access an exclusive-thread device that + * is already in use by a different thread. + * In cuda: cudaErrorDeviceAlreadyInUse + * In hip: hipErrorContextAlreadyInUse + */ + tangErrorContextAlreadyInUse = 216, + + /** + * A PTX compilation failed. The runtime may fall back to compiling PTX if + * an application does not contain a suitable binary for the current device. + * In cuda: cudaErrorInvalidPtx + * In hip: hipErrorInvalidKernelFile + */ + tangErrorInvalidKernelFile = 218, + + /** + * When launch kernel, unable to find the corresponding fatbinary. + */ + tangErrorFatBinaryNotFound = 300, + + /** + * This indicates that the file specified was not found. + */ + tangErrorFileNotFound = 301, + + /** + * This indicates that a link to a shared object failed to resolve. + */ + tangErrorSharedObjectSymbolNotFound = 302, + + /** + * This indicates that initialization of a shared object failed. + */ + tangErrorSharedObjectInitFailed = 303, + + /** + * This error indicates that an OS call failed. + * hip declares it but not use. + * But it seems useful. + * Need check. + */ + tangErrorOperatingSystem = 304, + + tangErrorIllegalState = 305, + + tangErrorStreamCaptureUnsupported = 306, + + tangErrorStreamCaptureInvalidated = 307, + + tangErrorStreamCaptureMerge = 308, + + tangErrorStreamCaptureUnmatched = 309, + + tangErrorStreamCaptureUnjoined = 310, + + tangErrorStreamCaptureIsolation = 311, + + tangErrorStreamCaptureImplicit = 312, + + tangErrorStreamCaptureWrongThread = 313, + + tangErrorCapturedEvent = 314, + + /** + * This indicates that a resource handle passed to the API call was not + * valid. Resource handles are opaque types like ::tangStream_t and + * ::tangEvent_t. + * In cuda: cudaErrorInvalidResourceHandle + * In hip: hipErrorInvalidHandle + */ + tangErrorInvalidResourceHandle = 400, + + /** + * This indicates that a named symbol was not found. Examples of symbols + * are global/constant variable names, texture names, and surface names. + * In cuda: cudaErrorSymbolNotFound + * In hip: tangErrorNotFound + */ + tangErrorSymbolNotFound = 500, + + /** + * This indicates that asynchronous operations issued previously have not + * completed yet. This result is not actually an error, but must be indicated + * differently than ::tangSuccess (which indicates completion). Calls that + * may return this value include ::tangEventQuery() and ::tangStreamQuery(). + */ + tangErrorNotReady = 600, + + /** + * The device encountered a load or store instruction on an invalid memory + * address. This leaves the process in an inconsistent state and any further + * TANG work will return the same error. To continue using TANG, the process + * must be terminated and relaunched. hip declares it but not use. But it + * seems useful. Need check. + */ + tangErrorIllegalAddress = 700, + + /** + * This indicates that a launch did not occur because it did not have + * appropriate resources. Although this error is similar to + * ::tangErrorInvalidConfiguration, this error usually indicates that the + * user has attempted to pass too many arguments to the device kernel, or the + * kernel launch specifies too many threads for the kernel's register count. + */ + tangErrorLaunchOutOfResources = 701, + + /** + * This indicates that the device kernel took too long to execute. This can + * only occur if timeouts are enabled - see the device property + * \ref ::tangDeviceProp::kernelExecTimeoutEnabled "kernelExecTimeoutEnabled" + * for more information. + * This leaves the process in an inconsistent state and any further TANG work + * will return the same error. To continue using TANG, the process must be + * terminated and relaunched. hip declares it but not use. But it seems + * useful. Need check. + */ + tangErrorLaunchTimeOut = 702, + + tangErrorPeerAccessAlreadyEnabled = 704, + tangErrorPeerAccessNotEnabled = 705, + + /** + * This error indicates that the memory range passed to ::tangHostRegister() + * has already been registered. + * hip declares it but not use. + * But it seems useful. + * Need check. + */ + tangErrorHostMemoryAlreadyRegistered = 712, + + /** + * This error indicates that the pointer passed to ::tangHostUnregister() + * does not correspond to any currently registered memory region. + * hip declares it but not use. + * But it seems useful. + * Need check. + */ + tangErrorHostMemoryNotRegistered = 713, + + /** + * An exception occurred on the device while executing a kernel. Common + * causes include dereferencing an invalid device pointer and accessing + * out of bounds shared memory. Less common cases can be system specific - + * more information about these cases can be found in the system specific user + * guide. This leaves the process in an inconsistent state and any further + * TANG work will return the same error. To continue using TANG, the process + * must be terminated and relaunched. + */ + tangErrorLaunchFailure = 719, + + /** + * This error indicates the attempted operation is not supported + * on the current system or device. + */ + tangErrorNotSupported = 801, + + /** + * This indicates that an unknown internal error has occurred. + */ + tangErrorUnknown = 999, + + /** + * @brief context is destroyed or in destroying in kernel + * + */ + tangErrorContextIsDestroyed = 3000, + + /** + * @brief context is not valid in kernel + * + */ + tangErrorContextInvalid = 3001, + + /** + * @brief stream is destroyed or in destroying in kernel + * + */ + tangErrorStreamIsDestroyed = 3002, + + /** + * @brief stream is not valid in kernel + * + */ + tangErrorStreamInvalid = 3003, + + /** + * @brief event is destroyed or in destroying in kernel + * + */ + tangErrorEventIsDestroyed = 3004, + + /** + * @brief event is not valid in kernel + * + */ + tangErrorEventInvalid = 3005, + + /** + * @brief device memory is not enough for current operation + * + */ + tangErrorDeviceOutOfMemory = 3006, + + /** + * @brief device memory is not found + * + */ + tangErrorDeviceMemoryNotFound = 3007, + + /** + * @brief pcie fatal error occured + * + */ + tangErrorPcieFatal = 3012, + + /** + * @brief pcie non-fatal unrecovered error occured + * + */ + tangErrorPcieNonFatalUnrecovered = 3013, + + /** + * @brief no more event exist + * + */ + tangErrorScpEventNotExist = 3014, + + /** + * @brief record event failed + * + */ + tangErrorSCPEventRecordFailed = 3015, + + /** + * @brief scp packet crc check failed + * + */ + tangErrorSCPCrcPacketFailed = 3016, + + /** + * @brief scp dispatch send failed + * + */ + tangErrorSCPDispSendFailed = 3017, + + /** + * @brief sq write sequence error + * + */ + tangErrorSCPSqWriteFailed = 3018, + + /** + * @brief udrc pcie xdma packet invalid + * + */ + tangErrorUdrcPcieDmaPacketInvalid = 3019, + + /** + * @brief udrc mp dma packet invalid + * + */ + tangErrorUdrcMpDmaPacketInvalid = 3020, + + /** + * @brief udrc reg packet invalid + * + */ + tangErrorUdrcRegPacketInvalid = 3021, + + /** + * @brief udrc reg access invalid + * + */ + tangErrorUdrcRegAcessInvalid = 3022, + + /** + * @brief aiss cluster is not configured + * + */ + tangErrorAissClusterUsrNotAllocated = 3023, + + /** + * @brief barrier is destroyed or in destroying in kernel + * + */ + tangErrorBarrierDestroyed = 3024, + + /** + * @brief barrier is not valid in kernel + * + */ + tangErrorBarrierInvalid = 3025, + + /** + * @brief one obj is destroyed or in destroying in kernel + * + */ + tangErrorDestroyed = 3026, + + /** + * @brief xdma C2H align mismath + * + */ + tangErrorXdmaC2HAlignMismatch = 3300, + + /** + * @brief xdma C2H invalid magic stopped + * + */ + tangErrorXdmaC2HInvalidMagicStopped = 3301, + + /** + * @brief xdma C2H invalid Len + * + */ + tangErrorXdmaC2HInvalidLen = 3302, + + /** + * @brief xdma C2H decode error + * + */ + tangErrorXdmaC2HDecode = 3303, + + /** + * @brief xdma C2H slave + * + */ + tangErrorXdmaC2HSlave = 3304, + + /** + * @brief xdma C2H descriptor unsupport request + * + */ + tangErrorXdmaC2HDescUnsupportRequest = 3305, + + /** + * @brief xdma C2H descriptor completer abort + * + */ + tangErrorXdmaC2HDescCompleterAbort = 3306, + + /** + * @brief xdma C2H descriptor parity + * + */ + tangErrorXdmaC2HDescParity = 3307, + + /** + * @brief xdma C2H descriptor header ep + * + */ + tangErrorXdmaC2HDescHeaderEp = 3308, + + /** + * @brief xdma C2H descriptor unexpected comp + * + */ + tangErrorXdmaC2HDescUnexpectedComp = 3309, + + /** + * @brief xdma C2H timeout + * + */ + tangErrorXdmaC2HTimeout = 3310, + + /** + * @brief xdma C2H unknown + * + */ + tangErrorXdmaC2HUnknown = 3311, + + /** + * @brief xdma H2C align mismatch + * + */ + tangErrorXdmaH2CAlignMismatch = 3350, + + /** + * @brief xdma H2C invalid magic stopped + * + */ + tangErrorXdmaH2CInvaildMagicStopped = 3351, + + /** + * @brief xdma H2C invalid len + * + */ + tangErrorXdmaH2CInvalidLen = 3352, + + /** + * @brief xdma H2C read unsupport request + * + */ + tangErrorXdmaH2CReadUnSupportRequest = 3353, + + /** + * @brief xdma H2C read completer abort + * + */ + tangErrorXdmaH2CReadCompleterAbort = 3354, + + /** + * @brief xdma H2C read parity + * + */ + tangErrorXdmaH2CReadParity = 3355, + + /** + * @brief xdma H2C read header ep + * + */ + tangErrorXdmaH2CReadHeaderEp = 3356, + + /** + * @brief xdma H2C read unexpected comp + * + */ + tangErrorXdmaH2CReadUnExpectedComp = 3357, + + /** + * @brief xdma H2C decode error + * + */ + tangErrorXdmaH2CDecode = 3358, + + /** + * @brief xdma H2C slave + * + */ + tangErrorXdmaH2CSlave = 3359, + + /** + * @brief xdma H2C descriptor unsupport request + * + */ + tangErrorXdmaH2CDescUnSupportRequest = 3360, + + /** + * @brief xdma H2C descriptor completer abort + * + */ + tangErrorXdmaH2CDescCompleterAbort = 3361, + + /** + * @brief xdma H2C descriptor parity + * + */ + tangErrorXdmaH2CDescParity = 3362, + + /** + * @brief xdma H2C descriptor header ep + * + */ + tangErrorXdmaH2CDescHeaderEp = 3363, + + /** + * @brief xdma H2C descriptor unexpected com + * + */ + tangErrorXdmaH2CDescUnExpectedComp = 3364, + + /** + * @brief xdma H2C timeout + * + */ + tangErrorXdmaH2CTimeout = 3365, + + /** + * @brief xdma H2C unknown + * + */ + tangErrorXdmaH2CUnknown = 3366, + + /** + * This indicates that the IOCTL operation of TANG driver is failed. + * Added by TANG. hip and tang do not use. + * Need check, avoid to use the save number with cuda and hip. + */ + tangErrorDriverIoctlFailed = 10000, +}; + +/** + * TANG memory copy types + */ +typedef enum tangMemcpyKind { + tangMemcpyHostToHost = 0, /**< Host -> Host */ + tangMemcpyHostToDevice = 1, /**< Host -> Device */ + tangMemcpyDeviceToHost = 2, /**< Device -> Host */ + tangMemcpyDeviceToDevice = 3, /**< Device -> Device */ + /** + * Direction of the transfer is inferred from the pointer values. + * Requires unified virtual addressing, thus tang doesn't support. + */ + // tangMemcpyDefault = 4 +} tangMemcpyKind; + +enum tangMemoryType { + tangMemoryTypeUnregistered = 0, /**< Unregistered memory */ + tangMemoryTypeHost = 1, /**< Host memory */ + tangMemoryTypeDevice = 2, /**< Device memory */ + tangMemoryTypeManaged = 3, /**< Managed memory */ +}; + +struct tangPointerAttributes { + enum tangMemoryType type; + + int device; + + void* devicePointer; + + void* hostPointer; +}; + +/** + * TANG function attributes + */ +typedef struct tangFuncAttributes { + /** + * The size in bytes of statically-allocated shared memory per block + * required by this function. This does not include dynamically-allocated + * shared memory requested by the user at runtime. + */ + size_t sharedSizeBytes; + + /** + * The size in bytes of user-allocated constant memory required by this + * function. + */ + // PT devices use global memory to perform as constant memory, + // and constant memory belongs to the module, rather than the function + size_t constSizeBytes; + + /** + * The size in bytes of local memory used by each thread of this function. + */ + size_t localSizeBytes; + + /** + * The maximum number of threads per block, beyond which a launch of the + * function would fail. This number depends on both the function and the + * device on which the function is currently loaded. + */ + int maxThreadsPerBlock; + + /** + * The number of registers used by each thread of this function. + */ + int numRegs; + + /** + * The PTX virtual architecture version for which the function was + * compiled. This value is the major PTX version * 10 + the minor PTX + * version, so a PTX version 1.3 function would return the value 13. + */ + int ptxVersion; + + /** + * The binary architecture version for which the function was compiled. + * This value is the major binary version * 10 + the minor binary version, + * so a binary version 1.3 function would return the value 13. + */ + int binaryVersion; + + /** + * The attribute to indicate whether the function has been compiled with + * user specified option "-Xptxas --dlcm=ca" set. + */ + int cacheModeCA; + + /** + * The maximum size in bytes of dynamic shared memory per block for + * this function. Any launch must have a dynamic shared memory size + * smaller than this value. + */ + int maxDynamicSharedSizeBytes; + + /** + * On devices where the L1 cache and shared memory use the same hardware + * resources, this sets the shared memory carveout preference, in percent of + * the maximum shared memory. Refer to + * ::tangDevAttrSHARED_MEM_PER_MULTIPROCESSOR. This is only a hint, and the + * driver can choose a different ratio if required to execute the function. + * See ::tangFuncSetAttribute + * + * PT devices do not suppport to config L1 cache/shared memory. + */ + int preferredShmemCarveout; +} tangFuncAttributes; + +/** + * TANG function attributes that can be set using ::tangFuncSetAttribute + */ +typedef enum tangFuncAttribute { + tangFuncAttributeMaxDynamicSharedMemorySize = + 8, /**< Maximum dynamic shared memory size */ + tangFuncAttributePreferredSharedMemoryCarveout = + 9, /**< Preferred shared memory-L1 cache split */ + tangFuncAttributeMax +} tangFuncAttribute; + +/** + * TANG function cache configurations + * @warning On PT2 devices, L1 cache and shared memory are separated, + * thus these hints and controls are ignored. + */ +typedef enum tangFuncCache { + tangFuncCachePreferNone, ///< no preference for shared memory or L1 (default) + tangFuncCachePreferShared, ///< prefer larger shared memory and smaller L1 + ///< cache + tangFuncCachePreferL1, ///< prefer larger L1 cache and smaller shared memory + tangFuncCachePreferEqual, ///< prefer equal size L1 cache and shared memory +} tangFuncCache; + +/** + * TANG shared memory configuration + * @warning On PT2 devices, shard memory bank size is fix to 4-bytes, + * thus these hints and controls are ignored. + */ +typedef enum tangSharedMemConfig { + tangSharedMemBankSizeDefault, ///< The compiler selects a device-specific + ///< value for the banking. + tangSharedMemBankSizeFourByte, ///< Shared mem is banked at 4-bytes intervals + ///< and performs best when adjacent threads + ///< access data 4 bytes apart. + tangSharedMemBankSizeEightByte ///< Shared mem is banked at 8-byte intervals + ///< and performs best when adjacent threads + ///< access data 4 bytes apart. +} tangSharedMemConfig; + +/** + * TANG Limits + */ +enum tangLimit { + tangLimitStackSize = 0x00, /**< GPU thread stack size */ + tangLimitPrintfFifoSize = 0x01, /**< GPU printf FIFO size */ + tangLimitMallocHeapSize = 0x02, /**< GPU malloc heap size */ + tangLimitDevRuntimeSyncDepth = + 0x03, /**< GPU device runtime synchronize depth */ + tangLimitDevRuntimePendingLaunchCount = + 0x04, /**< GPU device runtime pending launch count */ + tangLimitMaxL2FetchGranularity = + 0x05 /**< A value between 0 and 128 that indicates the +// maximum fetch granularity of L2 (in Bytes). This is a hint */ +}; + +enum tangEventFlags_e { + tangEventDefault = 0x00, + tangEventDisableTiming = 0x02, + tangEventInterprocess = 0x04, +}; + +enum tangEventSyncFlags_e { + tangEventSyncDefault = 0x00, + + //!< Block until the event is recorded and done. + tangEventSyncRecordedAndDone = 0x01, +}; + +enum tangEventRecordFlags_e { + //!< The default recording mode. + tangEventRecordDefault = 0x00, + + //!< Always use the hardware event. + tangEventRecordHW = 0x0100, + + //!< Always use the software event. + tangEventRecordSW = 0x0200, + + //!< Allow to blockinig the calling thread is resource is not available. + tangEventRecordBlockingAllowed = 0x0400, +}; + +/** + * TANG Memory Advise values + */ +// enum tangMemoryAdvise {}; + +/** + * TANG range attributes + */ +// enum tangMemRangeAttribute {}; + +typedef enum tangStreamCaptureMode_e { + tangStreamCaptureModeGlobal = 0, + tangStreamCaptureModeThreadLocal = 1, + tangStreamCaptureModeRelaxed = 2, +} tangStreamCaptureMode; + +typedef enum tangStreamCaptureStatus_e { + tangStreamCaptureStatusNone = 0, + tangStreamCaptureStatusActive = 1, + tangStreamCaptureStatusInvalidated = 2, +} tangStreamCaptureStatus; + +/** + * TANG device attribute enum + */ +typedef enum tangDeviceAttr { + tangDevAttrMaxSharedMemPerBlock = 0, //!< sharedMemPerBlock + tangDevAttrMaxRegsPerBlock, //!< regsPerBlock + tangDevAttrWarpSize, //!< warpSize + tangDevAttrMemPitch, //!< memPitch + tangDevAttrMaxThreadsPerBlock, //!< maxThreadsPerBlock + tangDevAttrMaxBlockDimX, //!< maxThreadsDimX + tangDevAttrMaxBlockDimY, //!< maxThreadsDimY + tangDevAttrMaxBlockDimZ, //!< maxThreadsDimZ + tangDevAttrMaxGridDimX, //!< maxGridSizeX + tangDevAttrMaxGridDimY, //!< maxGridSizeY + tangDevAttrMaxGridDimZ, //!< maxGridSizeZ + tangDevAttrClockRate, //!< clockRate + tangDevAttrTotalConstantMemory, //!< totalConstMem + tangDevAttrMultiProcessorCount, //!< multiProcessorCount + tangDevAttrMaxBlocksPerMultiProcessor, //!< maxBlocksPerMultiProcessor + tangDevAttrAsyncEngineCount, //!< asyncEngineCount + tangDevAttrMemoryClockRate, //!< memoryClockRate + tangDevAttrGlobalMemoryBusWidth, //!< memoryBusWidth + tangDevAttrL2CacheSize, //!< l2CacheSize + tangDevAttrMaxThreadsPerMultiProcessor, //!< maxThreadsPerMultiProcessor + tangDevAttrGlobalL1CacheSupported, //!< globalL1CacheSupported + tangDevAttrLocalL1CacheSupported, //!< localL1CacheSupported + tangDevAttrMaxSharedMemoryPerMultiprocessor,//!< sharedMemPerMultiprocessor + tangDevAttrMaxRegistersPerMultiprocessor, //!< regsPerMultiprocessor + tangDevAttrStreamPrioritiesSupported, //!< streamPrioritiesSupported + tangDevAttrConcurrentKernels, //!< concurrentKernels + tangDevAttrComputePreemptionSupported, //!< computePreemptionSupported + tangDevAttrKernelExecTimeout, //!< kernelExecTimeoutEnabled + tangDevAttrEccEnabled, //!< ECCEnabled + tangDevAttrMaxAccessPolicyWindowSize, //!< accessPolicyMaxWindowSize + tangDevAttrTccDriver, //!< tccDriver + tangDevAttrSingleToDoublePrecisionPerfRatio,//!< singleToDoublePrecisionPerfRatio + tangDevAttrCooperativeLaunch, //!< cooperativeLaunch + tangDevAttrCooperativeMultiDeviceLaunch, //!< cooperativeMultiDeviceLaunch + tangDevAttrMaxPersistingL2CacheSize, //!< persistingL2CacheMaxSize + tangDevAttrCanMapHostMemory, //!< canMapHostMemory + tangDevAttrUnifiedAddressing, //!< unifiedAddressing + tangDevAttrManagedMemory, //!< managedMemory + tangDevAttrConcurrentManagedAccess, //!< concurrentManagedAccess + tangDevAttrDirectManagedMemAccessFromHost, //!< directManagedMemAccessFromHost + tangDevAttrPageableMemoryAccess, //!< pageableMemoryAccess + tangDevAttrPageableMemoryAccessUsesHostPageTables, //!< pageableMemoryAccessUsesHostPageTables + tangDevAttrCanUseHostPointerForRegisteredMem, //!< canUseHostPointerForRegisteredMem + tangDevAttrHostNativeAtomicSupported, //!< hostNativeAtomicSupported + tangDevAttrCanFlushRemoteWrites, //!< canFlushRemoteWrites + tangDevAttrGpuOverlap, //!< gpuOverlap + tangDevAttrIntegrated, //!< integrated + tangDevAttrMaxSharedMemoryPerBlockOptin, //!< maxSharedMemoryPerBlockOptin + tangDevAttrGPUDirectRDMASupported, //!< gpuDirectRDMASupported + tangDevAttrGPUDirectRDMAFlushWritesOptions, //!< gpuDirectRDMAFlushWritesOptions + tangDevAttrGPUDirectRDMAWritesOrdering, //!< gpuDirectRDMAWritesOrdering + tangDevAttrComputeCapabilityMajor, //!< major + tangDevAttrComputeCapabilityMinor, //!< minor + tangDevAttrPciBusId, //!< pciBusID + tangDevAttrPciDeviceId, //!< pciDeviceID + tangDevAttrPciDomainId, //!< pciDomainID + tangDevAttrIsMultiGpuBoard, //!< isMultiGpuBoard + tangDevAttrMultiGpuBoardGroupID, //!< multiGpuBoardGroupID + tangDevAttrComputeMode, //!< computeMode + tangDevAttrReservedSharedMemoryPerBlock, //!< reservedSharedMemoryPerBlock + tangDevAttrSparseTangArraySupported, //!< sparseTangArraySupported + tangDevAttrHostRegisterSupported, //!< hostRegisterSupported + tangDevAttrHostRegisterReadOnlySupported, //!< hostRegisterReadOnlySupported + tangDevAttrMemoryPoolsSupported, //!< memoryPoolsSupported + tangDevAttrMemoryPoolSupportedHandleTypes, //!< memoryPoolSupportedHandleTypes + tangDevAttrMax +} tangDeviceAttr; + +#ifndef TANGRT_DEVICE_P2P_ATTR_ENUM +#define TANGRT_DEVICE_P2P_ATTR_ENUM +/** + * TANG Device P2P attributes + */ +enum tangDeviceP2PAttr { + tangDevP2PAttrPerformanceRank = 1, + tangDevP2PAttrAccessSupported = 2, + tangDevP2PAttrNativeAtomicSupported = 3, + tangDevP2PAttrTangArrayAccessSupported = 4, +}; +#endif // TANGRT_DEVICE_P2P_ATTR_ENUM + +/** + * + * TANG device properties + * Inconsistent with cudaDeviceProp + * Most field unused. + * Need check. + */ +typedef struct tangDeviceProp { + char name[256]; ///< Device name. + char uuid[16]; ///< a 16-byte unique identifier + uint64_t totalGlobalMem; ///< size of global memory region (in bytes). + int sharedMemPerBlock; ///< the maximum amount of shared memory + ///< available to a thread block in bytes. + int regsPerBlock; ///< the maximum number of 32-bit registers available to a + ///< thread block. + int warpSize; ///< the warp size in threads. + int memPitch; ///< the maximum pitch in bytes allowed by the memory copy + ///< functions + int maxThreadsPerBlock; ///< the maximum number of threads per block. + int maxThreadsDim[3]; ///< Max number of threads in each dimension (XYZ) of a + ///< block. + int maxGridSize[3]; ///< Max grid dimensions (XYZ). + int clockRate; ///< Max clock frequency of the multiProcessors in khz. + int totalConstMem; ///< the total amount of constant memory available on + ///< the device in bytes. + int multiProcessorCount; ///< Number of multi-processors (compute units). + int maxBlocksPerMultiProcessor; ///< the number of multiprocessors on the + ///< device + int asyncEngineCount; ///< + int memoryClockRate; ///< Max global memory clock frequency in khz. + int memoryBusWidth; ///< Global memory bus width in bits. + int l2CacheSize; ///< L2 cache size. + int maxThreadsPerMultiProcessor; ///< Maximum resident threads per + ///< multi-processor. + int globalL1CacheSupported; ///< whether the device supports caching of + ///< globals in L1 cache + int localL1CacheSupported; ///< whether the device supports caching of locals + ///< in L1 cache + int sharedMemPerMultiprocessor; ///< Maximum Shared Memory Per Multiprocessor. + int regsPerMultiprocessor; ///< the maximum amount of shared memory available + ///< to a multiprocessor in bytes + int streamPrioritiesSupported; ///< whether the device supports stream + ///< priorities + int concurrentKernels; ///< Device can possibly execute multiple kernels + ///< concurrently. + int computePreemptionSupported; ///< whether the device supports Compute + ///< Preemption + int kernelExecTimeoutEnabled; ///< Run time limit for kernels executed on the + ///< device + int ECCEnabled; ///< Device has ECC support enabled + int accessPolicyMaxWindowSize; ///< the maximum value of + ///< tangAccessPolicyWindow::num_bytes + int tccDriver; ///< whether device is a Tesla device using TCC driver + int singleToDoublePrecisionPerfRatio; ///< the ratio of single precision + ///< performance (in floating-point + ///< operations per second) to double + ///< precision performance + int cooperativeLaunch; ///< whether the device supports launching cooperative + ///< kernels via tangLaunchCooperativeKernel + int cooperativeMultiDeviceLaunch; ///< whether the device supports launching + ///< cooperative kernels via + ///< tangLaunchCooperativeKernelMultiDevice + int persistingL2CacheMaxSize; ///< L2 cache's maximum persisting lines size + ///< in bytes + int canMapHostMemory; ///< Check whether TANG can map host memory + int unifiedAddressing; ///< whether the device shares a unified address space + ///< with the host and 0 otherwise + int managedMemory; ///< whether the device supports allocating managed memory + ///< on this system, or 0 if it is not supported + int concurrentManagedAccess; ///< whether the device can coherently access + ///< managed memory concurrently with the CPU + int directManagedMemAccessFromHost; ///< whether the host can directly access + ///< managed memory on the device without + ///< migration + int pageableMemoryAccess; ///< whether the device supports coherently + ///< accessing pageable memory without calling + ///< tangHostRegister on it + int pageableMemoryAccessUsesHostPageTables; ///< whether the device accesses + ///< pageable memory via the + ///< host's page tables + int canUseHostPointerForRegisteredMem; ///< whether the device can access + ///< host registered memory at the + ///< same virtual address as the CPU + + int hostNativeAtomicSupported; ///< Link between the device and the host + ///< supports native atomic operations + int canFlushRemoteWrites; ///< Device supports flushing of outstanding remote + ///< writes + int gpuOverlap; ///< Device can possibly copy memory and execute a kernel + ///< concurrently + int integrated; ///< Device is integrated with host memory + int maxSharedMemoryPerBlockOptin; ///< The maximum optin shared memory per + ///< block. This value may vary by chip. + ///< See ::tangFuncSetAttribute + int gpuDirectRDMASupported; ///< Device supports GPUDirect RDMA APIs + int gpuDirectRDMAFlushWritesOptions; ///< The returned attribute shall be + ///< interpreted as a bitmask, where the + ///< individual bits are listed in the + ///< ::tangFlushGPUDirectRDMAWritesOptions + ///< enum + int gpuDirectRDMAWritesOrdering; ///< GPUDirect RDMA writes to the device do + ///< not need to be flushed for consumers + ///< within the scope indicated by the + ///< returned attribute. See + ///< ::tangGPUDirectRDMAWritesOrdering for + ///< the numerical values returned here. + int major; ///< the major revision numbers defining the device's compute + ///< capability + int minor; ///< the minor revision numbers defining the device's compute + ///< capability + int pciBusID; ///< PCI Bus ID. + int pciDeviceID; ///< PCI Device ID. + int pciDomainID; ///< PCI Domain ID + int isMultiGpuBoard; ///< whether device is on a multi-GPU board. + int multiGpuBoardGroupID; ///< a unique identifier for a group of devices + ///< associated with the same board + int computeMode; ///< the compute mode that the device is currently in + int reservedSharedMemoryPerBlock; ///< Shared memory reserved by TANG driver + ///< per block in bytes + int sparseTangArraySupported; ///< Device supports sparse arrays and sparse + ///< mipmapped arrays + int hostRegisterSupported; ///< Device supports host memory registration via + ///< ::tangHostRegister + int hostRegisterReadOnlySupported; ///< Device supports using the + ///< ::tangHostRegister flag + ///< tangHostRegisterReadOnly to register + ///< memory that must be mapped as + ///< read-only to the GPU + int memoryPoolsSupported; ///< Device supports using the ::tangMallocAsync + ///< and ::tangMemPool family of APIs + int memoryPoolSupportedHandleTypes; ///< Handle types supported with mempool + ///< based IPC +} __attribute__((packed)) tangDeviceProp; + +/** + * TANG launch parameters + */ +/* +struct __device_builtin__ tangLaunchParams { + void* func; ///< Device function symbol + dim3 gridDim; ///< Grid dimentions + dim3 blockDim; ///< Block dimentions + void **args; ///< Arguments + int sharedMem; ///< Shared memory + tangStream_t stream; ///< Stream identifier +}; +*/ +/******************************************************************************* + * * + * SHORTHAND TYPE DEFINITION USED BY RUNTIME API * + * * + *******************************************************************************/ + +/** + * TANG Error types + */ +typedef enum tangError tangError_t; + +/** + * @brief TANG Device + * @sa TAdevice + */ +typedef struct TAdevice_s* tangDevice_t; + +/** + * @brief TANG context + * @sa TAcontext + */ +typedef struct TActx_s* tangContext_t; + +/** + * @brief TANG stream + * @sa TAstream + */ +typedef struct TAstream_s* tangStream_t; + +/** + * @brief TANG event + * @sa TAevent + */ +typedef struct TAevent_s* tangEvent_t; + +/** + * @brief TANG function + * @sa TAfunction + */ +typedef struct TAfunc_s* tangFunction_t; + +/** + * @brief TANG graph & executable graph handle + * @sa tangStreamBeginCapture + * @sa tangStreamEndCapture + * @sa tangGraphLaunch + * @sa tangGraphInstantiate + */ +typedef struct TAgraph_s* tangGraph_t; +typedef struct TAgraphExec_s* tangGraphExec_t; +typedef struct TAgraphNode_s* tangGraphNode_t; + +typedef void (*tangHostFn_t)(void* userData); + +typedef struct tangHostNodeParams_s { + tangHostFn_t fn; + void* userData; +} tangHostNodeParams; + +typedef struct tangKernelNodeParams_s { + void* func; /**< Kernel to launch */ + dim3 gridDim; /**< Grid dimensions */ + dim3 blockDim; /**< Block dimensions */ + + /**< Dynamic shared memory size per thread block in bytes */ + unsigned int sharedMemBytes; + + /**< Kernel parameters */ + void** kernelParams; + void** extra; +} tangKernelNodeParams; + +typedef struct tangGraphInfo_s { + int nr_nodes; +} tangGraphInfo; + +typedef struct tangEventTimestamp_s { + uint64_t comp; + uint64_t comp_sw; + uint64_t create; + uint64_t enqueue; + uint64_t writeq_beg; + uint64_t writeq_end; +} tangEventTimestamp; + +struct tangLanuchParams { + void* func; + dim3 gridDim; + dim3 blockDim; + void** args; + size_t sharedMemBytes; + tangStream_t stream; +}; + +/** + * @brief tangFatbinaryWrapper + * TANGCC will provides the following asm code on x86_64 platform. + * @code{.s} + * __tang_fatbin_wrapper: + * .long 0 # 0x0, version, 4 bytes + * .zero 4 # padding space, 4 bytes + * .quad .L_Z9vectorAddPfS_S_.11 # fatbin + * .zero 160 # data[20] + * .size __tang_fatbin_wrapper, 176 + * @endcode + */ +struct __tangFatbinaryWrapper { + int version; + const void* fatbin; + // The TANGCC does not reserve space for size. + // The size will be parsed from fatbin + // unsigned long size; + struct { + uintptr_t data[20]; + } dso; +}; + +#define TANG_IPC_HANDLE_SIZE 64U +#define TANG_IPC_MEM_HANDLE_SIZE 64U + +#define tangIpcMemLazyEnablePeerAccess 0x01 + +typedef struct tangIpcMemHandle_s { + unsigned long reserved[TANG_IPC_MEM_HANDLE_SIZE / sizeof(unsigned long)]; +} tangIpcMemHandle_t; + +typedef struct tangIpcEventHandle_s { + unsigned long reserved[TANG_IPC_HANDLE_SIZE / sizeof(unsigned long)]; +} tangIpcEventHandle_t; + +#endif //! _TANG_RT_DRIVER_TYPES_H_ diff --git a/third_party/sunrise/backend/include/tang_rt/fmt.hpp b/third_party/sunrise/backend/include/tang_rt/fmt.hpp new file mode 100755 index 000000000..1928ae095 --- /dev/null +++ b/third_party/sunrise/backend/include/tang_rt/fmt.hpp @@ -0,0 +1,1097 @@ +#ifndef _TANGRT_FMT_HPP_ +#define _TANGRT_FMT_HPP_ + +#include + +#include +#include +#include + +// #define PT_PRINTF_ENDMARKER +#define PT_PRINTF_READY + +#if !defined(__TANGC_MAJOR__) && !defined(__device__) +#define __device__ +#endif //!< __TANGC_MAJOR__ + +#ifndef unlikely +#define unlikely(x) __builtin_expected(!!(x), 0) +#endif //!< unlikely + +#ifndef lower_32_bit +#define lower_32_bit(x) (((uint64_t)x) & 0xffffffff) +#endif + +#ifndef upper_32_bit +#define upper_32_bit(x) (((uint64_t)x) >> 32) +#endif + +namespace tangrt { +namespace fmt { + +static constexpr unsigned kArgAlignment = 4; +static constexpr unsigned kArgAlignmentMask = ~(kArgAlignment - 1); + +enum ArgId { + ArgId_None = 0, + + ArgId_char = 1, + ArgId_schar = 2, + ArgId_uchar = 3, + + ArgId_short = 5, + ArgId_ushort = 6, + + ArgId_int = 7, + ArgId_uint = 8, + + ArgId_long = 9, + ArgId_ulong = 10, + + ArgId_longlong = 11, + ArgId_ulonglong = 12, + + ArgId_float = 13, + ArgId_double = 14, + ArgId_long_double = 15, + + //!< nullptr + ArgId_nullptr = 20, + + //!< generic pointer type. + ArgId_pointer = 21, + + //!< char* ptr = nullptr; + ArgId_char_nullptr = 22, + ArgId_schar_nullptr = 23, + ArgId_uchar_nullptr = 24, + + ArgId_char_pointer = 25, + ArgId_schar_pointer = 26, + ArgId_uchar_pointer = 27, + + //!< char[] + ArgId_char_array = 28, + ArgId_schar_array = 29, + ArgId_uchar_array = 30, +}; + +static inline const char* GetArgIdName(const int id) { +#define _case(x, str) \ + case x: \ + return str + + switch (id) { + _case(ArgId_char, "char"); + _case(ArgId_schar, "signed char"); + _case(ArgId_uchar, "unsigned char"); + + _case(ArgId_short, "short"); + _case(ArgId_ushort, "unsigned short"); + _case(ArgId_int, "int"); + _case(ArgId_uint, "unsigned int"); + _case(ArgId_long, "long"); + _case(ArgId_ulong, "unsigned long"); + _case(ArgId_longlong, "long long"); + _case(ArgId_ulonglong, "unsigned long long"); + + _case(ArgId_float, "float"); + _case(ArgId_double, "double"); + _case(ArgId_long_double, "long double"); + + _case(ArgId_nullptr, "nullptr"); + _case(ArgId_pointer, "pointer"); + + _case(ArgId_char_nullptr, "char nullptr"); + _case(ArgId_schar_nullptr, "schar nullptr"); + _case(ArgId_uchar_nullptr, "uchar nullptr"); + + _case(ArgId_char_pointer, "char pointer"); + _case(ArgId_schar_pointer, "schar pointer"); + _case(ArgId_uchar_pointer, "uchar pointer"); + + _case(ArgId_char_array, "char array"); + _case(ArgId_schar_array, "schar array"); + _case(ArgId_uchar_array, "uchar array"); + default: + return "None"; + } +#undef _case +} + +static __device__ inline unsigned ArgAlign(unsigned int x) { + return (x + kArgAlignment - 1) & ~(kArgAlignment - 1); +} + +static __device__ inline unsigned int ArgTraitsStrLen(const char* s) { + const char* p = s; + while (*p) { + ++p; + } + return p - s; +} + +namespace detail { + +union u32c4_u { + char c[4]; + uint32_t u32; + + __device__ u32c4_u(char ch0, char ch1 = 0, char ch2 = 0, char ch3 = 0) + : c{ch0, ch1, ch2, ch3} {} +}; + +template +__device__ void FundamentalFillImpl(const uint32_t id, + const T t, + uint32_t* buf, + uint32_t& pos, + const uint32_t mask) { + static_assert(N == 1 || N == 2 || N == 4, ""); + static_assert(!std::is_same::value, + "This helper function is not suitable for char*."); + + union { + uint32_t u[N]; + T t; + } x; + x.t = t; + buf[pos++ & mask] = (sizeof(T) << 16) | id; + for (unsigned int i = 0; i < N; ++i) { + buf[pos++ & mask] = x.u[i]; + } +} + +template +__device__ void FundamentalFill(const uint32_t id, + const T t, + uint32_t* buf, + uint32_t& pos, + const uint32_t mask) { + FundamentalFillImpl(id, t, buf, pos, mask); +} + +#if 0 +template <> +__device__ void FundamentalFill(const uint32_t id, const float t, uint32_t* buf, + uint32_t& pos, const uint32_t mask) +{ + buf[pos++ & mask] = (sizeof(float) << 16) | id; + union + { + uint32_t u[2]; + float f; + } x; + x.f = t; + buf[pos++ & mask] = x.u[0]; +} + +//!< to avoid strict aliasing +//!< -fno-strict-aliasing +template <> +__device__ void FundamentalFill(const uint32_t id, const double t, uint32_t* buf, + uint32_t& pos, const uint32_t mask) +{ + buf[pos++ & mask] = (sizeof(double) << 16) | id; + union + { + uint32_t u[2]; + double d; + } x; + x.d = t; + buf[pos++ & mask] = x.u[0]; + buf[pos++ & mask] = x.u[1]; +} +#endif + +static __device__ inline void StringFillData(const char* s, + const uint32_t sizeBytes, + uint32_t* const buf, + uint32_t& pos, + const uint32_t mask) { + const uint32_t* s32 = (const uint32_t*)s; + for (unsigned int i = 0; i < sizeBytes / sizeof(uint32_t); ++i) { + buf[pos++ & mask] = s32[i]; + } + switch (sizeBytes & (4 - 1)) { + case 1: { + detail::u32c4_u x(s[sizeBytes - 1]); + buf[pos++ & mask] = x.u32; + break; + }; + case 2: { + detail::u32c4_u x(s[sizeBytes - 2], s[sizeBytes - 1]); + buf[pos++ & mask] = x.u32; + break; + }; + case 3: { + detail::u32c4_u x(s[sizeBytes - 3], s[sizeBytes - 2], s[sizeBytes - 1]); + buf[pos++ & mask] = x.u32; + break; + } + } +} // StringFillData function + +} // namespace detail + +template +struct FmtTraits; + +template +struct FmtTraits { + static __device__ unsigned int FmtLength(const char*) { + return 4 + ArgAlign(N); + } + + static __device__ void Fill(const char* s, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + //!< Fill header + buf[pos++ & mask] = (N << 16) | ArgId_schar_array; + detail::StringFillData(s, N, buf, pos, mask); + } +}; + +template +struct FmtTraits { + static __device__ unsigned int FmtLength(const char*) { + return 4 + ArgAlign(N); + } + + static __device__ void Fill(const char* s, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + //!< Fill header + buf[pos++ & mask] = (N << 16) | ArgId_schar_array; + detail::StringFillData(s, N, buf, pos, mask); + } +}; + +template <> +struct FmtTraits { + static __device__ unsigned int FmtLength(const char* s) { + return 4 + (s ? ArgAlign(ArgTraitsStrLen(s)) : 0); + } + + static __device__ void Fill(const char* s, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + if (!s) { + buf[pos++ & mask] = ArgId_char_nullptr; + return; + } + const uint32_t N = ArgTraitsStrLen(s); + detail::StringFillData(s, N, buf, pos, mask); + } +}; + +template <> +struct FmtTraits { + static __device__ unsigned int FmtLength(const char* s) { + return 4 + (s ? ArgAlign(ArgTraitsStrLen(s)) : 0); + } + + static __device__ void Fill(const char* s, + uint32_t* buf, + uint32_t pos, + uint32_t mask) { + if (!s) { + buf[pos++ & mask] = ArgId_char_nullptr; + return; + } + const uint32_t N = ArgTraitsStrLen(s); + detail::StringFillData(s, N, buf, pos, mask); + } +}; + +template +struct ArgTraits; + +// template +// struct ArgTraits { +// static const int id = ArgId_None; +// +// static unsigned int ArgLength(T&& t) { +// return 4 + sizeof(T); +// } +//}; + +template <> +struct ArgTraits { + static const int id = ArgId_char; + + static __device__ unsigned int ArgLength(const signed char t) { + return ArgAlign(4); + } + + static __device__ void Fill(const char ch, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + buf[pos++ & mask] = ch << 16 | id; + } +}; + +template <> +struct ArgTraits { + static const int id = ArgId_schar; + + static __device__ unsigned int ArgLength(const signed char t) { + return ArgAlign(sizeof(t)); + } + + static __device__ void Fill(const signed char ch, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + buf[pos++ & mask] = ch << 16 | id; + } +}; + +template <> +struct ArgTraits { + static const int id = ArgId_uchar; + + static __device__ unsigned int ArgLength(const unsigned char t) { + return ArgAlign(sizeof(t)); + } + + static __device__ void Fill(const unsigned char ch, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + buf[pos++ & mask] = ch << 16 | id; + } +}; + +template <> +struct ArgTraits { + static const int id = ArgId_short; + + static __device__ unsigned int ArgLength(const short int s) { + return ArgAlign(sizeof(s)); + } + + static __device__ void Fill(const short int ch, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + buf[pos++ & mask] = (ch << 16) | id; + } +}; + +template <> +struct ArgTraits { + static const int id = ArgId_ushort; + + static __device__ unsigned int ArgLength(const unsigned short int s) { + return ArgAlign(sizeof(s)); + } + + static __device__ void Fill(const unsigned short ch, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + buf[pos++ & mask] = (ch << 16) | id; + } +}; + +template <> +struct ArgTraits { + static const int id = ArgId_int; + + static __device__ unsigned int ArgLength(const int i) { + return 4 + sizeof(i); + } + + static __device__ void Fill(const int s, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + detail::FundamentalFill(ArgId_int, s, buf, pos, mask); + } +}; + +template <> +struct ArgTraits { + static const int id = ArgId_uint; + + static __device__ unsigned int ArgLength(const unsigned int i) { + return 4 + sizeof(i); + } + + static __device__ void Fill(const unsigned int s, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + detail::FundamentalFill(ArgId_uint, s, buf, pos, mask); + } +}; + +template <> +struct ArgTraits { + static const long id = ArgId_long; + + static __device__ unsigned int ArgLength(const long l) { + return 4 + sizeof(l); + } + + static __device__ void Fill(const long s, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + detail::FundamentalFill(ArgId_long, s, buf, pos, mask); + } +}; + +template <> +struct ArgTraits { + static const int id = ArgId_ulong; + + static __device__ unsigned int ArgLength(const unsigned long l) { + return 4 + sizeof(l); + } + + static __device__ void Fill(const unsigned long s, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + detail::FundamentalFill(ArgId_ulong, s, buf, pos, mask); + } +}; + +template <> +struct ArgTraits { + static const long id = ArgId_longlong; + + static __device__ unsigned int ArgLength(const long long l) { + return 4 + sizeof(l); + } + + static __device__ void Fill(const long long s, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + detail::FundamentalFill(ArgId_longlong, s, buf, pos, mask); + } +}; + +template <> +struct ArgTraits { + static const int id = ArgId_ulonglong; + + static __device__ unsigned int ArgLength(const unsigned long long l) { + return sizeof(l) + 4; + } + + static __device__ void Fill(const unsigned long long s, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + detail::FundamentalFill(ArgId_ulonglong, s, buf, pos, mask); + } +}; + +template <> +struct ArgTraits { + static const int id = ArgId_float; + + static __device__ unsigned int ArgLength(const float l) { + return sizeof(l) + 4; + } + + static __device__ void Fill(const float s, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + detail::FundamentalFill(ArgId_float, s, buf, pos, mask); + } +}; + +template <> +struct ArgTraits { + static const int id = ArgId_double; + + static __device__ unsigned int ArgLength(const double d) { + return sizeof(d) + 4; + } + + static __device__ void Fill(const double s, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + detail::FundamentalFill(ArgId_double, s, buf, pos, mask); + } +}; + +template <> +struct ArgTraits { + static const int id = ArgId_long_double; + + static __device__ unsigned int ArgLength(const long double d) { + return sizeof(d) + 4; + } + + static __device__ void Fill(const long double s, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + detail::FundamentalFill(ArgId_long_double, s, buf, pos, mask); + } +}; + +template <> +struct ArgTraits { + static const int id = ArgId_nullptr; + + static __device__ unsigned int ArgLength(...) { return 4; } + + static __device__ void Fill(const std::nullptr_t, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + buf[pos++ & mask] = ArgId_nullptr; + } +}; + +template +struct ArgTraits { + static const int id = ArgId_pointer; + + static __device__ unsigned int ArgLength(...) { return 4 + sizeof(T*); } + + static __device__ void Fill(T* const s, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + detail::FundamentalFill(ArgId_pointer, s, buf, pos, mask); + } +}; + +template <> +struct ArgTraits { + static const int id = ArgId_char_pointer; + + static __device__ unsigned int ArgLength(const char* s) { + return 4 + (s ? ArgAlign(ArgTraitsStrLen(s)) + 8 : 0); + } + + static __device__ void Fill(const char* s, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + if (!s) { + buf[pos++ & mask] = ArgId_char_nullptr; + return; + } + const uint32_t N = ArgTraitsStrLen(s); + buf[pos++ & mask] = (N << 16) | id; + buf[pos++ & mask] = lower_32_bit(s); + buf[pos++ & mask] = upper_32_bit(s); + detail::StringFillData(s, N, buf, pos, mask); + } +}; + +template <> +struct ArgTraits { + static const int id = ArgId_char_pointer; + + static __device__ unsigned int ArgLength(const char* s) { + return 4 + (s ? ArgAlign(ArgTraitsStrLen(s)) + 8 : 0); + } + + static __device__ void Fill(const char* s, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + if (!s) { + buf[pos++ & mask] = ArgId_char_nullptr; + return; + } + const uint32_t N = ArgTraitsStrLen(s); + buf[pos++ & mask] = (N << 16) | id; + buf[pos++ & mask] = lower_32_bit(s); + buf[pos++ & mask] = upper_32_bit(s); + detail::StringFillData(s, N, buf, pos, mask); + } +}; + +template +struct ArgTraits { + static const int id = ArgId_char_array; + + static __device__ unsigned int ArgLength(const char* s) { + return 4 + ArgAlign(N) + 8; + } + + static __device__ void Fill(const char* s, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + buf[pos++ & mask] = (N << 16) | id; + buf[pos++ & mask] = lower_32_bit(s); + buf[pos++ & mask] = upper_32_bit(s); + detail::StringFillData(s, N, buf, pos, mask); + } +}; + +template +struct ArgTraits { + static const int id = ArgId_schar_array; + + static __device__ unsigned int ArgLength(const char* s) { + return 4 + ArgAlign(N) + 8; + } + + static __device__ void Fill(const char* s, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + buf[pos++ & mask] = (N << 16) | id; + buf[pos++ & mask] = lower_32_bit(s); + buf[pos++ & mask] = upper_32_bit(s); + detail::StringFillData(s, N, buf, pos, mask); + } +}; + +template +struct ArgTraits { + static const int id = ArgId_uchar_array; + + static __device__ unsigned int ArgLength(const char* s) { + return 4 + ArgAlign(N) + 8; + } + + static __device__ void Fill(const char* s, + uint32_t* buf, + uint32_t& pos, + uint32_t mask) { + buf[pos++ & mask] = (N << 16) | id; + buf[pos++ & mask] = lower_32_bit(s); + buf[pos++ & mask] = upper_32_bit(s); + detail::StringFillData(s, N, buf, pos, mask); + } +}; + +// template +// struct ArgTraits { +// using type = T; +// +// static const int id = ArgId_any_array; +//}; + +// template +// struct ArgTraits { +// static const int id = ArgId_uchar_array; +//}; + +namespace detail { +template +__device__ unsigned SumArgsLength(Args&&... args); + +template <> +__device__ constexpr unsigned SumArgsLength() { + return 0; +} + +template +__device__ unsigned SumArgsLength(T&& t, Args&&... args) { + typedef typename std::remove_reference::type _type; + typedef typename std::remove_cv<_type>::type type; + + return ArgTraits::ArgLength(std::forward(t)) + + SumArgsLength(std::forward(args)...); +} + +struct FifoInfo { + uint32_t get; + uint32_t put; + + //!< num words + uint32_t size; + + uint32_t fifoSize; + uint64_t fifoAddress; +}; + +template +struct ArgTraitsHasFillFifoInfo { + template + static auto Check(int) -> decltype(&U::FillFifoInfo); + + template + static void Check(...); + + static const bool value = !std::is_same(0)), void>::value; +}; + +template +struct ArgTraitsFillProxy; + +template +struct ArgTraitsFillProxy { + template + __device__ static void Fill(U&& u, + uint32_t* fifobuf, + const FifoInfo& msgInfo, + uint32_t& pos, + uint32_t& mask) { + ArgTraits::FillFifoInfo(std::forward(u), fifobuf, msgInfo, pos, mask); + } +}; + +template +struct ArgTraitsFillProxy { + template + __device__ static void Fill(U&& u, + uint32_t* fifobuf, + const FifoInfo& msgInfo, + uint32_t& pos, + uint32_t& mask) { + ArgTraits::Fill(std::forward(u), fifobuf, pos, mask); + } +}; + +template +__device__ void FillArgs(uint32_t* fifobuf, + const FifoInfo& msgInfo, + uint32_t& pos, + uint32_t mask, + Args&&... args); + +template <> +__device__ inline void FillArgs(uint32_t* fifobuf, + const FifoInfo& msgInfo, + uint32_t& pos, + uint32_t mask) {} + +template +__device__ void FillArgs(uint32_t* fifobuf, + const FifoInfo& msgInfo, + uint32_t& pos, + uint32_t mask, + T&& t, + Args&&... args) { + typedef typename std::remove_reference::type _type; + typedef typename std::remove_cv<_type>::type type; + + ArgTraitsFillProxy>::value>:: + Fill(t, fifobuf, msgInfo, pos, mask); + //ArgTraits::Fill(t, fifobuf, pos, mask); + FillArgs(fifobuf, msgInfo, pos, mask, std::forward(args)...); +} + +template +struct CountOfArgs; + +template <> +struct CountOfArgs<> { + static const int value = 0; +}; + +template +struct CountOfArgs { + static const int value = CountOfArgs::value + 1; +}; + +} // namespace detail + +namespace debug { + +//!< The get the __pt_printf load. +struct MsgGet {}; + +//!< The begin position of the current __pt_printf. +struct MsgBeg {}; + +//!< The number words the current __pt_printf will consume. +struct MsgSize {}; + +//!< Print the begin address of the current __pt_printf fifo. +struct FifoAddress {}; + +//!< Print the size of the print fifo. +struct FifoSize {}; + +} // namespace debug + +template <> +struct ArgTraits { + static const int id = ArgId_uint; + + __device__ static unsigned int ArgLength(debug::MsgBeg put) { return 8; } + + __device__ static void FillFifoInfo(debug::MsgBeg, + uint32_t* fifobuf, + const detail::FifoInfo& msgInfo, + uint32_t& pos, + uint32_t mask) { + detail::FundamentalFill(ArgId_uint, msgInfo.put, fifobuf, pos, mask); + } +}; + +template <> +struct ArgTraits { + static const int id = ArgId_uint; + + __device__ static unsigned int ArgLength(debug::MsgGet put) { return 8; } + + __device__ static void FillFifoInfo(debug::MsgGet, + uint32_t* fifobuf, + const detail::FifoInfo& msgInfo, + uint32_t& pos, + uint32_t mask) { + detail::FundamentalFill(ArgId_uint, msgInfo.get, fifobuf, pos, mask); + } +}; + +template <> +struct ArgTraits { + static const int id = ArgId_uint; + + __device__ static unsigned int ArgLength(debug::MsgSize s) { return 8; } + + __device__ static void FillFifoInfo(debug::MsgSize, + uint32_t* fifobuf, + const detail::FifoInfo& msgInfo, + uint32_t& pos, + uint32_t mask) { + detail::FundamentalFill(ArgId_uint, + msgInfo.size, + fifobuf, + pos, + mask); + } +}; + +template <> +struct ArgTraits { + static const int id = ArgId_uint; + + __device__ static unsigned int ArgLength(debug::FifoSize s) { return 8; } + + __device__ static void FillFifoInfo(debug::FifoSize, + uint32_t* fifobuf, + const detail::FifoInfo& msgInfo, + uint32_t& pos, + uint32_t mask) { + detail::FundamentalFill(ArgId_uint, + msgInfo.fifoSize, + fifobuf, + pos, + mask); + } +}; + +template <> +struct ArgTraits { + static const int id = ArgId_pointer; + + __device__ static unsigned int ArgLength(debug::FifoAddress s) { return 12; } + + __device__ static void FillFifoInfo(debug::FifoAddress, + uint32_t* fifobuf, + const detail::FifoInfo& msgInfo, + uint32_t& pos, + uint32_t mask) { + detail::FundamentalFill(ArgId_pointer, + msgInfo.fifoAddress, + fifobuf, + pos, + mask); + } +}; + +struct fifo { + unsigned int put __attribute__((aligned(128))); + unsigned int mask; + uint32_t* data; + + unsigned int ready __attribute__((aligned(128))); + + unsigned int get __attribute__((aligned(128))); +}; + +extern "C" struct fifo __ptPrintfFifo; + +inline __device__ struct fifo* __ptSelectPrintfFifo(void) { + //unsigned int bidx = threadIdx.z * (blockDim.x * blockDim.y) + + // threadIdx.y * blockDim.x + threadIdx.x; + //return &__ptPrintfFifo + (widx / 32) & 0x01; +#ifdef __TANGC_MAJOR__ + return &__ptPrintfFifo + (__phywarpid() & 0x01); +#else + return &__ptPrintfFifo; +#endif //!< __TANGC_MAJOR__ +} + +inline __device__ uint32_t __ptPrintfFifoAlloc(struct fifo* fifo, + uint32_t n, + uint32_t* pGet) { + [[maybe_unused]] unsigned int tmp; + unsigned int newPut; + unsigned int oldPut; + unsigned int avail; + unsigned int get; + unsigned int mask = fifo->mask; + + if (n >= mask) { + return std::numeric_limits::max(); + } + do { +#ifdef __TANGC_MAJOR__ + oldPut = *((volatile unsigned int*)&fifo->put); + // oldPut = __ldcg(&fifo->put); +#else + oldPut = __atomic_load_n(&fifo->put, __ATOMIC_RELAXED); +#endif //!< __TANGC_MAJOR__ + +//#ifdef __TANGC_MAJOR__ +// __threadfence_memory(); +//#endif //!< __TANGC_MAJOR__ + +#ifdef __TANGC_MAJOR__ + // get = *((volatile unsigned int*)&fifo->get); + get = __ldcg(&fifo->get); +#else + get = __atomic_load_n(&fifo->get, __ATOMIC_RELAXED); +#endif //!< __TANGC_MAJOR__ + + // avail = mask - ((oldPut - get) & mask); + avail = (get - oldPut - 1) & mask; + if (avail < n) { + return std::numeric_limits::max(); + } + newPut = oldPut + n; + // newPut = (oldPut + n) & mask; +#ifdef __TANGC_MAJOR__ + tmp = atomicCAS(&fifo->put, oldPut, newPut); + // Compiler group provides this solution. +# if 1 + asm volatile("loop 1, 0, 1, 500000\n\tnop"); +# else + __stvm_bar_sync0(); +# endif + } while (oldPut != tmp); +#else + } while (!__atomic_compare_exchange_n(&fifo->put, + &oldPut, + newPut, + true, + __ATOMIC_RELAXED, + __ATOMIC_RELAXED)); +#endif //!< __TANGC_MAJOR__ + + *pGet = get; + return oldPut; +} + +inline __device__ void __ptPrintfFifoUpdateReady(struct fifo* fifo, + uint32_t orig_pos, + uint32_t pos) { + [[maybe_unused]] unsigned int tmp; + unsigned int oldReady = orig_pos; +#ifdef __TANGC_MAJOR__ + //! Make sure all writes before this call happens before + //! all writes after this call. + //! __syncthreads(); + __threadfence_memory(); +#endif //!< __TANGC_MAJOR__ + do { +#ifdef __TANGC_MAJOR__ + tmp = atomicCAS(&fifo->ready, oldReady, pos); +# if 1 + asm volatile("loop 1, 0, 1, 500000\n\tnop"); +# else + __stvm_bar_sync0(); +# endif + } while (oldReady != tmp); +#else + } while (!__atomic_compare_exchange_n(&fifo->ready, + &oldReady, + pos, + true, + __ATOMIC_RELEASE, + __ATOMIC_RELAXED)); +#endif //!< __TANGC_MAJOR__ +} + +template +__device__ void __pt_printf(Fmt&& fmt, Args&&... args) { + typedef typename std::remove_reference::type _fmt_type; + typedef typename std::remove_cv<_fmt_type>::type fmt_type; + + auto fifo = __ptSelectPrintfFifo(); + + // struct fifo = __ptPrintfFifo; + // numWords: the number bytes of the message; + // countOfArgs: the number of args + // fmt: fmt data + // arg[numArgs]: arg data + // endMarker: maybe not required + + unsigned int numFmtWords = FmtTraits::FmtLength(fmt) / 4; + unsigned int numArgWords = + detail::SumArgsLength(std::forward(args)...) / 4; + + static_assert(detail::CountOfArgs::value == sizeof...(Args), ""); + + // unsigned int countOfArgs = detail::CountOfArgs::value; + unsigned int countOfArgs = sizeof...(Args); + +#ifdef PT_PRINTF_ENDMARKER + unsigned int numMsgWords = 3 + numFmtWords + numArgWords; +#else + unsigned int numMsgWords = 2 + numFmtWords + numArgWords; +#endif //!< PT_PRINTF_ENDMARKER + + // align message to 128byte, 32uint32_t boundary. +#if 0 + numMsgWords = (numMsgWords + 31) & ~31; +#endif + + uint32_t get; + uint32_t pos = __ptPrintfFifoAlloc(fifo, numMsgWords, &get); + if (pos == std::numeric_limits::max()) { + return; + } + + const detail::FifoInfo fifoInfo = { + .get = get, + .put = pos, + .size = numMsgWords, + .fifoSize = fifo->mask + 1, + .fifoAddress = (uint64_t)fifo->data, + }; + + uint32_t const orig_pos = pos; + uint32_t const mask = fifo->mask; + uint32_t* const fifobuf = (uint32_t*)fifo->data; + + fifobuf[pos++ & mask] = numMsgWords; + fifobuf[pos++ & mask] = countOfArgs; + + FmtTraits::Fill(std::forward(fmt), fifobuf, pos, mask); + detail::FillArgs(fifobuf, fifoInfo, pos, mask, std::forward(args)...); + +#ifdef PT_PRINTF_ENDMARKER + fifobuf[pos++ & mask] = numMsgWords; +#endif //!< PT_PRINTF_ENDMARKER +#if 0 + while ((pos - orig_pos) < numMsgWords) { + fifobuf[pos++ & mask] = orig_pos; + } +#endif + + __ptPrintfFifoUpdateReady(fifo, orig_pos, (orig_pos + numMsgWords)); +} + +} // namespace fmt +} // namespace tangrt + +using tangrt::fmt::__pt_printf; + +#endif //!< _TANGRT_FMT_HPP_ diff --git a/third_party/sunrise/backend/include/tang_rt/host_defines.h b/third_party/sunrise/backend/include/tang_rt/host_defines.h new file mode 100755 index 000000000..f7639fb98 --- /dev/null +++ b/third_party/sunrise/backend/include/tang_rt/host_defines.h @@ -0,0 +1,101 @@ +/* +Copyright declaration. +*/ + +#ifndef _TANG_RT_INCLUDE_HOST_DEFINES_H_ +#define _TANG_RT_INCLUDE_HOST_DEFINES_H_ +#include +#include + +#ifdef __cplusplus +#define __dparm(x) = x +#else +#define __dparm(x) +#endif + +#ifdef __TANGC_MAJOR__ + +#ifndef __device__ +#define __device__ __Tdevice__ +#define __Tdevice__ __attribute__((Tdevice)) +#endif + +#ifndef __global__ +#define __global__ __Tglobal__ +#define __Tglobal__ __attribute__((Tglobal)) +#endif + +#ifndef __constant__ +#define __constant__ __Tconstant__ +#define __Tconstant__ __attribute__((Tconstant)) +#endif + +#ifndef __host__ +#define __host__ __Thost__ +#define __Thost__ __attribute__((Thost)) +#endif + +#ifndef __shared__ +#define __shared__ __Tshared__ +#define __Tshared__ __attribute__((Tshared)) +#endif + +#ifndef __forceinline__ +#define __forceinline__ __inline__ __attribute__((always_inline)) +#endif + +#endif + +#if defined(_MSC_VER) +#define TANGRT_DEPRECATED __declspec(deprecated) +#define TANGRT_API_EXPORT __declspec(dllexport) +#define TANGRT_API_IMPORT __declspec(dllimport) +#elif defined(__GNUC__) || defined(__clang__) +#define TANG_DEPRECATED __attribute__((deprecated)) +#define TANG_API_EXPORT __attribute__((visibility("default"))) +#define TANG_API_IMPORT __attribute__((visibility("default"))) +#else +#define TANG_DEPRECATED +#define TANG_API_EXPORT +#define TANG_API_IMPORT +#endif // unknown compiler, may needs extra care. + +#if defined(tangrt_shared_EXPORTS) +#define TANGRT_API_PUBLIC TANG_API_EXPORT +#elif !defined(__TANGRT_API_VERSION_INTERNAL) +#define TANGRT_API_PUBLIC TANG_API_IMPORT +#else +#define TANGRT_API_PUBLIC +#endif + +/************************************************** + * _ptds: Per-Thread-Default-Stream API use ptds to + * run commands. + * _ptsz suffix: Per-Thread-Stream-Zero API use ptds to + * run commands when the given stream is null. + * See the following code for details: + * @code + * tangError_t tangMemcpyAsync_ptsz(..., tangStream_t stream) { + * return tangMemcpyAsyncImpl(..., stream ? stream : TA_STREAM_PER_THREAD); + * } + * tangError_t tangMemcpyAsync(..., tangStream_t stream) { + * return tangMemcpyAsyncImpl(..., stream ? stream : TA_STREAM_LEGACY); + * } + * @endcode + **************************************************/ +#if defined(__TANGRT_API_PER_THREAD_DEFAULT_STREAM) +#define __TANGRT_API_PTDS(api) api##_ptds +#define __TANGRT_API_PTSZ(api) api##_ptsz +#else +#define __TANGRT_API_PTDS(api) api +#define __TANGRT_API_PTSZ(api) api +#endif //! __TANGRT_API_PER_THREAD_DEFAULT_STREAM + +#if defined(__TANGRT_API_VERSION_INTERNAL) +#undef __TANGRT_API_PTDS +#undef __TANGRT_API_PTSZ +#define __TANGRT_API_PTDS(api) api +#define __TANGRT_API_PTSZ(api) api +#endif // __TANGRT_API_VERSION_INTERNAL + +#endif //! _TANG_RT_INCLUDE_HOST_DEFINES_H_ diff --git a/third_party/sunrise/backend/include/tang_rt/vector_types.h b/third_party/sunrise/backend/include/tang_rt/vector_types.h new file mode 100755 index 000000000..e4aca338c --- /dev/null +++ b/third_party/sunrise/backend/include/tang_rt/vector_types.h @@ -0,0 +1,35 @@ +/* +Copyright declaration. +*/ + +// cuda/include/vector_types.h + +#ifndef _TANG_RT_INCLUDE_VECTOR_TYPES_H_ +#define _TANG_RT_INCLUDE_VECTOR_TYPES_H_ + +#include "tang_rt/host_defines.h" + +/** + * Struct for data in 3D + */ + +#if defined(__DIM3_TYPE__) +typedef dim3 __DIM3_TYPE__; +#else +typedef struct dim3 { + unsigned x; ///< x + unsigned y; ///< y + unsigned z; ///< z +#ifdef __cplusplus +#if __cplusplus >= 201103L + constexpr dim3(unsigned _x = 1, unsigned _y = 1, unsigned _z = 1) + : x(_x), y(_y), z(_z) {} +#else + dim3(unsigned _x = 1, unsigned _y = 1, unsigned _z = 1) + : x(_x), y(_y), z(_z) {} +#endif //! __cplusplus >= 201103 +#endif //! __cplusplus +} dim3; +#endif //! no __DIM3_TYPE__ + +#endif //! _TANG_RT_INCLUDE_VECTOR_TYPES_H_ diff --git a/third_party/sunrise/backend/include/tang_rt/version.h b/third_party/sunrise/backend/include/tang_rt/version.h new file mode 100755 index 000000000..3ac846e76 --- /dev/null +++ b/third_party/sunrise/backend/include/tang_rt/version.h @@ -0,0 +1,30 @@ +#ifndef _TANG_RUNTIME_VERSION_H_ +#define _TANG_RUNTIME_VERSION_H_ +#define TANG_VERSION_MAJOR 0 +#define TANG_VERSION_MINOR 13 +#define TANG_VERSION_PATCH 0 + +#define TANG_VERSION_GIT_SHA "" + +///////////////////////////////////////////////////////// + +#define TANGRT_VERSION_MAJOR 0 +#define TANGRT_VERSION_MINOR 13 +#define TANGRT_VERSION_PATCH 0 + +#define TANGRT_VERSION_GIT_SHA "04137493 Merge branch 'ln/bugfix/taStreamIsCapturing' into 'master'" + +///////////////////////////////////////////////////////// +#define TANGRT_TANGCC_VERSION_MAJOR 2 +#define TANGRT_TANGCC_VERSION_MINOR 2 + +#ifdef __TANGC_MAJOR__ +# if (TANGRT_TANGCC_VERSION_MAJOR <= 1) && (__TANGC_MAJOR__ >= 2) +#warning "the ptcc used is not compatible with the tang runtime library\nptcc less than 2.0.0 is required." +//#error "the ptcc used is not compatible with the tang runtime library\nptcc less than 2.0.0 is required." +//# elif (TANGRT_TANGCC_VERSION_MAJOR >= 2) && (__TANGC_MAJOR__ <= 1) +//#error "the ptcc used is not compatible with the tang runtime library\nptcc 2.0.0 or later is required." +# endif +#endif // __TANGC_MAJOR__ + +#endif //! _TANG_RUNTIME_VERSION_H_ diff --git a/third_party/sunrise/backend/include/tang_runtime.h b/third_party/sunrise/backend/include/tang_runtime.h new file mode 100755 index 000000000..3a58f7a02 --- /dev/null +++ b/third_party/sunrise/backend/include/tang_runtime.h @@ -0,0 +1,32 @@ +#ifndef _TANG_RUNTIME_H_ +#define _TANG_RUNTIME_H_ +#include "tang_rt/version.h" +#include "tang_rt/driver_types.h" +#include "tang_rt/vector_types.h" +#include "tang_runtime_api.h" + +#ifndef TA_STREAM_LEGACY +#define TA_STREAM_LEGACY ((tangStream_t)0x01) +#endif //! TA_STREAM_LEGACY + +#ifndef TA_STREAM_PER_THREAD +#define TA_STREAM_PER_THREAD ((tangStream_t)0x02) +#endif //! TA_STREAM_PER_THREAD + +#ifndef tangStreamLegacy +#define tangStreamLegacy ((tangStream_t)0x01) +#endif //! tangStreamLegacy + +#ifndef tangStreamPerThread +#define tangStreamPerThread ((tangStream_t)0x02) +#endif //! tangStreamPerThread + +#ifdef __cplusplus +extern "C" { +#endif //! __cplusplus + +#ifdef __cplusplus +} +#endif //! __cplusplus + +#endif //! _TANG_RUNTIME_H_ diff --git a/third_party/sunrise/backend/include/tang_runtime_api.h b/third_party/sunrise/backend/include/tang_runtime_api.h new file mode 100755 index 000000000..106c1b293 --- /dev/null +++ b/third_party/sunrise/backend/include/tang_runtime_api.h @@ -0,0 +1,1871 @@ +/* +Copyright declaration. +*/ +#ifndef _TANG_RUNTIME_API_H_ +#define _TANG_RUNTIME_API_H_ + +#include "tang_rt/driver_types.h" +#include "tang_rt/host_defines.h" +#include "tang_rt/vector_types.h" + +/** + * @brief Flags that can be used with tangStreamCreateWithFlags + * @{ + */ +#define tangStreamDefault 0x00 ///< Default stream creation flags +#define tangStreamNonBlocking \ + 0x01 ///< Stream does not implicitly synchronize with null stream + +//! Flags that can be used with tangEventCreateWithFlags: +#define tangEventDefault 0x0 ///< Default flags +#define tangEventBlockingSync \ + 0x1 ///< Waiting will yield CPU. Power-friendly and usage-friendly but may + ///< increase latency. +#define tangEventDisableTiming \ + 0x2 ///< Disable event's capability to record timing information. May + ///< improve performance. +#define tangEventInterprocess \ + 0x4 ///< Event can support IPC. @warning - not supported right now. + +//! Flags that can be used with tangStreamWaitEvent: +#define tangEventWaitDefault 0x00 ///< Default stream creation flags +#define tangEventWaitExternal \ + 0x01 ///< Event is captured in the graph as an external event node when + ///< performing stream capture. @warning - not supported right now. + +/** + * @brief enum values that can be used with tangStreamCreateWithPriority and + * tangDeviceGetStreamPriorityRange + * @{ + */ +enum stream_priority { + priority_high = -2, + priority_middle = -1, + priority_normal = 0, + priority_low = 1 +}; + +#ifdef __cplusplus +extern "C" { +#endif //! __cplusplus + +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Version Management + * @{ + */ + +/** + * @brief Returns the TANG SDK Runtime version. + * + * @param [out] runtimeVersion + * + * @returns #tangSuccess, #tangErrorInavlidValue + * + * @warning The TANG SDK runtime version does not correspond to an exact CUDA + * SDK runtime revision. + * + * @see tangDriverGetVersion + */ +tangError_t TANGRT_API_PUBLIC tangRuntimeGetVersion(int* runtimeVersion); + +/** + * @brief Returns the TANG SDK Driver version. + * + * @param [out] driverVersion + * + * @returns #tangSuccess, #tangErrorInavlidValue + * + * @warning The TANG SDK driver veriosn does not correspond to an exact CUDA SDK + * driver revision. + * + * @see tangRuntimeGetVersion + */ +tangError_t TANGRT_API_PUBLIC tangDriverGetVersion(int* driverVersion); + +// end doxygen Error +/** + * @} + */ + +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Error Handling + * @{ + */ + +/** + * @brief Return last error returned by any TANG runtime API call and resets the + * stored error code to #tangSuccess + * + * @returns return code from last TANG called from the active host thread + * + * Returns the last error that has been returned by any of the runtime calls in + * the same host thread, and then resets the saved error to #tangSuccess. + * + * @see tangGetErrorString, tangGetLastError, tangPeakAtLastError, tangError_t + */ +tangError_t TANGRT_API_PUBLIC tangGetLastError(void); + +/** + * @brief Return last error returned by any TANG runtime API call. + * + * @return #tangSuccess + * + * Returns the last error that has been returned by any of the runtime calls in + * the same host thread. Unlike tangGetLastError, this function does not reset + * the saved error code. + * + * @see tangGetErrorString, tangGetLastError, tangPeakAtLastError, tangError_t + */ +tangError_t TANGRT_API_PUBLIC tangPeekAtLastError(void); + +/** + * @brief Return name of the specified error code in text form. + * + * @param tang_error Error code to convert to name. + * @return const char pointer to the NULL-terminated error name + * + * @see tangGetErrorString, tangGetLastError, tangPeakAtLastError, tangError_t + */ +TANGRT_API_PUBLIC const char* tangGetErrorName(tangError_t tang_error); + +/** + * @brief Return handy text string message to explain the error which occurred + * + * @param tangError Error code to convert to string. + * @return const char pointer to the NULL-terminated error string + * + * @warning : on HCC, this function returns the name of the error (same as + * tangGetErrorName) + * + * @see tangGetErrorName, tangGetLastError, tangPeakAtLastError, tangError_t + */ +TANGRT_API_PUBLIC const char* tangGetErrorString(tangError_t tangError); + +// end doxygen Error +/** + * @} + */ + +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Device Management + * @{ + */ + +/** + * @brief Return number of compute-capable devices. + * + * @param [output] count Returns number of compute-capable devices. + * + * @returns #tangSuccess, #tangErrorNoDevice + * + * + * Returns in @p *count the number of devices that have ability to run compute + * commands. If there are no such devices, then @ref tangGetDeviceCount will + * return #tangErrorNoDevice. If 1 or more devices can be found, then + * tangGetDeviceCount returns #tangSuccess. + */ +tangError_t TANGRT_API_PUBLIC tangGetDeviceCount(int* count); + +/** + * @brief Set default device to be used for subsequent tang API calls from this + * thread. + * + * @param[in] deviceId Valid device in range 0...tangGetDeviceCount(). + * + * Sets @p device as the default device for the calling host thread. Valid + * device id's are 0... (tangGetDeviceCount()-1). + * + * Many TANG APIs implicitly use the "default device" : + * + * - Any device memory subsequently allocated from this host thread (using + * tangMalloc) will be allocated on device. + * - Any streams or events created from this host thread will be associated with + * device. + * - Any kernels launched from this host thread (using tangLaunchKernel) will be + * executed on device (unless a specific stream is specified, in which case the + * device associated with that stream will be used). + * + * This function may be called from any host thread. Multiple host threads may + * use the same device. This function does no synchronization with the previous + * or new device, and has very little runtime overhead. Applications can use + * tangSetDevice to quickly switch the default device before making a TANG + * runtime call which uses the default device. + * + * The default device is stored in thread-local-storage for each thread. + * Thread-pool implementations may inherit the default device of the previous + * thread. A good practice is to always call tangSetDevice at the start of TANG + * coding sequency to establish a known standard device. + * + * @returns #tangSuccess, #tangErrorInvalidDevice, #tangErrorDeviceAlreadyInUse + * + * @see tangGetDevice, tangGetDeviceCount + */ +tangError_t TANGRT_API_PUBLIC tangSetDevice(int deviceId); + +/** + * @brief Return the default device id for the calling host thread. + * + * @param [out] device *device is written with the default device + * + * TANG maintains an default device for each thread using thread-local-storage. + * This device is used implicitly for TANG runtime APIs called by this thread. + * tangGetDevice returns in * @p device the default device for the calling host + * thread. + * + * @returns #tangSuccess, #tangErrorInvalidDevice, #tangErrorInvalidValue + * + * @see tangSetDevice, tangGetDevicesizeBytes + */ +tangError_t TANGRT_API_PUBLIC tangGetDevice(int* deviceId); + +/** + * @brief Waits on all active streams on current device + * + * When this command is invoked, the host thread gets blocked until all the + * commands associated with streams associated with the device. TANG does not + * support multiple blocking modes (yet!). + * + * @returns #tangSuccess + * + * @see tangSetDevice, tangDeviceReset + */ +tangError_t TANGRT_API_PUBLIC tangDeviceSynchronize(void); + +/** + * @brief Returns device properties. + * + * @param [out] props written with device properties + * @param [in] deviceId which device to query for information + * + * @return #tangSuccess, #tangErrorInvalidDevice + * @bug HCC always returns 0 for maxThreadsPerMultiProcessor + * @bug HCC always returns 0 for regsPerBlock + * @bug HCC always returns 0 for l2CacheSize + * + * Populates tangGetDeviceProperties with information for the specified device. + */ +TANGRT_API_PUBLIC tangError_t tangGetDeviceProperties(tangDeviceProp* props, + int deviceId); + +/** + * @brief Query for a specific device attribute. + * + * @param [out] value pointer to value to return + * @param [in] attr attribute to query + * @param [in] deviceId which device to query for information + * + * @returns #tangSuccess, #tangErrorInvalidDevice, #tangErrorInvalidValue + */ +TANGRT_API_PUBLIC tangError_t tangDeviceGetAttribute(int* value, + tangDeviceAttr attr, + int deviceId); + +/** + * @brief tangDeviceGetPeerPointer. + * @via port to convert addr access a peer device's memory. + * + * @param[in ] s2 device index + * @param[in ] ptlink used port index. + * @param[in ] memory address alloc in peer device; + * @param[out ] the pointer conver peerAddr to accessAddr; + * @return #tangSuccess, #tangErrorInvalidValue + */ +TANGRT_API_PUBLIC tangError_t tangDeviceGetPeerPointer(int srcDevice, + int port, + void* peerAddr, + void** accessAddr); + +/** + * @brief tangDeviceGetP2PAttribute. + * @Queries attributes of the link between two devices. + * + * @param[out ] returned value of the requested attribute + * @param[in ] the supported attributes. + * @param[in ] source device of the target link. + * @param[in ] destination device of the target link. + * @return #tangSuccess, #tangErrorInvalidValue, #tangErrorInvalidDevice + * + */ +TANGRT_API_PUBLIC tangError_t tangDeviceGetP2PAttribute(int* value, + tangDeviceP2PAttr attr, + int srcDevice, + int dstDevice); + +/** + * @brief tangDeviceCanAccessPeer. + * @Queries if a device may directly access a peer device's memory. + * + * @param[out ] canAccessPeer return value, 1: success, 0: false. + * @param[in ] device local device id. + * @param[in ] peerDevice remote device id; + * @return #tangSuccess, #tangErrorInvalidDevice + * + */ +TANGRT_API_PUBLIC tangError_t tangDeviceCanAccessPeer(int* canAccessPeer, + int device, + int peerDevice); + +/** + * @brief tangDeviceEnablePeerAccess. + * @Enables direct access to memory allocations on a peer device. + * + * @param[in ] peerDevice remote device id. + * @param[in ] flags set 0; + * @return #tangSuccess, #tangErrorInvalidValue, #tangErrorInvalidDevice, + * #tangErrorPeerAccessAlreadyEnabled + * + */ +TANGRT_API_PUBLIC tangError_t tangDeviceEnablePeerAccess(int peerDevice, + unsigned int flags); + +/** + * @brief tangMemcpyPeer. + * @memory copy from a device to a peer device. + * + * @param[in ] dst, dst device memory point; + * @param[in ] dstDevice, dst device id; + * @param[in ] src, src device memory point; + * @param[in ] srcDevice, dst device id; + * @param[in ] size of memory copy in bytes; + * @return #tangSuccess + * + */ +TANGRT_API_PUBLIC tangError_t tangMemcpyPeer(void* dst, + int dstDevice, + const void* src, + int srcDevice, + size_t count); + +/** + * @brief tangMemcpyPeerAsync. + * @memory copy from a device to a peer device. + * + * @param[in ] dst, dst device memory point; + * @param[in ] dstDevice, dst device id; + * @param[in ] src, src device memory point; + * @param[in ] srcDevice, dst device id; + * @param[in ] size of memory copy in bytes; + * @param[in ] stream, used stream; + * @return #tangSuccess + * + */ +TANGRT_API_PUBLIC tangError_t tangMemcpyPeerAsync(void* dst, + int dstDevice, + const void* src, + int srcDevice, + size_t count, + tangStream_t stream); + +TANGRT_API_PUBLIC tangError_t tangMemcpyPeer_v2(void* dst, + int dstDevice, + const void* src, + int srcDevice, + size_t count); + +TANGRT_API_PUBLIC tangError_t tangMemcpyPeer_v2_ptds(void* dst, + int dstDevice, + const void* src, + int srcDevice, + size_t count); + +TANGRT_API_PUBLIC tangError_t tangMemcpyPeerAsync_v2(void* dst, + int dstDevice, + const void* src, + int srcDevice, + size_t count, + tangStream_t stream); + +TANGRT_API_PUBLIC tangError_t tangMemcpyPeerAsync_v2_ptsz(void* dst, + int dstDevice, + const void* src, + int srcDevice, + size_t count, + tangStream_t stream); + +/** + * @brief tangDeviceDisablePeerAccess. + * @Disables direct access to memory allocations on a peer device. + * + * @param[in ] peerDevice remote device id; + * @return #tangSuccess, #tangErrorInvalidDevice, #tangErrorPeerAccessNotEnabled + * + */ +TANGRT_API_PUBLIC tangError_t tangDeviceDisablePeerAccess(int peerDevice); + +/** + * @brief Get Resource limits of current device + * + * @param [out] pValue + * @param [in] limit + * + * @returns #tangSuccess, #tangErrorUnsupportedLimit, #tangErrorInvalidValue + * Note: Currently, only tangLimitMallocHeapSize is available + * + */ +TANGRT_API_PUBLIC tangError_t tangDeviceGetLimit(size_t* pValue, + enum tangLimit limit); + +TANGRT_API_PUBLIC tangError_t tangDeviceSetLimit(enum tangLimit limit, + size_t value); + +TANGRT_API_PUBLIC tangError_t tangDeviceReset(void); + +/** + * @brief Returns a handle to a compute device. + * @param [out] device handle + * @param [in] PCI Bus ID + * + * @returns #tangSuccess, #tangErrorInavlidDevice, #tangErrorInvalidValue + */ +TANGRT_API_PUBLIC tangError_t tangDeviceGetByPCIBusId(int* device, + const char* pciBusId); + +/** + * @brief Returns a PCI Bus Id string for the device. + * + * @param [out] pciBusId - PCI Bus ID + * @param [in] len - Maximum length of pciBusId name string + * @param [in] deviceId - device handle + * @returns #tangSuccess, #tangErrorInavlidDevice, #tangErrorInvalidValue + */ +TANGRT_API_PUBLIC tangError_t tangDeviceGetPCIBusId(char* pciBusId, + int len, + int deviceId); + +/** + * @brief Set L1/Shared cache partition. + * + * @param [in] config + * + * @returns #tangSuccess, #tangErrorNotInitialized + * Note: On PT2 devices, L1 cache and shared memory are separated. + * Thus these hints and controls are ignored on those architectures. + * + */ +TANGRT_API_PUBLIC tangError_t tangDeviceSetCacheConfig(tangFuncCache config); + +/** + * @brief Set Cache configuration for a specific function + * + * @param [in] config + * + * @returns #tangSuccess, #tangErrorNotInitialized + * Note: On PT2 devices, L1 cache and shared memory are separated. + * Thus these hints and controls are ignored on those architectures. + * + */ +TANGRT_API_PUBLIC tangError_t tangDeviceGetCacheConfig(tangFuncCache* config); + +/** + * @brief The bank width of shared memory on current device is set + * + * @param [in] config + * + * @returns #tangSuccess, #tangErrorInvalidValue, #tangErrorNotInitialized + * + * Note: On PT2 devices, shard memory bank size is fix to 4-bytes. + * Thus these hints and controls are ignored on those architectures. + * + */ +TANGRT_API_PUBLIC tangError_t +tangDeviceSetSharedMemConfig(tangSharedMemConfig config); + +/** + * @brief Returns bank width of shared memory for current device + * + * @param [out] config + * + * @returns #tangSuccess, #tangErrorInvalidValue, #tangErrorNotInitialized + * + * Note: On PT2 devices, shard memory bank size is fix to 4-bytes. + * Thus these hints and controls are ignored on those architectures. + * + */ +TANGRT_API_PUBLIC tangError_t +tangDeviceGetSharedMemConfig(tangSharedMemConfig* config); + +/** + * @brief Returns numerical values that correspond to the least and greatest + * stream priority. + * + * @param[out] leastPriority pointer in which value corresponding to least + * priority is returned. + * @param[out] greatestPriority pointer in which value corresponding to greatest + * priority is returned. + * + * Returns in *leastPriority and *greatestPriority the numerical values that + * correspond to the least and greatest stream priority respectively. Stream + * priorities follow a convention where lower numbers imply greater priorities. + * The range of meaningful stream priorities is given by * [*greatestPriority, + * *leastPriority]. If the user attempts to create a stream with a priority + * value that is outside the the meaningful range as specified by this API, the + * priority is automatically clamped to within the valid range. + */ +TANGRT_API_PUBLIC tangError_t +tangDeviceGetStreamPriorityRange(int* leastPriority, int* greatestPriority); + +/** + * @brief Set a list of devices that can be used for TANG. + * + * @param[in] List of devices to try. + * @param[in] Number of devices in specified list. + * + * Sets a list of devices for TANG execution in priority order using device_arr. + * The parameter len specifies the number of elements in the list. TANG will try + * devices from the list sequentially until it finds one that works. If this + * function is not called, or if it is called with a len of 0, then TANG will go + * back to its default behavior of trying devices sequentially from a default + * list containing all of the available TANG devices in the system. If a + * specified device ID in the list does not exist, this function will return + * tangErrorInvalidDevice. If len is not 0 and device_arr is NULL or if len + * exceeds the number of devices in the system, then tangErrorInvalidValue is + * returned. + * + * @return #tangSuccess, #tangErrorInvalidValue, #tangErrorInvalidDevice + * + */ +TANGRT_API_PUBLIC tangError_t tangSetValidDevices(int* device_arr, int len); + +/** + * @brief Select compute-device which best matches criteria. + * + * @param[out] device Device with best match + * @param[in] properties Desired device properties + * + * @return #tangSuccess, #tangErrorInvalidValue + * + */ +TANGRT_API_PUBLIC tangError_t +tangChooseDevice(int* device, const tangDeviceProp* properties); + +// end doxygen Device +/** + * @} + */ + +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Memory Management + * @{ + * + */ + +/** + * @brief Allocate memory on the default accelerator + * + * @param[out] pptr Pointer to the allocated memory + * @param[in] size Requested memory size + * + * If size is 0, no memory is allocated, *ptr returns nullptr, and tangSuccess + * is returned. + * + * @return #tangSuccess, #tangErrorOutOfMemory, #tangErrorInvalidValue (bad + * context, null *ptr) + * + * @see tangMallocPitch, tangFree, tangMallocArray, tangFreeArray, + * tangMalloc3D, tangMalloc3DArray, tangHostFree, tangHostMalloc + */ +tangError_t TANGRT_API_PUBLIC tangMalloc(void** pptr, size_t sizeBytes); + +/** + * @brief Allocate memory. + * + * @param pptr + * @param sizeBytes + * @param hStream + * @return tangError_t + */ +tangError_t TANGRT_API_PUBLIC tangMallocAsync(void** pptr, + size_t sizeBytes, + tangStream_t hStream); + +tangError_t TANGRT_API_PUBLIC tangMallocAsync_ptsz(void** pptr, + size_t sizeBytes, + tangStream_t hStream); + +/** + * @brief Free memory allocated by the hcc tang memory allocation API. + * This API performs an implicit tangDeviceSynchronize() call. + * If pointer is NULL, the tang runtime is initialized and tangSuccess is + * returned. + * + * @param[in] ptr Pointer to memory to be freed + * @return #tangSuccess + * @return #tangErrorInvalidDevicePointer (if pointer is invalid, including + * host pointers allocated with tangHostMalloc) + * + * @see tangMalloc, tangMallocPitch, tangMallocArray, tangFreeArray, + * tangHostFree, tangMalloc3D, tangMalloc3DArray, tangHostMalloc + */ +tangError_t TANGRT_API_PUBLIC tangFree(void* ptr); + +/** + * @brief Free memory block async. + * + * @param ptr + * @param hStream + * @return tangError_t + */ +tangError_t TANGRT_API_PUBLIC tangFreeAsync(void* ptr, tangStream_t hStream); +tangError_t TANGRT_API_PUBLIC tangFreeAsync_ptsz(void* ptr, + tangStream_t hStream); + +/** + * @brief Allocate page locked host memory + * + * @param[out] pptr Pointer to the allocated page locked host memory + * @param[in] sizeBytes Requested memory size + * + * If size is 0, no memory is allocated, *ptr returns nullptr, and tangSuccess + * is returned. + * + * @return #tangSuccess, #tangErrorOutOfMemory + * + */ +tangError_t TANGRT_API_PUBLIC tangMallocHost(void** pptr, size_t sizeBytes); + +/** + * @brief Allocate page locked host memory + * + * @param[out] pptr Pointer to the allocated page locked host memory + * @param[in] sizeBytes Requested memory size + * @param[in] flags See below. + * + * flags: + * - #tangHostAllocDefault Memory is page locked. + * - #tangHostAllocPortable Memory is considered registered by all + * contexts. + * - #tangHostAllocMapped Map the allocation into the address space for + * the current device. + * - #tangHostAllocWriteCombined Allocates the memory as write-combined (WC). + * TANG does not support IOMMU on device side, so flags of tangHostAllocMapped + * and tangHostAllocWriteCombined will always return false. + * + * If size is 0, no memory is allocated, *ptr returns nullptr, and tangSuccess + * is returned. + * + * @return #tangSuccess, #tangErrorOutOfMemory + */ +tangError_t TANGRT_API_PUBLIC tangHostAlloc(void** pptr, + size_t sizeBytes, + unsigned int flags); + +/** + * @brief Passes back the device pointer corresponding to the mapped, pinned + * host buffer allocated by tangHostAlloc(). Note: on PT2 devices, device + * pointer of mapped host memory requires 4-byte aligned access (because of the + * PCIE access mode). The start address assigned by tangHostAlloc is 4-byte + * aligned by default, but further use of the "offset" over the device pointer + * of mapped host memory should be careful. Access that is not 4-byte aligned + * may result in incorrect calculations. + * + * @param[out] pDevice Returned device pointer for mapped memory + * @param[in] pHost Requested host pointer mapping + * @param[in] flags Flags for extensions (must be 0 for now) + * + * @return #tangSuccess, #tangErrorInvalidValue, #tangErrorMemoryAllocation + */ +tangError_t TANGRT_API_PUBLIC tangHostGetDevicePointer(void** pDevice, + void* pHost, + unsigned int flags); + +/** + * @brief Passes back flags used to allocate pinned host memory allocated by + * tangHostAlloc. + * + * @param[out] pFlags Returned flags word + * @param[in] pHost Host pointer + * + * pFlags: + * - #tangHostAllocDefault Memory is page locked. + * - #tangHostAllocPortable Memory is considered registered by all + * contexts. + * - #tangHostAllocMapped Map the allocation into the address space for + * the current device. + * - #tangHostAllocWriteCombined Allocates the memory as write-combined (WC). + * TANG does not support IOMMU on device side, so flags of tangHostAllocMapped + * and tangHostAllocWriteCombined are not supported. + * + * @return #tangSuccess, #tangErrorOutOfMemory + */ +TANGRT_API_PUBLIC tangError_t tangHostGetFlags(unsigned int* pFlags, + void* pHost); + +/** + * @brief Free the page locked host memory allocated by the tang host memory + allocation API. + * + * @param[in] ptr Pointer to memory to be freed + * + * @return #tangSuccess, + * #tangErrorInvalidValue (if pointer is invalid, including device + pointers allocated with tangMalloc) + */ +tangError_t TANGRT_API_PUBLIC tangFreeHost(void* ptr); + +/** + * @brief Register host memory as page locked memory. + * + * @param[out] ptr Pointer to host memory to be registered. + * @param[in] sizeBytes Size of the host memory + * @param[in] flags See below. + * + * flags: + * - #tangHostRegisterDefault Memory is page locked. + * - #tangHostRegisterPortable Memory is considered registered by all contexts. + * - #tangHostRegisterMapped Map the allocation into the address space for + * the current device. + * - #tangHostRegisterIoMemory The passed memory pointer is treated as pointing + * to some memory-mapped I/O space. + * - #tangHostRegisterReadOnly The passed memory pointer is treated as pointing + * to memory that is considered read-only by the device. TANG does not support + * IOMMU on device side, so flags of tangHostRegisterMapped and + * tangHostRegisterIoMemory and tangHostRegisterReadOnly will always return + * false. + * + * @return #tangSuccess, #tangErrorOutOfMemory + * + * @see tangHostUnregister, tangHostGetFlags, tangHostGetDevicePointer + */ +tangError_t TANGRT_API_PUBLIC tangHostRegister(void* ptr, + size_t sizeBytes, + unsigned int flags); + +/** + * @brief Un-register host pointer + * + * @param[in] ptr Host pointer previously registered + * @return Error code + * + * @see tangHostRegister + */ +tangError_t TANGRT_API_PUBLIC tangHostUnregister(void* ptr); + +/** + * @brief Copy data from src to dst. + * + * It supports memory from host to device, + * device to host, device to device and host to host + * The src and dst must not overlap. + * + * For tangMemcpy, the copy is always performed by the current device (set by +tangSetDevice). + * For multi-gpu or peer-to-peer configurations, it is recommended to set the +current device to the + * device where the src data is physically located. For optimal peer-to-peer +copies, the copy + * device must be able to access the src and dst pointers (by calling +tangDeviceEnablePeerAccess + * with copy agent as the current device and src/dest as the peerDevice +argument. if this is not + * done, the tangMemcpy will still work, but will perform the copy using a +staging buffer on the + * host. Calling tangMemcpy with dst and src pointers that do not match the +tangMemcpyKind results + * in undefined behavior. + * + * @param[out] dst Data being copy to + * @param[in] src Data being copy from + * @param[in] sizeBytes Data size in bytes + * @param[in] kind Memory copy type + * @return #tangSuccess, #tangErrorInvalidValue, +#tangErrorInvalidMemcpyDirection , #tangErrorDriverIoctlFailed + * + * @see tangArrayCreate, tangArrayDestroy, tangArrayGetDescriptor, +tangMemAlloc, tangMemAllocHost, + * tangMemAllocPitch, tangMemcpy2D, tangMemcpy2DAsync, tangMemcpy2DUnaligned, +tangMemcpyAtoA, + * tangMemcpyAtoD, tangMemcpyAtoH, tangMemcpyAtoHAsync, tangMemcpyDtoA, +tangMemcpyDtoD, + * tangMemcpyDtoDAsync, tangMemcpyDtoH, tangMemcpyDtoHAsync, tangMemcpyHtoA, +tangMemcpyHtoAAsync, + * tangMemcpyHtoDAsync, tangMemFree, tangMemFreeHost, tangMemGetAddressRange, +tangMemGetInfo, + * tangMemHostAlloc, tangMemHostGetDevicePointer + */ +tangError_t TANGRT_API_PUBLIC tangMemcpy(void* dst, + const void* src, + size_t sizeBytes, + tangMemcpyKind kind); + +tangError_t TANGRT_API_PUBLIC tangMemcpy_ptds(void* dst, + const void* src, + size_t sizeBytes, + tangMemcpyKind kind); + +/** + * @brief Copy data from src to dst asynchronously. + * + * It supports memory from host to device, + * device to host, device to device and host to host + * The src and dst must not overlap. + * + * For tangMemcpyAsync, the copy is always performed by the current device (set + * by tangSetDevice). For multi-gpu or peer-to-peer configurations, it is + * recommended to set the current device to the device where the src data is + * physically located. For optimal peer-to-peer copies, the copy device must be + * able to access the src and dst pointers (by calling + * tangDeviceEnablePeerAccess with copy agent as the current device and src/dest + * as the peerDevice argument. if this is not done, the tangMemcpyAsync will + * still work, but will perform the copy using a staging buffer on the host. + * Calling tangMemcpy with dst and src pointers that do not match the + * tangMemcpyKind results in undefined behavior. + * + * @param[out] dst Data being copy to + * @param[in] src Data being copy from + * @param[in] sizeBytes Data size in bytes + * @param[in] kind Memory copy type + * @param[in] stream Stream to execute copy on + * @return #tangSuccess, #tangErrorInvalidValue, + * #tangErrorInvalidMemcpyDirection, #tangErrorDriverIoctlFailed + * + * @see tangArrayCreate, tangArrayDestroy, tangArrayGetDescriptor, + * tangMemAlloc, tangMemAllocHost, tangMemAllocPitch, tangMemcpy2D, + * tangMemcpy2DAsync, tangMemcpy2DUnaligned, tangMemcpyAtoA, tangMemcpyAtoD, + * tangMemcpyAtoH, tangMemcpyAtoHAsync, tangMemcpyDtoA, tangMemcpyDtoD, + * tangMemcpyDtoDAsync, tangMemcpyDtoH, tangMemcpyDtoHAsync, tangMemcpyHtoA, + * tangMemcpyHtoAAsync, tangMemcpyHtoDAsync, tangMemFree, tangMemFreeHost, + * tangMemGetAddressRange, tangMemGetInfo, tangMemHostAlloc, + * tangMemHostGetDevicePointer + */ +tangError_t TANGRT_API_PUBLIC +tangMemcpyAsync(void* dst, + const void* src, + size_t sizeBytes, + tangMemcpyKind kind, + tangStream_t strem __dparm(nullptr)); + +tangError_t TANGRT_API_PUBLIC +tangMemcpyAsync_ptsz(void* dst, + const void* src, + size_t sizeBytes, + tangMemcpyKind kind, + tangStream_t strem __dparm(nullptr)); + +/** + * @brief Fills the first sizeBytes bytes of the memory area pointed to by dest + * with the constant byte value value. + * + * @param[out] dst Data being filled + * @param[in] constant value to be set + * @param[in] sizeBytes Data size in bytes + * @return #tangSuccess, #tangErrorInvalidValue, #tangErrorNotInitialized + */ +tangError_t TANGRT_API_PUBLIC tangMemset(void* dst, + int value, + size_t sizeBytes); + +tangError_t TANGRT_API_PUBLIC tangMemset_ptds(void* dst, + int value, + size_t sizeBytes); + +/** + * @brief Fills the first sizeBytes bytes of the memory area pointed to by dev + * with the constant byte value value. + * + * tangMemsetAsync() is asynchronous with respect to the host, so the call may + * return before the memset is complete. The operation can optionally be + * associated to a stream by passing a non-zero stream argument. If stream is + * non-zero, the operation may overlap with operations in other streams. + * + * @param[out] dst Pointer to device memory + * @param[in] value - Value to set for each byte of specified memory + * @param[in] sizeBytes - Size in bytes to set + * @param[in] stream - Stream identifier + * @return #tangSuccess, #tangErrorInvalidValue, #tangErrorMemoryFree + */ +tangError_t TANGRT_API_PUBLIC +tangMemsetAsync(void* dst, + int value, + size_t sizeBytes, + tangStream_t stream __dparm(nullptr)); + +tangError_t TANGRT_API_PUBLIC +tangMemsetAsync_ptsz(void* dst, + int value, + size_t sizeBytes, + tangStream_t strem __dparm(nullptr)); + +tangError_t TANGRT_API_PUBLIC +tangMemcpyFromSymbol(void* dst, + const void* symbol, + size_t count, + size_t offset __dparm(0), + tangMemcpyKind kind __dparm(tangMemcpyDeviceToHost)); + +tangError_t TANGRT_API_PUBLIC +tangMemcpyFromSymbol_ptds(void* dst, + const void* symbol, + size_t count, + size_t offset __dparm(0), + tangMemcpyKind kind __dparm(tangMemcpyDeviceToHost)); + +tangError_t TANGRT_API_PUBLIC +tangMemcpyToSymbol(const void* symbol, + const void* src, + size_t count, + size_t offset __dparm(0), + tangMemcpyKind kind __dparm(tangMemcpyHostToDevice)); + +tangError_t TANGRT_API_PUBLIC +tangMemcpyToSymbol_ptds(const void* symbol, + const void* src, + size_t count, + size_t offset __dparm(0), + tangMemcpyKind kind __dparm(tangMemcpyHostToDevice)); + +tangError_t TANGRT_API_PUBLIC +tangMemcpyToSymbolAsync(const void* symbol, + const void* src, + size_t count, + size_t offset, + tangMemcpyKind kind, + tangStream_t stream __dparm(nullptr)); + +tangError_t TANGRT_API_PUBLIC +tangMemcpyToSymbolAsync_ptsz(const void* symbol, + const void* src, + size_t count, + size_t offset, + tangMemcpyKind kind, + tangStream_t stream __dparm(nullptr)); + +tangError_t TANGRT_API_PUBLIC +tangMemcpyFromSymbolAsync(void* dst, + const void* symbol, + size_t count, + size_t offset, + tangMemcpyKind kind, + tangStream_t stream __dparm(nullptr)); + +tangError_t TANGRT_API_PUBLIC +tangMemcpyFromSymbolAsync_ptsz(void* dst, + const void* symbol, + size_t count, + size_t offset, + tangMemcpyKind kind, + tangStream_t stream __dparm(nullptr)); + +/** + * @brief Query memory info. + * Return snapshot of free memory, and total allocatable memory on the device. + * + * Returns in *free a snapshot of the current free memory. + * @returns #tangSuccess, #tangErrorInvalidDevice, #tangErrorInvalidValue + * @warning On HCC, the free memory only accounts for memory allocated by this + *process and may be optimistic. + **/ +tangError_t TANGRT_API_PUBLIC tangMemGetInfo(size_t* free, size_t* total); + +/** + * @brief Finds the address associated with a TANG symbol. + * + * @param[out] devPtr Device pointer associated with symbol + * @param[in] symbol Device symbol address + * @return #tangSuccess, #tangErrorInvalidValue + **/ +tangError_t TANGRT_API_PUBLIC tangGetSymbolAddress(void** devPtr, + const void* symbol); + +/** + * @brief Finds the size of the object associated with a TANG symbol. + * + * @param[out] size Size of object associated with symbol + * @param[in] symbol Device symbol address + * @return #tangSuccess, #tangErrorInvalidValue + **/ +tangError_t TANGRT_API_PUBLIC tangGetSymbolSize(size_t* size, + const void* symbol); + +// doxygen end Memory +/** + * @} + */ + +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Stream Management + * @{ + */ + +/** + * @brief Create an asynchronous stream. + * + * @param[in, out] stream Valid pointer to tangStream_t. This function writes + * the memory with the newly created stream. + * @return #tangSuccess, #tangErrorInvalidValue + * + * Create a new asynchronous stream. @p stream returns an opaque handle that + * can be used to reference the newly created stream in subsequent tangStream* + * commands. The stream is allocated on the heap and will remain allocated even + * if the handle goes out-of-scope. To release the memory used by the stream, + * applicaiton must call tangStreamDestroy. + * + * @return #tangSuccess, #tangErrorInvalidValue + * + * @see tangStreamCreateWithFlags, tangStreamCreateWithPriority, + * tangStreamSynchronize, tangStreamWaitEvent, tangStreamDestroy + */ +tangError_t TANGRT_API_PUBLIC tangStreamCreate(tangStream_t* stream); + +/** + * @brief communicate to c2c. + * + * @param[in, out] stream Pointer to new stream + * @param[in ] cmd is that command packets + * @param[in ] cmd_count is that command count + * @param[in ] device_addr is that hbm addr + * @return #tangSuccess, #tangErrorInvalidValue + * + * send command to c2c module, it can be used by ptlink. + * + */ +tangError_t TANGRT_API_PUBLIC tangStreamC2Ctransfers(tangStream_t stream, + uint32_t* cmd, + uint32_t cmd_count, + uint64_t device_addr, + uint32_t mem_size); + +/** + * @brief Create an asynchronous stream. + * + * @param[in, out] stream Pointer to new stream + * @param[in ] flags to control stream creation. + * @return #tangSuccess, #tangErrorInvalidValue + * + * Create a new asynchronous stream. @p stream returns an opaque handle that + * can be used to reference the newly created stream in subsequent tangStream* + * commands. The stream is allocated on the heap and will remain allocated even + * if the handle goes out-of-scope. To release the memory used by the stream, + * applicaiton must call tangStreamDestroy. Flags controls behavior of the + * stream. See #tangStreamDefault, #tangStreamNonBlocking. + * + * @see tangStreamCreate, tangStreamCreateWithPriority, tangStreamSynchronize, + * tangStreamWaitEvent, tangStreamDestroy + */ +tangError_t TANGRT_API_PUBLIC tangStreamCreateWithFlags(tangStream_t* stream, + unsigned int flags __dparm(tangStreamDefault)); + +/** + * @brief Create an asynchronous stream with the specified priority. + * + * @param[in, out] stream Pointer to new stream + * @param[in ] flags to control stream creation. + * @param[in ] priority of the stream. Lower numbers represent higher + * priorities. + * @return #tangSuccess, #tangErrorInvalidValue + * + * Create a new asynchronous stream with the specified priority. @p stream + * returns an opaque handle that can be used to reference the newly created + * stream in subsequent tangStream* commands. The stream is allocated on the + * heap and will remain allocated even if the handle goes out-of-scope. To + * release the memory used by the stream, applicaiton must call + * tangStreamDestroy. Flags controls behavior of the stream. See + * #tangStreamDefault, #tangStreamNonBlocking. + * + * @see tangStreamCreate, tangStreamSynchronize, tangStreamWaitEvent, + * tangStreamDestroy + */ +tangError_t TANGRT_API_PUBLIC tangStreamCreateWithPriority(tangStream_t* stream, + unsigned int flags __dparm(tangStreamDefault), + int priority __dparm(priority_normal)); + +/** + * @brief Returns numerical values that correspond to the least and greatest + * stream priority. + * + * @param[in, out] leastPriority pointer in which value corresponding to least + * priority is returned. + * @param[in, out] greatestPriority pointer in which value corresponding to + * greatest priority is returned. + * + * Returns in *leastPriority and *greatestPriority the numerical values that + * correspond to the least and greatest stream priority respectively. Stream + * priorities follow a convention where lower numbers imply greater priorities. + * The range of meaningful stream priorities is given by + * [*greatestPriority, *leastPriority]. If the user attempts to create a stream + * with a priority value that is outside the the meaningful range as specified + * by this API, the priority is automatically clamped to within the valid range. + */ +tangError_t tangDeviceGetStreamPriorityRange(int* leastPriority, + int* greatestPriority); + +/** + * @brief Query the priority of a stream. + * + * @param[in] hStream stream to be queried + * @param[in,out] priority Pointer to an unsigned integer in which the stream's + * priority is returned + * @return #tangSuccess, #tangErrorInvalidValue + * + * Query the priority of a stream. The priority is returned in in priority. + * + * @see tangStreamCreateWithFlags + */ +tangError_t TANGRT_API_PUBLIC tangStreamGetPriority(tangStream_t hStream, + int* priority); + +tangError_t TANGRT_API_PUBLIC tangStreamGetPriority_ptsz(tangStream_t stream, + int* priority); + +/** + * @brief Destroys the specified stream. + * + * @param[in, out] stream Valid pointer to tangStream_t. This function writes + * the memory with the newly created stream. + * @return #tangSuccess #tangErrorInvalidHandle + * + * Destroys the specified stream. + * + * If commands are still executing on the specified stream, some may complete + * execution before the queue is deleted. + * + * The queue may be destroyed while some commands are still inflight, or may + * wait for all commands queued to the stream before destroying it. + * + * @see tangStreamCreate, tangStreamCreateWithFlags, + * tangStreamCreateWithPriority, tangStreamQuery, tangStreamWaitEvent, + * tangStreamSynchronize + */ +tangError_t TANGRT_API_PUBLIC tangStreamDestroy(tangStream_t stream); + +/** + * @brief Wait for all commands in stream to complete. + * + * @param[in] stream stream identifier. + * + * @return #tangSuccess, #tangErrorInvalidHandle + * + * This command is host-synchronous : the host will block until the specified + * stream is empty. + * + * This command follows standard null-stream semantics. Specifically, + * specifying the null stream will cause the command to wait for other streams + * on the same device to complete all pending operations. + * + * This command honors the tangDeviceLaunchBlocking flag, which controls whether + * the wait is active or blocking. + * + * @see tangStreamCreate, tangStreamCreateWithFlags, + * tangStreamCreateWithPriority, tangStreamWaitEvent, tangStreamDestroy + * + */ +tangError_t TANGRT_API_PUBLIC tangStreamSynchronize(tangStream_t stream); + +tangError_t TANGRT_API_PUBLIC tangStreamSynchronize_ptsz(tangStream_t stream); + +/** + * @brief Check if a stream has completed all its commands + * + * @param stream + * @return tangError_t + * tangSuccess + * tangErrorNotReady + */ +tangError_t TANGRT_API_PUBLIC tangStreamQuery(tangStream_t stream); +tangError_t TANGRT_API_PUBLIC tangStreamQuery_ptsz(tangStream_t stream); + +tangError_t TANGRT_API_PUBLIC +tangStreamBeginCapture(tangStream_t stream, tangStreamCaptureMode mode); + +tangError_t TANGRT_API_PUBLIC +tangStreamBeginCapture_ptsz(tangStream_t stream, + tangStreamCaptureMode mode); + +tangError_t TANGRT_API_PUBLIC tangStreamEndCapture(tangStream_t stream, + tangGraph_t* pGraph); + +tangError_t TANGRT_API_PUBLIC tangStreamEndCapture_ptsz(tangStream_t stream, + tangGraph_t* pGraph); + +tangError_t TANGRT_API_PUBLIC +tangStreamIsCapturing(tangStream_t stream, + tangStreamCaptureStatus* pStatus); + +tangError_t TANGRT_API_PUBLIC +tangStreamIsCapturing_ptsz(tangStream_t stream, + tangStreamCaptureStatus* pStatus); + +tangError_t TANGRT_API_PUBLIC +tangStreamGetCaptureInfo(tangStream_t hStream, + tangStreamCaptureStatus* pStatus, + unsigned long long* pId __dparm(0), + tangGraph_t* pGraph __dparm(0), + const tangGraphNode_t** deps __dparm(0), + size_t* numDeps __dparm(0)); + +tangError_t TANGRT_API_PUBLIC +tangStreamGetCaptureInfo_ptsz(tangStream_t hStream, + tangStreamCaptureStatus* pStatus, + unsigned long long* pId __dparm(0), + tangGraph_t* pGraph __dparm(0), + const tangGraphNode_t** deps __dparm(0), + size_t* numDeps __dparm(0)); + +tangError_t TANGRT_API_PUBLIC +tangThreadExchangeStreamCaptureMode(tangStreamCaptureMode* mode); + +tangError_t TANGRT_API_PUBLIC tangGraphInstantiate(tangGraphExec_t* pGraphExec, + tangGraph_t graph, + void*, + void*, + unsigned long long); + +tangError_t TANGRT_API_PUBLIC tangGraphLaunch(tangGraphExec_t graphExec, + tangStream_t stream); + +tangError_t TANGRT_API_PUBLIC tangGraphLaunch_ptsz(tangGraphExec_t graphExec, + tangStream_t stream); +tangError_t TANGRT_API_PUBLIC +tangGraphInstantiateWithFlags(tangGraphExec_t* pGraphExec, + tangGraph_t graph, + unsigned long long flags); + +tangError_t TANGRT_API_PUBLIC tangGraphDestroy(tangGraph_t graph); +tangError_t TANGRT_API_PUBLIC tangGraphExecDestroy(tangGraphExec_t graphExec); + +tangError_t TANGRT_API_PUBLIC tangGraphGetInfo(tangGraph_t graph, + tangGraphInfo* pInfo); + +tangError_t TANGRT_API_PUBLIC tangGraphCreate(tangGraph_t* pGraph, + unsigned int flags); + +tangError_t TANGRT_API_PUBLIC +tangGraphAddHostNode(tangGraphNode_t* pGraphNode, + tangGraph_t graph, + const tangGraphNode_t* dependencies, + size_t numDependencies, + const tangHostNodeParams* nodeParams); + +tangError_t TANGRT_API_PUBLIC +tangGraphAddKernelNode(tangGraphNode_t* pGraphNode, + tangGraph_t graph, + const tangGraphNode_t* dependencies, + size_t numDependencies, + const tangKernelNodeParams* nodeParams); + +/** + * @brief Make the specified compute stream wait for an event + * + * @param[in] stream stream to make wait. + * @param[in] event event to wait on + * @param[in] flag control operation + * + * @return #tangSuccess, #tangErrorInvalidHandle + * + * This function inserts a wait operation into the specified stream. + * All future work submitted to @p stream will wait until @p event reports + * completion before beginning execution. + * + * This function only waits for commands in the current stream to complete. + * Notably,, this function does not impliciy wait for commands in the default + * stream to complete, even if the specified stream is created with + * tangStreamNonBlocking = 0. + * + * @see tangStreamCreate, tangStreamCreateWithFlags, + * tangStreamCreateWithPriority, tangStreamSynchronize, tangStreamDestroy + */ +tangError_t TANGRT_API_PUBLIC tangStreamWaitEvent(tangStream_t stream, + tangEvent_t event, + unsigned int flag __dparm(0)); + +tangError_t TANGRT_API_PUBLIC +tangStreamWaitEvent_ptsz(tangStream_t stream, + tangEvent_t event, + unsigned int flag __dparm(0)); +/** + * @brief Return flags associated with this stream. + * + * @param[in] stream stream to be queried + * @param[in,out] flag Pointer to an unsigned integer in which the stream's + * flags are returned + * @return #tangSuccess, #tangErrorInvalidValue, #tangErrorInvalidHandle + * + * @returns #tangSuccess #tangErrorInvalidValue #tangErrorInvalidHandle + * + * Return flags associated with this stream in *@p flag. + * + * @see tangStreamCreateWithFlags + */ +tangError_t TANGRT_API_PUBLIC tangStreamGetFlags(tangStream_t stream, + unsigned int* flags); + +tangError_t TANGRT_API_PUBLIC tangStreamGetFlags_ptsz(tangStream_t stream, + unsigned int* flags); + +tangError_t TANGRT_API_PUBLIC tangStreamGetId(tangStream_t stream, + int* pId); + +tangError_t TANGRT_API_PUBLIC tangStreamGetId_ptsz(tangStream_t stream, + int* pId); + +typedef void (*tangStreamCallback_t)(tangStream_t stream, + tangError_t status, + void* userData); + +tangError_t TANGRT_API_PUBLIC +tangStreamAddCallback(tangStream_t stream, + tangStreamCallback_t callback, + void* userData, + unsigned int flags); + +tangError_t TANGRT_API_PUBLIC +tangStreamAddCallback_ptsz(tangStream_t stream, + tangStreamCallback_t callback, + void* userData, + unsigned int flags); + +tangError_t TANGRT_API_PUBLIC tangLaunchHostFunc(tangStream_t stream, + tangHostFn_t fn, + void* userData); + +tangError_t TANGRT_API_PUBLIC tangLaunchHostFunc_ptsz(tangStream_t stream, + tangHostFn_t fn, + void* userData); + +tangError_t TANGRT_API_PUBLIC tangProfilerStart(); +tangError_t TANGRT_API_PUBLIC tangProfilerStop(); +// end doxygen Stream +/** + * @} + */ + +tangError_t TANGRT_API_PUBLIC tangEngineCollAssign(int devId, + int collType, + uint64_t devAddr, + int length, + tangStream_t stream); + + +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Event Management + * @{ + */ + +/** + * @brief Create an event with the specified flags + * + * @param[in,out] event Returns the newly created event. + * @param[in] flag Flag to control event behavior. Valid values are + #tangEventDefault, #tangEventBlockingSync, #tangEventDisableTiming, + #tangEventInterprocess + + * #tangEventDefault : Default flag. The event will use active synchronization + and will support timing. Blocking synchronization provides lowest possible + latency at the expense of dedicating a CPU to poll on the event. + * #tangEventBlockingSync : The event will use blocking synchronization : if + tangEventSynchronize is called on this event, the thread will block until the + event completes. This can increase latency for the synchroniation but can + result in lower power and more resources for other CPU threads. + * #tangEventDisableTiming : Disable recording of timing information. On ROCM + platform, timing information is always recorded and this flag has no + performance benefit. + + * @warning tangEventInterprocess support is under development. Use of this + flag will return an error. + * + * @returns #tangSuccess, #tangErrorNotInitialized, #tangErrorInvalidValue, + #tangErrorLaunchFailure, #tangErrorOutOfMemory + * + * @see tangEventCreate, tangEventSynchronize, tangEventDestroy, + tangEventElapsedTime + */ +tangError_t TANGRT_API_PUBLIC tangEventCreateWithFlags(tangEvent_t* event, + unsigned flag); + +/** + * Create an event + * + * @param[in,out] event Returns the newly created event. + * + * @returns #tangSuccess, #tangErrorNotInitialized, #tangErrorInvalidValue, + * #tangErrorLaunchFailure, #tangErrorOutOfMemory + * + * @see tangEventCreateWithFlags, tangEventRecord, tangEventQuery, + * tangEventSynchronize, tangEventDestroy, tangEventElapsedTime + */ +tangError_t TANGRT_API_PUBLIC tangEventCreate(tangEvent_t* event); + +/** + * @brief Record an event in the specified stream. + * + * @param[in] event event to record. + * @param[in] stream stream in which to record event. + * @returns #tangSuccess, #tangErrorInvalidValue, #tangErrorNotInitialized, + * #tangErrorInvalidHandle, #tangErrorLaunchFailure + * + * tangEventQuery() or tangEventSynchronize() must be used to determine when the + * event transitions from "recording" (after tangEventRecord() is called) to + * "recorded" (when timestamps are set, if requested). + * + * Events which are recorded in a non-NULL stream will transition to + * from recording to "recorded" state when they reach the head of + * the specified stream, after all previous + * commands in that stream have completed executing. + * + * If tangEventRecord() has been previously called on this event, then this call + * will overwrite any existing state in event. + * + * If this function is called on an event that is currently being recorded, + * results are undefined + * - either outstanding recording may save state into the event, and the order + * is not guaranteed. + * + * @see tangEventCreate, tangEventCreateWithFlags, tangEventQuery, + * tangEventSynchronize, tangEventDestroy, tangEventElapsedTime + * + */ +tangError_t TANGRT_API_PUBLIC +tangEventRecord(tangEvent_t event, tangStream_t stream __dparm(nullptr)); + +tangError_t TANGRT_API_PUBLIC +tangEventRecord_ptsz(tangEvent_t event, tangStream_t stream __dparm(nullptr)); + +tangError_t TANGRT_API_PUBLIC +tangEventRecordWithFlags(tangEvent_t event, + tangStream_t stream __dparm(nullptr), + unsigned int flags __dparm(0)); + +tangError_t TANGRT_API_PUBLIC +tangEventRecordWithFlags_ptsz(tangEvent_t event, + tangStream_t stream __dparm(nullptr), + unsigned int flags __dparm(0)); +/** + * @brief Destroy the specified event. + * + * @param[in] event Event to destroy. + * @returns #tangSuccess, #tangErrorNotInitialized, #tangErrorInvalidValue, + * #tangErrorLaunchFailure + * + * Releases memory associated with the event. If the event is recording but + * has not completed recording when tangEventDestroy() is called, the function + * will return immediately and the completion_future resources will be released + * later, when the tangDevice is synchronized. + * + * @see tangEventCreate, tangEventCreateWithFlags, tangEventQuery, + * tangEventSynchronize, tangEventRecord, tangEventElapsedTime + * + * @returns #tangSuccess + */ +tangError_t TANGRT_API_PUBLIC tangEventDestroy(tangEvent_t event); + +/** + * @brief Wait for an event to complete. + * + * This function will block until the event is ready, waiting for all previous + * work in the stream specified when event was recorded with tangEventRecord(). + * + * If tangEventRecord() has not been called on @p event, this function returns + * immediately. + * + * TODO-hcc - This function needs to support tangEventBlockingSync parameter. + * + * @param[in] event Event on which to wait. + * @returns #tangSuccess, #tangErrorInvalidValue, #tangErrorNotInitialized, + * #tangErrorInvalidHandle, #tangErrorLaunchFailure + * + * @see tangEventCreate, tangEventCreateWithFlags, tangEventQuery, + * tangEventDestroy, tangEventRecord, tangEventElapsedTime + */ +tangError_t TANGRT_API_PUBLIC tangEventSynchronize(tangEvent_t event); +tangError_t TANGRT_API_PUBLIC tangEventSynchronizeWithFlags(tangEvent_t event, + unsigned int flags); + +/** + * @brief Return the elapsed time between two events. + * + * @param[out] ms : Return time between start and stop in ms. + * @param[in] start : Start event. + * @param[in] stop : Stop event. + * @returns #tangSuccess, #tangErrorInvalidValue, #tangErrorNotReady, + * #tangErrorInvalidHandle, #tangErrorNotInitialized, #tangErrorLaunchFailure + * + * Computes the elapsed time between two events. Time is computed in ms, with + * a resolution of approximately 1 us. + * + * Events which are recorded in a NULL stream will block until all commands + * on all other streams complete execution, and then record the timestamp. + * + * Events which are recorded in a non-NULL stream will record their timestamp + * when they reach the head of the specified stream, after all previous + * commands in that stream have completed executing. Thus the time that + * the event recorded may be significantly after the host calls + * tangEventRecord(). + * + * If tangEventRecord() has not been called on either event, then + * #tangErrorInvalidHandle is returned. If tangEventRecord() has been called on + * both events, but the timestamp has not yet been recorded on one or both + * events (that is, tangEventQuery() would return #tangErrorNotReady on at least + * one of the events), then #tangErrorNotReady is returned. + * + * @see tangEventCreate, tangEventCreateWithFlags, tangEventQuery, + * tangEventDestroy, tangEventRecord, tangEventSynchronize + */ +tangError_t TANGRT_API_PUBLIC tangEventElapsedTime(float* ms, + tangEvent_t start, + tangEvent_t stop); + +/** + * @brief Query event status + * + * @param[in] event Event to query. + * @returns #tangSuccess, #tangErrorNotReady, #tangErrorInvalidHandle, + * #tangErrorInvalidValue, #tangErrorNotInitialized, #tangErrorLaunchFailure + * + * Query the status of the specified event. This function will return + * #tangErrorNotReady if all commands in the appropriate stream (specified to + * tangEventRecord()) have completed. If that work has not completed, or if + * tangEventRecord() was not called on the event, then #tangSuccess is returned. + * + * @see tangEventCreate, tangEventCreateWithFlags, tangEventRecord, + * tangEventDestroy, tangEventSynchronize, tangEventElapsedTime + */ +tangError_t TANGRT_API_PUBLIC tangEventQuery(tangEvent_t event); + +tangError_t TANGRT_API_PUBLIC tangEventQueryTimestamp(tangEvent_t event, + tangEventTimestamp* ts); +// end doxygen Event +/** + * @} + */ + +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Excution Control + * @{ + */ + +/** + * @brief Find out attributes for a given function. + * + * @param [out] attr + * @param [in] func + * + * @returns tangSuccess, tangErrorInvalidValue, tangErrorInvalidDeviceFunction + * + * NOTE: runtime only has tangFuncGetAttributes API, has no tangFuncGetAttribute + * API. while user mode driver only has taFuncGetAttribute API, has no + * taFuncGetAttributes API. + */ + +tangError_t TANGRT_API_PUBLIC tangFuncGetAttributes(tangFuncAttributes* attr, + const void* func); + +tangError_t TANGRT_API_PUBLIC tangGetFuncBySymbol(tangFunction_t *hFunc, + const void *symbol); + +tangError_t TANGRT_API_PUBLIC +tangPointerGetAttributes(struct tangPointerAttributes* attributes, + const void* ptr); + +/** + * @brief Set attribute for a specific function + * + * @param [in] func; + * @param [in] attr; + * @param [in] value; + * + * @returns #tangSuccess, #tangErrorInvalidDeviceFunction, + * #tangErrorInvalidValue + * + * Note: PT devices do not support shared cache banking, and the hint is + * ignored. + * + * NOTE: runtime tangFuncSetAttribute API only supports two types of + * tangFuncAttribute. while user mode driver taFuncSetAttribute API supports + * more types of TAfunction_attribute. + * + */ +tangError_t TANGRT_API_PUBLIC tangFuncSetAttribute(const void* func, + tangFuncAttribute attr, + int value); + +/** + * @brief Set Cache configuration for a specific function + * + * @param [in] func; + * @param [in] config; + * + * @returns #tangSuccess, #tangErrorNotInitialized + * Note: PT devices do not support reconfigurable cache. This hint is ignored. + * + */ +tangError_t TANGRT_API_PUBLIC tangFuncSetCacheConfig(const void* func, + tangFuncCache config); + +/** + * @brief Set shared memory configuation for a specific function + * + * @param [in] func + * @param [in] config + * + * @returns #tangSuccess, #tangErrorInvalidValue, + * #tangErrorInvalidDeviceFunction + * + * Note: PT devices do not support shared cache banking, and the hint is + * ignored. + * + */ +tangError_t TANGRT_API_PUBLIC +tangFuncSetSharedMemConfig(const void* func, tangSharedMemConfig config); + +/** + * @brief Converts a double argument to be executed on a device. + * + * @param[in][out] d Double to convert. + * @returns #tangSuccess, #tangErrorInvalidValue + * + * Note: PT2 devices do not support double both on hardware and software + * simulation. + * + */ +TANGRT_API_PUBLIC tangError_t tangSetDoubleForDevice(double* d); + +/** + * @brief Converts a double argument after execution on a device. + * + * @param[in][out] d Double to convert. + * @returns #tangSuccess, #tangErrorInvalidValue + * + * Note: PT2 devices do not support double both on hardware and software + * simulation. + * + */ +TANGRT_API_PUBLIC tangError_t tangSetDoubleForHost(double* d); + +// doxygen end Execution Control +/** + * @} + */ + +/** + *------------------------------------------------------------------------------------------------- + *------------------------------------------------------------------------------------------------- + * @defgroup Occupancy + * @{ + * + */ + +/** + * @brief Returns occupancy for a device function. + * + * @param[out] numBlocks Returned occupancy + * @param[in] func Kernel function for which occupancy is calculated + * @param[in] blockSize Block size the kernel is intended to be launched with + * @param[in] dynamicSMemSize Per-block dynamic shared memory usage intended, in + * bytes + * @returns #tangSuccess, #tangErrorInvalidValue + * + */ +TANGRT_API_PUBLIC tangError_t +tangOccupancyMaxActiveBlocksPerMultiprocessor(int* numBlocks, + const void* func, + int blockSize, + size_t dynamicSMemSize); + +/** + * @brief Returns occupancy for a device function with the specified flags. + * + * @param[out] numBlocks Returned occupancy + * @param[in] func Kernel function for which occupancy is calculated + * @param[in] blockSize Block size the kernel is intended to be launched with + * @param[in] dynamicSMemSize Per-block dynamic shared memory usage intended, + * in bytes + * @param[in] flags Requested behavior for the occupancy calculator + * @returns #tangSuccess, #tangErrorInvalidValue + * + */ +TANGRT_API_PUBLIC tangError_t +tangOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(int* numBlocks, + const void* func, + int blockSize, + size_t dynamicSMemSize, + unsigned int flags); + +TANGRT_API_PUBLIC tangError_t tangIpcGetMemHandle(tangIpcMemHandle_t* pHandle, + void* devPtr); + +TANGRT_API_PUBLIC tangError_t tangIpcOpenMemHandle(void** devPtr, + tangIpcMemHandle_t handle, + unsigned int flags); + +TANGRT_API_PUBLIC tangError_t tangIpcCloseMemHandle(void* devPtr); + +TANGRT_API_PUBLIC tangError_t +tangIpcGetEventHandle(tangIpcEventHandle_t* pHandle, tangEvent_t event); + +TANGRT_API_PUBLIC tangError_t +tangIpcOpenEventHandle(tangEvent_t* phEvent, tangIpcEventHandle_t handle); + +// private api +TANGRT_API_PUBLIC tangError_t tangGetExportTable(void** pExportedTable, + void* args); + +// end doxygen Occupancy +/** + * @} + */ + +#ifdef __cplusplus +} +#endif //! __cplusplus + +#ifdef __cplusplus +template +inline tangError_t tangMalloc(T** pptr, size_t sizeBytes) { + return ::tangMalloc((void**)pptr, sizeBytes); +} +template +inline tangError_t tangMallocHost(T** pptr, size_t sizeBytes) { + return ::tangMallocHost((void**)pptr, sizeBytes); +} +template +inline tangError_t tangMallocAsync(T** pptr, size_t sizeBytes, tangStream_t stream) { + return ::tangMallocAsync((void**)pptr, sizeBytes, stream); +} +template +inline tangError_t tangMallocAsync_ptsz(T** pptr, size_t sizeBytes, tangStream_t stream) { + return ::tangMallocAsync_ptsz((void**)pptr, sizeBytes, stream); +} +template +inline tangError_t tangHostAlloc(T** pptr, + size_t sizeBytes, + unsigned int flags) { + return ::tangHostAlloc((void**)pptr, sizeBytes, flags); +} +template +inline tangError_t tangHostGetDevicePointer(T** pDevice, + void* pHost, + unsigned int flags) { + return ::tangHostGetDevicePointer((void**)pDevice, pHost, flags); +} +template +inline tangError_t tangIpcOpenMemHandle(T** pDevPtr, + tangIpcMemHandle_t handle, + unsigned int flags) { + return ::tangIpcOpenMemHandle((void**)pDevPtr, handle, flags); +} +#endif // __cplusplus + +#if defined(__TANGRT_API_PER_THREAD_DEFAULT_STREAM) +#define tangMemset __TANGRT_API_PTDS(tangMemset) +#define tangMemsetAsync __TANGRT_API_PTSZ(tangMemsetAsync) +#define tangMemcpy __TANGRT_API_PTDS(tangMemcpy) +#define tangMemcpyAsync __TANGRT_API_PTSZ(tangMemcpyAsync) +#define tangMallocAsync __TANGRT_API_PTSZ(tangMallocAsync) +#define tangFreeAsync __TANGRT_API_PTSZ(tangFreeAsync) +#define tangStreamSynchronize __TANGRT_API_PTSZ(tangStreamSynchronize) +#define tangStreamQuery __TANGRT_API_PTSZ(tangStreamQuery) +#define tangStreamWaitEvent __TANGRT_API_PTSZ(tangStreamWaitEvent) +#define tangStreamC2Ctransfers __TANGRT_API_PTSZ(tangStreamC2Ctransfers) +#define tangStreamGetFlags __TANGRT_API_PTSZ(tangStreamGetFlags) +#define tangStreamGetId __TANGRT_API_PTSZ(tangStreamGetId) +#define tangStreamGetPriority __TANGRT_API_PTSZ(tangStreamGetPriority) +#define tangStreamAddCallback __TANGRT_API_PTSZ(tangStreamAddCallback) +#define tangStreamBeginCapture __TANGRT_API_PTSZ(tangStreamBeginCapture) +#define tangStreamEndCapture __TANGRT_API_PTSZ(tangStreamEndCapture) +#define tangStreamIsCapturing __TANGRT_API_PTSZ(tangStreamIsCapturing) +#define tangStreamGetCaptureInfo __TANGRT_API_PTSZ(tangStreamGetCaptureInfo) +#define tangGraphLaunch __TANGRT_API_PTSZ(tangGraphLaunch) +#define tangLaunchHostFunc __TANGRT_API_PTSZ(tangLaunchHostFunc) +#define tangEventRecord __TANGRT_API_PTSZ(tangEventRecord) +#define tangEventRecordWithFlags __TANGRT_API_PTSZ(tangEventRecordWithFlags) +#define tangMemcpyFromSymbol __TANGRT_API_PTDS(tangMemcpyFromSymbol) +#define tangMemcpyFromSymbolAsync __TANGRT_API_PTSZ(tangMemcpyFromSymbolAsync) +#define tangMemcpyToSymbol __TANGRT_API_PTDS(tangMemcpyToSymbol) +#define tangMemcpyToSymbolAsync __TANGRT_API_PTSZ(tangMemcpyToSymbolAsync) +//#define tangDeviceCanAccessPeer __TANGRT_API_PTSZ(tangDeviceCanAccessPeer) +//#define tangDeviceEnablePeerAccess +//__TANGRT_API_PTSZ(tangDeviceEnablePeerAccess) #define +//tangDeviceDisablePeerAccess __TANGRT_API_PTSZ(tangDeviceDisablePeerAccess) +#endif //! __TANGRT_API_PER_THREAD_DEFAULT_STREAM +// end doxygen Events +/** + * @} + */ + +#ifdef __cplusplus + +// template is only available in cxx. +// And these template APIs may cause ambiguous. +// If you don't want to use these api, please define TANGRT_DISABLE_SYMBOL_TEMPLATE_API +// before #include or #include +#if defined(__cplusplus) && !defined(TANGRT_DISABLE_SYMBOL_TEMPLATE_API) + +template +inline tangError_t tangGetSymbolAddress(void** devPtr, const T& symbol) { + return ::tangGetSymbolAddress((void**)devPtr, (const void*)&symbol); +} + +template +inline tangError_t tangGetSymbolSize(size_t* size, const T& symbol) { + return ::tangGetSymbolSize(size, (const void*)&symbol); +} + +template +inline tangError_t tangMemcpyToSymbol( + const T& symbol, + const void* src, + size_t count, + size_t offset = 0, + enum tangMemcpyKind kind = tangMemcpyHostToDevice) { + return ::tangMemcpyToSymbol((const void*)&symbol, src, count, offset, kind); +} + +template +inline tangError_t tangMemcpyToSymbolAsync( + const T& symbol, + const void* src, + size_t count, + size_t offset = 0, + enum tangMemcpyKind kind = tangMemcpyHostToDevice, + tangStream_t stream = 0) { + return ::tangMemcpyToSymbolAsync((const void*)&symbol, + src, + count, + offset, + kind, + stream); +} + +template +inline tangError_t tangMemcpyFromSymbol( + void* dst, + const T& symbol, + size_t count, + size_t offset = 0, + enum tangMemcpyKind kind = tangMemcpyDeviceToHost) { + return ::tangMemcpyFromSymbol(dst, (const void*)&symbol, count, offset, kind); +} + +template +inline tangError_t tangMemcpyFromSymbolAsync( + void* dst, + const T& symbol, + size_t count, + size_t offset = 0, + enum tangMemcpyKind kind = tangMemcpyDeviceToHost, + tangStream_t stream = 0) { + return ::tangMemcpyFromSymbolAsync(dst, + (const void*)&symbol, + count, + offset, + kind, + stream); +} +#endif //!< TANGRT_SYMBOL_FORCE_CXX_API + +#endif // __cplusplus + +#endif //! _TANG_RUNTIME_API_H_ diff --git a/third_party/sunrise/backend/include/tapti/tapti.h b/third_party/sunrise/backend/include/tapti/tapti.h new file mode 100755 index 000000000..923cc2b57 --- /dev/null +++ b/third_party/sunrise/backend/include/tapti/tapti.h @@ -0,0 +1,11 @@ +#ifndef _TAPTI_HPP_ +#define _TAPTI_HPP_ + +#include "tang_runtime.h" +#include "tang.h" +#include "tapti_activity.h" +#include "tapti_version.h" +#include "tapti_callbacks.h" +#include "tapti_result.h" + +#endif // _TAPTI_HPP_ diff --git a/third_party/sunrise/backend/include/tapti/tapti_activity.h b/third_party/sunrise/backend/include/tapti/tapti_activity.h new file mode 100755 index 000000000..f61765c7a --- /dev/null +++ b/third_party/sunrise/backend/include/tapti/tapti_activity.h @@ -0,0 +1,956 @@ +#ifndef _TAPTI_ACTIVITY_HPP_ +#define _TAPTI_ACTIVITY_HPP_ + +#include +#include + +#include "tapti_callbacks.h" + +/** + * \brief The kind of a memory copy, indicating the source and + * destination targets of the copy. + * + * Each kind represents the source and destination targets of a memory + * copy. Targets are host, device, and array. + */ +typedef enum { + /** + * The memory copy kind is not known. + */ + TAPTI_ACTIVITY_MEMCPY_KIND_UNKNOWN = 0, + + /** + * A host to device memory copy. + */ + TAPTI_ACTIVITY_MEMCPY_KIND_HTOD = 1, + + /** + * A device to host memory copy. + */ + TAPTI_ACTIVITY_MEMCPY_KIND_DTOH = 2, + + /** + * A device to device memory copy on the same device. + */ + TAPTI_ACTIVITY_MEMCPY_KIND_DTOD = 3, + + /** + * A host to host memory copy. + */ + TAPTI_ACTIVITY_MEMCPY_KIND_HTOH = 4, + + /** + * A peer to peer memory copy across different devices. + */ + TAPTI_ACTIVITY_MEMCPY_KIND_PTOP = 5, + + TAPTI_ACTIVITY_MEMCPY_KIND_FORCE_INT = 0x7fffffff +} TApti_ActivityMemcpyKind; + +typedef enum { + TAPTI_EXTERNAL_CORRELATION_KIND_INVALID = 0, + TAPTI_EXTERNAL_CORRELATION_KIND_UNKNOWN = 1, + TAPTI_EXTERNAL_CORRELATION_KIND_OPENACC = 2, + TAPTI_EXTERNAL_CORRELATION_KIND_CUSTOM0 = 3, + TAPTI_EXTERNAL_CORRELATION_KIND_CUSTOM1 = 4, + TAPTI_EXTERNAL_CORRELATION_KIND_CUSTOM2 = 5, + TAPTI_EXTERNAL_CORRELATION_KIND_SIZE, + TAPTI_EXTERNAL_CORRELATION_KIND_FORCE_INT = 0x7fffffff +} TApti_ExternalCorrelationKind; + +#define ACTIVITY_RECORD_ALIGNMENT 8 +#define PACKED_ALIGNMENT __attribute__ ((__packed__)) __attribute__ ((aligned (ACTIVITY_RECORD_ALIGNMENT))) + +/** + * \brief The kinds of activity records. + * + * Each activity record kind represents information about a GPU or an + * activity occurring on a CPU or GPU. Each kind is associated with a + * activity record structure that holds the information associated + * with the kind. + * \see TApti_Activity + * \see TApti_ActivityAPI + * \see TApti_ActivityExternalCorrelation + * \see TApti_ActivityKernel + * \see TApti_ActivityMemcpy + * \see TApti_ActivityMemcpyPtoP + * \see TApti_ActivityMemset + */ + typedef enum { + /** + * The activity record is invalid. + */ + TAPTI_ACTIVITY_KIND_INVALID = 0, + + /** + * A host<->host, host<->device, or device<->device memory copy. The + * corresponding activity record structure is \ref + */ + TAPTI_ACTIVITY_KIND_MEMCPY = 1, + + /** + * A memory set executing on the GPU. The corresponding activity + * record structure is \ref TApti_ActivityMemset. + */ + TAPTI_ACTIVITY_KIND_MEMSET = 2, + + /** + * A kernel executing on the GPU. This activity kind may significantly change + * the overall performance characteristics of the application because all + * kernel executions are serialized on the GPU. Other activity kind for kernel + * TAPTI_ACTIVITY_KIND_CONCURRENT_KERNEL doesn't break kernel concurrency. + * The corresponding activity record structure is \ref TApti_ActivityKernel. + */ + TAPTI_ACTIVITY_KIND_KERNEL = 3, + + /** + * A TANG driver API function execution. The corresponding activity + * record structure is \ref TApti_ActivityAPI. + */ + TAPTI_ACTIVITY_KIND_DRIVER = 4, + + /** + * A TANG runtime API function execution. The corresponding activity + * record structure is \ref TApti_ActivityAPI. + */ + TAPTI_ACTIVITY_KIND_RUNTIME = 5, + + /** + * Records for correlation of different programming APIs. The + * corresponding activity record structure is \ref + * TApti_ActivityExternalCorrelation. + */ + TAPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION = 6, + + TAPTI_ACTIVITY_KIND_COUNT, + + TAPTI_ACTIVITY_KIND_FORCE_INT = 0x7fffffff +} TApti_ActivityKind; + +/** + * \brief The kinds of memory accessed by a memory operation/copy. + * + * Each kind represents the type of the memory + * accessed by a memory operation/copy. + */ +typedef enum { + /** + * The memory kind is unknown. + */ + TAPTI_ACTIVITY_MEMORY_KIND_UNKNOWN = 0, + + /** + * The memory is pageable. + */ + TAPTI_ACTIVITY_MEMORY_KIND_PAGEABLE = 1, + + /** + * The memory is pinned. + */ + TAPTI_ACTIVITY_MEMORY_KIND_PINNED = 2, + + /** + * The memory is on the device. + */ + TAPTI_ACTIVITY_MEMORY_KIND_DEVICE = 3, + + /** + * The memory is an array. + */ + TAPTI_ACTIVITY_MEMORY_KIND_ARRAY = 4, + + /** + * The memory is managed + */ + TAPTI_ACTIVITY_MEMORY_KIND_MANAGED = 5, + + /** + * The memory is device static + */ + TAPTI_ACTIVITY_MEMORY_KIND_DEVICE_STATIC = 6, + + /** + * The memory is managed static + */ + TAPTI_ACTIVITY_MEMORY_KIND_MANAGED_STATIC = 7, + + TAPTI_ACTIVITY_MEMORY_KIND_FORCE_INT = 0x7fffffff +} TApti_ActivityMemoryKind; + + +/** + * \brief The base activity record. + * + * The activity API uses a TApti_Activity as a generic representation + * for any activity. The 'kind' field is used to determine the + * specific activity kind, and from that the TAPTI_Activity object can + * be cast to the specific activity record type appropriate for that kind. + * + * Note that all activity record types are padded and aligned to + * ensure that each member of the record is naturally aligned. + * + * \see TApti_ActivityKind + */ +typedef struct PACKED_ALIGNMENT { + /** + * The kind of this activity. + */ + TApti_ActivityKind kind; +} TApti_Activity; + +/** + * \brief The activity record for correlation with external records + * + * This activity record correlates native TANG records (e.g. TANG Driver API, + * kernels, memcpys, ...) with records from external APIs such as OpenACC. + * (TAPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION). + * + * \see TApti_ActivityKind + */ +typedef struct PACKED_ALIGNMENT { + TApti_ActivityKind kind; + + /** + * The kind of external API this record correlated to. + */ + TApti_ExternalCorrelationKind externalKind; + + /** + * The correlation ID of the associated non-TANG API record. + * The exact field in the associated external record depends + * on that record's activity kind (\see externalKind). + */ + uint64_t externalId; + + /** + * The correlation ID of the associated TANG driver or runtime API record. + */ + uint32_t correlationId; + /** + * Undefined. Reserved for internal use. + */ + uint32_t reserved; +} TApti_ActivityExternalCorrelation; + +/** + * \brief The activity record for a driver or runtime API invocation. + * + * This activity record represents an invocation of a driver or + * runtime API (TAPTI_ACTIVITY_KIND_DRIVER and + * TAPTI_ACTIVITY_KIND_RUNTIME). + */ + typedef struct PACKED_ALIGNMENT { + TApti_ActivityKind kind; + /** + * The ID of the driver or runtime function. + */ + TApti_CallbackId cbid; + /** + * The start timestamp for the function, in ns. A value of 0 for + * both the start and end timestamps indicates that timestamp + * information could not be collected for the function. + */ + uint64_t start; + + /** + * The end timestamp for the function, in ns. A value of 0 for both + * the start and end timestamps indicates that timestamp information + * could not be collected for the function. + */ + uint64_t end; + + /** + * The ID of the process where the driver or runtime TANG function + * is executing. + */ + uint32_t processId; + + /** + * The ID of the thread where the driver or runtime TANG function is + * executing. + */ + uint32_t threadId; + + /** + * The correlation ID of the driver or runtime TANG function. Each + * function invocation is assigned a unique correlation ID that is + * identical to the correlation ID in the memcpy, memset, or kernel + * activity record that is associated with this function. + */ + uint32_t correlationId; + + /** + * The return value for the function. For a TANG driver function + * with will be a TAresult value, and for a TANG runtime function + * this will be a tangError_t value. + */ + uint32_t returnValue; +} TApti_ActivityAPI; + +/** + * \brief The activity record for kernel. (deprecated) + * + * This activity record represents a kernel execution + * (TAPTI_ACTIVITY_KIND_KERNEL and + * TAPTI_ACTIVITY_KIND_CONCURRENT_KERNEL) but is no longer generated + * by TAPTI. Kernel activities are now reported using the + * TApti_ActivityKernel9 activity record. + */ + typedef struct PACKED_ALIGNMENT { + /** + * The activity record kind, must be TAPTI_ACTIVITY_KIND_KERNEL + * or TAPTI_ACTIVITY_KIND_CONCURRENT_KERNEL. + */ + TApti_ActivityKind kind; + + /** + * The cache configuration requested by the kernel. The value is one + * of the CUfunc_cache enumeration values from cuda.h. + */ + uint8_t cacheConfigRequested; + + /** + * The cache configuration used for the kernel. The value is one of + * the CUfunc_cache enumeration values from cuda.h. + */ + uint8_t cacheConfigExecuted; + + /** + * The number of registers required for each thread executing the + * kernel. + */ + uint16_t registersPerThread; + + /** + * The start timestamp for the kernel execution, in ns. A value of 0 + * for both the start and end timestamps indicates that timestamp + * information could not be collected for the kernel. + */ + uint64_t start; + + /** + * The end timestamp for the kernel execution, in ns. A value of 0 + * for both the start and end timestamps indicates that timestamp + * information could not be collected for the kernel. + */ + uint64_t end; + + /** + * The ID of the device where the kernel is executing. + */ + uint32_t deviceId; + + /** + * The ID of the context where the kernel is executing. + */ + uint32_t contextId; + + /** + * The ID of the stream where the kernel is executing. + */ + uint32_t streamId; + + /** + * The X-dimension grid size for the kernel. + */ + int32_t gridX; + + /** + * The Y-dimension grid size for the kernel. + */ + int32_t gridY; + + /** + * The Z-dimension grid size for the kernel. + */ + int32_t gridZ; + + /** + * The X-dimension block size for the kernel. + */ + int32_t blockX; + + /** + * The Y-dimension block size for the kernel. + */ + int32_t blockY; + + /** + * The Z-dimension grid size for the kernel. + */ + int32_t blockZ; + + /** + * The static shared memory allocated for the kernel, in bytes. + */ + int32_t staticSharedMemory; + + /** + * The dynamic shared memory reserved for the kernel, in bytes. + */ + int32_t dynamicSharedMemory; + + /** + * The amount of local memory reserved for each thread, in bytes. + */ + uint32_t localMemoryPerThread; + + /** + * The total amount of local memory reserved for the kernel, in + * bytes. + */ + uint32_t localMemoryTotal; + + /** + * The correlation ID of the kernel. Each kernel execution is + * assigned a unique correlation ID that is identical to the + * correlation ID in the driver API activity record that launched + * the kernel. + */ + uint32_t correlationId; + + /** + * The runtime correlation ID of the kernel. Each kernel execution + * is assigned a unique runtime correlation ID that is identical to + * the correlation ID in the runtime API activity record that + * launched the kernel. + */ + uint32_t runtimeCorrelationId; + + /** + * Undefined. Reserved for internal use. + */ + uint32_t pad; + + /** + * The name of the kernel. This name is shared across all activity + * records representing the same kernel, and so should not be + * modified. + */ + const char *name; + + /** + * Undefined. Reserved for internal use. + */ + void *reserved0; +} TApti_ActivityKernel; + +/** + * \brief The activity record for memory copies. (deprecated) + * + * This activity record represents a memory copy + * (TAPTI_ACTIVITY_KIND_MEMCPY). + */ + typedef struct PACKED_ALIGNMENT { + /** + * The activity record kind, must be TAPTI_ACTIVITY_KIND_MEMCPY. + */ + TApti_ActivityKind kind; + + /** + * The kind of the memory copy, stored as a byte to reduce record + * size. \see TApti_ActivityMemcpyKind + */ + uint8_t copyKind; + + /** + * The source memory kind read by the memory copy, stored as a byte + * to reduce record size. \see TApti_ActivityMemoryKind + */ + uint8_t srcKind; + + /** + * The destination memory kind read by the memory copy, stored as a + * byte to reduce record size. \see TApti_ActivityMemoryKind + */ + uint8_t dstKind; + + /** + * The flags associated with the memory copy. \see TApti_ActivityFlag + */ + uint8_t flags; + + /** + * The number of bytes transferred by the memory copy. + */ + uint64_t bytes; + + /** + * The start timestamp for the memory copy, in ns. A value of 0 for + * both the start and end timestamps indicates that timestamp + * information could not be collected for the memory copy. + */ + uint64_t start; + + /** + * The end timestamp for the memory copy, in ns. A value of 0 for + * both the start and end timestamps indicates that timestamp + * information could not be collected for the memory copy. + */ + uint64_t end; + + /** + * The ID of the device where the memory copy is occurring. + */ + uint32_t deviceId; + + /** + * The ID of the context where the memory copy is occurring. + */ + uint32_t contextId; + + /** + * The ID of the stream where the memory copy is occurring. + */ + uint32_t streamId; + + /** + * The correlation ID of the memory copy. Each memory copy is + * assigned a unique correlation ID that is identical to the + * correlation ID in the driver API activity record that launched + * the memory copy. + */ + uint32_t correlationId; + + /** + * The runtime correlation ID of the memory copy. Each memory copy + * is assigned a unique runtime correlation ID that is identical to + * the correlation ID in the runtime API activity record that + * launched the memory copy. + */ + uint32_t runtimeCorrelationId; + +#ifdef TAptiLP64 + /** + * Undefined. Reserved for internal use. + */ + uint32_t pad; +#endif + + /** + * Undefined. Reserved for internal use. + */ + void *reserved0; +} TApti_ActivityMemcpy; + +/** + * \brief The activity record for peer-to-peer memory copies. + * + * This activity record represents a peer-to-peer memory copy + * (TAPTI_ACTIVITY_KIND_MEMCPY2) but is no longer generated + * by TAPTI. Peer-to-peer memory copy activities are now reported using the + * TApti_ActivityMemcpyPtoP2 activity record.. + */ + typedef struct PACKED_ALIGNMENT { + /** + * The activity record kind, must be TAPTI_ACTIVITY_KIND_MEMCPY2. + */ + TApti_ActivityKind kind; + + /** + * The kind of the memory copy, stored as a byte to reduce record + * size. \see TApti_ActivityMemcpyKind + */ + uint8_t copyKind; + + /** + * The source memory kind read by the memory copy, stored as a byte + * to reduce record size. \see TApti_ActivityMemoryKind + */ + uint8_t srcKind; + + /** + * The destination memory kind read by the memory copy, stored as a + * byte to reduce record size. \see TApti_ActivityMemoryKind + */ + uint8_t dstKind; + + /** + * The flags associated with the memory copy. \see + * TApti_ActivityFlag + */ + uint8_t flags; + + /** + * The number of bytes transferred by the memory copy. + */ + uint64_t bytes; + + /** + * The start timestamp for the memory copy, in ns. A value of 0 for + * both the start and end timestamps indicates that timestamp + * information could not be collected for the memory copy. + */ + uint64_t start; + + /** + * The end timestamp for the memory copy, in ns. A value of 0 for + * both the start and end timestamps indicates that timestamp + * information could not be collected for the memory copy. + */ + uint64_t end; + + /** + * The ID of the device where the memory copy is occurring. + */ + uint32_t deviceId; + + /** + * The ID of the context where the memory copy is occurring. + */ + uint32_t contextId; + + /** + * The ID of the stream where the memory copy is occurring. + */ + uint32_t streamId; + + /** + * The ID of the device where memory is being copied from. + */ + uint32_t srcDeviceId; + + /** + * The ID of the context owning the memory being copied from. + */ + uint32_t srcContextId; + + /** + * The ID of the device where memory is being copied to. + */ + uint32_t dstDeviceId; + + /** + * The ID of the context owning the memory being copied to. + */ + uint32_t dstContextId; + + /** + * The correlation ID of the memory copy. Each memory copy is + * assigned a unique correlation ID that is identical to the + * correlation ID in the driver and runtime API activity record that + * launched the memory copy. + */ + uint32_t correlationId; + +#ifndef TAptiLP64 + /** + * Undefined. Reserved for internal use. + */ + uint32_t pad; +#endif + + /** + * Undefined. Reserved for internal use. + */ + void *reserved0; +} TApti_ActivityMemcpyPtoP; + +/** + * \brief The activity record for memset. (deprecated) + * + * This activity record represents a memory set operation + * (TAPTI_ACTIVITY_KIND_MEMSET). + */ + typedef struct PACKED_ALIGNMENT { + /** + * The activity record kind, must be TAPTI_ACTIVITY_KIND_MEMSET. + */ + TApti_ActivityKind kind; + + /** + * The value being assigned to memory by the memory set. + */ + uint32_t value; + + /** + * The number of bytes being set by the memory set. + */ + uint64_t bytes; + + /** + * The start timestamp for the memory set, in ns. A value of 0 for + * both the start and end timestamps indicates that timestamp + * information could not be collected for the memory set. + */ + uint64_t start; + + /** + * The end timestamp for the memory set, in ns. A value of 0 for + * both the start and end timestamps indicates that timestamp + * information could not be collected for the memory set. + */ + uint64_t end; + + /** + * The ID of the device where the memory set is occurring. + */ + uint32_t deviceId; + + /** + * The ID of the context where the memory set is occurring. + */ + uint32_t contextId; + + /** + * The ID of the stream where the memory set is occurring. + */ + uint32_t streamId; + + /** + * The correlation ID of the memory set. Each memory set is assigned + * a unique correlation ID that is identical to the correlation ID + * in the driver API activity record that launched the memory set. + */ + uint32_t correlationId; + + /** + * The flags associated with the memset. \see TApti_ActivityFlag + */ + uint16_t flags; + + /** + * The memory kind of the memory set \see TApti_ActivityMemoryKind + */ + uint16_t memoryKind; + + +} TApti_ActivityMemset; + +#ifdef __cplusplus +extern "C" { +#endif //! __cplusplus + +#if defined(_MSC_VER) +#define TAPTI_DEPRECATED __declspec(deprecated) +#define TAPTI_API_EXPORT __declspec(dllexport) +#define TAPTI_API_IMPORT __declspec(dllimport) +#elif defined(__GNUC__) || defined(__clang__) +#define TAPTI_DEPRECATED __attribute__((deprecated)) +#define TAPTI_API_EXPORT __attribute__((visibility("default"))) +#define TAPTI_API_IMPORT __attribute__((visibility("default"))) +#else +#define TAPTI_DEPRECATED +#define TAPTI_API_EXPORT +#define TAPTI_API_IMPORT +#endif //! UNKNOWN COMPILER + +#if defined(tapti_shared_EXPORTS) +#define TAPTI_API TAPTI_API_EXPORT +#else +#define TAPTI_API TAPTI_API_IMPORT +#endif //! For user + +/** + * \brief Enable collection of a specific kind of activity record. + * + * Enable collection of a specific kind of activity record. Multiple + * kinds can be enabled by calling this function multiple times. By + * default all activity kinds are disabled for collection. + * + * \param kind The kind of activity record to collect + * + * \retval TAPTI_SUCCESS + * \retval TAPTI_ERROR_NOT_INITIALIZED + * \retval TAPTI_ERROR_NOT_COMPATIBLE if the activity kind cannot be enabled + * \retval TAPTI_ERROR_INVALID_KIND if the activity kind is not supported + */ +TAptiResult TAPTI_API taptiActivityEnable(TApti_ActivityKind kind); + +/** + * \brief Disable collection of a specific kind of activity record. + * + * Disable collection of a specific kind of activity record. Multiple + * kinds can be disabled by calling this function multiple times. By + * default all activity kinds are disabled for collection. + * + * \param kind The kind of activity record to stop collecting + * + * \retval TAPTI_SUCCESS + * \retval TAPTI_ERROR_NOT_INITIALIZED + * \retval TAPTI_ERROR_INVALID_KIND if the activity kind is not supported + */ +TAptiResult TAPTI_API taptiActivityDisable(TApti_ActivityKind kind); + +/** + * \brief Iterate over the activity records in a buffer. + * + * This is a helper function to iterate over the activity records in a + * buffer. A buffer of activity records is typically obtained by + * receiving a TApti_BuffersCallbackCompleteFunc callback. + * + * An example of typical usage: + * \code + * TApti_Activity *record = NULL; + * TAptiResult status = TAPTI_SUCCESS; + * do { + * status = taptiActivityGetNextRecord(buffer, validSize, &record); + * if(status == TAPTI_SUCCESS) { + * // Use record here... + * } + * else if (status == TAPTI_ERROR_MAX_LIMIT_REACHED) + * break; + * else { + * goto Error; + * } + * } while (1); + * \endcode + * + * \param buffer The buffer containing activity records + * \param record Inputs the previous record returned by + * taptiActivityGetNextRecord and returns the next activity record + * from the buffer. If input value is NULL, returns the first activity + * record in the buffer. Records of kind TAPTI_ACTIVITY_KIND_CONCURRENT_KERNEL + * may contain invalid (0) timestamps, indicating that no timing information could + * be collected for lack of device memory. + * \param validBufferSizeBytes The number of valid bytes in the buffer. + * + * \retval TAPTI_SUCCESS + * \retval TAPTI_ERROR_NOT_INITIALIZED + * \retval TAPTI_ERROR_MAX_LIMIT_REACHED if no more records in the buffer + * \retval TAPTI_ERROR_INVALID_PARAMETER if \p buffer is NULL. + */ +TAptiResult TAPTI_API taptiActivityGetNextRecord(uint8_t* buffer, size_t validBufferSizeBytes, TApti_Activity **record); + +/** + * \brief Push an external correlation id for the calling thread + * + * This function notifies TAPTI that the calling thread is entering an external API region. + * When a TAPTI activity API record is created while within an external API region and + * TAPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION is enabled, the activity API record will + * be preceeded by a TApti_ActivityExternalCorrelation record for each \ref TApti_ExternalCorrelationKind. + * + * \param kind The kind of external API activities should be correlated with. + * \param id External correlation id. + * + * \retval TAPTI_SUCCESS + * \retval TAPTI_ERROR_INVALID_PARAMETER The external API kind is invalid + */ +TAptiResult TAPTI_API taptiActivityPushExternalCorrelationId(TApti_ExternalCorrelationKind kind, uint64_t id); + +/** + * \brief Pop an external correlation id for the calling thread + * + * This function notifies TAPTI that the calling thread is leaving an external API region. + * + * \param kind The kind of external API activities should be correlated with. + * \param lastId If the function returns successful, contains the last external correlation id for this \p kind, can be NULL. + * + * \retval TAPTI_SUCCESS + * \retval TAPTI_ERROR_INVALID_PARAMETER The external API kind is invalid. + * \retval TAPTI_ERROR_QUEUE_EMPTY No external id is currently associated with \p kind. + */ +TAptiResult TAPTI_API taptiActivityPopExternalCorrelationId(TApti_ExternalCorrelationKind kind, uint64_t *lastId); + +/** + * \brief Request to deliver activity records via the buffer completion callback. + * + * This function returns the activity records associated with all contexts/streams + * (and the global buffers not associated with any stream) to the TAPTI client + * using the callback registered in taptiActivityRegisterCallbacks. + * + * This is a blocking call but it doesn't issue any TANG synchronization calls + * implicitly thus it's not guaranteed that all activities are completed on the + * underlying devices. Activity record is considered as completed if it has all + * the information filled up including the timestamps if any. It is the client's + * responsibility to issue necessary TANG synchronization calls before calling + * this function if all activity records with complete information are expected + * to be delivered. + * + * Behavior of the function based on the input flag: + * - ::For default flush i.e. when flag is set as 0, it returns all the + * activity buffers which have all the activity records completed, buffers need not + * to be full though. It doesn't return buffers which have one or more incomplete + * records. Default flush can be done at a regular interval in a separate thread. + * - ::For forced flush i.e. when flag TAPTI_ACTIVITY_FLAG_FLUSH_FORCED is passed + * to the function, it returns all the activity buffers including the ones which have + * one or more incomplete activity records. It's suggested for clients to do the + * force flush before the termination of the profiling session to allow remaining + * buffers to be delivered. In general, it can be done in the at-exit handler. + * + * Before calling this function, the buffer handling callback api must be activated + * by calling taptiActivityRegisterCallbacks. + * + * \param flag The flag can be set to indicate a forced flush. See TApti_ActivityFlag + * + * \retval TAPTI_SUCCESS + * \retval TAPTI_ERROR_NOT_INITIALIZED + * \retval TAPTI_ERROR_INVALID_OPERATION if not preceeded by a + * successful call to taptiActivityRegisterCallbacks + * \retval TAPTI_ERROR_UNKNOWN an internal error occurred + * + * \see taptiActivityFlushPeriod + */ +TAptiResult TAPTI_API taptiActivityFlushAll(void); + +/** + * \brief Function type for callback used by TAPTI to request an empty + * buffer for storing activity records. + * + * This callback function signals the TAPTI client that an activity + * buffer is needed by TAPTI. The activity buffer is used by TAPTI to + * store activity records. The callback function can decline the + * request by setting \p *buffer to NULL. In this case TAPTI may drop + * activity records. + * + * \param buffer Returns the new buffer. If set to NULL then no buffer + * is returned. + * \param size Returns the size of the returned buffer. + * \param maxNumRecords Returns the maximum number of records that + * should be placed in the buffer. If 0 then the buffer is filled with + * as many records as possible. If > 0 the buffer is filled with at + * most that many records before it is returned. + */ +typedef void (*TApti_BuffersCallbackRequestFunc)(uint8_t **buffer, size_t *size, size_t *maxNumRecords); + +/** + * \brief Function type for callback used by TAPTI to return a buffer + * of activity records. + * + * This callback function returns to the TAPTI client a buffer + * containing activity records. The buffer contains \p validSize + * bytes of activity records which should be read using + * taptiActivityGetNextRecord. The number of dropped records can be + * read using taptiActivityGetNumDroppedRecords. After this call TAPTI + * relinquished ownership of the buffer and will not use it + * anymore. The client may return the buffer to TAPTI using the + * TApti_BuffersCallbackRequestFunc callback. + + * \param buffer The activity record buffer. + * \param size The total size of the buffer in bytes as set in + * TApti_BuffersCallbackRequestFunc. + * \param validSize The number of valid bytes in the buffer. + */ +typedef void (*TApti_BuffersCallbackCompleteFunc)(uint8_t* buffer, size_t size, size_t validSize); + +/** + * \brief Registers callback functions with TAPTI for activity buffer + * handling. + * + * This function registers two callback functions to be used in asynchronous + * buffer handling. If registered, activity record buffers are handled using + * asynchronous requested/completed callbacks from TAPTI. + * + * Registering these callbacks prevents the client from using TAPTI's + * blocking enqueue/dequeue functions. + * + * \param funcBufferRequested callback which is invoked when an empty + * buffer is requested by TAPTI + * \param funcBufferCompleted callback which is invoked when a buffer + * containing activity records is available from TAPTI + * + * \retval TAPTI_SUCCESS + * \retval TAPTI_ERROR_INVALID_PARAMETER if either \p + * funcBufferRequested or \p funcBufferCompleted is NULL + */ +TAptiResult TAPTI_API taptiActivityRegisterCallbacks(TApti_BuffersCallbackRequestFunc funcBufferRequested, + TApti_BuffersCallbackCompleteFunc funcBufferCompleted); + +TAptiResult TAPTI_API taptiActivityPostProcess(void); + +#ifdef __cplusplus +} +#endif //! __cplusplus + +#endif // _TAPTI_ACTIVITY_HPP_ + diff --git a/third_party/sunrise/backend/include/tapti/tapti_callbacks.h b/third_party/sunrise/backend/include/tapti/tapti_callbacks.h new file mode 100755 index 000000000..bf9e7061d --- /dev/null +++ b/third_party/sunrise/backend/include/tapti/tapti_callbacks.h @@ -0,0 +1,109 @@ +#ifndef __TAPTI_CALLBACKS_HPP__ +#define __TAPTI_CALLBACKS_HPP__ + +#include +#include "tapti_result.h" + +/** + * \brief Callback domains. + * + * Callback domains. Each domain represents callback points for a + * group of related API functions or TANG driver activity. + */ +typedef enum { + /** + * Invalid domain. + */ + TAPTI_CB_DOMAIN_INVALID = 0, + /** + * Domain containing callback points for all driver API functions. + */ + TAPTI_CB_DOMAIN_DRIVER_API = 1, + /** + * Domain containing callback points for all runtime API + * functions. + */ + TAPTI_CB_DOMAIN_RUNTIME_API = 2, + TAPTI_CB_DOMAIN_SIZE, + + TAPTI_CB_DOMAIN_FORCE_INT = 0x7fffffff +}TApti_CallbackDomain; + +/** + * \brief An ID for a driver API, runtime API, resource or + * synchronization callback. + * + * An ID for a driver API, runtime API, resource or synchronization + * callback. Within a driver API callback this should be interpreted + * as a tapti_driver_api_trace_cbid value. + * Within a runtime API callback this should be interpreted as a + * TAPTI_runtime_api_trace_cbid value. + * Within a resource API callback this should be interpreted as a + * ref TAPTI_CallbackIdResource value. + * Within a synchronize API callback this should be interpreted as a + * ref TAPTI_CallbackIdSync value. + */ +typedef uint32_t TApti_CallbackId; + +#ifdef __cplusplus +extern "C" { +#endif //! __cplusplus + +#if defined(_MSC_VER) +#define TAPTI_DEPRECATED __declspec(deprecated) +#define TAPTI_API_EXPORT __declspec(dllexport) +#define TAPTI_API_IMPORT __declspec(dllimport) +#elif defined(__GNUC__) || defined(__clang__) +#define TAPTI_DEPRECATED __attribute__((deprecated)) +#define TAPTI_API_EXPORT __attribute__((visibility("default"))) +#define TAPTI_API_IMPORT __attribute__((visibility("default"))) +#else +#define TAPTI_DEPRECATED +#define TAPTI_API_EXPORT +#define TAPTI_API_IMPORT +#endif //! UNKNOWN COMPILER + +#if defined(tapti_shared_EXPORTS) +#define TAPTI_API TAPTI_API_EXPORT +#else +#define TAPTI_API TAPTI_API_IMPORT +#endif //! For user + +/** + * \brief Get the name of a callback for a specific domain and callback ID. + * + * Returns a pointer to the name c_string in \p **name. + * + * \note \b Names are available only for the DRIVER and RUNTIME domains. + * + * \param domain The domain of the callback + * \param cbid The ID of the callback + * \param name Returns pointer to the name string on success, NULL otherwise + * + * \retval TAPTI_SUCCESS on success + * \retval TAPTI_ERROR_INVALID_PARAMETER if \p name is NULL, or if + * \p domain or \p cbid is invalid. + */ +TAptiResult TAPTI_API taptiGetCallbackName(TApti_CallbackDomain domain, + uint32_t cbid, + const char **name); +/** + * \brief Get the TAPTI timestamp. + * + * Returns a timestamp normalized to correspond with the start and end + * timestamps reported in the TAPTI activity records. The timestamp is + * reported in nanoseconds. + * + * \param timestamp Returns the TAPTI timestamp + * + * \retval TAPTI_SUCCESS + * \retval TAPTI_ERROR_INVALID_PARAMETER if \p timestamp is NULL + */ +TAptiResult TAPTI_API taptiGetTimestamp(uint64_t *timestamp); + +#ifdef __cplusplus +} +#endif //! __cplusplus + +#endif // __TAPTI_CALLBACKS_HPP__ + diff --git a/third_party/sunrise/backend/include/tapti/tapti_driver_cbid.h b/third_party/sunrise/backend/include/tapti/tapti_driver_cbid.h new file mode 100755 index 000000000..075daf5c9 --- /dev/null +++ b/third_party/sunrise/backend/include/tapti/tapti_driver_cbid.h @@ -0,0 +1,203 @@ +// ************************************************************************* +// Definitions of indices for API functions, unique across entire API +// ************************************************************************* + +// This file is generated. + +typedef enum tapti_driver_api_trace_cbid_enum { + TAPTI_DRIVER_TRACE_CBID_INVALID = 0, + TAPTI_DRIVER_TRACE_CBID_taRuntimeGetVersion = 1, + TAPTI_DRIVER_TRACE_CBID_taGetExportTable = 2, + TAPTI_DRIVER_TRACE_CBID_taDriverGetVersion = 3, + TAPTI_DRIVER_TRACE_CBID_taKernelDriverGetVersion = 4, + TAPTI_DRIVER_TRACE_CBID_taGetErrorString = 5, + TAPTI_DRIVER_TRACE_CBID_taGetErrorName = 6, + TAPTI_DRIVER_TRACE_CBID_taInit = 7, + TAPTI_DRIVER_TRACE_CBID_taDeviceGet = 8, + TAPTI_DRIVER_TRACE_CBID_taDeviceReset = 9, + TAPTI_DRIVER_TRACE_CBID_taDeviceSynchronize = 10, + TAPTI_DRIVER_TRACE_CBID___taDeviceSynchronizeCurrent = 11, + TAPTI_DRIVER_TRACE_CBID_taDeviceGetByPCIBusId = 12, + TAPTI_DRIVER_TRACE_CBID_taDeviceGetPCIBusId = 13, + TAPTI_DRIVER_TRACE_CBID_taDeviceGetCount = 14, + TAPTI_DRIVER_TRACE_CBID_taDeviceGetAttribute = 15, + TAPTI_DRIVER_TRACE_CBID_taDeviceGetName = 16, + TAPTI_DRIVER_TRACE_CBID_taDeviceGetUuid = 17, + TAPTI_DRIVER_TRACE_CBID_taDeviceTotalMem = 18, + TAPTI_DRIVER_TRACE_CBID_taMemGetInfo = 19, + TAPTI_DRIVER_TRACE_CBID_taDeviceMemGetInfo = 20, + TAPTI_DRIVER_TRACE_CBID_taCtxCreate = 21, + TAPTI_DRIVER_TRACE_CBID_taCtxDestroy = 22, + TAPTI_DRIVER_TRACE_CBID_taCtxSetCurrent = 23, + TAPTI_DRIVER_TRACE_CBID_taCtxQueryCurrent = 24, + TAPTI_DRIVER_TRACE_CBID_taCtxGetCurrent = 25, + TAPTI_DRIVER_TRACE_CBID_taCtxGetCurrentDevice = 26, + TAPTI_DRIVER_TRACE_CBID_taCtxGetDevice = 27, + TAPTI_DRIVER_TRACE_CBID_taCtxGetOrdinal = 28, + TAPTI_DRIVER_TRACE_CBID_taCtxPushCurrent = 29, + TAPTI_DRIVER_TRACE_CBID_taCtxPopCurrent = 30, + TAPTI_DRIVER_TRACE_CBID_taCtxSynchronize = 31, + TAPTI_DRIVER_TRACE_CBID_taDevicePrimaryCtxRetain = 32, + TAPTI_DRIVER_TRACE_CBID_taDevicePrimaryCtxRelease = 33, + TAPTI_DRIVER_TRACE_CBID_taDevicePrimaryCtxReset = 34, + TAPTI_DRIVER_TRACE_CBID_taDevicePrimaryCtxSetFlags = 35, + TAPTI_DRIVER_TRACE_CBID_taDevicePrimaryCtxGetState = 36, + TAPTI_DRIVER_TRACE_CBID_taDeviceGetOrdinal = 37, + TAPTI_DRIVER_TRACE_CBID_taCtxGetFunction = 38, + TAPTI_DRIVER_TRACE_CBID_taCtxRegisterFunction = 39, + TAPTI_DRIVER_TRACE_CBID_taCtxGetVariable = 40, + TAPTI_DRIVER_TRACE_CBID_taCtxRegisterVariable = 41, + TAPTI_DRIVER_TRACE_CBID_taCtxGetLimit = 42, + TAPTI_DRIVER_TRACE_CBID_taCtxSetLimit = 43, + TAPTI_DRIVER_TRACE_CBID___taCtxQueryLimit = 44, + TAPTI_DRIVER_TRACE_CBID_taCtxGetBuiltInFunction = 45, + TAPTI_DRIVER_TRACE_CBID_taCtxRegisterBuiltInFunction = 46, + TAPTI_DRIVER_TRACE_CBID_taMemAlloc = 47, + TAPTI_DRIVER_TRACE_CBID_taMemFree = 48, + TAPTI_DRIVER_TRACE_CBID_taMemFreeAsync = 49, + TAPTI_DRIVER_TRACE_CBID_taMemFreeAsync_ptsz = 50, + TAPTI_DRIVER_TRACE_CBID_taMemAllocHost = 51, + TAPTI_DRIVER_TRACE_CBID_taMemHostAlloc = 52, + TAPTI_DRIVER_TRACE_CBID_taMemFreeHost = 53, + TAPTI_DRIVER_TRACE_CBID_taMemHostGetDevicePointer = 54, + TAPTI_DRIVER_TRACE_CBID_taMemHostGetFlags = 55, + TAPTI_DRIVER_TRACE_CBID_taMemHostRegister = 56, + TAPTI_DRIVER_TRACE_CBID_taMemHostUnregister = 57, + TAPTI_DRIVER_TRACE_CBID_taPointerGetAttribute = 58, + TAPTI_DRIVER_TRACE_CBID_taMemset = 59, + TAPTI_DRIVER_TRACE_CBID_taMemset_ptds = 60, + TAPTI_DRIVER_TRACE_CBID_taMemsetAsync = 61, + TAPTI_DRIVER_TRACE_CBID_taMemsetAsync_ptsz = 62, + TAPTI_DRIVER_TRACE_CBID_taMemcpyH2HAsync = 63, + TAPTI_DRIVER_TRACE_CBID_taMemcpyH2HAsync_ptsz = 64, + TAPTI_DRIVER_TRACE_CBID_taMemcpyH2DAsync = 65, + TAPTI_DRIVER_TRACE_CBID_taMemcpyH2DAsync_ptsz = 66, + TAPTI_DRIVER_TRACE_CBID_taMemcpyD2HAsync = 67, + TAPTI_DRIVER_TRACE_CBID_taMemcpyD2HAsync_ptsz = 68, + TAPTI_DRIVER_TRACE_CBID_taMemcpyD2DAsync = 69, + TAPTI_DRIVER_TRACE_CBID_taMemcpyD2DAsync_ptsz = 70, + TAPTI_DRIVER_TRACE_CBID_taMemcpyH2H = 71, + TAPTI_DRIVER_TRACE_CBID_taMemcpyH2H_ptds = 72, + TAPTI_DRIVER_TRACE_CBID_taMemcpyH2D = 73, + TAPTI_DRIVER_TRACE_CBID_taMemcpyH2D_ptds = 74, + TAPTI_DRIVER_TRACE_CBID_taMemcpyD2H = 75, + TAPTI_DRIVER_TRACE_CBID_taMemcpyD2H_ptds = 76, + TAPTI_DRIVER_TRACE_CBID_taMemcpyD2D = 77, + TAPTI_DRIVER_TRACE_CBID_taMemcpyD2D_ptds = 78, + TAPTI_DRIVER_TRACE_CBID_taStreamCreate = 79, + TAPTI_DRIVER_TRACE_CBID_taStreamCreateWithPriority = 80, + TAPTI_DRIVER_TRACE_CBID_taStreamGetPriority = 81, + TAPTI_DRIVER_TRACE_CBID_taStreamGetPriority_ptsz = 82, + TAPTI_DRIVER_TRACE_CBID_taStreamGetFlags = 83, + TAPTI_DRIVER_TRACE_CBID_taStreamGetFlags_ptsz = 84, + TAPTI_DRIVER_TRACE_CBID_taStreamGetId = 85, + TAPTI_DRIVER_TRACE_CBID_taStreamGetId_ptsz = 86, + TAPTI_DRIVER_TRACE_CBID_taStreamDestroy = 87, + TAPTI_DRIVER_TRACE_CBID_taStreamC2Ctransfers = 88, + TAPTI_DRIVER_TRACE_CBID_taDeviceGetP2PAttribute = 89, + TAPTI_DRIVER_TRACE_CBID_taDeviceGetPeerPointer = 90, + TAPTI_DRIVER_TRACE_CBID_taDeviceCanAccessPeer = 91, + TAPTI_DRIVER_TRACE_CBID_taDeviceEnablePeerAccess = 92, + TAPTI_DRIVER_TRACE_CBID_taDeviceDisablePeerAccess = 93, + TAPTI_DRIVER_TRACE_CBID_taMemcpyPeer = 94, + TAPTI_DRIVER_TRACE_CBID_taMemcpyPeerAsync = 95, + TAPTI_DRIVER_TRACE_CBID_taMemcpyPeer_v2 = 96, + TAPTI_DRIVER_TRACE_CBID_taMemcpyPeer_v2_ptds = 97, + TAPTI_DRIVER_TRACE_CBID_taMemcpyPeerAsync_v2 = 98, + TAPTI_DRIVER_TRACE_CBID_taMemcpyPeerAsync_v2_ptsz = 99, + TAPTI_DRIVER_TRACE_CBID_taMemcpyFromPeerAsync = 100, + TAPTI_DRIVER_TRACE_CBID_taMemcpyFromPeerAsync_ptsz = 101, + TAPTI_DRIVER_TRACE_CBID_taMemcpyToPeerAsync = 102, + TAPTI_DRIVER_TRACE_CBID_taMemcpyToPeerAsync_ptsz = 103, + TAPTI_DRIVER_TRACE_CBID_taStreamWaitEvent = 104, + TAPTI_DRIVER_TRACE_CBID_taStreamWaitEvent_ptsz = 105, + TAPTI_DRIVER_TRACE_CBID_taStreamSynchronize = 106, + TAPTI_DRIVER_TRACE_CBID_taStreamSynchronize_ptsz = 107, + TAPTI_DRIVER_TRACE_CBID_taStreamQuery = 108, + TAPTI_DRIVER_TRACE_CBID_taStreamQuery_ptsz = 109, + TAPTI_DRIVER_TRACE_CBID_taEventCreate = 110, + TAPTI_DRIVER_TRACE_CBID_taEventDestroy = 111, + TAPTI_DRIVER_TRACE_CBID_taEventRecord = 112, + TAPTI_DRIVER_TRACE_CBID_taEventRecord_ptsz = 113, + TAPTI_DRIVER_TRACE_CBID_taEventRecordWithFlags = 114, + TAPTI_DRIVER_TRACE_CBID_taEventRecordWithFlags_ptsz = 115, + TAPTI_DRIVER_TRACE_CBID_taEventSynchronize = 116, + TAPTI_DRIVER_TRACE_CBID_taEventSynchronizeWithFlags = 117, + TAPTI_DRIVER_TRACE_CBID_taEventElapsedTime = 118, + TAPTI_DRIVER_TRACE_CBID_taEventQuery = 119, + TAPTI_DRIVER_TRACE_CBID_taEventQueryTimestamp = 120, + TAPTI_DRIVER_TRACE_CBID_taGetBuiltinModule = 121, + TAPTI_DRIVER_TRACE_CBID_taModuleLoad = 122, + TAPTI_DRIVER_TRACE_CBID_taModuleLoadData = 123, + TAPTI_DRIVER_TRACE_CBID_taModuleUnload = 124, + TAPTI_DRIVER_TRACE_CBID_taModuleLoadFatBinaryManaged = 125, + TAPTI_DRIVER_TRACE_CBID_taModuleUnloadManaged = 126, + TAPTI_DRIVER_TRACE_CBID_taModuleSymbolTypeGetName = 127, + TAPTI_DRIVER_TRACE_CBID_taModuleIterateSymbols = 128, + TAPTI_DRIVER_TRACE_CBID_taModuleGetFunction = 129, + TAPTI_DRIVER_TRACE_CBID_taVariableGetInfo = 130, + TAPTI_DRIVER_TRACE_CBID_taFuncGetAttribute = 131, + TAPTI_DRIVER_TRACE_CBID_taFuncGetModule = 132, + TAPTI_DRIVER_TRACE_CBID_taFunctionGetAddress = 133, + TAPTI_DRIVER_TRACE_CBID_taFunctionGetNumArgs = 134, + TAPTI_DRIVER_TRACE_CBID_taFunctionGetInfo = 135, + TAPTI_DRIVER_TRACE_CBID_taModuleGetVariable = 136, + TAPTI_DRIVER_TRACE_CBID_taVariableCopyFromDevice = 137, + TAPTI_DRIVER_TRACE_CBID_taVariableCopyFromDevice_ptds = 138, + TAPTI_DRIVER_TRACE_CBID_taVariableCopyFromDeviceAsync = 139, + TAPTI_DRIVER_TRACE_CBID_taVariableCopyFromDeviceAsync_ptsz = 140, + TAPTI_DRIVER_TRACE_CBID_taVariableCopyFromHost = 141, + TAPTI_DRIVER_TRACE_CBID_taVariableCopyFromHost_ptds = 142, + TAPTI_DRIVER_TRACE_CBID_taVariableCopyFromHostAsync = 143, + TAPTI_DRIVER_TRACE_CBID_taVariableCopyFromHostAsync_ptsz = 144, + TAPTI_DRIVER_TRACE_CBID_taVariableCopyToDevice = 145, + TAPTI_DRIVER_TRACE_CBID_taVariableCopyToDevice_ptds = 146, + TAPTI_DRIVER_TRACE_CBID_taVariableCopyToDeviceAsync = 147, + TAPTI_DRIVER_TRACE_CBID_taVariableCopyToDeviceAsync_ptsz = 148, + TAPTI_DRIVER_TRACE_CBID_taVariableCopyToHost = 149, + TAPTI_DRIVER_TRACE_CBID_taVariableCopyToHost_ptds = 150, + TAPTI_DRIVER_TRACE_CBID_taVariableCopyToHostAsync = 151, + TAPTI_DRIVER_TRACE_CBID_taVariableCopyToHostAsync_ptsz = 152, + TAPTI_DRIVER_TRACE_CBID_taEnqueueCommand = 153, + TAPTI_DRIVER_TRACE_CBID_taEnqueueCommand_ptsz = 154, + TAPTI_DRIVER_TRACE_CBID_taLaunchFunction = 155, + TAPTI_DRIVER_TRACE_CBID_taLaunchFunction_ptsz = 156, + TAPTI_DRIVER_TRACE_CBID_taLaunchKernel = 157, + TAPTI_DRIVER_TRACE_CBID_taLaunchKernel_ptsz = 158, + TAPTI_DRIVER_TRACE_CBID_taLaunchHostFunc = 159, + TAPTI_DRIVER_TRACE_CBID_taLaunchHostFunc_ptsz = 160, + TAPTI_DRIVER_TRACE_CBID_taLaunchHostFuncProxy = 161, + TAPTI_DRIVER_TRACE_CBID_taLaunchHostFuncProxy_ptsz = 162, + TAPTI_DRIVER_TRACE_CBID_taStreamAddCallback = 163, + TAPTI_DRIVER_TRACE_CBID_taStreamAddCallback_ptsz = 164, + TAPTI_DRIVER_TRACE_CBID_taOccupancyMaxActiveBlocksPerMultiprocessor = 165, + TAPTI_DRIVER_TRACE_CBID_taStreamBeginCapture = 166, + TAPTI_DRIVER_TRACE_CBID_taStreamBeginCapture_ptsz = 167, + TAPTI_DRIVER_TRACE_CBID_taThreadExchangeStreamCaptureMode = 168, + TAPTI_DRIVER_TRACE_CBID_taStreamEndCapture = 169, + TAPTI_DRIVER_TRACE_CBID_taStreamEndCapture_ptsz = 170, + TAPTI_DRIVER_TRACE_CBID_taStreamIsCapturing = 171, + TAPTI_DRIVER_TRACE_CBID_taStreamIsCapturing_ptsz = 172, + TAPTI_DRIVER_TRACE_CBID_taStreamGetCaptureInfo = 173, + TAPTI_DRIVER_TRACE_CBID_taStreamGetCaptureInfo_ptsz = 174, + TAPTI_DRIVER_TRACE_CBID_taGraphInstantiateWithFlags = 175, + TAPTI_DRIVER_TRACE_CBID_taGraphLaunch = 176, + TAPTI_DRIVER_TRACE_CBID_taGraphLaunch_ptsz = 177, + TAPTI_DRIVER_TRACE_CBID_taGraphDestroy = 178, + TAPTI_DRIVER_TRACE_CBID_taGraphExecDestroy = 179, + TAPTI_DRIVER_TRACE_CBID_taGraphGetInfo = 180, + TAPTI_DRIVER_TRACE_CBID_taGraphCreate = 181, + TAPTI_DRIVER_TRACE_CBID_taGraphAddHostNode = 182, + TAPTI_DRIVER_TRACE_CBID_taGraphAddKernelNode = 183, + TAPTI_DRIVER_TRACE_CBID_taProfilerStart = 184, + TAPTI_DRIVER_TRACE_CBID_taProfilerStop = 185, + TAPTI_DRIVER_TRACE_CBID_taIpcGetMemHandle = 186, + TAPTI_DRIVER_TRACE_CBID_taIpcOpenMemHandle = 187, + TAPTI_DRIVER_TRACE_CBID_taIpcCloseMemHandle = 188, + TAPTI_DRIVER_TRACE_CBID_taIpcGetEventHandle = 189, + TAPTI_DRIVER_TRACE_CBID_taIpcOpenEventHandle = 190, + TAPTI_DRIVER_TRACE_CBID_taMemAllocAsync = 191, + TAPTI_DRIVER_TRACE_CBID_taMemAllocAsync_ptsz = 192, + TAPTI_DRIVER_TRACE_CBID_FORCE_INT = 0x7fffffff, +} tapti_driver_api_trace_cbid; + diff --git a/third_party/sunrise/backend/include/tapti/tapti_result.h b/third_party/sunrise/backend/include/tapti/tapti_result.h new file mode 100755 index 000000000..645afddea --- /dev/null +++ b/third_party/sunrise/backend/include/tapti/tapti_result.h @@ -0,0 +1,248 @@ +#ifndef _TAPTI_RESULT_HPP_ +#define _TAPTI_RESULT_HPP_ + +/** + * \brief TAPTI result codes. + * + * Error and result codes returned by TAPTI functions. + */ +typedef enum { + /** + * No error. + */ + TAPTI_SUCCESS = 0, + /** + * One or more of the parameters is invalid. + */ + TAPTI_ERROR_INVALID_PARAMETER = 1, + /** + * The device does not correspond to a valid TANG device. + */ + TAPTI_ERROR_INVALID_DEVICE = 2, + /** + * The context is NULL or not valid. + */ + TAPTI_ERROR_INVALID_CONTEXT = 3, + /** + * The event domain id is invalid. + */ + TAPTI_ERROR_INVALID_EVENT_DOMAIN_ID = 4, + /** + * The event id is invalid. + */ + TAPTI_ERROR_INVALID_EVENT_ID = 5, + /** + * The event name is invalid. + */ + TAPTI_ERROR_INVALID_EVENT_NAME = 6, + /** + * The current operation cannot be performed due to dependency on + * other factors. + */ + TAPTI_ERROR_INVALID_OPERATION = 7, + /** + * Unable to allocate enough memory to perform the requested + * operation. + */ + TAPTI_ERROR_OUT_OF_MEMORY = 8, + /** + * An error occurred on the performance monitoring hardware. + */ + TAPTI_ERROR_HARDWARE = 9, + /** + * The output buffer size is not sufficient to return all + * requested data. + */ + TAPTI_ERROR_PARAMETER_SIZE_NOT_SUFFICIENT = 10, + /** + * API is not implemented. + */ + TAPTI_ERROR_API_NOT_IMPLEMENTED = 11, + /** + * The maximum limit is reached. + */ + TAPTI_ERROR_MAX_LIMIT_REACHED = 12, + /** + * The object is not yet ready to perform the requested operation. + */ + TAPTI_ERROR_NOT_READY = 13, + /** + * The current operation is not compatible with the current state + * of the object + */ + TAPTI_ERROR_NOT_COMPATIBLE = 14, + /** + * TAPTI is unable to initialize its connection to the TANG + * driver. + */ + TAPTI_ERROR_NOT_INITIALIZED = 15, + /** + * The metric id is invalid. + */ + TAPTI_ERROR_INVALID_METRIC_ID = 16, + /** + * The metric name is invalid. + */ + TAPTI_ERROR_INVALID_METRIC_NAME = 17, + /** + * The queue is empty. + */ + TAPTI_ERROR_QUEUE_EMPTY = 18, + /** + * Invalid handle (internal?). + */ + TAPTI_ERROR_INVALID_HANDLE = 19, + /** + * Invalid stream. + */ + TAPTI_ERROR_INVALID_STREAM = 20, + /** + * Invalid kind. + */ + TAPTI_ERROR_INVALID_KIND = 21, + /** + * Invalid event value. + */ + TAPTI_ERROR_INVALID_EVENT_VALUE = 22, + /** + * TAPTI is disabled due to conflicts with other enabled profilers + */ + TAPTI_ERROR_DISABLED = 23, + /** + * Invalid module. + */ + TAPTI_ERROR_INVALID_MODULE = 24, + /** + * Invalid metric value. + */ + TAPTI_ERROR_INVALID_METRIC_VALUE = 25, + /** + * The performance monitoring hardware is in use by other client. + */ + TAPTI_ERROR_HARDWARE_BUSY = 26, + /** + * The attempted operation is not supported on the current + * system or device. + */ + TAPTI_ERROR_NOT_SUPPORTED = 27, + /** + * Unified memory profiling is not supported on the system. + * Potential reason could be unsupported OS or architecture. + */ + TAPTI_ERROR_UM_PROFILING_NOT_SUPPORTED = 28, + /** + * Unified memory profiling is not supported on the device + */ + TAPTI_ERROR_UM_PROFILING_NOT_SUPPORTED_ON_DEVICE = 29, + /** + * Unified memory profiling is not supported on a multi-GPU + * configuration without P2P support between any pair of devices + */ + TAPTI_ERROR_UM_PROFILING_NOT_SUPPORTED_ON_NON_P2P_DEVICES = 30, + /** + * Profiling on virtualized GPU is not supported. + */ + TAPTI_ERROR_VIRTUALIZED_DEVICE_NOT_SUPPORTED = 33, + /** + * User doesn't have sufficient privileges which are required to + * start the profiling session. + * One possible reason for this may be that the NVIDIA driver or your system + * administrator may have restricted access to the NVIDIA GPU performance counters. + * To learn how to resolve this issue and find more information, please visit + * https://developer.nvidia.com/TAPTI_ERROR_INSUFFICIENT_PRIVILEGES + */ + TAPTI_ERROR_INSUFFICIENT_PRIVILEGES = 35, + /** + * Legacy TAPTI Profiling API i.e. event API from the header TAPTI_events.h and + * metric API from the header TAPTI_metrics.h are not compatible with the + * Profiling API in the header TAPTI_profiler_target.h and Perfworks metrics API + * in the headers nvperf_host.h and nvperf_target.h. + */ + TAPTI_ERROR_OLD_PROFILER_API_INITIALIZED = 36, + /** + * Missing definition of the OpenACC API routine in the linked OpenACC library. + * + * One possible reason is that OpenACC library is linked statically in the + * user application, which might not have the definition of all the OpenACC + * API routines needed for the OpenACC profiling, as compiler might ignore + * definitions for the functions not used in the application. This issue + * can be mitigated by linking the OpenACC library dynamically. + */ + TAPTI_ERROR_OPENACC_UNDEFINED_ROUTINE = 37, + /** + * Legacy TAPTI Profiling API i.e. event API from the header TAPTI_events.h and + * metric API from the header TAPTI_metrics.h are not supported on devices with + * compute capability 7.5 and higher (i.e. Turing and later GPU architectures). + * These API will be deprecated in a future TANG release. These are replaced by + * Profiling API in the header TAPTI_profiler_target.h and Perfworks metrics API + * in the headers nvperf_host.h and nvperf_target.h. + */ + TAPTI_ERROR_LEGACY_PROFILER_NOT_SUPPORTED = 38, + /** + * TAPTI doesn't allow multiple callback subscribers. Only a single subscriber + * can be registered at a time. + * Same error code is used when application is launched using NVIDIA tools + * like nvprof, Visual Profiler, Nsight Systems, Nsight Compute, cuda-gdb and + * cuda-memcheck. + */ + TAPTI_ERROR_MULTIPLE_SUBSCRIBERS_NOT_SUPPORTED = 39, + /** + * Profiling on virtualized GPU is not allowed by hypervisor. + */ + TAPTI_ERROR_VIRTUALIZED_DEVICE_INSUFFICIENT_PRIVILEGES = 40, + /** + * Profiling and tracing are not allowed when confidential computing mode + * is enabled. + */ + TAPTI_ERROR_CONFIDENTIAL_COMPUTING_NOT_SUPPORTED = 41, + /** + * An unknown internal error has occurred. + */ + TAPTI_ERROR_UNKNOWN = 999, + TAPTI_ERROR_FORCE_INT = 0x7fffffff +} TAptiResult; + +#ifdef __cplusplus +extern "C" { +#endif //! __cplusplus + +#if defined(_MSC_VER) +#define TAPTI_DEPRECATED __declspec(deprecated) +#define TAPTI_API_EXPORT __declspec(dllexport) +#define TAPTI_API_IMPORT __declspec(dllimport) +#elif defined(__GNUC__) || defined(__clang__) +#define TAPTI_DEPRECATED __attribute__((deprecated)) +#define TAPTI_API_EXPORT __attribute__((visibility("default"))) +#define TAPTI_API_IMPORT __attribute__((visibility("default"))) +#else +#define TAPTI_DEPRECATED +#define TAPTI_API_EXPORT +#define TAPTI_API_IMPORT +#endif //! UNKNOWN COMPILER + +#if defined(tapti_shared_EXPORTS) +#define TAPTI_API TAPTI_API_EXPORT +#else +#define TAPTI_API TAPTI_API_IMPORT +#endif //! For user + +/** + * \brief Get the descriptive string for a TAptiResult. + * + * Return the descriptive string for a TAptiResult in \p *str. + * \note \b Thread-safety: this function is thread safe. + * + * \param result The result to get the string for + * \param str Returns the string + * + * \retval TAPTI_SUCCESS on success + * \retval TAPTI_ERROR_INVALID_PARAMETER if \p str is NULL or \p + * result is not a valid TAptiResult + */ +TAptiResult TAPTI_API taptiGetResultString(TAptiResult result, const char **str); + +#ifdef __cplusplus +} +#endif //! __cplusplus + +#endif // _TAPTI_RESULT_HPP_ diff --git a/third_party/sunrise/backend/include/tapti/tapti_runtime_cbid.h b/third_party/sunrise/backend/include/tapti/tapti_runtime_cbid.h new file mode 100755 index 000000000..75cb5fa3b --- /dev/null +++ b/third_party/sunrise/backend/include/tapti/tapti_runtime_cbid.h @@ -0,0 +1,147 @@ +// ************************************************************************* +// Definitions of indices for API functions, unique across entire API +// ************************************************************************* + +// This file is generated. + +typedef enum tapti_runtime_api_trace_cbid_enum { + TAPTI_RUNTIME_TRACE_CBID_INVALID = 0, + TAPTI_RUNTIME_TRACE_CBID_tangRuntimeGetVersion = 1, + TAPTI_RUNTIME_TRACE_CBID_tangDriverGetVersion = 2, + TAPTI_RUNTIME_TRACE_CBID_tangGetDeviceCount = 3, + TAPTI_RUNTIME_TRACE_CBID_tangGetDevice = 4, + TAPTI_RUNTIME_TRACE_CBID_tangSetDevice = 5, + TAPTI_RUNTIME_TRACE_CBID_tangDeviceSynchronize = 6, + TAPTI_RUNTIME_TRACE_CBID_tangDeviceReset = 7, + TAPTI_RUNTIME_TRACE_CBID_tangDeviceGetByPCIBusId = 8, + TAPTI_RUNTIME_TRACE_CBID_tangDeviceGetPCIBusId = 9, + TAPTI_RUNTIME_TRACE_CBID_tangDeviceSetCacheConfig = 10, + TAPTI_RUNTIME_TRACE_CBID_tangDeviceGetCacheConfig = 11, + TAPTI_RUNTIME_TRACE_CBID_tangDeviceSetSharedMemConfig = 12, + TAPTI_RUNTIME_TRACE_CBID_tangDeviceGetSharedMemConfig = 13, + TAPTI_RUNTIME_TRACE_CBID_tangDeviceGetStreamPriorityRange = 14, + TAPTI_RUNTIME_TRACE_CBID_tangSetValidDevices = 15, + TAPTI_RUNTIME_TRACE_CBID_tangDeviceGetLimit = 16, + TAPTI_RUNTIME_TRACE_CBID_tangDeviceSetLimit = 17, + TAPTI_RUNTIME_TRACE_CBID_tangMalloc = 18, + TAPTI_RUNTIME_TRACE_CBID_tangMallocAsync = 19, + TAPTI_RUNTIME_TRACE_CBID_tangMallocAsync_ptsz = 20, + TAPTI_RUNTIME_TRACE_CBID_tangFree = 21, + TAPTI_RUNTIME_TRACE_CBID_tangFreeAsync = 22, + TAPTI_RUNTIME_TRACE_CBID_tangFreeAsync_ptsz = 23, + TAPTI_RUNTIME_TRACE_CBID_tangMallocHost = 24, + TAPTI_RUNTIME_TRACE_CBID_tangHostAlloc = 25, + TAPTI_RUNTIME_TRACE_CBID_tangHostGetDevicePointer = 26, + TAPTI_RUNTIME_TRACE_CBID_tangHostGetFlags = 27, + TAPTI_RUNTIME_TRACE_CBID_tangFreeHost = 28, + TAPTI_RUNTIME_TRACE_CBID_tangHostRegister = 29, + TAPTI_RUNTIME_TRACE_CBID_tangHostUnregister = 30, + TAPTI_RUNTIME_TRACE_CBID_tangMemGetInfo = 31, + TAPTI_RUNTIME_TRACE_CBID_tangMemcpy = 32, + TAPTI_RUNTIME_TRACE_CBID_tangMemcpy_ptds = 33, + TAPTI_RUNTIME_TRACE_CBID_tangMemcpyAsync = 34, + TAPTI_RUNTIME_TRACE_CBID_tangMemcpyAsync_ptsz = 35, + TAPTI_RUNTIME_TRACE_CBID_tangMemset = 36, + TAPTI_RUNTIME_TRACE_CBID_tangMemset_ptds = 37, + TAPTI_RUNTIME_TRACE_CBID_tangMemsetAsync = 38, + TAPTI_RUNTIME_TRACE_CBID_tangMemsetAsync_ptsz = 39, + TAPTI_RUNTIME_TRACE_CBID_tangGetSymbolAddress = 40, + TAPTI_RUNTIME_TRACE_CBID_tangGetSymbolSize = 41, + TAPTI_RUNTIME_TRACE_CBID_tangMemcpyFromSymbol = 42, + TAPTI_RUNTIME_TRACE_CBID_tangMemcpyFromSymbol_ptds = 43, + TAPTI_RUNTIME_TRACE_CBID_tangMemcpyFromSymbolAsync = 44, + TAPTI_RUNTIME_TRACE_CBID_tangMemcpyFromSymbolAsync_ptsz = 45, + TAPTI_RUNTIME_TRACE_CBID_tangMemcpyToSymbol = 46, + TAPTI_RUNTIME_TRACE_CBID_tangMemcpyToSymbol_ptds = 47, + TAPTI_RUNTIME_TRACE_CBID_tangMemcpyToSymbolAsync = 48, + TAPTI_RUNTIME_TRACE_CBID_tangMemcpyToSymbolAsync_ptsz = 49, + TAPTI_RUNTIME_TRACE_CBID_tangStreamCreate = 50, + TAPTI_RUNTIME_TRACE_CBID_tangStreamCreateWithFlags = 51, + TAPTI_RUNTIME_TRACE_CBID_tangStreamCreateWithPriority = 52, + TAPTI_RUNTIME_TRACE_CBID_tangStreamGetPriority = 53, + TAPTI_RUNTIME_TRACE_CBID_tangStreamGetPriority_ptsz = 54, + TAPTI_RUNTIME_TRACE_CBID_tangStreamAddCallback = 55, + TAPTI_RUNTIME_TRACE_CBID_tangStreamAddCallback_ptsz = 56, + TAPTI_RUNTIME_TRACE_CBID_tangLaunchHostFunc = 57, + TAPTI_RUNTIME_TRACE_CBID_tangLaunchHostFunc_ptsz = 58, + TAPTI_RUNTIME_TRACE_CBID_tangStreamDestroy = 59, + TAPTI_RUNTIME_TRACE_CBID_tangStreamSynchronize = 60, + TAPTI_RUNTIME_TRACE_CBID_tangStreamSynchronize_ptsz = 61, + TAPTI_RUNTIME_TRACE_CBID_tangStreamC2Ctransfers = 62, + TAPTI_RUNTIME_TRACE_CBID_tangDeviceGetPeerPointer = 63, + TAPTI_RUNTIME_TRACE_CBID_tangDeviceGetP2PAttribute = 64, + TAPTI_RUNTIME_TRACE_CBID_tangDeviceCanAccessPeer = 65, + TAPTI_RUNTIME_TRACE_CBID_tangDeviceEnablePeerAccess = 66, + TAPTI_RUNTIME_TRACE_CBID_tangDeviceDisablePeerAccess = 67, + TAPTI_RUNTIME_TRACE_CBID_tangMemcpyPeer = 68, + TAPTI_RUNTIME_TRACE_CBID_tangMemcpyPeerAsync = 69, + TAPTI_RUNTIME_TRACE_CBID_tangEngineCollAssign = 70, + TAPTI_RUNTIME_TRACE_CBID_tangMemcpyPeer_v2 = 71, + TAPTI_RUNTIME_TRACE_CBID_tangMemcpyPeer_v2_ptds = 72, + TAPTI_RUNTIME_TRACE_CBID_tangMemcpyPeerAsync_v2 = 73, + TAPTI_RUNTIME_TRACE_CBID_tangMemcpyPeerAsync_v2_ptsz = 74, + TAPTI_RUNTIME_TRACE_CBID_tangStreamQuery = 75, + TAPTI_RUNTIME_TRACE_CBID_tangStreamQuery_ptsz = 76, + TAPTI_RUNTIME_TRACE_CBID_tangStreamWaitEvent = 77, + TAPTI_RUNTIME_TRACE_CBID_tangStreamWaitEvent_ptsz = 78, + TAPTI_RUNTIME_TRACE_CBID_tangStreamGetFlags = 79, + TAPTI_RUNTIME_TRACE_CBID_tangStreamGetFlags_ptsz = 80, + TAPTI_RUNTIME_TRACE_CBID_tangStreamGetId = 81, + TAPTI_RUNTIME_TRACE_CBID_tangStreamGetId_ptsz = 82, + TAPTI_RUNTIME_TRACE_CBID_tangEventCreateWithFlags = 83, + TAPTI_RUNTIME_TRACE_CBID_tangEventCreate = 84, + TAPTI_RUNTIME_TRACE_CBID_tangEventRecord = 85, + TAPTI_RUNTIME_TRACE_CBID_tangEventRecord_ptsz = 86, + TAPTI_RUNTIME_TRACE_CBID_tangEventRecordWithFlags = 87, + TAPTI_RUNTIME_TRACE_CBID_tangEventRecordWithFlags_ptsz = 88, + TAPTI_RUNTIME_TRACE_CBID_tangEventDestroy = 89, + TAPTI_RUNTIME_TRACE_CBID_tangEventSynchronize = 90, + TAPTI_RUNTIME_TRACE_CBID_tangEventSynchronizeWithFlags = 91, + TAPTI_RUNTIME_TRACE_CBID_tangEventElapsedTime = 92, + TAPTI_RUNTIME_TRACE_CBID_tangEventQuery = 93, + TAPTI_RUNTIME_TRACE_CBID_tangEventQueryTimestamp = 94, + TAPTI_RUNTIME_TRACE_CBID_tangDeviceGetAttribute = 95, + TAPTI_RUNTIME_TRACE_CBID_tangGetDeviceProperties = 96, + TAPTI_RUNTIME_TRACE_CBID_tangChooseDevice = 97, + TAPTI_RUNTIME_TRACE_CBID_tangFuncGetAttributes = 98, + TAPTI_RUNTIME_TRACE_CBID_tangFuncSetAttribute = 99, + TAPTI_RUNTIME_TRACE_CBID_tangFuncSetCacheConfig = 100, + TAPTI_RUNTIME_TRACE_CBID_tangFuncSetSharedMemConfig = 101, + TAPTI_RUNTIME_TRACE_CBID_tangSetDoubleForDevice = 102, + TAPTI_RUNTIME_TRACE_CBID_tangSetDoubleForHost = 103, + TAPTI_RUNTIME_TRACE_CBID_tangOccupancyMaxActiveBlocksPerMultiprocessor = 104, + TAPTI_RUNTIME_TRACE_CBID_tangOccupancyMaxActiveBlocksPerMultiprocessorWithFlags = 105, + TAPTI_RUNTIME_TRACE_CBID_tangPointerGetAttributes = 106, + TAPTI_RUNTIME_TRACE_CBID_tangStreamBeginCapture = 107, + TAPTI_RUNTIME_TRACE_CBID_tangStreamBeginCapture_ptsz = 108, + TAPTI_RUNTIME_TRACE_CBID_tangStreamEndCapture = 109, + TAPTI_RUNTIME_TRACE_CBID_tangStreamEndCapture_ptsz = 110, + TAPTI_RUNTIME_TRACE_CBID_tangStreamIsCapturing = 111, + TAPTI_RUNTIME_TRACE_CBID_tangStreamIsCapturing_ptsz = 112, + TAPTI_RUNTIME_TRACE_CBID_tangStreamGetCaptureInfo = 113, + TAPTI_RUNTIME_TRACE_CBID_tangStreamGetCaptureInfo_ptsz = 114, + TAPTI_RUNTIME_TRACE_CBID_tangGraphInstantiate = 115, + TAPTI_RUNTIME_TRACE_CBID_tangGraphLaunch = 116, + TAPTI_RUNTIME_TRACE_CBID_tangGraphLaunch_ptsz = 117, + TAPTI_RUNTIME_TRACE_CBID_tangGraphInstantiateWithFlags = 118, + TAPTI_RUNTIME_TRACE_CBID_tangGraphDestroy = 119, + TAPTI_RUNTIME_TRACE_CBID_tangGraphExecDestroy = 120, + TAPTI_RUNTIME_TRACE_CBID_tangGraphGetInfo = 121, + TAPTI_RUNTIME_TRACE_CBID_tangGraphCreate = 122, + TAPTI_RUNTIME_TRACE_CBID_tangGraphAddHostNode = 123, + TAPTI_RUNTIME_TRACE_CBID_tangGraphAddKernelNode = 124, + TAPTI_RUNTIME_TRACE_CBID_tangProfilerStart = 125, + TAPTI_RUNTIME_TRACE_CBID_tangProfilerStop = 126, + TAPTI_RUNTIME_TRACE_CBID_tangIpcGetMemHandle = 127, + TAPTI_RUNTIME_TRACE_CBID_tangIpcOpenMemHandle = 128, + TAPTI_RUNTIME_TRACE_CBID_tangIpcCloseMemHandle = 129, + TAPTI_RUNTIME_TRACE_CBID_tangIpcGetEventHandle = 130, + TAPTI_RUNTIME_TRACE_CBID_tangIpcOpenEventHandle = 131, + TAPTI_RUNTIME_TRACE_CBID_tangGetFuncBySymbol = 132, + TAPTI_RUNTIME_TRACE_CBID_tangGetExportTable = 133, + TAPTI_RUNTIME_TRACE_CBID_tangLaunchKernel = 134, + TAPTI_RUNTIME_TRACE_CBID_tangLaunchKernel_ptsz = 135, + TAPTI_RUNTIME_TRACE_CBID_tangThreadExchangeStreamCaptureMode = 136, + TAPTI_RUNTIME_TRACE_CBID_FORCE_INT = 0x7fffffff, +} tapti_runtime_api_trace_cbid; + diff --git a/third_party/sunrise/backend/include/tapti/tapti_version.h b/third_party/sunrise/backend/include/tapti/tapti_version.h new file mode 100755 index 000000000..f548f07ec --- /dev/null +++ b/third_party/sunrise/backend/include/tapti/tapti_version.h @@ -0,0 +1,39 @@ +#ifndef _TAPTI_VERSION_ +#define _TAPTI_VERSION_ + +#include +#include "tapti_result.h" + +#define TAPTI_API_VERSION 1 + +#ifdef __cplusplus +extern "C" { +#endif //! __cplusplus + +#if defined(_MSC_VER) +#define TAPTI_DEPRECATED __declspec(deprecated) +#define TAPTI_API_EXPORT __declspec(dllexport) +#define TAPTI_API_IMPORT __declspec(dllimport) +#elif defined(__GNUC__) || defined(__clang__) +#define TAPTI_DEPRECATED __attribute__((deprecated)) +#define TAPTI_API_EXPORT __attribute__((visibility("default"))) +#define TAPTI_API_IMPORT __attribute__((visibility("default"))) +#else +#define TAPTI_DEPRECATED +#define TAPTI_API_EXPORT +#define TAPTI_API_IMPORT +#endif //! UNKNOWN COMPILER + +#if defined(tapti_shared_EXPORTS) +#define TAPTI_API TAPTI_API_EXPORT +#else +#define TAPTI_API TAPTI_API_IMPORT +#endif //! For user + +TAptiResult TAPTI_API taptiGetVersion(uint32_t *version); + +#ifdef __cplusplus +} +#endif //! __cplusplus + +#endif // _TAPTI_VERSION_ diff --git a/third_party/sunrise/backend/spec/__init__.py b/third_party/sunrise/backend/spec/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/sunrise/backend/spec/include/flagtree_spec.h b/third_party/sunrise/backend/spec/include/flagtree_spec.h new file mode 100644 index 000000000..aacb7fb6c --- /dev/null +++ b/third_party/sunrise/backend/spec/include/flagtree_spec.h @@ -0,0 +1,12 @@ +#ifndef SUNRISE_FLAGTREE_SPEC_H +#define SUNRISE_FLAGTREE_SPEC_H + +#include "triton/Dialect/TritonGPU/IR/sunrise_Dialect.h" +#include "triton/Dialect/TritonGPU/IR/sunrise_LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/sunrise_Coalesce.h" +#include "triton/Dialect/TritonGPU/Transforms/Pipeliner/sunrise_PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/sunrise_Prefetch.h" +#include "triton/Dialect/TritonGPU/Transforms/sunrise_RemoveLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/sunrise_Utility.h" + +#endif // SUNRISE_FLAGTREE_SPEC_H diff --git a/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/CMakeLists.txt b/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td new file mode 100644 index 000000000..e714d8f85 --- /dev/null +++ b/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -0,0 +1,1447 @@ +#ifndef TRITONGPU_ATTRDEFS +#define TRITONGPU_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" + +//===----------------------------------------------------------------------===// +// TritonGPU Attribute Definitions +//===----------------------------------------------------------------------===// +def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + + let methods = [ + ]; +} + +def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">; + + +class TritonGPU_Attr traits = [], + Dialect dialect = TritonGPU_Dialect, + string baseCppClass = "::mlir::Attribute"> + : AttrDef { + + let description = [{ +TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines +how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function +\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding +to the indices of the CUDA threads allowed to access some data at index $i$. + +For example, let us consider the layout function: +\mathcal{L}(0, 0) = {0, 4} +\mathcal{L}(0, 1) = {1, 5} +\mathcal{L}(1, 0) = {2, 6} +\mathcal{L}(1, 1) = {3, 7} + +Then, attaching $\mathcal{L} to a tensor $T$ would mean that: +- T[0,0] is owned by both cuda thread 0 and 4 +- T[0,1] is owned by both cuda thread 1 and 5 +- T[1,0] is owned by both cuda thread 2 and 6 +- T[1,1] is owned by both cuda thread 3 and 7 + +Right now, Triton implements two main classes of layouts: shared, and distributed. + }]; + let attrName = "triton.gpu." # attrMnemonic; + + code extraBaseClassDeclaration = [{ + }]; +} + +//===----------------------------------------------------------------------===// +// CTA Layout +//===----------------------------------------------------------------------===// + +def CTALayoutAttr : TritonGPU_Attr<"CTALayout", "cta_layout"> { + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$CTAsPerCGA, + ArrayRefParameter<"unsigned">:$CTASplitNum, + ArrayRefParameter<"unsigned">:$CTAOrder + ); + + let description = [{ +Describes how blocks are distributed among the cooperate thread arrays (aka +CTAs, aka thread blocks) in a cooperate thread group (aka CTG, aka thread group +cluster). CGAs were introduced in Hopper (sm90). + +The tensor is divided up into CTASplitNum pieces, which are distributed among +the CTAsPerCGA thread blocks. Each CTA processes a subtensor of shape +`tensor_shape / CTASplitNum`. + +Example 0: The tensor shape is [64, 128] and, there are two CTAs, each +processing half the tensor [64, 64]. Then CTAsPerCGA = [1, 2] and +CTASplitNum = [1, 2]. + +Example 1: The tensor shape is [64, 128] and, there are two CTAs, both +processing the complete tensor [64, 128]. This happens when multicast is +enabled. In this case, CTAsPerCTA = [1, 2] but CTASplitNum = [1, 1]. + +Example 2: Consider a matmul AxB=C, where A=[M,K], B=[K,N], C=[M,N]. The +CTAsPerCGA for A, B, C are the same, [SplitM, SplitN], but the CTASplitNum are +different. CTASplitNum_A = [SplitM, 1], which means multicast on dim1, +CTASplitNum_B = [1, SplitN], which means multicast on dim0, CTASplitNum_C = +[SplitM, SplitN] which means no multicast. + +Currently programs with multiple CTAs per CGA are an experimental feature in +Triton, not enabled by default. + +You can leave off the CTALayout properties in the textual IR and Triton will +fill in the "default" CTALayout of CTAsPerCGA = CTASplitNum = [1...1]. In +addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to +[n-1,...,0] (it doesn't matter in this case). + }]; + + // CTALayout::get canonicalizes CTAOrder to [n,n-1,...,0] if CTAsPerCGA is + // [1...1]. The CTAOrder doesn't matter in this case. + // + // This is a little weird because if you write textual IR with a one order and + // then print it back out, you might get a different order. But it seems this + // is the best way to canonicalize an attribute in MLIR. + let builders = [ + AttrBuilder<(ins "ArrayRef":$CTAsPerCGA, + "ArrayRef":$CTASplitNum, + "ArrayRef":$CTAOrder), [{ + if (llvm::all_of(CTAsPerCGA, [](unsigned x) { return x == 1; })) { + SmallVector order; + for (int i = CTAsPerCGA.size() - 1; i >= 0; --i) + order.push_back(i); + return $_get(context, CTAsPerCGA, CTASplitNum, order); + } + return $_get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + }]>, + ]; + + let extraClassDeclaration = [{ + static CTALayoutAttr getDefault(MLIRContext *context, int rank) { + SmallVector CTAsPerCGA(rank, 1); + SmallVector CTASplitNum(rank, 1); + SmallVector CTAOrder; + for (int i = rank - 1; i >= 0; --i) + CTAOrder.push_back(i); + return get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + } + unsigned getRank() const { + return getCTAOrder().size(); + } + }]; + + let genVerifyDecl = 1; + let skipDefaultBuilders = 1; +} + + +def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + let description = [{ + Common trait for all TTGIR layouts. + }]; + let methods = [ + InterfaceMethod<"Get the shape of the CTAs per CGA.", + "SmallVector", + "getCTAsPerCGA">, + InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first", + "SmallVector", + "getCTAOrder">, + InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.", + "SmallVector", + "getCTASplitNum">, + ]; +} + +//===----------------------------------------------------------------------===// +// Shared Layout Encoding +//===----------------------------------------------------------------------===// + +def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + + let description = [{ + Common trait describing shared memory. + }]; + let methods = [ + InterfaceMethod<"Return the default alignment for the layout.", + "int32_t", + "getAlignment">, + ]; +} + +def SwizzledSharedEncodingAttr : + TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> { + let mnemonic = "swizzled_shared"; + + let description = [{ +An encoding for tensors whose elements may be simultaneously accessed by +different cuda threads in the programs, via shared memory. In other words, +for all indices i \in Z^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}. + +In order to avoid shared memory bank conflicts, elements may be swizzled. +Here are some examples. In all cases, the input tensor is [0, 1, ..., n-1]. + +1. Basic swizzling + + #shared<{vec=1, perPhase=1, maxPhase=4, order=[1,0]}> + [ 0, 1, 2, 3], // xor with 0 + [ 5, 4, 7, 6], // xor with 1 + [10, 11, 8, 9], // xor with 2 + [15, 14, 13, 12] // xor with 3 + +Here elements of row r are xor'ed with r (or more properly, in[r][c] -> +out[r][c^r]). + +2. Multiple rows per phase + + #shared<{vec=1, perPhase=2, maxPhase=4, order=[1,0]}> + [ 0, 1, 2, 3], // phase 0 (xor with 0) + [ 4, 5, 6, 7], + [ 9, 8, 11, 10], // phase 1 (xor with 1) + [13, 12, 15, 14] + +Elements of row r are xor'ed with r/2. In other words, perPhase=2 +means that pairs of 2 rows get the same swizzling. + +3. Max-phase applied + + $shared<{vec=1, perPhase=1, maxPhase=2, order=[1,0]}> + [ 0, 1, 2, 3], // phase 0 (xor with 0) + [ 5, 4, 7, 6], // phase 1 (xor with 1) + [ 8, 9, 10, 11], // phase 0 + [13, 12, 15, 14], // phase 1 + [16, 17, 18, 19], // ... + [21, 20, 23, 22], + [24, 25, 26, 27], + [29, 28, 31, 30] + +Elements of row r are xor'ed with (r/2) % 2. In other words, maxPhase=m has the +effect of limiting the maximum value of the xor to m-1. + +4. Max-phase and per-phase + + #shared<{vec=1, perPhase=2, maxPhase=2, order=[1,0]}> + [ 0, 1, 2, 3], // phase 0 (xor with 0) + [ 4, 5, 6, 7], // phase 0 + [ 9, 8, 11, 10], // phase 1 (xor with 1) + [13, 12, 15, 14], // phase 1 + [16, 17, 18, 19], // phase 0 + [20, 21, 22, 23], // phase 0 + [25, 24, 27, 26], // phase 1 + [29, 28, 31, 30]] // phase 1 + +Here the xor value (the "phase", I guess?) changes every perPhase rows, up to a +maximum value of maxPhase-1. In other words, elements of row r are xor'ed with +(r/2) % 2. + +5. Adding vec + + #shared<{vec=2, perPhase=1, maxPhase=4, order=[1,0]}> + [ 0, 1, 2, 3, 4, 5, 6, 7], + [10, 11, 8, 9, 14, 15, 12, 13], + [20, 21, 22, 23, 16, 17, 18, 19], + [30, 31, 28, 29, 26, 27, 24, 25] + +When vec=2, elements are swizzled in pairs of 2. In other words, the element at +(r,c) has value + + ((c / 2) ^ r) * 2 + (c % 2). + }]; + + // swizzle info: vec, perPhase, maxPhase + // order: the fastest-changing axis first + let parameters = ( + ins + "unsigned":$vec, + "unsigned":$perPhase, + "unsigned":$maxPhase, + ArrayRefParameter<"unsigned">:$order, + "CTALayoutAttr":$CTALayout + ); + + let builders = [ + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "unsigned":$typeWidthInBit), [{ + bool needTrans = false; // default value + return get(context, dotOpEnc, shape, order, CTALayout, typeWidthInBit, needTrans); + }]>, + + // TODO(jlebar): This should not be an overload of + // SwizzledSharedEncodingAttr::get(). It's misleading, because it does a bunch of + // nontrivial work based on the given dotOpEnc. + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "unsigned":$typeWidthInBit, + "bool":$needTrans), [{ + + // ---- begin MFMA ---- + if (auto mfmaEnc = mlir::dyn_cast(dotOpEnc.getParent())) { + return mfmaEnc.composeSharedLayoutForOperand( + CTALayout, dotOpEnc.getOpIdx(), shape, order, dotOpEnc.getKWidth(), + typeWidthInBit, needTrans); + } + + // ---- begin WMMA ---- + if (mlir::isa(dotOpEnc.getParent())) { + if (dotOpEnc.getOpIdx() == 0) { + const int numBanks = 32; + const int bankBitWidth = 32; + + // number of inner dimension rows per one pattern repeat + int innerDimLength = shape[order[0]]; + int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit; + + int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); + int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit; + int maxPhase = 16 / perPhase; + + return get(context, vecSize, perPhase, maxPhase, order, CTALayout); + } else { + // Do not swizzle in case k dimension is not innermost. + // In this case accesses will go in different banks even without swizzling. + return get(context, 1, 1, 1, order, CTALayout); + } + } + + // ---- begin SunriseMMA ---- + auto sunriseMmaEnc = mlir::dyn_cast(dotOpEnc.getParent()); + if(sunriseMmaEnc) { + // return get(context, 1, 1, 1, order, CTALayout); + // 4bit elemSize的情况需要在确认一下结果的正确性??? + // 确认一下不使用MMA的条件??? + int opIdx = dotOpEnc.getOpIdx(); + auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); + int vec = 1, perPhase = 1, maxPhase = 1; + + perPhase = 128 / (shapePerCTA[1] * typeWidthInBit / 8); // 32个bank一共128B,除以tensor一行的字节数 + perPhase = std::max(perPhase, 1); + if(opIdx == 0) { + vec = 4 * (32 / typeWidthInBit); // A一个warp对应的一行永远是4x4B + maxPhase = 8 / perPhase; + } else { + vec = 8; // B一个warp对应的一行永远是8个元素 + maxPhase = 4 * (32 / typeWidthInBit) / perPhase; + } + // context, vec, perPhase, maxPhase, order, CTALayout + return get(context, vec, perPhase, maxPhase, order, CTALayout); + } + + auto blockEnc = mlir::dyn_cast(dotOpEnc.getParent()); + if (blockEnc) { + // 针对fp16类形使用FMA的情况,可生成fp16x2类形的fma指令 + auto sizePerThread = blockEnc.getSizePerThread(); + unsigned vec = sizePerThread[sizePerThread.size() - 1]; + return get(context, vec, 1, 1, order, CTALayout); + } + + auto mmaEnc = mlir::dyn_cast(dotOpEnc.getParent()); + + if(!mmaEnc) + return get(context, 1, 1, 1, order, CTALayout); + + // ---- begin Ampere & Hopper ---- + if (mmaEnc.isAmpere() || mmaEnc.isHopper()) { + return get(context, dotOpEnc.getOpIdx(), dotOpEnc.getKWidth(), shape, order, CTALayout, typeWidthInBit, needTrans); + } + + // ---- not implemented ---- + llvm_unreachable("unsupported swizzling for provided MMA version"); + }]>, + + // NVIDIA constructor! + // TODO(lezcano): We should totally get rid of all these constructors... + AttrBuilder<(ins "int":$opIdx, + "unsigned":$kWidth, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "unsigned":$bitwidth, + "bool":$needTrans), [{ + int K = getShapePerCTA(CTALayout.getCTASplitNum(), shape)[order[0]]; + // Elems necessary to cover all the banks divided by the inner dimension + // This packs a few rows together for small K + int perPhase = std::max(1024 / (bitwidth * K), 1); + + int mmaStride = 8; + int vec = 4 * kWidth; + // needsTrans is equiv. to flipping the opIdx + if (needTrans) + std::swap(vec, mmaStride); + assert(opIdx == 0 || opIdx == 1); + int rank = order.size(); + int kDim = opIdx == 0 ? rank-1 : rank-2; + if (order[0] != kDim) + std::swap(vec, mmaStride); + // Count how many vec elements are needed to cover all the banks + int maxPhase = std::max(std::min(mmaStride, 1024 / (vec * bitwidth)), 1); + // Account for the row packing from perPhase: mmaStride / perPhase + maxPhase = std::max(maxPhase / perPhase, 1); + return get(context, vec, perPhase, maxPhase, order, CTALayout); + }]>, + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy), [{ + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + return get(context, dotOpEnc, shape, order, CTALayout, bitwidth); + }]>, + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy, + "bool":$needTrans), [{ + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + return get(context, dotOpEnc, shape, order, CTALayout, bitwidth, needTrans); + }]>, + ]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + unsigned getRank() const { return getCTAOrder().size(); } + int32_t getAlignment() const; + SmallVector getCTAsPerCGA() const; + SmallVector getCTAOrder() const; + SmallVector getCTASplitNum() const; + }]; + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + +def NVMMASharedEncodingAttr : + TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> { + let mnemonic = "nvmma_shared"; + + let description = [{ + Represent blocked shared memory matching MMAv3/MMAv5 shared memory input. + This is meant to represent 2d tiled blocked layout. + The full layout representation is described here: + https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-shared-memory-layout + When the memdesc has more than 2 dimensions the tiling is applied to 8 rows even if the first outer dimension is smaller than 8. + In this case `transposed` means that the contiguous dimension is the most outer dimension of the memdesc. + }]; + + + // fp4Padded: Indicates that this encoding represents a mixed-precision fp4 operand in MMAv5 scaled dot, which needs + // to be in the special padded layout as described in https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory + let parameters = ( + ins + "unsigned":$swizzlingByteWidth, + "bool":$transposed, + "unsigned":$elementBitWidth, + "bool":$fp4Padded, + "CTALayoutAttr":$CTALayout + ); + + let builders = [ + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy, + "bool": $fp4Padded), [{ + auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); + int32_t swizzlingByteWidth = 0; + unsigned eleBitWidth = eltTy.getIntOrFloatBitWidth(); + int packingFactor = fp4Padded ? 2 : 1; + + // get proper shared memory swizzling mode from the contiguous dimension + // size of the origin blocked layout. + auto contigDimSizeInByte = shapePerCTA[order[0]] * packingFactor * eleBitWidth / 8; + if (contigDimSizeInByte >= 128 && contigDimSizeInByte % 128 == 0) { + swizzlingByteWidth = 128; + } else if (contigDimSizeInByte >= 64 && contigDimSizeInByte % 64 == 0) { + swizzlingByteWidth = 64; + } else if (contigDimSizeInByte >= 32 && contigDimSizeInByte % 32 == 0) { + swizzlingByteWidth = 32; + } else { + swizzlingByteWidth = 0; + } + int flattenOutterDim = 1; + for (int i = 1; i < shapePerCTA.size(); i++) { + flattenOutterDim *= shapePerCTA[order[i]]; + } + if (shapePerCTA.size() < 2 || flattenOutterDim < 8) { + swizzlingByteWidth = 0; + } + bool transposed = order[0] == 0; + return $_get(context, swizzlingByteWidth, transposed, eleBitWidth, fp4Padded, CTALayout); + }]> + ]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + unsigned getRank() const { return getCTAOrder().size(); } + int32_t getAlignment() const; + SmallVector getCTAsPerCGA() const; + SmallVector getCTAOrder() const; + SmallVector getCTASplitNum() const; + int getPerPhase() const; + int getMaxPhase() const; + int getVec() const; + }]; + let hasCustomAssemblyFormat = 1; +} + +def AMDRotatingSharedEncodingAttr : + TritonGPU_Attr<"AMDRotatingSharedEncoding", "amd_rotating_shared_encoding", + [SharedEncodingTrait, LayoutEncodingTrait]> { + let mnemonic = "amd_rotating_shared"; + + let description = [{ +This shared encoding is similar to SwizzledSharedEncodingAttr, but instead of +repeating swizzling pattern every `maxPhase*perPhase` rows of the memory object, +called a block, this layout changes swizzling pattern `maxPhase` times, then +repeats the pattern. The name "rotating" comes from the fact that first tensor +element of each block is swizzled with different phase, which is equal to +current block number: 0, 1, 2.. maxPhase-1, 0, 1, 2 ... + +This layout is used to reduce bank conflicts in cases where shared memory writes +and reads are performed on layouts with different order. It's meant for hardware +without native shared memory tranpose support. + +Swizzling pattern affects only 2 fastest dimensions of a tensor. +In the following text these two dimensions are called row and column: +- row is a fastest dimension +- column is a second fastest dimension + +Elements in a row dimension are stored in memory contiguously. + +If a matrix of size [128x64] is stored in this shared layout with order [1, 0], +dim 1 (64) will be stored contiguously and called row, dim 0 (128) is will be +called column. If order of shared layout is [0, 1], dim 0 (128) is stored +contiguously becomes a row, dim 1 (64) becomes a column. + +Swizzling pattern is following: + +Let's consider an element with logical coordinates = (inRowId, inColId). +For simplicity, we do not vectorize memory in examples, +i.e. vec == 1 and layout swizzles inidividual elements. +For vec != 1 example, take a look at SwizzledSharedEncodingAttr documentation. + +Swizzled coordinates within memory object are (outRowId, outColId): + + outRowId = inRowId + phase = (inRowId / perPhase) % maxPhase + blockNo = (inRowId / (perPhase * maxPhase)) % maxPhase + combinedPhase = phase ^ blockNo + outColId = inColId ^ combinedPhase + +Actual offset in memory could be computed with following function: + +memmory_offset = (outColId + outRowId * num_of_element_in_row) * sizeof(element) + + +Swizzling examples (matrix is filled with numbers 0, 1, 2, .. columns*rows-1): + + #shared<{vec=1, perPhase=1, maxPhase=2, order=[1,0]}> + row elements + 0 [ 0, 1, 2, 3], // phase = 0 blockNo = 0 (xor with 0) + 1 [ 5, 4, 7, 6], // phase = 1 blockNo = 0 (xor with 1) + 2 [ 9, 8, 11, 10], // phase = 0 blockNo = 1 (xor with 1) + 3 [12, 13, 14, 15] // phase = 1 blockNo = 1 (xor with 0) + 4 [16, 17, 18, 19], // phase = 0 blockNo = 0 (xor with 0) + 5 [21, 20, 23, 22], // phase = 1 blockNo = 0 (xor with 1) + 6 [25, 24, 27, 26], // phase = 0 blockNo = 1 (xor with 1) + 7 [28, 29, 30, 31] // phase = 1 blockNo = 1 (xor with 0) + + #shared<{vec=1, perPhase=2, maxPhase=2, order=[1,0]}> + row elements + 0 [ 0, 1, 2, 3], // phase = 0 blockNo = 0 (xor with 0) + 1 [ 4, 5, 6, 7], // phase = 0 blockNo = 0 (xor with 0) + 2 [ 9, 8, 11, 10], // phase = 1 blockNo = 0 (xor with 1) + 3 [13, 12, 15, 14] // phase = 1 blockNo = 0 (xor with 1) + 4 [17, 16, 19, 18], // phase = 0 blockNo = 1 (xor with 1) + 5 [21, 20, 23, 22], // phase = 0 blockNo = 1 (xor with 1) + 6 [24, 25, 26, 27], // phase = 1 blockNo = 1 (xor with 0) + 7 [28, 29, 30, 31] // phase = 1 blockNo = 1 (xor with 0) + + #shared<{vec=1, perPhase=1, maxPhase=4, order=[1,0]}> + row elements + 0 [ 0, 1, 2, 3], // phase = 0 blockNo = 0 (xor with 0) + 1 [ 5, 4, 7, 6], // phase = 1 blockNo = 0 (xor with 1) + 2 [10, 11, 8, 9], // phase = 2 blockNo = 0 (xor with 2) + 3 [15, 14, 13, 12] // phase = 3 blockNo = 0 (xor with 3) + 4 [17, 16, 19, 18], // phase = 0 blockNo = 1 (xor with 1) + 5 [20, 21, 22, 23], // phase = 1 blockNo = 1 (xor with 0) + 6 [27, 26, 25, 24], // phase = 2 blockNo = 1 (xor with 3) + 7 [30, 31, 28, 29] // phase = 3 blockNo = 1 (xor with 2) + }]; + + let parameters = ( + ins + "unsigned":$vec, + "unsigned":$perPhase, + "unsigned":$maxPhase, + ArrayRefParameter<"unsigned">:$order, + "CTALayoutAttr":$CTALayout + ); + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + unsigned getRank() const { return getCTAOrder().size(); } + int32_t getAlignment() const; + SmallVector getCTAsPerCGA() const; + SmallVector getCTAOrder() const; + SmallVector getCTASplitNum() const; + }]; + let hasCustomAssemblyFormat = 1; +} + + +//===----------------------------------------------------------------------===// +// Distributed Layout Encoding +//===----------------------------------------------------------------------===// +def DistributedEncodingTrait : AttrInterface<"DistributedEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + + let description = [{ +The Distributed encoding describes the layout L with the 4-level compute hierarchy on GPU. +It is abstracted from the top to the bottom as CTAs Per CGA->Warps Per CTA->Threads Per Warp->Values Per Thread. + +For CTAs Per CGA and Warps Per CTA level, the linear id is distributed contiguously with the shape and order. +For example, for a shape/order pair defines a distribution layout +shape = [4, 4] +order = [0, 1] // The fastest-changing axis first +-> +layout = [0 4 8 12] + [1 5 9 13] + [2 6 10 14] + [3 7 11 15] + +For the Threads Per Warp and Values Per Thread level, the linear id distribution is variant for each sub-class encoding. + +If the layout does not completely cover the tensor, we tile it until we cover the entire tensor. +We call each individual tile "rep". + }]; + + let methods = [ + InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first", + "SmallVector", + "getRepOrder">, + InterfaceMethod<"Return total element size per thread.", + "unsigned", + "getTotalElemsPerThread", + (ins "ArrayRef":$shape), + /*defaultImplementation=*/[{ + return toLinearEncoding($_self, shape).getTotalElemsPerThread(shape); + }]>, + InterfaceMethod<"Return element size per thread in each dimension.", + "SmallVector", + "getElemsPerThread", + (ins "ArrayRef":$shape), + /*defaultImplementation=*/[{ + return toLinearEncoding($_self, shape).getElemsPerThread(shape); + }]>, + InterfaceMethod<"Convert to LinearLayout.", + "LinearLayout", + "toLinearLayout", + (ins "ArrayRef":$shape)>, + ]; +} + +class DistributedEncoding traits = [], + Dialect dialect = TritonGPU_Dialect> + : TritonGPU_Attr { + + let description = [{ +Distributed encodings have a layout function L that is entirely characterized +by a d-dimensional tensor T. Note that L doesn't need to have the same shape +(or even the same rank) as the tensor it is encoding. + +The layout function \mathcal{L} of this layout is then defined, for an +index `i` \in Z^d, as follows: + +\mathcal{L}(T)[i_d] = L[(i_d + k_d*T.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*T.shape[d] < L.shape[d] + +Intuitively, when the tensor dim size T.shape[d] is larger than the layout +dim size L.shape[d], on that particular dim, we distribute values from the +tensor to threads mapped in the layout in a "wrapped around" manner, with +each thread owning multiple values. + +OTOH, when the tensor dim size T.shape[d] is smaller than the layout +dim size L.shape[d], on that particular dim, we distribute values from the +tensor to threads mapped in the layout in a "broadcasted" manner, with +each value owned by multiple threads. + +For example, for a tensor/layout pair +T = [x x x x x x x x] + [x x x x x x x x] +L = [0 1 2 3 ] + [4 5 6 7 ] + [8 9 10 11] + [12 13 14 15] + +Then the data of T would be distributed as follow between the 16 CUDA threads: +L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, + {4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ] + }]; + + code extraDistributedDeclaration = extraBaseClassDeclaration # [{ + unsigned getRank() const { return getCTAOrder().size(); } + // Implemented in subclasses + SmallVector getRepOrder() const; + SmallVector getCTAsPerCGA() const; + SmallVector getCTAOrder() const; + SmallVector getCTASplitNum() const; + + LinearLayout toLinearLayout(ArrayRef shape) const; + }]; +} + +//===----------------------------------------------------------------------===// +// Linear Layout Encoding +//===----------------------------------------------------------------------===// + +def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout", + "linear layout"> { + let cppAccessorType = "const LinearLayout &"; +} + +def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> { + let mnemonic = "linear"; + + let description = [{ + See the docs in LinearLayout.h for the definition of linear layouts. + }]; + + let parameters = (ins LinearLayoutParam:$linearLayout); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + // Generic distributed encoding methods + unsigned getTotalElemsPerThread(ArrayRef shape) const; + SmallVector getElemsPerThread(ArrayRef shape) const; + + SmallVector getContig(const char *, SmallVector) const; + SmallVector getContigPerThread() const; + SmallVector getContigPerWarp() const; + SmallVector getOrder() const; + SmallVector getWarpOrder() const; + SmallVector getThreadOrder() const; + + + // Generalizes get{Warp,Thread,CTA}Order to linear layouts. + // Returns the order of the dimensions `dimName` of the layout. + // If more than dimension is of size one, it uses defaultOrder to determine + // the order of the dimensions of size one. + SmallVector orderPerDim(StringAttr dimName, + ArrayRef defaultOrder) const; + + // Generalizes getThreadsPerWarp, getWarpsPerCTA, getCTAsPerCGA to linear layouts. + // Returns the bases of the dimensions `dimName` of the layout. + // If skipBroadcast is false, we count a base zero + SmallVector basesPerDim(StringAttr dimName, + bool skipBroadcast = true) const; + SmallVector getThreadsPerWarp() const; + SmallVector getWarpsPerCTA() const; + + // [FIXME LL] Supports legacy behaviour. We should remove these functions + SmallVector getShapePerCTATile() const; + SmallVector getSizePerThread() const; + }]; + + let genVerifyDecl = 1; + // Example of assembly format: + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + let hasCustomAssemblyFormat = 1; +} + + +//===----------------------------------------------------------------------===// +// Blocked Layout Encoding +//===----------------------------------------------------------------------===// + +def BlockedEncodingAttr : DistributedEncoding<"BlockedEncoding", "blocked_encoding"> { + let mnemonic = "blocked"; + + let description = [{ +An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout +used to promote memory coalescing in LoadInst and StoreInst. +It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which +specify the amount of elements owned by each CUDA thread, warp and CTA respectively. + +Example 1, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows: + +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] + +for + +#ttg.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + warpsPerCTA = {1, 2} + CTAsPerCGA = {1, 1} + CTASplitNum = {1, 1} +}> + +Example 2, a row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) as follows: + +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +for + +#ttg.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + warpsPerCTA = {1, 2} + CTAsPerCGA = {1, 1} + CTASplitNum = {1, 1} +}> + +Example 3, A row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) and +4 CTAs (taking 2x2 for example) as follows: + +CTA [0,0] CTA [0,1] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] + +CTA [1,0] CTA [1,1] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +for + +#ttg.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + warpsPerCTA = {1, 2} + CTAsPerCGA = {2, 2} + CTASplitNum = {2, 2} +}> +}]; + + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$sizePerThread, + ArrayRefParameter<"unsigned">:$threadsPerWarp, + ArrayRefParameter<"unsigned">:$warpsPerCTA, + ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first + + // CTALayout is optional in the textual IR. If omitted, we infer it to be a + // single CTA (so CTAsPerCGA = [1,...,1], CTASplitNum = [1,...,1], + // CTAOrder=[n,n-1,...,0]). + "CTALayoutAttr":$CTALayout + ); + let genVerifyDecl = 1; + + let builders = [ + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$sizePerThread, + "ArrayRef":$order, + "unsigned":$numWarps, + "unsigned":$numThreadsPerWarp, + "CTALayoutAttr":$CTALayout), [{ + unsigned rank = sizePerThread.size(); + SmallVector threadsPerWarp(rank); + SmallVector warpsPerCTA(rank); + SmallVector shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); + + unsigned remainingLanes = numThreadsPerWarp; + unsigned remainingThreads = numWarps * numThreadsPerWarp; + unsigned remainingWarps = numWarps; + unsigned prevLanes = 1; + unsigned prevWarps = 1; + + // starting from the contiguous dimension + for (unsigned d = 0; d < rank - 1; ++d) { + unsigned i = order[d]; + unsigned threadsPerCTA = std::clamp(remainingThreads, 1, std::max(1, shapePerCTA[i] / sizePerThread[i])); + threadsPerWarp[i] = std::clamp(threadsPerCTA, 1, remainingLanes); + warpsPerCTA[i] = std::clamp(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps); + remainingWarps /= warpsPerCTA[i]; + remainingLanes /= threadsPerWarp[i]; + remainingThreads /= threadsPerCTA; + prevLanes *= threadsPerWarp[i]; + prevWarps *= warpsPerCTA[i]; + } + + // Expand the last dimension to fill the remaining lanes and warps + threadsPerWarp[order[rank - 1]] = numThreadsPerWarp / prevLanes; + warpsPerCTA[order[rank - 1]] = numWarps / prevWarps; + + return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, CTALayout); + }]>, + + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$sizePerThread, + "ArrayRef":$order, + "unsigned":$numWarps, + "unsigned":$numThreadsPerWarp, + "unsigned":$numCTAs), [{ + unsigned rank = sizePerThread.size(); + SmallVector CTAsPerCGA(rank); + SmallVector CTASplitNum(rank); + ArrayRef CTAOrder = order; + + unsigned remainingCTAs = numCTAs; + + // starting from the most strided dimension + for (int d = rank - 1; d >= 0; --d) { + unsigned i = order[d]; + CTAsPerCGA[i] = std::clamp(remainingCTAs, 1, std::max(1, shape[i] / sizePerThread[i])); + CTASplitNum[i] = CTAsPerCGA[i]; + remainingCTAs /= CTAsPerCGA[i]; + } + + CTAsPerCGA[rank - 1] *= remainingCTAs; // wrap at CTA level + + CTALayoutAttr CTALayout = CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + return get(context, shape, sizePerThread, order, numWarps, numThreadsPerWarp, CTALayout); + }]> + ]; + + let extraClassDeclaration = extraDistributedDeclaration; + + let hasCustomAssemblyFormat = 1; +} + +//===----------------------------------------------------------------------===// +// MMA Layout Encoding +//===----------------------------------------------------------------------===// + +def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + let methods = [ + InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first", + "SmallVector", + "getRepOrderForOperand", + (ins "int":$opIdx)>, + ]; +} + +def AMDMfmaEncodingAttr : DistributedEncoding<"AMDMfmaEncoding", "amd_mfma_encoding", [MmaEncodingTrait]> { + let mnemonic = "amd_mfma"; + + let description = [{ +An encoding for tensors that have been produced by MFMA matrix core instructions, +available on AMD Instinct GPUs of CDNA architectures. + +It is characterized by the following parameters: +- `versionMajor` and `versionMinor` indicates the GPU architecture: + - 1.0: gfx908, i.e. CDNA1 + - 2.0: gfx90a: i.e. CDNA2 + - 3.0: gfx942: CDNA3 + - 4.0: gfx950: CDNA4 +- `warpsPerCTA` indicates the warp layout in the block. +- `MDim` and `NDim` indicate the dimension of the output of the mfma instruction. +- `isTransposed` indicates the result tensor is transposed so that it can be converted to dotOperand layout +without going to shared memory. This is used in the case of chained dot (E.g. Flash-Attention kernel). + +Example 1: +Suppose we have a tensor with a shape of [32, 64], warpsPerCTA set to [1, 2] and MDim=NDim=32. +The data will be distributed between threads as follows: + + warp 0 warp 1 +-----------------/\-------------- -----------------/\-------------- +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] + +Example 2: +Suppose we have a tensor with a shape of [16, 32], warpsPerCTA set to [1, 2] and MDim=NDim=16. +The data will be distributed between threads as follows: + + warp 0 warp 1 +-----------------/\------------- ------------------/\--------------- +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] + +Example 3: +Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and nonKDim set to 4. +The data will be distributed between threads as follows(note that each element is duplicated in 16 threads): +Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and MDim=NDim=4. +The data will be distributed between threads as follows(note that each element is duplicated in 16 threads): + +M N -> warp 0 warp 2 +| --------------------------/\-------------------------- ------------------------------/\------------------------------ +V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + warp 1 warp 3 + --------------------------/\-------------------------- ------------------------------/\------------------------------ + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] +}]; + + let parameters = ( + ins + "unsigned": $versionMajor, + "unsigned": $versionMinor, + ArrayRefParameter<"unsigned">:$warpsPerCTA, + "unsigned":$MDim, + "unsigned":$NDim, + "bool":$isTransposed, + "CTALayoutAttr":$CTALayout + ); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + SmallVector getInstrShapeForOperand(int kWidth, int opIdx) const; + SmallVector getRepForOperand(ArrayRef operandShape, int kWidth, int opIdx) const; + SmallVector getRepOrderForOperand(int opIdx) const; + + // Returns a swizzled shared layout matching this MFMA layout for the + // dot operand at the given |operandIdx| with |operandShape|. + SwizzledSharedEncodingAttr composeSharedLayoutForOperand( + CTALayoutAttr ctaLayout, int operandIdx, ArrayRef operandShape, + ArrayRef sharedOrder, unsigned vectorSize, + unsigned elemBitWidth, bool needTrans) const; + }]; + + let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; +} + +def AMDWmmaEncodingAttr : DistributedEncoding<"AMDWmmaEncoding", "amd_wmma_encoding", [MmaEncodingTrait]> { + let mnemonic = "amd_wmma"; + + let description = [{ +An encoding for tensors that have been produced by WMMA matrix core instructions, +available on AMD Radeon GPUs of RDNA architectures. +- A `version` parameter specifies instruction version to lower in. The data + distribution within one warp is also depends on it. Following architectures are + supported: + - 1: gfx11 + - 2: gfx12 +- A `warpsPerCTA` parameter characterizes data distribution between warps. + An important limitation of WMMA for layout is a shape for tiles processed + by a single warp. It is [16, 16]. + This encoding assumes specific access to matrix elements by threads. + +Example: +Suppose we have a tensor with shape [32, 64], `warpsPerCTA` set to [2, 2]. +Matrix elements represent which lane owns the element. Currently only wave32 mode +is supported. + +// ----------------------------------- version = 1 ----------------------------------- // + +Row | warp 0 warp 1 + |/-------------------^-------------------\ /-------------------^-------------------\ +0 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +1 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +2 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +3 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + | ... ... ... ... +14 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +15 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + + | warp 2 warp 3 +16 |/-------------------^-------------------\ /-------------------^-------------------\ +17 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +18 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +19 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +20 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + | ... ... ... ... +30 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +31 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + +// ------------------------ version = 2, isTransposed = false ------------------------ // + +Row | warp 0 warp 1 + |/--------^---------\ /---------^--------\ +0 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +1 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +.. | ... ... +6 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +7 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +8 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +9 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +.. | ... ... +14 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +15 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] + | + | warp 2 warp 3 + |/--------^---------\ /---------^--------\ +16 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +17 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +.. | ... ... +22 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +23 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +24 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +25 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +.. | ... ... +30 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +31 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] + +// ------------------------ version = 2, isTransposed = true ------------------------ // + + | warp 0 warp 1 + |/----------------^----------------\ /-------^-------\ +Col>| 0 1 2 3 4 5 6 7 8 ... 15 16 17 18 ... 32 +Row | +0 |[0 0 0 0 0 0 0 0 16 ... 16] [0 0 0 ... 16] +1 |[1 1 1 1 1 1 1 1 17 ... 17] [1 1 1 ... 17] +.. | ... ... +14 |[14 14 14 14 14 14 14 14 30 ... 30] [14 14 14 ... 30] +15 |[15 15 15 15 15 15 15 15 31 ... 31] [15 15 15 ... 31] + | + | warp 2 warp 3 + |/----------------^----------------\ /-------^-------\ +16 |[0 0 0 0 0 0 0 0 16 ... 16] [0 0 0 ... 16] +17 |[1 1 1 1 1 1 1 1 17 ... 17] [1 1 1 ... 17] +.. | ... ... +30 |[14 14 14 14 14 14 14 14 30 ... 30] [14 14 14 ... 30] +31 |[15 15 15 15 15 15 15 15 31 ... 31] [15 15 15 ... 31] + }]; + + let parameters = ( + ins + "unsigned": $version, + "bool":$isTransposed, + ArrayRefParameter<"unsigned">:$warpsPerCTA, + "CTALayoutAttr":$CTALayout + ); + + let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = extraDistributedDeclaration # [{ + SmallVector getElemsPerInstrForOperands() const; + SmallVector getRepForOperand(ArrayRef operandShape, + Type elemType, int kWidth, int opIdx) const; + SmallVector getRepOrderForOperand(int opIdx) const; + unsigned getKWidthForOperands() const; + static SmallVector getMNKDimPerInstr(); + }]; +} + +def SunriseMmaEncodingAttr : DistributedEncoding<"SunriseMmaEncoding", "sunrise_mma_encoding", [MmaEncodingTrait]> { + let mnemonic = "sunrise_mma"; + + let description = [{ + TODO ... + }]; + + let parameters = ( + ins + "unsigned":$versionMajor, + "unsigned":$versionMinor, + ArrayRefParameter<"unsigned">:$warpsPerCTA, + // "int":$numWarps, + "CTALayoutAttr":$CTALayout, + "SunriseMmaEncodingAttr::TMMAOutLayout":$outLayout, + "unsigned":$inputElemBitWidth, + "unsigned":$outputElemBitWidth + ); + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = extraDistributedDeclaration # [{ + enum class TMMAOutLayout : unsigned { // 只用于mma中c, d的布局 + NotAvailable, + Row_2B, Col_2B, + ARow_4B_8x4, BRow_4B_4x8 + }; + + SmallVector getRepForOperand(ArrayRef operandShape, Type elemType, int opIdx) const; + SmallVector getRepOrderForOperand(int opIdx) const; + SmallVector getInstrShapeForOperand(unsigned opIdx) const; + SmallVector getShapePerCTATileForOperand(unsigned opIdx) const; + SmallVector getShapePerCTATile() const; + }]; +} + +def NvidiaMmaEncodingAttr : DistributedEncoding<"NvidiaMmaEncoding", "nvidia_mma_encoding", [MmaEncodingTrait]> { + let mnemonic = "nvidia_mma"; + + let description = [{ +An encoding for tensors that have been produced by tensor cores. + +It is characterized by two parameters: +- A 'versionMajor' which specifies the generation the tensor cores + whose output is being partitioned: + - 1 for first-gen tensor cores (Volta), and + - 2 for second-gen tensor cores (Turing/Ampere). +- A 'versionMinor' which indicates the specific layout of a tensor core + generation, e.g. for Volta, there might be multiple kinds of layouts + annotated by 0,1,2 and so on. +- A `blockTileSize` to indicate how data should be partitioned between warps. + +// -------------------------------- version = 1 --------------------------- // + +For first-gen tensor cores, the implicit warpTileSize is [16, 16]. +Note: the layout is different from the recommended in PTX ISA +https://docs.nvidia.com/cuda/parallel-thread-execution/index.html +(mma.884 section, FP32 accumulator). + +For example, when versionMinor=1, the matrix L corresponding to +blockTileSize=[32,16] is: + + warp 0 +--------------------------------/\------------------------------- +[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ] +[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ] +[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ] +[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ] +[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ] +[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ] +[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ] +[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ] +[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ] +[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ] +[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ] +[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ] +[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ] +[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ] +[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ] +[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ] + + warp 1 = warp0 + 32 +--------------------------------/\------------------------------- +[ 32 32 34 34 40 40 42 42 32 32 34 34 40 40 42 42 ] +[ 33 33 35 35 41 41 43 43 33 33 35 35 41 41 43 43 ] +[ ............................................................... ] + + +// -------------------------------- version = 2 --------------------------- // + +For second-gen tensor cores, the implicit warpTileSize is [16, 8]. +Information about this layout can be found in the official PTX documentation +https://docs.nvidia.com/cuda/parallel-thread-execution/index.html +(mma.16816 section, FP32 accumulator). + +For example, the matrix L corresponding to blockTileSize=[32,16] is: + warp 0 warp 2 +-----------------/\------------- ----------------/\------------- +[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35 +[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39 +[ .............................. .............................. +[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63 +[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35 +[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39 +[ .............................. .............................. +[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63 + + warp 1 warp 3 +----------------/\------------- ----------------/\------------- +[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99 +[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103 +[ .............................. ............................... +[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127 +[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99 +[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103 +[ .............................. ............................... +[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127 + +}]; + + let parameters = ( + ins + "unsigned":$versionMajor, + "unsigned":$versionMinor, + ArrayRefParameter<"unsigned">:$warpsPerCTA, + "CTALayoutAttr":$CTALayout, + ArrayRefParameter<"unsigned">:$instrShape + ); + + + let extraClassDeclaration = extraDistributedDeclaration # [{ + bool isVolta() const; + bool isTuring() const; + bool isAmpere() const; + bool isHopper() const; + + SmallVector getRepForOperand(ArrayRef shape, + int bitwidth, int kWidth, + int opIdx) const; + SmallVector getRepOrderForOperand(int opIdx) const; + }]; + + let hasCustomAssemblyFormat = 1; +} + +def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> { + let mnemonic = "slice"; + + let description = [{ + Given a `parent` layout and a `dim`, squeezes the given `dim` in the `parent` + layout and distributes values in a tensor T according to the new layout. + + For example, given + + T = [x x x x x x x x] + L_parent = [0 1 2 3 ] + [4 5 6 7 ] + [8 9 10 11] + [12 13 14 15] (with 16 CUDA threads) + + With dim = 0, squeezing out dim 0, we have + L = [{0,4,8,12}, {1,5,9,13}, {2,6,10,14}, {3,7,11,15} ] + + Then the data of T would be distributed as follow between the 16 CUDA threads: + L(T) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15}, {0,4,8,12} , ..., {3,7,11,15} ] + + With dim = 1, squeezing out dim 1, we have + L = [ {0,1,2,3}, {4,5,6,7}, {8,9,10,11}, {12,13,14,15} ] + + Then the data of T would be distributed as follow between the 16 CUDA threads: + L = [ {0,1,2,3}, {4,5,6,7}, ..., {12,13,14,15}, {0,1,2,3}, ..., {12,13,14,15} ] + + This is useful for constructing the inverse layout of an expand_dims operation + during some optimization passes. + }]; + + let parameters = ( + ins + "unsigned":$dim, + "DistributedEncodingTrait":$parent + ); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + template + SmallVector paddedShape(ArrayRef shape) const; + }]; + + let hasCustomAssemblyFormat = 1; +} + +def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding"> { + let mnemonic = "dot_op"; + + let description = [{ +In the TritonGPU dialect, given `d = tt.dot a, b, c` tt.dot's operands a and b +must be of DotOperandEncodingAttr layout, if the dot is MMA v1 or v2 (i.e. +pre-Hopper). For MMA v3, the operands are *almost always* in a regular shared +encoding, but sometimes the LHS is also a dot-operand encoding. + +a's opIdx is 0, b's opIdx is 1. + +The parent field is the layout of d. + +kWidth defines number of consecutive elements stored by one thread along k dimension. +Some layouts do not use this parameter, either because they have a fixed number of +elements along the K dim, or they use all elements of the tensor along the K dim. + +# WGMMA Notes +We require kWidth to be provided for Hopper because the dtype at loading might be +different from the dtype at WGMMA, due to casting. The kWidth is determined by the +dtype at WGMMA. + +The encoded tensor consists of operand A for possibly multiple wgmma instructions. +For each wgmma, each warp in a warp group feeds a single "warp matrix" +Each warp matrix consists of 2x2 "quads". +Each thread holds several elements in each quad. Right before a wgmma, +the sum of bitwidth of +the elements in each quad should add up to 32. + +These values are stored unrolled in `elements`. +The ordering of dimensions is as follows by convention: +batch (only 1 batch for Hopper currently) +matM (m-index of the "warp matrix") +matK (k-index of the "warp matrix") +quadK (k-index of the "quad" in the core matrix) +quadM (m-index of the "quad" in the core matrix) +vecIdx (index of the element in the quad; this is always along the k-dim) + }]; + + let parameters = ( + ins + "unsigned":$opIdx, + "Attribute":$parent, + DefaultValuedParameter<"unsigned", "0">:$kWidth + ); + + let builders = [ + AttrBuilder<(ins "unsigned":$opIdx, + "Attribute":$parent, + "Type":$eltTy), [{ + NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast(parent); + if (!parentAttr || (!parentAttr.isAmpere() && !parentAttr.isHopper())) + return $_get(context, opIdx, parent, 0); + // For MMAV2 and V3 + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + unsigned kWidth = 32 / bitwidth; + return $_get(context, opIdx, parent, kWidth); + }]> + ]; + + let assemblyFormat = "`<` `{` struct(params) `}` `>`"; + let genVerifyDecl = 1; + let extraClassDeclaration = extraDistributedDeclaration; +} + +def TTG_SharedMemorySpace : AttrDef { + let mnemonic = "shared_memory"; + let description = [{ + Attribute to indicate that the memory descriptor points to shared memory. + }]; +} + +#endif diff --git a/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/IR/sunrise_Dialect.h b/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/IR/sunrise_Dialect.h new file mode 100644 index 000000000..000a24884 --- /dev/null +++ b/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/IR/sunrise_Dialect.h @@ -0,0 +1,10 @@ +#ifndef SUNRISE_TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ +#define SUNRISE_TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ + +#define FLAGTREE_SPEC_Dialect_TritonGPU_IR_Dialect_functions +#define FLAGTREE_SPEC_Dialect_TritonGPU_IR_Dialect_cpp + +#define FLAGTREE_SPEC_BackendMmaEncodingAttr \ + ::mlir::triton::gpu::SunriseMmaEncodingAttr + +#endif // SUNRISE_TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ diff --git a/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/IR/sunrise_LinearLayoutConversions.h b/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/IR/sunrise_LinearLayoutConversions.h new file mode 100644 index 000000000..7c09f3703 --- /dev/null +++ b/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/IR/sunrise_LinearLayoutConversions.h @@ -0,0 +1,10 @@ +#ifndef SUNRISE_TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H +#define SUNRISE_TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H + +#define FLAGTREE_SPEC_Triton_Dialect_TritonGPU_IR_sunrise_LinearLayoutConversion + +#define FLAGTREE_SPEC_LinearLayoutConversions_toLinearLayout +#define FLAGTREE_SPEC_LinearLayoutConversions_sunrisemmaDotOperandToLinearLayout +#define FLAGTREE_SPEC_LinearLayoutConversions_SunriseMmaEncodingAttr_toLinearLayout + +#endif//SUNRISE_TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H diff --git a/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/Pipeliner/sunrise_PipeliningUtility.h b/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/Pipeliner/sunrise_PipeliningUtility.h new file mode 100644 index 000000000..1c8db1e0f --- /dev/null +++ b/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/Pipeliner/sunrise_PipeliningUtility.h @@ -0,0 +1,8 @@ +#ifndef SUNRISE_TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELININGUTILITY_H_ +#define SUNRISE_TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELININGUTILITY_H_ + +#define FLAGTREE_SPEC_Dialect_TritonGPU_Transforms_PipeliningUtility + +#define FLAGTREE_SPEC_Dialect_TritonGPU_Transforms_PipeliningUtility_predicateOp + +#endif // SUNRISE_TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELININGUTILITY_H_ diff --git a/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/sunrise_Coalesce.h b/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/sunrise_Coalesce.h new file mode 100644 index 000000000..f2729f6a3 --- /dev/null +++ b/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/sunrise_Coalesce.h @@ -0,0 +1,6 @@ +#ifndef SUNRISE_TRITON_DIALECT_TRITONGPU_TRANSFORMS_COALESCE_H +#define SUNRISE_TRITON_DIALECT_TRITONGPU_TRANSFORMS_COALESCE_H + +#define FLAGTREE_SPEC_Triton_Dialect_TritonGPU_Transforms_Sunrise_Coalesce + +#endif//SUNRISE_TRITON_DIALECT_TRITONGPU_TRANSFORMS_COALESCE_H diff --git a/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/sunrise_Prefetch.h b/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/sunrise_Prefetch.h new file mode 100644 index 000000000..c38994ee0 --- /dev/null +++ b/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/sunrise_Prefetch.h @@ -0,0 +1,6 @@ +#ifndef SUNRISE_TRITON_DIALECT_TRITONGPU_TRANSFORM_PREFETCH_H_ +#define SUNRISE_TRITON_DIALECT_TRITONGPU_TRANSFORM_PREFETCH_H_ + +#define FLAGTREE_SPEC_triton_Dialect_TritonGPU_Transforms_sunrise_Prefetch + +#endif//SUNRISE_TRITON_DIALECT_TRITONGPU_TRANSFORM_PREFETCH_H_ diff --git a/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/sunrise_RemoveLayoutConversions.h b/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/sunrise_RemoveLayoutConversions.h new file mode 100644 index 000000000..e4a62d3e5 --- /dev/null +++ b/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/sunrise_RemoveLayoutConversions.h @@ -0,0 +1,8 @@ +#ifndef SUNRISE_TRITON_DIALECT_TRITONGPU_TRANSFORMS_REMOVELAYOUTCONVERSIONS_H +#define SUNRISE_TRITON_DIALECT_TRITONGPU_TRANSFORMS_REMOVELAYOUTCONVERSIONS_H + +#define FLAGTREE_SPEC_triton_Dialect_TritonGPU_Transforms_sunrise_RemoveLayoutConversion + +#define FLAGTREE_SPEC_LayoutPropagation_propagateToUsers + +#endif//SUNRISE_TRITON_DIALECT_TRITONGPU_TRANSFORMS_REMOVELAYOUTCONVERSIONS_H diff --git a/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/sunrise_Utility.h b/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/sunrise_Utility.h new file mode 100644 index 000000000..ddb7729b4 --- /dev/null +++ b/third_party/sunrise/backend/spec/include/triton/Dialect/TritonGPU/Transforms/sunrise_Utility.h @@ -0,0 +1,8 @@ +#ifndef SUNRISE_TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H +#define SUNRISE_TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H + +#define FLAGTREE_SPEC_triton_Dialect_TritonGPU_Transforms_sunrise_Utility + +#define FLAGTREE_SPEC_getNumElementsPerThread + +#endif//SUNRISE_TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H diff --git a/third_party/sunrise/backend/spec/lib/CMakeLists.txt b/third_party/sunrise/backend/spec/lib/CMakeLists.txt new file mode 100644 index 000000000..629c08af6 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) diff --git a/third_party/sunrise/backend/spec/lib/Conversion/CMakeLists.txt b/third_party/sunrise/backend/spec/lib/Conversion/CMakeLists.txt new file mode 100644 index 000000000..6fda78900 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/CMakeLists.txt @@ -0,0 +1,2 @@ +#add_subdirectory(TritonToTritonGPU) +#add_subdirectory(TritonGPUToLLVM) diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp new file mode 100644 index 000000000..92de0a913 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp @@ -0,0 +1,49 @@ +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir::triton::gpu { +#define GEN_PASS_DEF_ALLOCATESHAREDMEMORY +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton::gpu + +namespace { +struct AllocateSharedMemory + : public mlir::triton::gpu::impl::AllocateSharedMemoryBase< + AllocateSharedMemory> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + ModuleAllocation allocation(mod); + + mod.walk([&](FunctionOpInterface funcOp) { + auto *funcAllocation = allocation.getFuncData(funcOp); + funcOp.walk([&](Operation *op) { + auto oBufferId = funcAllocation->getBufferId(op); + int offset = -1; + if (oBufferId != Allocation::InvalidBufferId) + offset = funcAllocation->getOffset(oBufferId); + else if (op->getNumResults() == 1) { + Value value = op->getResult(0); + auto vBufferId = funcAllocation->getBufferId(value); + if (vBufferId != Allocation::InvalidBufferId) + offset = funcAllocation->getOffset(vBufferId); + } + if (offset == -1) + return; + op->setAttr("allocation.offset", + IntegerAttr::get(IntegerType::get(ctx, 32), offset)); + }); + return WalkResult::skip(); + }); + mod->setAttr("ttg.shared", + mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32), + allocation.getSharedMemorySize())); + } +}; +} // namespace diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp new file mode 100644 index 000000000..c90aebf28 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp @@ -0,0 +1,200 @@ +#include "mlir/IR/BuiltinOps.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir::triton::gpu { +#define GEN_PASS_DEF_TRITONGPUALLOCATEWARPGROUPS +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton::gpu + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// Given a `ttg.warp_specialize` with a certain number of existing warps, pad it +// with extra warps until it has the same number of full warp groups as the +// largest partitioning. This ensures that all threads can be present to +// surrender registers. +static void padToMaxWarpGroups(WarpSpecializeOp op, int numExtraWarpGroups) { + int numExtraWarps = op.getTotalPartitionWarps(); + int warpsToAdd = numExtraWarpGroups * 4 - numExtraWarps; + assert(warpsToAdd >= 0); + + // Fill it with powers of 2. + SmallVector paddingPartitionSizes; + while (warpsToAdd > 0) { + int paddingSize = llvm::NextPowerOf2(warpsToAdd) / 2; + paddingPartitionSizes.push_back(paddingSize); + warpsToAdd -= paddingSize; + } + + auto partitions = cast( + op.getPartitionOpHolder().front().front()); + OperationState state(partitions.getLoc(), partitions.getOperationName()); + for (Region *region : partitions.getRegions()) + state.addRegion()->takeBody(*region); + + SmallVector partitionNumWarps(op.getPartitionNumWarps()); + for (int paddingSize : paddingPartitionSizes) { + partitionNumWarps.push_back(paddingSize); + + Block &body = state.addRegion()->emplaceBlock(); + for (Value capture : op.getExplicitCaptures()) + body.addArgument(capture.getType(), capture.getLoc()); + OpBuilder b(op.getContext()); + b.setInsertionPointToStart(&body); + b.create(op.getLoc()); + } + op.setPartitionNumWarps(partitionNumWarps); + + // Set the requested registers to low for the padded partitions that do + // nothing. + if (auto reqRegs = op.getRequestedRegisters()) { + SmallVector newReqRegs(*reqRegs); + newReqRegs.append(paddingPartitionSizes.size(), 16); + op.setRequestedRegisters(newReqRegs); + } + + OpBuilder b(partitions); + b.create(state); + partitions.erase(); +} + +namespace { +struct AllocateWarpGroups + : public mlir::triton::gpu::impl::TritonGPUAllocateWarpGroupsBase< + AllocateWarpGroups> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + + // First determine the maximum number of extra warps. + int maxExtraWarps = 0; + mod.walk([&](WarpSpecializeOp op) { + maxExtraWarps = std::max(maxExtraWarps, op.getTotalPartitionWarps()); + }); + + // Round this up to the nearest warpgroup (multiple of 4) and then pad each + // `ttg.warp_specialize` to the nearest warpgroup. + int numExtraWarpGroups = llvm::divideCeil(maxExtraWarps, 4); + mod.walk([&](WarpSpecializeOp op) { + padToMaxWarpGroups(op, numExtraWarpGroups); + }); + + // Determine the maximum number of registers per thread. This may have + // been set by the user. + int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); + int baseNumWarps = lookupNumWarps(mod); + int maxnreg; + if (auto maxnregAttr = + mod->getAttrOfType(AttrMaxRegistersName)) { + maxnreg = maxnregAttr.getInt(); + } else { + // Assume the user wants to use all 64K registers. + maxnreg = (64 * 1024) / (baseNumWarps + numExtraWarpGroups * 4) / + threadsPerWarp; + maxnreg = maxnreg / 8 * 8; + } + + struct WarpGroupInfo { + SmallVector partitions; + int maxRequestedRegs = 0; + unsigned numWarps = 0; + }; + struct WarpGroupPartition { + int startId; + Region *partition; + int32_t estRegs; + int numWarps; + }; + + // Compute the total number of warps required at any given time. + mod.walk([&](WarpSpecializeOp op) { + ArrayRef arr = op.getPartitionNumWarps(); + + // Allocate the start IDs such that the largest warpgroups have lower + // starting warp IDs. + // FIXME: Handle aligning warp group IDs to 4 for TMEM. + SmallVector> idxAndSize; + for (auto [i, size] : llvm::enumerate(arr)) + idxAndSize.emplace_back(i, size); + llvm::sort(idxAndSize, + [&](auto lhs, auto rhs) { return lhs.second > rhs.second; }); + + SmallVector startIds(arr.size()); + int startId = baseNumWarps; + for (auto [i, size] : idxAndSize) { + startIds[i] = startId; + startId += size; + } + op.setWarpGroupStartIds(startIds); + + // Require that an estimate has been set and that we have even warpgroups. + auto regsAttr = op.getRequestedRegisters(); + if (!regsAttr || op.getTotalPartitionWarps() % 4 != 0) + return; + + // Group the partitions into warpgroups. + SmallVector orderedPartitions; + for (auto [startId, partition, estRegs, numWarps] : + llvm::zip(startIds, op.getPartitionRegions(), *regsAttr, arr)) + orderedPartitions.push_back({startId, partition, estRegs, numWarps}); + llvm::sort(orderedPartitions, + [&](auto lhs, auto rhs) { return lhs.startId < rhs.startId; }); + + // Iterate over the partitions and assign them to warp groups. Determine + // the maximum number of requested registers per warp group. + SmallVector warpGroups; + for (auto [startId, partition, estRegs, numWarps] : orderedPartitions) { + if (startId % 4 == 0) { + warpGroups.push_back(WarpGroupInfo{}); + } + warpGroups.back().partitions.push_back(partition); + // Round up the nearest multiple of 8. + int estRegsCeil8 = llvm::divideCeil(estRegs, 8) * 8; + warpGroups.back().maxRequestedRegs = + std::max(warpGroups.back().maxRequestedRegs, estRegsCeil8); + warpGroups.back().numWarps += numWarps; + } + + // Compute the register deficit over the partition warp groups. + int registerBudget = maxnreg * baseNumWarps * threadsPerWarp; + for (const WarpGroupInfo &wg : warpGroups) { + assert(wg.numWarps % 4 == 0); + registerBudget += + (maxnreg - wg.maxRequestedRegs) * wg.numWarps * threadsPerWarp; + } + if (registerBudget <= 0) + return; + + // Determine the number of extra registers that we can distribute to the + // default warp group. + int leftover = registerBudget / (baseNumWarps * threadsPerWarp); + // Round down to the nearest multiple of 8. + leftover = leftover / 8 * 8; + if (leftover < 24) + return; // too few registers + + // Generate setmaxnreg in each partition according to its warp group. + SmallVector maxnregsPerPartition(1 + arr.size()); + for (const WarpGroupInfo &wg : warpGroups) { + for (Region *region : wg.partitions) { + maxnregsPerPartition[1 + region->getRegionNumber()] = + wg.maxRequestedRegs; + } + } + // Set the register usage for the default warp group. + maxnregsPerPartition.front() = leftover; + op.setActualRegisters(maxnregsPerPartition); + + // Set the initial max number of registers. This is needed for PTXAS to + // cooperate. + mod->setAttr(AttrMaxRegistersName, + Builder(op.getContext()).getI32IntegerAttr(maxnreg)); + }); + + Builder b(&getContext()); + mod->setAttr("ttg.total-num-warps", + b.getI32IntegerAttr(baseNumWarps + numExtraWarpGroups * 4)); + } +}; +} // namespace diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp new file mode 100644 index 000000000..1a5e0809b --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp @@ -0,0 +1,103 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; + +struct AssertOpConversion : public ConvertOpToLLVMPattern { + explicit AssertOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto ctx = rewriter.getContext(); + auto typeConverter = getTypeConverter(); + auto elems = unpackLLElements(loc, adaptor.getCondition(), rewriter); + auto elemTy = elems[0].getType(); + Value condition = b.int_val(elemTy.getIntOrFloatBitWidth(), 0); + for (auto elem : elems) { + if (elemTy.isSignedInteger() || elemTy.isSignlessInteger()) { + condition = b.or_( + condition, + b.icmp_eq(elem, rewriter.create( + loc, elemTy, rewriter.getZeroAttr(elemTy)))); + } else { + assert(false && "Unsupported type for assert"); + return failure(); + } + } + llAssert(op, condition, adaptor.getMessage(), rewriter); + if (isa(op.getCondition().getType())) { + // Add a barrier to avoid a race condition in case an assert is followed + // by an op that may trap if the assert condition is true. Since the + // tensor in those two operations may have different layout we need to + // make sure all the threads are done executing the assert before going to + // the next op. + b.barrier(); + } + rewriter.eraseOp(op); + return success(); + } + // op: the op at which the assert is inserted. Unlike printf, we need to + // know about the op to split the block. + void llAssert(Operation *op, Value condition, StringRef message, + ConversionPatternRewriter &rewriter) const { + + auto ctx = rewriter.getContext(); + auto loc = op->getLoc(); + + StringRef file = "unknown"; + StringRef func = "unknown"; + int line = 0; + int col = 0; + + while (auto callLoc = dyn_cast(loc)) + loc = callLoc.getCallee(); + + if (auto fileLineColLoc = dyn_cast(loc)) { + file = fileLineColLoc.getFilename(); + line = fileLineColLoc.getLine(); + col = fileLineColLoc.getColumn(); + } + + // #block1 + // if (condition) { + // #block2 + // __assertfail(message); + // } + // #block3 + Block *prevBlock = op->getBlock(); + + Block *ifBlock = rewriter.splitBlock(prevBlock, op->getIterator()); + rewriter.setInsertionPointToStart(ifBlock); + targetInfo.assertFail(rewriter, loc, message, file, func, line); + + // Split a block after the call. + Block *thenBlock = rewriter.splitBlock(ifBlock, op->getIterator()); + rewriter.setInsertionPointToEnd(ifBlock); + rewriter.create(loc, thenBlock); + rewriter.setInsertionPointToEnd(prevBlock); + rewriter.create(loc, condition, ifBlock, thenBlock); + rewriter.setInsertionPointToStart(thenBlock); + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateAssertOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..f6bc60136 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -0,0 +1,39 @@ +add_triton_library(FlagTree_sunrise_TritonGPUToLLVM + DotOpToLLVM/FMA.cpp + DotOpToLLVM/FMADotUtility.cpp + AllocateSharedMemory.cpp + AllocateWarpGroups.cpp + AssertOpToLLVM.cpp + ControlFlowOpToLLVM.cpp + ConvertLayoutOpToLLVM.cpp + ElementwiseOpToLLVM.cpp + FuncOpToLLVM.cpp + GatherOpToLLVM.cpp + GlobalScratchMemoryAllocation.cpp + HistogramOpToLLVM.cpp + MakeRangeOpToLLVM.cpp + MemoryOpToLLVM.cpp + PrintOpToLLVM.cpp + ReduceOpToLLVM.cpp + ScanOpToLLVM.cpp + SPMDOpToLLVM.cpp + TypeConverter.cpp + Utility.cpp + ViewOpToLLVM.cpp + + DEPENDS + TritonGPUConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRGPUDialect + MLIRGPUToNVVMTransforms + MLIRGPUToROCDLTransforms + MLIRGPUTransforms + TritonAnalysis + TritonIR + TritonGPUIR + TritonGPUTransforms + TritonNvidiaGPUTransforms +) diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp new file mode 100644 index 000000000..6ad3e7978 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp @@ -0,0 +1,162 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct ReturnOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = op->getParentOfType(); + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (funcOp->hasAttr("nvvm.kernel")) { + // A GPU kernel + if (op.getNumOperands() > 0) { + return rewriter.notifyMatchFailure( + op, "Kernel functions do not support return with operands"); + } + rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), + op->getAttrs()); + } else { + // A device function + LLVM::ReturnOp newOp; + if (adaptor.getOperands().size() < 2) { + // Single or no return value. + newOp = + rewriter.create(op.getLoc(), adaptor.getOperands()); + } else { + // Pack the results into a struct. + auto packedResultsTy = this->getTypeConverter()->packFunctionResults( + funcOp.getResultTypes()); + Value packedResults = + rewriter.create(op.getLoc(), packedResultsTy); + for (auto it : llvm::enumerate(adaptor.getOperands())) { + packedResults = b.insert_val(packedResultsTy, packedResults, + it.value(), it.index()); + } + newOp = rewriter.create(op.getLoc(), packedResults); + } + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); + } + return success(); + } +}; + +// CallOpInterfaceLowering is adapted from +// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485 +struct CallOpConversion : public ConvertOpToLLVMPattern { + CallOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto promotedOperands = promoteOperands(callOp, adaptor, rewriter); + auto newCallOp = + convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter); + if (!newCallOp) + return failure(); + auto results = getCallOpResults(callOp, newCallOp, rewriter); + rewriter.replaceOp(callOp, results); + return success(); + } + +private: + SmallVector + promoteOperands(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Get the last argument of the caller, which is the current stack pointer + // of shared memory and append it to the operands of the callOp. + auto loc = callOp.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto caller = callOp->getParentOfType(); + auto promotedOperands = this->getTypeConverter()->promoteOperands( + callOp.getLoc(), /*opOperands=*/callOp->getOperands(), + adaptor.getOperands(), rewriter); + if (!caller->hasAttr("allocation.offset") || + !callOp->hasAttr("allocation.offset")) { + auto base = LLVM::getStackPointer(rewriter, caller); + promotedOperands.push_back(base); + } else { + auto base = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, callOp); + promotedOperands.push_back(base); + } + + auto opOffsetAttr = callOp->getAttrOfType( + "ttg.global_scratch_memory_offset"); + Value opOffsetVal; + if (opOffsetAttr) { + auto opOffset = opOffsetAttr.getValue().getZExtValue(); + opOffsetVal = b.i32_val(opOffset); + } + + promotedOperands.push_back(LLVM::getGlobalScratchPtr( + loc, rewriter, targetInfo, caller, opOffsetVal)); + return promotedOperands; + } + + LLVM::CallOp + convertCallOpToLLVMCallOp(triton::CallOp callOp, + ArrayRef promotedOperands, + ConversionPatternRewriter &rewriter) const { + // Pack the result types into a struct. + Type packedResult = nullptr; + unsigned numResults = callOp.getNumResults(); + auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); + + if (numResults != 0) { + if (!(packedResult = + this->getTypeConverter()->packFunctionResults(resultTypes))) + return nullptr; + } + auto newCallOp = rewriter.create( + callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), + promotedOperands, callOp->getAttrs()); + newCallOp.getProperties().setOpBundleSizes( + rewriter.getDenseI32ArrayAttr({})); + newCallOp.getProperties().setOperandSegmentSizes( + {static_cast(promotedOperands.size()), 0}); + return newCallOp; + } + + SmallVector + getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp, + ConversionPatternRewriter &rewriter) const { + auto numResults = callOp.getNumResults(); + SmallVector results; + if (numResults < 2) { + // If < 2 results, packing did not do anything and we can just return. + results.append(newCallOp.result_begin(), newCallOp.result_end()); + } else { + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + results.push_back(rewriter.create( + callOp.getLoc(), newCallOp->getResult(0), i)); + } + } + return results; + } + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateControlFlowOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp new file mode 100644 index 000000000..473f79170 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -0,0 +1,463 @@ +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Support/LogicalResult.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton::gpu; + +struct ConvertLayoutOpUsingLinearLayoutsConversion + : public ConvertOpToLLVMPattern { + const TargetInfoBase &targetInfo; + + // Set benefit to 2 so that this pattern applies before other convert-layout + // conversions. TODO(jlebar): Eventually we want this to be the only pattern. + explicit ConvertLayoutOpUsingLinearLayoutsConversion( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + + const auto &shape = op.getType().getShape(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + + LinearLayout conversion = minimalCvtLayout(srcTy, dstTy); + LinearLayout srcLayout = + toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + LinearLayout dstLayout = + toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); + + StringAttr kBlock = str_attr("block"); + StringAttr kWarp = str_attr("warp"); + StringAttr kLane = str_attr("lane"); + StringAttr kRegister = str_attr("register"); + + assert(to_vector(conversion.getInDimNames()) == + to_vector(conversion.getOutDimNames())); + auto dims = conversion.getInDimNames(); + if (llvm::is_contained(dims, kBlock)) { + // Case 1: Transfer between values in different CTAs. + // This requires moving values through distributed shared memory. + return rewriter.notifyMatchFailure( + op, "NYI: Transfer between different CTAs"); + } else if (llvm::is_contained(dims, kWarp)) { + // Case 2: Transfer between values in the same CTA, in which case we move + // values through shared memory. + return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); + } else if (llvm::is_contained(dims, kLane)) { + // Case 3. Transfer between values in the same warp, in which case we try + // to move values using warp shuffles, though if the pattern is + // complicated enough we may fall back to using shared memory + if (auto decomposedCvt = + getWarpLayoutConvertDecomposition(srcTy, dstTy)) { + transferWithinWarp(op, *decomposedCvt, adaptor, rewriter); + return success(); + } + // TODO: Since data is only transferred within a warp over shared memory, + // we should use `bar.warp.sync` instead of `barrier`, which will improve + // latency when warps issue barriers on different cycles. + return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); + } else if (llvm::is_contained(dims, kRegister)) { + // Case 4. Transfer between values in the same thread, in which case we + // simply reorder the elements of adaptor.getSrc(). + return transferWithinThread(op, conversion, adaptor, rewriter); + } else { + // Cast 5. The two layouts are equivalent. We should probably remove + // these in RemoveLayoutConversion. + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } + } + + LogicalResult + transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + StringAttr kRegister = str_attr("register"); + assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); + + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector outVals(conversion.getInDimSize(kRegister)); + for (int i = 0; i < outVals.size(); i++) { + auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second; + outVals[i] = inVals[srcIdx]; + } + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + return success(); + } + + LogicalResult transferWithinBlock(ConvertLayoutOp op, + const LinearLayout &srcLayout, + const LinearLayout &dstLayout, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + + assert(cvtNeedsSharedMemory(srcTy, dstTy)); + + SmallVector inVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + assert(!inVals.empty()); + + // We munge the input values by converting i (n<8) elements to i8 and + // pointers to i64. This is necessary because TargetInfo::loadDShared and + // storeDShared can't handle vectors of pointers or sub-byte elements. + auto elemTy = srcTy.getElementType(); + auto isSubByteInt = + elemTy.isInteger() && elemTy.getIntOrFloatBitWidth() < 8; + auto isPtr = isa(elemTy); + auto llvmElemTyOrig = getTypeConverter()->convertType(elemTy); + if (isSubByteInt) + elemTy = IntegerType::get(elemTy.getContext(), 8); + else if (isPtr) + elemTy = IntegerType::get(elemTy.getContext(), 64); + auto llvmElemTy = getTypeConverter()->convertType(elemTy); + + // Munge input values + for (const auto &it : llvm::enumerate(inVals)) { + if (isSubByteInt) { + inVals[it.index()] = b.zext(llvmElemTy, it.value()); + } else if (isPtr) { + inVals[it.index()] = b.ptrtoint(llvmElemTy, it.value()); + } + } + + // Pretty sure this is the identity function ATM + // It'd be better to simply call `quotient({kBlock})` and + // remove kBlock from transferWithinBlockImpl + auto srcLayoutWithinBlock = getLayoutWithinBlock(srcLayout); + auto dstLayoutWithinBlock = getLayoutWithinBlock(dstLayout); + SmallVector outVals = + transferWithinBlockImpl(inVals, op, srcLayoutWithinBlock, + dstLayoutWithinBlock, adaptor, rewriter); + + // Unmunge output values + for (const auto &it : llvm::enumerate(outVals)) { + if (isSubByteInt) { + outVals[it.index()] = b.trunc(llvmElemTyOrig, it.value()); + } else if (isPtr) { + outVals[it.index()] = b.inttoptr(llvmElemTyOrig, it.value()); + } + } + + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + return success(); + } + + // Use warp shuffles to implement a layout conversion where data only needs to + // be moved within warps. + void transferWithinWarp(ConvertLayoutOp op, + DecomposedWarpConversion decomposed, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + + SmallVector + transferWithinBlockImpl(ArrayRef inVals, ConvertLayoutOp op, + const LinearLayout &srcLayout, + const LinearLayout &dstLayout, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + StringAttr kOffset = str_attr("offset"); + StringAttr kIteration = str_attr("iteration"); + + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + + auto scratchConfig = + getScratchConfigForCvt(op.getSrc().getType(), op.getType()); + auto tensorShapePerCTA = convertType(getShapePerCTA( + op.getSrc().getType().getEncoding(), op.getType().getShape())); + // Input dims: [offset, iteration, block] + // Output dims: dimN-1, dimN-2, ..., dim0, where N is obtained from repShape + LinearLayout sharedLayout = chooseShemLayoutForRegToRegConversion( + ctx, tensorShapePerCTA, scratchConfig.repShape, scratchConfig.order); + + // Layout for the store from registers to shared memory. + // + // Note: If two threads in the same warp write to the same shmem offset, the + // hardware resolves that without a stall or a bank conflict. Therefore we + // don't need to avoid duplicate writes. + // Input dims: [reg, lane, warp] + // Output dims: [offset, iteration] + bool isStMatrix = targetInfo.canUseStMatrix( + op.getSrc().getType(), scratchConfig.repShape, + scratchConfig.paddedRepShape, scratchConfig.order, + /*swizzleByteSize=*/0); + LinearLayout shmemStoreLayout = + isStMatrix ? chooseStMatrixLayout(ctx, op.getSrc().getType(), + /*swizzleByteSize=*/0) + : srcLayout.invertAndCompose(sharedLayout); + + const int shmemAllocatedNumElems = + getNumScratchElements(scratchConfig.paddedRepShape); + assert(shmemStoreLayout.getOutDimSize(kOffset) <= shmemAllocatedNumElems); + + // Layout for the load from shmem to registers. + LinearLayout shmemLoadLayout = dstLayout.invertAndCompose(sharedLayout); + + // Check that the `register` fully determines the `iteration`. That is, + // each thread does exactly the same reads and writes to shmem on each + // iteration, just with different input/output registers. + assert( + shmemStoreLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration})); + assert( + shmemLoadLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration})); + + // iteration -> registers + SmallVector> inRegsForIter = + collectRegsForIter(ctx, shmemStoreLayout); + SmallVector> outRegsForIter = + collectRegsForIter(ctx, shmemLoadLayout); + + Value smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + auto sharedPtrTy = smemBase.getType(); + Type elemTy = inVals[0].getType(); + auto outSize = shmemLoadLayout.getInDimSize(kRegister); + auto iterations = sharedLayout.getInDimSize(kIteration); + assert(scratchConfig.inVec * iterations <= inVals.size()); + assert(scratchConfig.outVec * iterations <= outSize); + + // Check only one dimension has been padded. + // This means the difference between the padded shape and the original shape + // should only be in one dimension, specifically in + // `scratchConfig.order[0]`. + auto rank = scratchConfig.repShape.size(); + for (auto i = 0; i < rank; i++) { + if (i == scratchConfig.order[0]) { + continue; + } + assert(scratchConfig.repShape[i] == scratchConfig.paddedRepShape[i]); + } + auto paddedStride = scratchConfig.repShape[scratchConfig.order[0]]; + auto paddedSize = + scratchConfig.paddedRepShape[scratchConfig.order[0]] - paddedStride; + + // Linear layout function is split in two parts below: + // + // L(r, t, w, b) = L(0, t, w, b) xor L(r, 0, 0, 0) + // offset = regBase xor regIdx + // + // It is the same hack as what we've done in the emitIndices function to get + // around performance issues on AMD GPUs + auto getVecAddr = [&](LinearLayout &layout, Value ®Base, + int regSlice) -> Value { + auto regIdx = layout + .apply({{kRegister, regSlice}, + {kLane, 0}, + {kWarp, 0}, + {kBlock, 0}})[0] + .second; + Value offset = b.xor_(regBase, b.i32_val(regIdx)); + if (paddedSize > 0) { + assert(llvm::isPowerOf2_32(paddedStride)); + assert(llvm::isPowerOf2_32(paddedSize)); + auto rshiftVal = llvm::Log2_32(paddedStride); + auto lshiftVal = llvm::Log2_32(paddedSize); + offset = b.add( + b.shl(b.lshr(offset, b.i32_val(rshiftVal)), b.i32_val(lshiftVal)), + offset); + } + auto vecAddr = b.gep(sharedPtrTy, elemTy, smemBase, offset, + LLVM::GEPNoWrapFlags::inbounds); + return vecAddr; + }; + + auto storeBase = applyLinearLayout(loc, rewriter, shmemStoreLayout, + {{kRegister, b.i32_val(0)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, b.i32_val(0)}})[0] + .second; + auto loadBase = applyLinearLayout(loc, rewriter, shmemLoadLayout, + {{kRegister, b.i32_val(0)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, b.i32_val(0)}})[0] + .second; + // register idx -> Value + llvm::MapVector outVals; + for (int i = 0; i < iterations; i++) { + if (i != 0) + b.barrier(); + + auto &inRegs = inRegsForIter[i]; + auto &outRegs = outRegsForIter[i]; + + // When using `stmatrix`, we can store `inVec` elements even if they are + // not contiguous + auto inVec = isStMatrix ? shmemStoreLayout.getNumConsecutiveInOut() + : scratchConfig.inVec; + for (int j = 0; j < inVals.size() / iterations; j += inVec) { + auto inRegSlice = inRegs[j]; + Value vecAddr = getVecAddr(shmemStoreLayout, storeBase, inRegSlice); + SmallVector inValsVec; + for (int k = 0; k < inVec; k++) + inValsVec.push_back(inVals[inRegSlice + k]); + Value valsVec = packLLVector(loc, inValsVec, rewriter); + if (isStMatrix) { + targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec); + } else { + targetInfo.storeDShared(rewriter, loc, vecAddr, std::nullopt, valsVec, + /*pred=*/b.true_val()); + } + } + + b.barrier(); + + for (int j = 0; j < outSize / iterations; j += scratchConfig.outVec) { + auto outRegSlice = outRegs[j]; + auto vecAddr = getVecAddr(shmemLoadLayout, loadBase, outRegSlice); + Value valsVec = + targetInfo.loadDShared(rewriter, loc, vecAddr, std::nullopt, + vec_ty(elemTy, scratchConfig.outVec), + /*pred=*/b.true_val()); + for (Value v : unpackLLVector(loc, valsVec, rewriter)) + outVals[outRegSlice++] = v; + } + } + + SmallVector outValsVec; + for (size_t i = 0; i < outVals.size(); i++) + outValsVec.push_back(outVals[i]); + return outValsVec; + } + + // Determine which registers are read/written in which iteration of the shmem + // transfer specified by `layout`. + SmallVector /*registers*/> + collectRegsForIter(MLIRContext *ctx, const LinearLayout &layout) const { + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + StringAttr kIteration = str_attr("iteration"); + + // The choice of iteration should be determined only by the register. That + // is, it should be correct to split the register dimension into iterations. + assert(layout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration})); + + LinearLayout sublayout = layout.sublayout({kRegister}, {kIteration}); + SmallVector> ret(sublayout.getOutDimSize(kIteration)); + for (int reg = 0; reg < sublayout.getInDimSize(kRegister); reg++) { + auto idx = sublayout.apply({{kRegister, reg}}); + ret[idx.begin()->second].push_back(reg); + } + return ret; + } +}; + +} // namespace + +void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp( + ConvertLayoutOp op, DecomposedWarpConversion decomposed, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); + auto [P1, Cp, P2inv, reducedP1, reducedP2] = std::move(decomposed); + + // Grab the source elements and prepare the outputs of just the shuffles. + SmallVector srcValues = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector shflOuts(Cp.getInDimSize(kRegister)); + + Value laneId = getLaneId(rewriter, loc); + + // Emit one shuffle per destination register. + for (int i : llvm::seq(shflOuts.size())) { + // 'Cp' maps a (dst_lane, dst_reg) -> (src_lane, src_reg), and we know that + // for a register, it does not map to different registers in the same lane. + // At the same time, for each register, P1 returns the source value index + // to provide as the shuffle value. + auto out = applyLinearLayout(loc, rewriter, P1, + {{kLane, laneId}, {kRegister, b.i32_val(i)}}); + assert(out.size() == 1); + Value srcRegIdx = out.front().second; + // The size of the input lane dimension is the number of selects to emit. + // TODO(jeff): For dtypes smaller than i32, we can use byte permutes and + // shuffle multiple values at a time. + Value shflSrc = b.undef(srcValues.front().getType()); + for (int j : llvm::seq(reducedP1.getInDimSize(kLane))) { + int32_t check = + reducedP1.apply({{kLane, j}, {kRegister, i}}).front().second; + shflSrc = b.select(b.icmp_eq(srcRegIdx, b.i32_val(check)), + srcValues[check], shflSrc); + } + + out = applyLinearLayout(loc, rewriter, Cp, + {{kLane, laneId}, {kRegister, b.i32_val(i)}}); + assert(out.size() == 1); + Value shflIdx = out.front().second; + shflOuts[i] = targetInfo.shuffleIdx(rewriter, loc, shflSrc, shflIdx); + } + + // Finally, we just need to apply P2 to the shflOuts to permute the registers + // into their final form. Use the same trick to reduce the number of emitted + // selects. + SmallVector results(shflOuts.size()); + for (int i : llvm::seq(results.size())) { + Value result = b.undef(srcValues.front().getType()); + + auto out = applyLinearLayout(loc, rewriter, P2inv, + {{kLane, laneId}, {kRegister, b.i32_val(i)}}); + Value resultIdx = out.front().second; + for (int j : llvm::seq(reducedP2.getInDimSize(kLane))) { + int32_t check = + reducedP2.apply({{kLane, j}, {kRegister, i}}).front().second; + result = b.select(b.icmp_eq(resultIdx, b.i32_val(check)), shflOuts[check], + result); + } + results[i] = result; + } + + Value result = + packLLElements(loc, getTypeConverter(), results, rewriter, op.getType()); + rewriter.replaceOp(op, result); +} + +void mlir::triton::populateConvertLayoutOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add( + typeConverter, targetInfo, benefit); +} diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp new file mode 100644 index 000000000..0d6a0cad3 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -0,0 +1,38 @@ +#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace ::mlir::triton::gpu; + +namespace { +class GenericFMAVectorMultiplier : public FMAVectorMultiplier { + OpBuilder &builder; + Location loc; + +public: + GenericFMAVectorMultiplier(OpBuilder &builder, Location loc) + : builder(builder), loc(loc) {} + + Value multiplyVectors(ArrayRef a, ArrayRef b, + Value c) override { + auto K = a.size(); + assert(b.size() == K); + Value accum = c; + for (auto [aElem, bElem] : llvm::zip(a, b)) + accum = builder.create(loc, aElem, bElem, accum); + return accum; + } +}; + +} // namespace + +LogicalResult convertFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + GenericFMAVectorMultiplier multiplier(rewriter, loc); + return parametricConvertFMADot(op, adaptor, typeConverter, rewriter, + multiplier); +} diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp new file mode 100644 index 000000000..fa2c81472 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp @@ -0,0 +1,170 @@ +#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; + +namespace { + +/// OperandValueKey structure represents compile time part +/// of spatial coordinates of a value in a tensor. +/// +/// Every Value spatial coordinates(i.e. [batch;nonK;k]) in tensor can be +/// defined as: +/// +/// batch = (bRepIdx * CTABSize + bIdx) + (laneBCoord + warpBCoord) +/// nonK = (nonKRepIdx * CTANKSize + nonKIdx) + (laneNonKCoord + warpNonKCoord) +/// k = kIdx +/// +/// Where: +/// CTABSize, CTANKSize: constants; +/// laneBCoord, warpBCoord, laneNonKCoord, warpNonKCoord: runtime components; +/// bRepIdx, nonKRepIdx, bIdx, nonKIdx, kIdx: compile time components. +struct OperandValueKey { + unsigned bRepIdx, nonKRepIdx; + unsigned bIdx, nonKIdx, kIdx; + + bool operator==(const OperandValueKey &other) const { + return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx && + bIdx == other.bIdx && nonKIdx == other.nonKIdx && + kIdx == other.kIdx); + } +}; + +} // namespace + +template <> struct std::hash { + std::size_t operator()(const OperandValueKey &k) const { + return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx, + k.kIdx); + } +}; + +namespace { + +using ValueTableFMA = std::unordered_map; + +ValueTableFMA getValueTableFromStructFMA( + Value val, ArrayRef perRepShape, ArrayRef repetitions, + unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter, + Location loc, ArrayRef inRepOrder, ArrayRef repOrder) { + ValueTableFMA res; + auto elems = unpackLLElements(loc, val, rewriter); + assert(perRepShape.size() == 3); + auto numElemsRep = product(perRepShape); + assert(elems.size() == numElemsRep * product(repetitions)); + assert(kDim == 1 || kDim == 2); + assert(nonKDim == 1 || nonKDim == 2); + const unsigned bDim = 0; + + for (unsigned idx = 0; idx < elems.size(); ++idx) { + auto inRepLinearIdx = idx % numElemsRep; + auto repLinearIdx = idx / numElemsRep; + auto inRepSpatialIdx = + mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder); + auto repSpatialIdx = + mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder); + OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim], + inRepSpatialIdx[0], inRepSpatialIdx[nonKDim], + inRepSpatialIdx[kDim]}; + res[key] = elems[idx]; + } + return res; +} + +} // namespace + +namespace mlir::triton::gpu { + +LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + FMAVectorMultiplier &multiplier) { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + + auto A = op.getA(); + auto D = op.getResult(); + + auto aTensorTy = cast(A.getType()); + auto dTensorTy = cast(D.getType()); + + SmallVector aShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy))); + auto dShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy))); + + BlockedEncodingAttr dLayout = + cast(dTensorTy.getEncoding()); + // TODO process A and B operand separately + auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder()); + auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder()); + auto cc = unpackLLElements(loc, adaptor.getC(), rewriter); + + Value llA = adaptor.getA(); + Value llB = adaptor.getB(); + + auto sizePerThread = getContigPerThread(dTensorTy); + auto numElemsPerThread = product(sizePerThread); + SmallVector shapePerCTATile; + for (auto [reg, thread, warp] : + llvm::zip(sizePerThread, dLayout.getThreadsPerWarp(), + dLayout.getWarpsPerCTA())) { + shapePerCTATile.push_back(reg * thread * warp); + } + shapePerCTATile = expandMatrixShapeWithBatch(ArrayRef(shapePerCTATile)); + sizePerThread = expandMatrixShapeWithBatch(ArrayRef(sizePerThread)); + + unsigned K = aShapePerCTA[2]; + + unsigned threadTileShape[3]; + unsigned repetitions[3]; + for (int i = 0; i < 3; ++i) { + repetitions[i] = + ceil(dShapePerCTA[i], static_cast(shapePerCTATile[i])); + } + + auto has = getValueTableFromStructFMA( + llA, {sizePerThread[0], sizePerThread[1], K}, + {repetitions[0], repetitions[1], 1}, + /*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder); + auto hbs = getValueTableFromStructFMA( + llB, {sizePerThread[0], K, sizePerThread[2]}, + {repetitions[0], 1, repetitions[2]}, + /*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder); + + SmallVector acc = cc; + + for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep) + for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep) + for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep) + for (unsigned b = 0; b < sizePerThread[0]; ++b) + for (unsigned m = 0; m < sizePerThread[1]; ++m) + for (unsigned n = 0; n < sizePerThread[2]; ++n) { + SmallVector multiDimAccumIdx = {b, m, n}; + unsigned linearInRepIdx = + LLVM::linearize(multiDimAccumIdx, sizePerThread, inRepOrder); + SmallVector multiDimRepIdx = {bRep, mRep, nRep}; + unsigned linearRepIdx = + LLVM::linearize(multiDimRepIdx, repetitions, repOrder); + unsigned linearAccumIdx = + linearInRepIdx + linearRepIdx * numElemsPerThread; + + SmallVector aOpVector; + SmallVector bOpVector; + + for (unsigned k = 0; k < K; ++k) { + aOpVector.push_back(has.at({bRep, mRep, b, m, k})); + bOpVector.push_back(hbs.at({bRep, nRep, b, n, k})); + } + + acc[linearAccumIdx] = multiplier.multiplyVectors( + aOpVector, bOpVector, acc[linearAccumIdx]); + } + + auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy); + rewriter.replaceOp(op, res); + + return success(); +} + +} // namespace mlir::triton::gpu diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp new file mode 100644 index 000000000..fb78360ad --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -0,0 +1,665 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir::triton::gpu; + +namespace mlir::triton::gpu { + +Type getElementType(Value value) { + auto type = value.getType(); + if (auto tensorType = dyn_cast(type)) + return tensorType.getElementType(); + return type; +} + +int getNumElementsPerThreads(Type type, + const LLVMTypeConverter *typeConverter) { + int numElemsPerThread = 1; + if (auto tensorTy = dyn_cast(type)) { + auto structType = + dyn_cast(typeConverter->convertType(type)); + if (structType) + numElemsPerThread = structType.getBody().size(); + } + return numElemsPerThread; +} + +} // namespace mlir::triton::gpu + +namespace { +struct AddPtrOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto resultTy = op.getType(); + auto typeConverter = getTypeConverter(); + auto resultTensorTy = dyn_cast(resultTy); + if (resultTensorTy) { + unsigned elems = getTotalElemsPerThread(resultTy); + Type elemTy = typeConverter->convertType( + cast(resultTensorTy.getElementType()).getPointeeType()); + Type ptrTy = typeConverter->convertType(resultTensorTy.getElementType()); + auto ptrs = unpackLLElements(loc, adaptor.getPtr(), rewriter); + auto offsets = unpackLLElements(loc, adaptor.getOffset(), rewriter); + SmallVector resultVals(elems); + for (unsigned i = 0; i < elems; ++i) { + resultVals[i] = b.gep(ptrTy, elemTy, ptrs[i], offsets[i]); + } + Value view = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, view); + } else { + assert(isa(resultTy)); + auto resultPtrTy = typeConverter->convertType(resultTy); + auto resultElemTy = typeConverter->convertType( + cast(resultTy).getPointeeType()); + Value result = b.gep(resultPtrTy, resultElemTy, adaptor.getPtr(), + adaptor.getOffset()); + rewriter.replaceOp(op, result); + } + return success(); + } +}; + +struct CmpIOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + SmallVector createDestOps(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, + MultipleOperandsRange operands, + Location loc) const { + return {rewriter.create( + loc, elemTy, ArithCmpIPredicateToLLVM(op.getPredicate()), + operands[0][0], operands[0][1])}; + } + + static LLVM::ICmpPredicate + ArithCmpIPredicateToLLVM(arith::CmpIPredicate predicate) { + switch (predicate) { +#define __PRED_ENUM(item__) \ + case arith::CmpIPredicate::item__: \ + return LLVM::ICmpPredicate::item__ + + __PRED_ENUM(eq); + __PRED_ENUM(ne); + __PRED_ENUM(sgt); + __PRED_ENUM(sge); + __PRED_ENUM(slt); + __PRED_ENUM(sle); + __PRED_ENUM(ugt); + __PRED_ENUM(uge); + __PRED_ENUM(ult); + __PRED_ENUM(ule); + +#undef __PRED_ENUM + } + llvm_unreachable("Unknown arith::CmpIPredicate"); + } +}; + +struct CmpFOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + static SmallVector + createDestOps(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Type elemTy, + MultipleOperandsRange operands, Location loc) { + return {rewriter.create( + loc, elemTy, ArithCmpFPredicateToLLVM(op.getPredicate()), + operands[0][0], operands[0][1])}; + } + + static LLVM::FCmpPredicate + ArithCmpFPredicateToLLVM(arith::CmpFPredicate predicate) { + switch (predicate) { +#define __PRED_ENUM(item__, item1__) \ + case arith::CmpFPredicate::item__: \ + return LLVM::FCmpPredicate::item1__ + + __PRED_ENUM(OEQ, oeq); + __PRED_ENUM(ONE, one); + __PRED_ENUM(OGT, ogt); + __PRED_ENUM(OGE, oge); + __PRED_ENUM(OLT, olt); + __PRED_ENUM(OLE, ole); + __PRED_ENUM(ORD, ord); + __PRED_ENUM(UEQ, ueq); + __PRED_ENUM(UGT, ugt); + __PRED_ENUM(UGE, uge); + __PRED_ENUM(ULT, ult); + __PRED_ENUM(ULE, ule); + __PRED_ENUM(UNE, une); + __PRED_ENUM(UNO, uno); + __PRED_ENUM(AlwaysTrue, _true); + __PRED_ENUM(AlwaysFalse, _false); + +#undef __PRED_ENUM + } + llvm_unreachable("Unknown arith::CmpFPredicate"); + } +}; + +struct MulhiUIOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + explicit MulhiUIOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + targetInfo(targetInfo) {} + + SmallVector createDestOps(MulhiUIOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + + Type resultElementTy = getElementTypeOrSelf(op.getResult().getType()); + assert(resultElementTy.isInteger(32) || resultElementTy.isInteger(64)); + + auto funcName = targetInfo.getMulhiFuncName(resultElementTy); + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + return { + LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()}; + } + +protected: + const TargetInfoBase &targetInfo; +}; + +struct ExternElementwiseOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + typedef typename Base::OpAdaptor OpAdaptor; + + SmallVector createDestOps(ExternElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + StringRef funcName = op.getSymbol(); + if (funcName.empty()) + llvm::errs() << "ExternElementwiseOpConversion"; + + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp( + rewriter, op, funcName, funcType, op.getLibname(), op.getLibpath()); + return { + LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()}; + } +}; + +struct ElementwiseInlineAsmOpConversion + : public ConvertOpToLLVMPattern { + using Base = ConvertOpToLLVMPattern; + + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + typedef typename Base::OpAdaptor OpAdaptor; + + // If operand size is smaller than 32 bits, pack in groups of 32 bits. + SmallVector packOperands(ElementwiseInlineAsmOp op, + MultipleOperandsRange operands, + ConversionPatternRewriter &rewriter, + Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector packedOperands; + unsigned numPackedElements = op.getPackedElement(); + for (int i = 0, e = op.getNumOperands(); i < e; i++) { + Type elemTy = getElementType(op.getOperand(i)); + unsigned bitWidth = + elemTy.isIntOrFloat() ? elemTy.getIntOrFloatBitWidth() : 64; + unsigned numElementPerReg = bitWidth < 32 ? 32 / bitWidth : 1; + numElementPerReg = std::min(numElementPerReg, numPackedElements); + for (int j = 0; j < numPackedElements; j += numElementPerReg) { + if (numElementPerReg == 1) { + packedOperands.push_back(operands[j][i]); + continue; + } + Type t = + vec_ty(getTypeConverter()->convertType(elemTy), numElementPerReg); + Value packed = b.undef(t); + for (int k = 0; k < numElementPerReg; k++) { + packed = b.insert_element(packed, operands[j + k][i], b.i32_val(k)); + } + packedOperands.push_back(packed); + } + } + return packedOperands; + } + + SmallVector> + createDestOps(ElementwiseInlineAsmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + MultipleOperandsRange operands, Location loc) const { + auto ctx = op->getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + if (operands.size() % op.getPackedElement() != 0) + llvm::report_fatal_error("Inline asm op has more packed elements than " + "number of elements per thread."); + + // Pack elems smaller than 32 bits into 32-bit registers. + SmallVector packedOperands = + packOperands(op, operands, rewriter, loc); + + // Types returned by the LLVM asm op. If there's more than one, they'll be + // wrapped in a struct. + SmallVector asmRetTypes; + for (auto result : op.getResult()) { + auto ty = getTypeConverter()->convertType(getElementType(result)); + + // Pack return elements into 32-bits. + unsigned bitWidth = ty.isIntOrFloat() ? ty.getIntOrFloatBitWidth() : 64; + unsigned numElemsPerReg = + std::min(bitWidth < 32 ? 32 / bitWidth : 1, op.getPackedElement()); + assert(op.getPackedElement() % numElemsPerReg == 0); + if (numElemsPerReg > 1) { + ty = vec_ty(ty, numElemsPerReg); + } + for (unsigned i = 0; i < op.getPackedElement() / numElemsPerReg; i++) { + asmRetTypes.push_back(ty); + } + } + Type asmRetType = + asmRetTypes.size() > 1 ? struct_ty(asmRetTypes) : asmRetTypes[0]; + + Value asmResults = + rewriter + .create( + loc, asmRetType, + /*operands=*/packedOperands, + /*asm_string=*/op.getAsmString(), + /*constraints=*/op.getConstraints(), + /*has_side_effects=*/!op.getPure(), + /*is_align_stack=*/false, LLVM::TailCallKind::None, + /*asm_dialect=*/ + LLVM::AsmDialectAttr::get(rewriter.getContext(), + LLVM::AsmDialect::AD_ATT), + /*operand_attrs=*/ArrayAttr()) + ->getResult(0); + + // asmResults is a flat struct; pack its values into + // [return_value][op.getPackedElement()]. + SmallVector> ret(op->getNumResults()); + int structIdx = 0; + for (int i = 0; i < op->getNumResults(); i++) { + for (int j = 0; j < op.getPackedElement(); j++) { + Value val; + if (asmRetTypes.size() > 1) { + val = b.extract_val(asmResults, structIdx++); + } else { + val = asmResults; + } + if (auto vectorTy = dyn_cast(val.getType())) { + for (int k = 0; k < vectorTy.getNumElements(); k++) { + ret[i].push_back(b.extract_element(val, b.i32_val(k))); + } + j += vectorTy.getNumElements() - 1; + } else { + ret[i].push_back(val); + } + } + } + return ret; + } + + LogicalResult + matchAndRewrite(ElementwiseInlineAsmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + // Layout is unpackedOperands[operand][elem]. + SmallVector> unpackedOperands; + for (auto operand : adaptor.getOperands()) { + auto argTy = op->getOperand(0).getType(); + auto subOperands = unpackLLElements(loc, operand, rewriter); + unpackedOperands.push_back(subOperands); + } + + int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(), + getTypeConverter()); + + // These are checked by the verifier, so we don't need to raise a nice + // error. + assert(all_of(unpackedOperands, [&](auto &operands) { + return operands.size() == numElemsPerThread; + })); + if (numElemsPerThread % op.getPackedElement() != 0) { + // Pad with the undef for each operand to have a multiple of + // op.getPackedElement() elements. + int numPaddedValue = + op.getPackedElement() - numElemsPerThread % op.getPackedElement(); + for (auto &operands : unpackedOperands) { + for (int i = 0; i < numPaddedValue; i++) { + operands.push_back(b.undef(operands[0].getType())); + } + } + } + + // Run the inline asm op on each block of elements. + // + // Layout is unpackedResults[result_idx][elem]. + // + // This loop always runs at least once, even when the asm has no input + // elements. + SmallVector> unpackedResults(op->getNumResults()); + for (unsigned i = 0; i < numElemsPerThread; i += op.getPackedElement()) { + // Block of elements to process with one call to the inline asm. This is + // ordered opposite `unpackedResults`: The outer dim is + // op.getPackedElement(), and the inner dim is the operand. + SmallVector> block(op.getPackedElement()); + for (auto &os : unpackedOperands) { + for (int j = 0; j < op.getPackedElement(); j++) { + block[j].push_back(os[i + j]); + } + } + auto cur = createDestOps(op, adaptor, rewriter, block, loc); + assert(cur.size() == unpackedResults.size()); + for (unsigned j = 0; j < cur.size(); j++) { + unpackedResults[j].insert(unpackedResults[j].end(), cur[j].begin(), + cur[j].end()); + } + } + for (auto &results : unpackedResults) { + results.resize(numElemsPerThread); + } + // Reorder and pack the results. + SmallVector outs; + for (int i = 0; i < unpackedResults.size(); i++) { + outs.push_back(packLLElements(loc, getTypeConverter(), unpackedResults[i], + rewriter, op->getResult(i).getType())); + } + + rewriter.replaceOp(op, outs); + return success(); + } +}; + +struct AbsIOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(math::AbsIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + return {rewriter.create(loc, elemTy, operands[0][0], + /*is_int_min_poison=*/false)}; + } +}; + +struct AbsFOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(math::AbsFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (llvm::isa(elemTy)) { + // Mask out the sign bit + auto num_bits = + getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth(); + assert(num_bits <= 16); + auto mask = (1u << (num_bits - 1u)) - 1u; + auto maskAttr = rewriter.getIntegerAttr(elemTy, mask); + auto maskConst = rewriter.create(loc, maskAttr); + return {b.and_(operands[0][0], maskConst)}; + } + + return {rewriter.create(loc, elemTy, operands[0][0])}; + } +}; + +struct SelectOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + std::array llvmOperands; + if (operands[0].size() == 2) { + // Case of scalar condition with tensor operands. + assert(op.getCondition().getType().isInteger(1)); + llvmOperands = {adaptor.getCondition(), operands[0][0], operands[0][1]}; + } else { + llvmOperands = {operands[0][0], operands[0][1], operands[0][2]}; + } + return {rewriter.create( + loc, llvmOperands[1].getType(), llvmOperands, + adaptor.getAttributes().getValue())}; + } +}; +template +struct MinMaxFOpConversion + : ElementwiseOpConversionBase> { + using Base = ElementwiseOpConversionBase>; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + static_assert(std::is_same::value || + std::is_same::value, + "OpTy must be arith::MinimumFOp or arith::MaximumFOp"); + + // Choose the destination op based on the OpTy. + using DestOpNanProp = + typename std::conditional::value, + LLVM::MinimumOp, LLVM::MaximumOp>::type; + using DestOpNoNanProp = + typename std::conditional::value, + LLVM::MinNumOp, LLVM::MaxNumOp>::type; + + explicit MinMaxFOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + bool hwNanPropagationSupported, + PatternBenefit benefit = 1) + : Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, + benefit), + hwNanPropagationSupported(hwNanPropagationSupported) {} + + SmallVector createDestOps(OpTy op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + if (hwNanPropagationSupported) { + return {rewriter.create(loc, elemTy, operands[0][0], + operands[0][1])}; + } + // Handle workaround for NaN propagation, i.e. software emulation of NaN + // propagation. If any of the operands is NaN, return NaN. + auto lhs = operands[0][0]; + auto rhs = operands[0][1]; + auto lhsIsNan = + rewriter.create(loc, LLVM::FCmpPredicate::une, lhs, lhs); + auto rhsIsNan = + rewriter.create(loc, LLVM::FCmpPredicate::une, rhs, rhs); + auto isNan = rewriter.create(loc, lhsIsNan, rhsIsNan); + auto nonNanRes = rewriter.create(loc, elemTy, lhs, rhs); + + auto nan = LLVM::createNaNConstant(loc, rewriter, elemTy); + + // Select the result based on the isNan flag. + return {rewriter.create(loc, isNan, nan, nonNanRes)}; + } + +private: + bool hwNanPropagationSupported; +}; + +struct ClampFOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + explicit ClampFOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + targetInfo(targetInfo) {} + + SmallVector createDestOps(ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + // Clip pattern not found, use min/max. + if (op.getPropagateNan() == PropagateNan::ALL) { + if (targetInfo.supportMaximumMinimum()) { + auto v = rewriter.create(loc, elemTy, operands[0][0], + operands[0][1]); + return {rewriter.create(loc, v, operands[0][2])}; + } + // On pre-80 compute capability, we need to handle NaN propagation + // manually. We need to check only the first operand for clamp. + auto lhs = operands[0][0]; + auto isNan = rewriter.create(loc, LLVM::FCmpPredicate::une, + lhs, lhs); + auto v = rewriter.create(loc, elemTy, operands[0][0], + operands[0][1]); + auto nonNanRes = rewriter.create(loc, v, operands[0][2]); + auto nan = LLVM::createNaNConstant(loc, rewriter, elemTy); + // Select the result based on the isNan flag. + return {rewriter.create(loc, isNan, nan, nonNanRes)}; + } + + // No NaN propagation. + assert(op.getPropagateNan() == PropagateNan::NONE); + auto v = rewriter.create(loc, elemTy, operands[0][0], + operands[0][1]); + return {rewriter.create(loc, v, operands[0][2])}; + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateMinMaxFOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, bool hwNanPropagationSupported, + PatternBenefit benefit) { + patterns.add>( + typeConverter, axisInfoAnalysis, hwNanPropagationSupported, benefit); + patterns.add>( + typeConverter, axisInfoAnalysis, hwNanPropagationSupported, benefit); +} + +void mlir::triton::populateClampFOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, axisInfoAnalysis, targetInfo, + benefit); +} + +void mlir::triton::populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, + PatternBenefit benefit) { +#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); + + POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp) + POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp) + POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp) + POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp) + POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp) + POPULATE_UNARY_OP(math::FloorOp, math::FloorOp) + POPULATE_UNARY_OP(math::CeilOp, math::CeilOp) + POPULATE_UNARY_OP(math::LogOp, math::LogOp) + POPULATE_UNARY_OP(math::Log2Op, math::Log2Op) + POPULATE_UNARY_OP(math::CosOp, math::CosOp) + POPULATE_UNARY_OP(math::SinOp, math::SinOp) + POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp) + POPULATE_UNARY_OP(math::RsqrtOp, math::RsqrtOp) + POPULATE_UNARY_OP(math::ExpOp, math::ExpOp) + POPULATE_UNARY_OP(math::Exp2Op, math::Exp2Op) + POPULATE_UNARY_OP(math::ErfOp, math::ErfOp) + POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp) + POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp) + POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp) +#undef POPULATE_UNARY_OP + +#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \ + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); + + POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // - + POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // + + POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // * + POPULATE_BINARY_OP(arith::DivSIOp, LLVM::SDivOp) + POPULATE_BINARY_OP(arith::DivUIOp, LLVM::UDivOp) + POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // % + POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp) + POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp) + POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // & + POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // | + POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^ + POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << + POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> + POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> + // fmin (return non-NaN if either op is non-NaN) + POPULATE_BINARY_OP(arith::MinNumFOp, LLVM::MinNumOp) + // fmax (return non-NaN if either op is non-NaN) + POPULATE_BINARY_OP(arith::MaxNumFOp, LLVM::MaxNumOp) + POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin + POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax + POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin + POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax +#undef POPULATE_BINARY_OP + + patterns.add>( + typeConverter, axisInfoAnalysis, benefit); + + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, targetInfo, + benefit); + patterns.add(typeConverter, axisInfoAnalysis, + benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); +} diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp new file mode 100644 index 000000000..7ece98f87 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp @@ -0,0 +1,213 @@ +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace mlir { +FailureOr +convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter); +} + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +// NOTE: [Additional Function Arguments] +// To support use of shared memory and global scratch memory inside of a +// function, the caller allocates a single large block of the relevant memory +// and calls the function with these extra arguments at the end. +// Specifically, the last argument is the global scratch memory allocation and +// the second to last is the shared memory allocation. +// +// For the kernel function itself, the shared memory base is a global symbol +// so no additional function argument is required but global scratch memory +// allocation is still passed in as the last argument. Though here the scratch +// memory is shared between all programs, so a linear offset based on the +// program id is required to get the local scratch base. + +/// FuncOp legalization pattern that converts MemRef arguments to pointers to +/// MemRef descriptors (LLVM struct data types) containing all the MemRef type +/// information. +struct FuncOpConversion : public ConvertOpToLLVMPattern { + FuncOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), targetInfo(targetInfo) {} + + /// Only retain those attributes that are not constructed by + /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument + /// attributes. + static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, + SmallVectorImpl &result) { + + for (const auto &attr : op->getAttrs()) { + if (attr.getName() == SymbolTable::getSymbolAttrName() || + attr.getName() == op.getFunctionTypeAttrName() || + attr.getName() == "std.varargs" || + (filterArgAttrs && attr.getName() == op.getArgAttrsAttrName())) + continue; + result.push_back(attr); + } + } + + triton::FuncOp amendFuncOp(triton::FuncOp funcOp, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const { + // Push back two new arguments that indicate the current pointer to shared + // memory and global scratch memory. + auto loc = funcOp.getLoc(); + auto ctx = funcOp->getContext(); + auto sharedPtrTy = + LLVM::LLVMPointerType::get(ctx, targetInfo.getSharedAddressSpace()); + auto globalPtrTy = LLVM::LLVMPointerType::get(ctx, 1); + + // 1. Modify the function type to add the new arguments. + auto funcTy = funcOp.getFunctionType(); + auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); + bool isKernel = triton::isKernel(funcOp); + if (isKernel) { + for (auto i : llvm::seq(amendedInputTy.size())) { + if (isa(amendedInputTy[i])) { + funcOp.setArgAttr(i, "tt.nv_tma_desc", + mlir::IntegerAttr::get(i32_ty, 1)); + } + } + } + if (!isKernel) { + amendedInputTy.push_back(sharedPtrTy); + } + amendedInputTy.push_back(globalPtrTy); + auto amendedFuncTy = + FunctionType::get(ctx, amendedInputTy, funcTy.getResults()); + // 2. Modify the argument attributes to add the new argument. + SmallVector amendedAttrs; + filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); + if (auto argAttrs = funcOp.getAllArgAttrs()) { + llvm::SmallVector amendedArgAttrs(argAttrs.begin(), + argAttrs.end()); + while (amendedArgAttrs.size() < amendedInputTy.size()) { + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + } + amendedAttrs.push_back( + rewriter.getNamedAttr(funcOp.getArgAttrsAttrName(), + rewriter.getArrayAttr(amendedArgAttrs))); + } + + // 3. Add the new arguments to the region + auto amendedFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); + auto ®ion = funcOp.getBody(); + if (!isKernel) { + region.addArgument(sharedPtrTy, loc); + } + region.addArgument(globalPtrTy, loc); + rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), + amendedFuncOp.end()); + return amendedFuncOp; + } + + // Map the MLIR attribute `tt.nv_tma_desc` to the appropriate LLVM and NVVM + // attributes. + static void handleByvalTmaDescArgs(LLVM::LLVMFuncOp &llvmFuncOp) { + const bool isKernel = triton::isKernel(llvmFuncOp); + for (unsigned i = 0; i < llvmFuncOp.getNumArguments(); ++i) { + const auto attrs = llvmFuncOp.getArgAttrDict(i); + if (!attrs) { + continue; + } + + for (const auto &attr : attrs) { + if (attr.getName() == "tt.nv_tma_desc") { + const auto i32_type = + mlir::IntegerType::get(llvmFuncOp.getContext(), 32); + assert(attr.getValue() == mlir::IntegerAttr::get(i32_type, 1)); + assert(isKernel && + "tt.nv_tma_desc is not supported for device functions"); + + // See + // https://github.com/google/jax/blob/main/jaxlib/mosaic/gpu/passes.cc + mlir::BlockArgument arg = llvmFuncOp.getArgument(i); + const auto byteType = + mlir::IntegerType::get(llvmFuncOp.getContext(), 8); + const auto arrayType = mlir::LLVM::LLVMArrayType::get( + llvmFuncOp.getContext(), byteType, 128); + llvmFuncOp.setArgAttr(i, LLVM::LLVMDialect::getByValAttrName(), + mlir::TypeAttr::get(arrayType)); + llvmFuncOp.setArgAttr(i, NVVM::NVVMDialect::getGridConstantAttrName(), + mlir::UnitAttr::get(llvmFuncOp.getContext())); + llvmFuncOp.setArgAttr(i, LLVM::LLVMDialect::getAlignAttrName(), + mlir::IntegerAttr::get(i32_type, 64)); + } + } + } + } + + LogicalResult + matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Prevent LLVM's inliner to inline this function + auto amendedFuncOp = amendFuncOp(funcOp, rewriter, targetInfo); + + FailureOr maybeNewFuncOp = + mlir::convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter, + *getTypeConverter()); + if (failed(maybeNewFuncOp)) { + return failure(); + } + + LLVM::LLVMFuncOp newFuncOp = *maybeNewFuncOp; + + auto ctx = funcOp->getContext(); + + if (triton::isKernel(funcOp)) { + // Set an attribute to indicate this function is a kernel entry. + newFuncOp->setAttr(NVVM::NVVMDialect::getKernelFuncAttrName(), + rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); + newFuncOp.setLinkage(LLVM::Linkage::External); + } else { + // The noinline attribute will be used by the LLVM codegen to prevent + // inlining. + // https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267 + newFuncOp.setPassthroughAttr( + ArrayAttr::get(ctx, rewriter.getStringAttr("noinline"))); + newFuncOp.setLinkage(LLVM::Linkage::Internal); + } + + // Determine the actual number of required warps. + int numWarps = triton::gpu::lookupNumWarps(funcOp); + if (auto totalNumWarps = funcOp.getParentOp()->getAttrOfType( + "ttg.total-num-warps")) + numWarps = totalNumWarps.getInt(); + + // Set `nvvm.maxnreg` if it was specified on the module. + if (Attribute maxnregAttr = + funcOp.getParentOp()->getAttr(triton::gpu::AttrMaxRegistersName)) + newFuncOp->setAttr(NVVM::NVVMDialect::getMaxnregAttrName(), maxnregAttr); + + // Set an attribute for reqntidx, it could be used in latter LLVM codegen + // for `nvvm.annotation` metadata. + newFuncOp->setAttr(NVVM::NVVMDialect::getReqntidAttrName(), + rewriter.getDenseI32ArrayAttr(32 * numWarps)); + + rewriter.eraseOp(funcOp); + rewriter.eraseOp(amendedFuncOp); + + // Add attributes for by-value TMA descriptor args (nvidia) + handleByvalTmaDescArgs(newFuncOp); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateFuncOpConversionPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp new file mode 100644 index 000000000..109a58389 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -0,0 +1,350 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace { +class GatherOpConversion : public ConvertOpToLLVMPattern { +public: + GatherOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; + +private: + // Codegen the gather by storing the source tensor into shared memory and then + // gathering directly from shared memory. + void emitGatherInShared(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + // Codegen a warp-local gather by shuffling elements across the warp and + // selecting from them. + void emitWarpLocalGather(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + + const TargetInfoBase &targetInfo; +}; + +LogicalResult +GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + GatherLoweringHelper helper(op); + // Specialize the lowering based on the source layout. Given that the cost of + // a warp shuffle is approximately half the cost of a roundtrip to shared + // memory with zero bank conflicts, we will need a more precise heuristic to + // choose between the two codegen paths and rely on the middle end to pick the + // right layout. + if (helper.isWarpLocal()) { + emitWarpLocalGather(op, adaptor, rewriter); + } else { + emitGatherInShared(op, adaptor, rewriter); + } + return success(); +} + +static Value convertIndexToI32(Location loc, Value index, + ConversionPatternRewriter &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned idxWidth = index.getType().getIntOrFloatBitWidth(); + // The LL index computations are performed with 32 bit integers. If the + // indices are something else, cast them to i32. + if (idxWidth > 32) { + index = b.trunc(i32_ty, index); + } else if (idxWidth < 32) { + // Negative indices don't make sense, so zero-extend. + index = b.zext(i32_ty, index); + } + return index; +} + +void GatherOpConversion::emitGatherInShared( + GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + RankedTensorType srcType = op.getSrc().getType(); + + // Compute the src subtensor shape owned by this CTA. + SmallVector srcShapePerCTA = + convertType(triton::gpu::getShapePerCTA(srcType)); + + // Grab the src values in this thread. + SmallVector srcValues = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + + // Emit the indices of the src values owned by this thread. + SmallVector> srcIndices = + emitIndices(loc, rewriter, targetInfo, srcType.getEncoding(), + op.getSrc().getType(), /*withCTAOffset=*/true); + + // Store the src values owned by the thread into their respective location in + // the scratch memory. + assert(srcValues.size() == srcIndices.size()); + + // Get the base pointer to the scratch memory. + Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op); + + // For each src element owned by the thread, index into the scratch memory and + // then store it. + Type elemType = getTypeConverter()->convertType(srcType.getElementType()); + for (auto [value, indices] : llvm::zip(srcValues, srcIndices)) { + // Convert the index at each dim into a single offset given the shape of the + // tensor. + Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA); + // Emit the offset into the shared memory and then store the value. + Value ptr = b.gep(smemBase.getType(), elemType, smemBase, offset); + b.store(value, ptr); + } + + // Synchronize the whole CTA. + b.barrier(); + + // Grab the index values owned by this thread. + SmallVector idxValues = + unpackLLElements(loc, adaptor.getIndices(), rewriter); + + // Apply the layout of the destination tensor to obtain the indices of the + // column to gather along, then for each column, replace the index along the + // gather axis with the appropriate index value. + // + // I = LL(pid) + // idx = indices[I] + // I_gather = [I[d] if d != axis else idx for d in range(len(I))] + // out[I] = src[I_gather] + RankedTensorType dstType = op.getType(); + SmallVector> dstIndices = + emitIndices(loc, rewriter, targetInfo, dstType.getEncoding(), dstType, + /*withCTAOffset=*/true); + + unsigned axis = op.getAxis(); + SmallVector results(dstIndices.size()); + for (auto [i, idx, indices] : llvm::enumerate(idxValues, dstIndices)) { + indices[axis] = convertIndexToI32(loc, idx, rewriter); + Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA); + Value ptr = b.gep(smemBase.getType(), elemType, smemBase, offset); + results[i] = b.load(elemType, ptr); + } + + Value packed = + packLLElements(loc, getTypeConverter(), results, rewriter, dstType); + rewriter.replaceOp(op, packed); +} + +// High-level description of the algorithm: +// +// `isWarpLocal` checks that it is possible to compute each output element +// without data movement across warps. +// +// If the gather dim is `dimN`, then this means +// +// ll^-1(dimN)[(block, warp)] == 0 +// +// for both source and index tensors: moving along the gather axis does not +// change the warp. Broadcasted layouts are not supported, so we know the +// layouts are permutation matrices. +// +// We can check this with `ll((block, warp))[dimN] == 0`. +// +// Let `gatherCol` be a tuple of all dimensions except the gather dimension. +// We also check that the gather columns line up the same way with respect to +// the warp between the source and index tensors with +// +// ll_src((block, warp))[gatherCol] == ll_idx((block, warp))[gatherCol] +// +// This means that for all index columns, the corresponding column in the source +// tensor is owned by the same warp. +// +// We also check +// +// ll_src(lane)[gatherCol] == ll_idx(lane)[gatherCol] +// +// This boils down to the fact that the algorithm essentially emits a series of +// index shuffles for each index value owned by each thread, and then a pile of +// selects to pick the right value. We need to figure out given an index value +// in a particular column, what are the source register values it could read +// from and who owns them. +// +// If this relationship did not hold, then the possible source registers for +// each index value varies with the thread, meaning the value operand provided +// to each shuffle index instruction would depend on the thread ID. This isn't a +// big deal. It just means would have to emit a pile of selects before each +// shuffle as well, to pick the right source register value. But we choose not +// to handle this. +// +// The codegen algorithm emits code: +// - Given the thread ID and a particular index tensor register, figure out +// which gather column it belongs to using a layout. +// - Using the index value itself as the value for `dimN`, use another layout to +// figure out which lane in the warp owns the desired value and which register +// in that lane it is. +// - For the gather column, figure out the source registers in that column, and +// for each of them, emit an index shuffle with the same computed lane ID. +// - Use the register component to select the right value from the shuffle +// results. +void GatherOpConversion::emitWarpLocalGather( + GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + RankedTensorType srcType = op.getSrc().getType(); + RankedTensorType idxType = op.getIndices().getType(); + + // Layout dimension names. + StringAttr kBlock = str_attr("block"); + StringAttr kWarp = str_attr("warp"); + StringAttr kLane = str_attr("lane"); + StringAttr kRegister = str_attr("register"); + StringAttr kGatherDim = rewriter.getStringAttr("dim" + Twine(op.getAxis())); + SmallVector allDims, otherDims; + for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) { + allDims.push_back(str_attr("dim" + Twine(dim))); + if (dim != op.getAxis()) { + otherDims.push_back(allDims.back()); + } + } + + // Compute the src and idx layouts. + LinearLayout srcLayout = + toLinearLayout(srcType.getShape(), srcType.getEncoding()); + LinearLayout idxLayout = + toLinearLayout(idxType.getShape(), idxType.getEncoding()); + + // Let `ll_src` be the source layout and `ll_idx` be the index layout. + // Let `src_col` be a tuple of dimensions except the gather dimension, + // representing a specific column in the source tensor. Likewise for + // `idx_col`. Let `src_idx` be the index into gather dimension in the source + // tensor. + // + // `(src_lane, src_reg) = ll_src^-1(src_col, src_idx)`, where `src_lane` is + // the thread that contains the required element and `src_reg` is the register + // within that thread. + // + // Because `ll_src(block=0, warp=0, lane=0)[otherDims] == + // ll_idx(0, 0, 0)[otherDims]`, we know given any `idx_reg` (element in the + // index tensor) the thread will need to read from the same column in the + // source tensor. + // + // Thus, we can obtain + // + // (src_lane, src_reg) = (ll_src^-1)( + // ll_idx(black, warp, lane, idx_reg)[otherDims], + // idxValues[idx_reg] + // )[{"lane", "register"}] + // + // And the mapping will be the correct for each thread. + // + // Given `src_reg \in [0, K*N)`, we just need to emit N index shuffles for + // each `idx_reg` (the number of index shuffles is quadratic!) and + // `llvm.select` using `src_reg` to get the right one. `K` is the number of + // elements per column owned by a thread. + + // Invert the source layout. It doesn't matter whether it is fully invertible + // with respect to anything except the register input dimension, since we know + // those don't vary in ways that matter for codegen. + LinearLayout invSrcLayout = srcLayout.pseudoinvert(); + + // Sanity check: the warp must be invariant to the index because otherwise the + // gather would need to read across warps! + assert(invSrcLayout.sublayoutIsZero(kGatherDim, {kBlock, kWarp}) && + "expected a warp-local gather"); + invSrcLayout = invSrcLayout.sublayout(allDims, {kLane, kRegister}); + + LinearLayout idxColLayout = + idxLayout.sublayout({kBlock, kWarp, kLane, kRegister}, otherDims); + + SmallVector srcValues = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector idxValues = + unpackLLElements(loc, adaptor.getIndices(), rewriter); + + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + Value blockId = targetInfo.getClusterCTAId(rewriter, loc); + + unsigned /*N=*/srcRegsPerThread = srcLayout.getInDimSize(kRegister); + assert(srcRegsPerThread == srcValues.size()); + + // Given a index value, we need to know which sources register values it could + // index into. This is invariant to anything other than the register, which we + // checked already. Compute the full reverse map from + // + // idx_reg -> gather_column -> (src_reg0, src_reg1, ...) + // + LinearLayout invertSrcRegMap = invSrcLayout.sublayout(allDims, {kRegister}); + // Remove zero bases in the gather dimension to make the function injective + // (for a given column) over the same codomain. + invertSrcRegMap = invertSrcRegMap.removeZeroBasesAlongDim(kGatherDim); + // We are left with only non-zero bases in the gather dimension, which means + // the number of registers per column is the size of the "gather dimension". + unsigned numRegsPerColumn = invertSrcRegMap.getInDimSize(kGatherDim); + // Get a map from idx_reg to the column it indexes into. + LinearLayout idxRegToCol = idxLayout.sublayout({kRegister}, otherDims); + // Now given `idx_reg`, we can compute the column it belongs to in both src + // and index tensors, then partially apply `invertSrcRegMap` with this to + // obtain a function that outputs the corresponding registers in the src + // tensor in the same column. + + // L(column, i) = L(column, 0) xor L(0, i) + LinearLayout invertSrcRegMapColPart = + invertSrcRegMap.sublayout(otherDims, {kRegister}); + LinearLayout invertSrcRegMapRest = + invertSrcRegMap.sublayout({kGatherDim}, {kRegister}); + + SmallVector results; + for (auto [idxReg, idxVal] : llvm::enumerate(idxValues)) { + SmallVector> column = + applyLinearLayout(loc, rewriter, idxColLayout, + {{kBlock, blockId}, + {kWarp, warpId}, + {kLane, laneId}, + {kRegister, b.i32_val(idxReg)}}); + assert(column.size() == otherDims.size()); + + // Combine the computed column with the data-dependent gather index. + column.emplace_back(kGatherDim, convertIndexToI32(loc, idxVal, rewriter)); + SmallVector> srcLaneAndReg = + applyLinearLayout(loc, rewriter, invSrcLayout, column); + + auto [srcLaneName, srcLane] = srcLaneAndReg.back(); + auto [srcRegName, srcReg] = srcLaneAndReg.front(); + assert(srcLaneName == kLane && srcRegName == kRegister); + + assert(!srcValues.empty() && "can't gather from an empty tensor"); + + // Figure out which src registers we need to index shuffle from. This is + // invariant to anything else. + SmallVector> normalizedColumn = + idxRegToCol.apply({{kRegister, idxReg}}); + int32_t srcBase = + invertSrcRegMapColPart.apply(normalizedColumn).front().second; + + Value result = b.undef(srcValues.front().getType()); + for (unsigned i = 0; i != numRegsPerColumn; ++i) { + int32_t rest = + invertSrcRegMapRest.apply({{kGatherDim, i}}).front().second; + int32_t srcRegIdx = srcBase ^ rest; + + Value value = + targetInfo.shuffleIdx(rewriter, loc, srcValues[srcRegIdx], srcLane); + result = b.select(b.icmp_eq(b.i32_val(srcRegIdx), srcReg), value, result); + } + + results.push_back(result); + } + + rewriter.replaceOp(op, packLLElements(loc, getTypeConverter(), results, + rewriter, op.getType())); +} + +} // namespace + +void triton::populateGatherOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.insert(typeConverter, targetInfo, benefit); +} diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp new file mode 100644 index 000000000..07299ea1c --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/GlobalScratchMemoryAllocation.cpp @@ -0,0 +1,103 @@ +#include "mlir/Analysis/Liveness.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace triton; +using namespace triton::gpu; + +namespace mlir::triton::gpu { +#define GEN_PASS_DEF_TRITONGPUGLOBALSCRATCHALLOCATIONPASS +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton::gpu + +static int32_t roundUp(int32_t val, int32_t step) { + auto t = val + step - 1; + return t - (t % step); +} + +static void allocateGMem(Operation *parentOp, + llvm::SetVector &callStack) { + // Recursively visit any dependency functions + parentOp->walk([&](triton::CallOp call) { + auto callable = call.resolveCallable(); + if (!callable->hasAttr("ttg.global_scratch_memory_size")) { + auto inserted = callStack.insert(parentOp); + assert(inserted && "call cycle detected"); + allocateGMem(callable, callStack); + callStack.remove(parentOp); + } + }); + + MLIRContext *ctx = parentOp->getContext(); + OpBuilder builder(ctx); + int32_t offset = 0; + uint32_t largestAlignment = 1; + + // Dumb allocation that ignores liveness and makes no attempt to minimize + // padding + // TODO: Use a real algorithm + parentOp->walk([&](Operation *op) { + uint32_t nbytes = 0; + uint32_t align = 0; + if (auto alloc = dyn_cast(op)) { + nbytes = alloc.getNbytes(); + align = alloc.getAlignment(); + } else if (auto callOp = dyn_cast(op)) { + auto callable = callOp.resolveCallable(); + auto nbytes_attr = callable->getAttrOfType( + "ttg.global_scratch_memory_size"); + auto align_attr = callable->getAttrOfType( + "ttg.global_scratch_memory_alignment"); + assert(nbytes_attr); + assert(align_attr); + + nbytes = nbytes_attr.getValue().getZExtValue(); + align = align_attr.getValue().getZExtValue(); + } + if (nbytes > 0) { + offset = roundUp(offset, align); + op->setAttr("ttg.global_scratch_memory_offset", + builder.getI32IntegerAttr(offset)); + offset += nbytes; + largestAlignment = std::max(largestAlignment, align); + } + }); + int32_t totalMemorySize = roundUp(offset, largestAlignment); + parentOp->setAttr("ttg.global_scratch_memory_size", + builder.getI32IntegerAttr(totalMemorySize)); + parentOp->setAttr("ttg.global_scratch_memory_alignment", + builder.getI32IntegerAttr(largestAlignment)); +} + +namespace { +class TritonGPUGlobalScratchAllocationPass + : public mlir::triton::gpu::impl::TritonGPUGlobalScratchAllocationPassBase< + TritonGPUGlobalScratchAllocationPass> { +public: + void runOnOperation() override { + ModuleOp mod = getOperation(); + + bool seenKernel = false; + + SetVector callStack; + mod->walk([&](triton::FuncOp func) { + allocateGMem(func, callStack); + + if (func.getVisibility() == SymbolTable::Visibility::Public) { + assert(!seenKernel); + seenKernel = true; + auto size = + func->getAttrOfType("ttg.global_scratch_memory_size"); + auto align = func->getAttrOfType( + "ttg.global_scratch_memory_alignment"); + assert(size); + assert(align); + mod->setAttr("ttg.global_scratch_memory_size", size); + mod->setAttr("ttg.global_scratch_memory_alignment", align); + } + }); + assert(seenKernel); + } +}; +} // namespace diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp new file mode 100644 index 000000000..2abc63788 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp @@ -0,0 +1,227 @@ +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// Compute a histogram within a warp. This uses an algorithm by @apgoucher +// that does the following: +// Create a ballot for each bit of the bin index (there +// are only log2(num_bins) of these) and then apply bitwise operations to get +// the indicator functions for the bins owned by this particular thread, and +// only popcount those. +static SmallVector computeWarpLevelHistogram( + Location loc, RankedTensorType srcType, SmallVector &srcValues, + SmallVector &maskValues, int numBins, int numThreadPerWarp, + Value threadId, ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + assert(numBins % numThreadPerWarp == 0 && + "numBins must be divisible by numThreadPerWarp"); + Value zero = b.i32_val(0); + int numBits = llvm::Log2_64(numBins); + int numBitsLaneId = llvm::Log2_64(numThreadPerWarp); + unsigned numElementsPerThreads = getTotalElemsPerThread(srcType); + unsigned numThreadWithUniqueData = getThreadsPerWarp(srcType)[0]; + // The histogram is distributed across threads, each thread owns `numBins / + // numThreadPerWarp` bins. + SmallVector warpLevelHistogram(numBins / numThreadPerWarp, zero); + for (int i = 0; i < numElementsPerThreads; ++i) { + Value value = srcValues[i]; + SmallVector ballotBits; + for (int j = 0; j < numBits; ++j) { + Value bitSet = b.and_(value, b.i32_val(1 << j)); + Value cmp = b.icmp_ne(bitSet, zero); + Value bit = + targetInfo.ballot(rewriter, loc, int_ty(numThreadPerWarp), cmp); + ballotBits.push_back(bit); + } + uint64_t fullMaskValue = + numThreadPerWarp == 32 ? 0xFFFFFFFF : 0xFFFFFFFFFFFFFFFF; + Value fullMask = b.int_val(numThreadPerWarp, fullMaskValue); + Value mask = fullMask; + // If not all threads have unique data, mask out the redundant ones. + if (numThreadWithUniqueData < numThreadPerWarp) { + mask = b.int_val(numThreadPerWarp, (1ULL << numThreadWithUniqueData) - 1); + } + for (int i = 0; i < numBitsLaneId; i++) { + Value updateMask = + b.select(b.icmp_ne(b.and_(threadId, b.i32_val(1 << i)), zero), + b.int_val(numThreadPerWarp, 0), fullMask); + mask = b.and_( + mask, b.xor_(ballotBits[i + numBits - numBitsLaneId], updateMask)); + } + // save a ballot bit to capture the input mask + Value inputMaskBit = fullMask; + if (maskValues.size() > 0) { + inputMaskBit = targetInfo.ballot(rewriter, loc, int_ty(numThreadPerWarp), + maskValues[i]); + } + // mask out the values for which input mask is invalid + mask = b.and_(mask, inputMaskBit); + // at this point, 'mask' tells you which elements are in a bin owned by this + // thread. + for (int k = 0; k < warpLevelHistogram.size(); k++) { + Value binMask = mask; + for (int j = 0; j < numBits - numBitsLaneId; j++) { + Value updateMask = + b.int_val(numThreadPerWarp, ((k & (1 << j)) ? 0 : fullMaskValue)); + binMask = b.and_(binMask, b.xor_(ballotBits[j], updateMask)); + } + // at this point, 'bin_mask' tells you which elements are in the kth bin + // owned by this thread. + Value bitCount = rewriter.create( + loc, int_ty(numThreadPerWarp), binMask); + if (numThreadPerWarp > 32) + bitCount = b.trunc(i32_ty, bitCount); + warpLevelHistogram[k] = b.add(warpLevelHistogram[k], bitCount); + } + } + return warpLevelHistogram; +} + +static void atomicAdd(Value ptr, Value val, Location loc, + ConversionPatternRewriter &rewriter) { + rewriter.create(loc, LLVM::AtomicBinOp::add, ptr, val, + LLVM::AtomicOrdering::monotonic); +} + +static SmallVector computeCrossWarpHistogram( + Location loc, ConversionPatternRewriter &rewriter, RankedTensorType srcType, + Value baseSharedMemPtr, const SmallVector &warpLevelHistogram, + int numBins, int numThreadPerWarp, const SmallVector &indices, + Value threadId, int numWarps) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector histogramValues; + unsigned numWarpsWithUniqueData = mlir::triton::gpu::getWarpsPerCTA( + srcType.getEncoding(), srcType.getShape())[0]; + Value laneId = b.and_(threadId, b.i32_val(numThreadPerWarp - 1)); + // Initialize the shared memory with zeros. + int64_t numElementPerThread = + ceil(numBins, numThreadPerWarp * numWarps); + for (int i = 0; i < numElementPerThread; ++i) { + Value offset = + b.add(threadId, b.i32_val((i * numWarps * numThreadPerWarp))); + offset = b.urem(offset, b.i32_val(numBins)); + Value sharedMemPtr = + b.gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset); + b.store(b.i32_val(0), sharedMemPtr); + } + b.barrier(); + Block *afterAtomics = nullptr; + // If some warps have replicated data we need to skip those warps when + // accumulating. + if (numWarpsWithUniqueData < numWarps) { + Block *currentBlock = rewriter.getInsertionBlock(); + afterAtomics = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + Block *atomicBlock = rewriter.createBlock(afterAtomics); + rewriter.setInsertionPointToEnd(currentBlock); + Value cond = b.icmp_ult( + threadId, b.i32_val(numWarpsWithUniqueData * numThreadPerWarp)); + rewriter.create(loc, cond, atomicBlock, afterAtomics); + rewriter.setInsertionPointToStart(atomicBlock); + } + // Apply atomic add to update the histogram in shared memory. + for (int i = 0; i < warpLevelHistogram.size(); ++i) { + Value warpLevelHistogramValue = warpLevelHistogram[i]; + Value offset = b.add(b.mul(laneId, b.i32_val(warpLevelHistogram.size())), + b.i32_val(i)); + Value sharedMemPtr = + b.gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset); + atomicAdd(sharedMemPtr, warpLevelHistogramValue, loc, rewriter); + } + if (afterAtomics) { + rewriter.create(loc, afterAtomics); + rewriter.setInsertionPointToStart(afterAtomics); + } + b.barrier(); + // load the histogram to register with the right layout. + for (Value index : indices) { + Value sharedMemPtr = + b.gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, index); + Value val = b.load(i32_ty, sharedMemPtr); + histogramValues.push_back(val); + } + return histogramValues; +} + +namespace { +struct HistogramOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + explicit HistogramOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(triton::HistogramOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = adaptor.getSrc(); + auto typeConverter = getTypeConverter(); + SmallVector srcValues = unpackLLElements(loc, input, rewriter); + + Value llMask = adaptor.getMask(); + SmallVector maskValues; + if (llMask) + maskValues = unpackLLElements(loc, llMask, rewriter); + + int numBins = op.getType().getDimSize(0); + auto mod = op->getParentOfType(); + int numThreadsPerWarp = + triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + assert(numThreadsPerWarp == 32 || + numThreadsPerWarp == 64 && + "Only supports 32 or 64 threads per warp"); + int numWarps = triton::gpu::lookupNumWarps(op); + // Pad out the bins so that we have at least one bin per thread within a + // warp. + numBins = std::max(numBins, numThreadsPerWarp); + Value threadId = getThreadId(rewriter, loc); + auto srcType = op.getSrc().getType(); + // First compute a warp local histogram based on values owned by each warps. + SmallVector warpLevelHistogram = computeWarpLevelHistogram( + loc, srcType, srcValues, maskValues, numBins, numThreadsPerWarp, + threadId, rewriter, targetInfo); + + // Then use atomic to update the histogram in shared memory. + // TODO: we could skip this for cases with num_warps=1 as long as we can + // generate the right layout. Currently the warp level histogram generates + // data in the default blocked layout. + Value baseSharedMemPtr = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + auto dstType = op.getType(); + Attribute dstEncoding = dstType.getEncoding(); + auto indices = emitIndices(op.getLoc(), rewriter, targetInfo, dstEncoding, + dstType, true); + SmallVector innerDimIndices; + for (int i = 0; i < indices.size(); ++i) + innerDimIndices.push_back(indices[i][0]); + SmallVector histogramValue = computeCrossWarpHistogram( + loc, rewriter, srcType, baseSharedMemPtr, warpLevelHistogram, numBins, + numThreadsPerWarp, innerDimIndices, threadId, numWarps); + + Value results = packLLElements(loc, typeConverter, histogramValue, rewriter, + op.getType()); + rewriter.replaceOp(op, results); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; +} // namespace + +void mlir::triton::populateHistogramOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp new file mode 100644 index 000000000..8060b4431 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp @@ -0,0 +1,54 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +struct MakeRangeOpConversion + : public ConvertOpToLLVMPattern { + MakeRangeOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + LogicalResult + matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + RankedTensorType ty = op.getType(); + auto shape = ty.getShape(); + auto layout = ty.getEncoding(); + auto elemTy = ty.getElementType(); + assert(elemTy.isInteger(32)); + Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.getStart()); + auto idxs = emitIndices(loc, rewriter, targetInfo, layout, ty, true); + unsigned elems = idxs.size(); + SmallVector retVals(elems); + // TODO: slice layout has more elements than expected. + // Unexpected behavior for make range, but generally OK when followed by + // expand dims + broadcast. very weird behavior otherwise potentially. + for (const auto &multiDim : llvm::enumerate(idxs)) { + assert(multiDim.value().size() == 1); + retVals[multiDim.index()] = b.add(multiDim.value()[0], start); + } + auto typeConverter = getTypeConverter(); + Value result = packLLElements(loc, typeConverter, retVals, rewriter, ty); + rewriter.replaceOp(op, result); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateMakeRangeOpToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp new file mode 100644 index 000000000..31bcc9e2b --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -0,0 +1,203 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// blocked -> shared. +// Swizzling in shared memory to avoid bank conflict. Normally used for +// A/B operands of dots. +void lowerDistributedToShared( + Location loc, Value src, Value dst, Value adaptorSrc, + const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo, + std::pair *const llvmOpCount = nullptr) { + auto srcTy = cast(src.getType()); + auto dstTy = cast(dst.getType()); + auto elemTy = typeConverter->convertType(srcTy.getElementType()); + + auto inVals = unpackLLElements(loc, adaptorSrc, rewriter); + storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemObj, loc, rewriter, + targetInfo, llvmOpCount); +} + +struct GlobalScratchAllocOpConversion + : public ConvertOpToLLVMPattern { + const TargetInfoBase *targetInfo; + + GlobalScratchAllocOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), targetInfo(&targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::GlobalScratchAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + auto opOffsetAttr = op->getAttrOfType( + "ttg.global_scratch_memory_offset"); + assert(opOffsetAttr); + auto opOffset = opOffsetAttr.getValue().getZExtValue(); + + auto funcOp = op->getParentOfType(); + if (!funcOp) { + return failure(); + } + Value ptr = LLVM::getGlobalScratchPtr(loc, rewriter, *targetInfo, funcOp, + b.i32_val(opOffset)); + + rewriter.replaceOp(op, ptr); + return success(); + } +}; + +struct LocalAllocOpConversion + : public ConvertOpToLLVMPattern { + LocalAllocOpConversion(const LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.isSharedMemoryAlloc()) + return failure(); + Location loc = op->getLoc(); + Value smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + auto resultTy = cast(op.getType()); + auto typeConverter = getTypeConverter(); + + auto llvmElemTy = typeConverter->convertType(resultTy.getElementType()); + auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, resultTy.getRank(), + loc, rewriter); + // If there is an initial tensor, store it into the shared memory. + if (op.getSrc()) { + lowerDistributedToShared(loc, op.getSrc(), op.getResult(), + adaptor.getSrc(), smemObj, typeConverter, + rewriter, targetInfo); + } + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +struct LocalDeallocOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::LocalDeallocOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::LocalDeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { +public: + LocalLoadOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter); + } + +private: + LogicalResult + lowerSharedToDistributed(LocalLoadOp op, LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getResult().getType(); + + auto smemObj = LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getSrc(), + typeConverter->convertType(srcTy.getElementType()), rewriter); + auto elemLlvmTy = typeConverter->convertType(dstTy.getElementType()); + + SmallVector outVals = loadSharedToDistributed( + op, elemLlvmTy, smemObj, loc, rewriter, targetInfo); + + Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +struct LocalStoreOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern; + + LocalStoreOpConversion(const LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value memDescVal = op.getDst(); + auto llvmElemTy = + getTypeConverter()->convertType(op.getDst().getType().getElementType()); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct( + op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter); + + std::pair llvmOpCount; + lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(), + adaptor.getSrc(), smemObj, getTypeConverter(), + rewriter, targetInfo, &llvmOpCount); + + targetInfo.localStoreOpAnnotation(op, llvmOpCount.first, + llvmOpCount.second); + + rewriter.eraseOp(op); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateMemoryOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, + benefit); + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp new file mode 100644 index 000000000..fc26d4ae3 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp @@ -0,0 +1,244 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace { + +// The input print op contains: +// - a "prefix" (string) specified by the user, and +// - one or more "operands" (tensors). +// +// For each operand, we print all of the values contained in this GPU thread, +// one per line, along with the index of the value in its tensor. +struct PrintOpConversion : public ConvertOpToLLVMPattern { + explicit PrintOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + auto getPid = [&](int axis) { + return targetInfo.programId(rewriter, loc, + op->getParentOfType(), axis); + }; + std::array pid = {getPid(0), getPid(1), getPid(2)}; + + // Simple printf of a string without any tensors. + if (op.getNumOperands() == 0) { + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << "pid (" << getFormatSubstr(pid[0]) << ", " + << getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")" + << op.getPrefix(); + llPrintf(formatStr, {pid[0], pid[1], pid[2]}, {}, rewriter); + rewriter.eraseOp(op); + return success(); + } + + assert(op.getNumOperands() == op.getIsSigned().size()); + + for (size_t i = 0; i < op.getNumOperands(); i++) { + bool isSigned = op.getIsSigned()[i] > 0; + // Elements of the tensor that are resident in this GPU thread. + auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter); + + // Get the indices of `elems` within the tensor. Note that if `elems` + // has an "interesting" layout, then these will not be in any + // particularly nice order. + + // Extract the shape of the tensor being printed and use it to figure + // out how many digits we need for each of the dimensions. + SmallVector dimWidths; + SmallVector> indices; + if (auto rankedTy = + dyn_cast(op.getOperand(i).getType())) { + indices = emitIndices(loc, rewriter, targetInfo, rankedTy.getEncoding(), + rankedTy, true); + for (int64_t dim : rankedTy.getShape()) { + if (dim > 0) { + dimWidths.push_back(static_cast(std::ceil(std::log10(dim)))); + } else { + dimWidths.push_back(0); + } + } + } else { + // We're printing a scalar. + assert(elems.size() == 1); + indices.push_back({}); + } + + if (!elems.empty()) { + printTensor(op.getPrefix(), /*operand=*/i, + /*numOperands=*/op.getNumOperands(), elems, pid, indices, + dimWidths, op.getHex(), rewriter, isSigned); + } + } + rewriter.eraseOp(op); + return success(); + } + + void printTensor(StringRef prefixStr, size_t operand, size_t numOperands, + ArrayRef elems, std::array pid, + ArrayRef> indices, + ArrayRef dimWidths, bool hex, + ConversionPatternRewriter &rewriter, bool isSigned) const { + assert(!elems.empty()); + assert(elems.size() == indices.size()); + assert(dimWidths.size() == indices.front().size()); + + size_t rank = dimWidths.size(); + + // Format is: + // pid (, , ) idx (, , ...) (operand ) + // where we leave off "(operand )" if there's only one operand. + // + // The Python wrapper munges `prefix` so that it prints nicely (e.g. starts + // with " " and ends with ": "). + + Value formatStrValue; + int formatStrByteCount = 0; + for (int i = 0; i < elems.size(); i++) { + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + + // nvptx printf can only accept 32 args; if we pass more than that, it + // will print garbage for the trailing args. + constexpr int kMaxPrintfOperands = 32; + SmallVector printfOperands; + + // TODO(jlebar): We really should pad the pid, but because the max pid is + // not known at compile-time, this would require nontrivial device-side + // work. + os << "pid ("; + for (int j = 0; j < pid.size(); j++) { + if (j != 0) { + os << ", "; + } + os << getFormatSubstr(pid[j]); + printfOperands.push_back(pid[j]); + } + os << ") "; + + // If `rank` is large enough, we could end up exceeding + // kMaxPrintfOperands. In that case, just truncate the index. + // (Subtract 2 because we're going to add two operands after the index.) + int maxAllowedRank = kMaxPrintfOperands - printfOperands.size() - 2; + + os << "idx ("; + const auto &index = indices[i]; + for (size_t dim = 0; dim < index.size(); dim++) { + if (dim != 0) { + os << ", "; + } + if (dim == maxAllowedRank) { + os << "... (truncated)"; + break; + } + os << getFormatSubstr(index[dim], /*hex=*/false, + /*width=*/dimWidths[dim]); + printfOperands.push_back(index[dim]); + } + os << ")" << prefixStr; + + if (numOperands > 1) { + os << "(operand " << operand << ") "; + } + + auto elem = elems[i]; + + os << getFormatSubstr(elem, hex, /*width=*/std::nullopt, isSigned); + printfOperands.push_back(elem); + + // It's the same format string each iteration, but it's a lot easier if we + // construct the format string at the same time as we populate + // printfOperands. But we don't want to create BLOCK_SIZE duplicate + // strings, so we cache the Value. + auto isSignedOperands = + llvm::SmallVector(printfOperands.size(), isSigned); + if (i == 0) { + formatStrValue = llPrintf(formatStr, printfOperands, isSignedOperands, + rewriter, &formatStrByteCount); + } else { + targetInfo.printf(rewriter, formatStrValue, formatStrByteCount, + printfOperands, isSignedOperands); + } + } + } + + std::string getFormatSubstr(Value value, bool hex = false, + std::optional width = std::nullopt, + bool isSigned = false) const { + Type type = value.getType(); + // If the `value` is a pointer, just return %p. + if (isa(type)) { + return "%p"; + } + // Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the + // type (so 4 for fp16, 8 for int32, 16 for int64). + if (hex) { + // Ignore `width` for `hex` values, pad to typeWidth. + std::string ret = + "0x%0" + std::to_string(type.getIntOrFloatBitWidth() / 4); + if (type.getIntOrFloatBitWidth() > 32) { + ret += "ll"; + } + ret += "x"; + return ret; + } + + std::string prefix = "%"; + if (width.has_value()) { + prefix += std::to_string(*width); + } + + if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { + return prefix + "f"; + } else if (type.isInteger()) { + if (type.getIntOrFloatBitWidth() == 64) + return prefix + (isSigned ? "lli" : "llu"); + else + return prefix + (isSigned ? "i" : "u"); + } + assert(false && "not supported type"); + return ""; + } + + // Returns a Value for the format string, which you can reuse. Writes the byte + // count for the string to |formatStrByteCount| if not null. + Value llPrintf(StringRef msg, ValueRange args, ArrayRef isSigned, + ConversionPatternRewriter &rewriter, + int *formatStrByteCount = nullptr) const { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), + rewriter, "printfFormat_", msgNewline); + targetInfo.printf(rewriter, msgValue, msgNewline.size_in_bytes(), args, + isSigned); + if (formatStrByteCount) + *formatStrByteCount = msgNewline.size_in_bytes(); + return msgValue; + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populatePrintOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp new file mode 100644 index 000000000..a17526f10 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -0,0 +1,391 @@ +#include "ReduceScanCommon.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::DistributedEncodingTrait; +using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getThreadOrder; +using ::mlir::triton::gpu::getTotalElemsPerThread; + +namespace { +struct ReduceOpConversion + : public ConvertTritonGPUReduceScanToLLVMPattern { +public: + ReduceOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertTritonGPUReduceScanToLLVMPattern(typeConverter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ReduceOpHelper helper(op); + assert(helper.isReduceWithinCTA() && + "Unexpected srcLayout in ReduceOpConversion"); + Location loc = op->getLoc(); + + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); + std::map, SmallVector> accs; + std::map, SmallVector> indices; + // First reduce all the values along axis within each thread. + reduceWithinThreads(helper, srcValues, accs, indices, rewriter); + + // Then reduce across threads within a warp. + reduceWithinWarps(helper, accs, rewriter); + + if (helper.isWarpSynchronous()) { + // If all the values to be reduced are within the same warp there is + // nothing left to do. + packResults(helper, accs, rewriter); + return success(); + } + + // Compute a shared memory base per operand. + auto smemShape = helper.getScratchRepShape(); + + SmallVector smemBases = + getSmemBases(op, product(smemShape), rewriter, targetInfo); + + storeWarpReduceToSharedMemory(helper, accs, indices, smemBases, rewriter); + + sync(rewriter, loc, op); + + // The second round of shuffle reduction + // now the problem size: sizeInterWarps, s1, s2, .. , sn + // where sizeInterWarps is 2^m + // + // Each thread needs to process: + // elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads + accumulatePartialReductions(helper, smemBases, rewriter); + + // We could avoid this barrier in some of the layouts, however this is not + // the general case. + // TODO: optimize the barrier in case the layouts are accepted. + sync(rewriter, loc, op); + + // set output values + loadReductionAndPackResult(helper, smemShape, smemBases, rewriter); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; + + void accumulate(Location loc, ConversionPatternRewriter &rewriter, + Region &combineOp, SmallVector &acc, ValueRange cur, + Value pred = {}) const { + auto results = applyCombineOp(loc, rewriter, combineOp, acc, cur, pred); + if (acc.size() < results.size()) { + acc.resize(results.size()); + } + for (unsigned i = 0; i < acc.size(); ++i) { + acc[i] = results[i]; + } + } + + SmallVector> + unpackInputs(Location loc, triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto types = op.getInputTypes(); + auto operands = adaptor.getOperands(); + unsigned srcElems = getTotalElemsPerThread(types[0]); + SmallVector> srcValues(srcElems); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto values = unpackLLElements(loc, operands[i], rewriter); + + assert(values.size() == srcValues.size()); + for (unsigned j = 0; j < srcValues.size(); ++j) { + srcValues[j].push_back(values[j]); + } + } + return srcValues; + } + + void sync(ConversionPatternRewriter &rewriter, Location loc, + triton::ReduceOp op) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + b.barrier(); + } + + // Reduce along op axis for elements that are in the same thread. The + // accumulated value is stored in accs. + void reduceWithinThreads( + ReduceOpHelper &helper, SmallVector> &srcValues, + std::map, SmallVector> &accs, + std::map, SmallVector> &indices, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + RankedTensorType operandType = op.getInputTypes()[0]; + // Assumes offsets don't actually depend on type + SmallVector> offsets = + emitOffsetForLayout(helper.getSrcLayout(), operandType); + + // Thread X might hold the same input value in two registers. Get the + // indices in `offsets` that hold unique values, and only accumulate over + // those. + llvm::MapVector, int> uniqueOffsets; + for (int i = 0; i < offsets.size(); ++i) { + uniqueOffsets.insert({offsets[i], i}); + } + + auto *combineOp = &op.getCombineOp(); + auto srcIndices = emitIndices(op.getLoc(), rewriter, targetInfo, + helper.getSrcLayout(), operandType, true); + // reduce within threads + for (const auto &[_, i] : uniqueOffsets) { + SmallVector key = offsets[i]; + key[op.getAxis()] = 0; + bool isFirst = accs.find(key) == accs.end(); + accumulate(op.getLoc(), rewriter, *combineOp, accs[key], srcValues[i]); + if (isFirst) + indices[key] = srcIndices[i]; + } + } + + // Apply warp reduction across the given number of contiguous lanes using op + // region and the accumulator values as source. + void warpReduce(ConversionPatternRewriter &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce, unsigned interleave, + Value pred = {}) const { + auto success = targetInfo.warpReduce(rewriter, loc, acc, op, + numLaneToReduce, interleave); + if (success) + return; + + for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) { + SmallVector shfl(acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + shfl[i] = targetInfo.shuffleXor(rewriter, loc, acc[i], N * interleave); + } + accumulate(op.getLoc(), rewriter, op.getCombineOp(), acc, shfl, pred); + } + } + + // Reduce across threads within each warp. + void + reduceWithinWarps(ReduceOpHelper &helper, + std::map, SmallVector> &accs, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData(); + unsigned threadOffsetOnReductionAxis = + helper.getThreadOffsetOnReductionAxis(); + for (auto it : accs) { + const SmallVector &key = it.first; + SmallVector &acc = accs[key]; + warpReduce(rewriter, op.getLoc(), acc, op, sizeIntraWarps, + threadOffsetOnReductionAxis); + } + } + + // Pack the accumulator values and replace the reduce op with the result. + void packResults(ReduceOpHelper &helper, + std::map, SmallVector> &accs, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + unsigned axis = op.getAxis(); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (auto resultTy = + dyn_cast(op.getResult()[i].getType())) { + auto resultLayout = cast(resultTy.getEncoding()); + unsigned resultElems = getTotalElemsPerThread(resultTy); + SmallVector> resultOffset = + emitOffsetForLayout(resultLayout, resultTy); + SmallVector resultVals; + for (int j = 0; j < resultElems; j++) { + auto key = resultOffset[j]; + key.insert(key.begin() + axis, 0); + resultVals.push_back(accs[key][i]); + } + results[i] = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, resultTy); + } else + results[i] = accs.begin()->second[i]; + } + rewriter.replaceOp(op, results); + } + + void storeWarpReduceToSharedMemory( + ReduceOpHelper &helper, + std::map, SmallVector> &accs, + std::map, SmallVector> &indices, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcLayout = + mlir::cast(helper.getSrcLayout()); + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + unsigned axis = op.getAxis(); + auto smemShape = helper.getScratchRepShape(); + + // Lezcano: We should move all the shared memory logic to use LLs natively + auto srcShape = helper.getSrcShape(); + auto kLane = rewriter.getStringAttr("lane"); + auto [multiDimLaneId, isRepresentativeLane] = + delinearize(rewriter, loc, srcLayout, srcShape, kLane, laneId); + auto kWarp = rewriter.getStringAttr("warp"); + auto [multiDimWarpId, isRepresentativeWarp] = + delinearize(rewriter, loc, srcLayout, srcShape, kWarp, warpId); + + Value laneIdAxis = multiDimLaneId[axis]; + Value laneZero = b.icmp_eq(laneIdAxis, b.i32_val(0)); + Value write = + b.and_(b.and_(isRepresentativeLane, isRepresentativeWarp), laneZero); + + Value warpIdAxis = multiDimWarpId[axis]; + + auto smemOrder = helper.getOrderWithAxisAtBeginning(); + for (auto it : accs) { + const SmallVector &key = it.first; + SmallVector &acc = it.second; + + SmallVector writeIdx = indices[key]; + writeIdx[axis] = warpIdAxis; + Value writeOffset = + linearize(rewriter, loc, writeIdx, smemShape, smemOrder); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + Value writePtr = + b.gep(smemBases[i].getType(), elemTy, smemBases[i], writeOffset); + targetInfo.storeShared(rewriter, loc, writePtr, acc[i], write); + } + } + } + + // Load the reduction of each warp and accumulate them to a final value and + // store back to shared memory. + void accumulatePartialReductions(ReduceOpHelper &helper, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + auto smemShape = helper.getScratchRepShape(); + unsigned elems = product(smemShape); + unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData(); + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + auto mod = op->getParentOfType(); + int numLanes = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + int numWarps = triton::gpu::lookupNumWarps(op); + int numThreads = numLanes * numWarps; + + Value threadId = getThreadId(rewriter, loc); + Value warpSize = b.i32_val(numLanes); + Value laneId = b.urem(threadId, warpSize); + Value zero = b.i32_val(0); + + unsigned elemsPerThread = std::max(elems / numThreads, 1); + Value threadIsNeeded = b.icmp_slt(threadId, b.i32_val(elems)); + Value readOffset = threadId; + for (unsigned round = 0; round < elemsPerThread; ++round) { + SmallVector acc(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + Value readPtr = + b.gep(smemBases[i].getType(), elemTy, smemBases[i], readOffset); + acc[i] = targetInfo.loadShared(rewriter, loc, readPtr, elemTy, + threadIsNeeded); + } + warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */, + threadIsNeeded); + // only the first thread in each sizeInterWarps is writing + Value writeOffset = readOffset; + SmallVector writePtrs(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + writePtrs[i] = + b.gep(smemBases[i].getType(), elemTy, smemBases[i], writeOffset); + } + + Value laneIdModSizeInterWarps = b.urem(laneId, b.i32_val(sizeInterWarps)); + Value laneIdModSizeInterWarpsIsZero = + b.icmp_eq(laneIdModSizeInterWarps, zero); + Value pred = b.and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero); + + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + targetInfo.storeShared(rewriter, loc, writePtrs[i], acc[i], pred); + } + + if (round != elemsPerThread - 1) { + readOffset = b.add(readOffset, b.i32_val(numThreads)); + } + } + } + + // Load the final reduction from shared memory and replace the reduce result + // with it. + void loadReductionAndPackResult(ReduceOpHelper &helper, + SmallVector smemShape, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcLayout = helper.getSrcLayout(); + auto axis = op.getAxis(); + auto smemOrder = helper.getOrderWithAxisAtBeginning(); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + if (auto resultTy = + dyn_cast(op.getResult()[i].getType())) { + // nd-tensor where n >= 1 + auto resultLayout = cast(resultTy.getEncoding()); + unsigned resultElems = getTotalElemsPerThread(resultTy); + auto resultIndices = emitIndices(loc, rewriter, targetInfo, + resultLayout, resultTy, true); + auto resultShape = resultTy.getShape(); + assert(resultIndices.size() == resultElems); + + SmallVector resultVals(resultElems); + for (size_t j = 0; j < resultElems; ++j) { + SmallVector readIdx = resultIndices[j]; + readIdx.insert(readIdx.begin() + op.getAxis(), b.i32_val(0)); + for (size_t resultIdx = 0, resultDim = resultShape.size(); + resultIdx < resultDim; ++resultIdx) { + auto smemIdx = resultIdx < op.getAxis() ? resultIdx : resultIdx + 1; + if (resultShape[resultIdx] > smemShape[smemIdx]) { + // When srcShape smaller than src sizePerThread, only srcShape + // elements is accumulated in smem. Modulo smemShape effectively + // replicates srcShape elements to src sizePerThread. + readIdx[smemIdx] = + b.urem(readIdx[smemIdx], b.i32_val(smemShape[smemIdx])); + } + } + Value readOffset = + linearize(rewriter, loc, readIdx, smemShape, smemOrder); + Value readPtr = + b.gep(smemBases[i].getType(), elemTy, smemBases[i], readOffset); + resultVals[j] = b.load(elemTy, readPtr); + } + + results[i] = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, resultTy); + } else { + // 0d-tensor -> scalar + results[i] = b.load(elemTy, smemBases[i]); + } + } + rewriter.replaceOp(op, results); + } +}; +} // namespace + +void mlir::triton::populateReduceOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h new file mode 100644 index 000000000..e3012d29d --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h @@ -0,0 +1,163 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_REDUCESCANCOMMON_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_REDUCESCANCOMMON_H + +// TODO: refactor so that it doesn't fail if Allocation.h +// is included after utility.h (due to conflict in `store` macro +// and +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Transforms/DialectConversion.h" + +// +#include "mlir/IR/TypeUtilities.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include +#include + +#define DEBUG_TYPE "ttgpu_to_llvm" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir::triton { +class ReduceOp; +class ScanOp; + +inline SmallVector +inlineCombineBlock(ConversionPatternRewriter &rewriter, Block &combineBlock, + Block *insertionBlock, Block::iterator insertionPoint, + ValueRange combineArgs) { + auto returnOp = combineBlock.getTerminator(); + rewriter.inlineBlockBefore(&combineBlock, insertionBlock, insertionPoint, + combineArgs); + + auto results = SmallVector(returnOp->getOperands()); + + // Delete the terminator, which is no longer used + rewriter.eraseOp(returnOp); + return results; +} + +inline SmallVector applyCombineOp(Location loc, + ConversionPatternRewriter &rewriter, + Region &combineOp, ValueRange acc, + ValueRange cur, Value pred = {}) { + // Allows for passing an uninitialized acc and use cur as the neutral element + if (acc.size() == 0) { + return cur; + } + assert(cur.size() == acc.size()); + + // Create a new copy of the combine block, and try to speculatively inline it + Block *currentBlock = rewriter.getBlock(); + Region &parent = *currentBlock->getParent(); + + rewriter.cloneRegionBefore(combineOp, parent, + std::next(currentBlock->getIterator())); + Block &newCombine = *currentBlock->getNextNode(); + + llvm::SmallVector combineArgs(2 * acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + combineArgs[i] = acc[i]; + combineArgs[acc.size() + i] = cur[i]; + } + + auto isRegionSpeculatable = + std::all_of(newCombine.begin(), newCombine.end(), + [](auto &op) { return isSpeculatable(&op); }); + + if (!pred || isRegionSpeculatable) { + // Fast path, region has no side effects so we can unconditionally execute + return inlineCombineBlock(rewriter, newCombine, currentBlock, + rewriter.getInsertionPoint(), combineArgs); + } + + // Slow case, create an if to only execute region when pred is true + // #currentBlock + // if (pred) { + // #newCombine + // results = combineOp(cur, acc) + // yield results + // } else { + // yield undef + // } + // #thenBlock + Block *thenBlock = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + + auto returnOp = newCombine.getTerminator(); + auto results = SmallVector(returnOp->getOperands()); + + rewriter.setInsertionPointToEnd(currentBlock); + SmallVector thenBlockArgs; + thenBlockArgs.reserve(results.size()); + for (auto result : results) { + auto ty = result.getType(); + auto undef = rewriter.create(loc, ty); + thenBlockArgs.push_back(undef); + thenBlock->addArgument(ty, loc); + } + rewriter.create(loc, pred, &newCombine, combineArgs, + thenBlock, thenBlockArgs); + + // Split a block after the call. + rewriter.setInsertionPointToEnd(&newCombine); + rewriter.replaceOpWithNewOp(returnOp, thenBlock, results); + rewriter.setInsertionPointToStart(thenBlock); + return SmallVector(thenBlock->getArguments()); +} + +} // namespace mlir::triton + +template +class ConvertTritonGPUReduceScanToLLVMPattern + : public ConvertOpToLLVMPattern { +public: + // Make sure the class is only instantiated with Reduce and Scan + static_assert(std::is_same_v || + std::is_same_v); + + using ConvertOpToLLVMPattern::getTypeConverter; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + // Return the pointee type of the shared memory pointer for operand i. + Type getElementType(SourceOp op, int i) const { + auto ty = op.getInputTypes()[i].getElementType(); + return getTypeConverter()->convertType(ty); + } + + // Helper to compute the smem bases in both reductions and scans + SmallVector getSmemBases(SourceOp op, unsigned elems, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const { + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + // indices will store the index of the op operands in descending order + // of their bitwidths + std::vector indices(op.getNumOperands()); + std::iota(indices.begin(), indices.end(), 0); + + std::sort(indices.begin(), indices.end(), [&](unsigned i, unsigned j) { + return op.getElementTypes()[i].getIntOrFloatBitWidth() > + op.getElementTypes()[j].getIntOrFloatBitWidth(); + }); + // Assign base index to each operand in their order in indices + std::map indexToBase; + auto basePtr = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + indexToBase[indices[0]] = basePtr; + for (unsigned i = 1; i < op.getNumOperands(); ++i) { + indexToBase[indices[i]] = + b.gep(basePtr.getType(), getElementType(op, indices[i - 1]), + indexToBase[indices[i - 1]], b.i32_val(elems)); + } + // smemBases[k] is the base pointer for the k-th operand + SmallVector smemBases(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + smemBases[i] = indexToBase[i]; + } + return smemBases; + } +}; + +#endif diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp new file mode 100644 index 000000000..972fc5592 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp @@ -0,0 +1,38 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct GetProgramIdOpConversion + : public ConvertOpToLLVMPattern { + explicit GetProgramIdOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value programId = targetInfo.programId(rewriter, op->getLoc(), + op->getParentOfType(), + op.getAxisAsInt()); + rewriter.replaceOp(op, programId); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp new file mode 100644 index 000000000..a89f9be8a --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -0,0 +1,573 @@ +#include "ReduceScanCommon.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::toLinearEncoding; + +// apply combine region to acc and cur and accumulate it into acc +static SmallVector accumulate(ScanLoweringHelper &helper, + ConversionPatternRewriter &rewriter, + ValueRange acc, ValueRange cur, + Value pred = {}) { + auto loc = helper.getLoc(); + auto &combineOp = helper.getCombineOp(); + return applyCombineOp(loc, rewriter, combineOp, acc, cur, pred); +} + +// Scan a contiguous elements within a thread and update `srcValues` in place. +static void +scanThreadContiguousElements(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper) { + // Depending on layout contiguous elements along axis dim may not be + // contiguous in srcValues. Keep track of what elements belong to the same + // chunk of contiguous elements. + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned numChunks = srcValues.size() / scanElementsPerThreads; + unsigned stride = helper.getAxisElementStride(); + SmallVector> accs(numChunks); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + // Change this into emitOffsetForLayout? + unsigned accIndex = (srcIndex % stride) + + ((srcIndex / stride) / scanElementsPerThreads) * stride; + + accs[accIndex] = + accumulate(helper, rewriter, accs[accIndex], srcValues[srcIndex]); + srcValues[srcIndex] = accs[accIndex]; + } +} + +// Apply a scan across threads of the warp for the last element of each +// contiguous group of elements. +static void warpScan(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + ScanLoweringHelper &helper, Value laneIdAxis) { + Location loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + // Reduce within warps. + SmallVector acc = srcValues[srcIndex]; + for (unsigned i = 1; i <= scanDim / 2; i <<= 1) { + SmallVector shfl(acc.size()); + for (unsigned j = 0; j < acc.size(); ++j) { + shfl[j] = targetInfo.shuffleUp(rewriter, loc, acc[j], i * threadStride); + } + Value mask = b.icmp_sge(laneIdAxis, b.i32_val(i)); + SmallVector tempAcc = + accumulate(helper, rewriter, shfl, acc, mask); + for (unsigned j = 0; j < acc.size(); ++j) { + acc[j] = b.select(mask, tempAcc[j], acc[j]); + } + } + srcValues[srcIndex] = std::move(acc); + } +} + +// For each set of contiguous elements within a thread we store the partial +// reduction into shared memory. Each parallel scan and each warp will store its +// own partial reductions. The shared memory is organized as follow: +// ----------------------------------------------------------------- +// chunk 0: | acc[0] warp 0 | acc[1] warp 0 | acc[0] warp 1 | acc[1] warp 1 | +// chunk 1: | acc[0] warp 0 | acc[1] warp 0 | acc[0] warp 1 | acc[1] warp 1 | +static void storeWarpAccumulator(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId, + Value warpId, SmallVector smemBases, + SmallVector smemTypes, + Value parallelLaneId, Value isRepresentative, + const TargetInfoBase &targetInfo) { + Location loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + unsigned chunkId = 0; + unsigned elementStride = helper.getAxisElementStride(); + + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + auto lastElement = srcValues[srcIndex]; + Value mask = b.icmp_eq(laneId, b.i32_val(scanDim - 1)); + mask = b.and_(mask, isRepresentative); + Value index = + b.add(parallelLaneId, b.mul(warpId, b.i32_val(numParallelLane))); + index = b.add(index, b.i32_val(chunkId * numParallelLane * axisNumWarps)); + for (unsigned i = 0; i < lastElement.size(); ++i) { + Value writePtr = + b.gep(smemBases[i].getType(), smemTypes[i], smemBases[i], index); + targetInfo.storeShared(rewriter, loc, writePtr, lastElement[i], mask); + } + chunkId++; + } +} + +// Read the partial reductions from shared memory from each chunk of contiguous +// elements for each warp and parallel scan. Then combine the partial reduction +// with the right elements. Within a given contiguous element chunk we update +// all the elements by accumulating the value from the last element of the +// reduced value from the previous lane. +static void AddPartialReduce(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + ScanLoweringHelper &helper, + ArrayRef smemBases, + ArrayRef smemTypes, Value warpId, + Value laneIdAxis, Value parallelLaneId) { + Location loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + Value maskNotFirstWarp = b.icmp_ne(warpId, b.i32_val(0)); + Value maskNotFirstLane = b.icmp_ne(laneIdAxis, b.i32_val(0)); + Value maskNotFirstThread = b.or_(maskNotFirstWarp, maskNotFirstLane); + struct Accumulator { + SmallVector acc; + SmallVector maskedAcc; + }; + unsigned numScanBlocks = helper.getAxisNumBlocks(); + unsigned numParallelBlocks = helper.getNonAxisNumBlocks(); + assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread * + scanElementsPerThreads == + srcValues.size()); + SmallVector accumulators(numParallelBlocks * + parallelElementsPerThread); + unsigned chunkId = 0; + unsigned blockStride = helper.getAxisBlockStride(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + // Accumulate the partial reduction from shared memory. Decide which + // accumulator to combine based on whether the elements belong to the same + // dimension along axis. + unsigned blockId = chunkId / parallelElementsPerThread; + unsigned parallelBlockId = + blockId % blockStride + + ((blockId / blockStride) / numScanBlocks) * blockStride; + unsigned accumulatorIndex = chunkId % parallelElementsPerThread + + parallelBlockId * parallelElementsPerThread; + Accumulator &accumulator = accumulators[accumulatorIndex]; + unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; + for (unsigned i = 0; i < axisNumWarps; ++i) { + Value index = + b.add(parallelLaneId, + b.i32_val(numParallelLane * (i + chunkId * axisNumWarps))); + SmallVector partialReduce(helper.getNumOperands()); + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + auto elemTy = smemTypes[j]; + Value ptr = b.gep(smemBases[j].getType(), elemTy, smemBases[j], index); + partialReduce[j] = b.load(elemTy, ptr); + } + + if (accumulator.acc.size() == 0) { + accumulator.acc = partialReduce; + accumulator.maskedAcc = partialReduce; + continue; + } + Value mask = b.icmp_sge(warpId, b.i32_val(i + 1)); + accumulator.acc = + accumulate(helper, rewriter, accumulator.acc, partialReduce); + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + accumulator.maskedAcc[j] = + b.select(mask, accumulator.acc[j], accumulator.maskedAcc[j]); + } + } + + Value pred = axisBlockId == 0 ? maskNotFirstWarp : Value{}; + auto temp = accumulate(helper, rewriter, accumulator.maskedAcc, + srcValues[srcIndex], pred); + if (axisBlockId == 0) { + // For the first warp and first chunk we don't have anything to + // accumulate. + auto val = srcValues[srcIndex]; + for (unsigned i = 0; i < helper.getNumOperands(); ++i) { + temp[i] = b.select(maskNotFirstWarp, temp[i], val[i]); + } + } + srcValues[srcIndex] = temp; + // Update the rest of the contiguous elements. + SmallVector lastElement(helper.getNumOperands()); + for (unsigned i = 0; i < helper.getNumOperands(); ++i) { + auto elem = targetInfo.shuffleUp(rewriter, loc, temp[i], threadStride); + lastElement[i] = + b.select(maskNotFirstLane, elem, accumulator.maskedAcc[i]); + } + for (unsigned i = 1; i < scanElementsPerThreads; ++i) { + pred = axisBlockId == 0 ? maskNotFirstThread : Value{}; + auto laneValue = srcValues[srcIndex - i * elementStride]; + laneValue = accumulate(helper, rewriter, lastElement, laneValue, pred); + if (axisBlockId == 0) { + // For the first warp and first chunk we don't have anything to + // accumulate. + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + laneValue[j] = b.select(maskNotFirstThread, laneValue[j], + srcValues[srcIndex - i * elementStride][j]); + } + } + srcValues[srcIndex - i * elementStride] = std::move(laneValue); + } + // For the next chunk start back from the value containing the + // accumulated value of all the warps. + accumulator.maskedAcc = accumulator.acc; + chunkId++; + } +} + +static void AddPartialReduceOneWarp(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + ScanLoweringHelper &helper, Value warpId, + Value laneIdAxis, Value laneIdLast) { + Location loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + Value maskFirstWarp = b.icmp_eq(warpId, b.i32_val(0)); + Value maskFirstLane = b.icmp_eq(laneIdAxis, b.i32_val(0)); + Value maskFirstThread = b.and_(maskFirstWarp, maskFirstLane); + unsigned numScanBlocks = helper.getAxisNumBlocks(); + unsigned numParallelBlocks = helper.getNonAxisNumBlocks(); + assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread * + scanElementsPerThreads == + srcValues.size()); + SmallVector> accumulators(numParallelBlocks * + parallelElementsPerThread); + unsigned chunkId = 0; + unsigned blockStride = helper.getAxisBlockStride(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + unsigned blockId = chunkId / parallelElementsPerThread; + unsigned parallelBlockId = + blockId % blockStride + + ((blockId / blockStride) / numScanBlocks) * blockStride; + unsigned accumulatorIndex = chunkId % parallelElementsPerThread + + parallelBlockId * parallelElementsPerThread; + auto &accumulator = accumulators[accumulatorIndex]; + unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; + if (axisBlockId == 0) // First chunk and first block + accumulator = srcValues[srcIndex]; + else + srcValues[srcIndex] = + accumulate(helper, rewriter, accumulator, srcValues[srcIndex]); + // Update the rest of the contiguous elements. + auto lastElement = srcValues[srcIndex]; + if (scanDim > 1) { + for (unsigned i = 0; i < helper.getNumOperands(); ++i) { + lastElement[i] = targetInfo.shuffleUp( + rewriter, loc, srcValues[srcIndex][i], threadStride); + lastElement[i] = + b.select(maskFirstLane, accumulator[i], lastElement[i]); + if (numScanBlocks > 1) + // Update accumulator with the value from the last lane. + accumulator[i] = targetInfo.shuffleIdx( + rewriter, loc, srcValues[srcIndex][i], laneIdLast); + } + } else if (numScanBlocks > 1) { + accumulator = srcValues[srcIndex]; + } + for (unsigned i = 1; i < scanElementsPerThreads; ++i) { + auto laneValue = srcValues[srcIndex - i * elementStride]; + laneValue = accumulate(helper, rewriter, lastElement, laneValue); + if (axisBlockId == 0) { + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + // For the first warp and first chunk we don't have anything to + // accumulate. + laneValue[j] = b.select(maskFirstThread, + srcValues[srcIndex - i * elementStride][j], + laneValue[j]); + } + } + srcValues[srcIndex - i * elementStride] = std::move(laneValue); + } + // For the next chunk start back from the value containing the + // accumulated value of all the warps. + chunkId++; + } +} + +namespace { +struct ScanOpConversion + : public ConvertTritonGPUReduceScanToLLVMPattern { +public: + using ConvertTritonGPUReduceScanToLLVMPattern< + triton::ScanOp>::ConvertTritonGPUReduceScanToLLVMPattern; + explicit ScanOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertTritonGPUReduceScanToLLVMPattern(typeConverter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (succeeded(emitFastScan(op, adaptor, rewriter, targetInfo))) + return success(); + return failure(); + } + +private: + const TargetInfoBase &targetInfo; + std::tuple, Value> + getMultiDimLaneId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId) const; + std::tuple, Value> + getMultiDimWarpId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value warpId) const; + std::tuple + getDelinearizedIds(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId, + Value warpId) const; + LogicalResult emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const; +}; + +std::tuple, Value> +ScanOpConversion::getMultiDimLaneId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value laneId) const { + auto loc = helper.getLoc(); + auto srcEncoding = helper.getEncoding(); + auto kWarp = rewriter.getStringAttr("lane"); + return delinearize(rewriter, loc, srcEncoding, helper.getShape(), kWarp, + laneId); +} + +std::tuple, Value> +ScanOpConversion::getMultiDimWarpId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value warpId) const { + auto loc = helper.getLoc(); + auto srcEncoding = helper.getEncoding(); + auto kWarp = rewriter.getStringAttr("warp"); + return delinearize(rewriter, loc, srcEncoding, helper.getShape(), kWarp, + warpId); +} + +// Break up the threadId into lane and warp id along the scan dimension and +// compute a flat id for the parallel dimensions. +std::tuple +ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId, + Value warpId) const { + auto loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned axis = helper.getAxis(); + auto srcEncoding = helper.getEncoding(); + + auto threadsPerWarp = srcEncoding.getThreadsPerWarp(); + auto warpsPerCTA = srcEncoding.getWarpsPerCTA(); + auto [multiDimLaneId, isRepresentativeLane] = + getMultiDimLaneId(rewriter, helper, laneId); + auto [multiDimWarpId, isRepresentativeWarp] = + getMultiDimWarpId(rewriter, helper, warpId); + + Value laneIdAxis = multiDimLaneId[axis]; + Value warpIdAxis = multiDimWarpId[axis]; + + multiDimLaneId[axis] = b.i32_val(0); + threadsPerWarp[axis] = 1; + Value laneIdParallel = linearize(rewriter, loc, multiDimLaneId, + threadsPerWarp, helper.getOrder()); + multiDimWarpId[axis] = b.i32_val(0); + warpsPerCTA[axis] = 1; + Value warpIdParallel = + linearize(rewriter, loc, multiDimWarpId, warpsPerCTA, helper.getOrder()); + Value flatIdParallel = b.add( + laneIdParallel, + b.mul(warpIdParallel, b.i32_val(helper.getNonAxisNumThreadsPerWarp()))); + auto isRepresentative = b.and_(isRepresentativeLane, isRepresentativeWarp); + return std::make_tuple(laneIdAxis, warpIdAxis, flatIdParallel, + isRepresentative); +} + +SmallVector> +unpackInputs(Location loc, triton::ScanOp op, triton::ScanOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter) { + auto types = op.getInputTypes(); + auto operands = adaptor.getOperands(); + unsigned srcElems = getTotalElemsPerThread(types[0]); + SmallVector> srcValues(srcElems); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto values = unpackLLElements(loc, operands[i], rewriter); + + assert(values.size() == srcValues.size()); + for (unsigned j = 0; j < srcValues.size(); ++j) { + srcValues[j].push_back(values[j]); + } + } + return srcValues; +} + +// Flip the srcValues. Both reverses the chunks and reverses the lanes. +// Lane reversal is done with a butterfly shuffle flip (divide and flip). +SmallVector> +flipSrcValues(Location loc, triton::ScanOp op, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + SmallVector> srcValues, int iWarpSize) { + SmallVector> values(srcValues.size()); + for (int i = 0; i < srcValues.size(); ++i) { + int revIndex = srcValues.size() - i - 1; + for (unsigned j = 0; j < op.getNumOperands(); ++j) { + for (unsigned k = iWarpSize / 2; k >= 1; k = k / 2) { + srcValues[revIndex][j] = + targetInfo.shuffleXor(rewriter, loc, srcValues[revIndex][j], k); + } + values[i].push_back(srcValues[revIndex][j]); + } + } + return values; +} + +// Lowering using warp shuffle operations to do warp level scan. +LogicalResult +ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const { + ScanLoweringHelper helper(op); + auto loc = helper.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + if (!helper.isSupported()) + return op.emitError("TODO: unsupported scan layout"); + + Value threadId = getThreadId(rewriter, loc); + auto mod = op->getParentOfType(); + unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + Value warpSize = b.i32_val(iWarpSize); + Value warpId = b.udiv(threadId, warpSize); + Value laneId = b.urem(threadId, warpSize); + + auto [laneIdAxis, warpIdAxis, flatIdParallel, isRepresentative] = + getDelinearizedIds(rewriter, helper, laneId, warpId); + auto axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + auto srcValues = + unpackInputs(loc, op, adaptor, rewriter, *getTypeConverter()); + + // For the reverse option we apply flip(scan(flip()) in + // order to avoid having a separate code path in the reverse direction. + // We do this by 1) reversing chunks, 2) reversing lanes, 3) reversing + // warp ids and then undoing this below. + // (Note: Tried pretty hard to get shflDownSync to work but I ended up + // having to add a lot of the complex cross warp code (if rev switch + // first/last etc). Reverse first seems more maintainable.) + if (op.getReverse()) { + warpIdAxis = b.sub(b.i32_val(axisNumWarps - 1), warpIdAxis); + srcValues = + flipSrcValues(loc, op, rewriter, targetInfo, srcValues, iWarpSize); + } + + // Scan contiguous elements in a thread and update `srcValues`. + scanThreadContiguousElements(srcValues, rewriter, helper); + // Apply warp level scan to the last element of each chunk of contiguous + // elements. + warpScan(srcValues, rewriter, targetInfo, helper, laneIdAxis); + + if (axisNumWarps > 1) { + // Slow path for the case where there are multiple warps with unique data on + // the axis. + auto elems = helper.getScratchSizeInElems(); + SmallVector smemBases = + getSmemBases(op, elems, rewriter, targetInfo); + SmallVector smemTypes(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + smemTypes[i] = getElementType(op, i); + } + + // Store the partial reducing for each warp into shared memory. + storeWarpAccumulator(srcValues, rewriter, helper, laneIdAxis, warpIdAxis, + smemBases, smemTypes, flatIdParallel, isRepresentative, + targetInfo); + b.barrier(); + // Read back the partial reduction of each warp and accumulate them based on + // warpId. Then update each chunk of contiguous elements by adding the + // accumulated value from the previous lane. + AddPartialReduce(srcValues, rewriter, targetInfo, helper, smemBases, + smemTypes, warpIdAxis, laneIdAxis, flatIdParallel); + } else if (srcValues.size() > 1) { + // Fast path for the case where there is only one warp with unique data on + // the axis. + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + auto multiDimLaneId = + std::get<0>(getMultiDimLaneId(rewriter, helper, laneId)); + multiDimLaneId[helper.getAxis()] = b.i32_val(scanDim - 1); + auto linearEncoding = helper.getEncoding(); + auto threadsPerWarp = linearEncoding.getThreadsPerWarp(); + auto laneIdLast = linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, + helper.getOrder()); + AddPartialReduceOneWarp(srcValues, rewriter, targetInfo, helper, warpIdAxis, + laneIdAxis, laneIdLast); + } // else axisNumWarps == 1 and srcValues.size() == 1, nothing to do. + + auto transpose = [](const SmallVector> &v) { + assert(v.size() > 0 && v[0].size() > 0); + auto ret = SmallVector>(v[0].size(), + SmallVector(v.size())); + for (int i = 0; i < v.size(); ++i) { + for (int j = 0; j < v[0].size(); ++j) { + ret[j][i] = v[i][j]; + } + } + return ret; + }; + + SmallVector results(op.getNumOperands()); + if (op.getReverse()) { + srcValues = + flipSrcValues(loc, op, rewriter, targetInfo, srcValues, iWarpSize); + } + + auto valuesTransposed = transpose(srcValues); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto resultTy = dyn_cast(op.getResult()[i].getType()); + results[i] = packLLElements(loc, getTypeConverter(), valuesTransposed[i], + rewriter, resultTy); + } + rewriter.replaceOp(op, results); + return success(); +} +} // namespace + +void mlir::triton::populateScanOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp new file mode 100644 index 000000000..f220ad317 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -0,0 +1,77 @@ +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" + +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::MemDescType; + +TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( + MLIRContext *ctx, const TargetInfoBase &targetInfo, + const DataLayoutAnalysis *analysis) + : TritonGPUToLLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), targetInfo, + analysis) {} + +TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( + MLIRContext *ctx, const LowerToLLVMOptions &options, + const TargetInfoBase &targetInfo, const DataLayoutAnalysis *analysis) + : LLVMTypeConverter(ctx, options, analysis) { + addConversion([ctx](triton::PointerType type) -> std::optional { + return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); + }); + addConversion([ctx](TensorDescType type) -> std::optional { + return LLVM::LLVMPointerType::get(ctx, 0); + }); + addConversion([&](RankedTensorType type) -> std::optional { + return convertTritonTensorType(type, targetInfo); + }); + addConversion([&](MemDescType type) -> std::optional { + return convertMemDescType(type, targetInfo); + }); + addConversion([&](triton::gpu::AsyncTokenType type) -> std::optional { + return convertAsyncTokenType(type); + }); + + convertFP8Type(); +} + +Type TritonGPUToLLVMTypeConverter::convertTritonTensorType( + RankedTensorType type, const TargetInfoBase &targetInfo) { + auto ctx = type.getContext(); + Type eltType = convertType(type.getElementType()); + unsigned numElementsPerThread = getTotalElemsPerThread(type); + SmallVector types(numElementsPerThread, eltType); + return LLVM::LLVMStructType::getLiteral(ctx, types); +} + +Type TritonGPUToLLVMTypeConverter::convertMemDescType( + MemDescType type, const TargetInfoBase &targetInfo) { + auto ctx = type.getContext(); + // base ptr + auto ptrType = LLVM::LLVMPointerType::get( + ctx, targetInfo.getAddressSpace(type.getMemorySpace())); + + if (isa( + type.getEncoding())) { + return ptrType; + } + + SmallVector types; + types.push_back(ptrType); + auto rank = type.getRank(); + // offsets + for (auto i = 0; i < rank; i++) { + types.push_back(IntegerType::get(ctx, 32)); + } + return LLVM::LLVMStructType::getLiteral(ctx, types); +} + +Type TritonGPUToLLVMTypeConverter::convertAsyncTokenType( + triton::gpu::AsyncTokenType type) { + return IntegerType::get(type.getContext(), 32); +} diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/Utility.cpp new file mode 100644 index 000000000..4d8944a9c --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -0,0 +1,1301 @@ +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Attributes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/STLExtras.h" + +#if defined(_MSC_VER) && !defined(__clang__) +// from https://gist.github.com/pps83/3210a2f980fd02bb2ba2e5a1fc4a2ef0 +#include + +static int __builtin_clz(unsigned x) { + unsigned long r; + _BitScanReverse(&r, x); + return static_cast(r ^ 31); +} + +static int __builtin_ctz(unsigned x) { + unsigned long r; + _BitScanForward(&r, x); + return static_cast(r); +} + +#endif + +// This reverts #5645, because it introduced increased register pressure in AMD +// backend. +// TODO: remove when new implementation performance reaches target level +namespace { + +LinearLayout getRegToSharedLayout(MLIRContext *ctx, ArrayRef shape, + LinearLayout regLayout, + triton::gpu::SharedEncodingTrait dstEnc, + int elemBitWidth) { + StringAttr kBlock = StringAttr::get(ctx, ("block")); + int rank = shape.size(); + + LinearLayout sharedLayout = triton::gpu::toLinearLayout(shape, dstEnc); + auto sharedOrder = triton::gpu::getOrder(dstEnc, shape); + + // sharedLayout's in-dims are currently (offset, block). Reshape to + // (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional + // shmem strides. (The offsetX's appear in minor-to-major order.) + auto sharedLegacy = cast(dstEnc); + SmallVector> multiDimSharedSize; + for (int i = 0; i < rank; i++) { + int dim = sharedOrder[i]; + int64_t size = std::max( + int64_t{1}, + shape[dim] / sharedLegacy.getCTALayout().getCTASplitNum()[dim]); + multiDimSharedSize.push_back( + {StringAttr::get(ctx, ("offset" + std::to_string(dim))), size}); + } + multiDimSharedSize.push_back({kBlock, sharedLayout.getInDimSize(kBlock)}); + sharedLayout = sharedLayout.reshapeIns(multiDimSharedSize); + + // regToSharedLayout maps from (register, lane, warp, block) to (offsetX1, + // ..., offsetXN, block), where the offsetX's are in minor-to-major order. + return regLayout.invertAndCompose(sharedLayout); +} + +} // namespace + +namespace mlir { + +namespace triton::gpu { +Type getFunctionType(Type resultType, ValueRange operands) { + SmallVector operandTypes(operands.getTypes()); + return LLVM::LLVMFunctionType::get(resultType, operandTypes); +} + +LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op, + StringRef funcName, Type funcType, + StringRef libname /*= ""*/, + StringRef libpath /*= ""*/) { + using LLVM::LLVMFuncOp; + + auto funcAttr = StringAttr::get(op->getContext(), funcName); + Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); + if (funcOp) + return cast(*funcOp); + + Operation *parent = op; + if (!isa(op)) + parent = op->getParentOfType(); + OpBuilder b(parent); + auto ret = b.create(op->getLoc(), funcName, funcType); + ret.getOperation()->setAttr("libname", + StringAttr::get(op->getContext(), libname)); + ret.getOperation()->setAttr("libpath", + StringAttr::get(op->getContext(), libpath)); + return ret; +} +} // namespace triton::gpu + +SmallVector> +applyLinearLayout(Location loc, RewriterBase &rewriter, + const LinearLayout &layout, + ArrayRef> indices) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + assert(layout.getNumInDims() == indices.size()); + for (auto [inDimName, idx] : indices) { + assert(layout.hasInDim(inDimName) && "Invalid inDimName"); + } + + // This function can emit a lot of MLIR code, which ultimately makes + // compilation slow. (We think this shouldn't be the case -- it's not *that* + // much code -- but we're not clear on how to fix the slowness, which happens + // in the bowels of MLIR.) + // + // As a result we go through some contortions to avoid emitting code where + // possible. + + // Manually constant-fold the layout where possible. + SmallVector> constantIns; + for (auto [inDimName, idx] : indices) { + if (auto constant = idx.getDefiningOp()) { + constantIns.push_back( + {inDimName, cast(constant.getValue()).getInt()}); + } else { + constantIns.push_back({inDimName, 0}); + } + } + SmallVector constantComponent = + llvm::to_vector(llvm::make_second_range(layout.apply(constantIns))); + + Value zero = b.i32_val(0); + SmallVector> outIndices; + for (auto [i, outDimName] : llvm::enumerate(layout.getOutDimNames())) { + if (constantComponent[i] == 0) + outIndices.push_back({outDimName, zero}); + else + outIndices.push_back({outDimName, b.i32_val(constantComponent[i])}); + } + + for (auto [inDimName, idx] : indices) { + if (idx.getDefiningOp()) { + continue; + } + + int nBits = layout.getInDimSizeLog2(inDimName); + for (int i = 0; i < nBits; i++) { + Value bit = b.and_(idx, b.i32_val(1 << i)); + Value bit_is_zero = b.icmp_eq(bit, zero); + for (auto &[outDimName, outIdx] : outIndices) { + int32_t basis = layout.getBasis(inDimName, i, outDimName); + if (basis == 0) + continue; + outIdx = b.xor_(outIdx, b.select(bit_is_zero, zero, b.i32_val(basis))); + } + } + } + + return outIndices; +} + +std::optional getWarpGroupStartThreadId(Block *block) { + using namespace triton::gpu; + + // Look for an enclosing `ttg.warp_specialize` op. + while (block && block->getParentOp() && + !isa(block->getParentOp())) + block = block->getParentOp()->getBlock(); + if (!block || !block->getParentOp()) + return {}; + + auto partitions = cast(block->getParentOp()); + unsigned idx = block->getParent()->getRegionNumber(); + WarpSpecializeOp ws = partitions.getParentOp(); + std::optional> startIds = ws.getWarpGroupStartIds(); + assert(startIds && "cannot get warp group ID before warp group allocation"); + int32_t warpStartId = (*startIds)[idx]; + int threadsPerWarp = + TritonGPUDialect::getThreadsPerWarp(ws->getParentOfType()); + return warpStartId * threadsPerWarp; +} + +Value getThreadId(OpBuilder &rewriter, Location loc) { + Value tid = + rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x); + tid = rewriter.create(loc, i32_ty, tid); + + // If this is being created inside a warp specialize op, compute the relative + // thread ID within the warp group. + if (std::optional startId = + getWarpGroupStartThreadId(rewriter.getInsertionBlock())) { + TritonLLVMOpBuilder b(loc, rewriter); + tid = rewriter.create(loc, tid, b.i32_val(*startId)); + } + + return tid; +} + +Value getLaneId(OpBuilder &rewriter, Location loc) { + TritonLLVMOpBuilder b(loc, rewriter); + Value tid = getThreadId(rewriter, loc); + int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(rewriter); + return b.urem(tid, b.i32_val(threadsPerWarp)); +} + +std::pair getLaneAndWarpId(OpBuilder &rewriter, Location loc) { + TritonLLVMOpBuilder b(loc, rewriter); + Value tid = getThreadId(rewriter, loc); + int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(rewriter); + Value warpSizeVal = b.i32_val(threadsPerWarp); + Value laneId = b.urem(tid, warpSizeVal); + + // If there is only one warp, the warp ID is always 0. + Operation *lookupPt = &rewriter.getInsertionBlock()->front(); + Value warpId; + if (triton::gpu::lookupNumWarps(lookupPt) == 1) + warpId = b.i32_val(0); + else + warpId = b.udiv(tid, warpSizeVal); + + return {laneId, warpId}; +} + +SmallVector> +emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + Attribute layout, RankedTensorType type, bool withCTAOffset) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + auto shape = type.getShape(); + + LinearLayout ll = triton::gpu::toLinearLayout(shape, layout); + + // TODO(jlebar): We could add strong typing if we wanted; for now this is + // "stringly typed". + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + Value blockId = + withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0); + unsigned rank = shape.size(); + SmallVector> ret; + // Linear layout function is split in two parts below: + // L(r, t, w, b) = L(0, t, w, b) xor L(r, 0, 0, 0) + // idxs = idxsBase xor idxsReg + // + // L(0, t, w, b) part is the same for all registers, + // so we hoist it out of the main register loop in the below. + // + // This approach produces code with lower register pressure and + // less computations, compared to fused L(r,t,w,b) method. + auto idxsBase = applyLinearLayout(loc, rewriter, ll, + {{kRegister, b.i32_val(0)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, blockId}}); + for (unsigned reg = 0; reg < ll.getInDimSize(str_attr("register")); reg++) { + auto idxsReg = + ll.apply({{kRegister, reg}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); + SmallVector> idxs; + for (auto [idxBase, idxReg] : llvm::zip(idxsBase, idxsReg)) { + auto dimName = idxBase.first; + assert(dimName == idxReg.first && + "dim names of block+warp+thread and register idx should be equal"); + auto idx = b.xor_(idxBase.second, b.i32_val(idxReg.second)); + idxs.emplace_back(dimName, idx); + } + assert(idxs.size() == rank); + for (unsigned k = 0; k < rank; ++k) { + assert(idxs[k].first == str_attr("dim" + std::to_string(k))); + } + ret.push_back(llvm::to_vector(llvm::make_second_range(idxs))); + } + + return ret; +} + +namespace { + +Value getSmemVecAddr(const LinearLayout ®Layout, + const LinearLayout ®ToSharedLayout, + const LinearLayout &invertAllocSharedLayout, + const SharedMemoryObject &smemObj, + triton::gpu::MemDescType sharedTy, Type elemLlvmTy, + Value regId, Value laneId, Value warpId, Value blockId, + Location loc, RewriterBase &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + MLIRContext *ctx = rewriter.getContext(); + StringAttr kBlock = str_attr("block"); + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + auto shape = sharedTy.getShape(); + auto allocShape = sharedTy.getAllocShape(); + auto rank = shape.size(); + auto sharedEnc = + cast(sharedTy.getEncoding()); + + auto smemBase = smemObj.getBase(); + auto smemOffsets = smemObj.getOffsets(); + auto smemStrides = smemObj.getStrides(sharedTy, loc, rewriter); + Value smemOffset; + // When loading or storing to shared memory, we consider two cases for + // performance reasons: + // + // 1. Non-swizzled shared memory. + // 2. Swizzled shared memory. + // + // Consider lowering `ttg.local_load %a`. In the first case, we can + // directly construct a linear layout using `%a`'s shape and shared memory + // encoding, irrespective of `%a`'s rank or whether it represents a slice of a + // larger tensor. + // + // The method does not apply for swizzled shared memory in some scenarios. + // Key properties of swizzling in Triton are: + // + // - Swizzling applies only to tensors with rank ≥ 2. + // - It is restricted to the last two dimensions of the tensor. + // - These last two dimensions are always treated as the most "minor." + // + // An important edge case arises when `%a` results from `%a = ttg.subview %b`, + // where `%b` is swizzled (and so is `%a`). In this case, constructing a + // layout and determining shared memory offsets using `%a`'s shape is + // incorrect. This is because swizzling depends on the original shape of `%b`, + // which differs from `%a`'s shape. As a result, some locations may fall + // outside `%a`'s contiguous view of memory. Specifically, an element `[i + // (row_idx), j (col_idx)]` in `%a` might map to `[i, j']` after swizzling, + // where `j'` lies outside `%a`'s shape but still within `%b`'s shape. + // + // We propose case 2 (see comments below), which provides a more general + // solution for all swizzled shared memory scenarios, including the edge case + // mentioned above. + if (isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1 + smemOffset = applyLinearLayout(loc, rewriter, regToSharedLayout, + {{kRegister, regId}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, blockId}})[0] + .second; + // This reverts #5645, because it introduced increased register pressure in + // AMD backend. + // TODO: remove when new implementation performance reaches target level + if (auto swizzledSharedEnc = + mlir::dyn_cast( + sharedEnc)) { + auto regToSharedLayout = + getRegToSharedLayout(ctx, shape, regLayout, swizzledSharedEnc, + elemLlvmTy.getIntOrFloatBitWidth()); + auto smemOrder = swizzledSharedEnc.getOrder(); + smemOffsets = llvm::to_vector(llvm::drop_end(llvm::make_second_range( + applyLinearLayout(loc, rewriter, regToSharedLayout, + {{kRegister, regId}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, b.i32_val(0)}})))); + // Reorder strides according to `order`. This way they match the + // multi-dimensional offsets in regToSharedLayout. + smemOffset = dot(rewriter, loc, smemOffsets, + applyPermutation(smemStrides, smemOrder)); + } + } else { // Case 2 -> rank-reduced swizzling + assert(rank >= 2 && "Swizzling only applies to tensors with rank >= 2"); + assert((isa(sharedEnc)) && + "NVMMA layout not supported for sliced tensors"); + // We define both tensor offsets and shared memory offsets: + // + // - Tensor offsets: Relative offsets within a given tensor. + // - Shared memory offsets: Absolute offsets within the shared memory. + // + // In Triton, the shared memory layout provides an invertible, one-to-one + // mapping between tensor offsets and shared memory offsets. The `base` + // field of any shared memory object represents both the shared memory + // offset and the tensor offset relative to the original tensor at + // allocation, prior to any subview operations. + // + // To determine the shared memory offsets for a specific register when + // dealing with swizzled and sliced tensors, the process involves: + // + // 1. Retrieving the original tensor's `invertAllocSharedLayout`, which + // maps the allocated tensor's offsets back to shared memory offsets. + // 2. Reconstructing the register's offsets in the allocated tensor by + // summing: + // - The shared memory offsets of the current view's base, and + // - The relative tensor offsets of the register. + // + // This approach ensures that "absolute" tensor offsets can be + // mapped to the correct shared memory addresses using + // `invertAllocSharedLayout`. + auto multiDimTensorOffsets = + llvm::to_vector(applyLinearLayout(loc, rewriter, regLayout, + {{kRegister, regId}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, blockId}})); + for (auto i = 0; i < rank; i++) { + multiDimTensorOffsets[i].second = + b.add(multiDimTensorOffsets[i].second, smemOffsets[i]); + } + smemOffset = applyLinearLayout(loc, rewriter, invertAllocSharedLayout, + multiDimTensorOffsets)[0] + .second; + Value baseToAllocBaseDist = dot(rewriter, loc, smemOffsets, smemStrides); + smemOffset = b.sub(smemOffset, baseToAllocBaseDist); + } + auto ptrTy = smemBase.getType(); + auto vecAddr = b.gep(ptrTy, elemLlvmTy, smemBase, smemOffset, + LLVM::GEPNoWrapFlags::inbounds); + return vecAddr; +} + +} // namespace + +bool emitTransferBetweenRegistersAndShared( + LinearLayout ®Layout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, + std::optional maxVecElems, const SharedMemoryObject &smemObj, + Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + std::function perVectorCallback) { + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + return emitTransferBetweenRegistersAndShared( + regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter, + target, laneId, warpId, perVectorCallback); +} + +bool emitTransferBetweenRegistersAndShared( + LinearLayout ®Layout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, + std::optional maxVecElems, const SharedMemoryObject &smemObj, + Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + Value laneId, Value warpId, + std::function perVectorCallback) { + MLIRContext *ctx = rewriter.getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + StringAttr kBlock = str_attr("block"); + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + + auto shape = sharedTy.getShape(); + LinearLayout sharedLayout = + triton::gpu::toLinearLayout(shape, sharedTy.getEncoding()); + LinearLayout regToSharedLayout = regLayout.invertAndCompose(sharedLayout); + + // TODO(jlebar): We don't currently support loading from shared memory in a + // different CTA. We'd need to emit `mapa.shared::cluster` instructions. + for (int inBlock = 1; inBlock < regToSharedLayout.getInDimSize(kBlock); + inBlock *= 2) { + auto idx = regToSharedLayout.apply( + {{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, inBlock}}); + // Intra-block offset must be 0 + int32_t offset = idx[0].second; + if (offset != 0) { + return false; + } + // Check if there's any cross CTA load. + int32_t outBlock = idx[1].second; + if (outBlock != inBlock) { + return false; + } + } + + // Determine how many consecutive registers map to consecutive shmem elements + // in out-dimension offsetN. This is our load instruction's vector width. + // + // It's OK if the vector width we choose here is wider than the hardware + // supports; LLVM will legalize it. + const int vecElems = + std::min(regToSharedLayout.getNumConsecutiveInOut(), + maxVecElems.value_or(std::numeric_limits::max())); + + auto withCTAOffset = triton::gpu::getNumCTAs(sharedTy.getEncoding()) > 1; + Value blockId = + withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0); + + // For kernels with a single CTA, `allocSharedLayout.sublayout(S("block"), + // outDims) == 0`. We need to take out the "block" dimension in order to use + // `invert`. + // For kernels with multiple CTAs per CGA, + // `allocSharedLayout.sublayout(S("block"), outDims) != 0`. We do not need to + // take out the "block" dimension. + // Thus we use `pseudoinvert` instead of `invert` here for simplicity. + auto allocShape = sharedTy.getAllocShape(); + LinearLayout invertAllocSharedLayout = + triton::gpu::toLinearLayout(allocShape.take_back(sharedTy.getRank()), + sharedTy.getEncoding()) + .pseudoinvert(); + + int numElems = regToSharedLayout.getInDimSize(kRegister); + auto vecTy = vec_ty(elemLlvmTy, vecElems); + SmallVector ret; + for (int i = 0; i < numElems / vecElems; i++) { + auto regId = b.i32_val(i * vecElems); + auto vecAddr = getSmemVecAddr( + regLayout, regToSharedLayout, invertAllocSharedLayout, smemObj, + sharedTy, elemLlvmTy, regId, laneId, warpId, blockId, loc, rewriter); + perVectorCallback(vecTy, vecAddr); + } + return true; +} + +bool emitTransferBetweenRegistersAndShared( + RankedTensorType registerTy, triton::gpu::MemDescType sharedTy, + Type elemLlvmTy, std::optional maxVecElems, + const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, + std::function perVectorCallback) { + auto regLayout = triton::gpu::toLinearLayout(registerTy.getShape(), + registerTy.getEncoding()); + return emitTransferBetweenRegistersAndShared( + regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter, + target, perVectorCallback); +} + +SmallVector loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp, + Type elemLlvmTy, + const SharedMemoryObject &smemObj, + Location loc, RewriterBase &rewriter, + const TargetInfoBase &target) { + auto srcTy = localLoadOp.getSrc().getType(); + auto dstTy = localLoadOp.getResult().getType(); + + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector ret; + bool success = emitTransferBetweenRegistersAndShared( + dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc, + rewriter, target, [&](VectorType vecTy, Value vecAddr) { + auto vecVal = b.load(vecTy, vecAddr); + target.localLoadOpAnnotation(localLoadOp, vecVal); + vecVal.setAlignment(vecTy.getNumElements() * + elemLlvmTy.getIntOrFloatBitWidth() / 8); + + for (int v = 0; v < vecTy.getNumElements(); v++) { + ret.push_back(b.extract_element(elemLlvmTy, vecVal, b.i32_val(v))); + } + }); + if (!success) + llvm::report_fatal_error("Failed to emit transfer from shared to register"); + + return ret; +} + +void storeDistributedToShared(triton::gpu::MemDescType dstTy, + RankedTensorType srcTy, Type elemLlvmTy, + ArrayRef srcVals, + const SharedMemoryObject &smemObj, Location loc, + RewriterBase &rewriter, + const TargetInfoBase &target, + std::pair *const llvmOpCount) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + bool success = emitTransferBetweenRegistersAndShared( + srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc, + rewriter, target, [&](VectorType vecTy, Value vecAddr) { + ArrayRef vals = srcVals.take_front(vecTy.getNumElements()); + srcVals = srcVals.drop_front(vecTy.getNumElements()); + + Value vec = b.undef(vecTy); + for (int i = 0; i < vals.size(); i++) { + vec = b.insert_element(vec, vals[i], b.i32_val(i)); + } + b.store(vec, vecAddr) + .setAlignment(vecTy.getNumElements() * + elemLlvmTy.getIntOrFloatBitWidth() / 8); + if (llvmOpCount) { + ++(llvmOpCount->first); + llvmOpCount->second = vecTy; + } + }); + + if (!success) + llvm::report_fatal_error("Failed to emit transfer from register to shared"); +} + +SmallVector unpackLLElements(Location loc, Value llvmStruct, + RewriterBase &rewriter) { + assert(bool(llvmStruct) && "can not unpack null values"); + if (llvmStruct.getType().isIntOrIndexOrFloat() || + isa(llvmStruct.getType()) || + isa(llvmStruct.getType())) + return {llvmStruct}; + ArrayRef types = + cast(llvmStruct.getType()).getBody(); + SmallVector results(types.size()); + auto b = TritonLLVMOpBuilder(loc, rewriter); + for (unsigned i = 0; i < types.size(); ++i) { + Type type = types[i]; + results[i] = b.extract_val(type, llvmStruct, i); + } + return results; +} + +Value packLLElements(Location loc, const LLVMTypeConverter *typeConverter, + ValueRange resultVals, RewriterBase &rewriter, Type type) { + auto structType = + dyn_cast(typeConverter->convertType(type)); + if (!structType) { + assert(resultVals.size() == 1); + return *resultVals.begin(); + } + + auto elementTypes = structType.getBody(); + if (elementTypes.size() != resultVals.size()) { + emitError(loc) << " size mismatch when packing elements for LLVM struct" + << " expected " << elementTypes.size() << " but got " + << resultVals.size(); + llvm::report_fatal_error( + "size mismatch when packing elements for LLVM struct"); + } + Value llvmStruct = rewriter.create(loc, structType); + auto b = TritonLLVMOpBuilder(loc, rewriter); + for (auto [i, value] : llvm::enumerate(resultVals)) { + assert(value && "unexpected null value"); + if (value.getType() != elementTypes[i]) { + LDBG("type " << type << " structType " << structType); + LDBG("value " << value); + emitError(loc) << "invalid element type in packLLElements. Expected " + << elementTypes[i] << " but got " << value.getType(); + llvm::report_fatal_error( + "element type mismatch when packing elements for LLVM struct"); + } + llvmStruct = b.insert_val(structType, llvmStruct, value, i); + } + return llvmStruct; +} + +SmallVector unpackLLVector(Location loc, Value llvmVec, + RewriterBase &rewriter) { + assert(bool(llvmVec) && "cannot unpack null value"); + if (llvmVec.getType().isIntOrIndexOrFloat() || + isa(llvmVec.getType()) || + isa(llvmVec.getType())) + return {llvmVec}; + + auto b = TritonLLVMOpBuilder(loc, rewriter); + SmallVector results; + for (int i = 0; i < cast(llvmVec.getType()).getNumElements(); + i++) { + results.push_back(b.extract_element(llvmVec, b.i32_val(i))); + } + return results; +} + +Value packLLVector(Location loc, ValueRange vals, RewriterBase &rewriter) { + assert(vals.size() > 0); + auto vecType = vec_ty(vals[0].getType(), vals.size()); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value vec = b.undef(vecType); + for (int i = 0; i < vals.size(); i++) { + vec = b.insert_element(vec, vals[i], b.i32_val(i)); + } + return vec; +} + +std::optional matchAtomicOp(RMWOp atomicOp) { + switch (atomicOp) { + case RMWOp::AND: + return LLVM::AtomicBinOp::_and; + case RMWOp::OR: + return LLVM::AtomicBinOp::_or; + case RMWOp::XOR: + return LLVM::AtomicBinOp::_xor; + case RMWOp::ADD: + return LLVM::AtomicBinOp::add; + case RMWOp::FADD: + return LLVM::AtomicBinOp::fadd; + case RMWOp::MAX: + return LLVM::AtomicBinOp::max; + case RMWOp::MIN: + return LLVM::AtomicBinOp::min; + case RMWOp::UMAX: + return LLVM::AtomicBinOp::umax; + case RMWOp::UMIN: + return LLVM::AtomicBinOp::umin; + case RMWOp::XCHG: + return LLVM::AtomicBinOp::xchg; + default: + return {}; + } +} + +std::optional getMemoryOrdering(MemSemantic memOrdering) { + switch (memOrdering) { + case MemSemantic::RELAXED: + return LLVM::AtomicOrdering::monotonic; + case MemSemantic::ACQUIRE: + return LLVM::AtomicOrdering::acquire; + case MemSemantic::RELEASE: + return LLVM::AtomicOrdering::release; + case MemSemantic::ACQUIRE_RELEASE: + return LLVM::AtomicOrdering::acq_rel; + default: + return {}; + } +} + +bool isSimpleSharedMemoryAccess(ArrayRef shape, + ArrayRef allocShape, + triton::gpu::SharedEncodingTrait sharedEnc) { + auto rank = shape.size(); + auto swizzledLayout = + dyn_cast(sharedEnc); + auto nvmmaLayout = dyn_cast(sharedEnc); + bool noSwizzling = (swizzledLayout && swizzledLayout.getMaxPhase() == 1) || + (nvmmaLayout && nvmmaLayout.getSwizzlingByteWidth() == 0); + return /*no swizzling*/ noSwizzling || + /*swizzling but same shape*/ shape == allocShape || + /*swizzling and rank-reduced and rank >= 2*/ + (shape == allocShape.take_back(rank) && rank >= 2); +} + +llvm::MapVector getAllFreeVarMasks(MLIRContext *ctx) { + // Mask where all elements are redundant + auto kReg = str_attr("reg"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kBlock = str_attr("block"); + + int32_t fullMask = -1; + llvm::MapVector ret; + for (auto dimName : {kReg, kLane, kWarp, kBlock}) { + ret[dimName] = fullMask; + } + return ret; +} + +llvm::MapVector getFreeVariableMasks(Type type) { + auto ctx = type.getContext(); + auto tensorTy = dyn_cast(type); + if (!tensorTy) { + return getAllFreeVarMasks(ctx); + } + auto ll = + triton::gpu::toLinearLayout(tensorTy.getShape(), tensorTy.getEncoding()); + return ll.getFreeVariableMasks(); +} + +SmallVector> emitOffsetForLayout(Attribute layout, + RankedTensorType type) { + MLIRContext *ctx = layout.getContext(); + auto shape = type.getShape(); + unsigned rank = shape.size(); + + auto ll = triton::gpu::toLinearLayout(shape, layout); + + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + + SmallVector> offsets; + for (int i = 0; i < ll.getInDimSize(str_attr("register")); i++) { + auto idxs = ll.apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); + assert(idxs.size() == rank); + for (unsigned k = 0; k < rank; ++k) { + assert(idxs[k].first == str_attr("dim" + std::to_string(k))); + } + offsets.push_back( + llvm::to_vector_of(llvm::make_second_range(idxs))); + } + return offsets; +} + +namespace LLVM { +using namespace mlir::triton; +using mlir::triton::gpu::getOrder; + +Value createConstantI1(Location loc, OpBuilder &rewriter, bool v) { + auto i1ty = rewriter.getIntegerType(1); + return rewriter.create(loc, i1ty, + IntegerAttr::get(i1ty, v)); +} + +Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v) { + auto i32ty = rewriter.getIntegerType(32); + return rewriter.create(loc, i32ty, + IntegerAttr::get(i32ty, v)); +} + +Value createConstantI64(Location loc, OpBuilder &rewriter, int64_t v) { + auto i64ty = rewriter.getIntegerType(64); + return rewriter.create(loc, i64ty, + IntegerAttr::get(i64ty, v)); +} + +Value createConstantF16(Location loc, OpBuilder &rewriter, float v) { + auto type = type::f16Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF16FloatAttr(v)); +} + +Value createConstantBF16(Location loc, OpBuilder &rewriter, float v) { + APFloat apf(v); + bool ignored; + apf.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &ignored); + auto type = type::bf16Ty(rewriter.getContext()); + auto attr = FloatAttr::get(type, apf); + return rewriter.create(loc, type, attr); +} + +Value createConstantF32(Location loc, OpBuilder &rewriter, float v) { + auto type = type::f32Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF32FloatAttr(v)); +} + +Value createConstantF64(Location loc, OpBuilder &rewriter, double v) { + auto type = type::f64Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF64FloatAttr(v)); +} + +Value createNaNConstant(Location loc, OpBuilder &rewriter, Type type) { + if (!isa(type)) { + llvm::report_fatal_error("Creating NaN constant for non-float type!"); + } + return rewriter.create( + loc, type, APFloat::getNaN(cast(type).getFloatSemantics())); +} + +// Create an index type constant. +Value createIndexConstant(OpBuilder &builder, Location loc, + const TypeConverter *converter, int64_t value) { + Type ty = converter->convertType(builder.getIndexType()); + return builder.create(loc, ty, + builder.getIntegerAttr(ty, value)); +} + +// Create an integer constant of \param width bits. +Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, + int64_t value) { + Type ty = builder.getIntegerType(width); + return builder.create(loc, ty, + builder.getIntegerAttr(ty, value)); +} + +LLVM::CallOp createLLVMCallOp(OpBuilder &builder, Location loc, + LLVMFuncOp funcOp, ValueRange args) { + auto op = builder.create(loc, funcOp, args); + op.getProperties().setOpBundleSizes(builder.getDenseI32ArrayAttr({})); + op.getProperties().setOperandSegmentSizes({static_cast(args.size()), 0}); + return op; +} + +LLVM::CallIntrinsicOp +createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic, + TypeRange types, ValueRange args) { + auto op = builder.create(loc, types, args); + op.getProperties().setIntrin(builder.getStringAttr(intrinsic)); + op.getProperties().setOpBundleSizes(builder.getDenseI32ArrayAttr({})); + op.getProperties().setOperandSegmentSizes({static_cast(args.size()), 0}); + return op; +} + +SharedMemoryObject::SharedMemoryObject(Value base, Type baseElemType, + ArrayRef offsets) + : base(base), baseElemType(baseElemType), + offsets(offsets.begin(), offsets.end()) {} + +SharedMemoryObject::SharedMemoryObject(Value base, Type baseElemType, + int64_t rank, Location loc, + RewriterBase &rewriter) + : base(base), baseElemType(baseElemType) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + offsets.append(rank, b.i32_val(0)); +} + +SmallVector SharedMemoryObject::getElems() const { + SmallVector elems; + elems.push_back(base); + elems.append(offsets.begin(), offsets.end()); + return elems; +} + +SmallVector SharedMemoryObject::getTypes() const { + SmallVector types; + types.push_back(base.getType()); + types.append(offsets.size(), IntegerType::get(base.getContext(), 32)); + return types; +} + +SmallVector +SharedMemoryObject::getStrides(triton::gpu::MemDescType memDesc, Location loc, + RewriterBase &rewriter) const { + auto allocShape = memDesc.getAllocShape(); + auto allocShapePerCTA = + triton::gpu::getAllocationShapePerCTA(memDesc.getEncoding(), allocShape); + auto layoutOrder = triton::gpu::getOrder(memDesc); + auto allocStrides = SharedMemoryObject::getStridesForShape( + allocShapePerCTA, layoutOrder, loc, rewriter); + return SmallVector(allocStrides.end() - offsets.size(), + allocStrides.end()); +} + +Value SharedMemoryObject::getBaseBeforeSlice(int dim, Location loc, + RewriterBase &rewriter) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value cSwizzleOffset = getCSwizzleOffset(dim); + Value offset = b.sub(b.i32_val(0), cSwizzleOffset); + Type type = base.getType(); + return b.gep(type, baseElemType, base, offset); +} + +SmallVector +SharedMemoryObject::getOrderForShape(ArrayRef shape, + ArrayRef layoutOrder) { + SmallVector order(shape.size()); + // Default minor-to-major order + std::iota(order.rbegin(), order.rend(), 0); + if (layoutOrder.size() > 0) { + // If a layout order is provided, we assume it specifies the order in + // which the dimensions are first accessed, and unspecified dimensions + // retain the minor-to-major order. For example, if order = [2, 1, 0] and + // layoutOrder = [0, 1], we need to shift `layoutOrder` + // by -1 (move them right). The resulting order will then be [1, 2, 0]. + int rankDiff = layoutOrder.size() - shape.size(); + auto minRank = std::min(shape.size(), layoutOrder.size()); + for (size_t i = 0; i < minRank; ++i) + order[i] = layoutOrder[i] - rankDiff; + } + assert(isPermutationOfIota(order) && "Invalid order"); + return order; +} + +SmallVector +SharedMemoryObject::getStridesForShape(ArrayRef shape, + ArrayRef layoutOrder, + Location loc, RewriterBase &rewriter) { + SmallVector strides(shape.size()); + auto order = SharedMemoryObject::getOrderForShape(shape, layoutOrder); + int64_t stride = 1; + auto b = TritonLLVMOpBuilder(loc, rewriter); + for (auto idx : order) { + strides[idx] = b.i32_val(stride); + stride *= shape[idx]; + } + return strides; +} + +Value getStructFromSharedMemoryObject(Location loc, + const SharedMemoryObject &smemObj, + RewriterBase &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto elems = smemObj.getElems(); + auto types = smemObj.getTypes(); + auto structTy = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); + // pack into struct + Value llvmStruct = rewriter.create(loc, structTy); + for (const auto &v : llvm::enumerate(elems)) { + assert(v.value() && "can not insert null values"); + llvmStruct = b.insert_val(structTy, llvmStruct, v.value(), v.index()); + } + return llvmStruct; +} + +SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, + Value llvmStruct, + Type elemTy, + RewriterBase &rewriter) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + ArrayRef types = + cast(llvmStruct.getType()).getBody(); + SmallVector elems(types.size()); + for (unsigned i = 0; i < types.size(); ++i) { + Type type = types[i]; + elems[i] = b.extract_val(type, llvmStruct, i); + } + return {/*base=*/elems[0], + /*baseElemType=*/elemTy, + /*offsets=*/{elems.begin() + 1, elems.end()}}; +} + +Value getStackPointer(RewriterBase &rewriter, FunctionOpInterface funcOp) { + // See NOTE: [Additional Function Arguments] + if (!isKernel(funcOp)) { + return funcOp.getArgument(funcOp.getNumArguments() - 2); + } + + auto mod = funcOp->getParentOfType(); + auto globalBase = dyn_cast(mod.lookupSymbol("global_smem")); + assert(globalBase); + return rewriter.create(funcOp.getLoc(), globalBase); +} + +Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter, + const TargetInfoBase &targetInfo, + FunctionOpInterface funcOp, Value allocOffset = {}) { + // See NOTE: [Additional Function Arguments] + if (!isKernel(funcOp)) { + // Base for this function + auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1); + if (!allocOffset) { + return gmemBase; + } + + auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext(), 1); + auto b = TritonLLVMOpBuilder(loc, rewriter); + return b.gep(ptrTy, i8_ty, gmemBase, allocOffset); + } + + // Base for entire kernel + auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1); + + ModuleOp mod = funcOp.getOperation()->getParentOfType(); + auto allocSizeAttr = mod.getOperation()->getAttrOfType( + "ttg.global_scratch_memory_size"); + if (!allocSizeAttr) { + return gmemBase; + } + + Value gridIdx[3]; + Value gridDim[2]; + for (int k = 0; k < 3; ++k) { + gridIdx[k] = rewriter.create(loc, k); + } + for (int k = 0; k < 2; ++k) { + gridDim[k] = rewriter.create(loc, k); + } + + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value linearId = gridIdx[2]; + for (int k = 0; k < 2; ++k) { + linearId = b.add(gridIdx[1 - k], b.mul(linearId, gridDim[1 - k])); + } + auto numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); + if (numCTAs > 1) { + linearId = b.mul(linearId, b.i32_val(numCTAs)); + linearId = b.add(linearId, targetInfo.getClusterCTAId(rewriter, loc)); + } + + auto allocSize = allocSizeAttr.getValue().getZExtValue(); + + Value offset = b.mul(linearId, b.i32_val(allocSize)); + if (allocOffset) { + offset = b.add(offset, allocOffset); + } + + auto *ctx = rewriter.getContext(); + auto res = + b.gep(mlir::LLVM::LLVMPointerType::get(ctx, 1), i8_ty, gmemBase, offset); + return res; +} + +Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Operation *op) { + auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), + target.getSharedAddressSpace()); + auto func = op->template getParentOfType(); + if (!func) + func = cast(op); + + assert(op->hasAttr("allocation.offset")); + size_t offset = cast(op->getAttr("allocation.offset")) + .getValue() + .getZExtValue(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value offVal = b.i32_val(offset); + Value base = + b.gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal); + return base; +} + +// Extract the bits of `a` that are set in `mask` +Value pext_i32(RewriterBase &rewriter, Location loc, Value a, uint32_t mask) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + assert(a.getType() == i32_ty && "a must be i32"); + // Handle width = 32 to avoid doing 1 << 32 + if (mask == 0xFFFFFFFF) + return a; + + // Implements the blocked algorithm from + // https://forums.developer.nvidia.com/t/pdep-and-pext-functionality-for-cuda/270973 + uint32_t mskConst = mask; + uint32_t extcnt = 0; + Value result = b.i32_val(0); + while (mskConst) { + uint32_t oldmsk = mskConst; + uint32_t bitgrplsb = mskConst & (-mskConst); + mskConst &= bitgrplsb + mskConst; + uint32_t bitgrp = mskConst ^ oldmsk; + uint32_t lsbpos = 31 - __builtin_clz(bitgrplsb); + // like popcount for a number 0..01..1..0 but portable + uint32_t grplen = __builtin_ctz(~(bitgrp >> lsbpos)); + uint32_t shift = lsbpos - extcnt; + extcnt += grplen; + result = + b.or_(result, b.lshr(b.and_(b.i32_val(bitgrp), a), b.i32_val(shift))); + } + return result; +} + +std::tuple, Value> +delinearize(RewriterBase &rewriter, Location loc, + triton::gpu::DistributedEncodingTrait layout, + ArrayRef shape, StringAttr dimName, Value linear) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto ll = triton::gpu::toLinearLayout(shape, layout); + auto linearLayout = + triton::gpu::LinearEncodingAttr::get(rewriter.getContext(), ll); + assert(ll.hasInDim(dimName)); + int32_t freeVarMask = ll.getFreeVariableMasks()[dimName]; + auto isRepresentative = b.true_val(); + if (freeVarMask != 0) { + isRepresentative = + b.icmp_eq(b.and_(b.i32_val(freeVarMask), linear), b.i32_val(0)); + // We remove the bits of linear that are set to one in freeVarMask + int32_t nonFreeVarMask = ~freeVarMask & (ll.getInDimSize(dimName) - 1); + linear = pext_i32(rewriter, loc, linear, nonFreeVarMask); + } + + auto orderDim = linearLayout.orderPerDim(dimName, linearLayout.getOrder()); + auto shapeDim = linearLayout.basesPerDim(dimName); + auto multiDim = delinearize(rewriter, loc, linear, shapeDim, orderDim); + + return std::make_tuple(std::move(multiDim), isRepresentative); +} + +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape, + ArrayRef order) { + unsigned rank = shape.size(); + assert(rank == order.size()); + auto reordered = applyPermutation(shape, order); + SmallVector reorderedMultiDim(rank); + if (auto constantOp = linear.getDefiningOp()) { + unsigned intVal = mlir::cast(constantOp.getValue()) + .getValue() + .getSExtValue(); + reorderedMultiDim = delinearize(rewriter, loc, intVal, reordered); + } else { + reorderedMultiDim = delinearize(rewriter, loc, linear, reordered); + } + SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; + } + return multiDim; +} + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + unsigned linear, ArrayRef shape) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + unsigned remained = linear; + for (auto &&en : llvm::enumerate(shape)) { + unsigned dimSize = en.value(); + multiDim[en.index()] = b.i32_val(remained % dimSize); + remained = remained / dimSize; + } + return multiDim; +} + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + Value remained = linear; + for (auto &&en : llvm::enumerate(shape)) { + Value dimSize = b.i32_val(en.value()); + multiDim[en.index()] = b.urem(remained, dimSize); + remained = b.udiv(remained, dimSize); + } + return multiDim; +} + +SmallVector delinearize(unsigned linear, ArrayRef shape, + ArrayRef order) { + auto rank = shape.size(); + assert(order.size() == rank); + SmallVector multiDim(rank); + for (auto dim : order) { + multiDim[dim] = linear % shape[dim]; + linear /= shape[dim]; + } + assert(linear == 0); + return multiDim; +} + +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order) { + return linearize(rewriter, loc, applyPermutation(multiDim, order), + applyPermutation(shape, order)); +} + +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto rank = multiDim.size(); + Value linear = b.i32_val(0); + if (rank > 0) { + linear = multiDim.back(); + for (auto [dim, dimShape] : + llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) { + Value dimSize = b.i32_val(dimShape); + linear = b.add(b.mul(linear, dimSize), dim); + } + } + return linear; +} + +size_t linearize(ArrayRef multiDim, ArrayRef shape, + ArrayRef order) { + size_t linear = 0; + for (unsigned dim : llvm::reverse(order)) + linear = linear * shape[dim] + multiDim[dim]; + return linear; +} + +Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key, + StringRef content) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + auto ctx = moduleOp.getContext(); + unsigned stringNumber = 0; + SmallString<16> stringConstName; + do { + stringConstName.clear(); + (key + Twine(stringNumber++)).toStringRef(stringConstName); + } while (moduleOp.lookupSymbol(stringConstName)); + + llvm::SmallString<64> contentStr(content); + size_t contentSize = contentStr.size_in_bytes(); + auto globalType = LLVM::LLVMArrayType::get(i8_ty, contentSize); + + LLVM::GlobalOp global; + { + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + global = rewriter.create( + UnknownLoc::get(ctx), globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, + rewriter.getStringAttr(contentStr)); + } + + Value zero = b.i32_val(0); + Type globalPtrType = LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()); + Value globalPtr = rewriter.create( + UnknownLoc::get(ctx), globalPtrType, global.getSymName()); + Value stringStart = + b.gep(ptr_ty(ctx), i8_ty, globalPtr, SmallVector({zero})); + return stringStart; +} + +} // namespace LLVM + +Value dot(RewriterBase &rewriter, Location loc, ArrayRef offsets, + ArrayRef strides) { + assert(offsets.size() == strides.size()); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value ret = b.i32_val(0); + for (auto [offset, stride] : llvm::zip(offsets, strides)) { + ret = b.add(ret, b.mul(offset, stride)); + } + return ret; +} + +SharedMemoryObject +getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, + SharedMemoryObject smemObj, + ArrayRef shape) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + assert(shape.size() == 2 || shape.size() == 3); + auto offsets = smemObj.getOffsets(); + auto rank = offsets.size(); + assert(rank == shape.size()); + if (rank == 3) + return smemObj; + offsets.insert(offsets.begin(), b.i32_val(0)); + auto expandedSmemObj = + SharedMemoryObject(smemObj.getBase(), smemObj.getBaseElemType(), offsets); + return expandedSmemObj; +} + +// Isolated a single warp specialize op from above. +static void +makeWarpGroupsIsolatedFromAbove(triton::gpu::WarpSpecializeOp wsOp) { + SetVector captures; + getUsedValuesDefinedAbove(wsOp.getPartitionOpHolder(), captures); + for (Value capture : captures) { + wsOp->insertOperands(wsOp.getNumOperands(), capture); + for (Region *region : wsOp.getPartitionRegions()) { + BlockArgument arg = + region->addArgument(capture.getType(), capture.getLoc()); + replaceAllUsesInRegionWith(capture, arg, *region); + } + } +} + +void makeAllWarpGroupsIsolatedFromAbove(Operation *op) { + op->walk([](triton::gpu::WarpSpecializeOp wsOp) { + makeWarpGroupsIsolatedFromAbove(wsOp); + }); +} + +} // namespace mlir diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp new file mode 100644 index 000000000..ec652fa71 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -0,0 +1,522 @@ +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "triton/Tools/LayoutUtils.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +namespace { +struct SplatOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + // Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a + // LLVM::StructType value. + // + // @elemType: the element type in operand. + // @resType: the return type of the Splat-like op. + // @constVal: a LLVM::ConstantOp or other scalar value. + static Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + Location loc) { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto tensorTy = cast(resType); + // Check the converted type for the tensor as depending on the encoding the + // converter may pick different element types. + auto srcType = typeConverter->convertType(tensorTy); + if (auto structTy = dyn_cast(srcType)) + srcType = structTy.getBody()[0]; + // If the type sizes don't match we need to pack constants. + if (srcType.isIntOrFloat() && constVal.getType().getIntOrFloatBitWidth() != + srcType.getIntOrFloatBitWidth()) { + unsigned cstBitWidth = constVal.getType().getIntOrFloatBitWidth(); + unsigned srcBitWidth = srcType.getIntOrFloatBitWidth(); + assert(cstBitWidth <= srcBitWidth && srcBitWidth % cstBitWidth == 0); + unsigned ratio = srcBitWidth / cstBitWidth; + Type intTy = IntegerType::get(elemType.getContext(), cstBitWidth); + VectorType vecType = VectorType::get(ratio, intTy); + Value intCst = b.bitcast(constVal, intTy); + Value vec = b.undef(vecType); + for (unsigned i = 0; i < ratio; ++i) + vec = b.insert_element(vecType, vec, intCst, b.int_val(32, i)); + constVal = vec; + } + auto llSrc = b.bitcast(constVal, srcType); + size_t elemsPerThread = getTotalElemsPerThread(tensorTy); + llvm::SmallVector elems(elemsPerThread, llSrc); + return packLLElements(loc, typeConverter, elems, rewriter, resType); + } + LogicalResult matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op->getLoc(); + auto src = adaptor.getSrc(); + auto typeConverter = getTypeConverter(); + auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src, + typeConverter, rewriter, loc); + rewriter.replaceOp(op, {llStruct}); + return success(); + } +}; +// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr), +// the logic is the same as triton::SplatOp, so the underlying implementation +// is reused. +struct ArithConstantSplatOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto value = op.getValue(); + if (!mlir::dyn_cast(value)) + return failure(); + auto loc = op->getLoc(); + LLVM::ConstantOp arithConstantOp; + auto values = mlir::dyn_cast(op.getValue()); + auto elemType = values.getElementType(); + Attribute val; + if (type::isFloat(elemType)) { + val = values.getValues()[0]; + } else if (type::isInt(elemType)) { + val = values.getValues()[0]; + } else { + llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: " + << value.getType() << "\n"; + return failure(); + } + // Lower FP8 constant to int8 constant since FP8 types are not supported on + // LLVM IR. + if (type::isFloat8(elemType)) + elemType = rewriter.getIntegerType(8); + auto constOp = rewriter.create(loc, elemType, val); + auto typeConverter = getTypeConverter(); + auto llStruct = SplatOpConversion::convertSplatLikeOp( + elemType, op.getType(), constOp, typeConverter, rewriter, loc); + rewriter.replaceOp(op, llStruct); + return success(); + } +}; +struct CatOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename CatOp::Adaptor; + explicit CatOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + unsigned elems = getTotalElemsPerThread(resultTy); + auto typeConverter = getTypeConverter(); + Type elemTy = typeConverter->convertType(resultTy.getElementType()); + SmallVector types(elems, elemTy); + // unpack input values + auto lhsVals = unpackLLElements(loc, adaptor.getLhs(), rewriter); + auto rhsVals = unpackLLElements(loc, adaptor.getRhs(), rewriter); + // concatenate (and potentially reorder) values + SmallVector retVals; + for (Value v : lhsVals) + retVals.push_back(v); + for (Value v : rhsVals) + retVals.push_back(v); + // pack and replace + Value ret = packLLElements(loc, typeConverter, retVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct JoinOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename JoinOp::Adaptor; + explicit JoinOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // We rely on the following invariants of this op (which are checked by its + // verifier): + // + // - The last dimension (the one we're joining) is also the most minor + // dimension. + // - The input and output encodings are the same, except the output has + // 2 elements per thread in the last dim. + // + // With these invariants, join is trivial: We can count how many contiguous + // registers belong to the same chunk then we merge the registers between + // two different chunks. + Location loc = op->getLoc(); + RankedTensorType dstTy = op.getType(); + auto ll = toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); + int splitDim = dstTy.getRank() - 1; + auto kReg = mlir::StringAttr::get(dstTy.getContext(), "register"); + const auto &bases = ll.getBases(); + const auto ®s = bases.find(kReg)->second; + int numContiguousValues = 1; + bool found = false; + for (const auto ® : regs) { + if (reg[splitDim] == 1) { + found = true; + break; + } + numContiguousValues *= 2; + } + assert(found && "Join dimension is not distributed along registers."); + SmallVector lhsVals = + unpackLLElements(loc, adaptor.getLhs(), rewriter); + SmallVector rhsVals = + unpackLLElements(loc, adaptor.getRhs(), rewriter); + assert(lhsVals.size() == rhsVals.size()); + SmallVector joinedVals; + joinedVals.resize(lhsVals.size() * 2); + for (int i = 0; i < lhsVals.size(); i += numContiguousValues) { + for (int j = 0; j < numContiguousValues; j++) { + joinedVals[2 * i + j] = lhsVals[i + j]; + joinedVals[2 * i + numContiguousValues + j] = rhsVals[i + j]; + } + } + auto typeConverter = getTypeConverter(); + Value ret = packLLElements(loc, typeConverter, joinedVals, rewriter, dstTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct SplitOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename SplitOp::Adaptor; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // We rely on the following invariants of this op (which are checked by its + // verifier): + // + // - The layout distribute the last dimension along registers + // - The last dimension (the one we're splitting) has sizePerThread=2, + // threadPerWarp=1 and warpPerBlock=1. + // + // With these invariants, split is trivial: We can count how many contiguous + // registers belong to the same chunk then we separate the registers between + // two different chunks. + auto srcTy = cast(op.getSrc().getType()); + auto ll = toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + int splitDim = srcTy.getRank() - 1; + auto kReg = mlir::StringAttr::get(srcTy.getContext(), "register"); + const auto &bases = ll.getBases(); + const auto ®s = bases.find(kReg)->second; + int numContiguousValues = 1; + bool found = false; + for (const auto ® : regs) { + if (reg[splitDim] == 1) { + found = true; + break; + } + numContiguousValues *= 2; + } + assert(found && "Split dimension is not distributed along registers."); + Location loc = op->getLoc(); + auto typeConverter = getTypeConverter(); + SmallVector srcVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + assert(srcVals.size() % 2 == 0); + SmallVector outLhsVals; + SmallVector outRhsVals; + for (int i = 0; i < srcVals.size(); i += 2 * numContiguousValues) { + for (int j = 0; j < numContiguousValues; j++) { + outLhsVals.push_back(srcVals[i + j]); + outRhsVals.push_back(srcVals[i + numContiguousValues + j]); + } + } + auto resultTy = cast(op.getResult(0).getType()); + Value retLhs = + packLLElements(loc, typeConverter, outLhsVals, rewriter, resultTy); + Value retRhs = + packLLElements(loc, typeConverter, outRhsVals, rewriter, resultTy); + rewriter.replaceOp(op, {retLhs, retRhs}); + return success(); + } +}; +struct ReshapeOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename ReshapeOp::Adaptor; + explicit ReshapeOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + if (triton::gpu::isExpensiveView(op.getSrc().getType(), op.getType())) { + return emitOptionalError(loc, + "expensive view not supported on reshape op"); + } + auto resultTy = cast(op.getType()); + auto srcTy = cast(op.getSrc().getType()); + auto typeConverter = getTypeConverter(); + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + Value ret = packLLElements(loc, typeConverter, vals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct ExpandDimsOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename ExpandDimsOp::Adaptor; + explicit ExpandDimsOpConversion( + LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto typeConverter = getTypeConverter(); + auto srcVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto srcTy = cast(op.getSrc().getType()); + auto resultTy = cast(op.getType()); + auto srcLayout = dyn_cast(srcTy.getEncoding()); + if (!srcLayout) { + return emitOptionalError( + loc, "ExpandDimsOp only supports SliceEncodingAttr as its input"); + } + auto resultLayout = resultTy.getEncoding(); + auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy); + auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy); + std::map, Value> srcValues; + for (size_t i = 0; i < srcOffsets.size(); i++) { + srcValues[srcOffsets[i]] = srcVals[i]; + } + SmallVector resultVals; + for (size_t i = 0; i < resultOffsets.size(); i++) { + auto offset = resultOffsets[i]; + offset.erase(offset.begin() + srcLayout.getDim()); + resultVals.push_back(srcValues.at(offset)); + } + Value ret = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct MemDescTransOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(MemDescTransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + auto llvmElemTy = + getTypeConverter()->convertType(resultTy.getElementType()); + auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + auto dstSmemObj = SharedMemoryObject( + srcSmemObj.getBase(), srcSmemObj.getBaseElemType(), + /*offsets=*/applyPermutation(srcSmemObj.getOffsets(), op.getOrder())); + auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } +}; + +struct MemDescReshapeOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(MemDescReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + auto llvmElemTy = + getTypeConverter()->convertType(resultTy.getElementType()); + auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + SmallVector offsets = srcSmemObj.getOffsets(); + SmallVector srcShape; + for (int64_t d : op.getSrc().getType().getShape()) + srcShape.push_back(d); + SmallVector dstShape; + for (int64_t d : op.getType().getShape()) + dstShape.push_back(d); + Value linearOffset = LLVM::linearize(rewriter, loc, offsets, srcShape); + SmallVector delinearizedOffset = + LLVM::delinearize(rewriter, loc, linearOffset, dstShape); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto dstSmemObj = SharedMemoryObject( + srcSmemObj.getBase(), srcSmemObj.getBaseElemType(), delinearizedOffset); + auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } +}; + +struct TransOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // By construction, TransOp::inferReturnTypes ensures that the src encoding + // is the same as the dst encoding so that this op is a no-op. + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } +}; + +struct BroadcastOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Following the order of indices in the legacy code, a broadcast of: + // [s(0), s(1) ... s(k-1), 1, s(k+1), s(k+2) ... s(n-1)] + // => + // [s(0), s(1) ... s(k-1), s(k), s(k+1), s(k+2) ... s(n-1)] + // + // logically maps to a broadcast within a thread's scope: + // [cta(0)..cta(k-1), 1,cta(k+1)..cta(n-1),spt(0)..spt(k-1), + // 1,spt(k+1)..spt(n-1)] + // => + // [cta(0)..cta(k-1),cta(k),cta(k+1)..cta(n-1),spt(0)..spt(k-1),spt(k),spt(k+1)..spt(n-1)] + // + // regardless of the order of the layout + // + Location loc = op->getLoc(); + Value src = adaptor.getSrc(); + Value result = op.getResult(); + auto srcTy = cast(op.getSrc().getType()); + auto resultTy = cast(result.getType()); + auto srcLayout = srcTy.getEncoding(); + auto resultLayout = resultTy.getEncoding(); + auto srcShape = srcTy.getShape(); + auto resultShape = resultTy.getShape(); + unsigned rank = srcTy.getRank(); + auto typeConverter = getTypeConverter(); + assert(rank == resultTy.getRank()); + auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy); + auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy); + SmallVector srcVals = unpackLLElements(loc, src, rewriter); + std::map, Value> srcValues; + for (size_t i = 0; i < srcOffsets.size(); i++) { + srcValues[srcOffsets[i]] = srcVals[i]; + } + SmallVector resultVals; + for (size_t i = 0; i < resultOffsets.size(); i++) { + auto offset = resultOffsets[i]; + for (size_t j = 0; j < srcShape.size(); j++) + if (srcShape[j] == 1) + offset[j] = 0; + resultVals.push_back(srcValues.at(offset)); + } + Value resultStruct = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +struct MemDescSubviewOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::MemDescSubviewOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::MemDescSubviewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto *ctx = op->getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto srcTy = op.getSrc().getType(); + auto destTy = op.getResult().getType(); + auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + auto layoutOrder = getOrder(srcTy); + auto enc = srcTy.getEncoding(); + + // newBase = base + offset + auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + auto smemStrides = smemObj.getStrides(srcTy, loc, rewriter); + SmallVector opOffsetVals = op.getOffsets(); + // We assume we always create a subview of the last dimensions + SmallVector opSmemStrides(smemStrides.end() - opOffsetVals.size(), + smemStrides.end()); + // Compute total offset + SmallVector offsetVals; + auto destRank = op.getResult().getType().getRank(); + auto rankReduced = srcTy.getRank() - destRank; + for (int i = rankReduced; i < opOffsetVals.size(); i++) { + offsetVals.push_back(b.add(opOffsetVals[i], smemObj.getOffsets()[i])); + } + + Value offset; + if (rankReduced || (destTy.getRank() == 1 && destTy.getDimSize(0) == 1)) { + // We are splitting the pipelining dimension which may not be a power of 2 + // so we can't use LinearLayouts + offset = dot(rewriter, loc, opOffsetVals, opSmemStrides); + } else { + auto dimNames = standardOutDimNames(ctx, opOffsetVals.size()); + SmallVector> logicalOffsets; + // This assumes the subviews are additive, in the sense that we can + // compute the offset of one and an add it to the offset of the previous + // one we computed. We check for this in the verifier. + for (int i = 0; i < rankReduced; i++) { + logicalOffsets.push_back({dimNames[i], b.i32_val(0)}); + } + for (int i = rankReduced; i < opOffsetVals.size(); i++) { + logicalOffsets.push_back({dimNames[i], offsetVals[i - rankReduced]}); + } + // The order gives us the honest-to-goodness layout rank + auto srcAllocShape = + srcTy.getAllocShape().take_back(getOrder(srcTy).size()); + auto llInv = toLinearLayout(srcAllocShape, srcTy.getEncoding()).invert(); + offset = + applyLinearLayout(loc, rewriter, llInv, logicalOffsets)[0].second; + } + + auto base = smemObj.getBase(); + auto elemPtrTy = base.getType(); + smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset), + llvmElemTy, offsetVals); + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } +}; + +struct MemDescReinterpretOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult matchAndRewrite(MemDescReinterpretOp op, OpAdaptor adaptor, + ConversionPatternRewriter &b) const override { + Location loc = op.getLoc(); + MemDescType srcTy = op.getSrc().getType(); + MemDescType dstTy = op.getType(); + Type srcElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + Type dstElemTy = getTypeConverter()->convertType(dstTy.getElementType()); + + auto smemObj = + getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), srcElemTy, b); + SharedMemoryObject newObj(smemObj.getBase(), dstElemTy, dstTy.getRank(), + loc, b); + b.replaceOp(op, getStructFromSharedMemoryObject(loc, newObj, b)); + return success(); + } +}; +} // namespace + +void mlir::triton::populateViewOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add( + typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); +} diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/third_party/sunrise/backend/spec/lib/Conversion/TritonToTritonGPU/CMakeLists.txt new file mode 100644 index 000000000..bd9916a06 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -0,0 +1,16 @@ +add_triton_library(FlagTree_sunrise_TritonToTritonGPU + RelayoutTritonGPU.cpp + TritonGPUConversion.cpp + TritonToTritonGPUPass.cpp + + DEPENDS + TritonConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTransforms + TritonIR + ProtonIR + TritonGPUIR +) diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonToTritonGPU/RelayoutTritonGPU.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonToTritonGPU/RelayoutTritonGPU.cpp new file mode 100644 index 000000000..36d848e2d --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonToTritonGPU/RelayoutTritonGPU.cpp @@ -0,0 +1,130 @@ +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +namespace mlir::triton { +#define GEN_PASS_DEF_RELAYOUTTRITONGPU +#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc" +} // namespace mlir::triton + +namespace { + +using namespace mlir; +using namespace triton; +using namespace triton::gpu; +namespace ttng = triton::nvidia_gpu; + +// Given a tensor and its representation in tensor memory, determine its +// distributed layout. +RankedTensorType getTMEMTensorLayout(const TypeConverter *tc, + RankedTensorType type, MemDescType memdesc, + unsigned numWarps) { + Attribute encoding; + type = cast(tc->convertType(type)); + if (isa(memdesc.getEncoding())) { + encoding = LinearEncodingAttr::get( + type.getContext(), getScaleTMEMStoreLinearLayout(type, numWarps)); + } else { + auto tmemEnc = cast(memdesc.getEncoding()); + encoding = ttng::getTmemCompatibleLayout( + tmemEnc.getBlockM(), tmemEnc.getBlockN(), type, numWarps); + } + return RankedTensorType::get(type.getShape(), type.getElementType(), + encoding); +} + +struct TMEMLoadOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttng::TMEMLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType type = getTMEMTensorLayout( + typeConverter, op.getType(), op.getSrc().getType(), lookupNumWarps(op)); + rewriter.modifyOpInPlace(op, [&] { op.getResult().setType(type); }); + Type resultType = getTypeConverter()->convertType(op.getType()); + rewriter.setInsertionPointAfter(op); + auto cvt = rewriter.create(op.getLoc(), resultType, + op.getResult()); + rewriter.replaceAllUsesExcept(op.getResult(), cvt, cvt); + return success(); + } +}; + +struct TMEMStoreOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttng::TMEMStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType type = + getTMEMTensorLayout(typeConverter, op.getSrc().getType(), + op.getDst().getType(), lookupNumWarps(op)); + Value src = + rewriter.create(op.getLoc(), type, adaptor.getSrc()); + rewriter.modifyOpInPlace(op, [&] { op.getSrcMutable().assign(src); }); + return success(); + } +}; + +struct TMEMAllocOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttng::TMEMAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getSrc()) + return success(); + RankedTensorType type = getTMEMTensorLayout( + typeConverter, op.getSrc().getType(), op.getType(), lookupNumWarps(op)); + Value src = + rewriter.create(op.getLoc(), type, adaptor.getSrc()); + rewriter.modifyOpInPlace(op, [&] { op.getSrcMutable().assign(src); }); + return success(); + } +}; + +class RelayoutTritonGPU + : public triton::impl::RelayoutTritonGPUBase { +public: + using RelayoutTritonGPUBase::RelayoutTritonGPUBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + int numWarps = lookupNumWarps(mod); + int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); + int numCTAs = TritonGPUDialect::getNumCTAs(mod); + + // type converter + TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp, + numCTAs, /*enableSourceRemat=*/true); + TritonGPUConversionTarget target(*context, typeConverter); + target.addDynamicallyLegalDialect( + [&](Operation *op) { + return TritonGPUConversionTarget::isDynamicallyLegal(op, + typeConverter); + }); + + // rewrite patterns + RewritePatternSet patterns(context); + // add rules + patterns.insert< + // clang-format off + GatherScatterOpPattern, + GatherScatterOpPattern, + TMEMLoadOpPattern, + TMEMStoreOpPattern, + TMEMAllocOpPattern + // clang-format on + >(typeConverter, context); + + if (failed(applyPartialConversion(mod, target, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp new file mode 100644 index 000000000..3003579be --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -0,0 +1,184 @@ +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" + +#include +#include + +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +using namespace mlir::triton::gpu; + +// +// TypeConverter +// +TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, + int numWarps, int threadsPerWarp, + int numCTAs, + bool enableSourceRemat) + : context(context), numWarps(numWarps), threadsPerWarp(threadsPerWarp), + numCTAs(numCTAs) { + addConversion([](Type type) { return type; }); + + // Add encoding for tensor + addConversion([this](RankedTensorType tensorType) -> RankedTensorType { + // types with encoding are already in the right format + // TODO: check for layout encodings more specifically + if (tensorType.getEncoding()) + return tensorType; + ArrayRef shape = tensorType.getShape(); + triton::gpu::BlockedEncodingAttr encoding = + getDefaultBlockedEncoding(this->context, shape, this->numWarps, + this->threadsPerWarp, this->numCTAs); + return RankedTensorType::get(shape, tensorType.getElementType(), encoding); + }); + + // Add encoding for tensor pointer + addConversion([this](triton::PointerType ptrType) -> triton::PointerType { + // Check whether tensor pointer `tt.ptr>` + auto pointeeTensorType = + dyn_cast(ptrType.getPointeeType()); + if (pointeeTensorType == nullptr) + return ptrType; + + // Add layout into the tensor + auto convertedTensorType = convertType(pointeeTensorType); + return triton::PointerType::get(convertedTensorType, + ptrType.getAddressSpace()); + }); + + // If the origValue still has live user(s), use this to + // convert origValue to newValue + if (enableSourceRemat) { + addSourceMaterialization([](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, Location loc) -> Value { + return builder.create(loc, tensorType, inputs) + .getResult(0); + }); + } + + // This will be called when (desiredType != newOperandType) + // where, desiredType = typeConverter->convertType(origType) + // NOTE: only for remapped values. + addTargetMaterialization([](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, Location loc) { + auto cast = + builder.create(loc, tensorType, inputs); + return cast.getResult(); + }); +} + +// +// TritonGPUConversion +// +TritonGPUConversionTarget::TritonGPUConversionTarget( + MLIRContext &context, TritonGPUTypeConverter &typeConverter) + : ConversionTarget(context) { + // TODO: we should also verify ops of TritonGPUDialect + addLegalDialect(); + + // Some ops from SCF are illegal + addIllegalOp(); + + addDynamicallyLegalDialect( + [&](Operation *op) { return isDynamicallyLegal(op, typeConverter); }); + + // We have requirements for the data layouts + addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { + Attribute aEncoding = + cast(dotOp.getA().getType()).getEncoding(); + Attribute bEncoding = + cast(dotOp.getB().getType()).getEncoding(); + if (aEncoding && isa(aEncoding) && + bEncoding && isa(bEncoding)) + return true; + return false; + }); + addDynamicallyLegalOp([](triton::FuncOp funcOp) -> bool { + for (auto arg : funcOp.getArguments()) { + if (auto tensor = dyn_cast(arg.getType())) { + if (!tensor.getEncoding()) + return false; + } + } + return true; + }); +} + +bool TritonGPUConversionTarget::isDynamicallyLegal( + Operation *op, const TypeConverter &typeConverter) { + bool hasLegalRegions = true; + for (auto ®ion : op->getRegions()) { + hasLegalRegions = hasLegalRegions && typeConverter.isLegal(®ion); + } + if (hasLegalRegions && typeConverter.isLegal(op)) { + return true; + } + return false; +} + +// This function returns the layout to use for gather/scatter indices. The +// `gather4` and `scatter4` TMA instructions require 4 consecutive indices. +// Thus, threads issuing these instructions must have all 4 index elements +// available. +static RankedTensorType getNewIndicesType(RankedTensorType type, + unsigned numThreads, + unsigned numWarps) { + assert(type.getRank() == 1); + auto enc = cast(type.getEncoding()); + + // Technically any layout where we have a pack of 4 neighbouring elements plus + // broadcasted over the warp dimension is okay but for now we just pick a + // layout. + std::array sizePerThread{1, 4}; + std::array threadsPerWarp = {numThreads, 1}; + std::array order = {1, 0}; + std::array warpsPerCta = {1, numWarps}; + + MLIRContext *ctx = type.getContext(); + auto ctaLayout = CTALayoutAttr::getDefault(ctx, /*rank=*/2); + auto parentEncoding = BlockedEncodingAttr::get( + ctx, sizePerThread, threadsPerWarp, warpsPerCta, order, ctaLayout); + auto newEncoding = SliceEncodingAttr::get(ctx, /*dim=*/0, parentEncoding); + if (enc == newEncoding) + return {}; + + return RankedTensorType::get(type.getShape(), type.getElementType(), + newEncoding); +} + +// Function for converting any gather or scatter op that requires a specific +// index layout. This also handles converting result types if there are any. +static LogicalResult convertGatherScatterIndices(Operation *op, + OpOperand &indices, + ConversionPatternRewriter &b) { + auto type = cast(indices.get().getType()); + RankedTensorType newType = + getNewIndicesType(type, lookupThreadsPerWarp(b), lookupNumWarps(op)); + if (!newType) + return failure(); + Value index = b.create(op->getLoc(), newType, indices.get()); + indices.set(index); + return success(); +} + +LogicalResult impl::convertGatherScatterOp( + Operation *op, ValueRange operands, OpOperand &xOffsetsMutable, + const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { + LogicalResult result = success(); + rewriter.modifyOpInPlace(op, [&] { + for (auto [operand, value] : llvm::zip(op->getOpOperands(), operands)) + operand.set(value); + for (OpResult result : op->getOpResults()) + result.setType(typeConverter.convertType(result.getType())); + result = convertGatherScatterIndices(op, xOffsetsMutable, rewriter); + }); + return result; +} diff --git a/third_party/sunrise/backend/spec/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/third_party/sunrise/backend/spec/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp new file mode 100644 index 000000000..def97ec70 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -0,0 +1,821 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" +#include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace mlir::triton { +#define GEN_PASS_DEF_CONVERTTRITONTOTRITONGPU +#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc" +} // namespace mlir::triton + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// pass named attrs (e.g., tt.contiguity) from Triton to Triton +static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) { + for (const NamedAttribute attr : dictAttrs.getValue()) + if (!op->hasAttr(attr.getName())) + op->setAttr(attr.getName(), attr.getValue()); +} + +template struct GenericOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector retTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + retTypes))) + return failure(); + rewriter.replaceOpWithNewOp(op, retTypes, adaptor.getOperands(), + op->getAttrs()); + + return success(); + } +}; + +class ArithConstantPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = getTypeConverter()->convertType(op.getType()); + auto retShapedType = cast(retType); + auto value = dyn_cast(adaptor.getValue()); + if (isa(retShapedType)) { + assert(value && "expected a dense elements attribute"); + // This is a hack. We just want to add encoding. + value = value.reshape(retShapedType); + } + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retShapedType, value), + adaptor.getAttributes()); + return success(); + } +}; + +void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, + TritonGPUConversionTarget &target) { + // -------------- + // Add legality and rewrite pattern rules for operations + // from the Arith dialect. The basic premise is that + // Arith operations require both inputs to have the same + // non-null encoding + // -------------- + MLIRContext *context = patterns.getContext(); + // TODO: there's probably a better way to avoid adding all ops one-by-one + patterns.add< + ArithConstantPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, // NegFOp + // Floating point + GenericOpPattern, GenericOpPattern, + // MaxMin + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + // Floating point + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + // Cmp + GenericOpPattern, GenericOpPattern, + // Select + GenericOpPattern, + // Cast Ops + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern>(typeConverter, context); +} + +void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, + TritonGPUConversionTarget &target) { + MLIRContext *context = patterns.getContext(); + // Rewrite rule + patterns.add, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern>( + typeConverter, context); +} + +// +// Triton patterns +// +struct TritonExpandDimsPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Type retType = op.getType()); + RankedTensorType argType = + cast(adaptor.getSrc().getType()); + Attribute _argEncoding = argType.getEncoding(); + if (!_argEncoding) + return failure(); + auto argEncoding = cast(_argEncoding); + // return shape + auto retShape = argType.getShape().vec(); + retShape.insert(retShape.begin() + op.getAxis(), 1); + // return encoding + auto retSizePerThread = llvm::to_vector(argEncoding.getSizePerThread()); + retSizePerThread.insert(retSizePerThread.begin() + op.getAxis(), 1); + auto retThreadsPerWarp = to_vector(argEncoding.getThreadsPerWarp()); + retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.getAxis(), 1); + auto retWarpsPerCTA = to_vector(argEncoding.getWarpsPerCTA()); + retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.getAxis(), 1); + SmallVector retOrder(retShape.size()); + std::iota(retOrder.begin(), retOrder.end(), 0); + + auto argCTALayout = argEncoding.getCTALayout(); + auto retCTAsPerCGA = insertOne(argCTALayout.getCTAsPerCGA(), op.getAxis()); + auto retCTASplitNum = + insertOne(argCTALayout.getCTASplitNum(), op.getAxis()); + auto retCTAOrder = insertOrder(argCTALayout.getCTAOrder(), op.getAxis()); + auto retCTALayout = triton::gpu::CTALayoutAttr::get( + getContext(), retCTAsPerCGA, retCTASplitNum, retCTAOrder); + + triton::gpu::BlockedEncodingAttr retEncoding = + triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread, + retThreadsPerWarp, retWarpsPerCTA, + retOrder, retCTALayout); + // convert operand to slice of return type + Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get( + getContext(), op.getAxis(), retEncoding); + RankedTensorType newArgType = RankedTensorType::get( + argType.getShape(), argType.getElementType(), newArgEncoding); + // construct new op + auto newSrc = rewriter.create( + op.getLoc(), newArgType, adaptor.getSrc()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, newSrc, adaptor.getAxis()), + adaptor.getAttributes()); + return success(); + } + +private: + template + SmallVector insertOne(ArrayRef vec, unsigned axis) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + axis, 1); + return res; + } + + // Example: order = [ 0, 2, 1, 3], dim = 2 + // resOrder = [2, 0, 3, 1, 4] + SmallVector insertOrder(ArrayRef order, + unsigned axis) const { + SmallVector resOrder(order.begin(), order.end()); + for (unsigned i = 0; i < resOrder.size(); ++i) + if (resOrder[i] >= axis) + ++resOrder[i]; + resOrder.insert(resOrder.begin(), axis); + return resOrder; + } +}; + +struct TritonDotPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType origType = op.getType(); + auto origShape = origType.getShape(); + auto typeConverter = getTypeConverter(); + int numWarps = typeConverter->getNumWarps(); + int threadsPerWarp = typeConverter->getThreadsPerWarp(); + int numCTAs = typeConverter->getNumCTAs(); + auto rank = origShape.size(); + SmallVector retSizePerThread(rank, 1); + auto numElements = product(origShape); + if (numElements / (numWarps * threadsPerWarp) >= 4) { + retSizePerThread[rank - 1] = 2; + retSizePerThread[rank - 2] = 2; + } + if (numElements / (numWarps * threadsPerWarp) >= 16) { + retSizePerThread[rank - 1] = 4; + retSizePerThread[rank - 2] = 4; + } + retSizePerThread[rank - 1] = std::min( + retSizePerThread[rank - 1], static_cast(origShape[rank - 1])); + retSizePerThread[rank - 2] = std::min( + retSizePerThread[rank - 2], static_cast(origShape[rank - 2])); + + SmallVector retOrder(rank); + for (unsigned i = 0; i < rank; ++i) + retOrder[i] = rank - 1 - i; + Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get( + getContext(), origShape, retSizePerThread, retOrder, numWarps, + threadsPerWarp, numCTAs); + RankedTensorType retType = + RankedTensorType::get(origShape, origType.getElementType(), dEncoding); + // a & b must be of smem layout + auto aType = cast(adaptor.getA().getType()); + auto bType = cast(adaptor.getB().getType()); + Type aEltType = aType.getElementType(); + Type bEltType = bType.getElementType(); + Attribute aEncoding = aType.getEncoding(); + Attribute bEncoding = bType.getEncoding(); + if (!aEncoding || !bEncoding) + return failure(); + Value a = adaptor.getA(); + Value b = adaptor.getB(); + Value c = adaptor.getC(); + if (!mlir::isa(aEncoding)) { + Attribute encoding = triton::gpu::DotOperandEncodingAttr::get( + getContext(), 0, dEncoding, aEltType); + auto dstType = + RankedTensorType::get(aType.getShape(), aEltType, encoding); + a = rewriter.create(a.getLoc(), dstType, a); + } + if (!mlir::isa(bEncoding)) { + Attribute encoding = triton::gpu::DotOperandEncodingAttr::get( + getContext(), 1, dEncoding, bEltType); + auto dstType = + RankedTensorType::get(bType.getShape(), bEltType, encoding); + b = rewriter.create(b.getLoc(), dstType, b); + } + c = rewriter.create(c.getLoc(), retType, c); + + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retType, a, b, c, adaptor.getInputPrecision(), + adaptor.getMaxNumImpreciseAcc()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonCatPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // The cat op satisfy two conditions: + // 1. output.numel = lhs.numel + rhs.numel + // 2. output.total_elems_per_thread = + // next_power_of_2(lhs.total_elems_per_thread + rhs.total_elems_per_thread) + // For now, this behaves like generic, but this + // will evolve when we add support for `can_reorder=False`. + auto retType = cast( + this->getTypeConverter()->convertType(op.getType())); + auto retEncoding = + cast(retType.getEncoding()); + auto lhsType = adaptor.getLhs().getType(); + auto rhsType = adaptor.getRhs().getType(); + auto lhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(lhsType); + auto rhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(rhsType); + auto retTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(retType); + auto retShape = retType.getShape(); + auto retOrder = retEncoding.getOrder(); + auto retThreadsPerWarp = retEncoding.getThreadsPerWarp(); + auto retWarpsPerCTA = retEncoding.getWarpsPerCTA(); + // Get new retSizePerThread if ret elems per thread is not enough. + // We have to round it up to the next power of 2 due to triton's tensor size + // constraint. + auto newRetTotalElemsPerThread = + nextPowOf2(lhsTotalElemsPerThread + rhsTotalElemsPerThread); + auto newRetSizePerThread = llvm::to_vector(retEncoding.getSizePerThread()); + newRetSizePerThread[retOrder[0]] *= + newRetTotalElemsPerThread / retTotalElemsPerThread; + triton::gpu::BlockedEncodingAttr newRetEncoding = + triton::gpu::BlockedEncodingAttr::get( + getContext(), newRetSizePerThread, retThreadsPerWarp, + retWarpsPerCTA, retOrder, retEncoding.getCTALayout()); + auto newRetType = RankedTensorType::get(retShape, retType.getElementType(), + newRetEncoding); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, newRetType, adaptor.getOperands()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonJoinOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Simply rely on type inference for this op. (Notably, GenericOpPattern + // does not do this, instead it assigns the default layout to the ins and + // outs.) + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, adaptor.getLhs(), adaptor.getRhs()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonSplitOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto src = adaptor.getSrc(); + auto srcTy = cast(src.getType()); + auto srcEnc = dyn_cast(srcTy.getEncoding()); + int rank = srcEnc.getOrder().size(); + auto typeConverter = getTypeConverter(); + + // The operand to split must have: + // - a blocked layout, with + // - sizePerThread = 2 in the last dimension, + // - threadsPerWarp, warpsPerCTA, and CTAsPerCGA = 1 in the last dim, and + // - the last dimension minor. + // If that's not the case, add a convert before the split. + if (!srcEnc || srcEnc.getSizePerThread().back() != 2 || + srcEnc.getOrder().front() != rank - 1) { + // If we take the default encoding for the op's result (i.e. post-split) + // and add 1 to the end of each dim, that gives us what we want. Other + // than making a legal src encoding, our choice of layout doesn't matter; + // it'll get fixed by RemoveLayoutConversions. + auto defaultEnc = getDefaultBlockedEncoding( + getContext(), + cast(op.getResult(0).getType()).getShape(), + typeConverter->getNumWarps(), typeConverter->getThreadsPerWarp(), + typeConverter->getNumCTAs()); + + auto append = [&](ArrayRef vals, unsigned val) { + SmallVector res(vals); + res.push_back(val); + return res; + }; + auto prepend = [&](ArrayRef vals, unsigned val) { + SmallVector res; + res.push_back(val); + res.append(vals.begin(), vals.end()); + return res; + }; + + srcEnc = BlockedEncodingAttr::get( + getContext(), append(defaultEnc.getSizePerThread(), 2), + append(defaultEnc.getThreadsPerWarp(), 1), + append(defaultEnc.getWarpsPerCTA(), 1), + prepend(defaultEnc.getOrder(), rank - 1), + CTALayoutAttr::get(getContext(), + append(defaultEnc.getCTAsPerCGA(), 1), + append(defaultEnc.getCTASplitNum(), 1), + prepend(defaultEnc.getCTAOrder(), rank - 1))); + srcTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), + srcEnc); + src = rewriter.create(op.getLoc(), srcTy, src); + } + + addNamedAttrs(rewriter.replaceOpWithNewOp(op, src), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonTransPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = adaptor.getSrc(); + auto srcTy = cast(src.getType()); + auto srcEnc = srcTy.getEncoding(); + if (!srcEnc) + return failure(); + addNamedAttrs(rewriter.replaceOpWithNewOp(op, src, op.getOrder()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonBroadcastPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // This creates a tensor with the new shape but the argument's layout + LogicalResult + matchAndRewrite(BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcType = cast(adaptor.getSrc().getType()); + auto srcEncoding = srcType.getEncoding(); + if (!srcEncoding) + return failure(); + Type retType = RankedTensorType::get( + op.getType().getShape(), op.getType().getElementType(), srcEncoding); + // Type retType = this->getTypeConverter()->convertType(op.getType()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retType, adaptor.getOperands()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonReducePattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newReduce = rewriter.create( + op.getLoc(), adaptor.getOperands(), adaptor.getAxis()); + addNamedAttrs(newReduce, adaptor.getAttributes()); + + auto &newCombineOp = newReduce.getCombineOp(); + rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp, + newCombineOp.end()); + rewriter.replaceOp(op, newReduce.getResult()); + return success(); + } +}; + +struct TritonScanPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newScan = rewriter.create( + op.getLoc(), adaptor.getOperands(), adaptor.getAxis(), op.getReverse()); + addNamedAttrs(newScan, adaptor.getAttributes()); + + auto &newCombineOp = newScan.getCombineOp(); + rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp, + newCombineOp.end()); + rewriter.replaceOp(op, newScan.getResult()); + return success(); + } +}; + +class TritonFuncOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + TypeConverter::SignatureConversion result(op.getNumArguments()); + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getName(), op.getFunctionType()); + addNamedAttrs(newOp, adaptor.getAttributes()); + rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(), + newOp.getBody().end()); + // Convert just the entry block. The remaining unstructured control flow is + // converted by br patterns. + if (!newOp.getBody().empty()) + rewriter.applySignatureConversion(&newOp.getBody().front(), result, + converter); + return success(); + } +}; + +class TritonCallOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getCallee(), op.getResultTypes(), adaptor.getOperands()); + addNamedAttrs(newOp, adaptor.getAttributes()); + return success(); + } +}; + +class TritonReturnOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ReturnOp op, ReturnOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } +}; + +void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, unsigned numCTAs) { + MLIRContext *context = patterns.getContext(); + patterns.insert< // TODO: view should have custom pattern that views the + // layout + // clang-format off + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + TritonBroadcastPattern, + TritonCatPattern, + TritonJoinOpPattern, + TritonSplitOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + TritonReducePattern, + GenericOpPattern, + TritonScanPattern, + GenericOpPattern, + GenericOpPattern, + TritonExpandDimsPattern, + TritonTransPattern, + TritonDotPattern, + GatherScatterOpPattern, + GatherScatterOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + // this assumes the right layout will be set later for dot scaled. + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + TritonFuncOpPattern + // clang-format on + >(typeConverter, context); +} +// Proton patterns +// NOTE: Because Proton's inputs are scalars and not tensors this conversion +// isn't strictly necessary however you could envision a case where we pass in +// tensors in for Triton object specific tracing operations in which case we +// would need to fill in the OpConversionPattern +void populateProtonPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add>(typeConverter, + context); +} +// +// SCF patterns +// +// This is borrowed from ConvertForOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +struct SCFForPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + // Ref: ConvertForOpTypes + LogicalResult + matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + + // Now, update all the types. + + // Convert the types of block arguments within the given region. This + // replaces each block with a new block containing the updated signature. + // The entry block may have a special conversion if `entryConversion` is + // provided. On success, the new entry block to the region is returned for + // convenience. Otherwise, failure is returned. + if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), + *getTypeConverter()))) { + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + // Change the clone to use the updated operands. We could have cloned with + // a IRMapping, but this seems a bit more direct. + newOp->setOperands(adaptor.getOperands()); + // Update the result types to the new converted types. + SmallVector newResultTypes; + for (Type type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + + rewriter.replaceOp(op, newOp.getResults()); + + return success(); + } +}; + +// This is borrowed from ConvertFIfOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +class SCFIfPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::IfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // TODO: Generalize this to any type conversion, not just 1:1. + // + // We need to implement something more sophisticated here that tracks which + // types convert to which other types and does the appropriate + // materialization logic. + // For example, it's possible that one result type converts to 0 types and + // another to 2 types, so newResultTypes would at least be the right size to + // not crash in the llvm::zip call below, but then we would set the the + // wrong type on the SSA values! These edge cases are also why we cannot + // safely use the TypeConverter::convertTypes helper here. + SmallVector newResultTypes; + for (auto type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + + // See comments in the ForOp pattern for why we clone without regions and + // then inline. + scf::IfOp newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), + newOp.getThenRegion().end()); + rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), + newOp.getElseRegion().end()); + + // Update the operands and types. + newOp->setOperands(adaptor.getOperands()); + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +// This is borrowed from ConvertFIfOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +class SCFWhilePattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::WhileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *converter = getTypeConverter(); + assert(converter); + SmallVector newResultTypes; + if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes))) + return failure(); + + auto newOp = rewriter.create(op.getLoc(), newResultTypes, + adaptor.getOperands()); + for (auto i : {0u, 1u}) { + auto &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + if (failed(rewriter.convertRegionTypes(&dstRegion, *converter))) + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +class SCFConditionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::ConditionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.modifyOpInPlace(op, + [&]() { op->setOperands(adaptor.getOperands()); }); + return success(); + } +}; + +void populateSCFPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add, SCFForPattern, SCFIfPattern, + SCFWhilePattern, SCFConditionPattern>(typeConverter, context); +} + +// CF + +class CFBranchPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::BranchOp op, cf::BranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getSuccessor(), adaptor.getOperands()); + if (failed(rewriter.convertRegionTypes(newOp.getSuccessor()->getParent(), + *converter))) + return failure(); + return success(); + } +}; + +class CFCondBranchPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::CondBranchOp op, cf::CondBranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, adaptor.getCondition(), op.getTrueDest(), + adaptor.getTrueDestOperands(), op.getFalseDest(), + adaptor.getFalseDestOperands()); + addNamedAttrs(newOp, adaptor.getAttributes()); + + if (failed(rewriter.convertRegionTypes(newOp.getTrueDest()->getParent(), + *converter))) + return failure(); + if (failed(rewriter.convertRegionTypes(newOp.getFalseDest()->getParent(), + *converter))) + return failure(); + return success(); + } +}; + +void populateCFPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add(typeConverter, context); +} + +class ConvertTritonToTritonGPU + : public triton::impl::ConvertTritonToTritonGPUBase< + ConvertTritonToTritonGPU> { +public: + using ConvertTritonToTritonGPUBase::ConvertTritonToTritonGPUBase; + + void runOnOperation() override { + if (target.getValue().empty()) { + mlir::emitError( + getOperation().getLoc(), + "'convert-triton-to-tritongpu' requires 'target' option to be set"); + return signalPassFailure(); + } + + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + // type converter + TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp, + numCTAs, enableSourceRemat); + TritonGPUConversionTarget target(*context, typeConverter); + // rewrite patterns + RewritePatternSet patterns(context); + // add rules + populateArithPatternsAndLegality(typeConverter, patterns, target); + populateMathPatternsAndLegality(typeConverter, patterns, target); + populateTritonPatterns(typeConverter, patterns, numCTAs); + populateProtonPatterns(typeConverter, patterns); + // TODO: can we use + // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? + populateSCFPatterns(typeConverter, patterns); + populateCFPatterns(typeConverter, patterns); + patterns.insert>(typeConverter, context); + + Builder b(&getContext()); + mod->setAttr(AttrNumWarpsName, b.getI32IntegerAttr(numWarps)); + mod->setAttr(AttrNumThreadsPerWarp, b.getI32IntegerAttr(threadsPerWarp)); + mod->setAttr(AttrNumCTAsName, b.getI32IntegerAttr(numCTAs)); + mod->setAttr(AttrTargetName, b.getStringAttr(this->target.getValue())); + + if (failed(applyPartialConversion(mod, target, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace diff --git a/third_party/sunrise/backend/spec/lib/Dialect/CMakeLists.txt b/third_party/sunrise/backend/spec/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000..b2375fd85 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/CMakeLists.txt @@ -0,0 +1,2 @@ +#add_subdirectory(Triton) +add_subdirectory(TritonGPU) diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/CMakeLists.txt b/third_party/sunrise/backend/spec/lib/Dialect/Triton/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/CMakeLists.txt b/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/CMakeLists.txt new file mode 100644 index 000000000..78a9f57fe --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/CMakeLists.txt @@ -0,0 +1,22 @@ +set(LLVM_TARGET_DEFINITIONS Canonicalize.td) +mlir_tablegen(TritonCanonicalize.inc -gen-rewriters) +add_public_tablegen_target(TritonCanonicalizeIncGen) + +add_triton_library(FlagTree_sunrise_TritonIR + Dialect.cpp + Ops.cpp + Traits.cpp + Types.cpp + OpInterfaces.cpp + Utility.cpp + + DEPENDS + TritonTableGen + TritonCanonicalizeIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRArithDialect + MLIRMathDialect + MLIRSCFDialect +) diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Canonicalize.td b/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Canonicalize.td new file mode 100644 index 000000000..dc3771033 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Canonicalize.td @@ -0,0 +1,17 @@ +#ifndef TT_PATTERNS +#define TT_PATTERNS + +include "mlir/IR/PatternBase.td" +include "triton/Dialect/Triton/IR/TritonOps.td" + +// broadcast(splat(x)) -> splat(x) +def BroadcastSplatPattern : + Pat<(TT_BroadcastOp (TT_SplatOp $x)), + (TT_SplatOp $x)>; + +// broadcast(broadcast(x)) -> broadcast(x) +def BroadcastBroadcastPattern : + Pat<(TT_BroadcastOp (TT_BroadcastOp $x)), + (TT_BroadcastOp $x)>; + +#endif diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Dialect.cpp b/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Dialect.cpp new file mode 100644 index 000000000..31786c184 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Dialect.cpp @@ -0,0 +1,79 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Interfaces.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc" +#include "triton/Dialect/Triton/IR/Dialect.cpp.inc" +#include "triton/Dialect/Triton/IR/OpInterfaces.cpp.inc" + +using namespace mlir; +using namespace mlir::triton; + +//===----------------------------------------------------------------------===// +// TritonDialect Dialect Interfaces +//===----------------------------------------------------------------------===// + +bool TritonInlinerInterface::isLegalToInline(Operation *call, + Operation *callable, + bool wouldBeCloned) const { + auto funcOp = dyn_cast(callable); + if (!funcOp) + return true; + if (funcOp->hasAttr("noinline")) + return !funcOp->getAttrOfType("noinline").getValue(); + return true; +} + +/// Handle the given inlined terminator by replacing it with a new operation +/// as necessary. +void TritonInlinerInterface::handleTerminator(Operation *op, + Block *newDest) const { + // Only return needs to be handled here. + auto returnOp = dyn_cast(op); + if (!returnOp) + return; + + // Replace the return with a branch to the dest. + OpBuilder builder(op); + builder.create(op->getLoc(), newDest, + returnOp.getOperands()); + op->erase(); +} + +/// Handle the given inlined terminator by replacing it with a new operation +/// as necessary. +void TritonInlinerInterface::handleTerminator(Operation *op, + ValueRange valuesToRepl) const { + // Only return needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()].replaceAllUsesWith(it.value()); +} + +void TritonDialect::initialize() { + registerTypes(); + + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/Triton/IR/Ops.cpp.inc" + >(); + + // We can also add interface here. + addInterfaces(); +} + +Operation *TritonDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + return arith::ConstantOp::materialize(builder, value, type, loc); +} + + diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/OpInterfaces.cpp b/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/OpInterfaces.cpp new file mode 100644 index 000000000..cd8f79066 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/OpInterfaces.cpp @@ -0,0 +1,77 @@ +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Support/LogicalResult.h" + +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#include "triton/Dialect/Triton/IR/Types.h" + +namespace mlir { +namespace triton { +namespace impl { + +LogicalResult verifyTransposeOpInterface(Operation *op) { + TransposeOpInterface transposeOp = cast(op); + auto rank = cast(transposeOp.getSrc().getType()).getRank(); + auto order = transposeOp.getOrder(); + if (rank != order.size()) { + return op->emitError( + "order must have the same size as the rank of the operand and result"); + } + + SmallVector sortedOrder(order); + llvm::sort(sortedOrder); + for (int32_t i = 0; i < sortedOrder.size(); i++) { + if (sortedOrder[i] != i) { + return op->emitError("order must be a permutation of [0, ..., rank - 1]"); + } + } + + return success(); +} + +// A DotOpInterface operation should have at least three operands. +// The first two operands should share a common dimension, and the result +// should have the dimensions of the two operands that are not shared. +// A DotOpInterface operation can be either 2d or 3d. +// In the 3d case, the first dimension of operands is the batch dimension. +LogicalResult verifyDotOpInterface(Operation *op) { + DotOpInterface dotOp = cast(op); + + if (dotOp->getNumOperands() < 3) + return dotOp->emitOpError("expected at least 3 operands"); + auto aTy = cast(dotOp->getOperand(0).getType()); + auto bTy = cast(dotOp->getOperand(1).getType()); + auto cTy = cast(dotOp->getOperand(2).getType()); + auto aShape = aTy.getShape(); + auto bShape = bTy.getShape(); + auto cShape = cTy.getShape(); + // Check if all 3d or all 2d + if (aShape.size() != 2 && aShape.size() != 3) + return dotOp->emitOpError("expected operands to be 2d or 3d"); + if (aShape.size() != bShape.size() || aShape.size() != cShape.size()) + return dotOp->emitOpError("expected all operands to have the same rank"); + + // Check for valid A, B input shapes for dot + if (!dotOp.verifyDims()) + return dotOp->emitOpError( + "expected the last dimension of the first operand " + "to be equal to the second-to-last dimension of " + "the second operand"); + + // Check the batch dimension + if (aShape.size() == 3 && (aShape[0] != cShape[0] || bShape[0] != cShape[0])) + return dotOp->emitOpError("expected the first dimension of the first " + "operand to be equal to the first dimension of " + "the result"); + // Check the output shape + if (!dotOp.verifyOutputDims()) + return dotOp->emitOpError( + "expected the output shape to be the concatenation of the last " + "dimension of the first operand and the last dimension of the " + "second "); + return success(); +} + +} // namespace impl +} // namespace triton +} // namespace mlir diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Ops.cpp b/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Ops.cpp new file mode 100644 index 000000000..060a4e609 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Ops.cpp @@ -0,0 +1,1368 @@ +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "llvm/Support/ErrorHandling.h" + +namespace mlir { +namespace triton { + +void LoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable(), + GlobalMemory::get()); + if (getIsVolatile()) + effects.emplace_back(MemoryEffects::Write::get()); +} + +} // namespace triton +} // namespace mlir + +#define GET_OP_CLASSES +#include "triton/Dialect/Triton/IR/Ops.cpp.inc" + +// enum attribute definitions +#include "triton/Dialect/Triton/IR/OpsEnums.cpp.inc" + +#include "TritonCanonicalize.inc" + +namespace mlir { +namespace triton { + +//-- LoadOp -- +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + CacheModifier cache, EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, /*padding=*/std::nullopt, + cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, boundaryCheck, + padding, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, CacheModifier cache, EvictionPolicy evict, + bool isVolatile) { + LoadOp::build(builder, state, ptr, mask, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, mask, other, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + auto paddingAttr = + padding.has_value() + ? PaddingOptionAttr::get(builder.getContext(), padding.value()) + : PaddingOptionAttr(); + LoadOp::build(builder, state, ptr, mask, other, + builder.getDenseI32ArrayAttr(boundaryCheck), paddingAttr, cache, + evict, isVolatile); +} + +// load(ptr, splat(1), ...) -> load(ptr, ...) +// load(ptr, splat(0), other, ...) -> other +struct CanonicalizeMaskedLoadPattern : public OpRewritePattern { + CanonicalizeMaskedLoadPattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(LoadOp loadOp, + PatternRewriter &rewriter) const override { + auto mask = loadOp.getMask(); + if (!mask) + return failure(); + + auto constantMask = mask.getDefiningOp(); + if (!constantMask) + return failure(); + + auto splatMask = mlir::dyn_cast(constantMask.getValue()); + if (!splatMask) + return failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(), + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + } else { + // mask = splat(0) + + // If there's no "other", the value is "undef". Perhaps we want to + // optimize it in the future.x + auto otherVal = loadOp.getOther(); + if (!otherVal) + return failure(); + rewriter.replaceOp(loadOp, otherVal); + } + return success(); + } +}; + +void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//-- StoreOp -- +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, CacheModifier cache, EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, + /*boundaryCheck=*/{}, cache, evict); +} + +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, Value mask, CacheModifier cache, + EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, mask, /*boundaryCheck=*/{}, + cache, evict); +} + +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, ArrayRef boundaryCheck, + CacheModifier cache, EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, + builder.getDenseI32ArrayAttr(boundaryCheck), cache, + evict); +} + +// store(ptr, value, splat(1), ...) -> store(ptr, value, ...) +// store(ptr, value, splat(0), ...) -> [none] +struct CanonicalizeMaskedStorePattern : public OpRewritePattern { + CanonicalizeMaskedStorePattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(StoreOp storeOp, + PatternRewriter &rewriter) const override { + auto mask = storeOp.getMask(); + if (!mask) + return failure(); + + auto constantMask = mask.getDefiningOp(); + if (!constantMask) + return failure(); + + auto splatMask = mlir::dyn_cast(constantMask.getValue()); + if (!splatMask) + return failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(), + storeOp.getEvict()); + } else { + // mask = splat(0) + rewriter.eraseOp(storeOp); + } + return success(); + } +}; + +void StoreOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//-- TransOp -- +OpFoldResult TransOp::fold(FoldAdaptor adaptor) { + // transpose(x, order=[0, 1, ...]) -> x + if (isIota(getOrder())) { + // If the source and result types are the same, we can return the source + // If their layout is different (even if structurally equivalent), we need + // to insert a convert_layout in between as otherwise ::fold complains + // We do this in CanonicalizeConvertFromTranspose + if (getSrc().getType() == getType()) { + return getSrc(); + } + } + + // transpose(transpose(x)) -> transpose(x) + if (auto innerTrans = getSrc().getDefiningOp()) { + setOrder(applyPermutation(innerTrans.getOrder(), getOrder())); + setOperand(innerTrans.getSrc()); + return getResult(); + } + + // Eliminate splat constant transpose ops. + if (auto attr = + llvm::dyn_cast_if_present(adaptor.getSrc())) + return attr.reshape(getType()); + + return {}; +} + +LogicalResult TransOp::verify() { + auto order = getOrder(); + auto srcTy = cast(getSrc().getType()); + if (order.size() != srcTy.getShape().size()) { + return emitError("order must have the same size as the source tensor"); + } + if (!isPermutationOfIota(order)) { + return emitError("order must be a permutation of 0..n-1"); + } + SmallVector retShape = applyPermutation(srcTy.getShape(), order); + if (retShape != getType().getShape()) { + return emitError( + "result shape must match the permutation of the source shape"); + } + return success(); +} + +LogicalResult +TransOp::inferReturnTypes(MLIRContext *context, std::optional loc, + TransOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes) { + + // type is the same as the input + auto argTy = cast(adaptor.getSrc().getType()); + auto shape = argTy.getShape(); + auto order = adaptor.getOrder(); + SmallVector retShape = applyPermutation(shape, order); + + auto retEltTy = argTy.getElementType(); + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = cast(&dialect); + if (failed(inferLayoutInterface->inferTransOpEncoding( + argEncoding, shape, order, retEncoding, loc))) { + return failure(); + } + } + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); + return success(); +} + +//-- DotOp -- +LogicalResult +DotOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the accumulator + auto accTy = cast(operands[2].getType()); + inferredReturnTypes.push_back(accTy); + + // verify encodings + auto aEnc = cast(operands[0].getType()).getEncoding(); + auto bEnc = cast(operands[1].getType()).getEncoding(); + auto retEnc = accTy.getEncoding(); + if (aEnc) { + assert(bEnc && retEnc); + Dialect &dialect = retEnc.getDialect(); + auto interface = cast(&dialect); + if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) + return failure(); + if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed()) + return failure(); + } + return success(); +} + +LogicalResult DotOp::verify() { + auto aTy = getA().getType(); + auto bTy = getB().getType(); + if (aTy.getElementType().getIntOrFloatBitWidth() != + bTy.getElementType().getIntOrFloatBitWidth()) + return emitError( + "element types of operands A and B must have same bit width"); + auto aEncoding = aTy.getEncoding(); + auto bEncoding = bTy.getEncoding(); + if (!aEncoding && !bEncoding) + return success(); + // Verify that the encodings are valid. + if (!aEncoding || !bEncoding) + return emitError("mismatching encoding between A and B operands"); + auto accTy = getC().getType(); + auto retEnc = accTy.getEncoding(); + if (!retEnc) + return emitError("miss encoding of C operand"); + Dialect &dialect = retEnc.getDialect(); + auto interface = cast(&dialect); + return interface->verifyDotOpEncodingCompatibility(getOperation(), aEncoding, + bEncoding); +} + +bool DotOp::verifyDims() { + auto aShape = this->getA().getType().getShape(); + auto bShape = this->getB().getType().getShape(); + + return aShape[aShape.size() - 1] == bShape[aShape.size() - 2]; +} + +//-- DotScaledOp -- +bool DotScaledOp::verifyDims() { + auto aShape = this->getA().getType().getShape(); + auto bShape = this->getB().getType().getShape(); + + auto aKdim = aShape[aShape.size() - 1]; + auto bKdim = bShape[aShape.size() - 2]; + if (this->getAElemType() == ScaleDotElemType::E2M1) { + if (this->getLhsKPack()) + aKdim *= 2; + } + if (this->getBElemType() == ScaleDotElemType::E2M1) { + if (this->getRhsKPack()) + bKdim *= 2; + } + + return aKdim == bKdim; +} + +bool DotScaledOp::verifyOutputDims() { + auto cShape = this->getC().getType().getShape(); + auto oMdim = cShape[cShape.size() - 2]; + auto oNdim = cShape[cShape.size() - 1]; + auto aShape = this->getA().getType().getShape(); + auto bShape = this->getB().getType().getShape(); + auto adim = aShape[aShape.size() - 2]; + auto bdim = bShape[bShape.size() - 1]; + if (this->getAElemType() == ScaleDotElemType::E2M1) { + if (!this->getLhsKPack()) + adim *= 2; + } + if (this->getBElemType() == ScaleDotElemType::E2M1) { + if (!this->getRhsKPack()) + bdim *= 2; + } + if (adim != oMdim || bdim != oNdim) + return false; + return true; +} + +//-- MakeRangeOp -- +OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) { + // make_range(start, start + 1) -> constant(start) + if (adaptor.getStart() + 1 == adaptor.getEnd()) { + auto shapedType = cast(getType()); + return SplatElementsAttr::get(shapedType, adaptor.getStartAttr()); + } + return {}; +} + +LogicalResult MakeRangeOp::verify() { + int64_t start = getStartAttr().getInt(); + int64_t end = getEndAttr().getInt(); + if (start >= end) { + return this->emitOpError() << "start must be less than end"; + } + auto ty = getType(); + if (ty.getShape().size() != 1) { + return this->emitOpError() << "return type must be a 1D tensor"; + } + if (end - start != ty.getShape()[0]) { + return this->emitOpError() + << "number of elements in returned tensor, " << ty.getShape()[0] + << ", must match size of range [" << start << ", " << end + << "), which has " << end - start << " elements"; + } + if (!ty.getElementType().isInteger(32)) { + return this->emitOpError() << "returned tensor must have i32 elements"; + } + return success(); +} + +//-- ReduceOp -- +static LogicalResult +inferReduceReturnShape(std::optional loc, RankedTensorType argTy, + Type retEltTy, int axis, + SmallVectorImpl &inferredReturnTypes) { + auto retShape = argTy.getShape().vec(); + retShape.erase(retShape.begin() + axis); + if (retShape.empty()) { + // 0d-tensor -> scalar + inferredReturnTypes.push_back(retEltTy); + } else { + // nd-tensor where n >= 1 + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = cast(&dialect); + if (failed(inferLayoutInterface->inferReduceOpEncoding( + argEncoding, axis, retEncoding, loc))) { + return failure(); + } + } + // create type + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); + } + return success(); +} + +LogicalResult +ReduceOp::inferReturnTypes(MLIRContext *context, std::optional loc, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + Properties *prop = properties.as(); + int axis = prop->axis.getInt(); + for (auto arg : operands) { + auto argTy = cast(arg.getType()); + auto retEltTy = argTy.getElementType(); + if (failed(inferReduceReturnShape(loc, argTy, retEltTy, axis, + inferredReturnTypes))) { + return failure(); + } + } + return success(); +} + +// Helpers for Reductions and Scans +template LogicalResult verifyReduceScan(Op &op) { + if (op.getOperands().empty()) { + return op.emitOpError() << "must have at least 1 operand"; + } + if (op.getNumOperands() != op.getNumResults()) { + return op.emitOpError() << "must have the same number of inputs as outputs"; + } + + auto getElementType = [](Type ty) { + if (auto tensorType = dyn_cast(ty)) { + return tensorType.getElementType(); + } + return ty; + }; + + for (auto [opElemTy, resTy] : + llvm::zip(op.getElementTypes(), op.getResultTypes())) { + if (opElemTy != getElementType(resTy)) { + return op.emitOpError() << "operand types and result types must agree"; + } + } + return success(); +} + +template +static LogicalResult verifyRegionsImpl(Op &op) { + auto argElementTypes = op.getElementTypes(); + const auto &operands = op.getOperands(); + const auto numArgs = 2 * operands.size(); + auto &block = *op.getBody(); + if (block.getNumArguments() != numArgs) { + return op.emitOpError() << "nested block must take " << numArgs + << " arguments, but given block with " + << block.getNumArguments() << " arguments"; + } + unsigned i = 0; + const auto &blockArgTypes = block.getArgumentTypes(); + for (unsigned i = 0; i < numArgs; ++i) { + const auto &blockArgTy = blockArgTypes[i]; + const auto &argElemTy = argElementTypes[i % operands.size()]; + if (blockArgTy != argElemTy) { + return op.emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << blockArgTy; + } + } + + auto terminator = dyn_cast(block.getTerminator()); + if (!terminator) { + return op.emitOpError() + << "combine operation must be terminated " + << "with a ReduceReturnOp but got " << block.getTerminator(); + } + const auto &combineResults = terminator->getOperands(); + if (combineResults.size() != operands.size()) { + return op.emitOpError() + << "expected combine operation to return " << operands.size() + << " values but got " << combineResults.size(); + } + for (unsigned i = 0; i < combineResults.size(); ++i) { + const auto &resultTy = combineResults[i].getType(); + const auto &argElemTy = argElementTypes[i]; + if (resultTy != argElemTy) { + return op.emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << resultTy; + } + } + return success(); +} + +static llvm::SmallVector +getInputTypesImpl(const Operation::operand_range &operands) { + llvm::SmallVector srcTys; + srcTys.reserve(operands.size()); + for (const auto &ty : operands.getTypes()) { + srcTys.push_back(cast(ty)); + } + return srcTys; +} + +static llvm::SmallVector +getElementTypesImpl(const Operation::operand_range &operands) { + llvm::SmallVector srcElemTys; + srcElemTys.reserve(operands.size()); + for (const auto &op : operands) { + srcElemTys.push_back(cast(op.getType()).getElementType()); + } + return srcElemTys; +} + +LogicalResult ReduceOp::verify() { return verifyReduceScan(*this); } + +LogicalResult ReduceOp::verifyRegions() { + return verifyRegionsImpl(*this); +} + +llvm::SmallVector ReduceOp::getInputTypes() { + return getInputTypesImpl(this->getOperands()); +} + +llvm::SmallVector ReduceOp::getElementTypes() { + return getElementTypesImpl(this->getOperands()); +} + +::mlir::Operation *ReduceOp::getSingleCombiner() { + if (getNumOperands() != 1 || getNumResults() != 1) + return nullptr; + Block *block = &(*getCombineOp().begin()); + Operation *yield = block->getTerminator(); + Operation *reduceOp = yield->getOperand(0).getDefiningOp(); + if (!reduceOp || reduceOp->getNumOperands() != 2 || + reduceOp->getNumResults() != 1) + return nullptr; + if (reduceOp->getOperand(0) != block->getArgument(0) || + reduceOp->getOperand(1) != block->getArgument(1)) + return nullptr; + + return reduceOp; +} + +unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } + +//-- ScanOp -- +void ScanOp::build(OpBuilder &builder, OperationState &state, + ValueRange operands, int axis, bool reverse) { + SmallVector inferredReturnTypes; + for (auto arg : operands) + inferredReturnTypes.push_back(arg.getType()); + ScanOp::build(builder, state, inferredReturnTypes, operands, axis, reverse); +} + +LogicalResult +ScanOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + for (auto arg : operands) + inferredReturnTypes.push_back(arg.getType()); + return success(); +} + +LogicalResult ScanOp::verify() { return verifyReduceScan(*this); } + +LogicalResult ScanOp::verifyRegions() { + return verifyRegionsImpl(*this); +} + +llvm::SmallVector ScanOp::getInputTypes() { + return getInputTypesImpl(this->getOperands()); +} + +llvm::SmallVector ScanOp::getElementTypes() { + return getElementTypesImpl(this->getOperands()); +} + +unsigned ScanOp::getNumOperands() { return this->getOperands().size(); } + +//-- SplatOp -- +OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { + auto value = adaptor.getSrc(); + if (!value) + return {}; + if (!isa(value)) + return {}; + auto shapedType = cast(getType()); + auto ret = SplatElementsAttr::get(shapedType, ArrayRef(value)); + return ret; +} + +//-- ExpandDimsOp -- +LogicalResult ExpandDimsOp::inferReturnTypes( + MLIRContext *context, std::optional loc, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // infer shape + auto arg = operands[0]; + auto argTy = cast(arg.getType()); + auto retShape = argTy.getShape().vec(); + Properties *prop = properties.as(); + int axis = prop->axis.getInt(); + retShape.insert(retShape.begin() + axis, 1); + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = cast(&dialect); + if (failed(inferLayoutInterface->inferExpandDimsOpEncoding( + argEncoding, axis, retEncoding, loc))) + return emitOptionalError(loc, "failed to infer layout for ExpandDimsOp"); + } + // create type + auto argEltTy = argTy.getElementType(); + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, argEltTy, retEncoding)); + return success(); +} + +LogicalResult ExpandDimsOp::canonicalize(ExpandDimsOp op, + PatternRewriter &rewriter) { + auto definingOp = op.getSrc().getDefiningOp(); + if (!definingOp) { + return failure(); + } + // expand_dims(splat) -> splat + if (auto splat = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); + return success(); + } + // expand_dims(broadcast(x)) -> broadcast(expand_dims(x)) + // + // On its own this doesn't do much, but consider + // broadcast(expand_dims(broadcast)) + // -> broadcast(broadcast(expand_dims)) + // -> broadcast(expand_dims) + if (auto broadcast = dyn_cast(definingOp)) { + auto src = broadcast.getSrc(); + auto srcTy = src.getType(); + SmallVector newExpandShape(srcTy.getShape()); + newExpandShape.insert(newExpandShape.begin() + op.getAxis(), 1); + + // Infer the encoding of the new expand op, if encodings are present. + Attribute newExpandEnc; + if (auto srcEnc = srcTy.getEncoding()) { + Dialect &dialect = srcEnc.getDialect(); + auto inferLayoutInterface = cast(&dialect); + if (failed(inferLayoutInterface->inferExpandDimsOpEncoding( + srcEnc, op.getAxis(), newExpandEnc, op.getLoc()))) { + return emitOptionalError(op.getLoc(), + "failed to infer layout for ExpandDimsOp"); + } + } + + auto newExpandTy = RankedTensorType::get( + newExpandShape, srcTy.getElementType(), newExpandEnc); + auto newExpand = rewriter.create(op.getLoc(), newExpandTy, + src, op.getAxis()); + auto newBroadcast = rewriter.create( + broadcast.getLoc(), op.getType(), newExpand.getResult()); + rewriter.replaceOp(op, {newBroadcast.getResult()}); + return success(); + } + + return failure(); +} + +template +static OpFoldResult foldViewLikeOp(ViewLikeOp op, Attribute value) { + if (!value) + return {}; + + auto shapedType = cast(op.getType()); + if (auto denseElemsAttr = dyn_cast(value)) { + if (denseElemsAttr.isSplat()) { + return denseElemsAttr.resizeSplat(shapedType); + } else { + return denseElemsAttr.reshape(shapedType); + } + } + return {}; +} + +OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) { + return foldViewLikeOp(*this, adaptor.getSrc()); +} + +//-- ReshapeOp -- + +void ReshapeOp::build(OpBuilder &builder, OperationState &state, + ArrayRef shape, Value src, bool allowReorder) { + auto srcTy = cast(src.getType()); + auto srcEnc = srcTy.getEncoding(); + Attribute dstEnc; + if (srcEnc) { + auto result = cast(&srcEnc.getDialect()) + ->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, shape, + dstEnc, state.location); + assert(succeeded(result)); + } + auto dstTy = RankedTensorType::get(shape, srcTy.getElementType(), dstEnc); + build(builder, state, dstTy, src, allowReorder); +} + +LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) { + if (op.getEfficientLayout()) + return failure(); + + auto definingOp = op.getSrc().getDefiningOp(); + if (!definingOp) { + return failure(); + } + + // reshape(reshape) -> reshape + if (auto parentReshape = dyn_cast(definingOp)) { + // Allow reorder if either reshape allowed it + const bool allowReorder = + (op.getAllowReorder() || parentReshape.getAllowReorder()); + rewriter.replaceOpWithNewOp(op, op.getType(), + parentReshape.getSrc(), allowReorder, + op.getEfficientLayout()); + return success(); + } + + // reshape(splat) -> splat + if (auto splat = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); + return success(); + } + + return failure(); +} + +OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { + if (getType() == getSrc().getType() && !getAllowReorder()) { + // no-op + return getSrc(); + } + + return foldViewLikeOp(*this, adaptor.getSrc()); +} + +LogicalResult ReshapeOp::verify() { + auto dstTy = getType(); + auto srcTy = getSrc().getType(); + if (getType().getNumElements() != srcTy.getNumElements()) { + return emitError( + "number of src and dst elements of reshape must be the same"); + } + + Attribute srcEnc = srcTy.getEncoding(); + Attribute dstEnc = dstTy.getEncoding(); + if (!!srcEnc != !!dstEnc) { + return emitError("Op requires that either (a) src and dst both have " + "encodings, or (b) neither does."); + } + + if (!srcEnc || getAllowReorder()) { + return success(); + } + + // Check that we can infer the dst encoding from the src encoding + // and that the inferred dst encoding is the same as the given dst encoding + Attribute inferredDstEnc; + auto layoutInterface = + cast(&srcEnc.getDialect()); + auto result = layoutInterface->inferReshapeOpEncoding( + srcTy.getShape(), srcEnc, dstTy.getShape(), inferredDstEnc, getLoc()); + if (failed(result)) + return failure(); + return layoutInterface->verifyLayoutsAreEqual( + dstTy.getShape(), inferredDstEnc, dstEnc, getLoc()); +} + +//-- FpToFpOp -- + +// Fold FpToFpOp when the input operand is a constant zero. +OpFoldResult FpToFpOp::fold(FoldAdaptor adaptor) { + auto srcVal = getSrc(); + auto dstTy = getType(); + // Fold trivial cast + if (srcVal.getType() == dstTy) { + return srcVal; + } + + auto resElemType = cast(getElementTypeOrSelf(getType())); + const llvm::fltSemantics &semantic = resElemType.getFloatSemantics(); + + if (matchPattern(srcVal, m_PosZeroFloat())) { + llvm::APFloat posZero = + llvm::APFloat::getZero(semantic, /*negative=*/false); + if (auto tensorTy = dyn_cast(dstTy)) + return DenseElementsAttr::get(tensorTy, posZero); + return Builder(getContext()).getFloatAttr(resElemType, posZero); + } + + if (matchPattern(srcVal, m_NegZeroFloat())) { + llvm::APFloat negZero = llvm::APFloat::getZero(semantic, /*negative=*/true); + if (auto tensorTy = dyn_cast(dstTy)) + return DenseElementsAttr::get(tensorTy, negZero); + return Builder(getContext()).getFloatAttr(resElemType, negZero); + } + + return {}; +} + +LogicalResult FpToFpOp::verify() { + auto dstType = getType(); + auto srcType = getSrc().getType(); + if (auto dstTensorType = dyn_cast(dstType)) + dstType = dstTensorType.getElementType(); + if (auto srcTensorType = dyn_cast(srcType)) + srcType = srcTensorType.getElementType(); + if ((dstType.getIntOrFloatBitWidth() < srcType.getIntOrFloatBitWidth()) && + (!getRounding().has_value())) { + return emitError("Rounding mode is required for FP downcast"); + } + return success(); +} + +//-- BitcastOp -- +LogicalResult BitcastOp::verify() { + // Bitcast only allows conversion between types with the same bit width. + Type dstType = getType(); + Type srcType = getSrc().getType(); + // Strip tensor shapes; SameOperandsAndResultShape guarantees shapes match. + if (auto dstTensorType = dyn_cast(dstType)) + dstType = dstTensorType.getElementType(); + if (auto srcTensorType = dyn_cast(srcType)) + srcType = srcTensorType.getElementType(); + bool dstIsPtr = isa(dstType); + bool srcIsPtr = isa(srcType); + if (dstIsPtr || srcIsPtr) { + // Bitcast supports pointer-to-pointer conversions but not + // pointer-to-scalar. + if (dstIsPtr && srcIsPtr) { + if (triton::getAddressSpace(dstType) != triton::getAddressSpace(srcType)) + return emitError( + "Cannot bitcast pointer between different address spaces"); + return success(); + } + return emitError("Cannot bitcast pointer to non-pointer type"); + } + unsigned dstBits = dstType.getIntOrFloatBitWidth(); + unsigned srcBits = srcType.getIntOrFloatBitWidth(); + if (dstBits != srcBits) { + return emitError("Cannot bitcast data-type of size ") + << srcBits << " to data-type of size " << dstBits; + } + return success(); +} + +//-- BroadcastOp -- +void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { + if (getType() == getSrc().getType()) { + // no-op + return getSrc(); + } + + auto value = adaptor.getSrc(); + if (!value) + return {}; + + if (auto denseElemsAttr = dyn_cast(value)) { + auto shapedType = cast(getType()); + return denseElemsAttr.resizeSplat(shapedType); + } + return {}; +} + +LogicalResult BroadcastOp::verify() { + auto src = getSrc(); + auto srcTensorType = cast(src.getType()); + auto srcShape = srcTensorType.getShape(); + auto result = getResult(); + auto resultTensorType = cast(result.getType()); + auto resultShape = resultTensorType.getShape(); + if (srcShape.size() != resultShape.size()) { + return emitError("rank of source must be same as rank of result"); + } + for (int i = 0; i < srcShape.size(); i++) { + if (srcShape[i] != 1 && srcShape[i] != resultShape[i]) { + return emitError("Different dimensions at index ") + << i << " between source and result. " + << "Broadcast requires the source dimension to be 1."; + } + } + return success(); +} + +//-- MakeTensorPtrOp -- +void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange shape, ValueRange strides, + ValueRange offsets, ArrayRef tensorShape, + ArrayRef order) { + // Get pointer type from `base` + auto pointerType = cast(base.getType()); + assert(pointerType != nullptr); + + // Build type `tt.ptr>` + auto tensorType = RankedTensorType::get( + SmallVector(tensorShape.begin(), tensorShape.end()), + pointerType.getPointeeType()); + auto result = PointerType::get(tensorType, pointerType.getAddressSpace()); + + return build(builder, state, result, base, shape, strides, offsets, + builder.getDenseI32ArrayAttr(order)); +} + +//-- AddPtrOp -- +OpFoldResult AddPtrOp::fold(FoldAdaptor adaptor) { + // addptr(ptr, 0) -> ptr + if (matchPattern(adaptor.getOffset(), m_Zero())) { + return getPtr(); + } + return {}; +} + +//-- AdvanceOp -- +OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) { + // advance(ptr, 0, 0) -> ptr + SmallVector rawOffsets = getOffsets(); + auto offsets = getConstantIntValues(rawOffsets); + if (!offsets.has_value()) + return {}; + for (int64_t offset : offsets.value()) + if (offset != 0) + return {}; + return getPtr(); +} + +//-- MakeTensorDescOp -- +void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange shape, ValueRange strides, + ArrayRef blockShape, + bool isSignedInteger) { + auto ptrTy = dyn_cast(base.getType()); + if (!ptrTy) { + llvm::report_fatal_error("Expected pointer type"); + } + auto elemTy = ptrTy.getPointeeType(); + SmallVector blockShape64(blockShape); + auto blockTy = RankedTensorType::get(blockShape64, elemTy); + auto descTy = + TensorDescType::get(builder.getContext(), blockTy, isSignedInteger); + return build(builder, state, descTy, base, shape, strides); +} + +// The following ops, including `call`, `func`, and `return` are copied and +// modified from +// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp +// We could revert it back once MLIR has a better inliner interface. +//-- FuncOp -- +void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, + FunctionType type, ArrayRef attrs, + ArrayRef argAttrs) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); + + if (argAttrs.empty()) + return; + assert(type.getNumInputs() == argAttrs.size()); + call_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); +} + +ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(OpAsmPrinter &printer) { + function_interface_impl::printFunctionOp( + printer, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +// -- CallOp -- +LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the callee attribute was specified. + auto fnAttr = (*this).getProperties().callee; + if (!fnAttr) + return emitOpError("requires a 'callee' symbol reference attribute"); + FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + if (!fn) + return emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; + + // Verify that the operand and result types match the callee. + auto fnType = fn.getFunctionType(); + if (fnType.getNumInputs() != getNumOperands()) + return emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) + if (getOperand(i).getType() != fnType.getInput(i)) + return emitOpError("operand type mismatch: expected operand type ") + << fnType.getInput(i) << ", but provided " + << getOperand(i).getType() << " for operand number " << i; + + if (fnType.getNumResults() != getNumResults()) + return emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) + if (getResult(i).getType() != fnType.getResult(i)) { + auto diag = emitOpError("result type mismatch at index ") << i; + diag.attachNote() << " op result types: " << getResultTypes(); + diag.attachNote() << "function result types: " << fnType.getResults(); + return diag; + } + + return success(); +} + +// -- ReturnOp -- +LogicalResult ReturnOp::verify() { + auto function = cast((*this)->getParentOp()); + + // The operand number and types must match the function signature. + const auto &results = function.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing function (@" + << function.getName() << ") returns " << results.size(); + + for (unsigned i = 0, e = results.size(); i != e; ++i) + if (getOperand(i).getType() != results[i]) + return emitError() << "type of return operand " << i << " (" + << getOperand(i).getType() + << ") doesn't match function result type (" + << results[i] << ")" + << " in function @" << function.getName(); + + return success(); +} + +// -- JoinOp -- + +void JoinOp::build(OpBuilder &builder, OperationState &state, Value lhs, + Value rhs) { + auto lhsTy = cast(lhs.getType()); + SmallVector retShape(lhsTy.getShape()); + retShape.push_back(2); + + Attribute srcEnc = lhsTy.getEncoding(); + Attribute retEnc; + if (srcEnc) { + if (failed(cast(&srcEnc.getDialect()) + ->inferDefaultJoinOpEncoding( + srcEnc, retEnc, lhsTy.getShape(), state.location))) { + llvm_unreachable("failed to infer join encoding"); + } + } + auto retTy = RankedTensorType::get(retShape, lhsTy.getElementType(), retEnc); + JoinOp::build(builder, state, retTy, lhs, rhs); +} + +LogicalResult JoinOp::verify() { + RankedTensorType srcTy = getLhs().getType(); + SmallVector retShape(srcTy.getShape()); + retShape.push_back(2); + + RankedTensorType retTy = getType(); + if (SmallVector(retTy.getShape()) != retShape) { + return emitOpError("result shape must be (") + << retShape << "), but got " << retTy.getShape(); + } + if (retTy.getElementType() != srcTy.getElementType()) { + return emitOpError("result element type must match the input element type"); + } + Attribute retEnc = retTy.getEncoding(); + if (!retEnc) { + if (srcTy.getEncoding()) { + return emitOpError("result encoding must be specified"); + } + return success(); + } + // There are multiple correct destination layout for a given source layout but + // there is only one correct source layout for a given destination layout. So + // we verify that the source layout match the destination layout. + Attribute srcEnc; + Location location = getLoc(); + if (cast(&retEnc.getDialect()) + ->inferSplitOpEncoding(retEnc, srcEnc, retShape, location) + .failed()) { + return failure(); + } + + if (cast(&srcEnc.getDialect()) + ->verifyLayoutsAreEqual(srcTy.getShape(), srcEnc, srcTy.getEncoding(), + {}) + .failed()) { + return emitOpError("incompatible join layout"); + } + return success(); +} + +// -- SplitOp -- +LogicalResult SplitOp::inferReturnTypes( + MLIRContext *context, std::optional location, + SplitOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { + auto srcTy = cast(adaptor.getSrc().getType()); + auto srcShape = srcTy.getShape(); + + if (srcShape.empty() || srcShape.back() != 2) { + return emitOptionalError(location, + "last dimension of input tensor must be 2"); + } + ArrayRef retShape(srcShape.begin(), srcShape.end() - 1); + + Attribute srcEnc = srcTy.getEncoding(); + Attribute retEnc; + if (srcEnc) { + if (cast(&srcEnc.getDialect()) + ->inferSplitOpEncoding(srcEnc, retEnc, srcTy.getShape(), location) + .failed()) { + return failure(); + } + } + auto retTy = RankedTensorType::get(retShape, srcTy.getElementType(), retEnc); + inferredReturnTypes.push_back(retTy); + inferredReturnTypes.push_back(retTy); + return success(); +} + +// -- ElementwiseInlineAsmOp -- +void ElementwiseInlineAsmOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get()); + effects.emplace_back(MemoryEffects::Read::get()); +} + +Speculation::Speculatability ElementwiseInlineAsmOp::getSpeculatability() { + if (getPure()) + return Speculation::Speculatable; + return Speculation::NotSpeculatable; +} + +LogicalResult ElementwiseInlineAsmOp::verify() { + if (getNumOperands() >= 1) { + auto tensorType = dyn_cast(getOperand(0).getType()); + size_t numInputElems = tensorType ? tensorType.getNumElements() : 0; + if (numInputElems % this->getPackedElement() != 0) { + return emitError("number of input elements ") + << numInputElems + << " must be a multiple of the op's packed_element attribute, " + << getPackedElement(); + } + } + return success(); +} + +// -- ExternElementwiseOp -- +void ExternElementwiseOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get()); + effects.emplace_back(MemoryEffects::Read::get()); +} + +Speculation::Speculatability ExternElementwiseOp::getSpeculatability() { + if (getPure()) + return Speculation::Speculatable; + return Speculation::NotSpeculatable; +} + +// -- GatherOp -- +LogicalResult GatherOp::verify() { + RankedTensorType indicesTy = getIndices().getType(); + RankedTensorType srcTy = getSrc().getType(); + RankedTensorType resTy = getResult().getType(); + + if (indicesTy.getShape() != resTy.getShape()) { + return emitOpError("indices and output shapes must match"); + } + if (indicesTy.getEncoding() != resTy.getEncoding()) { + return emitOpError("indices and output encodings must match"); + } + if (srcTy.getElementType() != resTy.getElementType()) { + return emitOpError("input and output element types must match"); + } + if (srcTy.getRank() != indicesTy.getRank()) { + return emitOpError("input and indices ranks must match"); + } + if (getAxis() >= srcTy.getRank()) { + return emitOpError("gather dimension must be less than the input rank"); + } + for (int dim = 0; dim < indicesTy.getRank(); ++dim) { + if (dim == getAxis()) + continue; + if (indicesTy.getShape()[dim] != srcTy.getShape()[dim]) { + return emitOpError("indices dimension ") + << dim << " must match the corresponding input dimension"; + } + } + + return success(); +} + +LogicalResult GatherOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + GatherOpAdaptor adaptor(operands, attributes, properties, regions); + auto indicesType = cast(adaptor.getIndices().getType()); + auto srcType = cast(adaptor.getSrc().getType()); + + // Shape and encoding of the indices with the element type of the src. + inferredReturnTypes.push_back( + RankedTensorType::get(indicesType.getShape(), srcType.getElementType(), + indicesType.getEncoding())); + return success(); +} + +// -- DescriptorGatherOp +LogicalResult +DescriptorGatherOp::verifyResultType(Operation *op, ShapedType resultType, + RankedTensorType indicesType) { + if (indicesType.getRank() != 1) + return op->emitOpError("x offsets must be a 1D tensor, but got ") + << indicesType; + if (resultType.getRank() != 2) + return op->emitOpError("result must be a 2D tensor, but got ") + << resultType; + + // The swizzling of TMA accesses matches that of the MMAv3 shared memory + // layouts. However, these have minimum size requirements. + // TODO: We can support smaller gather sizes by padding the `local_alloc` this + // lowers to to the nearest minimum tile size. + if (unsigned rows = resultType.getShape()[0]; rows < 8) { + return op->emitOpError("gather must have at least 8 rows, but got ") + << rows; + } + + Type dtype = resultType.getElementType(); + if (dtype.getIntOrFloatBitWidth() > 32) + return op->emitOpError("TMA dtype cannot be greater than 32 bits"); + + unsigned minCols = 32 / dtype.getIntOrFloatBitWidth() * 8; + if (unsigned cols = resultType.getShape()[1]; cols < minCols) { + return op->emitOpError("gather of ") + << dtype << " must have at least " << minCols << " columns, but got " + << cols; + } + + if (resultType.getShape()[0] != indicesType.getShape()[0]) { + return op->emitOpError("result tensor must have as many rows as indices (") + << indicesType.getShape()[0] << "), but got " << resultType; + } + + return success(); +} + +static LogicalResult verifyGatherScatterOp(Operation *op, + RankedTensorType blockType, + RankedTensorType resultType, + RankedTensorType indicesType) { + // Gather from `!tt.tensordesc>`. + if (blockType.getRank() != 2) { + return op->emitOpError("block must be a 2D tensor, but got ") << blockType; + } + if (blockType.getShape()[0] != 1) { + return op->emitOpError("block must have exactly 1 row, but got ") + << blockType; + } + + // With x offsets `tensor` into `tensor`. + if (failed(DescriptorGatherOp::verifyResultType(op, resultType, indicesType))) + return failure(); + + if (resultType.getShape()[1] != blockType.getShape()[1]) { + return op->emitOpError("result tensor number of columns must match block (") + << blockType.getShape()[1] << "), but got " << resultType; + } + if (resultType.getElementType() != blockType.getElementType()) { + return op->emitOpError("result tensor element type must match block (") + << blockType.getElementType() << "), but got " << resultType; + } + + return success(); +} + +LogicalResult DescriptorGatherOp::verify() { + return verifyGatherScatterOp(*this, + getDesc().getType().getSignlessBlockType(), + getResult().getType(), getXOffsets().getType()); +} + +// -- DescriptorScatterOp -- +LogicalResult DescriptorScatterOp::verify() { + return verifyGatherScatterOp(*this, + getDesc().getType().getSignlessBlockType(), + getSrc().getType(), getXOffsets().getType()); +} + +// -- DescriptorLoadOp -- +static LogicalResult verifyDescriptorLoadStoreType(Operation *op, + TensorDescType desc, + RankedTensorType tensor) { + RankedTensorType block = desc.getSignlessBlockType(); + ArrayRef blockShape = block.getShape(); + ArrayRef tensorShape = tensor.getShape(); + if (blockShape.size() > tensorShape.size()) { + // Allow ranked reduced load if the leading dimensions are all 1s. + for (int i = 0; i < blockShape.size() - tensorShape.size(); ++i) { + if (blockShape[i] != 1) + return op->emitOpError( + "ranked reduce load only allowed for unit dimension leading dim."); + } + blockShape = blockShape.take_back(tensorShape.size()); + } + + if (blockShape == tensorShape && + block.getElementType() == tensor.getElementType()) + return success(); + return op->emitOpError("tensor descriptor block and tensor types must match"); +} + +LogicalResult DescriptorLoadOp::verify() { + return verifyDescriptorLoadStoreType(*this, getDesc().getType(), getType()); +} + +// -- DescriptorStoreOp -- +LogicalResult DescriptorStoreOp::verify() { + return verifyDescriptorLoadStoreType(*this, getDesc().getType(), + getSrc().getType()); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Traits.cpp b/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Traits.cpp new file mode 100644 index 000000000..6c45e5a8d --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Traits.cpp @@ -0,0 +1,217 @@ +#include "triton/Dialect/Triton/IR/Traits.h" + +#include + +#include "mlir/IR/TypeUtilities.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; + +LogicalResult OpTrait::impl::verifyEquivalentType(Type typeA, Type typeB) { + auto tensorTypeA = dyn_cast(typeA); + auto tensorTypeB = dyn_cast(typeB); + if (!(bool(tensorTypeA) && bool(tensorTypeB))) + return typeA == typeB ? success() : failure(); + auto encodingA = tensorTypeA.getEncoding(); + auto encodingB = tensorTypeB.getEncoding(); + auto shapeA = tensorTypeA.getShape(); + auto shapeB = tensorTypeB.getShape(); + if (shapeA != shapeB) + return failure(); + if (tensorTypeA.getElementType() != tensorTypeB.getElementType()) + return failure(); + // If there's no encoding or the encodings are the same + if (encodingA == encodingB) + return success(); + + return cast(&encodingA.getDialect()) + ->verifyLayoutsAreEqual(shapeA, encodingA, encodingB, {}); +} + +static LogicalResult verifySameEncoding(Type typeA, Type typeB, + bool allowTensorPointerType) { + // TODO(Keren): the allowTensorPointerType argument is a hack to allow. + // The type checking code is kind of a mess with the current design. + auto getEncoding = [=](Type type) -> Attribute { + Attribute ret; + if (auto tensorType = dyn_cast(type)) { + ret = tensorType.getEncoding(); + } + if (!allowTensorPointerType) { + assert(!triton::isTensorPointerType(type)); + } + return ret; + }; + auto encodingA = getEncoding(typeA); + auto encodingB = getEncoding(typeB); + if (!encodingA || !encodingB) + return success(); + return encodingA == encodingB ? success() : failure(); +} + +LogicalResult +OpTrait::impl::verifySameOperandsEncoding(Operation *op, + bool allowTensorPointerType) { + if (failed(verifyAtLeastNOperands(op, 1))) + return failure(); + + auto type = op->getOperand(0).getType(); + for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) + if (failed(verifySameEncoding(opType, type, allowTensorPointerType))) + return op->emitOpError() << "requires the same encoding for all operands"; + + return success(); +} + +LogicalResult OpTrait::impl::verifySameOperandsAndResultEncoding( + Operation *op, bool allowTensorPointerType) { + if (op->getNumOperands() == 0) + return success(); + + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) + return failure(); + + auto type = op->getOperand(0).getType(); + for (auto resultType : op->getResultTypes()) + if (failed(verifySameEncoding(resultType, type, allowTensorPointerType))) + return op->emitOpError() + << "requires the same encoding for all operands and results"; + + return verifySameOperandsEncoding(op, allowTensorPointerType); +} + +LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) { + for (auto opType : op->getOperandTypes()) { + if (auto tensorType = dyn_cast(opType)) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + if ((numElements & (numElements - 1)) != 0) + return op->emitError("Number of elements must be power-of-two, but ") + << *op << " doesn't follow the rule (" << numElements << ")" + << " elements"; + } + } + for (auto opType : op->getResultTypes()) { + if (auto tensorType = dyn_cast(opType)) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + if ((numElements & (numElements - 1)) != 0) + return op->emitError("Number of elements must be power-of-two, but ") + << *op << " doesn't follow the rule (" << numElements << ")" + << " elements"; + } + } + return success(); +} + +// Check that the Triton layouts on op's operands and return types are valid. +// For example, we check that the number of warps per block in a Triton GPU +// blocked layout matches that of its module. +// +// It's a little weird to check these properties of a layout only when the +// layout is used in an op, since most of the properties don't actually depend +// on the op. They do depend on the *module*, though, and a layout is attached +// to a module only by virtue of being used in one of the module's ops. +LogicalResult OpTrait::impl::verifyTensorLayouts(Operation *op) { + auto checkLayout = [&](Value val, auto makeErr) -> LogicalResult { + // Only ranked tensors can have layouts. + auto rankedTy = dyn_cast(val.getType()); + if (!rankedTy) + return success(); + + mlir::Attribute layout = rankedTy.getEncoding(); + if (!layout) + return success(); + + Dialect &dialect = layout.getDialect(); + auto verifyLayoutInterface = + dyn_cast(&dialect); + if (verifyLayoutInterface) { + return verifyLayoutInterface->verifyTensorLayout(layout, rankedTy, op, + makeErr); + } + + return success(); + }; + + for (size_t i = 0; i < op->getNumOperands(); i++) { + auto operand = op->getOperand(i); + auto err = checkLayout(operand, [&]() { + // Stringify the operand using `printAsOperand`. This prints e.g. "%42" + // rather than the full definition. + std::string operandStr; + llvm::raw_string_ostream os(operandStr); + // If we don't assume verified, dump() will recursively call this + // function! + operand.printAsOperand(os, OpPrintingFlags().assumeVerified()); + + return op->emitError("Operand ") + << i << " (" << operand << ") has an invalid layout: "; + }); + if (!err.succeeded()) + return err; + } + + for (size_t i = 0; i < op->getNumResults(); i++) { + auto result = op->getResult(i); + auto err = checkLayout(result, [&]() { + if (op->getNumResults() == 1) { + return op->emitError("Result has an invalid layout: "); + } else { + return op->emitError("Result ") << i << " has an invalid layout: "; + } + }); + if (!err.succeeded()) + return err; + } + + return success(); +} + +static ArrayRef getTypeShape(Type type) { + auto rankedType = dyn_cast(type); + if (auto ptrType = dyn_cast(type)) + rankedType = dyn_cast(ptrType.getPointeeType()); + return rankedType ? rankedType.getShape() : ArrayRef(); +} + +LogicalResult OpTrait::impl::verifySameLoadStoreOperandsShape(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1))) + return failure(); + + auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); + for (auto type : llvm::drop_begin(op->getOperandTypes(), 1)) + if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) + return op->emitOpError() << "requires the same shape for all operands"; + + return success(); +} + +LogicalResult +OpTrait::impl::verifySameLoadStoreOperandsAndResultShape(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) + return failure(); + + auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); + for (auto type : op->getResultTypes()) + if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) + return op->emitOpError() + << "requires the same shape for all operands and results"; + + return verifySameLoadStoreOperandsShape(op); +} diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Types.cpp b/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Types.cpp new file mode 100644 index 000000000..de8925cbf --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Types.cpp @@ -0,0 +1,142 @@ +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton; + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/Types.cpp.inc" + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void TritonDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/Triton/IR/Types.cpp.inc" + >(); +} + +Type PointerType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + Type pointeeType; + if (parser.parseType(pointeeType)) + return Type(); + + int addressSpace = 1; + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseInteger(addressSpace)) + return Type(); + } + + if (parser.parseGreater()) + return Type(); + + return PointerType::get(pointeeType, addressSpace); +} + +void PointerType::print(AsmPrinter &printer) const { + if (getAddressSpace() == 1) { + printer << "<" << getPointeeType() << ">"; + } else { + printer << "<" << getPointeeType() << ", " << getAddressSpace() << ">"; + } +} + +namespace mlir { + +namespace triton { + +unsigned getPointeeBitWidth(Type type) { + auto pointeeType = getPointeeType(type); + if (auto tensorTy = dyn_cast(pointeeType)) + return tensorTy.getElementType().getIntOrFloatBitWidth(); + return pointeeType.getIntOrFloatBitWidth(); +} + +Type getI1SameShape(Type type) { + auto i1Type = IntegerType::get(type.getContext(), 1); + if (auto tensorTy = dyn_cast(type)) + return RankedTensorType::get(tensorTy.getShape(), i1Type, + tensorTy.getEncoding()); + return i1Type; +} + +Type getPointeeType(Type type) { + if (auto tensorTy = dyn_cast(type)) { + // Tensor of pointers + auto shape = tensorTy.getShape(); + auto ptrType = dyn_cast(tensorTy.getElementType()); + Type pointeeType = ptrType.getPointeeType(); + return RankedTensorType::get(shape, pointeeType, tensorTy.getEncoding()); + } else if (auto ptrType = dyn_cast(type)) { + // scalar pointer + Type pointeeType = ptrType.getPointeeType(); + return pointeeType; + } + return type; +} + +Type getI32SameShape(Type type) { + auto i32Type = IntegerType::get(type.getContext(), 32); + if (auto tensorTy = dyn_cast(type)) + return RankedTensorType::get(tensorTy.getShape(), i32Type, + tensorTy.getEncoding()); + return i32Type; +} + +Type getPointerTypeSameShape(Type type) { + if (auto tensorTy = dyn_cast(type)) { + Type elementType = tensorTy.getElementType(); + auto shape = tensorTy.getShape(); + PointerType ptrType = PointerType::get(elementType, 1); + return RankedTensorType::get(shape, ptrType, tensorTy.getEncoding()); + } else { + return PointerType::get(type, 1); + } +} + +Type getPointerTypeToElement(Type type) { + Type elementType = getElementTypeOrSelf(type); + PointerType ptrType = PointerType::get(elementType, 1); + return ptrType; +} + +// upstream Triton only uses address space 1 for Pointer Type +Type getPointerType(Type type, int addressSpace) { + return PointerType::get(type, addressSpace); +} + +int getAddressSpace(Type type) { + if (auto ptrType = dyn_cast(type)) + return ptrType.getAddressSpace(); + return 1; +} + +bool isTensorPointerType(Type type) { + if (auto ptrType = dyn_cast(type)) + return isa(ptrType.getPointeeType()); + return false; +} + +bool isTensorOrTensorPointerType(Type type) { + return isa(type) || isTensorPointerType(type); +} + +Type getElementTypeOfTensorPointerType(Type type) { + if (auto ptrType = dyn_cast(type)) + if (auto tensorTy = dyn_cast(ptrType.getPointeeType())) + return tensorTy.getElementType(); + return {}; +} + +} // namespace triton + +} // namespace mlir diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Utility.cpp b/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Utility.cpp new file mode 100644 index 000000000..af9f798ec --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/IR/Utility.cpp @@ -0,0 +1,119 @@ +#include "triton/Dialect/Triton/IR/Utility.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +using namespace mlir; +namespace tt = mlir::triton; + +Value tt::getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask, + Value pred) { + Type maskType = tt::getI1SameShape(typeLike); + Location loc = pred.getLoc(); + Value mask = pred; + if (isa(maskType)) { + mask = rewriter.create(loc, maskType, pred); + } + if (currentMask) { + mask = rewriter.create(loc, mask, currentMask); + } + return mask; +} + +static tt::MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) { + + if (auto makeTensorPtrOp = dyn_cast(op)) { + return makeTensorPtrOp; + } + + if (auto advanceOp = dyn_cast(op)) { + return tt::getMakeTensorPtrOp(advanceOp.getPtr()); + } + + if (auto branch = dyn_cast(op)) { + auto idx = cast(v).getResultNumber(); + llvm::SmallVector yieldOps; + op->walk([&](Operation *op) { + if (auto yieldOp = dyn_cast(op)) + yieldOps.push_back(yieldOp); + }); + + // benzh@ if multi yields, all yields operand should come from same arg. + Value newValue = yieldOps[0].getOperands()[idx]; + return tt::getMakeTensorPtrOp(newValue); + } + + llvm_unreachable("Unable to getMakeTensorPtr()"); +} + +tt::MakeTensorPtrOp tt::getMakeTensorPtrOp(Value v) { + using BranchOps = llvm::SetVector>; + llvm::DenseMap blockToCFOps; + auto moduleOp = + v.getParentBlock()->getParentOp()->getParentOfType(); + + moduleOp.walk([&](Operation *op) { + if (auto br = dyn_cast(op)) { + Block *block = br.getDest(); + blockToCFOps[block].insert({op, -1}); + } + if (auto condBr = dyn_cast(op)) { + Block *blockT = condBr.getTrueDest(); + Block *blockF = condBr.getFalseDest(); + blockToCFOps[blockT].insert({condBr, 1}); + blockToCFOps[blockF].insert({condBr, 0}); + } + }); + + if (Operation *definingOp = v.getDefiningOp()) + return getMakeTensorPtrOpImpl(definingOp, v); + + // If there is no defining op, v must be a BlockArgument. + BlockArgument arg = cast(v); + unsigned argNum = arg.getArgNumber(); + Operation *argOwner = arg.getOwner()->getParentOp(); + + if (auto forOp = dyn_cast(argOwner)) + return tt::getMakeTensorPtrOp( + forOp.getOperand(argNum + forOp.getNumControlOperands() - 1)); + if (auto funcOp = dyn_cast(argOwner)) { + Block *block = arg.getOwner(); + Operation *op; + int tOrF; + std::tie(op, tOrF) = blockToCFOps[block][0]; + if (auto br = dyn_cast(op)) + return tt::getMakeTensorPtrOp(br.getDestOperands()[argNum]); + if (auto condBr = dyn_cast(op)) + return tt::getMakeTensorPtrOp( + tOrF ? condBr.getTrueDestOperands()[argNum] + : condBr.getFalseDestOperands()[argNum]); + return tt::getMakeTensorPtrOp(argOwner->getOperand(argNum)); + } + llvm_unreachable("Unable to getMakeTensorPtr()"); +} + +Value tt::getLastInductionValue(OpBuilder &b, scf::ForOp loop) { + Location loc = loop.getLoc(); + // (ub - lb -1) // step * step + lb + Value diff = + b.create(loc, loop.getUpperBound(), loop.getLowerBound()); + diff = b.create( + loc, diff, b.create(loc, b.getI32IntegerAttr(1))); + Value ceilStep = b.create( + loc, b.create(loc, diff, loop.getStep()), loop.getStep()); + return b.create(loc, ceilStep, loop.getLowerBound()); +} + +bool tt::isKernel(FunctionOpInterface funcOp) { + return funcOp.getVisibility() == SymbolTable::Visibility::Public; +} + +bool tt::isHostSideDescriptor(Value v) { + auto arg = dyn_cast(v); + if (!arg) + return false; + auto funcOp = dyn_cast(arg.getOwner()->getParentOp()); + if (!funcOp) + return false; + return tt::isKernel(funcOp); +} diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/ArithTypeConversion.cpp b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/ArithTypeConversion.cpp new file mode 100644 index 000000000..333c205b6 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/ArithTypeConversion.cpp @@ -0,0 +1,50 @@ +#include "triton/Dialect/Triton/Transforms/ArithTypeConversion.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace { + +struct RewriteArithSelectOp : mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(mlir::arith::SelectOp op, OneToNOpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // Note we're replacing the select op with an if op because we are + // converting one value into many values. + auto newIf = rewriter.create( + op.getLoc(), mlir::TypeRange(adaptor.getTrueValue()), op.getCondition(), + true); + // We set the attributes from the op in case the op has any additional + // attributes + newIf->setAttrs(op->getAttrs()); + + { + mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(newIf.thenBlock()); + rewriter.create(op->getLoc(), adaptor.getTrueValue()); + rewriter.setInsertionPointToStart(newIf.elseBlock()); + rewriter.create(op->getLoc(), + adaptor.getFalseValue()); + } + + // Replace the old operation results + rewriter.replaceOpWithMultiple(op, {newIf->getResults()}); + + return mlir::success(); + } +}; + +} // namespace +namespace mlir::triton { + +void populateArithTypeConversions(const TypeConverter &converter, + RewritePatternSet &patterns) { + patterns.add(converter, patterns.getContext()); +} + +} // namespace mlir::triton diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/CMakeLists.txt b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/CMakeLists.txt new file mode 100644 index 000000000..115189da8 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/CMakeLists.txt @@ -0,0 +1,27 @@ +set(LLVM_TARGET_DEFINITIONS Combine.td) +mlir_tablegen(TritonCombine.inc -gen-rewriters) +add_public_tablegen_target(TritonCombineIncGen) + +add_triton_library(FlagTree_sunrise_TritonTransforms + Combine.cpp + LoopAwareCSE.cpp + LoopInvariantCodeMotion.cpp + LoopPeeling.cpp + LoopUnroll.cpp + ReorderBroadcast.cpp + RewriteTensorPointer.cpp + RewriteTensorDescriptorToPointer.cpp + ArithTypeConversion.cpp + FunctionTypeConversion.cpp + + DEPENDS + TritonTransformsIncGen + TritonCombineIncGen + + LINK_LIBS PUBLIC + MLIRPass + MLIRTransformUtils + MLIRTransforms + MLIRSCFToControlFlow + TritonIR +) diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/Combine.cpp b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/Combine.cpp new file mode 100644 index 000000000..6fab87c8a --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/Combine.cpp @@ -0,0 +1,268 @@ +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +namespace mlir::triton { + +#define GEN_PASS_DEF_TRITONCOMBINEOPS +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace { + +bool isZero(Value val) { + return (matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat())); +} + +bool isAddPtrOffsetCombinable(Value first, Value second) { + auto GetConstantIntValue = [](Value val) -> std::optional { + DenseElementsAttr constAttr; + auto defOp = val.getDefiningOp(); + if (defOp) { + if (auto splatOp = llvm::dyn_cast(defOp)) + val = splatOp.getSrc(); + else if (matchPattern(defOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto attr = constAttr.getSplatValue(); + // Check IntegerAttr + if (auto intAttr = dyn_cast_or_null(attr)) + return intAttr.getValue(); + } + } + + // Check constant value. + llvm::APInt intVal; + if (matchPattern(val, m_ConstantInt(&intVal))) + return intVal; + + return std::nullopt; + }; + + if (first.getType() == second.getType()) { + // Whether bitwidth of element type is equal to pointer + if (getElementTypeOrSelf(first.getType()).getIntOrFloatBitWidth() == 64) + return true; + + // first + second does not overflow + auto firstVal = GetConstantIntValue(first); + auto secondVal = GetConstantIntValue(second); + if (firstVal && secondVal) { + bool overflow = false; + auto resVal = firstVal->sadd_ov(*secondVal, overflow); + return !overflow; + } + } + return false; +} + +// TODO(csigg): remove after next LLVM integrate. +using FastMathFlags = arith::FastMathFlags; + +#include "TritonCombine.inc" + +// select(cond, load(ptrs, splat(cond), ???), other) +// => load(ptrs, splat(cond), other) +class CombineSelectMaskedLoadPattern : public RewritePattern { +public: + CombineSelectMaskedLoadPattern(MLIRContext *context) + : RewritePattern(arith::SelectOp::getOperationName(), 3, context, + {LoadOp::getOperationName()}) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto selectOp = llvm::dyn_cast(op); + if (!selectOp) + return failure(); + + Value trueValue = selectOp.getTrueValue(); + Value falseValue = selectOp.getFalseValue(); + Value condSelect = selectOp.getCondition(); + + auto loadOp = trueValue.getDefiningOp(); + if (!loadOp) + return failure(); + + Value mask = loadOp.getMask(); + if (!mask) + return failure(); + + auto splatOp = mask.getDefiningOp(); + if (!splatOp) + return failure(); + + auto splatCond = splatOp.getSrc(); + if (splatCond != condSelect) + return failure(); + + rewriter.replaceOpWithNewOp( + op, loadOp.getPtr(), loadOp.getMask(), /*other=*/falseValue, + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + return success(); + } +}; + +// sum(x[:, :, None] * y[None, :, :], 1) +// -> dot(x, y) +class CombineBroadcastMulReducePattern : public RewritePattern { +private: + static bool isAddF32(const Operation *op) { + if (auto addf = dyn_cast_or_null(op)) + return addf.getType().getIntOrFloatBitWidth() <= 32; + return false; + } + + static SmallVector getEqualIndices(ArrayRef x, + ArrayRef y) { + SmallVector res; + for (int i = 0; i < x.size(); ++i) + if (x[i] == y[i]) + res.push_back(i); + return res; + } + +public: + CombineBroadcastMulReducePattern(MLIRContext *context) + : RewritePattern(ReduceOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + auto reduceOp = llvm::dyn_cast(op); + if (!reduceOp) + return failure(); + // only support reduce with simple addition + Region &combineOp = reduceOp.getCombineOp(); + bool isReduceAdd = combineOp.hasOneBlock() && + combineOp.front().getOperations().size() == 2 && + isAddF32(&*combineOp.front().getOperations().begin()); + if (!isReduceAdd) + return failure(); + // operand of reduce has to be mul + auto mulOp = reduceOp.getOperand(0).getDefiningOp(); + if (!mulOp) + return failure(); + // mul operand has to be broadcast + auto broadcastLhsOp = mulOp.getOperand(0).getDefiningOp(); + if (!broadcastLhsOp) + return failure(); + auto broadcastRhsOp = mulOp.getOperand(1).getDefiningOp(); + if (!broadcastRhsOp) + return failure(); + // broadcast operand is expand dims + auto expandLhsOp = broadcastLhsOp.getSrc().getDefiningOp(); + if (!expandLhsOp) + return failure(); + auto expandRhsOp = broadcastRhsOp.getSrc().getDefiningOp(); + if (!expandRhsOp) + return failure(); + // get not-broadcast dimensions + int expandLhsAxis = expandLhsOp.getAxis(); + int expandRhsAxis = expandRhsOp.getAxis(); + if (expandLhsAxis != 2 || expandRhsAxis != 0) + return failure(); + auto broadcastLhsShape = + cast(broadcastLhsOp.getType()).getShape(); + auto broadcastRhsShape = + cast(broadcastLhsOp.getType()).getShape(); + if (broadcastLhsShape[2] < 16 || broadcastRhsShape[0] < 16) + return failure(); + Type newAccType = RankedTensorType::get( + {broadcastLhsShape[0], broadcastRhsShape[2]}, + cast(broadcastLhsOp.getSrc().getType()).getElementType()); + rewriter.setInsertionPoint(op); + auto newAcc = rewriter.create( + op->getLoc(), newAccType, + rewriter.create(op->getLoc(), + rewriter.getF32FloatAttr(0))); + rewriter.replaceOpWithNewOp(op, expandLhsOp.getSrc(), + expandRhsOp.getSrc(), newAcc, + InputPrecision::TF32, 0); + return success(); + } +}; + +// When reducing a 1D tensor the order of elements of the tensor doesn't matter. +// Therefore we can relax the reshape to allow it to re-order elements. +class CombineReshapeReducePatterns : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::ReshapeOp reshapeOp, + mlir::PatternRewriter &rewriter) const override { + if (reshapeOp.getAllowReorder()) + return failure(); + if (reshapeOp.getType().getRank() != 1) + return failure(); + for (Operation *user : reshapeOp->getUsers()) { + if (!isa(user)) + return failure(); + } + rewriter.modifyOpInPlace(reshapeOp, + [&]() { reshapeOp.setAllowReorder(true); }); + return success(); + } +}; + +class RankedReduceDescriptorLoads : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::ReshapeOp reshapeOp, + mlir::PatternRewriter &rewriter) const override { + auto loadDef = reshapeOp.getSrc().getDefiningOp(); + if (!loadDef || !loadDef->hasOneUse()) + return failure(); + int loadRank = loadDef.getType().getRank(); + int reshapeRank = reshapeOp.getType().getRank(); + if (!(reshapeRank < loadRank)) + return failure(); + ArrayRef loadShape = loadDef.getType().getShape(); + ArrayRef reshapeShape = reshapeOp.getType().getShape(); + for (int i = 0; i < loadRank - reshapeRank; ++i) { + // Only rank reduce unit dims. + if (loadShape[i] != 1) + return failure(); + } + if (loadShape.take_back(reshapeRank) != reshapeShape) + return failure(); + rewriter.modifyOpInPlace( + loadDef, [&]() { loadDef.getResult().setType(reshapeOp.getType()); }); + rewriter.replaceOp(reshapeOp, loadDef.getResult()); + return success(); + } +}; + +} // anonymous namespace + +class CombineOpsPass : public impl::TritonCombineOpsBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp m = getOperation(); + + // Dot Add %{ + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + // %} + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + + if (applyPatternsGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // namespace mlir::triton diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/Combine.td b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/Combine.td new file mode 100644 index 000000000..e3588f587 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/Combine.td @@ -0,0 +1,47 @@ +#ifndef TRITON_PATTERNS +#define TRITON_PATTERNS + +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "triton/Dialect/Triton/IR/TritonOps.td" +include "mlir/IR/PatternBase.td" + + +// AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) +// AddFOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) + +// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) +// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) +def CombineDotAddIPattern : Pat< + (Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $overflow), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; +def CombineDotAddFPattern : Pat< + (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; + +def CombineDotAddIRevPattern : Pat< + (Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $overflow), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; +def CombineDotAddFRevPattern : Pat< + (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; + +// addptr(addptr(%ptr, %idx0), %idx1) => addptr(%ptr, AddI(%idx0, %idx1)) +// Note: leave (sub %c0, %c0) canceling to ArithDialect +// (ref: ArithCanonicalization.td) +defvar DefOverflow = ConstantEnumCase; +def CombineAddPtrPattern : Pat< + (TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1), + (TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1, DefOverflow)), + [(Constraint> $idx0, $idx1)]>; + +#endif diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp new file mode 100644 index 000000000..0170463ce --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp @@ -0,0 +1,86 @@ +#include "triton/Dialect/Triton/Transforms/FunctionTypeConversion.h" + +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" + +#include + +namespace mlir::triton { + +namespace { + +SmallVector flattenValues(ArrayRef values) { + SmallVector ret; + for (const auto &vs : values) { + llvm::append_range(ret, vs); + } + return ret; +} + +struct CallOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm::SmallVector resultReplacementGrouping; + llvm::SmallVector convertedResults; + + for (auto type : callOp->getResultTypes()) { + const auto oldNumFlattenedResults = convertedResults.size(); + if (failed(getTypeConverter()->convertTypes(type, convertedResults))) { + return failure(); + } + resultReplacementGrouping.push_back(convertedResults.size() - + oldNumFlattenedResults); + } + + auto newCallOp = rewriter.create( + callOp->getLoc(), callOp.getCallee(), convertedResults, + flattenValues(adaptor.getOperands())); + // Preserve any additional attributes that may have been set on the op + newCallOp->setAttrs(callOp->getAttrs()); + + SmallVector replacements; + std::size_t offset = 0; + for (auto groupSize : resultReplacementGrouping) { + replacements.push_back(newCallOp->getResults().slice(offset, groupSize)); + offset += groupSize; + } + + rewriter.replaceOpWithMultiple(callOp, replacements); + return success(); + } +}; + +struct ReturnOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ReturnOp returnOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newReturnOp = rewriter.create( + returnOp->getLoc(), flattenValues(adaptor.getOperands())); + // Preserve any additional attributes that may have been set on the op + newReturnOp->setAttrs(returnOp->getAttrs()); + + rewriter.replaceOp(returnOp, newReturnOp); + return success(); + } +}; + +} // namespace + +void populateFunctionTypeConversions(const TypeConverter &converter, + RewritePatternSet &patterns) { + mlir::populateFunctionOpInterfaceTypeConversionPattern( + patterns, converter); + patterns.add(converter, + patterns.getContext()); +} + +} // namespace mlir::triton diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/LoopAwareCSE.cpp b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/LoopAwareCSE.cpp new file mode 100644 index 000000000..6c56e478a --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/LoopAwareCSE.cpp @@ -0,0 +1,176 @@ +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Dominance.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/CSE.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/EquivalenceClasses.h" + +using namespace mlir; + +namespace mlir::triton { +#define GEN_PASS_DEF_TRITONLOOPAWARECSE +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" +} // namespace mlir::triton + +namespace { +class ValueEquivalence { +public: + std::optional getKnownEquivalence(Value a, Value b) { + if (auto it = equalValues.find(normalizeKey(a, b)); it != equalValues.end()) + return it->second; + return std::nullopt; + } + void setKnownEquivalence(Value a, Value b, bool eq) { + equalValues.insert_or_assign(normalizeKey(a, b), eq); + } + +private: + // Commutatively query the equivalence of two values by sorting the key by + // pointer value. + std::pair normalizeKey(Value a, Value b) { + if ((uintptr_t)a.getAsOpaquePointer() < (uintptr_t)b.getAsOpaquePointer()) + return {a, b}; + return {b, a}; + } + + DenseMap, bool> equalValues; +}; + +struct LoopCSEDriver { + LoopCSEDriver(scf::ForOp loop) : loop(loop) {} + + bool areIterArgsEqual(int i, int j); + bool areEqualInLoop(Value a, Value b); + + scf::ForOp loop; + SmallVector> argStack; +}; +} // namespace + +bool LoopCSEDriver::areIterArgsEqual(int i, int j) { + if (i == j) + return true; + if (loop.getInitArgs()[i] != loop.getInitArgs()[j]) + return false; + if (llvm::is_contained(argStack, std::make_pair(i, j))) + return true; + BlockArgument aArg = loop.getRegionIterArg(i); + BlockArgument bArg = loop.getRegionIterArg(j); + // First, assume the arguments are equal. This is how recursion is broken. + argStack.push_back({i, j}); + bool result = + areEqualInLoop(loop.getYieldedValues()[i], loop.getYieldedValues()[j]); + argStack.pop_back(); + return result; +} + +bool LoopCSEDriver::areEqualInLoop(Value a, Value b) { + // Check trivial case. + if (a == b) + return true; + if (a.getType() != b.getType()) + return false; + + Block *aBlock = a.getParentBlock(); + Block *bBlock = b.getParentBlock(); + // Values from outside the loop must have been equal. + if (aBlock != loop.getBody() || bBlock != loop.getBody()) { + return false; + } + // Both must be block arguments or not. + if (isa(a) != isa(b)) + return false; + // Both must be the inductor var or not. + if (a == loop.getInductionVar() || b == loop.getInductionVar()) + return false; + + if (auto aArg = dyn_cast(a)) { + auto bArg = cast(b); + bool result = + areIterArgsEqual(aArg.getArgNumber() - 1, bArg.getArgNumber() - 1); + return result; + } + + Operation *aDef = a.getDefiningOp(); + Operation *bDef = b.getDefiningOp(); + // For it to be known that the operation results have the same value, they + // must be side effect free. + if (!isMemoryEffectFree(aDef) || !isMemoryEffectFree(bDef)) + return false; + // Don't bother with operations with regions. + if (aDef->getNumRegions() || bDef->getNumRegions()) + return false; + + bool result = OperationEquivalence::isEquivalentTo( + aDef, bDef, + [&](Value a, Value b) { return success(areEqualInLoop(a, b)); }, + /*markEquivalent=*/nullptr, OperationEquivalence::IgnoreLocations); + return result; +} + +static void loopCSE(scf::ForOp loop) { + int numIterArgs = loop.getNumRegionIterArgs(); + // Group equivalent iter args together. + llvm::EquivalenceClasses equivalentArgs; + LoopCSEDriver driver(loop); + for (int i = 0; i != numIterArgs; ++i) { + for (int j = i + 1; j != numIterArgs; ++j) { + if (driver.areIterArgsEqual(i, j)) + equivalentArgs.unionSets(i, j); + } + } + + // For each equivalence class, replace all other args in the class with one. + for (auto it = equivalentArgs.begin(), end = equivalentArgs.end(); it != end; + ++it) { + if (!(*it)->isLeader()) + continue; + SmallVector eqArgs; + for (auto mIt = equivalentArgs.member_begin(**it); + mIt != equivalentArgs.member_end(); ++mIt) + eqArgs.push_back(*mIt); + assert(eqArgs.size() > 1); + // Sort the indices so the pass is deterministic. + llvm::sort(eqArgs); + BlockArgument unique = loop.getRegionIterArg(eqArgs.front()); + Value uniqueResult = loop.getResult(eqArgs.front()); + for (int j : llvm::drop_begin(eqArgs)) { + BlockArgument other = loop.getRegionIterArg(j); + other.replaceAllUsesWith(unique); + // Short-circuit the value. The canonicalizer will clean this up. Leftover + // subcomputations can now be removed by normal CSE. + (*loop.getYieldedValuesMutable())[j].set(other); + loop.getResult(j).replaceAllUsesWith(uniqueResult); + } + } +} + +namespace { +struct LoopAwareCSE + : public triton::impl::TritonLoopAwareCSEBase { + using TritonLoopAwareCSEBase::TritonLoopAwareCSEBase; + + void runOnOperation() override { + // LoopAwareCSE doesn't recursively CSE ops outside of loops, so run CSE + // first to make sure values from outside loops that are equivalent are made + // pointer equal. + IRRewriter rewriter(&getContext()); + auto &domInfo = getAnalysis(); + eliminateCommonSubExpressions(rewriter, domInfo, getOperation()); + + // CSE region iter args within loop bodies. + getOperation().walk(loopCSE); + + // Now that equivalent iter args have been made pointer equal, run CSE again + // to clean up the loop body. + eliminateCommonSubExpressions(rewriter, domInfo, getOperation()); + + // Run the `scf.for` canonicalizer to clean up the loops (short-circuited + // values, unused results, etc.). + RewritePatternSet patterns(&getContext()); + scf::ForOp::getCanonicalizationPatterns(patterns, &getContext()); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/LoopInvariantCodeMotion.cpp b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/LoopInvariantCodeMotion.cpp new file mode 100644 index 000000000..260732230 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/LoopInvariantCodeMotion.cpp @@ -0,0 +1,83 @@ +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "llvm/Support/Debug.h" + +namespace mlir::triton { + +#define GEN_PASS_DEF_TRITONLOOPINVARIANTCODEMOTION +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "triton-licm" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +class LoopInvariantCodeMotionPass + : public impl::TritonLoopInvariantCodeMotionBase< + LoopInvariantCodeMotionPass> { + + DenseMap isLoopMemoryEffectFreeOrOnlyRead; + + bool isMemoryEffectFreeOrOnlyRead(Operation *op) { + std::optional> effects = + getEffectsRecursively(op); + if (!effects) + return false; + return llvm::all_of(*effects, + [&](const MemoryEffects::EffectInstance &effect) { + return isa(effect.getEffect()); + }); + } + + void runOnOperation() override { + // Walk through all loops in a function in innermost-loop-first order. + // This way, we first LICM from the inner loop, and place the ops in the + // outer loop, which in turn can be further LICM'ed. + getOperation()->walk([&](LoopLikeOpInterface loopLike) { + moveLoopInvariantCode( + loopLike.getLoopRegions(), + // isDefinedOutsideOfRegion + [&](Value value, Region *region) { + return loopLike.isDefinedOutsideOfLoop(value); + }, + // shouldMoveOutOfRegion + [&](Operation *op, Region *region) { + if (!isa(op)) + return isSpeculatable(op) && isMemoryEffectFree(op); + if (!isLoopMemoryEffectFreeOrOnlyRead.contains(loopLike)) + isLoopMemoryEffectFreeOrOnlyRead[loopLike] = + isMemoryEffectFreeOrOnlyRead(loopLike); + return isMemoryEffectFreeOrOnlyRead(op) && + isLoopMemoryEffectFreeOrOnlyRead[loopLike]; + }, + // moveOutOfRegion + [&](Operation *op, Region *) { + // Create the new mask for load op. + if (auto loadOp = dyn_cast(op)) { + Value mask = loadOp.getMask(); + IRRewriter rewriter(loopLike); + Location loc = loopLike->getLoc(); + Value cond; + if (auto forOp = dyn_cast(loopLike.getOperation())) { + cond = rewriter.create( + loc, arith::CmpIPredicate::slt, forOp.getLowerBound(), + forOp.getUpperBound()); + } else if (auto whileOp = + dyn_cast(loopLike.getOperation())) { + // TODO: Support Load Op hoisting for while loop. + return; + } else { + return; + } + Value newMask = getPredMask(rewriter, loadOp.getPtr().getType(), + loadOp.getMask(), cond); + loadOp.getMaskMutable().assign(newMask); + } + loopLike.moveOutOfLoop(op); + }); + }); + } +}; + +} // namespace mlir::triton diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/LoopPeeling.cpp b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/LoopPeeling.cpp new file mode 100644 index 000000000..59028a50b --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/LoopPeeling.cpp @@ -0,0 +1,68 @@ +#include "triton/Dialect/Triton/Transforms/LoopPeeling.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +using namespace mlir; + +namespace mlir { +namespace triton { + +void peelLoopEpilogue( + scf::ForOp forOp, + function_ref + processPeeledOp) { + SmallVector loopBodyOps; + IRRewriter rewriter(forOp); + Location loc = forOp.getLoc(); + Type type = forOp.getStep().getType(); + + // Fetch loop bounds and step + Value lowerBound = forOp.getLowerBound(); + Value upperBound = forOp.getUpperBound(); + Value step = forOp.getStep(); + Value newUpperBound = rewriter.create(loc, upperBound, step); + + rewriter.setInsertionPointAfter(forOp); + Value lastIV = getLastInductionValue(rewriter, forOp); + + auto cond = rewriter.create(loc, arith::CmpIPredicate::slt, + lowerBound, upperBound); + + // Create an if op to execute the peeled iteration + IRMapping map; + map.map(forOp.getRegionIterArgs(), forOp.getResults()); + map.map(forOp.getInductionVar(), lastIV); + auto ifOp = rewriter.create(loc, forOp.getResultTypes(), cond, + /*hasElse=*/true); + ifOp.getThenRegion().front().erase(); + forOp.getBodyRegion().cloneInto(&ifOp.getThenRegion(), map); + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + rewriter.create(loc, forOp.getResults()); + + forOp->replaceUsesWithIf(ifOp, [&](OpOperand &operand) { + return !ifOp->isAncestor(operand.getOwner()); + }); + + forOp.getUpperBoundMutable().assign(newUpperBound); + + if (processPeeledOp) { + for (auto &op : + llvm::make_early_inc_range(forOp.getBody()->without_terminator())) { + Operation *newOp = processPeeledOp(rewriter, &op, /*isEpilogue=*/false); + if (newOp && newOp != &op) { + op.replaceAllUsesWith(newOp); + } + } + for (auto &op : llvm::make_early_inc_range( + ifOp.getThenRegion().front().without_terminator())) { + Operation *newOp = processPeeledOp(rewriter, &op, /*isEpilogue=*/true); + if (newOp && newOp != &op) { + op.replaceAllUsesWith(newOp); + } + } + } +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/LoopUnroll.cpp b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/LoopUnroll.cpp new file mode 100644 index 000000000..294dff873 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/LoopUnroll.cpp @@ -0,0 +1,62 @@ +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "llvm/Support/Debug.h" + +namespace mlir::triton { + +#define GEN_PASS_DEF_TRITONLOOPUNROLL +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "triton-loop-unroll" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +class LoopUnrollPass : public impl::TritonLoopUnrollBase { + + int getUnrollFactorOrDefault(scf::ForOp forOp) { + // Use the attribute attached to the loop if it exists otherwise set the + // factor to 1 to suppress the unrolling. + if (auto factor = + forOp->getAttrOfType(loopUnrollFactorAttrName)) + return factor.getInt(); + return 1; + } + + const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor"; + const char *pipelineStagesAttrName = "tt.num_stages"; + +public: + void runOnOperation() override { + LDBG("Loop unroll pass"); + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { + // Bail out for loops with unroll factor <= 1. + if (getUnrollFactorOrDefault(forOp) > 1) + loops.push_back(forOp); + }); + + auto ctx = getOperation()->getContext(); + for (auto loop : loops) { + auto unrollFactor = getUnrollFactorOrDefault(loop); + loop->removeAttr(loopUnrollFactorAttrName); + LDBG("Unrolling loop by " << unrollFactor << " times\n" << loop); + auto resultLoops = loopUnrollByFactor(loop, unrollFactor); + // Do not pipeline the epilog loop. + if (succeeded(resultLoops) && resultLoops->epilogueLoopOp) { + (*resultLoops->epilogueLoopOp) + ->setAttr(pipelineStagesAttrName, + mlir::IntegerAttr::get(IntegerType::get(ctx, 32), 1)); + } + } + } +}; + +} // namespace mlir::triton diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp new file mode 100644 index 000000000..8338d4cf6 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp @@ -0,0 +1,232 @@ +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +namespace mlir::triton { + +#define GEN_PASS_DEF_TRITONREORDERBROADCAST +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace { + +Operation *cloneWithNewArgsAndResultTypes(PatternRewriter &rewriter, + Operation *op, ValueRange newOperands, + TypeRange newTypes) { + OperationState newElementwiseState(op->getLoc(), op->getName()); + newElementwiseState.addOperands(newOperands); + newElementwiseState.addTypes(newTypes); + newElementwiseState.addAttributes(op->getAttrs()); + return rewriter.create(newElementwiseState); +} + +bool isSplat(Operation *op) { + if (auto splatOp = llvm::dyn_cast(op)) { + return true; + } + DenseElementsAttr constAttr; + return (matchPattern(op, m_Constant(&constAttr)) && constAttr.isSplat()); +} + +// elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) +struct MoveSplatAfterElementwisePattern + : public OpTraitRewritePattern { + + MoveSplatAfterElementwisePattern(MLIRContext *context) + : OpTraitRewritePattern(context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (!isMemoryEffectFree(op)) { + return failure(); + } + + for (auto operand : op->getOperands()) { + auto definingOp = operand.getDefiningOp(); + if (!definingOp) + return failure(); + + if (!isSplat(definingOp)) { + return failure(); + } + } + + if (op->getNumOperands() <= 0) + return failure(); + + auto loc = op->getLoc(); + auto operands = op->getOperands(); + + llvm::SmallVector scalarOperands(operands.size()); + for (unsigned iOp = 0; iOp < operands.size(); ++iOp) { + auto definingOp = operands[iOp].getDefiningOp(); + + DenseElementsAttr constAttr; + if (auto splatOp = llvm::dyn_cast(definingOp)) { + scalarOperands[iOp] = splatOp.getSrc(); + } else if (matchPattern(definingOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto value = constAttr.getSplatValue(); + scalarOperands[iOp] = arith::ConstantOp::materialize( + rewriter, value, constAttr.getElementType(), loc); + } else { + llvm_unreachable("Expected a splat"); + } + } + + auto resultTypes = op->getResultTypes(); + llvm::SmallVector scalarResultTys; + for (auto resultTy : resultTypes) { + auto elemTy = dyn_cast(resultTy).getElementType(); + scalarResultTys.push_back(elemTy); + } + + auto newOp = cloneWithNewArgsAndResultTypes(rewriter, op, scalarOperands, + scalarResultTys); + + for (unsigned iRes = 0; iRes < resultTypes.size(); ++iRes) { + auto newResult = rewriter.create(loc, resultTypes[iRes], + newOp->getResult(iRes)); + rewriter.replaceAllUsesWith(op->getResult(iRes), newResult); + } + return success(); + } +}; + +// elementwise(broadcast(a)) => broadcast(elementwise(a)) +// This also generalizes to multiple arguments when the rest are splat-like +// Not handled: multiple broadcasted arguments +struct MoveBroadcastAfterElementwisePattern + : public OpTraitRewritePattern { + + MoveBroadcastAfterElementwisePattern(MLIRContext *context) + : OpTraitRewritePattern(context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (!isMemoryEffectFree(op)) { + return failure(); + } + + auto operands = op->getOperands(); + bool seenBroadcast = false; + ArrayRef srcShape; + for (auto operand : operands) { + auto definingOp = operand.getDefiningOp(); + if (!definingOp) { + return failure(); + } + auto getSrcShape = [](BroadcastOp b) { + return b.getSrc().getType().getShape(); + }; + if (auto broadcastOp = llvm::dyn_cast(definingOp)) { + if (!seenBroadcast) { + seenBroadcast = true; + srcShape = getSrcShape(broadcastOp); + } else if (srcShape != getSrcShape(broadcastOp)) { + // If the broadcast have different types we cannot re-order. + return failure(); + } + } else if (!isSplat(definingOp)) { + // Not splat or broadcast + return failure(); + } + } + if (!seenBroadcast) + return failure(); + + auto loc = op->getLoc(); + + // Find broadcast op + BroadcastOp broadcastOp; + for (auto operand : operands) { + broadcastOp = operand.getDefiningOp(); + if (broadcastOp) { + break; + } + } + + auto srcTy = broadcastOp.getSrc().getType(); + auto bcSrcShape = srcTy.getShape(); + auto srcEncoding = srcTy.getEncoding(); + + // Reshape operands to match srcShape + llvm::SmallVector newOperands; + for (auto operand : operands) { + auto definingOp = operand.getDefiningOp(); + if (auto broadcastSrcOp = llvm::dyn_cast(definingOp)) { + newOperands.push_back(broadcastSrcOp.getSrc()); + continue; + } + auto elemTy = + dyn_cast(operand.getType()).getElementType(); + auto newTy = RankedTensorType::get(bcSrcShape, elemTy, srcEncoding); + if (auto splatOp = llvm::dyn_cast(definingOp)) { + auto newSplat = rewriter.create(loc, newTy, splatOp.getSrc()); + newOperands.push_back(newSplat); + continue; + } + DenseElementsAttr constAttr; + if (matchPattern(definingOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto scalarValue = constAttr.getSplatValue(); + auto splatValue = SplatElementsAttr::get(newTy, scalarValue); + auto newConstant = + rewriter.create(loc, newTy, splatValue); + newOperands.push_back(newConstant); + continue; + } + llvm_unreachable("Expected broadcast or splat"); + } + + // Reshape results to match srcShape + llvm::SmallVector newResultTypes; + auto resultTypes = op->getResultTypes(); + for (auto resultTy : resultTypes) { + auto elemTy = dyn_cast(resultTy).getElementType(); + newResultTypes.push_back( + RankedTensorType::get(bcSrcShape, elemTy, srcEncoding)); + } + + // Create new op and broadcast results + auto newOp = cloneWithNewArgsAndResultTypes(rewriter, op, newOperands, + newResultTypes); + for (unsigned iRes = 0; iRes < newResultTypes.size(); ++iRes) { + auto newResult = rewriter.create(loc, resultTypes[iRes], + newOp->getResult(iRes)); + rewriter.replaceAllUsesWith(op->getResult(iRes), newResult); + } + return success(); + } +}; + +} // namespace + +class ReorderBroadcastPass + : public impl::TritonReorderBroadcastBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp m = getOperation(); + + BroadcastOp::getCanonicalizationPatterns(patterns, context); + ExpandDimsOp::getCanonicalizationPatterns(patterns, context); + // elementwise(broadcast(a)) => broadcast(elementwise(a)) + patterns.add(context); + // elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) + patterns.add(context); + + if (applyPatternsGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // namespace mlir::triton diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp new file mode 100644 index 000000000..b9012b85e --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp @@ -0,0 +1,508 @@ +#include "triton/Dialect/Triton/Transforms/ArithTypeConversion.h" +#include "triton/Dialect/Triton/Transforms/FunctionTypeConversion.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include +#include + +#include + +namespace mlir::triton { + +#define GEN_PASS_DEF_TRITONREWRITETENSORDESCRIPTORTOPOINTER +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace { + +bool hasATensorDescriptorType(mlir::TypeRange types) { + return llvm::any_of(types, [](mlir::Type t) { + return llvm::isa(t); + }); +} + +using namespace mlir; + +/** + * @brief Filter out operand segment sizes from the list of attributes since + * this attribute is operation specific and shouldn't be set arbitrarily. + */ +mlir::SmallVector +filterSegmentSizes(mlir::ArrayRef attrs) { + mlir::SmallVector ret; + llvm::copy_if(attrs, std::back_inserter(ret), [](const NamedAttribute &attr) { + auto attrName = attr.getName().getValue(); + return attrName != "operandSegmentSizes"; + }); + return ret; +} + +struct Descriptor { + Value base; + ValueRange shape; + ValueRange strides; +}; + +Descriptor unpackDescriptor(TensorDescType type, ValueRange pack) { + int rank = type.getBlockType().getRank(); + assert(pack.size() == 1 + 2 * rank && "Expected tensor descriptors to be " + "broken down into a ptr and " + "`rank` shapes and `rank` strides"); + Descriptor res; + res.base = pack[0]; + res.shape = pack.slice(1, rank); + res.strides = pack.slice(1 + rank, rank); + return res; +} + +Value expandOffsets(OpBuilder &builder, Location loc, + ArrayRef blockShape, Value offsets, unsigned dim) { + Value expandedResult = offsets; + for (size_t j = 0; j < blockShape.size(); ++j) { + if (j == dim) { + continue; + } + expandedResult = + builder.create(loc, expandedResult, j); + } + + return expandedResult; +} + +Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc, + ArrayRef blockShape, + Value offset, unsigned dim) { + // Add range + auto indexI32RowType = + RankedTensorType::get({blockShape[dim]}, builder.getI32Type()); + auto indexRowType = + RankedTensorType::get({blockShape[dim]}, builder.getI64Type()); + Value splatOffset = + builder.create(loc, indexRowType, offset); + Value range = builder.create(loc, indexI32RowType, 0, + blockShape[dim]); + Value i64Range = builder.create(loc, indexRowType, range); + + Value offsets = builder.create(loc, splatOffset, i64Range); + return expandOffsets(builder, loc, blockShape, offsets, dim); +} + +Value generatePtrFromOffsetRanges(OpBuilder &builder, Location loc, + ArrayRef blockShape, + Descriptor &desc, ValueRange offsets) { + assert(blockShape.size() == desc.shape.size()); + assert(blockShape.size() == offsets.size()); + auto indexTensorType = + RankedTensorType::get(blockShape, builder.getI64Type()); + auto ptrType = cast(desc.base.getType()); + auto ptrTensorType = RankedTensorType::get(blockShape, ptrType); + + // Generate offsets per dimension + Value ptr = builder.create(loc, ptrTensorType, desc.base); + for (unsigned i = 0; i < blockShape.size(); ++i) { + // We must splat strides into the expanded shape not a row for retaining + // the divisibility information given by strides + Value splatStride = builder.create( + loc, offsets[i].getType(), desc.strides[i]); + Value offsetWithStride = + builder.create(loc, offsets[i], splatStride); + Value broadcasted = builder.create( + loc, indexTensorType, offsetWithStride); + + // Add to the pointer + ptr = + builder.create(loc, ptrTensorType, ptr, broadcasted); + } + + return ptr; +} + +Value generatePtr(OpBuilder &builder, const Location &loc, + ArrayRef blockShape, Descriptor &desc, + ValueRange offsets) { + assert(blockShape.size() == desc.shape.size()); + assert(blockShape.size() == offsets.size()); + SmallVector offsetRanges; + for (unsigned i = 0; i < blockShape.size(); ++i) { + auto offsetWithRange = + getExpandedOffsetWithRange(builder, loc, blockShape, offsets[i], i); + offsetRanges.push_back(offsetWithRange); + } + + return generatePtrFromOffsetRanges(builder, loc, blockShape, desc, + offsetRanges); +} + +Value generateMaskFromOffsetRanges(OpBuilder &builder, const Location &loc, + ArrayRef blockShape, + Descriptor &desc, ValueRange offsetRanges) { + assert(blockShape.size() == desc.shape.size()); + assert(blockShape.size() == offsetRanges.size()); + + // Generate mask per dimension + auto maskTensorType = RankedTensorType::get(blockShape, builder.getI1Type()); + Value mask; + for (std::size_t i = 0; i < blockShape.size(); ++i) { + auto offsetWithRange = offsetRanges[i]; + + // Compare with lower bound + Value lowerBound = builder.create( + loc, 0, builder.getI64Type()); + Value splatLowerBound = builder.create( + loc, offsetRanges[i].getType(), lowerBound); + Value cmpLower = builder.create( + loc, arith::CmpIPredicate::sge, offsetRanges[i], splatLowerBound); + + // Compare with upper bound + Value splatUpperBound = builder.create( + loc, offsetRanges[i].getType(), desc.shape[i]); + Value cmpUpper = builder.create( + loc, arith::CmpIPredicate::slt, offsetRanges[i], splatUpperBound); + + // And and broadcast + Value andResult = builder.create(loc, cmpLower, cmpUpper); + Value broadcasted = + builder.create(loc, maskTensorType, andResult); + + // And up all results + if (!mask) { + mask = broadcasted; + } else { + mask = builder.create(loc, mask, broadcasted); + } + } + + return mask; +} + +Value generateMask(OpBuilder &builder, const Location &loc, + ArrayRef blockShape, Descriptor &desc, + ValueRange offsets) { + assert(blockShape.size() == desc.shape.size()); + assert(blockShape.size() == offsets.size()); + SmallVector offsetRanges; + for (unsigned i = 0; i < blockShape.size(); ++i) { + auto offsetWithRange = + getExpandedOffsetWithRange(builder, loc, blockShape, offsets[i], i); + offsetRanges.push_back(offsetWithRange); + } + + return generateMaskFromOffsetRanges(builder, loc, blockShape, desc, + offsetRanges); +} + +Value generateOther(OpBuilder &builder, Location loc, Type scalarTy, + ArrayRef blockShape) { + auto blockTy = RankedTensorType::get(blockShape, scalarTy); + auto attr = builder.getZeroAttr(blockTy); + return builder.create(loc, attr); +} + +Value generateOther(OpBuilder &builder, Location loc, TensorDescType descTy) { + auto blockTy = descTy.getSignlessBlockType(); + return generateOther(builder, loc, blockTy.getElementType(), + blockTy.getShape()); +} + +SmallVector castToI64(OpBuilder &builder, + mlir::ValueRange values) { + auto i64Type = builder.getI64Type(); + return llvm::map_to_vector(values, [&](mlir::Value v) { + return builder.createOrFold(v.getLoc(), i64Type, v); + }); +} + +struct RewriteMakeTensorDesc : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + llvm::LogicalResult + matchAndRewrite(triton::MakeTensorDescOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector ptrShapeStrides; + llvm::append_values(ptrShapeStrides, adaptor.getBase()); + llvm::append_range(ptrShapeStrides, + castToI64(rewriter, adaptor.getShape())); + llvm::append_range(ptrShapeStrides, adaptor.getStrides()); + rewriter.replaceOpWithMultiple(op, {ptrShapeStrides}); + return mlir::success(); + } +}; + +struct RewriteLoadPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + llvm::LogicalResult + matchAndRewrite(triton::DescriptorLoadOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + const auto blockShape = op.getDesc().getType().getBlockType().getShape(); + auto descTy = op.getDesc().getType(); + auto desc = unpackDescriptor(descTy, adaptor.getDesc()); + auto offsets = castToI64(rewriter, op.getIndices()); + + auto newLoad = rewriter.replaceOpWithNewOp( + op, generatePtr(rewriter, loc, blockShape, desc, offsets), + generateMask(rewriter, loc, blockShape, desc, offsets), + generateOther(rewriter, loc, descTy), triton::CacheModifier::NONE, + triton::EvictionPolicy::NORMAL, false); + newLoad->setAttrs(filterSegmentSizes(op->getAttrs())); + + return llvm::success(); + } +}; + +struct RewriteStorePattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + llvm::LogicalResult + matchAndRewrite(triton::DescriptorStoreOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto descTy = op.getDesc().getType(); + const auto blockShape = descTy.getBlockType().getShape(); + auto desc = unpackDescriptor(descTy, adaptor.getDesc()); + auto offsets = castToI64(rewriter, op.getIndices()); + + auto newStore = rewriter.replaceOpWithNewOp( + op, generatePtr(rewriter, loc, blockShape, desc, offsets), op.getSrc(), + generateMask(rewriter, loc, blockShape, desc, offsets), + triton::CacheModifier::NONE, triton::EvictionPolicy::NORMAL); + newStore->setAttrs(filterSegmentSizes(op->getAttrs())); + + return llvm::success(); + } +}; + +std::pair +generateGatherScatterPtrMask(OpBuilder &builder, Location loc, + ArrayRef blockShape, Descriptor &desc, + Value xOffsets, Value yOffset) { + Value xOffsetRange = + expandOffsets(builder, loc, blockShape, xOffsets, /*dim=*/0); + yOffset = castToI64(builder, {yOffset})[0]; + auto xOffsetI64Ty = RankedTensorType::get( + cast(xOffsetRange.getType()).getShape(), + yOffset.getType()); + xOffsetRange = + builder.create(loc, xOffsetI64Ty, xOffsetRange); + auto yOffsetRange = + getExpandedOffsetWithRange(builder, loc, blockShape, yOffset, /*dim=*/1); + auto ptr = generatePtrFromOffsetRanges(builder, loc, blockShape, desc, + {xOffsetRange, yOffsetRange}); + auto mask = generateMaskFromOffsetRanges(builder, loc, blockShape, desc, + {xOffsetRange, yOffsetRange}); + return {ptr, mask}; +} + +struct RewriteGatherPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + llvm::LogicalResult + matchAndRewrite(triton::DescriptorGatherOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto descTy = op.getDesc().getType(); + const auto blockShape = op.getResult().getType().getShape(); + auto desc = unpackDescriptor(descTy, adaptor.getDesc()); + auto [ptr, mask] = generateGatherScatterPtrMask( + rewriter, loc, blockShape, desc, op.getXOffsets(), op.getYOffset()); + auto other = generateOther(rewriter, loc, + descTy.getSignlessBlockType().getElementType(), + blockShape); + auto newLoad = rewriter.replaceOpWithNewOp( + op, ptr, mask, other, triton::CacheModifier::NONE, + triton::EvictionPolicy::NORMAL, false); + newLoad->setAttrs(filterSegmentSizes(op->getAttrs())); + + return llvm::success(); + } +}; + +struct RewriteScatterPattern + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + llvm::LogicalResult + matchAndRewrite(triton::DescriptorScatterOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto descTy = op.getDesc().getType(); + const auto blockShape = op.getSrc().getType().getShape(); + auto desc = unpackDescriptor(descTy, adaptor.getDesc()); + auto [ptr, mask] = generateGatherScatterPtrMask( + rewriter, loc, blockShape, desc, op.getXOffsets(), op.getYOffset()); + auto newStore = rewriter.replaceOpWithNewOp( + op, ptr, op.getSrc(), mask, triton::CacheModifier::NONE, + triton::EvictionPolicy::NORMAL); + newStore->setAttrs(filterSegmentSizes(op->getAttrs())); + + return llvm::success(); + } +}; + +std::optional translateReduceKind(DescriptorReduceKind kind, + TensorDescType ty) { + auto scalarTy = ty.getBlockType().getElementType(); + switch (kind) { + case DescriptorReduceKind::ADD: + return scalarTy.isInteger() ? RMWOp::ADD : RMWOp::FADD; + case DescriptorReduceKind::MIN: + if (scalarTy.isUnsignedInteger()) { + return RMWOp::UMIN; + } else if (scalarTy.isSignedInteger()) { + return RMWOp::MIN; + } + return {}; + case DescriptorReduceKind::MAX: + if (scalarTy.isUnsignedInteger()) { + return RMWOp::UMAX; + } else if (scalarTy.isSignedInteger()) { + return RMWOp::MAX; + } + return {}; + case DescriptorReduceKind::AND: + return RMWOp::AND; + case DescriptorReduceKind::OR: + return RMWOp::OR; + case DescriptorReduceKind::XOR: + return RMWOp::XOR; + default: + break; + } + return {}; +} + +struct RewriteReducePattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + llvm::LogicalResult + matchAndRewrite(triton::DescriptorReduceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto descTy = op.getDesc().getType(); + const auto blockShape = descTy.getBlockType().getShape(); + auto desc = unpackDescriptor(descTy, adaptor.getDesc()); + auto offsets = castToI64(rewriter, op.getIndices()); + auto rmwOp = translateReduceKind(op.getKind(), descTy); + if (!rmwOp) { + std::string msgstring; + llvm::raw_string_ostream msg(msgstring); + msg << "Cannot fallback on descriptor atomic op, unsupported for type " + << descTy.getBlockType().getElementType(); + return op->emitError(msgstring); + } + + auto newStore = rewriter.create( + loc, descTy.getSignlessBlockType(), *rmwOp, + generatePtr(rewriter, loc, blockShape, desc, offsets), op.getSrc(), + generateMask(rewriter, loc, blockShape, desc, offsets), + MemSemantic::RELEASE, MemSyncScope::GPU); + op.erase(); + return success(); + } +}; + +/** + * @brief This implements the pass for converting triton tensor descriptor + * loads/stores into indexed loads/stores. + * + * The key idea is that each tensor descriptor can be broken down into multiple + * values. Suppose we have a tensor pointer with rank r, we can cast that tensor + * descriptor value to and from 1+2r values: a tensor pointer value and two i32 + * value for each dimension representing the dynamic shape and strides. + * + * As in normal conversion patterns, individual operations can be converted + * using casted tensor descriptors and offsets and casting the results back to + * tensor pointers. + * + * We have special handling for TMA loads/stores and the make tensor descriptor + * op. + * + * @note Why use the conversion pattern rewriter? In most cases the defining + * operation of a tensor descriptor will be a make tensor descriptor op. + * However, this isn't always true - for example, if the tensor descriptor is a + * function argument or is in a conditional statement, we need better tracking + * of the pointer, shape, and strides. + */ +class TritonRewriteTensorDescriptorToPointerPass + : public impl::TritonRewriteTensorDescriptorToPointerBase< + TritonRewriteTensorDescriptorToPointerPass> { + void runOnOperation() override { + auto op = getOperation(); + + mlir::ConversionTarget target(getContext()); + target.addDynamicallyLegalDialect( + [](mlir::Operation *op) { + return !hasATensorDescriptorType(op->getOperandTypes()) && + !hasATensorDescriptorType(op->getResultTypes()); + }); + target.addDynamicallyLegalOp([](triton::FuncOp funcOp) { + return !hasATensorDescriptorType(funcOp.getFunctionType().getInputs()) && + !hasATensorDescriptorType(funcOp.getFunctionType().getResults()); + }); + + mlir::TypeConverter converter; + + converter.addConversion([](mlir::Type t) { + // Most types don't require any conversion + return t; + }); + converter.addConversion([](mlir::triton::TensorDescType t, + llvm::SmallVectorImpl &out) { + // We convert a tensor descriptor into an pointer, and a shape and stride + // for each dimension, i.e., we create 1+2*rank values. Note that tensor + // descriptors may be signed/unsigned integers whereas pointers should + // always be signless. + auto tensorType = t.getSignlessBlockType(); + out.push_back(triton::getPointerType(tensorType.getElementType())); + out.insert(out.end(), 2 * tensorType.getRank(), + mlir::IntegerType::get(t.getContext(), 64)); + return mlir::success(); + }); + + mlir::RewritePatternSet patterns(op->getContext()); + + // Populate conversion patterns to handle loops, function calls, and arith + // ops. + triton::populateFunctionTypeConversions(converter, patterns); + mlir::scf::populateSCFStructuralTypeConversions(converter, patterns); + triton::populateArithTypeConversions(converter, patterns); + + patterns + .add( + converter, &getContext()); + + ConversionConfig config; + config.buildMaterializations = false; + + if (mlir::failed(mlir::applyPartialConversion( + op, target, std::move(patterns), config))) { + signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace mlir::triton diff --git a/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp new file mode 100644 index 000000000..4e107dd0a --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -0,0 +1,614 @@ +#include + +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +namespace mlir::triton { + +#define GEN_PASS_DEF_TRITONREWRITETENSORPOINTER +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace { + +/// An additional struct to record the meta information of operations +/// with tensor pointers +struct RewritedInfo { +private: + Value base; + SmallVector shape; + SmallVector strides; + SmallVector offsets; + ArrayRef tensorShape; + bool hasRowStride; + unsigned rowStrideParamIndex; + + // A cache to avoid generating the same offset with range + DenseMap cachedOffsetWithRange; + +public: + RewritedInfo() = default; + + RewritedInfo(const RewritedInfo &other) = default; + + RewritedInfo(Value base, const SmallVector &shape, + const SmallVector &strides, + const SmallVector &offsets, + const ArrayRef &tensorShape, + bool hasRowStride, unsigned rowStrideParamIndex) + : base(base), shape(shape), strides(strides), offsets(offsets), + tensorShape(tensorShape), hasRowStride(hasRowStride), rowStrideParamIndex(rowStrideParamIndex) { + assert(shape.size() == strides.size() && shape.size() == offsets.size() && + shape.size() == tensorShape.size()); + } + + unsigned int length() const { return shape.size(); } + + Value getOffset(unsigned i) { return offsets[i]; } + + SmallVector getOffsets() { return offsets; } + + void setOffset(unsigned i, Value newOffset) { + offsets[i] = newOffset; + cachedOffsetWithRange.clear(); + } + + void setOffsets(const SmallVector &newOffsets) { + offsets = newOffsets; + cachedOffsetWithRange.clear(); + } + + unsigned getRowStrideParamIndex() const { return rowStrideParamIndex; } + + bool getHasRowStride() const { return hasRowStride; } + + Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc, + unsigned i) { + if (cachedOffsetWithRange.count(i)) + return cachedOffsetWithRange[i]; + + // Add range + auto indexI32RowType = + RankedTensorType::get({tensorShape[i]}, builder.getI32Type()); + auto indexRowType = + RankedTensorType::get({tensorShape[i]}, builder.getI64Type()); + Value splatOffset = + builder.create(loc, indexRowType, offsets[i]); + Value range = builder.create(loc, indexI32RowType, 0, + tensorShape[i]); + Value i64Range = builder.create(loc, indexRowType, range); + + // Expand dimensions + Value expandedResult = + builder.create(loc, splatOffset, i64Range); + for (int j = 0; j < tensorShape.size(); ++j) { + if (j == i) + continue; + expandedResult = + builder.create(loc, expandedResult, j); + } + + return cachedOffsetWithRange[i] = expandedResult; + } + + Value generatePtr(OpBuilder &builder, const Location &loc) { + assert(tensorShape.size() == offsets.size() && + tensorShape.size() == strides.size()); + auto indexTensorType = + RankedTensorType::get(tensorShape, builder.getI64Type()); + auto ptrType = cast(base.getType()); + auto ptrTensorType = RankedTensorType::get(tensorShape, ptrType); + + // Generate offsets per dimension + Value ptr = builder.create(loc, ptrTensorType, base); + for (unsigned i = 0; i < tensorShape.size(); ++i) { + auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); + + // We must splat strides into the expanded shape not a row for retaining + // the divisibility information given by strides + Value splatStride = builder.create( + loc, offsetWithRange.getType(), strides[i]); + Value offsetWithStride = + builder.create(loc, offsetWithRange, splatStride); + Value broadcasted = builder.create( + loc, indexTensorType, offsetWithStride); + + // Add to the pointer + ptr = builder.create(loc, ptrTensorType, ptr, + broadcasted); + } + + return ptr; + } + + Value generateMask(OpBuilder &builder, const Location &loc, + const std::optional> &boundaryCheck) { + if (!boundaryCheck.has_value()) + return {}; + + // Generate mask per dimension + auto maskTensorType = + RankedTensorType::get(tensorShape, builder.getI1Type()); + Value mask; + for (auto i : boundaryCheck.value()) { + auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); + + // Compare with lower bound + Value lowerBound = builder.create( + loc, 0, builder.getI64Type()); + Value splatLowerBound = builder.create( + loc, offsetWithRange.getType(), lowerBound); + Value cmpLower = builder.create( + loc, arith::CmpIPredicate::sge, offsetWithRange, splatLowerBound); + + // Compare with upper bound + Value splatUpperBound = builder.create( + loc, offsetWithRange.getType(), shape[i]); + Value cmpUpper = builder.create( + loc, arith::CmpIPredicate::slt, offsetWithRange, splatUpperBound); + + // And and broadcast + Value andResult = builder.create(loc, cmpLower, cmpUpper); + Value broadcasted = + builder.create(loc, maskTensorType, andResult); + + // And up all results + if (!mask) { + mask = broadcasted; + } else { + mask = builder.create(loc, mask, broadcasted); + } + } + + return mask; + } + + Value generateOther(OpBuilder &builder, const Location &loc, + const std::optional &padding) { + if (!padding.has_value()) + return Value(); + + // Create element attribute + auto elementType = + cast(base.getType()).getPointeeType(); + auto otherTensorType = RankedTensorType::get(tensorShape, elementType); + + // Set zero padding value + TypedAttr attr = builder.getZeroAttr(elementType); + + // Float NaN padding case + if (padding.value() == triton::PaddingOption::PAD_NAN) { + assert(!elementType.isIntOrIndex()); + auto apNaN = llvm::APFloat::getNaN( + cast(attr).getValue().getSemantics()); + attr = builder.getFloatAttr(elementType, apNaN); + } + + // Create tensor + Value constant = builder.create(loc, attr); + return builder.create(loc, otherTensorType, constant); + } +}; + +// 成功返回true +bool findRowStrideParamIndexFromValue(Value v, unsigned* paramIdx) { + Operation* op = v.getDefiningOp(); + while(op != nullptr) { + if(op->getOperands().size() < 1) { + return false; + } + v = op->getOperands()[0]; + op = v.getDefiningOp(); + } + if(auto blockArg = dyn_cast(v)) { + *paramIdx = blockArg.getArgNumber(); + Block* parentBlock = blockArg.getOwner(); + if(parentBlock->isEntryBlock()) { + return true; + } + } + return false; +} + +} // namespace + +// TODO: this pass relies on assumptions of how block pointers are created and +// on pattern matches that walks the SSA links to find the base/strides. This is +// very fragile and to solve we should expose convert Ptr of tensor to a +// structure containins all values and not only offsets. +class RewriteTensorPointerPass + : public impl::TritonRewriteTensorPointerBase { +private: + DenseMap rewritedInfo; + +public: + static bool needRewrite(Operation *op) { + return std::any_of(op->getOperands().begin(), op->getOperands().end(), + [](Value operand) { + return triton::isTensorPointerType(operand.getType()); + }); + } + + static void generateNewOperands(SmallVector &oldOperands, + unsigned index, ArrayRef newValues) { + size_t size = oldOperands.size(); + assert(index < size); + SmallVector operands = oldOperands; + oldOperands.reserve(size - 1 + newValues.size()); + oldOperands.clear(); + if (index != 0) { + oldOperands.append(operands.begin(), operands.begin() + index); + } + oldOperands.append(newValues.begin(), newValues.end()); + if (index != size - 1) { + oldOperands.append(operands.begin() + index + 1, operands.end()); + } + } + + Operation *rewriteMakeTensorPtrOp(OpBuilder &builder, + triton::MakeTensorPtrOp op, + std::stack &eraser) { + // Save info for later use + auto ptrType = cast(op.getType()); + auto tensorType = cast(ptrType.getPointeeType()); + + // Cast I32 offsets into I64 + SmallVector i64Offsets; + for (auto offset : op.getOffsets()) { + auto i64Offset = builder.create( + op.getLoc(), builder.getI64Type(), offset); + i64Offsets.push_back(i64Offset); + } + + unsigned rowStrideParamIdx = 0; + bool hasRowStride = false; + auto strides = op.getStrides(); + if(strides.size() == 2) { + hasRowStride = findRowStrideParamIndexFromValue(strides[0], &rowStrideParamIdx); + } + + // Save information + rewritedInfo[op.getResult()] = + RewritedInfo(op.getBase(), op.getShape(), op.getStrides(), i64Offsets, + tensorType.getShape(), hasRowStride, rowStrideParamIdx); + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteAdvanceOp(OpBuilder &builder, triton::AdvanceOp op, + std::stack &eraser) { + // Get info from previous results + assert(rewritedInfo.count(op.getPtr())); + auto info = rewritedInfo[op.getPtr()]; + + // Calculate new offsets + assert(info.length() == op.getOffsets().size()); + SmallVector newOffsets; + for (int i = 0; i < info.length(); ++i) { + Value i64Offset = builder.create( + op.getLoc(), builder.getI64Type(), op.getOffsets()[i]); + Value newOffset = builder.create( + op.getLoc(), info.getOffset(i), i64Offset); + newOffsets.push_back(newOffset); + } + + // Save info for later use + info.setOffsets(newOffsets); + rewritedInfo[op.getResult()] = info; + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteLoadStoreOp(OpBuilder &builder, Operation *op, + std::stack &eraser) { + assert(isa(op) || isa(op)); + + // We only have to rewrite load/stores with tensor pointers + auto ptr = op->getOperand(0); + if (!triton::isTensorPointerType(ptr.getType())) + return nullptr; + + // Get info from previous results + assert(rewritedInfo.count(ptr)); + auto info = rewritedInfo[ptr]; + + // Load/store with tensor pointers implicitly will check the bound while + // accessing memory, so we should set `mask` and `other` (according to the + // padding). Also note that load with tensor pointers do not have `mask` and + // `other` while building IR from Python AST + std::optional> boundaryCheck; + if (auto loadOp = dyn_cast(op)) { + assert(!loadOp.getMask() && !loadOp.getOther()); + boundaryCheck = loadOp.getBoundaryCheck(); + } else if (auto storeOp = dyn_cast(op)) { + assert(!storeOp.getMask()); + boundaryCheck = storeOp.getBoundaryCheck(); + } + + // Generate new `ptr`, `mask` and `other` + auto newPtr = info.generatePtr(builder, op->getLoc()); + auto newMask = info.generateMask(builder, op->getLoc(), boundaryCheck); + Value newOther; + if (auto loadOp = dyn_cast(op)) + newOther = info.generateOther(builder, op->getLoc(), loadOp.getPadding()); + + // Create a new operation + if (auto loadOp = dyn_cast(op)) { + auto newResult = builder.create( + loadOp.getLoc(), newPtr, newMask, newOther, loadOp.getCache(), + loadOp.getEvict(), loadOp.getIsVolatile()); + MLIRContext *ctx = builder.getContext(); + if(info.getHasRowStride() == true) { + newResult->setAttr("sunrise.rowStrideParamIdx", IntegerAttr::get(IntegerType::get(ctx, 32), info.getRowStrideParamIndex())); + } + op->getResult(0).replaceAllUsesWith(newResult); + } else if (auto storeOp = dyn_cast(op)) { + builder.create(storeOp.getLoc(), newPtr, + storeOp.getValue(), newMask, + storeOp.getCache(), storeOp.getEvict()); + } + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteIfOp(OpBuilder &builder, scf::IfOp op, + std::stack &eraser) { + auto thenYieldOp = op.thenYield(); + assert(op.getNumResults() == thenYieldOp.getNumOperands()); + SmallVector results = thenYieldOp.getOperands(); + + // get new result types + SmallVector newRetTypes; + bool needRewrite = false; + for (unsigned i = 0; i < results.size(); ++i) { + if (!triton::isTensorPointerType(results[i].getType())) { + newRetTypes.push_back(results[i].getType()); + continue; + } + needRewrite = true; + auto makeTensorPtrOp = triton::getMakeTensorPtrOp(results[i]); + assert(rewritedInfo.count(makeTensorPtrOp.getResult())); + const auto &info = rewritedInfo[makeTensorPtrOp.getResult()]; + for (unsigned j = 0; j < info.length(); ++j) { + newRetTypes.push_back(builder.getI64Type()); + } + } + if (!needRewrite) + return op; + // create and clone new IfOp + bool hasElse = !op.getElseRegion().empty(); + scf::IfOp newOp = builder.create(op.getLoc(), newRetTypes, + op.getCondition(), hasElse); + IRMapping mapping; + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + mapping.map(op->getOperand(i), newOp->getOperand(i)); + } + auto rematerialize = [&](Block *block) { + for (Operation &opInIf : block->getOperations()) { + builder.clone(opInIf, mapping); + } + }; + builder.setInsertionPointToStart(newOp.thenBlock()); + rematerialize(op.thenBlock()); + if (hasElse) { + builder.setInsertionPointToStart(newOp.elseBlock()); + rematerialize(op.elseBlock()); + } + + // update rewritedInfo + auto opResults = op.getResults(); + unsigned oldResIdx = 0, newResIdx = 0; + while (oldResIdx < results.size()) { + if (!triton::isTensorPointerType(results[oldResIdx].getType())) { + opResults[oldResIdx].replaceAllUsesWith(newOp.getResult(newResIdx)); + oldResIdx++; + newResIdx++; + } else { + auto makeTensorPtrOp = triton::getMakeTensorPtrOp(results[oldResIdx]); + assert(rewritedInfo.count(makeTensorPtrOp.getResult())); + auto info = rewritedInfo[makeTensorPtrOp.getResult()]; + for (unsigned j = 0; j < info.length(); ++j) { + info.setOffset(j, newOp->getResult(newResIdx++)); + } + rewritedInfo[op.getResult(oldResIdx)] = info; + oldResIdx++; + } + } + + eraser.push(op); + return newOp; + } + + Operation *rewriteForOp(OpBuilder &builder, scf::ForOp op, + std::stack &eraser) { + // Generate new iteration operands and set rewritten information + SmallVector oldIterOperands = llvm::to_vector(op.getInitArgs()); + SmallVector newIterOperands = llvm::to_vector(op.getInitArgs()); + for (unsigned i = 0, oldI = 0, size = op.getInitArgs().size(); i < size; + ++i, ++oldI) { + if (!triton::isTensorPointerType(newIterOperands[i].getType())) + continue; + + // Expand the tensor pointer into offsets + assert(rewritedInfo.count(newIterOperands[i])); + auto info = rewritedInfo[newIterOperands[i]]; + generateNewOperands(newIterOperands, i, info.getOffsets()); + i += info.length() - 1; + size += info.length() - 1; + } + + // Rebuild the loop type + auto newForOp = builder.create(op.getLoc(), op.getLowerBound(), + op.getUpperBound(), op.getStep(), + newIterOperands); + newForOp->setAttrs(op->getAttrs()); + + // Create value mapping. Note that for tensor pointers, we use identity + // mapping. It may refer to a value in the old loop, but we will rewrite it + // later + IRMapping mapping; + for (unsigned i = 0, oldI = 0, sz = op.getInitArgs().size(); oldI < sz; + ++i, ++oldI) { + auto oldRegionIterArg = op.getRegionIterArg(oldI); + if (triton::isTensorPointerType(oldRegionIterArg.getType())) { + // Pass rewritten info inside + assert(rewritedInfo.count(oldIterOperands[oldI])); + auto info = rewritedInfo[oldIterOperands[oldI]]; + mapping.map(oldRegionIterArg, oldRegionIterArg); + for (unsigned j = 0; j < info.length(); ++j) + info.setOffset(j, newForOp.getRegionIterArg(i + j)); + rewritedInfo[oldRegionIterArg] = info; + i += info.length() - 1; + } else { + mapping.map(oldRegionIterArg, newForOp.getRegionIterArg(i)); + } + } + mapping.map(op.getInductionVar(), newForOp.getInductionVar()); + + // Clone body + builder.setInsertionPointToStart(newForOp.getBody()); + for (auto &opInFor : *op.getBody()) { + builder.clone(opInFor, mapping); + } + + // Replace later usages + assert(op.getNumResults() == op.getInitArgs().size()); + for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) { + auto oldResult = op.getResult(oldI); + if (triton::isTensorPointerType(oldResult.getType())) { + // Pack new offsets into rewritten info + assert(rewritedInfo.count(oldIterOperands[oldI])); + auto info = rewritedInfo[oldIterOperands[oldI]]; + for (unsigned j = 0; j < info.length(); ++j) + info.setOffset(j, newForOp.getResult(i + j)); + i += info.length() - 1; + rewritedInfo[oldResult] = info; + } else { + oldResult.replaceAllUsesWith(newForOp.getResult(i)); + } + } + + // Erase later + eraser.push(op); + return newForOp; + } + + Operation *rewriteYieldOp(OpBuilder &builder, scf::YieldOp op, + std::stack &eraser) { + // Replace tensor pointers with offsets + SmallVector newOperands = op->getOperands(); + for (unsigned i = 0, size = op.getNumOperands(); i < size; ++i) { + if (!triton::isTensorPointerType(newOperands[i].getType())) + continue; + + assert(rewritedInfo.count(newOperands[i])); + auto info = rewritedInfo[newOperands[i]]; + generateNewOperands(newOperands, i, info.getOffsets()); + i += info.length() - 1; + size += info.length() - 1; + } + op->setOperands(newOperands); + + // No need to erase + return nullptr; + } + + Operation *rewriteOp(Operation *op, std::stack &eraser) { + OpBuilder builder(op); + + // Rewrite `make_tensor_ptr` and `advance` and make a tensor of pointers + // Rewriting functions return the next operation to visit, if there is no + // next one, simply return `nullptr` + if (auto makeTensorPtrOp = dyn_cast(op)) { + return rewriteMakeTensorPtrOp(builder, makeTensorPtrOp, eraser); + } else if (auto advanceOp = dyn_cast(op)) { + return rewriteAdvanceOp(builder, advanceOp, eraser); + } else if (isa(op) || isa(op)) { + return rewriteLoadStoreOp(builder, op, eraser); + } else if (isa(op->getDialect())) { + if (auto ifOp = dyn_cast(op)) { + return rewriteIfOp(builder, ifOp, eraser); + } + if (!needRewrite(op)) + return op; + + if (auto forOp = dyn_cast(op)) { + return rewriteForOp(builder, forOp, eraser); + } else if (auto yieldOp = dyn_cast(op)) { + return rewriteYieldOp(builder, yieldOp, eraser); + } else { + llvm_unreachable("Currently we only support tensor pointer usages " + "inside a `scf::ForOp` or `scf::IfOp`, others such as " + "`scf::WhileOp`, `cf::BranchOp` or `cf::CondBranchOp` " + "are not supported yet"); + } + } + + // Otherwise return the original one + return op; + } + + void visitOperation(Operation *op, std::stack &eraser) { + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Operation &nestedOp : llvm::make_early_inc_range(block)) { + if (auto newOp = rewriteOp(&nestedOp, eraser)) { + visitOperation(newOp, eraser); + } + } + } + } + } + + void runOnOperation() override { + // NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because + // MLIR does not support one-multiple value mapping. For example, if we use + // `ConversionPatternRewriter`, we can not make a type converter, which + // converts `ptr` into multiple types `ptr<>, int64, int64, ...` + // (containing the base/offsets/strides...). What we can do is to convert + // `ptr` into a single type `Tuple, int64, int64, ...>`. But + // in this way, we also have to define `PackTuple` and `UnpackTuple` + // operations and make a canonicalization pass to optimize, which is much + // So here we recursively build the IR, to be specific, we have to rewrite + // `tt.make_tensor_ptr`, `tt.advance`, `tt.load`, `tt.store`, + // `scf.for` (tensor pointer usages may be in a loop fashion) + std::stack eraser; + visitOperation(getOperation(), eraser); + + // The operation could not be erased during visit, because they may have + // later usages, so we erase after visit + rewritedInfo.clear(); + while (!eraser.empty()) { + auto op = eraser.top(); + eraser.pop(); + op->erase(); + } + + // 此时loadOp不再有boundaryCheck,根据参数个数判断是否有mask + getOperation()->walk([](Operation* op){ + if(isa(*op) == false) { + return; + } + OpBuilder builder(op); + MLIRContext *ctx = builder.getContext(); + int hasOriMask = 0; + if(op->getNumOperands() > 1) { + hasOriMask = 1; + } + op->setAttr("sunrise.hasOriMask", IntegerAttr::get(IntegerType::get(ctx, 32), hasOriMask)); + }); + } +}; + +} // namespace mlir::triton diff --git a/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/CMakeLists.txt b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..e462f0933 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_triton_library(FlagTree_sunrise_TritonGPUIR + Dialect.cpp + LinearLayoutConversions.cpp + + DEPENDS + TritonGPUTableGen + TritonGPUAttrDefsIncGen + TritonGPUTypeInterfacesIncGen + + LINK_LIBS PUBLIC + MLIRGPUDialect + TritonIR + TritonTools +) diff --git a/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/IR/Dialect.cpp b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/IR/Dialect.cpp new file mode 100644 index 000000000..45c340716 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -0,0 +1,3391 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include +#include + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Interfaces.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LayoutUtility.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" +#include "triton/Tools/StrUtil.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/TypeSwitch.h" + +// Include TableGen'erated code +#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" +#include "triton/Dialect/TritonGPU/IR/TypeInterfaces.cpp.inc" + + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// Utility +namespace mlir { +namespace triton { +namespace gpu { + +LinearEncodingAttr TritonGPUDialect::toLinearEncoding(ArrayRef shape, + Attribute layout) { + CacheKey key{std::vector(shape.begin(), shape.end()), layout}; + if (auto result = leCache.get(key)) { + return *result; + } + auto linearLayout = toLinearLayout(shape, layout); + auto linearEncoding = + LinearEncodingAttr::get(layout.getContext(), std::move(linearLayout)); + leCache.set(key, linearEncoding); + return linearEncoding; +} + +LinearEncodingAttr toLinearEncoding(RankedTensorType type) { + return toLinearEncoding(type.getEncoding(), type.getShape()); +} + +LinearEncodingAttr toLinearEncoding(Attribute layout, ArrayRef shape) { + auto *ctx = layout.getContext(); + return ctx->getLoadedDialect()->toLinearEncoding(shape, + layout); +} + +unsigned getTotalElemsPerThread(Attribute layout, ArrayRef shape) { + return toLinearEncoding(layout, shape).getTotalElemsPerThread(shape); +} + +SmallVector getElemsPerThread(Attribute layout, + ArrayRef shape) { + return toLinearEncoding(layout, shape).getElemsPerThread(shape); +} + +SmallVector getElemsPerThread(Type type) { + if (type.isIntOrIndexOrFloat() || isa(type)) + return SmallVector(1, 1); + auto tensorType = cast(type); + return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape()); +} + +unsigned getTotalElemsPerThread(Type type) { + if (type.isIntOrIndexOrFloat() || isa(type)) + return 1; + auto tensorType = cast(type); + return getTotalElemsPerThread(tensorType.getEncoding(), + tensorType.getShape()); +} + +SmallVector getThreadsPerWarp(Attribute layout, + ArrayRef shape) { + return toLinearEncoding(layout, shape).getThreadsPerWarp(); +} + +SmallVector getWarpsPerCTA(Attribute layout, + ArrayRef shape) { + return toLinearEncoding(layout, shape).getWarpsPerCTA(); +} + +SmallVector getContigPerThread(RankedTensorType type) { + return toLinearEncoding(type).getContigPerThread(); +} + +SmallVector getShapePerCTATile(RankedTensorType type) { + return toLinearEncoding(type).getShapePerCTATile(); +} + +bool isExpensiveView(Type srcType, Type dstType) { + auto tensorSrcType = cast(srcType); + auto tensorDstType = cast(dstType); + auto llSrc = + toLinearLayout(tensorSrcType.getShape(), tensorSrcType.getEncoding()); + auto llDst = + toLinearLayout(tensorDstType.getShape(), tensorDstType.getEncoding()); + // In case there are replicated value we need to make sure the new and old + // layout have matching masks. + for (auto [srcMask, dstMask] : + llvm::zip(llSrc.getFreeVariableMasks(), llDst.getFreeVariableMasks())) { + assert(srcMask.first == dstMask.first); + if (srcMask.second != dstMask.second) + return true; + } + return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType); +} + +/* Utility function used by get.*Order methods of SliceEncodingAttr. + * Erase dim and decrease all values larger than dim by 1. + * Example: order = [0, 2, 4, 3, 1], dim = 2 + * resOrder = [0, 3, 2, 1] + */ +static SmallVector eraseOrder(ArrayRef order, + unsigned dim) { + unsigned rank = order.size(); + assert(dim < rank && "Invalid dim to erase"); + SmallVector resOrder; + for (unsigned i : order) + if (i < dim) + resOrder.push_back(i); + else if (i > dim) + resOrder.push_back(i - 1); + return resOrder; +} + +SmallVector getMatrixOrder(unsigned rank, bool rowMajor) { + // Return the order that represents that the batch is in row-major or + // column-major order for a batch of matrices of shape [*, m, n] with + // len(shape) == rank. + assert(rank >= 2); + SmallVector order(rank); + std::iota(order.rbegin(), order.rend(), 0); + if (!rowMajor) { + std::swap(order[0], order[1]); + } + return order; +} + +SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, + bool kContig) { + // kContig: if true, the matrix is fastest-running on k, + // otherwise it is on m (resp. n) + // opIdx=0: [*batch, m, k] + // opIdx=1: [*batch, k, n] + assert(opIdx == 0 || opIdx == 1); + auto rowMajor = bool(opIdx) != kContig; + return getMatrixOrder(rank, rowMajor); +} + +SmallVector getRepOrder(RankedTensorType type) { + auto layout = type.getEncoding(); + if (auto distributedLayout = mlir::dyn_cast(layout)) + return distributedLayout.getRepOrder(); + else + llvm::report_fatal_error("Unimplemented usage of getRepOrder"); + return {}; +} + +// Legacy impl for now +// This one's not terribly bad as we don't broadcast ShareEncodings +SmallVector getOrder(SharedEncodingTrait layout, + ArrayRef shape) { + if (auto swizzledLayout = + mlir::dyn_cast(layout)) { + return llvm::to_vector(swizzledLayout.getOrder()); + } + if (auto sharedLayout = mlir::dyn_cast(layout)) { + if (shape.size() == 1) { + return {0}; + } + return getMatrixOrder(shape.size(), !sharedLayout.getTransposed()); + } + if (auto sharedLayout = + mlir::dyn_cast(layout)) { + return llvm::to_vector(sharedLayout.getOrder()); + } + llvm::report_fatal_error("Unimplemented usage of getOrder for MemDescType"); + return {}; +} + +SmallVector getOrder(DistributedEncodingTrait layout, + ArrayRef shape) { + return toLinearEncoding(layout, shape).getOrder(); +} + +SmallVector getOrderForMemory(DistributedEncodingTrait layout, + ArrayRef shape) { + auto linear = toLinearEncoding(layout, shape); + auto order = linear.getOrder(); + auto threadOrder = linear.getThreadOrder(); + if (order == threadOrder) { + return order; + } + // Heuristic: + // If the element contiguity does not align with the thread order + // because the thread order dimension has contiguity of 1---meaning that + // the order position of this dimension is irrelevant---we prefer + // to use the thread order for the memory layout + auto contig = linear.getElemsPerThread(shape); + if (contig[threadOrder[0]] == 1) { + return threadOrder; + } + return order; +} + +SmallVector getThreadOrder(DistributedEncodingTrait layout, + ArrayRef shape) { + return toLinearEncoding(layout, shape).getThreadOrder(); +} + +SmallVector getWarpOrder(DistributedEncodingTrait layout, + ArrayRef shape) { + return toLinearEncoding(layout, shape).getWarpOrder(); +} + +CTALayoutAttr getCTALayout(Attribute layout) { + if (auto ttgLayout = mlir::dyn_cast(layout)) { + return CTALayoutAttr::get(layout.getContext(), getCTAsPerCGA(ttgLayout), + getCTASplitNum(ttgLayout), + getCTAOrder(ttgLayout)); + } + llvm::report_fatal_error("Unimplemented usage of getCTALayout"); + return {}; +} + +SmallVector getCTAsPerCGA(Attribute layout) { + ArrayRef ref; + if (auto ttgLayout = mlir::dyn_cast(layout)) + return ttgLayout.getCTAsPerCGA(); + else + llvm::report_fatal_error("Unimplemented usage of getCTAsPerCGA"); + return SmallVector(ref.begin(), ref.end()); +} + +SmallVector getCTASplitNum(Attribute layout) { + SmallVector res; + if (auto ttgLayout = mlir::dyn_cast(layout)) { + return ttgLayout.getCTASplitNum(); + } else if (auto tmemLayout = + mlir::dyn_cast( + layout)) { + res.resize(2); + res[0] = tmemLayout.getCTASplitM(); + res[1] = tmemLayout.getCTASplitN(); + } else if (auto tmemScaleLayout = mlir::dyn_cast< + triton::nvidia_gpu::TensorMemoryScalesEncodingAttr>(layout)) { + res.resize(2); + res[0] = tmemScaleLayout.getCTASplitM(); + res[1] = tmemScaleLayout.getCTASplitN(); + } else { + assert(false && "Unimplemented usage of getCTASplitNum"); + } + return res; +} + +SmallVector getCTAOrder(Attribute layout) { + SmallVector res; + if (auto ttgLayout = mlir::dyn_cast(layout)) { + res = ttgLayout.getCTAOrder(); + } else { + llvm::report_fatal_error("Unimplemented usage of getCTAOrder"); + } + return res; +} + +SmallVector getShapePerCTA(ArrayRef CTASplitNum, + ArrayRef shape) { + unsigned rank = shape.size(); + SmallVector shapePerCTA(rank); + for (unsigned i = 0; i < rank; ++i) { + unsigned splitNum = std::min(shape[i], CTASplitNum[i]); + shapePerCTA[i] = shape[i] / splitNum; + } + return shapePerCTA; +} + +SmallVector getShapePerCTA(Attribute layout, ArrayRef shape) { + if (mlir::isa(layout)) { + // Special logic for pipeline pass, where shape is 3D and CTALayout is 2D. + // The first dim of shape is numStages. This is a work around, otherwise + // too many places would have to be modified in pipeline pass. Maybe we + // need to refactor this logic in the future. + auto CTASplitNum = cast(layout).getCTASplitNum(); + if (shape.size() == CTASplitNum.size() + 1) { + auto res = getShapePerCTA(CTASplitNum, shape.drop_front()); + res.insert(res.begin(), shape.front()); + return res; + } + } + SmallVector splitNum = getCTASplitNum(layout); + if (auto tmem = dyn_cast(layout)) { + if (shape.size() > splitNum.size()) { + splitNum.insert(splitNum.begin(), shape.size() - splitNum.size(), 1); + } + } + return getShapePerCTA(splitNum, shape); +} + +SmallVector getAllocationShapePerCTA(Attribute layout, + ArrayRef shapeLogical) { + SmallVector shape(shapeLogical); + if (auto sharedMMALayout = mlir::dyn_cast(layout)) { + if (sharedMMALayout.getFp4Padded()) { + auto packedAxis = getOrder(sharedMMALayout, shapeLogical)[0]; + shape[packedAxis] *= 2; + } + } + return getShapePerCTA(layout, shape); +} + +SmallVector getShapePerCTA(Type type) { + auto tensorType = cast(type); + return getShapePerCTA(tensorType.getEncoding(), tensorType.getShape()); +} + +SmallVector getAllocationShapePerCTA(Type type) { + auto tensorType = cast(type); + return getAllocationShapePerCTA(tensorType.getEncoding(), + tensorType.getShape()); +} + +unsigned getNumCTAs(Attribute layout) { + return product(getCTAsPerCGA(layout)); +} + +bool isExpensiveCat(CatOp cat, Attribute targetEncoding) { + // If the new elements per thread is less than the old one, we will need to + // do convert encoding that goes through shared memory anyway. So we + // consider it as expensive. + RankedTensorType tensorTy = cat.getType(); + auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy); + auto shape = tensorTy.getShape(); + auto newTotalElemsPerThread = + gpu::getTotalElemsPerThread(targetEncoding, shape); + return newTotalElemsPerThread < totalElemsPerThread; +} + +static LogicalResult +verifyLayoutOrder(function_ref emitError, + ArrayRef order) { + if (!isPermutationOfIota(order)) { + return emitError() + << "order must be a permutation of 0..(rank-1), but was [" << order + << "]"; + } + return success(); +} + +LogicalResult CTALayoutAttr::verify( + function_ref emitError, ArrayRef CTAsPerCGA, + ArrayRef CTASplitNum, ArrayRef CTAOrder) { + if (!llvm::all_equal( + {CTAsPerCGA.size(), CTASplitNum.size(), CTAOrder.size()})) { + return emitError() << "CTAsPerCGA, CTASplitNum, and CTAOrder must all have " + "the same rank."; + } + + if (failed(verifyLayoutOrder(emitError, CTAOrder))) + return failure(); + + if (llvm::any_of(CTAsPerCGA, [](unsigned x) { return x == 0; })) { + return emitError() << "Every element in CTAsPerCGA must be greater than 0."; + } + + if (llvm::any_of(CTASplitNum, [](unsigned x) { return x == 0; })) { + return emitError() + << "Every element in CTASplitNum must be greater than 0."; + } + + return success(); +} + +LogicalResult +BlockedEncodingAttr::verify(function_ref emitError, + ArrayRef sizePerThread, + ArrayRef threadsPerWarp, + ArrayRef warpsPerCTA, + ArrayRef order, CTALayoutAttr CTALayout) { + if (!llvm::all_equal({sizePerThread.size(), threadsPerWarp.size(), + warpsPerCTA.size(), order.size()})) { + return emitError() << "sizePerThread, threadsPerWarp, warpsPerCTA, and " + "order must all have the same rank."; + } + + // Empty CTALayout is allowed, but if it's present its rank must match the + // BlockedEncodingAttr's rank. + if (order.size() != CTALayout.getRank()) { + return emitError() << "BlockedEncodingAttr and CTALayout's fields must " + "have the same rank."; + } + return verifyLayoutOrder(emitError, order); +} + +// 1 element per thread +// order = reverse(arange(rank)) +triton::gpu::BlockedEncodingAttr +getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, + int numWarps, int threadsPerWarp, int numCTAs) { + int rank = shape.size(); + llvm::SmallVector order(rank); + std::iota(order.begin(), order.end(), 0); + std::reverse(order.begin(), order.end()); + llvm::SmallVector sizePerThread(rank, 1); + triton::gpu::BlockedEncodingAttr encoding = + triton::gpu::BlockedEncodingAttr::get(context, shape, sizePerThread, + order, numWarps, threadsPerWarp, + numCTAs); + return encoding; +} + +LogicalResult tryJoinOnAxis(MLIRContext *ctx, const LinearLayout &inLl, + LinearLayout &outLl, bool fwdInference, int axis, + std::optional loc) { + auto kRegister = StringAttr::get(ctx, "register"); + auto outDims = llvm::to_vector(inLl.getOutDimNames()); + if (fwdInference) { + auto split = LinearLayout::identity1D(2, kRegister, outDims[axis]); + outLl = split * inLl; + } else { + // Assert that there is a dimension with size 2 in the axis + // that has contiguous elements + // Note that this is more general than the fwdInference case in that + // - It allows the dimension not to be the fastest running + // - It allows broadcasting + // In general, this allows us to split along any axis as long as + // the basis (0, 0, ..., 0, 1, 0, ..., 0) is in the registers. + bool found = false; + LinearLayout::BasesT newBases; + for (const auto &basesDim : inLl.getBases()) { + std::vector> newBasesDim; + for (auto base : basesDim.second) { + if (base[axis] == 1 && basesDim.first == kRegister) { + found = true; + continue; + } + base[axis] /= 2; + newBasesDim.push_back(std::move(base)); + } + newBases.insert({basesDim.first, std::move(newBasesDim)}); + } + if (!found) + return emitOptionalError(loc, + "Fp4ToFpOp/SplitOp requires at least 2 elements " + "per thread in the axis/last dimension"); + outLl = LinearLayout(std::move(newBases), std::move(outDims)); + } + return success(); +} + +} // namespace gpu +} // namespace triton +} // namespace mlir + +static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr, + unsigned &value, StringRef desc) { + auto intAttr = mlir::dyn_cast(attr); + if (!intAttr) { + parser.emitError(parser.getNameLoc(), "expected an integer type in ") + << desc; + return failure(); + } + if (intAttr.getType().isSignedInteger()) { + int64_t attrVal = intAttr.getSInt(); + if (attrVal < 0) { + parser.emitError(parser.getNameLoc(), + "expected an unsigned integer value in ") + << desc; + return failure(); + } + value = attrVal; + } else if (intAttr.getType().isSignlessInteger()) { + int64_t attrVal = intAttr.getInt(); + if (attrVal < 0) { + parser.emitError(parser.getNameLoc(), + "expected an unsigned integer value in ") + << desc; + return failure(); + } + value = attrVal; + } else { + value = intAttr.getUInt(); + } + return success(); +} + +static LogicalResult parseBoolAttrValue(AsmParser &parser, Attribute attr, + bool &value, StringRef desc) { + auto boolAttr = mlir::dyn_cast(attr); + if (!boolAttr) { + parser.emitError(parser.getNameLoc(), "expected a bool type in ") << desc; + return failure(); + } + value = boolAttr.getValue(); + return success(); +} + +// parse an array of integers +static LogicalResult parseIntArrayAttr(AsmParser &parser, + const NamedAttribute &attr, + SmallVector &res, + StringRef desc) { + auto arrayAttr = mlir::dyn_cast(attr.getValue()); + if (!arrayAttr) { + parser.emitError(parser.getNameLoc(), "expected an array for ") << desc; + return failure(); + } + for (Attribute i : arrayAttr) { + unsigned value; + if (parseIntAttrValue(parser, i, value, desc).failed()) + return failure(); + res.push_back(value); + } + return success(); +}; + +static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr, + unsigned &value, StringRef desc) { + return parseIntAttrValue(parser, attr.getValue(), value, desc); +}; + +static LogicalResult parseBool(AsmParser &parser, const NamedAttribute &attr, + bool &value, StringRef desc) { + return parseBoolAttrValue(parser, attr.getValue(), value, desc); +}; + +// Print the CTALayout if it's not equal to the default. +static void maybePrintCTALayout(mlir::MLIRContext *context, + mlir::AsmPrinter &printer, CTALayoutAttr layout, + unsigned rank) { + if (layout != CTALayoutAttr::getDefault(context, rank)) { + printer << ", CTAsPerCGA = [" << ArrayRef(layout.getCTAsPerCGA()) << "]" + << ", CTASplitNum = [" << ArrayRef(layout.getCTASplitNum()) << "]" + << ", CTAOrder = [" << ArrayRef(layout.getCTAOrder()) << "]"; + } +} + +//===----------------------------------------------------------------------===// +// Attribute methods +//===----------------------------------------------------------------------===// +#include "triton/Dialect/TritonGPU/IR/AttrInterfaces.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/AttrDefs.cpp.inc" + +// If we only had BlockedEncodingAttr, we could simply return ArrayRefs here. +// But we need to have a consistent interface with e.g. SliceEncodingAttr, which +// computes some of these fields. +SmallVector BlockedEncodingAttr::getRepOrder() const { + return SmallVector(getOrder()); +} +SmallVector BlockedEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector BlockedEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector BlockedEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} + +template +SmallVector SliceEncodingAttr::paddedShape(ArrayRef shape) const { + size_t rank = shape.size(); + unsigned dim = getDim(); + SmallVector retShape(rank + 1); + for (unsigned d = 0; d < rank + 1; ++d) { + if (d < dim) + retShape[d] = shape[d]; + else if (d == dim) + retShape[d] = 1; + else + retShape[d] = shape[d - 1]; + } + return retShape; +} +template SmallVector +SliceEncodingAttr::paddedShape(ArrayRef shape) const; +template SmallVector +SliceEncodingAttr::paddedShape(ArrayRef shape) const; +SmallVector SliceEncodingAttr::getRepOrder() const { + auto parentRepOrder = getParent().getRepOrder(); + return eraseOrder(parentRepOrder, getDim()); +} +SmallVector SliceEncodingAttr::getCTASplitNum() const { + SmallVector res = ::getCTASplitNum(getParent()); + res.erase(res.begin() + getDim()); + return res; +} +SmallVector SliceEncodingAttr::getCTAOrder() const { + auto parentCTAOrder = ::getCTAOrder(getParent()); + return eraseOrder(parentCTAOrder, getDim()); +} +SmallVector SliceEncodingAttr::getCTAsPerCGA() const { + auto parentCTAsPerCGA = ::getCTAsPerCGA(getParent()); + if (parentCTAsPerCGA[getDim()] == 1) { + parentCTAsPerCGA.erase(parentCTAsPerCGA.begin() + getDim()); + return parentCTAsPerCGA; + } + /* For getCTAsPerCGA of a slice layout, we have two choices: + * (1) Return CTAsPerCGA of its parent. This is not a perfect solution + * because the rank of the returned CTAsPerCGA does not match the rank of + * tensorShape. + * (2) Get CTAsPerCGA of its parent and erase the sliced dim. This is not a + * perfect solution because the product of the returned CTAsPerCGA might not + * match numCTAs. + * To avoid introducing inconsistencies to the shape and + * layout system, the usage of directly getting CTAsPerCGA of a slice layout + * in which the sliced dim is not 1 is banned. You should always consider + * slice layout as a special case and use getCTAsPerCGA(layout.getParent()) + * in the branch where layout is an instance of SliceEncodingAttr. This is + * inconvenient but safe. + */ + llvm::report_fatal_error( + "getCTAsPerCGA for SliceEncodingAttr is not well-defined"); +} + +// Wmma encoding + +int32_t SwizzledSharedEncodingAttr::getAlignment() const { return 16; } + +SmallVector SwizzledSharedEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector SwizzledSharedEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector SwizzledSharedEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} + +int32_t AMDRotatingSharedEncodingAttr::getAlignment() const { return 16; } + +SmallVector AMDRotatingSharedEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector AMDRotatingSharedEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector AMDRotatingSharedEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} + +SmallVector DotOperandEncodingAttr::getCTAsPerCGA() const { + return ::getCTAsPerCGA(getParent()); +} +SmallVector DotOperandEncodingAttr::getCTAOrder() const { + return ::getCTAOrder(getParent()); +} +SmallVector DotOperandEncodingAttr::getCTASplitNum() const { + SmallVector res = ::getCTASplitNum(getParent()); + auto rank = res.size(); + assert(rank == 2 || rank == 3 && "Invalid dotLayout"); + + // Do not split CTA in K dimension + auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2; + res[kDim] = 1; + return res; +} + +LogicalResult DotOperandEncodingAttr::verify( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + unsigned opIdx, Attribute parent, unsigned kWidth) { + if (opIdx != 0 && opIdx != 1) { + return emitError() << "ttg.dot_op opIdx parameter can be 0 or 1, got: " + << opIdx; + } + if (!parent) { + return emitError() << "ttg.dot_op parent parameter cannot be null"; + } + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper())) + return emitError() << "ttg.dot_op kWidth parameter can only be " + "non-zero for Ampere or Hopper MMA parent"; + if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper())) + return emitError() << "ttg.dot_op kWidth parameter is mandatory for " + "Ampere or Hopper MMA parent"; + if (opIdx != 0 && parentAttr.isHopper()) + return emitError() + << "ttg.dot_op opIdx parameter must be 0 for " + "Hopper MMA parent, since Hopper WGMMA only allows first " + "operand to be in registers"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 16 && parentAttr.getVersion() == 1 || + kWidth != 8 && kWidth != 16 && parentAttr.getVersion() == 2) + return emitError() << "ttg.dot_op kWidth parameter must be 16 for " + "gfx11 and 8/16 for gfx12"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth == 0) + return emitError() << "ttg.dot_op kWidth parameter is mandatory for " + "MFMA parent"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + return success(); + } + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 0) + return emitError() << "ttg.dot_op kWidth parameter is not supported " + "when the parent is a blocked layout"; + return success(); + } + + return emitError() << "ttg.dot_op unexpected parent layout: " << parent; +} + +//===----------------------------------------------------------------------===// +// Blocked Encoding +//===----------------------------------------------------------------------===// + +static std::optional getCTALayoutOrError( + AsmParser &parser, std::optional> CTAsPerCGA, + std::optional> CTASplitNum, + std::optional> CTAOrder, unsigned rank) { + if (CTAsPerCGA && CTASplitNum && CTAOrder) { + return CTALayoutAttr::get(parser.getContext(), *CTAsPerCGA, *CTASplitNum, + *CTAOrder); + } + if (!CTAsPerCGA && !CTASplitNum && !CTAOrder) { + return CTALayoutAttr::getDefault(parser.getContext(), rank); + } + parser.emitError(parser.getNameLoc(), "CTAsPerCGA, CTASplitNum, and CTAOrder " + "must all be present or all be absent"); + return std::nullopt; +} + +Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + SmallVector sizePerThread; + SmallVector threadsPerWarp; + SmallVector warpsPerCTA; + SmallVector order; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "sizePerThread") { + if (parseIntArrayAttr(parser, attr, sizePerThread, + "number of elements per thread") + .failed()) + return {}; + } else if (attr.getName() == "threadsPerWarp") { + if (parseIntArrayAttr(parser, attr, threadsPerWarp, + "number of threads per warp") + .failed()) + return {}; + } else if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, + "number of warps per CTA") + .failed()) + return {}; + } else if (attr.getName() == "order") { + if (parseIntArrayAttr(parser, attr, order, "order").failed()) + return {}; + } else if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } else if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } else if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/sizePerThread.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked(parser.getContext(), + sizePerThread, threadsPerWarp, + warpsPerCTA, order, *CTALayout); +} + +void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { + printer << "<{" + << "sizePerThread = [" << ArrayRef(getSizePerThread()) << "]" + << ", threadsPerWarp = [" << ArrayRef(getThreadsPerWarp()) << "]" + << ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]" + << ", order = [" << getOrder() << "]"; + + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getSizePerThread().size()); + + printer << "}>"; +} + +// FIXME Can we take the LinearLayout by const&? +LogicalResult +LinearEncodingAttr::verify(function_ref emitError, + LinearLayout linearLayout) { + // Example of LinearEncodingAttr + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + // The input dims must be {register, lane, warp, block} + // The output dims of the linear layout should be dim0..dim[rank-1] + + static const auto expectedInDims = + SmallVector({"register", "lane", "warp", "block"}); + for (const auto &[i, dims] : llvm::enumerate( + llvm::zip(linearLayout.getInDimNames(), expectedInDims))) { + const auto &[dim, expectedDimStr] = dims; + if (dim.str() != expectedDimStr) { + return emitError() << "Expected input dimension " << i << " to be '" + << expectedDimStr << "'. Got " << dim; + } + } + + // outDims are ['dim0', 'dim1', ...] + for (auto [i, dim] : llvm::enumerate(linearLayout.getOutDimNames())) { + if (dim.str() != ("dim" + llvm::Twine(i)).str()) { + return emitError() + << "Expected output dimensions to be ['dim0', 'dim1', ...]. Got " + << dim << " at position " << i; + } + } + + const auto &bases = linearLayout.getBases(); + auto nonZero = [](auto val) { return val != 0; }; + for (const auto &dimBases : llvm::make_second_range(bases)) { + if (!llvm::all_of(dimBases, [&](const auto &basis) { + return std::count_if(basis.begin(), basis.end(), nonZero) <= 1; + })) { + return emitError() + << "In a distributed layout, each base must move in at most one " + "dimension."; + } + } + + return success(); +} + +void LinearEncodingAttr::print(mlir::AsmPrinter &printer) const { + // We don't use the default implementation as it's a bit too verbose + // This prints in the following format that is shape agnostic, in the sense + // that we don't print explicitly the outShape of the LL + // We always assume LLs to be surjective + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + auto ll = getLinearLayout(); + printer << "<{" << join(ll.getBases(), ", ", [](const auto &base) { + return base.first.str() + " = " + "[" + + join(base.second, ", ", + [](const std::vector &vec) { + return "[" + join(vec, ", ") + "]"; + }) + + "]"; + }) << "}>"; +} + +Attribute LinearEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + + if (parser.parseGreater().failed()) + return {}; + + LinearLayout::BasesT bases; + + // Parse the basis names in order (the order is relevant) + std::vector inDimNames = {"register", "lane", "warp", "block"}; + + for (const auto &inDimNameStr : inDimNames) { + auto inDimName = StringAttr::get(parser.getContext(), inDimNameStr); + Attribute value = dict.get(inDimName); + + // Expecting an array of arrays + auto arrayOfArraysAttr = mlir::dyn_cast(value); + if (!arrayOfArraysAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected array of arrays for basis of '") + << inDimName.getValue() << "'"; + return {}; + } + + std::vector> inDimBases; + for (Attribute arrayAttr : arrayOfArraysAttr) { + auto intArrayAttr = mlir::dyn_cast(arrayAttr); + if (!intArrayAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected array of integers in basis for '") + << inDimName.getValue() << "'"; + return {}; + } + std::vector basis; + for (Attribute intAttr : intArrayAttr) { + auto intValueAttr = mlir::dyn_cast(intAttr); + if (!intValueAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected integer in basis for '") + << inDimName.getValue() << "'"; + return {}; + } + basis.push_back(intValueAttr.getInt()); + } + inDimBases.push_back(std::move(basis)); + } + bases[inDimName] = std::move(inDimBases); + } + size_t rank = 0; + for (const auto &basesDim : llvm::make_second_range(bases)) { + if (!basesDim.empty()) { + rank = basesDim[0].size(); + break; + } + } + + // To implement this we'd need to serialise the rank as well. + // We can do this if we ever need it + if (rank == 0) { + parser.emitError(parser.getCurrentLocation(), "Empty Layout not supported"); + return {}; + } + + // Generate standared outDimNames (dim0, dim1, ...) + SmallVector outDimNames; + for (int i = 0; i < rank; ++i) { + outDimNames.push_back( + StringAttr::get(parser.getContext(), "dim" + llvm::Twine(i))); + } + + // Create LinearLayout + LinearLayout linearLayout(std::move(bases), std::move(outDimNames)); + + // Create and return the LinearEncodingAttr + return parser.getChecked(parser.getContext(), + std::move(linearLayout)); +} + +SmallVector basesPerDimImpl(const LinearLayout::BasesT &namedBases, + StringAttr dimName, size_t rank, + bool skipBroadcast = true) { + const auto &bases = namedBases.find(dimName)->second; + + if (bases.empty()) { + return SmallVector(rank, 1); + } + + SmallVector ret(rank, 1); + auto nonZero = [](auto val) { return val != 0; }; + int nonZeroIdx = 0; + for (const auto &basis : bases) { + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + // Bases can have one or zero non-zero elements + // Skip a basis if it's broadcasting (all zeros) + // e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout) + if (it != basis.end()) { + nonZeroIdx = it - basis.begin(); + ret[nonZeroIdx] *= 2; + } else if (!skipBroadcast) { + // If we've seen a non-zero basis, we double the size of the previous dim + // This is just needed to count the CTAsPerCGA + ret[nonZeroIdx] *= 2; + } + } + return ret; +} + +SmallVector +LinearEncodingAttr::basesPerDim(StringAttr dimName, bool skipBroadcast) const { + auto ll = getLinearLayout(); + auto rank = ll.getNumOutDims(); + return basesPerDimImpl(ll.getBases(), dimName, rank, skipBroadcast); +} + +SmallVector +LinearEncodingAttr::orderPerDim(StringAttr dimName, + ArrayRef defaultOrder) const { + auto ll = getLinearLayout(); + const auto &bases = ll.getBases().find(dimName)->second; + llvm::SetVector order; + auto nonZero = [](auto val) { return val != 0; }; + for (const auto &basis : bases) { + // Bases can have one or zero non-zero elements + // Skip a basis if it's broadcasting (all zeros) + // e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout) + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + if (it != basis.end()) { + auto i = it - basis.begin(); + order.insert(i); + } + } + // If any dim is missing, we add them in the defaultOrder + for (auto i : defaultOrder) { + order.insert(i); + } + return SmallVector(order.begin(), order.end()); +} + +// [Note. Divergence of methods wrt. legacy layouts] +// For smaller shapes where the CTATile is larger than the output +// tensor, some methods return different values than the legacy layouts. I think +// this is benign tho. An example: what is the vector of `warpsPerCTA` if +// all the warps hold the same data? I think it should be [1, 1], even if we +// have 4 warps. But perhaps for this we have to add some masking in some +// places... We'll see +SmallVector LinearEncodingAttr::getRepOrder() const { + // This is not correct, but: + // - It happens to agree in most places with the legacy layout + // - getRepOrder does not make sense for LinearEncodingAttr as it already has + // the same shape as the tensor that uses it + return getOrder(); +} +SmallVector LinearEncodingAttr::getCTAsPerCGA() const { + // CTAs are split into an identity part (SplitNum) and a broadcast part + return basesPerDim(StringAttr::get(getContext(), "block"), + /*skipBroadcast=*/false); +} +SmallVector LinearEncodingAttr::getCTAOrder() const { + return orderPerDim(StringAttr::get(getContext(), "block"), getOrder()); +} +SmallVector LinearEncodingAttr::getCTASplitNum() const { + return basesPerDim(StringAttr::get(getContext(), "block")); +} +SmallVector LinearEncodingAttr::getWarpsPerCTA() const { + return basesPerDim(StringAttr::get(getContext(), "warp")); +} +SmallVector LinearEncodingAttr::getWarpOrder() const { + return orderPerDim(StringAttr::get(getContext(), "warp"), getOrder()); +} +SmallVector LinearEncodingAttr::getThreadsPerWarp() const { + return basesPerDim(StringAttr::get(getContext(), "lane")); +} +SmallVector LinearEncodingAttr::getThreadOrder() const { + return orderPerDim(StringAttr::get(getContext(), "lane"), getOrder()); +} +SmallVector LinearEncodingAttr::getSizePerThread() const { + auto rank = getOrder().size(); + auto ll = getLinearLayout(); + auto ctx = getContext(); + auto kRegister = StringAttr::get(ctx, "register"); + + // We canonicalize on the spot, as if we use CGAs the regs are not in + // canonical form The order is [reg, lane, warp, rep, block], so we first + // remove the blocks + llvm::SmallVector ctaShape; + for (auto [shape, cgaNum] : + llvm::zip(ll.getOutDimSizes(), getCTASplitNum())) { + ctaShape.push_back(shape / cgaNum); + } + LinearLayout::BasesT bases = ll.getBases(); + + llvm::SetVector reverseRepOrder; + auto nonZero = [](auto val) { return val != 0; }; + auto ®isters = bases[kRegister]; + while (!registers.empty()) { + auto &basis = registers.back(); + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + // If there's broadcasting (base == zeros) there are no more reps + if (it == basis.end()) { + break; + } + auto dim = it - basis.begin(); + reverseRepOrder.insert(dim); + // As soon as we stop finding reps, we stop + if (dim != reverseRepOrder.back() || 2 * basis[dim] != ctaShape[dim]) { + break; + } + ctaShape[dim] /= 2; + registers.pop_back(); + } + return basesPerDimImpl(bases, kRegister, rank); +} + +SmallVector LinearEncodingAttr::getShapePerCTATile() const { + auto sizePerThread = getSizePerThread(); + auto threadsPerWarp = getThreadsPerWarp(); + auto warpsPerCTA = getWarpsPerCTA(); + SmallVector shape; + for (auto [size, thread, warp] : + llvm::zip(sizePerThread, threadsPerWarp, warpsPerCTA)) { + shape.push_back(size * thread * warp); + } + return shape; +} + +SmallVector LinearEncodingAttr::getOrder() const { + auto rank = getLinearLayout().getNumOutDims(); + SmallVector order(rank); + // Choose [rank-1, rank-2, ... 0] as the default order in case + // there are dims that do not move in the register + // This order is as good as any really + std::iota(order.rbegin(), order.rend(), 0); + + return orderPerDim(StringAttr::get(getContext(), "register"), order); +} + +LinearLayout LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { + auto ll = getLinearLayout(); + auto canonicalDims = llvm::to_vector(ll.getOutDimNames()); + llvm::SmallDenseMap namedShape; + llvm::SmallVector permutedDims; + for (auto dim : getRepOrder()) { + permutedDims.push_back(canonicalDims[dim]); + namedShape[canonicalDims[dim]] = shape[dim]; + } + ll = ll.transposeOuts(permutedDims); + ll = ensureLayoutNotSmallerThan(ll, namedShape); + ll = ensureLayoutNotLargerThan(ll, namedShape, /*broadcastRegisters=*/false); + ll = ll.transposeOuts(canonicalDims); + return ll; +} + +SmallVector +LinearEncodingAttr::getElemsPerThread(ArrayRef shape) const { + // When broadcasting the layout the shape changes, otherwise the shape is + // the same as the shape of the tensor + // We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep + // the invariant that the shape of the LL is that of the tensor + // We choose the former for BC + auto scaledLayout = get(getContext(), toLinearLayout(shape)); + auto kRegister = StringAttr::get(getContext(), "register"); + return scaledLayout.basesPerDim(kRegister, /*skipBroadcast=*/false); +} + +SmallVector +LinearEncodingAttr::getContig(const char *inDim, + SmallVector lowerContig) const { + auto ll = getLinearLayout(); + const auto &bases = + ll.getBases().find(StringAttr::get(getContext(), inDim))->second; + auto order = getOrder(); + auto rank = order.size(); + + SmallVector contig(lowerContig); + auto basisIt = bases.begin(); + for (unsigned dim : order) { + std::vector basis(rank, 0); + basis[dim] = contig[dim]; + + while (basisIt != bases.end() && *basisIt == basis) { + contig[dim] *= 2; + basis[dim] *= 2; + ++basisIt; + } + } + return contig; +} + +SmallVector LinearEncodingAttr::getContigPerThread() const { + SmallVector contig(getOrder().size(), 1); + return getContig("register", contig); +} + +SmallVector LinearEncodingAttr::getContigPerWarp() const { + return getContig("lane", getContigPerThread()); +} + +unsigned +LinearEncodingAttr::getTotalElemsPerThread(ArrayRef shape) const { + return product(getElemsPerThread(shape)); +} + +//===----------------------------------------------------------------------===// +// MMA encoding +//===----------------------------------------------------------------------===// + +Attribute NvidiaMmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned versionMajor = 0; + unsigned versionMinor = 0; + SmallVector warpsPerCTA; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + SmallVector instrShape; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "versionMajor") { + if (parseUInt(parser, attr, versionMajor, "versionMajor").failed()) + return {}; + } + if (attr.getName() == "versionMinor") { + if (parseUInt(parser, attr, versionMinor, "versionMinor").failed()) + return {}; + } + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } + if (attr.getName() == "instrShape") { + if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) { + return {}; + } + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), versionMajor, versionMinor, warpsPerCTA, *CTALayout, + instrShape); +} + +void NvidiaMmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "versionMajor = " << getVersionMajor() + << ", versionMinor = " << getVersionMinor() // + << ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]"; + + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getRank()); + + printer << ", instrShape = [" << getInstrShape() << "]}>"; +} + +//===----------------------------------------------------------------------===// +// MFMA encoding +//===----------------------------------------------------------------------===// + +Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned versionMajor = 0; + unsigned versionMinor = 0; + SmallVector warpsPerCTA; + SmallVector instrShape; + bool isTransposed; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "versionMajor") { + if (parseUInt(parser, attr, versionMajor, "versionMajor").failed()) + return {}; + } + if (attr.getName() == "versionMinor") { + if (parseUInt(parser, attr, versionMinor, "versionMinor").failed()) + return {}; + } + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "instrShape") { + if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) + return {}; + } + + if (attr.getName() == "isTransposed") { + if (parseBool(parser, attr, isTransposed, "isTransposed").failed()) + return {}; + } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), versionMajor, versionMinor, warpsPerCTA, + instrShape[0], instrShape[1], isTransposed, *CTALayout); +} + +void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "versionMajor = " << getVersionMajor() // + << ", versionMinor = " << getVersionMinor() // + << ", warpsPerCTA = [" << getWarpsPerCTA() << "]" // + << ", instrShape = [" << ArrayRef{getMDim(), getNDim()} << "]" // + << ", isTransposed = " << getIsTransposed(); + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getRank()); + printer << "}>"; +} + +LogicalResult +AMDMfmaEncodingAttr::verify(function_ref emitError, + unsigned versionMajor, unsigned versionMinor, + llvm::ArrayRef warpsPerCTA, + unsigned mDim, unsigned nDim, bool isTransposed, + mlir::triton::gpu::CTALayoutAttr) { + if (!(versionMajor >= 0 && versionMajor <= 4)) { + return emitError() << "major version must be in the [0, 4] range"; + } + if (versionMinor != 0) { + return emitError() << "minor version must be 0"; + } + if (!((mDim == 32 && nDim == 32) || (mDim == 16 && nDim == 16))) { + return emitError() + << "(M, N) cases other than (32, 32) or (16, 16) unimplemented"; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// WMMA encoding +//===----------------------------------------------------------------------===// + +Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned version = 0; + bool isTransposed = false; + SmallVector warpsPerCTA; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "version") { + if (parseUInt(parser, attr, version, "version").failed()) + return {}; + } + if (attr.getName() == "isTranspose") { + if (parseBool(parser, attr, isTransposed, "isTranspose").failed()) + return {}; + } + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), version, isTransposed, warpsPerCTA, *CTALayout); +} + +void AMDWmmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "version = " << getVersion() + << ", isTranspose = " << getIsTransposed() << ", warpsPerCTA = [" + << ArrayRef(getWarpsPerCTA()) << "]"; + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getWarpsPerCTA().size()); + printer << "}>"; +} + +LogicalResult +AMDWmmaEncodingAttr::verify(function_ref emitError, + unsigned version, bool isTransposed, + llvm::ArrayRef warpsPerCTA, + mlir::triton::gpu::CTALayoutAttr) { + if (version != 1 && version != 2) { + return emitError() << "WMMA version must be in the [1, 2] range"; + } + // Transposed layout is needed for bypassing LDS between multiple dots. + // Version 1 tt.dot results and tt.dot operand layouts are different, + // therefore we test and support transposed only for version 2. + if (version != 2 && isTransposed) { + return emitError() << "Transposed WMMA is supported only for version 2"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Sunrise_MMA Encoding +//===----------------------------------------------------------------------===// +Attribute SunriseMmaEncodingAttr::parse(AsmParser &parser, Type type) { + DictionaryAttr dict; + if (parser.parseLess().failed()) { return {}; } + if (parser.parseAttribute(dict).failed()) { return {};} + if (parser.parseGreater().failed()) {return {};} + + unsigned versionMajor = 0; + unsigned versionMinor = 0; + SmallVector warpsPerCTA; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + SunriseMmaEncodingAttr::TMMAOutLayout outLayout; + unsigned outLayoutUint = static_cast(SunriseMmaEncodingAttr::TMMAOutLayout::NotAvailable); + unsigned inputElemBitWidth = 0; + unsigned outputElemBitWidth = 0; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "versionMajor") { if (parseUInt(parser, attr, versionMajor, "versionMajor").failed()) {return {};} } + if (attr.getName() == "versionMinor") { if (parseUInt(parser, attr, versionMinor, "versionMinor").failed()) {return {};} } + if (attr.getName() == "warpsPerCTA") { if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) {return {};} } + if (attr.getName() == "CTAsPerCGA") { if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA").failed()) {return {};} } + if (attr.getName() == "CTASplitNum") { if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum").failed()){return {};} } + if (attr.getName() == "CTAOrder") { if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder").failed()) {return {};} } + if (attr.getName() == "outLayout") { if(parseUInt(parser, attr, outLayoutUint, "outLayout").failed()) {return {}; } } + if (attr.getName() == "inputElemBitWidth") { if(parseUInt(parser, attr, inputElemBitWidth, "inputElemBitWidth").failed()) {return {}; } } + if (attr.getName() == "outputElemBitWidth") { if(parseUInt(parser, attr, outputElemBitWidth, "outputElemBitWidth").failed()) {return {}; } } + } + outLayout = static_cast(outLayoutUint); + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), versionMajor, versionMinor, warpsPerCTA, *CTALayout, outLayout, inputElemBitWidth, outputElemBitWidth); +} + +void SunriseMmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "versionMajor = " << getVersionMajor() + << ", versionMinor = " << getVersionMinor() + << ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]"; + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getWarpsPerCTA().size()); + + // printer << ", order = " << sv2str(getOrder(*this)) + // << ", threadsPerWarp = " << sv2str(getThreadsPerWarp()) + // << ", threadOrder = " << sv2str(getThreadOrder()) + // << ", sizePerThread = " << sv2str(getSizePerThread()) + // << ", shapePerCTATile = " << sv2str(getShapePerCTATile()) + // << ", outLayout = " << static_cast(getOutLayout()) + // << ", inputElemBitWidth = " << getInputElemBitWidth() + // << ", outputElemBitWidth = " << getOutputElemBitWidth(); + + printer << "}>"; +} + +//===----------------------------------------------------------------------===// +// Sliced Encoding +//===----------------------------------------------------------------------===// + +Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + NamedAttrList attrs; + if (parser.parseOptionalAttrDict(attrs).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + unsigned dim = mlir::cast(attrs.get("dim")).getInt(); + auto parent = mlir::dyn_cast(attrs.get("parent")); + if (!parent) { + parser.emitError(parser.getNameLoc(), + "expected a distributed encoding trait"); + return {}; + } + return parser.getChecked(parser.getContext(), dim, parent); +} + +void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const { + printer << "<{" + << "dim = " << getDim() << ", " + << "parent = " << getParent() << "}>"; +} + +//===----------------------------------------------------------------------===// +// Helper shared encoding functions +//===----------------------------------------------------------------------===// + +template +Attribute parseSwizzledEncoding(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned vec = 0; + unsigned perPhase = 0; + unsigned maxPhase = 0; + SmallVector order; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "vec") { + if (parseUInt(parser, attr, vec, "vec").failed()) + return {}; + } else if (attr.getName() == "perPhase") { + if (parseUInt(parser, attr, perPhase, "perPhase").failed()) + return {}; + } else if (attr.getName() == "maxPhase") { + if (parseUInt(parser, attr, maxPhase, "maxPhase").failed()) + return {}; + } else if (attr.getName() == "order") { + if (parseIntArrayAttr(parser, attr, order, "order").failed()) + return {}; + } else if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } else if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } else if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/order.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked(parser.getContext(), vec, perPhase, + maxPhase, order, *CTALayout); +} + +//===----------------------------------------------------------------------===// +// SwizzledShared encoding +//===----------------------------------------------------------------------===// + +LogicalResult +SwizzledSharedEncodingAttr::verify(function_ref emitError, + unsigned vec, unsigned perPhase, + unsigned maxPhase, ArrayRef order, + CTALayoutAttr ctaLayout) { + if (order.size() != ctaLayout.getRank()) { + return emitError() << "order size (" << order.size() + << ") must match CTALayout rank (" << ctaLayout.getRank() + << ")"; + } + return verifyLayoutOrder(emitError, order); +} + +Attribute SwizzledSharedEncodingAttr::parse(AsmParser &parser, Type type) { + return parseSwizzledEncoding(parser, type); +} + +void SwizzledSharedEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "vec = " << getVec() // + << ", perPhase = " << getPerPhase() + << ", maxPhase = " << getMaxPhase() // + << ", order = [" << getOrder() << "]"; + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getOrder().size()); + printer << "}>"; +} + +//===----------------------------------------------------------------------===// +// NVMMAShared encoding +//===----------------------------------------------------------------------===// + +Attribute NVMMASharedEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned swizzlingByteWidth; + bool transposed = false; + bool fp4Padded = false; + unsigned elementBitWidth; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "swizzlingByteWidth") { + if (parseUInt(parser, attr, swizzlingByteWidth, "swizzlingByteWidth") + .failed()) + return {}; + } else if (attr.getName() == "transposed") { + if (parseBool(parser, attr, transposed, "transposed").failed()) + return {}; + } else if (attr.getName() == "elementBitWidth") { + if (parseUInt(parser, attr, elementBitWidth, "elementBitWidth").failed()) + return {}; + } else if (attr.getName() == "fp4Padded") { + if (parseBool(parser, attr, fp4Padded, "fp4Padded").failed()) + return {}; + } else if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } else if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } else if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/2); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), swizzlingByteWidth, transposed, elementBitWidth, + fp4Padded, *CTALayout); +} + +void NVMMASharedEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "swizzlingByteWidth = " << getSwizzlingByteWidth() // + << ", transposed = " << getTransposed() // + << ", elementBitWidth = " << getElementBitWidth(); + if (getFp4Padded()) { + // Print only in this case to reduce the noise for the more common case. + printer << ", fp4Padded = true"; + } + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/2); + printer << "}>"; +} + +int NVMMASharedEncodingAttr::getVec() const { + if (getSwizzlingByteWidth() == 0) + return 1; + return 128 / getElementBitWidth(); +} + +int NVMMASharedEncodingAttr::getPerPhase() const { + if (getSwizzlingByteWidth() == 0) + return 1; + return 128 / getSwizzlingByteWidth(); +} + +int NVMMASharedEncodingAttr::getMaxPhase() const { + if (getSwizzlingByteWidth() == 0) + return 1; + return getSwizzlingByteWidth() / 16; +} + +int32_t NVMMASharedEncodingAttr::getAlignment() const { + return 128 * getMaxPhase(); +} + +SmallVector NVMMASharedEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector NVMMASharedEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector NVMMASharedEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} + +//===----------------------------------------------------------------------===// +// AMDRotatingShared encoding +//===----------------------------------------------------------------------===// + +Attribute AMDRotatingSharedEncodingAttr::parse(AsmParser &parser, Type type) { + return parseSwizzledEncoding(parser, type); +} + +void AMDRotatingSharedEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "vec = " << getVec() // + << ", perPhase = " << getPerPhase() + << ", maxPhase = " << getMaxPhase() // + << ", order = [" << getOrder() << "]"; + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getOrder().size()); + printer << "}>"; +} + +//===----------------------------------------------------------------------===// +// Mfma encoding +//===----------------------------------------------------------------------===// +// TODO: there is a lot of common code with MmaEncoding here + +SmallVector AMDMfmaEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector AMDMfmaEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector AMDMfmaEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} + +SmallVector +AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const { + unsigned mDim = getMDim(); + unsigned nDim = getNDim(); + assert((mDim == nDim) && (mDim == 32 || mDim == 16 || mDim == 4) || + (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + constexpr int warpSize = 64; // MFMA is always based on the 64-wide warps. + int kGroups = -1; + if (mDim == nDim) + kGroups = warpSize / mDim; + if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64) + kGroups = 1; + int64_t kDim = kWidth * kGroups; + if (opIdx == 0) + return {mDim, kDim}; + else + assert(opIdx == 1); + return {kDim, nDim}; +} + +SmallVector AMDMfmaEncodingAttr::getRepOrder() const { + return getMatrixOrder(getRank(), /*rowMajor*/ true); +} + +SmallVector +AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + return getOrderForDotOperand(opIdx, getRank(), /*kContig*/ true); +} + +SmallVector +AMDMfmaEncodingAttr::getRepForOperand(ArrayRef operandShape, + int kWidth, int opIdx) const { + auto operandTileShape = getInstrShapeForOperand(kWidth, opIdx); + auto rank = operandShape.size(); + auto warpsPerCTA = getWarpsPerCTA(); + int numRepBatch = + rank == 3 ? std::max(1, operandShape[0] / warpsPerCTA[0]) : 1; + if (opIdx == 0) + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / + (operandTileShape[0] * warpsPerCTA[rank - 2])), + std::max(1, operandShape[rank - 1] / operandTileShape[1])}; + else { + assert(opIdx == 1); + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / operandTileShape[0]), + std::max(1, operandShape[rank - 1] / (operandTileShape[1] * + warpsPerCTA[rank - 1]))}; + } +} + +SwizzledSharedEncodingAttr AMDMfmaEncodingAttr::composeSharedLayoutForOperand( + CTALayoutAttr ctaLayout, int operandIdx, ArrayRef operandShape, + ArrayRef sharedOrder, unsigned vectorSize, unsigned elemBitWidth, + bool needTrans) const { + int kDimIndex = operandIdx == 0 ? 1 : 0; + if (needTrans) + kDimIndex = 1 - kDimIndex; + + bool isKContig = sharedOrder[0] == kDimIndex; + // GFX950 supports LDS transpose load instructions, so we need swizzling even + // when K dimension is not the contiguous dimension. + bool isGFX950 = getVersionMajor() == 4; + bool swizzleNonKContig = + isGFX950 && (elemBitWidth == 8 || elemBitWidth == 16); + + if (!isKContig && !swizzleNonKContig) { + // Do not swizzle. In this case accesses will go in different banks even + // without swizzling. + return SwizzledSharedEncodingAttr::get(getContext(), 1, 1, 1, sharedOrder, + ctaLayout); + } + + const unsigned numBanks = isGFX950 ? 64 : 32; + const unsigned bankBitWidth = 32; + const unsigned simdWidth = 16; + + // Number of inner dimension rows per one pattern repeat + int innerDimLength = operandShape[sharedOrder[0]]; + int elemsPerOneBanksRow = (numBanks * bankBitWidth) / elemBitWidth; + + int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); + int maxPhase = + std::max(std::min(simdWidth / perPhase, innerDimLength / vectorSize), 1u); + + // TODO (zhanglx): figure out better parameters for mfma4 + if (getMDim() == 4) + maxPhase = 4; + + return SwizzledSharedEncodingAttr::get(getContext(), vectorSize, perPhase, + maxPhase, sharedOrder, ctaLayout); +} + +//===----------------------------------------------------------------------===// +// Wmma encoding +//===----------------------------------------------------------------------===// + +SmallVector AMDWmmaEncodingAttr::getRepOrder() const { + return getMatrixOrder(getRank(), /*rowMajor*/ true); +} + +SmallVector +AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + return getOrderForDotOperand(opIdx, getRank(), /*kContig*/ true); +} + +SmallVector AMDWmmaEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector AMDWmmaEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector AMDWmmaEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} + +SmallVector AMDWmmaEncodingAttr::getElemsPerInstrForOperands() const { + return {16, 16}; +} + +SmallVector +AMDWmmaEncodingAttr::getRepForOperand(ArrayRef operandShape, + Type elemType, int kWidth, + int opIdx) const { + auto operandTileShape = getElemsPerInstrForOperands(); + assert(operandTileShape.size() == 2); + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = operandShape.size(); + assert(rank == 2 || rank == 3); + int numRepBatch = + rank == 3 ? std::max(1, operandShape[0] / warpsPerCTA[0]) : 1; + if (opIdx == 0) + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / + (operandTileShape[0] * warpsPerCTA[rank - 2])), + std::max(1, operandShape[rank - 1] / operandTileShape[1])}; + else { + assert(opIdx == 1); + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / operandTileShape[0]), + std::max(1, operandShape[rank - 1] / (operandTileShape[1] * + warpsPerCTA[rank - 1]))}; + } +} + +SmallVector AMDWmmaEncodingAttr::getMNKDimPerInstr() { + // TODO: move magic numbers out of the code + return {16, 16, 16}; +} + +unsigned AMDWmmaEncodingAttr::getKWidthForOperands() const { + SmallVector sizePerThread(getRank(), 1); + auto numReplicated = getVersion() == 1 ? 2 : 1; + auto elemsPerInstr = + numReplicated * product(getElemsPerInstrForOperands()) / 32; + return elemsPerInstr; +} + +//===----------------------------------------------------------------------===// +// sunrise_mma encoding +//===----------------------------------------------------------------------===// +SmallVector SunriseMmaEncodingAttr::getRepOrder() const { + return getMatrixOrder(getRank(), /*rowMajor*/ true); +} + +SmallVector +SunriseMmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + return getOrderForDotOperand(opIdx, getRank(), /*kContig*/ true); +} + +SmallVector +SunriseMmaEncodingAttr::getInstrShapeForOperand(unsigned opIdx) const { + // A, B矩阵每个warp加载的尺寸,只支持行主序 + int packPer32bit = 32 / this->getInputElemBitWidth(); + if(opIdx == 0) { + switch(packPer32bit) { + case 1: return {8, 4}; + case 2: return {8, 8}; + case 4: return {8, 16}; + case 8: return {8, 32}; + default: llvm_unreachable("unsupported packPer32bit"); + } + } + else { + switch(packPer32bit) { + case 1: return {4, 8}; + case 2: return {8, 8}; + case 4: return {16, 8}; + case 8: return {32, 8}; + default: llvm_unreachable("unsupported packPer32bit"); + } + } + return {0, 0}; +} + +SmallVector +SunriseMmaEncodingAttr::getShapePerCTATileForOperand(unsigned opIdx) const { + // CTA内所有warp执行一次mma操作对应输入tensor的尺寸 + unsigned k = 0; // mma指令k维度的元素个数 + int elemBitWidth = this->getInputElemBitWidth(); + switch (elemBitWidth) { + case 4: k = 32; break; + case 8: k = 16; break; + case 16: k = 8; break; + case 32: k = 4; break; + default: + llvm::report_fatal_error("SunriseMmaEncodingAttr::getShapePerCTATileForOperand " + "unsuppored inputElemBitWidth"); + } + auto shapePerCTATile = getShapePerCTATile(); + SmallVector ret; + if (opIdx == 0) { + ret = {shapePerCTATile[0], k}; + } else { + ret = {k, shapePerCTATile[1]}; + } + return ret; +} + +SmallVector SunriseMmaEncodingAttr::getShapePerCTATile() const { + // CTA内所有warp执行一次mma操作的结果的尺寸 + auto warpsPerCTA = getWarpsPerCTA(); + assert(warpsPerCTA.size() == 2); + SmallVector shapePerCTATile(2); + shapePerCTATile[0] = warpsPerCTA[0] * 8; + shapePerCTATile[1] = warpsPerCTA[1] * 8; + return shapePerCTATile; +} + +SmallVector SunriseMmaEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector SunriseMmaEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector SunriseMmaEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} + +// 获取mma的两个输入a、b的每个维度的切割数,即每个维度有几个CTATile +// 【注意】两个a和b的元素类型应该一致 +// 返回值:[repM, repK]或[repK, repN] +SmallVector SunriseMmaEncodingAttr::getRepForOperand(ArrayRef operandShape, + Type elemType, int opIdx) const { + int instrM = 8, instrN = 8, instrK = 0; + if(elemType.isF32() || elemType.isInteger(32)) { instrK = 4; } + else if(elemType.isF16() || elemType.isBF16()) { instrK = 8; } + else if(elemType.isInteger(8)) { instrK = 16; } + else if(elemType.isInteger(4)) { instrK = 32; } + else { + llvm::report_fatal_error("unsupported tensor data type for tmma!"); + return {}; + } + + auto warpsPerCTA = getWarpsPerCTA(); + SmallVector ret; + if (opIdx == 0) + ret = {std::max(1, operandShape[0] / (instrM * warpsPerCTA[0])), + std::max(1, operandShape[1] / instrK)}; + else { + assert(opIdx == 1); + ret = {std::max(1, operandShape[0] / instrK), + std::max(1, operandShape[1] / (instrN * warpsPerCTA[1]))}; + } + return ret; +} +//===----------------------------------------------------------------------===// +// Mma encoding +//===----------------------------------------------------------------------===// + +bool NvidiaMmaEncodingAttr::isVolta() const { return getVersionMajor() == 1; } + +bool NvidiaMmaEncodingAttr::isTuring() const { + return getVersionMajor() == 2 && getVersionMinor() == 1; +} + +bool NvidiaMmaEncodingAttr::isAmpere() const { return getVersionMajor() == 2; } + +bool NvidiaMmaEncodingAttr::isHopper() const { return getVersionMajor() == 3; } + +SmallVector NvidiaMmaEncodingAttr::getRepOrder() const { + return getMatrixOrder(getRank(), /*rowMajor*/ true); +} +SmallVector NvidiaMmaEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector NvidiaMmaEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector NvidiaMmaEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} + +SmallVector +NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + return getOrderForDotOperand(opIdx, getRank(), /*kContig*/ true); +} + +SmallVector +NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef shape, int bitwidth, + int kWidth, int opIdx) const { + assert( + kWidth >= 32 / bitwidth && + "kWidth must be >= 32 / bitwidth for this function to be well-defined"); + auto rank = shape.size(); + // Broadcast long K + auto warpsPerCTA = to_vector(getWarpsPerCTA()); + auto kDim = opIdx == 0 ? rank - 1 : rank - 2; + warpsPerCTA[kDim] = 1; + + SmallVector tileSize; + if (rank == 3) { + tileSize.push_back(1); + } + if (opIdx == 0) { + // m x k + tileSize.push_back(16); + tileSize.push_back(4 * 64 / bitwidth); + } else { + // k x n + // Hopper path never uses the n value, since this method is only invoked + // for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF + // so it's fine if the n is incorrect here + tileSize.push_back(4 * 64 / bitwidth); + tileSize.push_back(8); + } + + SmallVector numRep; + // Lezcano: This is odd. Why do we always return a vector of size 3? + if (rank != 3) { + numRep.push_back(1); + } + for (auto [s, size, warp] : llvm::zip(shape, tileSize, warpsPerCTA)) { + numRep.push_back(std::max(1, s / (size * warp))); + } + return numRep; +} + +//===----------------------------------------------------------------------===// +// DotOperand Encoding +//===----------------------------------------------------------------------===// +SmallVector DotOperandEncodingAttr::getRepOrder() const { + if (auto mma = mlir::dyn_cast(getParent())) { + return mma.getRepOrderForOperand(getOpIdx()); + } + llvm::report_fatal_error( + "getRepOrder not implemented for DotOperandEncodingAttr"); + return {}; +} + +//===----------------------------------------------------------------------===// +// ASM Interface (i.e.: alias) +//===----------------------------------------------------------------------===// + +class TritonGPUOpAsmInterface : public OpAsmDialectInterface { +public: + using OpAsmDialectInterface::OpAsmDialectInterface; + + AliasResult getAlias(Attribute attr, raw_ostream &os) const override { + // Encoding attributes + if (auto mmaAttr = mlir::dyn_cast(attr)) { + os << "mma"; + return AliasResult::FinalAlias; + } else if (auto sharedAttr = mlir::dyn_cast(attr)) { + os << "shared"; + return AliasResult::FinalAlias; + } else if (auto blockedAttr = mlir::dyn_cast(attr)) { + os << "blocked"; + return AliasResult::FinalAlias; + } else if (auto linearAttr = mlir::dyn_cast(attr)) { + os << "linear"; + return AliasResult::FinalAlias; + } /* else if (auto sliceAttr = dyn_cast(attr)) { + os << "slice"; + return AliasResult::FinalAlias; + } */ + // Memory space attributes + if (auto smem = mlir::dyn_cast(attr)) { + os << "smem"; + return AliasResult::FinalAlias; + } + return OpAsmDialectInterface::getAlias(attr, os); + } +}; + +struct TritonGPUInferLayoutInterface + : public triton::DialectInferLayoutInterface { + using DialectInferLayoutInterface::DialectInferLayoutInterface; + + LogicalResult + inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding, + std::optional loc) const override { + resultEncoding = + SliceEncodingAttr::get(getDialect()->getContext(), axis, + cast(operandEncoding)); + return success(); + } + + // Infer the encoding of a tt.trans(x) given the encoding of x. + // + // Our goal is to choose an encoding so that the trans is a "nop". For + // example, in a blocked encoding, the same GPU threads hold the same + // elements, they're just "renamed" -- what was element [i,j] of the tensor is + // now element [j,i], but that element is held by the same GPU thread. + // + // For most properties of the encoding, we let + // outputEnc.prop = inputEnc.prop * trans.order, + // where `x * y` means we apply permutation y to x. + // + // This works because prop[i] tells you something about the i'th dimension of + // the tensor. (For example, sizePerThread[2] == 4 means that one GPU thread + // contains 4 elements along dim 2 of the tensor.) The transpose reorders the + // dimensions according to the perm trans.order, so we achieve our goal of + // having a "nop" transpose by reordering the values in the prop the same way. + // + // The big exception to this is the encoding's `order`. + // + // An encoding's order is a list of dimensions, from fastest moving (most + // minor) to slowest moving. Thus enc.order[i] does not tell you something + // about the i'th dimension of the tensor, and it would be disasterously + // incorrect to do enc.order * trans.order. + // + // But! If we invert enc.order, it *does* meet this criterion. For example, + // if enc.order = [2,0,1], inverse(enc.order) = [1,2,0]. If you stare at it, + // you'll see that inverse(enc.order)[i] == j means that dimension i is the + // j'th most minor. Therefore we can safely permute *this* by trans.order. + // + // Thus we have + // + // outputEnc.order = inverse(inverse(inputEnc.order) * trans.order) + // = inverse(trans.order) * inputEnc.order. + // + LogicalResult + inferTransOpEncoding(Attribute operandEncoding, ArrayRef shape, + ArrayRef order, Attribute &resultEncoding, + std::optional loc) const override { + // Note: inferFooOpEncoding should not crash if given invalid inputs, which + // happens when someone creates invalid IR. If we return failure() on + // error, then MLIR will generate a helpful error message. + if (isIota(order)) { + resultEncoding = operandEncoding; + return success(); + } + if (shape.size() != order.size()) { + return emitOptionalError(loc, "shape and order rank do not match: ", + shape.size(), " vs ", order.size()); + } + auto checkRank = [&](unsigned rank) { + if (rank != order.size()) { + return emitOptionalError(loc, "rank of encoding does not match order: ", + rank, " vs ", order.size()); + } + return success(); + }; + + auto *ctx = getDialect()->getContext(); + auto invOrder = inversePermutation(order); + SmallVector invOrderUnsigned(invOrder.begin(), invOrder.end()); + + if (auto enc = dyn_cast(operandEncoding)) { + if (failed(checkRank(enc.getRank()))) + return failure(); + + CTALayoutAttr ctaLayout = + permuteCTALayout(ctx, enc.getCTALayout(), order); + resultEncoding = SwizzledSharedEncodingAttr::get( + ctx, enc.getVec(), enc.getPerPhase(), enc.getMaxPhase(), + applyPermutation(invOrderUnsigned, enc.getOrder()), ctaLayout); + return success(); + } + + if (auto enc = dyn_cast(operandEncoding)) { + if (failed(checkRank(enc.getRank()))) + return failure(); + if (order != ArrayRef({1, 0})) { + return emitOptionalError( + loc, "NVMMSharedEncoding can only be transposed in 2D"); + } + + CTALayoutAttr ctaLayout = + permuteCTALayout(ctx, enc.getCTALayout(), order); + resultEncoding = NVMMASharedEncodingAttr::get( + ctx, enc.getSwizzlingByteWidth(), !enc.getTransposed(), + enc.getElementBitWidth(), enc.getFp4Padded(), ctaLayout); + return success(); + } + + if (auto enc = dyn_cast(operandEncoding)) { + if (failed(checkRank(enc.getRank()))) + return failure(); + + CTALayoutAttr ctaLayout = + permuteCTALayout(ctx, enc.getCTALayout(), order); + resultEncoding = BlockedEncodingAttr::get( + ctx, applyPermutation(enc.getSizePerThread(), order), + applyPermutation(enc.getThreadsPerWarp(), order), + applyPermutation(enc.getWarpsPerCTA(), order), + applyPermutation(invOrderUnsigned, enc.getOrder()), ctaLayout); + return success(); + } + + auto ll = toLinearLayout(shape, operandEncoding); + auto transposedLl = transposeLinearLayout(ll, order); + resultEncoding = LinearEncodingAttr::get(ctx, std::move(transposedLl)); + return success(); + } + + LogicalResult + inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding, + std::optional location) const override { + auto sliceEncoding = mlir::dyn_cast(operandEncoding); + if (!sliceEncoding) + return emitOptionalError( + location, "ExpandDimsOp operand encoding must be SliceEncodingAttr"); + if (sliceEncoding.getDim() != axis) + return emitOptionalError( + location, "Incompatible slice dimension for ExpandDimsOp operand"); + resultEncoding = sliceEncoding.getParent(); + return success(); + } + + LogicalResult + inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, + Attribute retEncoding, + std::optional location) const override { + auto mmaRetEncoding = mlir::dyn_cast(retEncoding); + if (mmaRetEncoding && mmaRetEncoding.isHopper()) { + auto dotOpEnc = mlir::dyn_cast(operandEncoding); + if (!mlir::isa(operandEncoding) && + !(opIdx == 0 && dotOpEnc && dotOpEnc.getOpIdx() == 0 && + mlir::isa(dotOpEnc.getParent()))) { + return emitOptionalError( + location, "unexpected operand layout for NvidiaMmaEncodingAttr v3"); + } + } else if (auto dotOpEnc = + mlir::dyn_cast(operandEncoding)) { + if (opIdx != dotOpEnc.getOpIdx()) + return emitOptionalError(location, "Wrong opIdx"); + if (retEncoding != dotOpEnc.getParent()) + return emitOptionalError(location, "Incompatible parent encoding"); + } else + return emitOptionalError( + location, "Dot's a/b's encoding should be of DotOperandEncodingAttr"); + return success(); + } + + LogicalResult + verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, + Attribute operandEncodingB) const override { + auto aEncoding = + mlir::dyn_cast(operandEncodingA); + auto bEncoding = + mlir::dyn_cast(operandEncodingB); + if (!aEncoding && !bEncoding) + return mlir::success(); + auto mmaAEncoding = + mlir::dyn_cast_or_null(aEncoding.getParent()); + if (mmaAEncoding && mmaAEncoding.isHopper()) + return success(); + // Verify that the encodings are valid. + if (!aEncoding || !bEncoding) + return op->emitError("mismatching encoding between A and B operands"); + if (aEncoding.getKWidth() != bEncoding.getKWidth()) + return op->emitError("mismatching kWidth between A and B operands"); + return success(); + } + + // Given a src shape + encoding and a dst shape, our goal is to compute a dst + // encoding that makes the reshape a "nop". That is, if GPU thread [x,y,z] + // contains elements [a,b,c,d] before the reshape, it contains those same + // elements after the reshape, they're just "renamed". + // + // Using legacy layouts, a dst encoding that satisfies this property may not + // exist. Here are some positive and negative examples. + // + // - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so + // dim 1 is the fastest-changing in the dst, but the src has the opposite + // order. + // - OK: 2x2x32 order=[1,0,2] -> 4x32. We choose dst order [0,1]. + // What's important is that the 2x2 dimensions appear in major-to-minor + // order. + // - NOT OK: 32x32 sizePerThread=[2,2] -> 1024. Thread 0 in the src + // contains elements [(0,0), (0,1), (1,0), and (1,1)]. We cannot express + // this with an encoding based on the dst shape. + // - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will + // contain the same elements as before. + // + // With linear layouts, we can always find a dst encoding that satisfies + // this property. See inferReshapeOpEncoding. + // + // Users of this function require that it is symmetrical: if + // (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) => + // srcEnc. + LogicalResult inferReshapeOpLegacyEncoding(ArrayRef srcShape, + Attribute srcEnc, + ArrayRef dstShape, + Attribute &dstEnc) const { + auto src = mlir::dyn_cast(srcEnc); + if (!src) { + return failure(); + } + + // Nop reshape; we can always infer an encoding. + if (srcShape == dstShape) { + dstEnc = srcEnc; + return success(); + } + + // default -> default encoding is always a nop. + auto context = srcEnc.getContext(); + int32_t numWarps = product(src.getWarpsPerCTA()); + int32_t threadsPerWarp = product(src.getThreadsPerWarp()); + int32_t numCTAs = product(src.getCTALayout().getCTAsPerCGA()); + if (srcEnc == getDefaultBlockedEncoding(context, srcShape, numWarps, + threadsPerWarp, numCTAs)) { + dstEnc = getDefaultBlockedEncoding(context, dstShape, numWarps, + threadsPerWarp, numCTAs); + return success(); + } + + // Feature flag to disable this routine while it's relatively new. + // TODO(jlebar): Remove this once we're confident in the code. + if (triton::tools::getBoolEnv( + "TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE")) { + return failure(); + } + + // Cowardly refuse to handle encodings with multiple CTAs. CTAsPerCGA + // should be like the other fields in blocked encoding, but I'm not sure how + // to handle CTASplitNum. + if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) || + !all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) { + return failure(); + } + + // Cowardly refuse to handle encodings where shape[dim] is not divisible by + // sizePerThread[dim], threadsPerWarp[dim], and warpsPerCTA[dim]. (We make + // an exception if the block is larger than the shape.) + auto checkDivisibility = [&](StringRef name, ArrayRef subblock) { + for (int dim = 0; dim < srcShape.size(); dim++) { + if (srcShape[dim] >= subblock[dim] && + srcShape[dim] % subblock[dim] != 0) { + return failure(); + } + } + return success(); + }; + if (!succeeded( + checkDivisibility("sizePerThread", src.getSizePerThread())) || + !succeeded( + checkDivisibility("threadsPerWarp", src.getThreadsPerWarp())) || + !succeeded(checkDivisibility("warpsPerCTA", src.getWarpsPerCTA()))) { + return failure(); + } + + SmallVector, SmallVector>> decomp = + getReshapeDecomposition(srcShape, dstShape); + + // enc.order[i] == j means that dimension j is the enc.order[i]'th most + // minor. But what we usually want is the inverse: inverse(enc.order)[i] = j + // means that dimension i is the j'th most minor (larger means more major). + auto srcInvOrder = inversePermutation(src.getOrder()); + + // If src dims [a,b,c] are to be merged, then they must be consecutive in + // physical order, with `a` being the most major. + for (const auto &[srcDims, dstDims] : decomp) { + if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) { + return failure(); + } + } + + // If src dims [a,b,c] are to be merged, then `c` must fill up sizePerThread + // / threadsPerWarp / blocksPerCTA before `b` can have any non-1 values. + // Examples: + // + // - NOT OK: shape=[4,4,4], sizePerThread=[1,2,2]. + // The total sizePerThread for dim 2 is 2, which is less than dim 2's + // size of 4. Therefore dim 1 cannot have non-1 sizePerThread. + // + // - OK: shape=[4,4,4], sizePerThread=[1,2,4]. + // Dim 2's sizePerThread covers its whole size, so dim 1 is allowed to + // have non-1 sizePerThread. + // + // - NOT OK: shape=[4,4,4], sizePerThread=[2,1,4]. + // Dim 1's sizePerThread does not cover its whole size, so dim 0 is not + // allowed to have non-1 sizePerThread. + // + // - NOT OK: shape=[4,4,4], sizePerThread=[1,1,2], + // threadsPerWarp=[1,2,1]. + // Dim 2 has 2 elems per thread and 1 thread per warp. 2*1 is less than + // dim 2's size. Therefore dim 1 must have threadsPerWarp=1. + // + // In addition, the encoding's block can be larger than the shape, but only + // in the most-major dimension of each decomposed chunk, and only after + // we've "used up" the more minor dims. Examples: + // + // - OK: shape=[4,4,4], sizePerThread=[1,2,4], threadsPerWarp=[16,2,1], + // warpsPerCTA=[4,1,1]. + // The whole size of dims 0 and 1 are covered by sizePerThread * + // threadsPerWarp. Therefore dim 2 is allowed to have threadsPerWarp and + // warpsPerCTA larger than its size. + for (const auto &[srcDims, dstDims] : decomp) { + auto shapeRemaining = gather(srcShape, srcDims); + auto checkSubblock = [&, srcDims = srcDims](ArrayRef subblock) { + // Iterate minor-to-major (i==0 is most major). + for (int i = srcDims.size() - 1; i >= 0; i--) { + int dim = srcDims[i]; + if (subblock[dim] == 1) { + continue; + } + + // Check that more-minor dims all have 1 in shapeRemaining. + for (int j = i + 1; j < srcDims.size(); j++) { + if (shapeRemaining[j] != 1) { + return failure(); + } + } + + if (shapeRemaining[i] >= subblock[dim]) { + assert(shapeRemaining[i] % subblock[dim] == 0); // checked earlier + shapeRemaining[i] /= subblock[dim]; + } else { + shapeRemaining[i] = 0; + } + + // Is the block larger than the shape in this dimension? This is OK + // only if we're the most-major dimension of the chunk and in all + // future chunks, only this most-major dim has a non-1 size. + if (shapeRemaining[i] == 0 && i != 0) { + return failure(); + } + } + return success(); + }; + if (!succeeded(checkSubblock(src.getSizePerThread())) || + !succeeded(checkSubblock(src.getThreadsPerWarp())) || + !succeeded(checkSubblock(src.getWarpsPerCTA()))) { + return failure(); + } + } + + // Given e.g. src.getSizePerThread(), computeSubblockSize computes e.g. + // dst.getSizePerThread(). This should be called for each of sizePerThread, + // threadsPerWarp, and warpsPerCTA, in that order. + SmallVector dstShapeRemaining(dstShape); + auto computeSubblockSize = [&](ArrayRef srcSubblock, + SmallVector &dstSubblock, + StringRef fieldName) -> LogicalResult { + // The dst subblock is "filled up" greedily starting with the most minor + // dim. When we're done, we are left with a smaller shape, of size + // dstShape / dstSubblock, which we store in dstShapeRemaining and use for + // the next call to computeSubblockSize. + dstSubblock.resize(dstShape.size()); + for (const auto &[srcDims, dstDims] : decomp) { + int64_t subblockRemaining = product(gather(srcSubblock, srcDims)); + for (int i = dstDims.size() - 1; i >= 0; i--) { + auto &val = dstSubblock[dstDims[i]]; + auto &shapeRemaining = dstShapeRemaining[dstDims[i]]; + val = std::min(subblockRemaining, shapeRemaining); + + assert(shapeRemaining % val == 0); // Checked earlier. + subblockRemaining /= val; + shapeRemaining /= val; + } + + // If there are any elems remaining in the subblock, it must be because + // the block is larger than the shape. This excess goes into the + // most-major dim of the subblock. + dstSubblock[dstDims[0]] *= subblockRemaining; + } + return success(); + }; + + SmallVector dstSizePerThread; + SmallVector dstThreadsPerWarp; + SmallVector dstWarpsPerCTA; + if (!succeeded(computeSubblockSize(src.getSizePerThread(), dstSizePerThread, + "sizePerThread")) || + !succeeded(computeSubblockSize(src.getThreadsPerWarp(), + dstThreadsPerWarp, "threadsPerWarp")) || + !succeeded(computeSubblockSize(src.getWarpsPerCTA(), dstWarpsPerCTA, + "warpsPerCTA"))) { + return failure(); + } + + // Since we know that each set of srcDims is consecutive, we can + // meaningfully sort decomp by the physical order of the src dimensions, + // major-to-minor. This will also be the order of the dst dimensions. + llvm::sort(decomp, [&](const auto &a, const auto &b) { + const auto &[srcDimsA, dstDimsA] = a; + const auto &[srcDimsB, dstDimsB] = b; + return srcInvOrder[srcDimsA.front()] < srcInvOrder[srcDimsB.front()]; + }); + + // Compute the dst order. Make the dimensions appear in the same order as + // their corresponding src dimensions. + SmallVector dstInvOrder(dstShape.size()); + int i = 0; + for (const auto &[srcDims, dstDims] : decomp) { + for (auto dim : reverse(dstDims)) { + dstInvOrder[dim] = i++; + } + } + auto dstOrder = inversePermutation(dstInvOrder); + + // CTALayout can be all 1's because we bailed on multi-CTA layouts above. + auto CTALayout = CTALayoutAttr::get( + src.getContext(), + /*CTAsPerCGA=*/SmallVector(dstShape.size(), 1), + /*CTASplitNum=*/SmallVector(dstShape.size(), 1), + /*CTAOrder=*/llvm::to_vector(llvm::seq(dstShape.size()))); + + dstEnc = BlockedEncodingAttr::get(src.getContext(), dstSizePerThread, + dstThreadsPerWarp, dstWarpsPerCTA, + dstOrder, CTALayout); + + return success(); + } + + LogicalResult + verifyLayoutsAreEqual(ArrayRef shape, Attribute expected, + Attribute got, + std::optional loc) const override { + if (expected == got) { + return success(); + } + if (!expected || !got) + return failure(); + + // Check whether the encodings are structurally the same. + if (!areLayoutsEquivalent(shape, expected, got)) { + return emitOptionalError(loc, "Expected result encoding ", expected, + " but was ", got); + } + return success(); + } + + LogicalResult + inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const override { + if (product(srcShape) != product(dstShape)) { + return emitOptionalError(loc, "numel of dst shape does not match " + "numel of src shape"); + } + auto result = + inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc); + if (succeeded(result)) { + return result; + } + // If the legacy encoding failed use LinearLayouts. + // Once LinearLayouts are more widely used, we can remove + // inferReshapeOpLegacyEncoding and simply use LLs. + LinearLayout ll = inferReshapeLinearLayout(srcShape, srcEnc, dstShape); + + dstEnc = LinearEncodingAttr::get(srcEnc.getContext(), ll); + return success(); + } + + LogicalResult + inferDefaultJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, + ArrayRef shape, + std::optional loc) const override { + if (auto enc = mlir::dyn_cast(srcEnc)) { + // JoinOp takes two tensors of shape AxBxC and generates a tensor of shape + // AxBxCx2. The encoding is the same as the input, but with 2 elems per + // thread in the new dimension. The new dimension is the fastest running + // dimension. + auto append = [](ArrayRef vals, int val) { + SmallVector ret(vals); + ret.push_back(val); + return ret; + }; + auto appendMajorDim = [](ArrayRef order) { + SmallVector ret(order); + ret.insert(ret.begin(), ret.size()); + return ret; + }; + dstEnc = BlockedEncodingAttr::get( + enc.getContext(), append(enc.getSizePerThread(), 2), + append(enc.getThreadsPerWarp(), 1), append(enc.getWarpsPerCTA(), 1), + appendMajorDim(enc.getOrder()), + CTALayoutAttr::get(enc.getContext(), append(enc.getCTAsPerCGA(), 1), + append(enc.getCTASplitNum(), 1), + appendMajorDim(enc.getCTAOrder()))); + return success(); + } + + auto ctx = getContext(); + + // Append dim to shape + auto ll = toLinearLayout(shape, srcEnc); + SmallVector dstShape(shape.begin(), shape.end()); + dstShape.push_back(1); + ll = ll.reshapeOuts(standardOutDimPairs(ctx, dstShape)); + + // Try join on last dim + auto axis = dstShape.size() - 1; + auto newLl = LinearLayout::empty(); + auto result = + tryJoinOnAxis(ctx, ll, newLl, /*fwdInference=*/true, axis, loc); + + assert(result.succeeded()); + dstEnc = LinearEncodingAttr::get(ctx, newLl); + return success(); + } + + LogicalResult + inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc, + ArrayRef shape, + std::optional loc) const override { + auto enc = mlir::dyn_cast(srcEnc); + if (enc) { + // SplitOp takes a tensor of shape AxBxCx2 and generates two tensors of + // shape AxBxC. The input must have 2 elements per thread in the last + // dimension, which must be the fastest running dimension. The result + // encoding is the same as the input, but with the last dimension removed. + if (enc.getSizePerThread().back() != 2) { + return emitOptionalError( + loc, "SplitOp requires 2 elements per thread in the " + "last dimension of the input"); + } + if (enc.getThreadsPerWarp().back() != 1 || + enc.getWarpsPerCTA().back() != 1 || enc.getCTAsPerCGA().back() != 1) { + return emitOptionalError( + loc, "SplitOp requires threadsPerWarp, warpsPerCTA, " + "and CTAsPerCGA = 1 for the last dimension of the input"); + } + if (enc.getCTALayout().getCTAsPerCGA().back() != 1) { + return emitOptionalError( + loc, + "SplitOp requires the last dimension to be most-minor in CTAOrder"); + } + SmallVector newOrder(enc.getOrder()); + int splitDim = newOrder.size() - 1; + // Remove splitDim from order. + newOrder.erase(std::remove(newOrder.begin(), newOrder.end(), splitDim), + newOrder.end()); + dstEnc = BlockedEncodingAttr::get( + enc.getContext(), // + ArrayRef(enc.getSizePerThread()).drop_back(1), + ArrayRef(enc.getThreadsPerWarp()).drop_back(1), + ArrayRef(enc.getWarpsPerCTA()).drop_back(1), ArrayRef(newOrder), + CTALayoutAttr::get(enc.getContext(), // + ArrayRef(enc.getCTAsPerCGA()).drop_back(1), + ArrayRef(enc.getCTASplitNum()).drop_back(1), + ArrayRef(enc.getCTAOrder()).drop_front(1))); + return success(); + } + + auto axis = shape.size() - 1; + if (shape[axis] != 2) { + return emitOptionalError( + loc, "SplitOp input shape should have 2 in the last dim"); + } + + auto ctx = getContext(); + + // Split on last dim + auto ll = toLinearLayout(shape, srcEnc); + auto newLl = LinearLayout::empty(); + auto result = + tryJoinOnAxis(ctx, ll, newLl, /*fwdInference=*/false, axis, loc); + if (!result.succeeded()) { + return failure(); + } + + // Remove last dim from newLl (which should be 1) + SmallVector dstShape(shape.begin(), shape.end()); + dstShape.pop_back(); + newLl = newLl.reshapeOuts(standardOutDimPairs(ctx, dstShape)); + dstEnc = LinearEncodingAttr::get(ctx, newLl); + return success(); + } + + LogicalResult + inferFp4ToFpOpEncoding(ArrayRef shape, int axis, Attribute inEnc, + Attribute &outEnc, bool fwdInference, + std::optional loc) const override { + // We implement two legacy layout propagations + // Once we fully migrate to LinearLayouts, we can remove these. + auto *ctx = getContext(); + // The output encoding will only be a legacy encoding if the axis is the + // fastest running dimension. + // FIXME: We should make sure that there are enough elements along the axis + // axis whenever fwdInference is false + if (getOrder(cast(inEnc), shape)[axis] == 0) { + // Dot operand: double kWidth if kDim == axis. + if (auto dotEnc = mlir::dyn_cast(inEnc)) { + auto kWidth = dotEnc.getKWidth(); + if (fwdInference) { + kWidth *= 2; + } else { + if (kWidth > 1) { + // bwd inference + kWidth /= 2; + } else { + return emitOptionalError(loc, + "Fp4ToFpOp requires at least 2 elements " + "per thread in the axis dimension"); + } + } + outEnc = DotOperandEncodingAttr::get(ctx, dotEnc.getOpIdx(), + dotEnc.getParent(), kWidth); + return success(); + } + + // Blocked layout: double elemsPerThread[axis]. + if (auto blockedEnc = mlir::dyn_cast(inEnc)) { + auto sizePerThread = llvm::to_vector(blockedEnc.getSizePerThread()); + if (fwdInference) { + sizePerThread[axis] *= 2; + } else { + if (sizePerThread[axis] > 1) { + sizePerThread[axis] /= 2; + } else { + return emitOptionalError( + loc, "Fp4ToFpOp requires at least 2 elements per " + "thread in the axis dimension"); + } + } + outEnc = BlockedEncodingAttr::get( + ctx, sizePerThread, blockedEnc.getThreadsPerWarp(), + blockedEnc.getWarpsPerCTA(), blockedEnc.getOrder(), + blockedEnc.getCTALayout()); + return success(); + } + } + + auto ll = toLinearLayout(shape, inEnc); + auto newLl = LinearLayout::empty(); + auto result = tryJoinOnAxis(ctx, ll, newLl, fwdInference, axis, loc); + if (!result.succeeded()) + return result; + outEnc = LinearEncodingAttr::get(ctx, newLl); + return success(); + } +}; + +struct TritonGPUVerifyTensorLayoutInterface + : public triton::DialectVerifyTensorLayoutInterface { + using DialectVerifyTensorLayoutInterface::DialectVerifyTensorLayoutInterface; + + LogicalResult verifyTensorLayout( + Attribute layout, RankedTensorType rankedTy, Operation *op, + function_ref makeErr) const override { + if (isa(layout)) + return makeErr() << "Shared layout is not allowed on tensor type."; + // TODO(jlebar): Currently this only checks blocked layouts, but other + // layouts also have invariants! + + // TODO(jlebar): Handle the case when the encoding is nested within tt.ptr. + if (auto blocked = dyn_cast(layout)) { + ModuleOp module = op->getParentOfType(); + + // A different verifier should have checked that the layout itself is + // valid, including that threads-per-warp has the same rank as + // warps-per-block etc. + if (blocked.getRank() != rankedTy.getRank()) { + return makeErr() << layout << ".\nLayout has rank " << blocked.getRank() + << ", but the tensor it's attached to has rank " + << rankedTy.getRank() << "."; + } + + int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module); + int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp()); + if (layoutThreadsPerWarp != moduleThreadsPerWarp) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutThreadsPerWarp + << " threads per warp, but the module specifies " + << moduleThreadsPerWarp << " threads per warp."; + } + + std::optional moduleWarpsPerCTA = maybeLookupNumWarps(op); + if (!moduleWarpsPerCTA) { + return makeErr() + << "Could not determine the number of warps per CTA. Operation " + "is not in a context with `ttg.num-warps`."; + } + int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA()); + if (layoutWarpsPerCTA != *moduleWarpsPerCTA) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutWarpsPerCTA + << " warps per CTA, but the context requires " + << *moduleWarpsPerCTA << " warps per CTA."; + } + + if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) { + int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs(module); + int64_t layoutCTAsPerCGA = + product(blocked.getCTALayout().getCTAsPerCGA()); + if (layoutCTAsPerCGA != moduleCTAsPerCGA) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutCTAsPerCGA + << " CTAs per CGA, but the module specifies " + << moduleCTAsPerCGA << " CTAs per CGA."; + } + } + } + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Layout debug printing +//===----------------------------------------------------------------------===// + +// Return N-D delinearized indices from a linear index. +static SmallVector delinearizeIndex(int64_t idx, + ArrayRef shape) { + SmallVector ret(shape.size()); + for (int i = shape.size() - 1; i >= 0; i--) { + ret[i] = idx % shape[i]; + idx /= shape[i]; + } + return ret; +} + +// Returns how many padding characters are needed for the string representation +// of value to be the same as max. +static int numCharacterPadding(int value, int max) { + return std::to_string(max).size() - std::to_string(value).size(); +} + +// return the string padded to have the same length as max. +static std::string paddedString(int value, int max) { + int nbChar = numCharacterPadding(value, max); + std::string str; + for (int i = 0; i < nbChar; i++) + str += " "; + str += std::to_string(value); + return str; +} + +std::string getSharedLayoutStr(RankedTensorType tensorType, + bool useHWPointOfView) { + auto layout = tensorType.getEncoding(); + if (!layout) + return ""; + + LinearLayout ll = triton::gpu::toLinearLayout(tensorType.getShape(), layout); + + StringAttr kOffset = StringAttr::get(tensorType.getContext(), "offset"); + StringAttr kBlock = StringAttr::get(tensorType.getContext(), "block"); + int64_t tensorSize = product(tensorType.getShape()); + unsigned numBlocks = getNumCTAs(layout); + int32_t blockSize = tensorSize / numBlocks; + + // elementMapping is for the non-hw layout, offsetMapping for hw-layout + std::vector elementMapping(tensorSize); + std::vector offsetMapping; + + // Shared layouts are a mapping of (block, offset) --> (...) + + // We can just use a single int to index into elementMapping because + // the 'swizzle' operation rearranges the indices---and we want to keep it + // that way + int32_t idx = 0; + // Enumerate all the offsets for each block + for (int32_t block = 0; block < numBlocks; block++) { + for (int32_t offset = 0; offset < blockSize; offset++) { + SmallVector> inputs = { + {kBlock, block}, + {kOffset, offset}, + }; + + SmallVector> outputs = ll.apply(inputs); + + std::string sharedInfo = "("; + std::string &value = elementMapping[idx]; + + if (!value.empty()) + value += "|"; + + value += "("; + // We can build up both strings (for hw/non-hw layouts) concurrently + for (int i = 0; i < outputs.size(); i++) { + // Based on the formatting from LinearLayout::toString, the format for + // the hw layout is slightly different. HW layouts use "," vs ":". + if (i > 0) { + sharedInfo += ","; + value += ":"; + } + auto index = paddedString(outputs[i].second, tensorType.getDimSize(i)); + sharedInfo += index; + value += index; + } + value += ")"; + sharedInfo += ")"; + + offsetMapping.push_back(sharedInfo); + + idx++; + } + } + + std::string layoutStr; + + if (!useHWPointOfView) { + int rank = tensorType.getRank(); + bool newLine = true; + for (int i = 0; i < tensorSize; i++) { + auto indices = delinearizeIndex(i, tensorType.getShape()); + int numOpenBracket = 0; + for (int j = rank - 1; j >= 0; j--) { + if (indices[j] % tensorType.getDimSize(j) != 0) + break; + layoutStr += "["; + numOpenBracket++; + } + if (newLine) { + for (int j = 0; j < rank - numOpenBracket; j++) + layoutStr += " "; + newLine = false; + } + + layoutStr += elementMapping[i]; + auto nextIndices = delinearizeIndex(i + 1, tensorType.getShape()); + for (int j = rank - 1; j >= 0; j--) { + if (nextIndices[j] % tensorType.getDimSize(j) != 0) + break; + layoutStr += "]"; + } + if (nextIndices.back() % tensorType.getShape().back() == 0) { + layoutStr += "\n"; + newLine = true; + } else { + layoutStr += ","; + } + } + } else { + // For the HW view here, print the (block, offset) --> (r,c) mapping + uint32_t idx = 0; + for (int32_t block = 0; block < numBlocks; block++) { + layoutStr += "Block: " + std::to_string(block) + ":\n"; + for (int32_t offset = 0; offset < (tensorSize / numBlocks); offset++) { + layoutStr += "Offset: " + std::to_string(offset) + " -> "; + layoutStr += offsetMapping[idx]; + layoutStr += "\n"; + idx++; + } + } + } + + return layoutStr; +} + +std::string getDistributedLayoutStr(RankedTensorType tensorType, + bool useHWPointOfView) { + auto layout = tensorType.getEncoding(); + if (!layout) + return ""; + + StringAttr kRegister = StringAttr::get(tensorType.getContext(), "register"); + StringAttr kLane = StringAttr::get(tensorType.getContext(), "lane"); + StringAttr kWarp = StringAttr::get(tensorType.getContext(), "warp"); + StringAttr kBlock = StringAttr::get(tensorType.getContext(), "block"); + + LinearLayout ll = triton::gpu::toLinearLayout(tensorType.getShape(), layout); + int64_t tensorSize = product(tensorType.getShape()); + std::vector elementMapping(tensorSize); + std::vector threadMapping; + unsigned threadsPerWarp = ll.getInDimSize(kLane); + unsigned numWarpsPerCTA = ll.getInDimSize(kWarp); + unsigned numBlocks = ll.getInDimSize(kBlock); + int numElementsPerThreads = ll.getInDimSize(kRegister); + for (int blockId = 0; blockId < numBlocks; ++blockId) { + for (int warpId = 0; warpId < numWarpsPerCTA; warpId++) { + for (int tid = 0; tid < threadsPerWarp; ++tid) { + for (int idx = 0; idx < numElementsPerThreads; ++idx) { + SmallVector> inputs = { + {kBlock, blockId}, + {kWarp, warpId}, + {kLane, tid}, + {kRegister, idx}}; + SmallVector> outputs = + ll.apply(inputs); + int32_t linearizedIdx = 0; + int stride = 1; + for (int i = outputs.size() - 1; i >= 0; i--) { + linearizedIdx += outputs[i].second * stride; + stride *= tensorType.getDimSize(i); + } + std::string &value = elementMapping[linearizedIdx]; + if (!value.empty()) + value += "|"; + int padding = numCharacterPadding(blockId, numBlocks) + + numCharacterPadding(tid + warpId * threadsPerWarp, + numWarpsPerCTA * threadsPerWarp) + + numCharacterPadding(idx, numElementsPerThreads); + for (int i = 0; i < padding; i++) + value += " "; + if (numBlocks > 1) + value += "B" + std::to_string(blockId) + ":"; + value += "T" + std::to_string(tid + warpId * threadsPerWarp) + ":" + + std::to_string(idx); + // Now also compute the thread mapping. + std::string threadInfo = "("; + for (int i = 0; i < outputs.size(); i++) { + if (i > 0) + threadInfo += ","; + threadInfo += + paddedString(outputs[i].second, tensorType.getDimSize(i)); + } + threadInfo += ")"; + threadMapping.push_back(threadInfo); + } + } + } + } + std::string layoutStr; + if (!useHWPointOfView) { + // Printing the threads containing each elements of the tensor. + int rank = tensorType.getRank(); + bool newLine = true; + for (int i = 0; i < tensorSize; i++) { + auto indices = delinearizeIndex(i, tensorType.getShape()); + int numOpenBracket = 0; + for (int j = rank - 1; j >= 0; j--) { + if (indices[j] % tensorType.getDimSize(j) != 0) + break; + layoutStr += "["; + numOpenBracket++; + } + if (newLine) { + for (int j = 0; j < rank - numOpenBracket; j++) + layoutStr += " "; + newLine = false; + } + + layoutStr += elementMapping[i]; + auto nextIndices = delinearizeIndex(i + 1, tensorType.getShape()); + for (int j = rank - 1; j >= 0; j--) { + if (nextIndices[j] % tensorType.getDimSize(j) != 0) + break; + layoutStr += "]"; + } + if (nextIndices.back() % tensorType.getShape().back() == 0) { + layoutStr += "\n"; + newLine = true; + } else { + layoutStr += ", "; + } + } + } else { + // Printing the elements in each physical reg/warps/threads. + for (int blockId = 0; blockId < numBlocks; blockId++) { + if (numBlocks > 1) + layoutStr += "Block" + std::to_string(blockId) + ":\n"; + for (int warpId = 0; warpId < numWarpsPerCTA; warpId++) { + layoutStr += "Warp" + std::to_string(warpId) + ":\n"; + for (int idx = 0; idx < numElementsPerThreads; ++idx) { + for (int tid = 0; tid < threadsPerWarp; ++tid) { + int linearizedIdx = + blockId * numWarpsPerCTA * threadsPerWarp * + numElementsPerThreads + + warpId * threadsPerWarp * numElementsPerThreads + + tid * numElementsPerThreads + idx; + layoutStr += threadMapping[linearizedIdx]; + if (tid < threadsPerWarp - 1) + layoutStr += ", "; + } + layoutStr += "\n"; + } + } + } + } + return layoutStr; +} + +template +llvm::SmallVector +mlir::triton::gpu::expandMatrixShapeWithBatch(llvm::ArrayRef s) { + auto rank = s.size(); + assert(rank == 2 || rank == 3); + if (rank == 3) + return llvm::SmallVector(s); + return {1, s[0], s[1]}; +} + +template llvm::SmallVector +mlir::triton::gpu::expandMatrixShapeWithBatch( + llvm::ArrayRef s); + +template llvm::SmallVector +mlir::triton::gpu::expandMatrixShapeWithBatch( + llvm::ArrayRef s); + +llvm::SmallVector +mlir::triton::gpu::expandMatrixOrderWithBatch(llvm::ArrayRef o) { + int rank = o.size(); + assert(rank == 2 || rank == 3); + if (rank == 3) + return llvm::SmallVector(o); + llvm::SmallVector expanded(3, 0); + for (int i = 0; i < rank; ++i) + expanded[i] += o[i] + 1; + return expanded; +} + +std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType, + bool useHWPointOfView) { + auto layout = tensorType.getEncoding(); + + // tensorType is needed later on (e.g., getDimSize(j)), so we still have to + // pass it as a param + if (mlir::isa(layout)) { + return getSharedLayoutStr(tensorType, useHWPointOfView); + } else if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return getDistributedLayoutStr(tensorType, useHWPointOfView); + } + + // else unimplemented, return error + llvm::report_fatal_error("Unimplemented usage of getLayoutStr"); + return ""; +} + +void mlir::triton::gpu::dumpLayout(RankedTensorType tensorType) { + llvm::errs() << getLayoutStr(tensorType, /*useHWPointOfView=*/false); +} + +void mlir::triton::gpu::dumpHWLayout(RankedTensorType tensorType) { + llvm::errs() << getLayoutStr(tensorType, /*useHWPointOfView=*/true); +} + +namespace { +struct TensorModel + : public triton::gpu::TensorOrMemDesc::ExternalModel { + Type getElementType(Type pointer) const { + return cast(pointer).getElementType(); + } + Attribute getEncoding(Type pointer) const { + return cast(pointer).getEncoding(); + } + ArrayRef getShape(Type pointer) const { + return cast(pointer).getShape(); + } + int64_t getRank(Type pointer) const { + return cast(pointer).getRank(); + } + int64_t getElementTypeBitWidth(Type pointer) const { + return cast(pointer).getElementTypeBitWidth(); + } +}; + +struct MemDescModel + : public triton::gpu::TensorOrMemDesc::ExternalModel { + Type getElementType(Type pointer) const { + return cast(pointer).getElementType(); + } + Attribute getEncoding(Type pointer) const { + return cast(pointer).getEncoding(); + } + ArrayRef getShape(Type pointer) const { + return cast(pointer).getShape(); + } + int64_t getRank(Type pointer) const { + return cast(pointer).getShape().size(); + } + int64_t getElementTypeBitWidth(Type pointer) const { + return cast(pointer).getElementType().getIntOrFloatBitWidth(); + } +}; +} // namespace + +void TritonGPUDialect::initialize() { + registerTypes(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "triton/Dialect/TritonGPU/IR/AttrDefs.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" +#include "triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc" + >(); + addInterfaces(); + addInterfaces(); + addInterfaces(); + addInterfaces(); + + RankedTensorType::attachInterface(*getContext()); + MemDescType::attachInterface(*getContext()); +} + +LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + // Verify that dialect attributes are attached to the right ops. + if (llvm::is_contained( + {AttrNumCTAsName, AttrTargetName, AttrNumThreadsPerWarp}, + attr.getName()) && + !isa(op)) { + return op->emitOpError("has unexpected attribute ") + << attr.getName() << " which is expected only on `module` ops"; + } + if (attr.getName() == AttrNumWarpsName && !isa(op)) { + return op->emitOpError("has unexpected attribute ") + << attr.getName() + << " which is expected only on `module` or `tt.func` ops"; + } + + return success(); +} + +int TritonGPUDialect::getNumCTAs(ModuleOp module) { + if (auto attr = module->getAttrOfType(AttrNumCTAsName)) + return attr.getInt(); + return 1; +} + +int TritonGPUDialect::getThreadsPerWarp(ModuleOp module) { + if (auto attr = module->getAttrOfType(AttrNumThreadsPerWarp)) + return attr.getInt(); + return 32; +} + +std::optional triton::gpu::maybeLookupNumWarps(Operation *op) { + if (isa(op)) { + if (auto attr = op->getAttrOfType(AttrNumWarpsName)) + return attr.getInt(); + } else if (auto partitions = + dyn_cast(op->getParentOp())) { + unsigned idx = op->getParentRegion()->getRegionNumber(); + return partitions.getParentOp().getPartitionNumWarps()[idx]; + } + if (Operation *parent = op->getParentOp()) + return maybeLookupNumWarps(parent); + return {}; +} + +int triton::gpu::lookupNumWarps(Operation *op) { + std::optional numWarps = maybeLookupNumWarps(op); + if (!numWarps) { + op->emitOpError( + "is not contained within a context that specifies the number of warps"); + llvm::report_fatal_error("failed to lookup the number of warps, the " + "surrounding module should contain a " + + Twine(AttrNumWarpsName) + " attribute"); + } + return *numWarps; +} + +int triton::gpu::lookupThreadsPerWarp(OpBuilder &rewriter) { + assert(rewriter.getInsertionBlock() && "expected an insertion point"); + Operation *op = rewriter.getInsertionBlock()->getParentOp(); + while (op && !isa(op)) + op = op->getParentOp(); + assert(op && "cannot create thread ID outside of module"); + return triton::gpu::TritonGPUDialect::getThreadsPerWarp(cast(op)); +} + +bool triton::gpu::areLayoutsEquivalent(ArrayRef shape, Attribute lhs, + Attribute rhs) { + auto lhsLL = triton::gpu::toLinearLayout(shape, lhs); + auto rhsLL = triton::gpu::toLinearLayout(shape, rhs); + return lhsLL == rhsLL; +} + +bool triton::gpu::isInnermostContiguous(MemDescType type, unsigned numElems) { + ArrayRef shape = type.getShape(); + Attribute enc = type.getEncoding(); + MLIRContext *ctx = enc.getContext(); + + LinearLayout actual = toLinearLayout(shape, enc); + StringAttr fastestIn = *actual.getInDimNames().begin(); + + // Flatten actual outs in reverse order to produce a row-major flattening + // of the layout + auto outNames = actual.getOutDimNames(); + SmallVector revOut(outNames.begin(), outNames.end()); + std::reverse(revOut.begin(), revOut.end()); + actual = actual.transposeOuts(revOut).flattenOuts(); + + return actual.getNumConsecutiveInOut() >= numElems; +} + +LinearLayout triton::gpu::inferReshapeLinearLayout(ArrayRef srcShape, + Attribute srcEnc, + ArrayRef dstShape) { + auto *ctx = srcEnc.getContext(); + auto src = toLinearLayout(srcShape, srcEnc); + assert(product(srcShape) == product(dstShape)); + auto dst = reshapeLayout(ctx, src, dstShape); + return dst; +} diff --git a/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp new file mode 100644 index 000000000..eacded9b7 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -0,0 +1,1958 @@ +#include + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" + +using mlir::triton::ScaleDotElemType; + +using mlir::triton::ScaleDotElemType; + +namespace mlir::triton::gpu { +namespace { + +// We use the following nomenclature in this file. +// +// - ctaLayout: A layout for one block, i.e. input dims [register, lane, warp] +// for register layouts, and input dims [offset] for shared layouts. +// - cgaLayout: Arrangement of multiple blocks, i.e. input dims [block]. +// +// Note that this is inconsistent with the type name CTALayoutAttr. That type +// is equivalent to our cgaLayout. +// +// IMO the name CTALayoutAttr is wrong. If we tried to be consistent anyway, +// then we'd have to rename ctaLayout to "warpLayout". I think that's more +// confusing than being inconsistent about "cgaLayout", especially when we have +// to consider the size of the warpLayout (surely that's not the "warpSize"). + +#define S(v) StringAttr::get(ctx, (v)) + +SmallVector getDefaultMmaOrder(MmaEncodingTrait layout) { + auto rank = layout.getRepOrderForOperand(0).size(); + return getMatrixOrder(rank, /*rowMajor*/ true); +} + +// TODO Have order be a mandatory argument of standardOutDimNames. +SmallVector permuteDimNames(const SmallVector &names, + const SmallVector &order) { + assert(names.size() == order.size()); + SmallVector ret; + for (unsigned i : order) { + ret.push_back(names[i]); + } + return ret; +} + +// Make a LinearLayout that maps a block-id to an N-dimensional index. +// +// The tensor is split up into CTAsPerCGA pieces, which are distributed among +// the CTAsPerCGA CTAs (i.e. blocks) in the CGA (i.e. groups). +// +// See the nomenclature note at the top of the file for an explanation of why +// this is called makeCgaLayout when it accesunrise a CTALayoutAttr. +LinearLayout makeCgaLayout(CTALayoutAttr layout) { + MLIRContext *ctx = layout.getContext(); + StringAttr kBlock = S("block"); + + int rank = layout.getCTAOrder().size(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + LinearLayout ret = LinearLayout::empty(); + for (int i = 0; i < rank; i++) { + // Start with the most minor dimension, which is order[0]. + int dim = layout.getCTAOrder()[i]; + int split = layout.getCTASplitNum()[dim]; + int ctas = layout.getCTAsPerCGA()[dim]; + assert(ctas % split == 0); + ret *= LinearLayout::identity1D(split, kBlock, outDimNames[dim]) * + LinearLayout::zeros1D(ctas / split, kBlock, outDimNames[dim]); + } + + // Transpose to standard order (dim0, dim1, ...). + return ret.transposeOuts(outDimNames); +} + +// Combines the layout of a CTA (input dims [register, lane, warp]) with the +// layout of a CGA (i.e. a block), and ensures that the resulting layout has the +// given shape. +// +// See the nomenclature note at the top of the file for why the variable with +// type CTALayoutAttr is called cgaLayoutAttr. +LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, + CTALayoutAttr cgaLayoutAttr, + ArrayRef shape) { + int rank = shape.size(); + assert(ctaLayout.getNumOutDims() == rank); + assert(cgaLayoutAttr.getCTAOrder().size() == rank); + MLIRContext *ctx = cgaLayoutAttr.getContext(); + + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + llvm::SmallDenseMap labeledShape; + for (auto [dim, size] : llvm::zip(outDimNames, shape)) { + labeledShape[dim] = size; + } + + LinearLayout cgaLayout = + ensureLayoutNotLargerThan(makeCgaLayout(cgaLayoutAttr), labeledShape) + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); + + // Calculate the shape of the ctaLayout, which is `shape` divided by the + // cgaLayout's size. + llvm::SmallDenseMap ctaShape; + assert(llvm::to_vector(ctaLayout.getOutDimNames()) == + llvm::to_vector(cgaLayout.getOutDimNames())); + for (auto dim : ctaLayout.getOutDimNames()) { + ctaShape[dim] = + std::max(int64_t{1}, labeledShape[dim] / cgaLayout.getOutDimSize(dim)); + } + + ctaLayout = ensureLayoutNotSmallerThan(ctaLayout, ctaShape); + ctaLayout = ensureLayoutNotLargerThan(ctaLayout, ctaShape); + + LinearLayout ret = (ctaLayout * cgaLayout).transposeOuts(outDimNames); + for (auto dim : ret.getOutDimNames()) { + assert(ret.getOutDimSize(dim) == labeledShape[dim]); + } + return ret; +} + +LinearLayout swizzledSharedToLinearLayout(ArrayRef shape, + SwizzledSharedEncodingAttr shared) { + MLIRContext *ctx = shared.getContext(); + + auto shapePerCTA = getShapePerCTA(shared, shape); + + int rank = shape.size(); + if (rank == 1) { + return combineCtaCgaWithShape( + LinearLayout::identity1D(shapePerCTA[0], S("offset"), S("dim0")), + shared.getCTALayout(), shape); + } + + auto outDimNames = standardOutDimNames(ctx, rank); + + // Construct bases for the 2 most minor dimensions of the layout. These are + // the dims that get swizzled. + assert(shape.size() >= 2); + int colDim = shared.getOrder()[0]; + int rowDim = shared.getOrder()[1]; + int numCols = shapePerCTA[colDim]; + int numRows = shapePerCTA[rowDim]; + StringAttr colDimName = outDimNames[colDim]; + StringAttr rowDimName = outDimNames[rowDim]; + + std::vector> bases2D; + for (int col = 1; col < numCols; col *= 2) { + bases2D.push_back({0, col}); + } + for (int row = 1; row < numRows; row *= 2) { + int vec = shared.getVec(); + int perPhase = shared.getPerPhase(); + int maxPhase = shared.getMaxPhase(); + bases2D.push_back({row, (vec * ((row / perPhase) % maxPhase)) % numCols}); + } + LinearLayout ctaLayout = + LinearLayout({{S("offset"), bases2D}}, {rowDimName, colDimName}); + + // Add the remaining dimensions. + for (int i = 2; i < rank; i++) { + int dim = shared.getOrder()[i]; + ctaLayout *= LinearLayout::identity1D(shapePerCTA[dim], S("offset"), + outDimNames[dim]); + } + + return combineCtaCgaWithShape(ctaLayout, shared.getCTALayout(), shape); +} + +LinearLayout +sharedToLinearLayoutAMDRotating(ArrayRef shape, + AMDRotatingSharedEncodingAttr shared) { + MLIRContext *ctx = shared.getContext(); + + auto shapePerCTA = getShapePerCTA(shared, shape); + + int rank = shape.size(); + if (rank == 1) { + return combineCtaCgaWithShape( + LinearLayout::identity1D(shapePerCTA[0], S("offset"), S("dim0")), + shared.getCTALayout(), shape); + } + + auto outDimNames = standardOutDimNames(ctx, rank); + + // Construct bases for the 2 most minor dimensions of the layout. These are + // the dims that get swizzled. + assert(shape.size() >= 2); + int colDim = shared.getOrder()[0]; + int rowDim = shared.getOrder()[1]; + int numCols = shape[colDim]; + int numRows = shape[rowDim]; + StringAttr colDimName = outDimNames[colDim]; + StringAttr rowDimName = outDimNames[rowDim]; + + std::vector> bases2D; + for (int col = 1; col < numCols; col *= 2) { + bases2D.push_back({0, col}); + } + for (int row = 1; row < numRows; row *= 2) { + int vec = shared.getVec(); + int perPhase = shared.getPerPhase(); + int maxPhase = shared.getMaxPhase(); + + int phase = (row / perPhase) % maxPhase; + int blockNo = row / maxPhase / perPhase % maxPhase; + int combinedPhase = phase ^ blockNo; + bases2D.push_back({row, (vec * combinedPhase) % numCols}); + } + LinearLayout ctaLayout = + LinearLayout({{S("offset"), bases2D}}, {rowDimName, colDimName}); + + // Add the remaining dimensions. + for (int i = 2; i < rank; i++) { + int dim = shared.getOrder()[i]; + ctaLayout *= + LinearLayout::identity1D(shape[dim], S("offset"), outDimNames[dim]); + } + + return combineCtaCgaWithShape(ctaLayout, shared.getCTALayout(), shape); +} + +// Returns the layout of a single core matrix which tiles the nvmma layout +LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared, + bool disableSwizzle) { + auto *ctx = shared.getContext(); + + int elemBitWidth = shared.getElementBitWidth(); + int tileWidthBytes = shared.getSwizzlingByteWidth(); + int vec = shared.getVec(); + int perPhase = shared.getPerPhase(); + int maxPhase = shared.getMaxPhase(); + + int tileRows = 8; + int tileCols = 8 * tileWidthBytes / elemBitWidth; + bool isFp4Padded = shared.getFp4Padded(); + + std::vector> bases2D; + for (int col = 1; col < tileCols; col *= 2) { + if (isFp4Padded) { + // Each group of 16 offsets consists of 8 "real" and 8 "padded" offsets. + // We represent the padded layout by mapping 8 padded offsets to the same + // coordinates as the real ones. When computing the inverse of this LL, + // the offsets correspoding to the real ones are picked in the image by + // invertAndCompose. + int colPacked = col / 16 * 8 + col % 8; + bases2D.push_back({0, colPacked}); + } else { + bases2D.push_back({0, col}); + } + } + for (int row = 1; row < tileRows; row *= 2) { + if (disableSwizzle) { + bases2D.push_back({row, 0}); + } else if (isFp4Padded) { + int colPadded = vec * ((row / perPhase) % maxPhase); + int colPacked = colPadded / 16 * 8 + colPadded % 8; + bases2D.push_back({row, colPacked}); + } else { + bases2D.push_back({row, vec * ((row / perPhase) % maxPhase)}); + } + } + auto outDimNames = standardOutDimNames(ctx, 2); + return LinearLayout({{S("offset"), bases2D}}, outDimNames); +} + +} // namespace + +LinearLayout nvmmaSharedToLinearLayout(ArrayRef shape, + NVMMASharedEncodingAttr shared, + bool disableSwizzle) { + MLIRContext *ctx = shared.getContext(); + int rank = shape.size(); + auto shapePerCTA = getShapePerCTA(shared, shape); + auto kOffset = S("offset"); + auto tmaShape = triton::nvidia_gpu::getTMABlockShape(shared, shapePerCTA, + /*packedSize=*/true); + if (shared.getSwizzlingByteWidth() == 0) { + auto outDimNames = standardOutDimNames(ctx, rank); + LinearLayout layout = LinearLayout::identity1D(tmaShape[rank - 1], kOffset, + outDimNames[rank - 1]); + for (int i = rank - 2; i >= 0; --i) { + layout *= LinearLayout::identity1D(tmaShape[i], kOffset, outDimNames[i]); + } + layout = ensureLayoutNotSmallerThan(layout, outDimNames, shapePerCTA); + return combineCtaCgaWithShape(layout, shared.getCTALayout(), shape); + } + assert(rank >= 2); + + // Collapse all the outer dim into one. We will then create a layout for this + // shape and reshape it to the original shape. + std::array collapsedTmaShape{1, tmaShape.back()}; + for (int i = 0; i + 1 < rank; i++) + collapsedTmaShape[0] *= tmaShape[i]; + if (shared.getTransposed()) { + std::swap(collapsedTmaShape[0], collapsedTmaShape[1]); + } + + auto tileLayout = getCoreMatrixLinearLayout(shared, disableSwizzle); + auto outDimNames = standardOutDimNames(ctx, 2); + auto kRow = outDimNames[0]; + auto kCol = outDimNames[1]; + auto tileRows = tileLayout.getOutDimSize(kRow); + auto tileCols = tileLayout.getOutDimSize(kCol); + + int packingFactor = shared.getFp4Padded() ? 2 : 1; + if (collapsedTmaShape[1] * packingFactor < tileCols || + collapsedTmaShape[0] < tileRows) { + llvm::errs() << "Illegal shared layout; expected collapsed shapePerCTA to " + "be at least [" + << tileRows << ", " << (tileCols / packingFactor) + << "], collapsedTmaShape: [" << collapsedTmaShape[0] << ", " + << collapsedTmaShape[1] << "]\n"; + llvm::report_fatal_error("Illegal shared layout"); + } + + // Distribute the remaining rows and cols. + auto layout = + ensureLayoutNotSmallerThan(tileLayout, outDimNames, collapsedTmaShape); + + // Reshape the layout to the N-D pre-transposed shape per CTA. + SmallVector maybeTransposedTmaShape = tmaShape; + if (shared.getTransposed()) { + // Move the outer dim to the inner position. + // TODO: we should move back to using `order` instead of transposed to make + // the order more explicit. + std::rotate(maybeTransposedTmaShape.begin(), + maybeTransposedTmaShape.begin() + 1, + maybeTransposedTmaShape.end()); + } + auto reshapedLayout = reshapeLayout(ctx, layout, maybeTransposedTmaShape); + + if (shared.getTransposed()) { + SmallVector order = {rank - 1}; + for (int i = 0; i < rank - 1; i++) { + order.push_back(i); + } + reshapedLayout = transposeLinearLayout(reshapedLayout, order); + } + + reshapedLayout = ensureLayoutNotSmallerThan( + reshapedLayout, standardOutDimNames(ctx, shapePerCTA.size()), + shapePerCTA); + return combineCtaCgaWithShape(reshapedLayout, shared.getCTALayout(), shape); +} + +/// Function to generate lane and warp layout for dot operands. +static LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx, + ArrayRef shape, + ArrayRef order, + unsigned kDim, + StringAttr inDimName) { + // Let warpsPerCTAMma = {2, 2}, then + // warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB + // assume warpOrder = {1, 0} + // Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that + // the C is owned as per the following layout: + // C: 0 | 1 + // - | - + // 2 | 3 + // In order to be able to compute C, we need the following warp tiling of + // A and B: + // A: 0 1 | 0 1 B: 0 2 | 1 3 + // - - | - - - - | - - + // 2 3 | 2 3 0 2 | 1 3 + // In other words, we need to broadcast along K + auto rank = shape.size(); + auto dimNames = standardOutDimNames(ctx, rank); + LinearLayout layout = LinearLayout::empty(); + + // We have to broadcast along the inner dimension + // For A, when moving along M we go from 0 to 2. + // For B, when moving along N we go from 0 to 1. + // As such, choosing the order of A {1, 0}, gives us the correct broadcasting + // Same happens if the warpOrder is {0, 1}, like in Hopper + for (auto d : order) { + if (d == kDim) { + layout *= LinearLayout::zeros1D(shape[d], inDimName, dimNames[d]); + } else { + layout *= LinearLayout::identity1D(shape[d], inDimName, dimNames[d]); + } + } + return layout; +} + +LinearLayout +AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { + int rank = shape.size(); + assert(rank == getRank()); + + bool hasBatchDim = rank == 3; + int mIndex = 0 + hasBatchDim; + int nIndex = 1 + hasBatchDim; + (void)mIndex, (void)nIndex; + + assert(((getMDim() == 32 && getNDim() == 32) || + (getMDim() == 16 && getNDim() == 16)) && + "Unsupported mfma type"); + + MLIRContext *ctx = getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + + // https://github.com/ROCm/amd_matrix_instruction_calculator can print the + // register and lane layout for mfma instructions. + + // We use the order from fastest varying to slowest varying. So each base + // vector is a tuple of values mapping to matrix C's (N, M[, B]) indices. + SmallVector order = getDefaultMmaOrder(*this); + auto tileLayout = LinearLayout::empty(); + + if (getMDim() == 32) { + // For mfma with 32x32 output, each of the 64 threads holds 16 elements. + // + // For the register (i.e., element) dimension, these 16 elements are along + // the matrix C's M dimension, with 4 consecutive elements spanning 4 rows + // and then the next 4 rows being a gap. + // + // For the lane (i.e., thread) dimension, these threads are along the + // matrix C's N dimension, with 32 consecutive threads covering a whole + // row and the next 32 threads start after a gap spanning 4 rows. + tileLayout = LinearLayout( + {{kRegister, {{0, 1}, {0, 2}, {0, 8}, /*gap*/ {0, 16}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, /*gap*/ {0, 4}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + // For mfma.transposed layout, the element ownership among threads are + // "transposed" within each warp. + if (getIsTransposed()) + tileLayout = LinearLayout( + {{kRegister, {{1, 0}, {2, 0}, {8, 0}, /*gap*/ {16, 0}}}, + {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, /*gap*/ {4, 0}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + } else { + assert(getMDim() == 16); + // For mfma with 16x16 output, each of the 64 threads holds 4 elements. + // + // For the register (i.e., element) dimension, these 4 elements are along + // the matrix C's M dimension, with 4 consecutive elements spanning 4 rows. + // + // For the lane (i.e., thread) dimension, these threads are along the + // matrix C's N dimension, with 16 consecutive threads covering a whole + // row and the next 16 threads start after a gap spanning 4 rows. + tileLayout = LinearLayout( + {{kRegister, {{0, 1}, {0, 2}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 4}, {0, 8}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + // For mfma.transposed layout, the element ownership among threads are + // "transposed" within each warp. + if (getIsTransposed()) + tileLayout = LinearLayout( + {{kRegister, {{1, 0}, {2, 0}}}, + {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, /*gap*/ {4, 0}, {8, 0}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + } + if (hasBatchDim) { + assert(order[2] == 0); + // Extend the base vector with one value to accommodate for the batch + // dimension, which appears at the last. + tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]); + tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]); + } + + // And each warp takes the same register and lane sub-layout. So multiply with + // an identity layout for the warp. + LinearLayout warpLayout = + identityStandardND(S("warp"), getWarpsPerCTA(), order); + LinearLayout ctaLayout = tileLayout * warpLayout; + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); +} + +LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout, + ArrayRef shape, + int32_t elemBitWidth) { + auto mfmaLayout = llvm::cast(dotMfmaLayout.getParent()); + auto mDim = mfmaLayout.getMDim(); + assert(mDim == 16 || mDim == 32); + assert(elemBitWidth == 16 || elemBitWidth == 8); + + auto rank = shape.size(); + bool hasBatchDim = rank == 3; + int32_t kWidthDot = dotMfmaLayout.getKWidth(); + // Number of bits loaded by an LDS read. ds_read_tr primarily supports 64-bit + // loads for most element sizes (16b, 8b, 4b). + const int32_t ldsReadWidth = 64; + int32_t kWidthTransRead = ldsReadWidth / elemBitWidth; + const int elemByteWidth = elemBitWidth / 8; + auto kDim = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + + int32_t kSize = shape[kDim]; + auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + + MLIRContext *ctx = dotMfmaLayout.getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + + // register order + // operand A: [1, 0] / [2, 1, 0] + // operand B: [0, 1] / [1, 2, 0] + // Regular dot mfma order for both cases is [k, nonk]/[k, nonk, batch] + // For LDS transpose layout swap order to [nonk, k]/[nonk, k, batch] + SmallVector order = + getOrderForDotOperand(dotMfmaLayout.getOpIdx(), rank, /*kContig*/ false); + + // For ds_read_b64_tr_* instructions, each thread accesses 64 bits (8 bytes) + // of data. The smallest unit for transposition is a + // [non-K, K] = {16, kWidthTransRead} sub-tile of elements, + // where each thread reads kWidthTransRead elements along the non-K dimension. + // Due to the transposition mechanism, each thread ends up with + // kWidthTransRead elements along the K dimension. + // + // The MFMA selection logic prioritizes double-rate MFMA instructions whenever + // possible: + // + // - For MFMA operations where M = N = 16, when blockK > k, mfma16x16x2*k + // is selected; otherwise (blockK ≤ k), mfma16x16xk remains the choice. + // + // - For MFMA operations where M = N = 32, when blockK > k, mfma32x32x2*k is + // selected; otherwise (blockK ≤ k), mfma32x32xk is used. + // + // NOTE: For fp8 and fp4, "double-rate" results in 4*k since scaled MFMA + // instructions are used. + // + // In "double-rate" MFMA instructions, each thread holds 2*kWidthTransRead + // elements along the K dimension: + // - The first kWidthTransRead elements belong to the first sub-tile. + // - The next kWidthTransRead elements belong to the second sub-tile. + // + // These elements are then grouped into larger tiles, each consisting of + // 8 {16, kWidthTransRead} sub-tiles. These tiles correspond to the data + // for one MFMA instruction. The shape of these tiles depends on the MFMA + // instruction used. + // + // For single-rate MFMA instructions, each thread holds kWidthTransRead + // elements along the K dimension. This means that the larger tile + // (corresponding to one MFMA instruction) consists of 4 {16, kWidthTransRead} + // sub-tiles. + std::vector> registerBase; + std::vector> laneBase; + + // Populate register base for first subtile + for (int i = 1; i < kWidthTransRead; i *= 2) { + registerBase.push_back({i, 0}); + } + + const int threadsPerSubtileNonK = 16 / kWidthTransRead; + const int threadsPerSubtileK = kWidthTransRead; + + // Populate lane base for first subtile + for (int i = 1; i < threadsPerSubtileNonK; i *= 2) { + laneBase.push_back({i * kWidthTransRead, 0}); + } + for (int i = 1; i < threadsPerSubtileK; i *= 2) { + laneBase.push_back({0, i}); + } + + // Function to extend register base for multiple tiles K dim. + auto extendRegisterBaseForKDim = [&](int kTileSize, int numSubtilesPerTile) { + const int regsPerTile = kWidthTransRead * numSubtilesPerTile; + int totalRegs = (kSize / kTileSize) * regsPerTile; + + for (int reg = regsPerTile; reg < totalRegs; reg *= 2) { + registerBase.push_back({0, (reg / regsPerTile) * kTileSize}); + } + }; + + const bool isMfma32 = (mDim == 32); + const bool isMfma16 = (mDim == 16); + + // kDoubleTileSize is the k dimension of a tile when double rated + // mfma instructions are used. + const int kDoubleTileSize = + isMfma32 ? 32 / elemByteWidth : 64 / elemByteWidth; + // kTileSize is the actually k dimention of a tile, which is + // determined by kWidthDot. + const int kTileSize = kWidthDot * 64 / mDim; + // We use kDoubleTileSize as a reference to check whether the given + // kWidthDot leads to double or single sub-tiles in each tile. + const int numSubtilesPerTile = (kTileSize == kDoubleTileSize) ? 2 : 1; + + // Extend register base for large K sizes. + if (numSubtilesPerTile == 2) + registerBase.push_back({0, threadsPerSubtileK}); // Second subtile + + extendRegisterBaseForKDim(kTileSize, numSubtilesPerTile); + + // Extend lane base based on MFMA size. + std::vector> laneBaseExt; + + if (isMfma32) { + laneBaseExt = {{16, 0}, {0, numSubtilesPerTile * threadsPerSubtileK}}; + } else { + laneBaseExt = {{0, numSubtilesPerTile * threadsPerSubtileK}, + {0, 2 * numSubtilesPerTile * threadsPerSubtileK}}; + } + + laneBase.insert(laneBase.end(), laneBaseExt.begin(), laneBaseExt.end()); + + // Base vectors above are defined in a fixed order [non-k-dim, k-dim]. + // To assign them to actual matrix dimensions we associate with register + // `order` which is also [nonk, k] given we set kContig to false. + LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + + if (hasBatchDim) { + assert(order[2] == 0); + // Extend the base vector with one value to accommodate for the batch + // dimension, which appears at the last. + tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]); + tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]); + } + + // warp order + // common for both operand A and B: [0, 1] / [0, 1, 2] + // in both cases it is [M dim, N dim]/[batch, M dim, N dim] + auto warpOrder = getDefaultMmaOrder(mfmaLayout); + LinearLayout warpLayout = identityStandardND(kWarp, warpsPerCTA, warpOrder); + + LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) * + warpLayout.transposeOuts(outDimNames); + return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape); +} + +LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, + ArrayRef shape) { + auto mfmaLayout = llvm::cast(dotMfmaLayout.getParent()); + + auto rank = shape.size(); + bool hasBatchDim = rank == 3; + int mIndex = 0 + hasBatchDim; + + int32_t kWidth = dotMfmaLayout.getKWidth(); + auto kDim = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + int32_t kSize = shape[kDim]; + auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + + MLIRContext *ctx = dotMfmaLayout.getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + + // register order + // operand A: [1, 0] / [2, 1, 0] + // operand B: [0, 1] / [1, 2, 0] + // for both cases it is [k, nonk]/[k, nonk, batch] + auto order = + getOrderForDotOperand(dotMfmaLayout.getOpIdx(), rank, /*kContig*/ true); + + // warp order + // common for both operand A and B: [0, 1] / [0, 1, 2] + // in both cases it is [M dim, N dim]/[batch, M dim, N dim] + auto warpOrder = getDefaultMmaOrder(mfmaLayout); + + // Lane holds kWidth consecutive elements along k dimension, so + // base register vectors for one tile are initialized in following way: + // {1, 0}, {2, 0} ... {kWidth/2, 0} + std::vector> registerBase; + for (int32_t elem = 1; elem < kWidth; elem *= 2) + registerBase.emplace_back(std::vector{elem, 0}); + + std::vector> laneBase; + int32_t kTileSize = -1; + + if (mfmaLayout.getMDim() == 32) { + // Canonical MFMA linear layout handles 4 consecutive elements along + // the register dimension. Dot operand handles variable kWidth consecutive + // elements. For lane dim, since the MFMA thread arrangement is {K, N} = {2, + // 32}, this means that mapping of first 5 base (up to thread 16) vectors + // will be an identity along N dim. Thread 32 will be mapped to element + // kWidth in K dimension. + laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {kWidth, 0}}; + kTileSize = kWidth * 2; + } else { + assert(mfmaLayout.getMDim() == 16); + // For lane dim, since the MFMA thread arrangement is {K, N} = {4, 16}, this + // means that mapping of first 4 base (up to thread 16) vectors will be an + // identity along N dim. Thread 16 will be mapped to element kWisth in K + // dimension. Thread 32 is mapped to element 2*kWidth in K dim. + laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {kWidth, 0}, {kWidth * 2, 0}}; + kTileSize = kWidth * 4; + } + assert(kTileSize != -1); + // Add repeats of registers along K dimension to register base vectors + for (int32_t elem = kTileSize; elem < kSize; elem *= 2) + registerBase.emplace_back(std::vector{elem, 0}); + + // Base vectors above are defined in a fixed order [k-dim, non-k-dim]. + // To assign them to actual matrix dimensions we assoicate with register + // `order` which is also also [k, nonk]. + LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + + if (hasBatchDim) { + assert(order[2] == 0); + // Extend the base vector with one value to accommodate for the batch + // dimension, which appears at the last. + tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]); + tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]); + } + + LinearLayout warpLayout = identityStandardND(kWarp, warpsPerCTA, warpOrder); + + LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) * + warpLayout.transposeOuts(outDimNames); + + return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape); +} + +LinearLayout +AMDWmmaEncodingAttr::toLinearLayout(ArrayRef shape) const { + int rank = shape.size(); + assert(rank == getRank()); + + bool hasBatchDim = rank == 3; + int mIndex = 0 + hasBatchDim; + int nIndex = 1 + hasBatchDim; + (void)mIndex, (void)nIndex; + + SmallVector mnkDim = getMNKDimPerInstr(); + unsigned mDim = mnkDim[0], nDim = mnkDim[1]; + (void)mDim, (void)nDim; + + assert(((shape[mIndex] == 1 || shape[mIndex] >= mDim) && + (shape[nIndex] == 1 || shape[nIndex] >= nDim)) && + "Unsupported tensor shape for given wmma layout"); + + MLIRContext *ctx = getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + + // https://github.com/ROCm/amd_matrix_instruction_calculator can print the + // register and lane layout for mfma instructions. + + // We use the order from fastest varying to slowest varying. So each base + // vector is a tuple of values mapping to matrix C's (N, M[, B]) indices. + auto threadOrder = getMatrixOrder(rank, /*rowMajor*/ !getIsTransposed()); + assert(threadOrder[0] == mIndex || threadOrder[0] == nIndex); + assert(threadOrder[1] == mIndex || threadOrder[1] == nIndex); + + // For wmma with 16x16 output, each of the 32 threads holds 8 elements. + // + // The first version of WMMA layout has following specific: + // for the register (i.e., element) dimension, these 8 elements are + // along the matrix C's M dimension, with 1 consecutive elements + // spanning 1 row and then the next 1 row being a gap. + // + // For the lane (i.e., thread) dimension, these threads are along the + // matrix C's N dimension, with 16 consecutive threads covering a whole + // row and the next 16 threads start at the next row. + // + // The second version of wmma layout is less tricky: + // for the register dimension 8 elements are along the matrix C's M + // dimension. First 16 lanes take 0-8 elems along M, second 16 take 8-15. + // We have 16 pair of threads in each warp, one pair covers the whole + // column. + // + // Please also check explaining comments in TritonGPUAttrDefs.td at the + // AMDWmmaEncodingAttr section. + unsigned ver = getVersion(); + assert(ver == 1 || ver == 2); + LinearLayout tileLayout = + ver == 1 + ? LinearLayout( + {{kRegister, {/*gap*/ {0, 2}, {0, 4}, {0, 8}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 1}}}}, + {outDimNames[threadOrder[0]], outDimNames[threadOrder[1]]}) + : LinearLayout( + {{kRegister, {{0, 1}, {0, 2}, {0, 4}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 8}}}}, + {outDimNames[threadOrder[0]], outDimNames[threadOrder[1]]}); + + if (hasBatchDim) { + int batchIndex = 0; + // Extend the base vector with one value to accommodate for the batch + // dimension, which appears at the last. + tileLayout *= + LinearLayout::identity1D(1, kRegister, outDimNames[batchIndex]); + tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[batchIndex]); + } + + // And each warp takes the same register and lane sub-layout. So multiply with + // an identity layout for the warp. + auto warpOrder = getDefaultMmaOrder(*this); + LinearLayout warpLayout = + identityStandardND(S("warp"), getWarpsPerCTA(), warpOrder); + // reorder dim names in rep order, so combineCtaCgaWithShape generate proper + // extension of layout + auto repOrder = getRepOrder(); + SmallVector repDimNames; + for (auto dim : repOrder) + repDimNames.push_back(outDimNames[dim]); + LinearLayout ctaLayout = tileLayout.transposeOuts(repDimNames) * + warpLayout.transposeOuts(repDimNames); + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); +} + +LinearLayout wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout, + ArrayRef shape) { + auto wmmaLayout = llvm::cast(dotWmmaLayout.getParent()); + auto rank = shape.size(); + bool hasBatchDim = rank == 3; + auto kDim = dotWmmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + int32_t kSize = shape[kDim]; + MLIRContext *ctx = dotWmmaLayout.getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + // lane order + // operand A: [1, 0] / [2, 1, 0] + // operand B: [0, 1] / [1, 2, 0] + // for both cases it is [k, nonk]/[k, nonk, batch] + auto laneOrder = + getOrderForDotOperand(dotWmmaLayout.getOpIdx(), rank, /*kContig*/ true); + // generate continuous part of register bases(i.e. kWidth) + std::vector> registerBase; + const int32_t kWidth = dotWmmaLayout.getKWidth(); + for (int i = 1; i < kWidth; i *= 2) + registerBase.push_back(std::vector{i, 0}); + std::vector> laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}}; + switch (wmmaLayout.getVersion()) { + case 1: + // WMMA version 1 duplicates values in lanes 0-15 and 16-31 + laneBase.push_back({0, 0}); + break; + case 2: + // WMMA version 2 offset values in lanes 0-15 and 16-31 across k dimensions + laneBase.push_back({kWidth, 0}); + break; + default: + assert(false && "unexpected version"); + } + // Generate layout for one wmma instruction + LinearLayout tileLayout( + {{kRegister, registerBase}, {kLane, laneBase}}, + {outDimNames[laneOrder[0]], outDimNames[laneOrder[1]]}); + if (hasBatchDim) { + assert(laneOrder[2] == 0); + // Extend the base vector with one value to accommodate for the batch + // dimension, which appears at the last. + tileLayout *= + LinearLayout::identity1D(1, kRegister, outDimNames[laneOrder[2]]); + tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[laneOrder[2]]); + } + + // Generate warp layout + auto warpsPerCTA = wmmaLayout.getWarpsPerCTA(); + auto warpOrder = getDefaultMmaOrder(wmmaLayout); + LinearLayout warpLayout = + broadcastedDotOperandLayout(ctx, warpsPerCTA, warpOrder, kDim, S("warp")); + + // reorder dim names in rep order, so combineCtaCgaWithShape generate proper + // extension of layout + auto repOrder = wmmaLayout.getRepOrderForOperand(dotWmmaLayout.getOpIdx()); + SmallVector repDimNames; + for (auto dim : repOrder) + repDimNames.push_back(outDimNames[dim]); + + // join instruction layout and warps using repetition order of dimensions + LinearLayout ctaLayout = tileLayout.transposeOuts(repDimNames) * + warpLayout.transposeOuts(repDimNames); + + return combineCtaCgaWithShape(ctaLayout, wmmaLayout.getCTALayout(), shape); +} + +LinearLayout sunrisemmaDotOperandToLinearLayout(DotOperandEncodingAttr dotEncAttr, ArrayRef shape) { + MLIRContext *ctx = dotEncAttr.getContext(); + auto rank = shape.size(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + unsigned dotOpIdx = dotEncAttr.getOpIdx(); + SunriseMmaEncodingAttr sunriseMmaAttr = cast(dotEncAttr.getParent()); + // auto order = getOrderForDotOperand(dotOpIdx, rank, /*kContig*/ true); //???是否应该写死[1,0] + SmallVector order({1,0}); + unsigned elemBitWidth = sunriseMmaAttr.getInputElemBitWidth(); + auto tileLayout = LinearLayout::empty(); + switch(elemBitWidth) { + case 32: + if(dotOpIdx == 0) { + tileLayout = LinearLayout( + {{S("register"), {}}, + {S("lane"), {{1,0}, {2,0}, {0,1}, {0,2}, {0,4}}} + }, {outDimNames[order[0]], outDimNames[order[1]]} + ); + } else { + tileLayout = LinearLayout( + {{S("register"), {}}, + {S("lane"), {{1,0}, {2,0}, {4,0}, {0,1}, {0,2}}} + }, {outDimNames[order[0]], outDimNames[order[1]]} + ); + } + break; + case 16: + if(dotOpIdx == 0) { + tileLayout = LinearLayout( + //{{S("register"), {{1,0}, {8,0}}}, // 有么有8,0好像都行? + {{S("register"), {{1,0}}}, + {S("lane"), {{2,0}, {4,0}, {0,1}, {0,2}, {0,4}}} + }, {outDimNames[order[0]], outDimNames[order[1]]} + ); + } else { + tileLayout = LinearLayout( + {{S("register"), {{1,0}}}, + {S("lane"), {{2,0}, {4,0}, {0,1}, {0,2}, {0,4}}} + }, {outDimNames[order[0]], outDimNames[order[1]]} + ); + } + break; + case 8: + if(dotOpIdx == 0) { + tileLayout = LinearLayout( + {{S("register"), {{1,0}, {2,0}}}, + {S("lane"), {{4,0}, {8,0}, {0,1}, {0,2}, {0,4}}} + }, {outDimNames[order[0]], outDimNames[order[1]]} + ); + } else { + tileLayout = LinearLayout( + {{S("register"), {{1,0}, {2,0}}}, + {S("lane"), {{4,0}, {0,1}, {0,2}, {0,4}, {0,8}}} + }, {outDimNames[order[0]], outDimNames[order[1]]} + ); + } + break; + case 4: + default: + llvm::report_fatal_error("linearlayout not implemented!"); + break; + } + + auto kDim = dotOpIdx == 0 ? rank - 1 : rank - 2; + auto warpOrder = getDefaultMmaOrder(sunriseMmaAttr); + // SmallVector warpOrder = dotOpIdx == 0 ? SmallVector({1,0}) : SmallVector({0,1}); + + // LinearLayout warpLayout = identityStandardND(S("warp"), sunriseMmaAttr.getWarpsPerCTA(), warpOrder); + LinearLayout warpLayout = broadcastedDotOperandLayout(ctx, sunriseMmaAttr.getWarpsPerCTA(), warpOrder, kDim, S("warp")); + + // reorder dim names in rep order, so combineCtaCgaWithShape generate proper + // extension of layout + auto repOrder = sunriseMmaAttr.getRepOrderForOperand(dotOpIdx); + SmallVector repDimNames; + for (auto dim : repOrder) + repDimNames.push_back(outDimNames[dim]); + + // join instruction layout and warps using repetition order of dimensions + LinearLayout ctaLayout = tileLayout.transposeOuts(repDimNames) * + warpLayout.transposeOuts(repDimNames); + return combineCtaCgaWithShape(ctaLayout, sunriseMmaAttr.getCTALayout(), shape); +} + +LinearLayout SunriseMmaEncodingAttr::toLinearLayout(ArrayRef shape) const { + int rank = shape.size(); + assert(rank == 2); + + MLIRContext *ctx = getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + // SmallVector order = getDefaultMmaOrder(*this); + // SmallVector order = getOrderForDotOperand(dotWmmaLayout.getOpIdx(), rank, /*kContig*/ true); + SmallVector order = SmallVector({1,0}); + + SunriseMmaEncodingAttr::TMMAOutLayout outLayout = getOutLayout(); + if ((outLayout != SunriseMmaEncodingAttr::TMMAOutLayout::Row_2B) && (outLayout != SunriseMmaEncodingAttr::TMMAOutLayout::ARow_4B_8x4)) { + assert(0 && "Unsupport SunriseMmaEncodingAttr::TMMAOutLayout yet"); + } + auto tileLayout = LinearLayout::empty(); + if(outLayout == SunriseMmaEncodingAttr::TMMAOutLayout::Row_2B) { + // fp16 row + tileLayout = LinearLayout( + {{S("register"), {{1, 0}}}, + //{{S("register"), {}}, + {S("lane"), {{2,0}, {4,0}, {0,1}, {0,2}, {0,4}}} + }, + {outDimNames[order[0]], outDimNames[order[1]]} + ); + } else if(outLayout == SunriseMmaEncodingAttr::TMMAOutLayout::ARow_4B_8x4) { + // fp32 8x4 + tileLayout = LinearLayout( + {{S("register"), {{4, 0}}}, + {S("lane"), {{1,0}, {2,0}, {0,1}, {0,2}, {0,4}}} + }, + {outDimNames[order[0]], outDimNames[order[1]]} + ); + } else { + llvm::report_fatal_error("Unsupport sunrisemma outLayout"); + } + LinearLayout warpLayout = identityStandardND(S("warp"), getWarpsPerCTA(), order); + // reorder dim names in rep order, so combineCtaCgaWithShape generate proper + // extension of layout + auto repOrder = getRepOrder(); + SmallVector repDimNames; + for (auto dim : repOrder) + repDimNames.push_back(outDimNames[dim]); + + // join instruction layout and warps using repetition order of dimensions + LinearLayout ctaLayout = tileLayout.transposeOuts(repDimNames) * + warpLayout.transposeOuts(repDimNames); + return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); +} + +LinearLayout +BlockedEncodingAttr::toLinearLayout(ArrayRef shape) const { + MLIRContext *ctx = getContext(); + auto order = getOrder(); + LinearLayout ctaLayout = + identityStandardND(S("register"), getSizePerThread(), order) * + identityStandardND(S("lane"), getThreadsPerWarp(), order) * + identityStandardND(S("warp"), getWarpsPerCTA(), order); + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); +} + +LinearLayout fmaDotToLinearLayout(DotOperandEncodingAttr operandLayout, + ArrayRef shape) { + int rank = shape.size(); + auto blocked = cast(operandLayout.getParent()); + MLIRContext *ctx = operandLayout.getContext(); + + // TODO: introduce registerOrder or use getDefaultOrder(operandLayout) + // Currently this order is used in legacy converter, because we do not + // have access to full dot operand layout, only parent part. + auto regOrder = blocked.getOrder(); + auto threadOrder = blocked.getOrder(); + auto warpOrder = blocked.getOrder(); + auto repOrder = blocked.getRepOrder(); + + StringAttr kReg = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + + auto threadSize = llvm::to_vector(blocked.getSizePerThread()); + auto kDimIdx = operandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + threadSize[kDimIdx] = shape[kDimIdx]; + auto threadShape = blocked.getThreadsPerWarp(); + auto warpShape = blocked.getWarpsPerCTA(); + + SmallVector repDimNames = + permuteDimNames(standardOutDimNames(ctx, rank), repOrder); + + auto registersLayout = identityStandardND(kReg, threadSize, regOrder); + auto lanesLayout = broadcastedDotOperandLayout(ctx, threadShape, threadOrder, + kDimIdx, kLane); + auto warpsLayout = + broadcastedDotOperandLayout(ctx, warpShape, warpOrder, kDimIdx, kWarp); + + LinearLayout ctaLayout = registersLayout.transposeOuts(repDimNames) * + lanesLayout.transposeOuts(repDimNames) * + warpsLayout.transposeOuts(repDimNames); + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(operandLayout), shape); +} + +LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef tileShape, + unsigned kWidth, ArrayRef order, + ArrayRef repOrder) { + // Trivial layout mapping 0 -> (0, 0), but we set the order to repOrder + // Like LinearLayout::empty() but with a rank and an order + int rank = repOrder.size(); + auto dimNames = standardOutDimNames(ctx, rank); + auto trivialShape = SmallVector(rank, 1); + LinearLayout ctaLayout = + identityStandardND(S("register"), trivialShape, repOrder); + + assert(rank >= 2); + auto inner = order[0]; + auto outer = order[1]; + + assert(tileShape.size() == rank); + int m = tileShape[outer]; + int n = tileShape[inner]; + + // The relative order of registers and lanes is given by: + // - Inner dim: kWidth registers + // - Inner dim: 4 lanes + // - Outer dim: 8 lanes + // - Outer dim: repeat m / 8 times + // - Inner dim: repeat n / (kWidth * 4) times + assert(m % 8 == 0); + assert(n % (kWidth * 4) == 0); + // There is at least one subtile on the inner-most dimension + // FIXME. We should implement operator* in terms of operator*= + // and chain *= instead of using * + auto outDimNames = llvm::to_vector(ctaLayout.getOutDimNames()); + ctaLayout = ctaLayout * + LinearLayout::identity1D(kWidth, S("register"), dimNames[inner]) * + LinearLayout::identity1D(4, S("lane"), dimNames[inner]) * + LinearLayout::identity1D(8, S("lane"), dimNames[outer]) * + LinearLayout::identity1D(m / 8, S("register"), dimNames[outer]) * + LinearLayout::identity1D(n / (kWidth * 4), S("register"), + dimNames[inner]); + return ctaLayout; +} + +LinearLayout +NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef shape) const { + auto ctx = getContext(); + int rank = shape.size(); + assert(rank == getRank()); + + SmallVector tileShape; + if (isAmpere()) { + // Ampere.getInstrShape() returns the tile shape + tileShape = SmallVector(getInstrShape()); + } else { + assert(isHopper()); + auto instrShapeMNK = getInstrShape(); + tileShape = SmallVector({instrShapeMNK[0], instrShapeMNK[1]}); + } + // nvidiamma layout always assumes kWidth = 2 + constexpr auto kWidth = 2; + auto order = getDefaultMmaOrder(*this); + auto ctaLayout = nvidiaMmaTile(ctx, tileShape, kWidth, order, getRepOrder()); + + auto warpOrder = getMatrixOrder(rank, /*rowMajor*/ !isHopper()); + ctaLayout *= identityStandardND(S("warp"), getWarpsPerCTA(), warpOrder) + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); +} + +LinearLayout nvidiaDotToLinearLayout(ArrayRef shape, + DotOperandEncodingAttr dot) { + int rank = shape.size(); + auto mma = cast(dot.getParent()); + int kWidth = dot.getKWidth(); + bool isA = dot.getOpIdx() == 0; + MLIRContext *ctx = mma.getContext(); + + SmallVector tileShape(rank, 1); + if (isA) { + tileShape[rank - 2] = 16; + tileShape[rank - 1] = kWidth * 8; + } else { + // Hopper takes the rhs via shared memory + assert(mma.isAmpere()); + tileShape[rank - 2] = kWidth * 8; + tileShape[rank - 1] = 8; + } + auto order = getOrderForDotOperand(dot.getOpIdx(), rank, /*kContig*/ true); + auto ctaLayout = + nvidiaMmaTile(ctx, tileShape, kWidth, order, dot.getRepOrder()); + auto kDim = isA ? rank - 1 : rank - 2; + auto warpOrder = getMatrixOrder(rank, /*rowMajor*/ !mma.isHopper()); + ctaLayout *= broadcastedDotOperandLayout(ctx, mma.getWarpsPerCTA(), warpOrder, + kDim, S("warp")) + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape); +} + +LinearLayout +DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { + auto parent = getParent(); + if (auto blockedLayout = mlir::dyn_cast(parent)) { + return fmaDotToLinearLayout(*this, shape); + } else if (auto mfmaLayout = mlir::dyn_cast(parent)) { + return mfmaDotToLinearLayout(*this, shape); + } else if (auto wmmaLayout = mlir::dyn_cast(parent)) { + return wmmaDotOperandToLinearLayout(*this, shape); + } else if (auto sunrisemmaLayout = mlir::dyn_cast(parent)){ + return sunrisemmaDotOperandToLinearLayout(*this, shape); + } else { + auto mma = mlir::cast(parent); + return nvidiaDotToLinearLayout(shape, *this); + } +} + +LinearLayout SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { + MLIRContext *ctx = getContext(); + + // First compute the linear layout for this layout's parent. + SmallVector parentShape(shape); + parentShape.insert(parentShape.begin() + getDim(), 1); + LinearLayout parentLL = triton::gpu::toLinearLayout(parentShape, getParent()); + + // Remove dimension getDim() from the parent layout. + // + // 1. Construct a layout `transform` from parent-out-dims to slice-out-dims + // that removes the relevant out-dim. + // 2. Compute linearSlice = parent.compose(transform). Now linearSlice maps + // from parent in-dims to slice out-dims. + // 3. Fix up duplicate registers introduced by slicing. + auto outDimNames = standardOutDimNames(ctx, shape.size() + 1); + LinearLayout transform = LinearLayout::empty(); + for (auto [idx, outDim] : llvm::enumerate(parentLL.getOutDimNames())) { + if (idx == getDim()) { + // Because we're multiplying by all zeros, we could replace outDimNames[0] + // with any other valid out-dim; the layout will be the same. + transform *= LinearLayout::zeros1D(parentLL.getOutDimSize(outDim), outDim, + outDimNames[0]); + } else { + transform *= + LinearLayout::identity1D(parentLL.getOutDimSize(outDim), outDim, + outDimNames[idx - (idx < getDim() ? 0 : 1)]); + } + } + LinearLayout sliceLL = parentLL.compose(transform); + + // Step 3: Along the "register" dim, remove any all-zero bases. + auto bases = sliceLL.getBases(); + std::vector> newRegBases; + for (const auto &basis : bases[S("register")]) { + if (llvm::any_of(basis, [](int b) { return b != 0; })) { + newRegBases.push_back(basis); + } + } + bases[S("register")] = newRegBases; + + return LinearLayout(std::move(bases), + llvm::to_vector(sliceLL.getOutDimNames())); +} + +LinearLayout TritonGPUDialect::toLinearLayout(ArrayRef shape, + Attribute layout) { + CacheKey key{std::vector(shape.begin(), shape.end()), layout}; + if (auto result = llCache.get(key)) { + return *result; + } + + // Layouts are distributed or shared in triton core + // To add a new layout add an else-if clause + LinearLayout result = LinearLayout::empty(); + if (auto distributed = dyn_cast(layout)) { + result = distributed.toLinearLayout(shape); + } else { + if (auto shared = dyn_cast(layout)) { + result = swizzledSharedToLinearLayout(shape, shared); + } else if (auto shared = dyn_cast(layout)) { + result = nvmmaSharedToLinearLayout(shape, shared); + } else if (auto sbl = dyn_cast(layout)) { + result = sharedToLinearLayoutAMDRotating(shape, sbl); + } else { + assert(0 && "unknown layout"); + } + } + + llCache.set(std::move(key), result); + return result; +} + +LinearLayout toLinearLayout(ArrayRef shape, Attribute layout) { + auto *ctx = layout.getContext(); + return ctx->getLoadedDialect()->toLinearLayout(shape, + layout); +} + +LinearLayout getLayoutWithinBlock(const LinearLayout &layout) { + assert(!layout.getInDimNames().empty()); + MLIRContext *ctx = layout.getInDimNames().begin()->getContext(); + + StringAttr kBlock = S("block"); + assert(layout.hasInDim(kBlock)); + auto bases = layout.getBases(); + bases[kBlock] = {}; + return LinearLayout(bases, llvm::to_vector<4>(layout.getOutDimNames())); +} + +LinearLayout chooseShemLayoutForRegToRegConversion( + MLIRContext *ctx, ArrayRef tensorShape, + ArrayRef repShape, ArrayRef order) { + auto outDimNames = standardOutDimNames(ctx, tensorShape.size()); + LinearLayout layout = LinearLayout::empty(); + SmallVector kRepDims; + SmallVector kOffsetDims; + auto totalIters = 1; + auto totalOffsets = 1; + for (int i = 0; i < tensorShape.size(); i++) { + int dim = order[i]; + StringAttr kIteration = S("iteration" + std::to_string(dim)); + StringAttr kOffset = S("offset" + std::to_string(dim)); + kRepDims.push_back(kIteration); + kOffsetDims.push_back(kOffset); + assert(llvm::isPowerOf2_32(repShape[dim])); + assert(llvm::isPowerOf2_32(tensorShape[dim])); + auto numIters = tensorShape[dim] / repShape[dim]; + layout *= + LinearLayout::identity1D(repShape[dim], kOffset, outDimNames[dim]); + layout *= LinearLayout::identity1D(numIters, kIteration, outDimNames[dim]); + totalIters *= numIters; + totalOffsets *= repShape[dim]; + } + StringAttr kOffset = S("offset"); + StringAttr kIteration = S("iteration"); + StringAttr kBlock = S("block"); + SmallVector newDims; + newDims.append(kOffsetDims.begin(), kOffsetDims.end()); + newDims.append(kRepDims.begin(), kRepDims.end()); + // Transpose layout from [offset0, rep0, offset1, rep1, ...] to + // [offset0, offset1, ..., rep0, rep1, ...] + auto ret = layout.transposeIns(newDims); + // Reshape layout from [offset0, offset1, ..., rep0, rep1, ...] to + // [offset, rep, block] + return ret.reshapeIns( + {{kOffset, totalOffsets}, {kIteration, totalIters}, {kBlock, 1}}); +} + +namespace { +LinearLayout chooseStMatrixLayoutNVMMA(MLIRContext *ctx, + RankedTensorType tensorTy, + int swizzleByteSize) { + int perPhase; + int maxPhase; + if (swizzleByteSize == 32) { + perPhase = 4; + maxPhase = 2; + } else if (swizzleByteSize == 64) { + perPhase = 2; + maxPhase = 4; + } else if (swizzleByteSize == 128) { + perPhase = 1; + maxPhase = 8; + } else { + llvm::errs() << "Illegal swizzleByteSize: " << swizzleByteSize << "\n"; + llvm::report_fatal_error("Illegal swizzleByteSize"); + } + + // stmatrix only supports 16-bit elements, and each vector has 8 elements + int elemBitWidth = 16; + int vecSize = 8; + int numRowsPerTile = 16; + int numColsPerChunk = 8 * swizzleByteSize / elemBitWidth; + + // Construct a single stmatrix.x4 (16x16) tile + std::vector> basesReg = {{1, 0}, {2, 0}, {4, 0}}; + std::vector> basesLane; + for (int row = 1; row < numRowsPerTile; row *= 2) { + basesLane.push_back({vecSize * ((row / perPhase) % maxPhase), row}); + } + basesLane.push_back({8, 0}); + + auto mma = cast(tensorTy.getEncoding()); + assert(mma.getVersionMajor() >= 3 && "Only MMAv3 is supported"); + int instrM = mma.getInstrShape()[0]; + int instrN = mma.getInstrShape()[1]; + + // TODO(Keren): The following logic can be simplified by using the + // `divideLeft` function in `LinearLayout` once it's available. + // Construct the bases for a single chunk + // In theory the following situation is valid but it will be + // suboptimal. Swizzling should happen within a warp. + assert(instrN >= numColsPerChunk && + "Each chunk is filled in with a single warp"); + for (int col = 1; col < numColsPerChunk / 16; col *= 2) { + basesReg.push_back({16 * col, 0}); + } + + // Construct the bases for warpsPerCTA[0] + std::vector> basesWarp; + auto warpsPerCTA = mma.getWarpsPerCTA(); + auto shapePerCTA = getShapePerCTA(tensorTy); + for (int warp = 1; warp < warpsPerCTA[0]; warp *= 2) { + basesWarp.push_back({0, warp * instrM}); + } + + // Expand the `register` dimension so the size of columns matches `shape[1] / + // warpsPerCTA[1]` + auto numColsPerWarp = std::max(instrN, shapePerCTA[1] / warpsPerCTA[1]); + assert(warpsPerCTA[1] * instrN >= shapePerCTA[1] && + "There must be enough columns to use MMAv3"); + auto numCols = numColsPerWarp / numColsPerChunk; + for (int col = 1; col < numCols; col *= 2) { + int basis = col * shapePerCTA[0]; + basesReg.push_back({0, basis}); + } + + // Expand the `register` dimension so that the size of rows matches `shape[0]` + assert(warpsPerCTA[0] * instrM <= shapePerCTA[0] && + "There must be enough rows to use MMAv3"); + auto numRows = shapePerCTA[0] / (warpsPerCTA[0] * instrM); + for (int row = 1; row < numRows; row *= 2) { + int basis = row * warpsPerCTA[0] * instrM; + basesReg.push_back({0, basis}); + } + + // Expand the `warp` dimension so that the size of cols matches `shape[1]` + for (int warp = 1; warp < warpsPerCTA[1]; warp *= 2) { + if (warp * numColsPerWarp >= shapePerCTA[1]) { + basesWarp.push_back({0, 0}); + } else { + int basis = (warp * numColsPerWarp) / numColsPerChunk * shapePerCTA[0]; + basesWarp.push_back({0, basis}); + } + } + + auto layout = LinearLayout({{S("register"), basesReg}, + {S("lane"), basesLane}, + {S("warp"), basesWarp}, + {S("block"), {}}}, + {S("offset1"), S("offset0")}); + return layout.reshapeOuts( + {{S("offset"), layout.getTotalOutDimSize()}, {S("iteration"), 1}}); +} + +LinearLayout chooseStMatrixLayoutSwizzled(MLIRContext *ctx, Attribute encoding, + ArrayRef shape) { + StringAttr kReg = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + StringAttr kCol = S("dim1"); + StringAttr kRow = S("dim0"); + StringAttr kBlock = S("block"); + + std::vector> basesReg = {{1, 0}, {2, 0}, {4, 0}}; + std::vector> basesLane = { + {0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}}; + LinearLayout layout = + LinearLayout({{kReg, basesReg}, {kLane, basesLane}}, {kCol, kRow}); + + // Expand the `register` dimension so the size of columns matches `n`. + auto mma = cast(encoding); + int n = mma.getInstrShape()[1]; + layout *= + LinearLayout::identity1D(n / layout.getOutDimSize(kCol), kReg, kCol); + + // Expand the `warp` dimension according to warpsPerCTA. + layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}) + .transposeOuts(llvm::to_vector(layout.getOutDimNames())); + auto ret = combineCtaCgaWithShape(layout, mma.getCTALayout(), shape); + auto tensorShapePerCTA = getShapePerCTA(mma, shape); + llvm::SmallDenseMap namedTensorShape; + namedTensorShape[kRow] = tensorShapePerCTA[0]; + namedTensorShape[kCol] = tensorShapePerCTA[1]; + ret = ensureLayoutNotSmallerThan(ret, namedTensorShape); + ret = ensureLayoutNotLargerThan(ret, namedTensorShape); + return ret.transposeOuts(llvm::to_vector(layout.getOutDimNames())) + .reshapeOuts( + {{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}}); +} + +LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot, + ArrayRef shape, bool needTrans, + int32_t elemBitWidth) { + auto ctx = dot.getContext(); + auto mma = cast(dot.getParent()); + auto rank = shape.size(); + auto opIdx = dot.getOpIdx(); + int kDim = (opIdx == 0) ? rank - 1 : rank - 2; + int nonKDim = (opIdx == 0) ? rank - 2 : rank - 1; + + StringAttr kReg = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + StringAttr kBlock = S("block"); + StringAttr kInner = opIdx == 0 ? (needTrans ? S("dim0") : S("dim1")) + : (needTrans ? S("dim1") : S("dim0")); + StringAttr kOuter = opIdx == 0 ? (needTrans ? S("dim1") : S("dim0")) + : (needTrans ? S("dim0") : S("dim1")); + + std::vector> basesReg; + for (int reg = 1; reg < 8 * 16 / elemBitWidth; reg *= 2) { + basesReg.push_back({0, reg}); + } + std::vector> basesLane = { + {1, 0}, {2, 0}, {4, 0}, {0, 0}, {0, 0}}; + bool kX2 = shape[kDim] > 8 * 16 / elemBitWidth; + bool kX4 = shape[kDim] > 16 * 16 / elemBitWidth; + bool nonKX2 = shape[nonKDim] > 8; + // Construct a tile consisting of 4 8x8x16bits sub-tiles to use ldmatrix + // efficiently. opIdx=0 and opIdx=1 are handled differently. + if (opIdx == 0) { + // The matrix elements of thread 0 are distributed in the following pattern + // (fp16): + // + // col0 col8 + // row0 reg[0-1] reg[4-5] + // row8 reg[2-3] reg[6-7] + if (needTrans) { + assert(elemBitWidth <= 16 && "Only elements smaller than 16 bits are " + "supported in the transposed mode"); + if (nonKX2) + basesLane[3] = {0, 8}; + if (kX2) + basesLane[4] = {8 * 16 / elemBitWidth, 0}; + } else { + if (nonKX2) + basesLane[3] = {8, 0}; + if (kX2) + basesLane[4] = {0, 8 * 16 / elemBitWidth}; + } + } else { + // The matrix elements of thread 0 are distributed in the following pattern + // (fp16): + // + // col0 col8 col16 col24 + // row0 reg[0-1] reg[2-3] reg[4-5] reg[6-7] + if (needTrans) { + assert(elemBitWidth <= 16 && "Only elements smaller than 16 bits are " + "supported in the transposed mode"); + if (kX2) + basesLane[3] = {8, 0}; + if (kX4) + basesLane[4] = {16, 0}; + } else { + if (kX2) + basesLane[3] = {0, 8 * 16 / elemBitWidth}; + if (kX4) + basesLane[4] = {0, 16 * 16 / elemBitWidth}; + } + } + int numTileCols = + (8 * 16 / elemBitWidth) + << (static_cast(kX2) + static_cast(kX4 && opIdx == 1)); + // Expand the `register` dimension so the size of columns matches `K`. + auto layout = + LinearLayout({{kReg, basesReg}, {kLane, basesLane}, {kWarp, {}}}, + {kOuter, kInner}) * + LinearLayout::identity1D(shape[kDim] / numTileCols, kReg, + S("dim" + std::to_string(kDim))); + // Expand the `warp` dimension according to warpsPerCTA. + auto warpsPerCTA = mma.getWarpsPerCTA(); + auto warpOrder = getMatrixOrder(rank, /*rowMajor*/ !mma.isHopper()); + layout *= + broadcastedDotOperandLayout(ctx, warpsPerCTA, warpOrder, kDim, kWarp) + .transposeOuts(llvm::to_vector(layout.getOutDimNames())); + return combineCtaCgaWithShape(layout, getCTALayout(dot), shape); +} + +} // anonymous namespace + +LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy, + int swizzleByteSize) { + if (swizzleByteSize == 0) + return chooseStMatrixLayoutSwizzled(ctx, tensorTy.getEncoding(), + tensorTy.getShape()); + else + return chooseStMatrixLayoutNVMMA(ctx, tensorTy, swizzleByteSize); +} + +LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef shape, + bool needTrans, int32_t elemBitWidth) { + auto dot = cast(enc); + return chooseDotLdMatrixLayout(dot, shape, needTrans, elemBitWidth); +} + +LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef shape, + int32_t elemBitWidth) { + auto dot = cast(enc); + return chooseDotDsReadB64TrLayout(dot, shape, elemBitWidth); +} + +LinearLayout chooseScaledMfmaScaleLayout( + MLIRContext *ctx, int dotOperandIdx, + const std::vector> &dotOperandWarpBasis, + ArrayRef dotOperandShape, unsigned mfmaMDim) { + using basisT = std::vector>; + unsigned rank = dotOperandShape.size(); + auto order = mlir::triton::gpu::getMatrixOrder(rank, /*rowMajor=*/true); + auto standardOutDims = standardOutDimNames(ctx, rank); + StringAttr kRegister = StringAttr::get(ctx, "register"); + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + StringAttr kBlock = StringAttr::get(ctx, "block"); + // Init register layout. Will be adjusted later + auto regs = mlir::triton::identityStandardND(kRegister, {1, 1}, order); + LinearLayout lanes = LinearLayout::empty(); + // In scaled dot, the shapes of operands(without batch dimension) are, + // respectively: + // - A: [M, K] + // - B: [K, N] + // - aScale: [M, K / 32] + // - bScale: [N, K / 32] + // + // To correctly feed A/B and its scale into instruction, we need to + // distribute aScale/bScale among warps in the same way as A/B. But bScale + // is not transposed like B. So we need to transpose the warp layout of + // bScale. + // + // The tricky part is, our desired outputs are [dim0, dim1], but + // at this position, the layouts are transposed to [dim1, dim0]. So + // instead of reverse bScale's layout, we need to reverse aScale's. There + // will be a transpose in the end to correct everything. + basisT warps = dotOperandWarpBasis; + if (dotOperandIdx == 0) { + for (auto &basis : warps) { + std::reverse(basis.begin(), basis.end()); + } + } + // In general, for both 32x32 and 16x16 scaled mfma, and no matter what + // data type the A/B operand is, each lane takes 32 elements from A/B + // alone K dim, and 1 or 2 elements from scale accordingly. The number of + // scale's elements in a lane varies because the 32 elements from A/B may + // not be consecutive. + // + // For mxfp4, these 32 elements are consecutive, so only 1 scale element + // is required. But for mxfp6/mxfp8, there are 2 16-consecutive elements + // blocks, so 2 scale elements are required. + if (mfmaMDim == 32) { + // For ROCDL::mfma_scale_f32_32x32x64_f8f6f4 with fp4 input, each lane + // takes 32 consecutive elements from A alone K dimension. The first + // 32 lanes collectively handle A[0:32][0:32], and the other 32 lanes + // collectively handle A[0:32][32:64]. Each lane take 1 scale element + // accordingly. Similar to B and bScale. + lanes = LinearLayout( + {{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {1, 0}}}, + {kWarp, warps}, + {kBlock, {}}}, + {standardOutDims[order[0]], standardOutDims[order[1]]}); + } else { + assert(mfmaMDim == 16); + // For ROCDL::mfma_scale_f32_16x16x128_f8f6f4 with fp4 input, each lane + // takes 32 consecutive elements from A alone K dimension. The first + // 16 lanes collectively handle A[0:16][0:32], and another 16 lanes + // collectively handle A[0:16][32:64] and so on. Each lane take 1 scale + // element accordingly. Similar to B and bScale. + lanes = + LinearLayout({{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}}, + {kWarp, warps}, + {kBlock, {}}}, + {standardOutDims[order[0]], standardOutDims[order[1]]}); + } + LinearLayout newLL = regs * lanes; + + // Adjust register-level layout to fill the shape, at this level, both + // aScale and bScale should align with A operand. + SmallVector repOrder = {1, 0}; + for (auto d : repOrder) { + auto outDim = standardOutDims[d]; + auto dimSize = newLL.getOutDimSize(outDim); + newLL *= LinearLayout::identity1D(dotOperandShape[d] / dimSize, kRegister, + outDim); + } + newLL = newLL.transposeOuts(standardOutDims); + return newLL; +} + +std::optional +chooseMfmaLikeStoreLayout(RankedTensorType valType) { + // TODO: WMMA Support on RDNA + if (!isa(valType.getEncoding())) + return {}; + auto mfmaLayout = cast(valType.getEncoding()); + + // We currently only support transposed [B]F16 MFMA32x32 and MFMA16x16 on + // CDNA4. + bool isMfma32 = mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32; + bool isMfma16 = mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16; + + auto valShape = valType.getShape(); + // For mfma16x16, to use in-wavefront swap, we need to make sure the tiles + // used are in one wavefront if there are multiple tiles, which means + // warpsPerCTA = [numWarps, 1] and at least two tiles along the N dim. For + // now, it is only possible for FA-like kernels since during mfma generation, + // the WarpsPerCTA of the head dot in the chain will be reshaped to [numWaprs, + // 1]. + // TODO: For gemm-like kernel, the transformation here cannot be applied for + // now and will support it. + bool validForMfma16 = isMfma16 && valShape.back() >= 16 * 2 && + mfmaLayout.getWarpsPerCTA().back() == 1; + + Type elemType = valType.getElementType(); + if (!(valType.getRank() == 2 && (elemType.isF16() || elemType.isBF16()) && + mfmaLayout.getVersionMajor() == 4 && mfmaLayout.getIsTransposed() && + (isMfma32 || validForMfma16))) + return {}; + + LinearLayout mfmaLL = mfmaLayout.toLinearLayout(valShape); + auto mfmaOutDims = llvm::to_vector(mfmaLL.getOutDimNames()); + StringAttr dimM = mfmaOutDims[0]; + StringAttr dimN = mfmaOutDims[1]; + auto swapLL = LinearLayout::empty(); + // The rows are kept as is with an identity linear layout. + swapLL *= LinearLayout::identity1D(valShape[0], dimM, dimM); + /* + clang-format off + In transposed mfma32 layout, Each thread holds 4 consecutive values along N + dim. We want to exchange column 4-7 (owned by thread 32-63, BLK0) and column + 8-11 (owned by thread 0-31, BLK1) every 16 columns to make each thread holds 8 + elements. This would mean exchange the 2nd and 3rd basis vector from an + identity linear layout on tensor elements. + + Correspondingly, the transposed mfma16 layout, the output of + transposed of mfma16x16 is: + + N/register + M/Lane v0 v1 v2 v3 v4 v5 v6 v7 + ------------------------------------------------------------------------- + row0: 0-15 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 | + ------------------------------------------------------------------------- + row1: 16-31 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 | + ------------------------------------------------------------------------- + row2: 32-47 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 | + ------------------------------------------------------------------------- + row3: 48-63 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 | + ------------------------------------------------------------------------- + which means: + The columns from v0 to v3 are in the one output of mfma16x16 and + the columns from v4 to v7 are in the one output of mfma16x16, + + The following graph is the same as the one above, execept the tile number is replaced with coordinates in the tenor, + N/register + ----------------------------------------------- + M/lane |(0, 0) ... (0, 3) | (0, 16) ... (0, 19) | + |.... | sub-tensor-0 | + |(15, 0) ... (15, 3) | (15, 16) ... (15, 19) | + ----------------------------------------------- + |(0, 4) ... (0, 7) | (0, 20) ... (0, 23) | + |sub-tensor-1 | .... | + |(15, 0) ... (15, 3) | (15, 20) ... (15, 23) | + ----------------------------------------------- + |(0, 8) ... (0, 11)| (0, 24) ... (0, 27) | + |.... | sub-tensor-2 | + |(15, 8) ... (15, 11)| (15, 24) ... (15, 27) | + ----------------------------------------------- + |(0, 12) ... (0, 15)| (0, 28) ... (0, 31) | + |sub-tensor-3 | .... | + |(15, 12) ... (15, 15)| (15, 28) ... (15, 31) | + ----------------------------------------------- + The basis vector for lane and register are: + Register = {{0, 1}, {0, 2}} + Lane = {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}} + With this layout, only 4xfp16 can be packed in the final global store. + + To use 128-bits global store, we need to pack 8 elements, which means the layout looks like: + N/register + M/Lane v0 v1 v2 v3 v4 v5 v6 v7 + ------------------------------------------------------------------------- + row0: 0-15 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | + ------------------------------------------------------------------------- + row1: 16-31 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | + ------------------------------------------------------------------------- + row2: 32-47 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | + ------------------------------------------------------------------------- + row3: 48-63 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | + ------------------------------------------------------------------------- + + The following graph is the same as the one above, execept the tile number is replaced with coordinates in the tenor: + N/register + ----------------------------------------------- + |(0, 0) ... (0, 3) | (0, 4) ... (0, 7) | + |.... | sub-tensor-1 | + |(15, 0) ... (15, 3) | (15, 16) ... (15, 19) | + ----------------------------------------------- + |(0, 16) ... (0, 19) | (0, 20) ... (0, 23) | + |sub-tensor-0 | .... | + |(15, 16) ... (15, 19)| (15, 20) ... (15, 23) | + ----------------------------------------------- + |(0, 8) ... (0, 11)| (0, 12) ... (0, 15) | + |.... | sub-tensor-3 | + |(15, 8) ... (15, 11)| (15, 12) ... (15, 15) | + ----------------------------------------------- + |(0, 24) ... (0, 27)| (0, 28) ... (0, 31) | + |sub-tensor-2 | .... | + |(15, 24) ... (15, 27)| (15, 28) ... (15, 31) | + ----------------------------------------------- + which means we need to exchange sub-tensor-0 with sub-tensor-1 and sub-tensor-2 and sub-tensor-3. + And basis vector for lane and register are: + Register = {{0, 1}, {0, 2}, {0, 4}} + Lane = {{1, 0}, {2, 0, [4, 0}, {8, 0}, {0, 16}, {0, 8}} + + The steps to get this layout are, firstly we check the last dim of WarpsPerCTA is 1, so we can use v_permlane16. + Then, we exchange the 2nd and 4th elements in the basis vector of an identity linear and then it will be composed with + the original mfma16 LL. + clang-format on + */ + auto destIdxInBases = isMfma32 ? 3 : 4; + std::vector> dimNBases(mfmaLL.getOutDimSizeLog2(dimN)); + std::generate(dimNBases.begin(), dimNBases.end(), + [i = 0]() mutable { return std::vector{1 << i++}; }); + std::swap(dimNBases[2], dimNBases[destIdxInBases]); + swapLL *= LinearLayout({{dimN, dimNBases}}, {dimN}); + + return mfmaLL.compose(swapLL); +} + +LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType, + int numWarps) { + assert(numWarps == 4 || numWarps == 8); + MLIRContext *ctx = scaleType.getContext(); + + using basisT = std::vector>; + StringAttr kRegister = StringAttr::get(ctx, "register"); + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + + int64_t M = scaleType.getDimSize(0); + int64_t N = scaleType.getDimSize(1); + auto CTALayout = getCTALayout(scaleType.getEncoding()); + basisT regBase; + + // Pick a layout that will be trivial to store into the following TMEM layout: + // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x + // Pack 4 scales together, if there are less than 4 we replicate the data. + for (int i = 1; i < 4; i = i << 1) { + if (i >= N) + regBase.push_back({0, 0}); + else + regBase.push_back({0, i}); + } + // Distribute 32 elements of M along a warp. + basisT laneBase = {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}}; + // The data are replicated across all the warps of each warpgroups. + basisT warpBase = {{0, 0}, {0, 0}}; + for (int i = 32; i < M; i = i << 1) { + regBase.push_back({i, 0}); + } + for (int i = 4; i < N; i = i << 1) { + regBase.push_back({0, i}); + } + // If we have 8 warps distribute the last dimension on the second warp group. + if (numWarps == 8) { + warpBase.push_back(regBase.back()); + regBase.pop_back(); + } + + SmallVector outDimNames = standardOutDimNames(ctx, 2); + auto regLanes = + LinearLayout({{kRegister, regBase}, {kLane, laneBase}, {kWarp, warpBase}}, + {outDimNames[0], outDimNames[1]}); + + return combineCtaCgaWithShape(regLanes, CTALayout, scaleType.getShape()); +} + +std::optional +getTmemLoadStoreLayout16x256(int M, int N, RankedTensorType oldType, + int numWarps) { + // Too small to distribute on two warp groups while using 16x256 message. + if (numWarps == 8 && M == 64 && N <= 16 && + oldType.getElementTypeBitWidth() < 32) { + return {}; + } + assert(numWarps == 4 || numWarps == 8); + auto ctaLayout = getCTALayout(oldType.getEncoding()); + SmallVector shape = getShapePerCTA(oldType); + MLIRContext *ctx = ctaLayout.getContext(); + + using basisT = std::vector>; + StringAttr kRegister = StringAttr::get(ctx, "register"); + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + SmallVector outDimNames = standardOutDimNames(ctx, 2); + + unsigned numElementsPerThread = 256 / oldType.getElementTypeBitWidth(); + int kWidth = 64 / oldType.getElementTypeBitWidth(); + // Follow the layout given by a tmem load using this layout for the inner + // shape: + // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b + LinearLayout innerTile = + nvidiaMmaTile(ctx, {8, numElementsPerThread}, kWidth, {1, 0}, {0, 1}); + innerTile = + innerTile * LinearLayout::identity1D(2, kRegister, outDimNames[0]); + // Then distribute the rest along warpgroups and registers. + // Then the last warp distribute along M or N following the same order as + // in getTmemLoadStoreLayout32x32b. This allows us to use the same lowering to + // tmem for load and store. This part could be generalized by making the + // lowering of tmem load and store rely more on linear layout. + bool distributeMAlongWarps = false; + bool distributeNAlongWarps = false; + // Figure out how to distribute acorss warpgroups. + if (numWarps == 8) { + if (shape[0] > 128) { + distributeMAlongWarps = true; + } else { + distributeNAlongWarps = true; + } + } + int nBase = numElementsPerThread; + int maxRegN = + std::min(N, distributeNAlongWarps ? (int)shape[1] / 2 : (int)shape[1]); + if (maxRegN / nBase > 1) { + innerTile = innerTile * LinearLayout::identity1D(maxRegN / nBase, kRegister, + outDimNames[1]); + } + if (M != 64) { + innerTile = + innerTile * LinearLayout::identity1D(2, kRegister, outDimNames[0]); + } + // Distribute M along 4 warps to satisfy TMEM requirements. + innerTile = innerTile * LinearLayout::identity1D(4, kWarp, outDimNames[0]); + + // Fill out the rest of the shape with M first then N. + int numMRegDim = std::min(128, (int)shape[0]) / M; + if (numMRegDim > 1) { + innerTile = innerTile * + LinearLayout::identity1D(numMRegDim, kRegister, outDimNames[0]); + } + // Dim M=128 should be distributed on the second warp group. + int nextDim = 128; + if (distributeMAlongWarps) { + innerTile = innerTile * LinearLayout::identity1D(2, kWarp, outDimNames[0]); + nextDim <<= 1; + } + numMRegDim = shape[0] / nextDim; + if (numMRegDim > 1) { + innerTile = innerTile * + LinearLayout::identity1D(numMRegDim, kRegister, outDimNames[0]); + } + int maxN = distributeNAlongWarps ? shape[1] / 2 : shape[1]; + int numNRegDim = maxN / maxRegN; + if (numNRegDim > 1) { + innerTile = innerTile * + LinearLayout::identity1D(numNRegDim, kRegister, outDimNames[1]); + } + if (distributeNAlongWarps) { + innerTile = innerTile * LinearLayout::identity1D(2, kWarp, outDimNames[1]); + } + return combineCtaCgaWithShape(innerTile, ctaLayout, oldType.getShape()); +} + +LinearLayout getTmemLoadLayoutSplitLongM(int M, int N, RankedTensorType oldType, + int numWarps) { + assert(numWarps == 8); + auto ctaLayout = getCTALayout(oldType.getEncoding()); + SmallVector shape = getShapePerCTA(oldType); + MLIRContext *ctx = ctaLayout.getContext(); + + using basisT = std::vector>; + StringAttr kRegister = StringAttr::get(ctx, "register"); + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + + // Follow the layout given by a tmem load using this layout: + // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-1632b2 + basisT laneBase; + assert(M == 128); + for (int i = 1; i < 16; i = i << 1) { + laneBase.push_back({i, 0}); + } + basisT regBase; + for (int i = 1; i < N / 2; i = i << 1) { + regBase.push_back({0, i}); + } + laneBase.push_back({0, N / 2}); + // then replicate the pattern. + for (int i = N; i < shape[1]; i = i << 1) { + regBase.push_back({0, i}); + } + for (int i = M; i < shape[0]; i = i << 1) { + regBase.push_back({i, 0}); + } + // warp 0 and 4 can only access M[0:32], therefore we need to interleave the + // data. + basisT warpBase = {{32, 0}, {64, 0}, {16, 0}}; + SmallVector outDimNames = standardOutDimNames(ctx, 2); + auto regLanes = + LinearLayout({{kRegister, regBase}, {kLane, laneBase}, {kWarp, warpBase}}, + {outDimNames[0], outDimNames[1]}); + + return combineCtaCgaWithShape(regLanes, ctaLayout, oldType.getShape()); +} + +} // namespace mlir::triton::gpu diff --git a/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000..0418aefe3 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -0,0 +1,13 @@ +include_directories(${PROJECT_SOURCE_DIR}/lib/Dialect/TritonGPU/Transforms/Pipeliner) + +add_triton_library(FlagTree_sunrise_TritonGPUTransforms + Pipeliner/PipeliningUtility.cpp + Coalesce.cpp + Prefetch.cpp + RemoveLayoutConversions.cpp + Utility.cpp + + DEPENDS + TritonTableGen + TritonGPUAttrDefsIncGen +) diff --git a/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp new file mode 100644 index 000000000..605f5d141 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -0,0 +1,198 @@ +#include +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tritongpu-coalesce" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUCOALESCE +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +struct CoalescePass : public impl::TritonGPUCoalesceBase { + void + setCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis, Operation *op, + int numWarps, int threadsPerWarp, + llvm::MapVector &layoutMap) { + Value ptr = getMemAccessPtr(op); + auto refTensorType = cast(ptr.getType()); + + LDBG("Considering op: " << *op); + LLVM_DEBUG({ + DBGS() << "axis info of pointer: "; + axisInfoAnalysis.getAxisInfo(ptr)->print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + auto contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity(); + SmallVector order = argSort(contiguity); + if(order.size() == 2 && order[1] == 1) { + order = {1, 0}; + } + LDBG("order=[" << triton::join(order, ", ") << "]"); + + auto matchesShape = [&refTensorType](const Value &val) { + auto rttType = dyn_cast(val.getType()); + return rttType && rttType.getShape() == refTensorType.getShape(); + }; + + // The desired divisibility is the maximum divisibility among all dependent + // pointers which have the same shape and order as `ptr`. + llvm::SmallSetVector memAccessesSameOrder; + memAccessesSameOrder.insert(op); + if (ptr.getDefiningOp()) { + for (Operation *use : mlir::multiRootGetSlice(op)) { + Value val = getMemAccessPtr(use); + if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use)) + continue; + auto currOrder = + argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity()); + if (order == currOrder) { + LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use); + memAccessesSameOrder.insert(use); + } + } + } + + auto shapePerCTA = triton::gpu::getShapePerCTA(refTensorType); + LDBG("shapePerCTA=[" << triton::join(shapePerCTA, ", ") << "]"); + + int numElems = product(shapePerCTA); + int numThreads = numWarps * threadsPerWarp; + + unsigned perThread = getNumElementsPerThread(op, order, axisInfoAnalysis); + LDBG("perThread for op: " << perThread); + + for (Operation *opSameOrder : memAccessesSameOrder) { + if (opSameOrder == op) + continue; + unsigned currPerThread = + getNumElementsPerThread(opSameOrder, order, axisInfoAnalysis); + LDBG("perThread for opSameOrder: " << currPerThread); + perThread = std::max(perThread, currPerThread); + } + + perThread = std::min(perThread, std::max(numElems / numThreads, 1)); + LDBG("perThread: " << perThread); + + if (!dyn_cast(op)) { + // For ops that can result in a global memory write, we should enforce + // that each thread handles at most 128 bits, which is the widest + // available vectorized store op; otherwise, the store will have "gaps" + // in the memory write at the warp level, resulting in worse performance. + // For loads, we can expect that the gaps won't matter due to the L1 + // cache. + perThread = std::min( + perThread, getNumElementsPerThread(op, order, axisInfoAnalysis)); + } + SmallVector sizePerThread(refTensorType.getRank(), 1); + sizePerThread[order[0]] = perThread; + + auto CTALayout = triton::gpu::getCTALayout(refTensorType.getEncoding()); + layoutMap[op] = triton::gpu::BlockedEncodingAttr::get( + &getContext(), refTensorType.getShape(), sizePerThread, order, numWarps, + threadsPerWarp, CTALayout); + } + + static Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = cast(type); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + } + + void coalesceOp(Attribute encoding, Operation *op) { + OpBuilder builder(op); + // Convert operands + // For load/store with tensor pointers, we don't have to change the + // operands' type, we do this by changing the outputs' type of + // `make_tensor_ptr` + SmallVector newArgs; + for (auto operand : op->getOperands()) { + auto tensorType = dyn_cast(operand.getType()); + if (tensorType && + !isa(tensorType.getEncoding())) { + Type newType = getNewType(tensorType, encoding); + newArgs.push_back(builder.create( + op->getLoc(), newType, operand)); + } else { + newArgs.push_back(operand); + } + } + + // Convert output types + SmallVector newTypes; + for (auto t : op->getResultTypes()) { + bool isAsync = isa(op); + newTypes.push_back(isAsync ? t : getNewType(t, encoding)); + } + + // Construct new op with the new encoding + Operation *newOp = + builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs, + newTypes, op->getAttrs()); + + // Cast the results back to the original layout + for (size_t i = 0; i < op->getNumResults(); i++) { + Value newResult = newOp->getResult(i); + if (newTypes[i] != op->getResultTypes()[i]) { + newResult = builder.create( + op->getLoc(), op->getResult(i).getType(), newResult); + } + op->getResult(i).replaceAllUsesWith(newResult); + } + op->erase(); + } + + void runOnOperation() override { + // Run axis info analysis + ModuleOp moduleOp = getOperation(); + ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + // For each i/o operation, we determine what layout + // the pointers should have for best memory coalescing + llvm::MapVector layoutMap; + int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(moduleOp); + moduleOp.walk([&](Operation *curr) { + Value ptr = getMemAccessPtr(curr); + if (!ptr) + return; + // We only convert `tensor>` load/store + bool isPtrTensor = false; + if (auto tensorType = dyn_cast(ptr.getType())) + isPtrTensor = isa(tensorType.getElementType()); + if (!isPtrTensor) + return; + int numWarps = lookupNumWarps(curr); + setCoalescedEncoding(axisInfoAnalysis, curr, numWarps, threadsPerWarp, + layoutMap); + }); + + // For each memory op that has a layout L1: + // 1. Create a coalesced memory layout L2 of the pointer operands + // 2. Convert all operands from layout L1 to layout L2 + // 3. Create a new memory op that consumes these operands and + // produces a tensor with layout L2 + // 4. Convert the output of this new memory op back to L1 + // 5. Replace all the uses of the original memory op by the new one + for (auto &kv : layoutMap) { + coalesceOp(kv.second, kv.first); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp new file mode 100644 index 000000000..a6ca73db2 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -0,0 +1,760 @@ +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include + +#define DEBUG_TYPE "triton-loop-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +//===----------------------------------------------------------------------===// +// Hoisting Utilities +//===----------------------------------------------------------------------===// + +bool triton::isPureScalarOp(Operation *op) { + auto isScalar = [](Type type) { return type.isIntOrIndexOrFloat(); }; + return isPure(op) && llvm::all_of(op->getOperandTypes(), isScalar) && + llvm::all_of(op->getResultTypes(), isScalar); +} + +bool triton::getDominatingValueSetOpsToHoist( + DominanceInfo &domInfo, Operation *refOp, ArrayRef valueSet, + llvm::SetVector &toHoist, + function_ref canHoist) { + // The set of operations below `refOp` that are being checked if they can be + // hoisted. This set prevents checking operations twice but also if the + // computation can be hoisted, this becomes the set of operations to hoist. + llvm::SetVector visited; + + // Climb the use-def chain breadth-first so that operations can be hoisted in + // the reverse visitation order. + std::queue queue; + for (Value value : valueSet) + queue.push(value); + + while (!queue.empty()) { + Value value = queue.front(); + queue.pop(); + + // If the value properly dominates the outer loop, then it must be invariant + // to it. + if (domInfo.properlyDominates(value, refOp)) + continue; + // If the value is a block argument, it cannot be hoisted. + if (auto arg = dyn_cast(value)) + return false; + + Operation *op = value.getDefiningOp(); + // Check if the op was already visited. + if (visited.contains(op)) + continue; + // If the defining op cannot be hoisted, then the value cannot be made loop + // invariant. + if (!canHoist(op)) + return false; + visited.insert(op); + // Recurse on the operands of the op. + for (Value operand : op->getOperands()) + queue.push(operand); + } + + // The operations in `visited` must be hoisted. Note that operations are not + // added to `toHoist` unless all of `values` can be hoisted. This is to avoid + // hoisting operations for loops that don't end up getting fused if one of + // their bounds operands cannot be hoisted. + toHoist.insert(visited.begin(), visited.end()); + + return true; +} + +void triton::hoistOpsBefore(Operation *refOp, + const llvm::SetVector &toHoist) { + return hoistOpsBefore(refOp->getBlock(), refOp->getIterator(), toHoist); +} +void triton::hoistOpsBefore(Block *block, Block::iterator it, + const llvm::SetVector &toHoist) { + for (Operation *op : topologicalSort(toHoist)) { + op->moveBefore(block, it); + } +} + +//===----------------------------------------------------------------------===// +// Sinking Utilities +//===----------------------------------------------------------------------===// + +Value triton::sinkValueRedefinition(RewriterBase &rewriter, Value in, Value out, + Block *block) { + OpBuilder::InsertionGuard guard(rewriter); + for (; block != in.getParentBlock(); + block = block->getParentOp()->getBlock()) { + Operation *op = block->getParentOp(); + rewriter.setInsertionPoint(op); + + // `in` is live into the loop body. `out` becomes the live-out if the + // loop executes at least once. + if (auto forOp = dyn_cast(op)) { + forOp = addIterArgsToLoop(rewriter, forOp, in); + appendToForOpYield(forOp, out); + out = forOp.getResults().back(); + continue; + } + + // `in` is live into both branches. `out` becomes the live-out if the + // particular branch is taken. + if (auto ifOp = dyn_cast(op)) { + scf::IfOp newIfOp = + replaceIfOpWithNewSignature(rewriter, ifOp, out.getType()); + scf::YieldOp taken = newIfOp.thenYield(); + scf::YieldOp other = newIfOp.elseYield(); + if (block == newIfOp.elseBlock()) + std::swap(taken, other); + taken->insertOperands(taken.getNumOperands(), out); + other->insertOperands(other.getNumOperands(), in); + out = newIfOp.getResults().back(); + rewriter.eraseOp(ifOp); + continue; + } + + // TODO: Handle `scf.while`, etc. + llvm::report_fatal_error("FIXME: sinking into unhandled control flow op: " + + op->getName().getStringRef()); + } + + return out; +} + +//===----------------------------------------------------------------------===// +// Loop Pipelining Utilities +//===----------------------------------------------------------------------===// + +bool mlir::triton::loopHasDistGreaterThanOne(scf::ForOp forOp) { + return llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def; + }); +} + +bool mlir::triton::isOuterLoop(scf::ForOp forOp) { + return llvm::any_of(forOp.getBody()->getOperations(), [](Operation &op) { + return isa(op); + }); +} + +// Function to mask operations during scheduling. +Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op, + Value pred) { + OpBuilder::InsertionGuard guard(rewriter); + if (mlir::isMemoryEffectFree(op)) + return op; + if (isConstantIntValue(pred, 1)) + return op; + if (isa(op)) + return op; + if (isa(op)) + return op; + if (isa(op)) + return op; + if (isa(op)) + return op; + if (auto ifOp = dyn_cast(op)) { + rewriter.setInsertionPoint(op); + Value cnd = getPredMask(rewriter, ifOp.getCondition().getType(), + ifOp.getCondition(), pred); + ifOp.getConditionMutable().assign(cnd); + return op; + } + + if (auto asyncCopyOp = dyn_cast(op)) { + rewriter.setInsertionPoint(asyncCopyOp); + bool hasOriMask = false; + if (op->hasAttr("sunrise.hasOriMask") == false || + mlir::cast(op->getAttr("sunrise.hasOriMask")).getInt() == 1) { + hasOriMask = true; + } + if (!hasOriMask) { + Value mask = getPredMask(rewriter, asyncCopyOp.getSrc().getType(), + asyncCopyOp.getMask(), pred); + asyncCopyOp.getMaskMutable().assign(mask); + } + return op; + } + if (auto loadOp = dyn_cast(op)) { + rewriter.setInsertionPoint(loadOp); + bool hasOriMask = false; + if (op->hasAttr("sunrise.hasOriMask") == false || + mlir::cast(op->getAttr("sunrise.hasOriMask")).getInt() == 1) { + hasOriMask = true; + } + if (!hasOriMask) { + Value mask = getPredMask(rewriter, loadOp.getPtr().getType(), + loadOp.getMask(), pred); + loadOp.getMaskMutable().assign(mask); + } + return op; + } + if (auto copyOp = dyn_cast(op)) { + rewriter.setInsertionPoint(copyOp); + Value mask = getPredMask(rewriter, copyOp.getPred().getType(), + copyOp.getPred(), pred); + copyOp.getPredMutable().assign(mask); + return op; + } + if (auto gatherOp = dyn_cast(op)) { + rewriter.setInsertionPoint(gatherOp); + Value mask = getPredMask(rewriter, gatherOp.getPred().getType(), + gatherOp.getPred(), pred); + gatherOp.getPredMutable().assign(mask); + return op; + } + if (auto expectOp = dyn_cast(op)) { + rewriter.setInsertionPoint(expectOp); + Value mask = getPredMask(rewriter, expectOp.getPred().getType(), + expectOp.getPred(), pred); + expectOp.getPredMutable().assign(mask); + return op; + } + if (auto mmav5Op = dyn_cast(op)) { + rewriter.setInsertionPoint(mmav5Op); + auto currPred = mmav5Op.getPredicate(); + Value mask = getPredMask(rewriter, currPred.getType(), currPred, pred); + mmav5Op.setPredicate(mask); + return op; + } + if (auto tmemStoreOp = dyn_cast(op)) { + rewriter.setInsertionPoint(tmemStoreOp); + Value mask = getPredMask(rewriter, tmemStoreOp.getPred().getType(), + tmemStoreOp.getPred(), pred); + tmemStoreOp.getPredMutable().assign(mask); + return op; + } + if (auto waitBarrier = dyn_cast(op)) { + rewriter.setInsertionPoint(waitBarrier); + Value mask = pred; + Value currentPred = waitBarrier.getPred(); + if (currentPred) { + mask = getPredMask(rewriter, currentPred.getType(), currentPred, pred); + } + waitBarrier.getPredMutable().assign(mask); + return op; + } + if (auto arriveBarrier = dyn_cast(op)) { + rewriter.setInsertionPoint(arriveBarrier); + Value mask = pred; + Value currentPred = arriveBarrier.getPred(); + if (currentPred) { + mask = getPredMask(rewriter, currentPred.getType(), currentPred, pred); + } + arriveBarrier.getPredMutable().assign(mask); + return op; + } + if (auto storeOp = dyn_cast(op)) { + rewriter.setInsertionPoint(storeOp); + Value mask = getPredMask(rewriter, storeOp.getPtr().getType(), + storeOp.getMask(), pred); + storeOp.getMaskMutable().assign(mask); + return op; + } + if (auto atomicRMWOp = dyn_cast(op)) { + rewriter.setInsertionPoint(atomicRMWOp); + Value mask = getPredMask(rewriter, atomicRMWOp.getPtr().getType(), + atomicRMWOp.getMask(), pred); + atomicRMWOp.getMaskMutable().assign(mask); + return op; + } + if (!op->isRegistered()) { + // Skip ops from unregistered dialects to make writing lit tests easier. + return op; + } + + op->emitOpError("pipeliner doesn't know how to predicate this op."); + llvm::report_fatal_error("Fatal pipeliner error"); + return op; +} + +// Return true if the given ForOp has the attribute +// `tt.disallow_acc_multi_buffer` set to true. +bool mlir::triton::getDisallowAccMultiBuffer(scf::ForOp forOp) { + return forOp->hasAttr(mlir::triton::kDisallowAccMultiBufferAttrName); +} + +std::pair +mlir::triton::getDefinitionAndDistance(scf::ForOp forOp, Value value) { + int64_t distance = 0; + DenseSet seen; + while (auto arg = dyn_cast(value)) { + // Ignore implicit captures. + if (arg.getOwner() != forOp.getBody()) + return {nullptr, 0}; + // Ignore induction variable. + if (arg.getArgNumber() == 0) + return {nullptr, 0}; + ++distance; + value = forOp.getYieldedValues()[arg.getArgNumber() - 1]; + if (!seen.insert(value).second) + return {nullptr, 0}; + } + return {cast(value), distance}; +} + +std::pair +mlir::triton::getDefiningOpAndDistance(scf::ForOp forOp, Value value) { + auto [definition, distance] = getDefinitionAndDistance(forOp, value); + return {definition ? definition.getDefiningOp() : nullptr, distance}; +} + +int mlir::triton::getCopyVecBytes(RankedTensorType registerTy, + ttg::SharedEncodingTrait sharedEnc) { + auto regLayout = triton::gpu::toLinearLayout(registerTy.getShape(), + registerTy.getEncoding()); + auto sharedLayout = + triton::gpu::toLinearLayout(registerTy.getShape(), sharedEnc); + auto regToSharedLayout = regLayout.invertAndCompose(sharedLayout); + const int vecElems = regToSharedLayout.getNumConsecutiveInOut(); + return vecElems * registerTy.getElementTypeBitWidth() / 8; +} + +bool mlir::triton::canBeConvertedToAsyncLoad( + tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + assert(!isLoadFromTensorPtr(loadOp) && + "Block ptr should have been lowered before this pass."); + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return false; + auto ty = cast(tensorTy.getElementType()).getPointeeType(); + unsigned width = vec * ty.getIntOrFloatBitWidth(); + + // We do not pipeline all loads for the following reasons: + // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8, or 16. + // 2. It's likely that pipling small loads won't offer much performance + // improvement and may even hurt performance by increasing register + // pressure. + LDBG("Load " << *loadOp << " has width " << width); + return width >= 32; +} + +void mlir::triton::serializeLatencies(ModuleOp module, + DenseMap &opLatency) { + auto helper = TritonDialect::getLoaded(module)->getLatencyAttrHelper(); + auto builder = Builder(module); + for (auto &[op, latency] : opLatency) { + helper.setAttr(op, builder.getI32IntegerAttr(latency)); + } +} + +void mlir::triton::serializeSelfLatencies( + ModuleOp module, DenseMap &opSelfLatency) { + auto helper = TritonDialect::getLoaded(module)->getSelfLatencyAttrHelper(); + auto builder = Builder(module); + for (auto &[op, latency] : opSelfLatency) { + helper.setAttr(op, builder.getI32IntegerAttr(latency)); + } +} + +DenseMap mlir::triton::deserializeLatencies(Operation *op) { + DenseMap opLatency; + auto latencyHelper = TritonDialect::getLoaded(op)->getLatencyAttrHelper(); + op->walk([&](Operation *op) { + if (auto attr = latencyHelper.getAttr(op)) { + opLatency[op] = attr.getInt(); + latencyHelper.removeAttr(op); + } + }); + return opLatency; +} + +Value mlir::triton::createScalarAlloc(ImplicitLocOpBuilder &rewriter, Type type, + unsigned numBuffers) { + MLIRContext *ctx = rewriter.getContext(); + unsigned numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs( + rewriter.getBlock()->getParentOp()->getParentOfType()); + Attribute sharedMemorySpace = + ttg::SharedMemorySpaceAttr::get(rewriter.getContext()); + auto barrierCTALayout = + ttg::CTALayoutAttr::get(/*context=*/ctx, /*CTAsPerCGA=*/{numCTAs}, + /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierEncoding = + ttg::SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, {0}, barrierCTALayout); + ttg::MemDescType memDescType = ttg::MemDescType::get( + {numBuffers}, type, barrierEncoding, sharedMemorySpace, + /*mutableMemory=*/true); + return rewriter.create(memDescType, Value()); +} + +// Create an allocation and init the mbarriers. +Value mlir::triton::createBarrierAlloc(scf::ForOp forOp, int numBarriers, + int arriveCount) { + ImplicitLocOpBuilder rewriter(forOp.getLoc(), forOp); + + Value barrierAlloc = + createScalarAlloc(rewriter, rewriter.getI64Type(), numBarriers); + for (unsigned i = 0; i < numBarriers; i++) { + Value barrierView = createSingleBufferView(rewriter, barrierAlloc, i); + rewriter.create(barrierView, arriveCount); + } + // Invalidate and deallocate the barriers. + rewriter.setInsertionPointAfter(forOp); + for (unsigned i = 0; i < numBarriers; i++) { + Value barrierView = createSingleBufferView(rewriter, barrierAlloc, i); + rewriter.create(barrierView); + } + rewriter.create(barrierAlloc); + return barrierAlloc; +} + +Value mlir::triton::createAlloc(Operation *insertBefore, RankedTensorType ty, + Location loc, + gpu::SharedEncodingTrait sharedEnc, + unsigned distance) { + OpBuilder builder(insertBefore); + Attribute sharedMemorySpace = + ttg::SharedMemorySpaceAttr::get(insertBefore->getContext()); + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), distance); + Type memdescType = ttg::MemDescType::get(bufferShape, ty.getElementType(), + sharedEnc, sharedMemorySpace, + /*mutableMemory=*/true); + Value alloc = builder.create(loc, memdescType); + + builder.setInsertionPointAfter(insertBefore); + builder.create(insertBefore->getLoc(), alloc); + return alloc; +} + +bool mlir::triton::isTMALoad(Operation *op) { + return isa(op); +} + +bool mlir::triton::canBeAsyncLoad(Operation *op) { + if (mlir::triton::isTMALoad(op)) { + return true; + } + assert(isa(op)); + ttg::SharedEncodingTrait sharedEncoding = mlir::triton::getSharedEncoding(op); + // Do not create async loads for small loads (cp.async requires at least 4 + // bytes) + int copyVecBytes = mlir::triton::getCopyVecBytes( + cast(op->getResultTypes()[0]), sharedEncoding); + if (copyVecBytes >= 4) { + return true; + } + return false; +} + +void mlir::triton::combineRedundantWaitOps( + llvm::SmallSetVector &waitOps) { + llvm::MapVector toDelete; + for (auto waitOp : waitOps) { + if (toDelete.count(waitOp)) + continue; + SmallVector waitGroup = {waitOp}; + SmallVector depTokens = waitOp.getOperands(); + unsigned minWaitNumber = waitOp.getNum(); + Operation *next = waitOp->getNextNode(); + while (next && !isa(next)) { + if (auto nextWait = dyn_cast(next)) { + waitGroup.push_back(nextWait); + minWaitNumber = std::min(minWaitNumber, nextWait.getNum()); + depTokens.append(nextWait.getOperands().begin(), + nextWait.getOperands().end()); + } + next = next->getNextNode(); + } + if (waitGroup.size() == 1) + continue; + OpBuilder builder(waitGroup.front()); + auto newWaitOp = builder.create(waitOp.getLoc(), + depTokens, minWaitNumber); + for (auto waitOp : waitGroup) { + toDelete[waitOp] = newWaitOp; + } + } + for (auto waitOp : toDelete) { + waitOp.first->replaceAllUsesWith(waitOp.second); + waitOp.first->erase(); + } +} + +ttg::MemDescType mlir::triton::getBufferViewType(ttg::MemDescType allocTy) { + Attribute sharedMemorySpace = + ttg::SharedMemorySpaceAttr::get(allocTy.getContext()); + return ttg::MemDescType::get(allocTy.getShape().drop_front(), + allocTy.getElementType(), allocTy.getEncoding(), + sharedMemorySpace, /*mutableMemory=*/true, + /*allocShape=*/allocTy.getAllocShape()); +} + +ttg::SharedEncodingTrait mlir::triton::getSharedEncoding(RankedTensorType ty) { + auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); + auto order = ttg::getOrder(ty); + // Use generic layout. This won't be optimal for 2D tensors. + return ttg::SwizzledSharedEncodingAttr::get(ty.getContext(), 1, 1, 1, order, + ctaLayout); +} + +ttg::SharedEncodingTrait mlir::triton::getSharedEncoding(Operation *op) { + // Try to use local alloc encoding if possible. + ttg::SharedEncodingTrait localAllocEnc; + if (llvm::any_of(op->getUsers(), [&](Operation *user) { + return isa(user); + })) { + for (auto user : op->getUsers()) { + auto localAlloc = dyn_cast(user); + if (!localAlloc) + continue; + auto enc = mlir::cast( + localAlloc.getType().getEncoding()); + if (!localAllocEnc) { + localAllocEnc = enc; + } + if (enc != localAllocEnc) { + // Some users have different encoding than others. + // Use one of the encodings, and warn about the performance issue. + op->emitRemark() + << "Pipelining load with different use encodings. This will lead " + "to layout conversions and performance degradation."; + continue; + } + } + } + + auto ty = cast(op->getResultTypes()[0]); + auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); + auto order = ttg::getOrder(ty); + if (isTMALoad(op)) { + // TMA encoding is set on the descriptor type + TypedValue desc; + if (auto load = dyn_cast(op)) { + desc = load.getDesc(); + } else if (auto gather = dyn_cast(op)) { + desc = gather.getDesc(); + } else { + op->emitError() << "unrecognized tma load type"; + llvm::report_fatal_error("unrecognized tma load type"); + } + return ttng::getEncodingFromDescriptor(op, ty, desc); + } + + if (localAllocEnc) + return localAllocEnc; + + // Try to use dot encoding if possible. + bool incompatible = false; + localAllocEnc = + getSharedEncIfAllUsersAreDotEnc(op->getResult(0), incompatible) + .value_or(nullptr); + + if (localAllocEnc) + return localAllocEnc; + + // Use generic layout. This won't be optimal for 2D tensors. + return ttg::SwizzledSharedEncodingAttr::get(ty.getContext(), 1, 1, 1, order, + ctaLayout); +} + +int mlir::triton::getNumStagesOrDefault(scf::ForOp forOp, + int defaultNumStages) { + // Use the attribute attached to the loop if it exists otherwise use the + // global control. + auto helper = TritonDialect::getLoaded(forOp)->getNumStagesAttrHelper(); + if (auto attr = helper.getAttr(forOp)) + return attr.getInt(); + return defaultNumStages; +} + +TypedValue +triton::createSingleBufferView(OpBuilder &builder, Value alloc, Value idx) { + assert(isa(alloc.getType()) && "Expected MemDescType"); + auto allocDescType = cast(alloc.getType()); + SmallVector shape; + if (allocDescType.getShape().size() > 1) { + shape.insert(shape.end(), allocDescType.getShape().begin() + 1, + allocDescType.getShape().end()); + } else { + shape.push_back(1); + } + auto viewDescType = ttg::MemDescType::get( + shape, allocDescType.getElementType(), allocDescType.getEncoding(), + allocDescType.getMemorySpace(), allocDescType.getMutableMemory(), + /*allocShape=*/allocDescType.getAllocShape()); + SmallVector idxs = {idx}; + if (allocDescType.getShape().size() > 1) { + Value zero = builder.create(alloc.getLoc(), 0, 32); + for (unsigned i = 1; i < allocDescType.getShape().size(); i++) { + idxs.push_back(zero); + } + } + return builder.create(alloc.getLoc(), viewDescType, + alloc, idxs); +} + +TypedValue +triton::createSingleBufferView(OpBuilder &builder, Value alloc, int idx) { + Value idxVal = builder.create(alloc.getLoc(), idx, 32); + return createSingleBufferView(builder, alloc, idxVal); +} + +Value triton::createIncrementModulo(OpBuilder &builder, Location loc, + Value counter, Value modulus, Value zero, + Value one, Value *outWrapCond) { + Value addOne = builder.create(loc, counter, one); + Value outOfRangeCond = builder.create( + loc, arith::CmpIPredicate::sge, addOne, modulus); + if (outWrapCond) + *outWrapCond = outOfRangeCond; + return builder.create(loc, outOfRangeCond, zero, addOne); +} + +///////////////////////////// +// LOWER TMA DESCRIPTORS +///////////////////////////// + +static void +allocTMABuffers(scf::ForOp forOp, + llvm::MapVector &tmaBufferMapping, + int maxStage) { + IRRewriter rewriter(forOp); + + // Create a multi-buffered allocation for each MakeTensorDescOp call in the + // loop + forOp.walk([&](tt::MakeTensorDescOp op) { + // TODO peter: walk to loop yield to find the init value if this is a + // loop-carried value. That would save us from allocating another buffer + // just for the init value + auto loc = op.getLoc(); + Value alloc = rewriter.create( + loc, triton::getPointerType(rewriter.getI8Type()), + maxStage * ttng::TMA_SIZE_BYTES, ttng::TMA_ALIGN); + tmaBufferMapping[op.getOperation()] = alloc; + }); +} + +static Value subviewTMADescriptor(OpBuilder &builder, Location loc, Value alloc, + Value counter) { + Value tmaSizeVal = + builder.create(loc, ttng::TMA_SIZE_BYTES, 32); + Value offset = builder.create(loc, tmaSizeVal, counter); + return builder.create(loc, alloc.getType(), alloc, offset); +} + +static LogicalResult rewriteTMABufferUpdates( + scf::ForOp forOp, + const llvm::MapVector &tmaBufferMapping, + ArrayRef tmaCounters, int numBuffers, Value one, Value zero, + triton::CoarseSchedule &schedule) { + assert(tmaBufferMapping.size() == tmaCounters.size()); + + Value numBuffersVal = mlir::OpBuilder(forOp).create( + forOp.getLoc(), numBuffers, 32); + + for (auto [iOp, pair] : llvm::enumerate(tmaBufferMapping)) { + auto &[op, alloc] = pair; + + // Rewriter MakeTensorDescOp as writing a TMA descriptor + auto makeDescOp = cast(op); + + triton::OpBuilderForStage builder(makeDescOp.getLoc(), makeDescOp, + schedule); + + BlockArgument counter = tmaCounters[iOp]; + Value nextBuf = + subviewTMADescriptor(builder, builder.getLoc(), alloc, counter); + if (failed(ttng::createTMADesc(nextBuf, makeDescOp, builder))) { + return failure(); + } + builder.create(nextBuf); + Value nextDesc = builder.create( + makeDescOp.getType(), nextBuf); + + makeDescOp.getResult().replaceAllUsesWith(nextDesc); + + // Increment the buffer index counter + Value nextCounter = createIncrementModulo( + builder, builder.getLoc(), counter, numBuffersVal, zero, one); + + // If we are in a (potentially nested) if region, propagate the counter + // up to the main for op body scope + IRRewriter rewriter(forOp); + nextCounter = triton::sinkValueRedefinition(rewriter, counter, nextCounter, + op->getBlock()); + + // Finally, rewrite the loop level yield + auto forYield = cast(forOp.getBody()->getTerminator()); + forYield.setOperand(counter.getArgNumber() - 1, nextCounter); + } + return success(); +} + +scf::ForOp triton::lowerTMADescriptors(scf::ForOp forOp, + CoarseSchedule &schedule) { + llvm::MapVector tmaBufferMapping; + int maxStage = schedule.getNumStages() - 1; + for (auto &op : forOp.getBody()->without_terminator()) { + if (auto wgMmaOp = dyn_cast(&op)) { + // Hopper only: Add one more buffer slice if there is a WarpGroupDotOp, + // as if it will be pipelined, we will effectively make the pipeline + // one stage longer. + maxStage += 1; + break; + } + } + allocTMABuffers(forOp, tmaBufferMapping, maxStage); + if (tmaBufferMapping.empty()) + return forOp; + + IRRewriter builder(forOp); + Location loc = forOp.getLoc(); + Value zero = builder.create(loc, 0, 32); + Value one = builder.create(loc, 1, 32); + SmallVector newOperands; + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + // Create one counter per TMA buffer. This allows the descriptors to be + // updated independently without needing to write duplicate of existing tma + // descriptors. + unsigned tmaCounterArgsStartIdx = newOperandIndex + newOperands.size(); + for (int i = 0; i < tmaBufferMapping.size(); ++i) { + newOperands.push_back(zero); + } + + forOp = addIterArgsToLoop(builder, forOp, newOperands); + + auto tmaCounters = ArrayRef(forOp.getBody()->getArguments()) + .slice(tmaCounterArgsStartIdx); + + // Update yield op with temporary yield values + auto forYield = cast(forOp.getBody()->getTerminator()); + for (unsigned i = 0; i < newOperands.size(); ++i) { + forYield.getResultsMutable().append(newOperands[i]); + } + + if (failed(rewriteTMABufferUpdates(forOp, tmaBufferMapping, tmaCounters, + maxStage, one, zero, schedule))) { + llvm_unreachable("Failed to rewrite TMA ops"); + } + return forOp; +} diff --git a/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp new file mode 100644 index 000000000..4d3838ace --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -0,0 +1,468 @@ +//===----------------------------------------------------------------------===// +// +// This pass tries to prefetch operands (a and b) of tt.dot. +// Those ConvertLayoutOps will be lowered to shared memory loads. +// +// For example: +// %a: tensor<128x32xf16, #enc> +// scf.for %iv = ... iter_args(%a_arg = %a, ...) { +// %d = tt.dot %a_arg, %b, %c +// ... +// scf.yield %a_next, ... +// } +// +// will be translated to +// +// %a: tensor<128x32xf16, #enc> +// %a_tmp = tensor.subview %a[0, 0] [128, 16] +// %a_prefetch = ttg.local_load %a_tmp +// scf.for %iv = ... iter_args(%a_buf = %a, ..., %a_prefetch_arg = %a_prefetch) +// { +// %x = tt.dot %a_prefetch_arg, %b, %c +// %a_tmp_rem = tensor.subview %a_buf[0, 16] [128, 16] +// %a_prefetch_next = ttg.local_load %a_tmp_rem +// ... +// scf.yield %next_a, ..., %a_prefetch_next +// } +//===----------------------------------------------------------------------===// + +#include "mlir/IR/IRMapping.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tritongpu-prefetch" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUPREFETCH +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { + +class Prefetcher { + /// cache the ForOp we are working on + scf::ForOp forOp; + /// cache the YieldOp of this ForOp + scf::YieldOp yieldOp; + /// + // TODO: add a hook to infer prefetchWidth + unsigned prefetchWidth = 32; + + /// dots to be prefetched + SetVector dots; + /// dot => dot operand + DenseMap dot2aLoopArg; + DenseMap dot2aHeaderDef; + DenseMap dot2bLoopArg; + DenseMap dot2bHeaderDef; + DenseMap dot2aYield; + DenseMap dot2bYield; + DenseMap> dot2aVals; + DenseMap> dot2bVals; + /// operand => defining + DenseMap operand2headPrefetch; + + LogicalResult isForOpOperand(Value v); + + Value generatePrefetch(Value v, unsigned opIdx, bool isPrologue, + Attribute dotEncoding, OpBuilder &builder, + std::optional offsetK = std::nullopt, + std::optional shapeK = std::nullopt); + + void cloneElementwiseOps(Value &bRem, const SmallVector &vals, + OpBuilder &builder); + +public: + Prefetcher() = delete; + + Prefetcher(scf::ForOp forOp) : forOp(forOp) { + yieldOp = cast(forOp.getBody()->getTerminator()); + } + + LogicalResult initialize(); + + void emitPrologue(); + + scf::ForOp createNewForOp(); +}; + +void Prefetcher::cloneElementwiseOps(Value &ret, const SmallVector &vals, + OpBuilder &builder) { + IRMapping mapping; + mapping.map(vals[1], ret); + for (int i = 2; i < vals.size(); i++) { + Value v = vals[i]; + Value curr = builder.clone(*v.getDefiningOp(), mapping)->getResult(0); + if (isa(curr.getType())) { + auto retType = RankedTensorType::get( + cast(ret.getType()).getShape(), + cast(curr.getType()).getElementType(), + cast(curr.getDefiningOp()->getOperand(0).getType()) + .getEncoding()); + curr.setType(retType); + } + mapping.map(v, curr); + } + if (vals.size() > 1) + ret = mapping.lookup(vals.back()); +} + +Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, + Attribute dotEncoding, OpBuilder &builder, + std::optional offsetK, + std::optional shapeK) { + // opIdx: 0 => a, 1 => b + auto type = cast(v.getType()); + SmallVector shape{type.getShape().begin(), type.getShape().end()}; + auto rank = shape.size(); + SmallVector offset(rank, 0); + Type elementType = type.getElementType(); + + // k => (prefetchWidth, k - prefetchWidth) + int64_t kIdx = opIdx == 0 ? rank - 1 : rank - 2; + + offset[kIdx] = isPrologue ? 0 : prefetchWidth; + shape[kIdx] = isPrologue ? prefetchWidth : (shape[kIdx] - prefetchWidth); + + if (shapeK) + shape[kIdx] = *shapeK; + if (offsetK) + offset[kIdx] = *offsetK; + + SmallVector offsetsVal; + for (int64_t off : offset) + offsetsVal.push_back( + builder.create(v.getLoc(), off, 32)); + Value newSmem = builder.create( + v.getLoc(), + triton::gpu::MemDescType::get( + shape, elementType, type.getEncoding(), type.getMemorySpace(), + type.getMutableMemory(), type.getAllocShape()), + v, offsetsVal); + + LDBG("prolog newSmem: "<( + v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), + newSmem); + + return prefetchSlice; +} + +LogicalResult Prefetcher::initialize() { + Block *loop = forOp.getBody(); + + auto getEncoding = [](Value v) { + return cast(v.getType()).getEncoding(); + }; + + SmallVector dotsInFor; + for (Operation &op : *loop) + if (auto dotOp = dyn_cast(op)) { + // Only accepts dotOps encoded as Nvidia MMA v2 or AMD MFMA + auto dstMmaEnc = + dyn_cast(getEncoding(dotOp.getResult())); + auto dstMfmaEnc = + dyn_cast(getEncoding(dotOp.getResult())); + auto dstTmmaEnc = + dyn_cast(getEncoding(dotOp.getResult())); + if (!dstTmmaEnc && !dstMfmaEnc && (!dstMmaEnc || dstMmaEnc.getVersionMajor() != 2)) + // Don't rewrite if any other type is found. + return failure(); + dotsInFor.push_back(dotOp); + } + + if (dotsInFor.empty()) + return failure(); + + // TODO: segfault (original for still has uses) + // when used in flash attention that has 2 dots in the loop + if (dotsInFor.size() > 1) + return failure(); + + // returns source of cvt + auto getPrefetchSrc = [](Value v) -> SmallVector { + // walk back to conversion + Operation *op = v.getDefiningOp(); + bool foundConvertFromShared = false; + SmallVector rets; + rets.push_back(op->getResult(0)); + LDBG("Prefetch src: " << *op); + while (op) { + if (op->getNumOperands() != 1) { + break; + } + if (!op->getResult(0).hasOneUse()) { + break; + } + rets.push_back(op->getOperand(0)); + if (auto cvt = dyn_cast(op)) { + // NYI for other encodings, for example if we have transpose + // in the chain + if (isa(cvt.getType().getEncoding())) + foundConvertFromShared = true; + break; + } + op = op->getOperand(0).getDefiningOp(); + if (op) + LDBG("op: " << *op); + } + std::reverse(rets.begin(), rets.end()); + if (foundConvertFromShared) + return rets; + return {}; + }; + + auto getIncomingOp = [this](Value v) -> Value { + if (auto arg = mlir::dyn_cast(v)) + if (arg.getOwner()->getParentOp() == forOp.getOperation()) + return forOp.getTiedLoopInit(arg)->get(); + return Value(); + }; + + auto getYieldOperand = [this](Value v) -> Value { + auto arg = mlir::cast(v); + unsigned yieldIdx = arg.getArgNumber() - forOp.getNumInductionVars(); + return yieldOp.getOperand(yieldIdx); + }; + + for (triton::DotOp dot : dotsInFor) { + auto aType = dot.getA().getType(); + auto bType = dot.getB().getType(); + auto aEnc = + mlir::cast(aType.getEncoding()); + auto bEnc = + mlir::cast(bType.getEncoding()); + int aKWidth = aEnc.getKWidth(); + int bKWidth = bEnc.getKWidth(); + assert(aKWidth == bKWidth); + + auto kSize = aType.getShape().back(); + + // works better with nvidia tensor cores + unsigned elementWidth = aType.getElementTypeBitWidth(); + if (aKWidth == 0) + prefetchWidth = 256 / elementWidth; + else + prefetchWidth = 8 * aKWidth; + + // Skip prefetching if kSize is less than prefetchWidth + if (kSize < prefetchWidth) + continue; + auto aVals = getPrefetchSrc(dot.getA()); + auto bVals = getPrefetchSrc(dot.getB()); + LDBG("aVals.size:"< loopArgs; + for (auto v : forOp.getInitArgs()) + loopArgs.push_back(v); + for (triton::DotOp dot : dots) { + loopArgs.push_back(operand2headPrefetch[dot.getA()]); + loopArgs.push_back(operand2headPrefetch[dot.getB()]); + } + + auto newForOp = builder.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), loopArgs); + + builder.setInsertionPointToStart(newForOp.getBody()); + IRMapping mapping; + for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + + // The insertion point should be placed before the yield op + auto setInsertionPointBeforeYield = [](OpBuilder &builder, + scf::ForOp newForOp) { + if (newForOp.getBody()->mightHaveTerminator()) { + builder.setInsertionPoint(newForOp.getBody()->getTerminator()); + } else { + builder.setInsertionPointToEnd(newForOp.getBody()); + } + }; + + for (Operation &op : forOp.getBody()->without_terminator()) { + // If we're currently trying to sink a prefetched dot, we need to stop + // sinking it (by resetting the insertion point to the end) if we find + // control flow, or anything that depends on the dot op. + if (op.getNumRegions() > 0) { + setInsertionPointBeforeYield(builder, newForOp); + } + for (auto operand : op.getOperands()) { + if (auto def = operand.getDefiningOp()) { + auto dot = dyn_cast(def); + if (dot && dots.contains(dot)) { + setInsertionPointBeforeYield(builder, newForOp); + } + } + } + Operation *newOp = builder.clone(op, mapping); + auto dot = dyn_cast(&op); + if (dot && dots.contains(dot)) { + Attribute dotEncoding = dot.getType().getEncoding(); + // prefetched dot + Operation *firstDot = builder.clone(*dot, mapping); + if (Value a = operand2headPrefetch.lookup(dot.getA())) + firstDot->setOperand( + 0, newForOp.getTiedLoopRegionIterArg(&*a.use_begin())); + if (Value b = operand2headPrefetch.lookup(dot.getB())) + firstDot->setOperand( + 1, newForOp.getTiedLoopRegionIterArg(&*b.use_begin())); + + // remaining part + int64_t kOff = prefetchWidth; + int64_t kRem = dot.getA().getType().getShape().back() - prefetchWidth; + Operation *prevDot = firstDot; + if (kRem == 0) { + // There is only one dot while prefetchWidth == kSize so delay issuing + // it. Meanwhile, newOp should be set to firstDot to make sure the dot + // result is updated to yield. + builder.setInsertionPoint(prevDot); + newOp = firstDot; + } + + while (kRem != 0) { + // int64_t kShape = largestPow2(kRem); + int64_t kShape = prefetchWidth; + auto insertionPoint = builder.saveInsertionPoint(); + builder.setInsertionPoint(prevDot); + Value aRem = + generatePrefetch(mapping.lookup(dot2aLoopArg[dot]), 0, false, + dotEncoding, builder, kOff, kShape); + cloneElementwiseOps(aRem, dot2aVals[dot], builder); + Value bRem = + generatePrefetch(mapping.lookup(dot2bLoopArg[dot]), 1, false, + dotEncoding, builder, kOff, kShape); + cloneElementwiseOps(bRem, dot2bVals[dot], builder); + builder.restoreInsertionPoint(insertionPoint); + newOp = builder.clone(*dot, mapping); + newOp->setOperand(0, aRem); + newOp->setOperand(1, bRem); + newOp->setOperand(2, prevDot->getResult(0)); + prevDot = newOp; + kOff += kShape; + kRem -= kShape; + if (kRem == 0) { + // We want to delay issuing the last dot as long as possible, ideally + // until after the prefetch. To accomplish this, set the insertion + // point above the dot. If we find anything dependent on the dot (at + // the top of this loop), we resume inserting after it. + builder.setInsertionPoint(prevDot); + } + } + } + // update mapping of results + for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults())) + mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx)); + } + + // prefetch next iteration + SmallVector yieldValues; + for (Value v : forOp.getBody()->getTerminator()->getOperands()) + yieldValues.push_back(mapping.lookupOrDefault(v)); + for (triton::DotOp dot : dots) { + Attribute dotEncoding = dot.getType().getEncoding(); + Value aToYield = generatePrefetch(mapping.lookup(dot2aYield[dot]), 0, true, + dotEncoding, builder); + cloneElementwiseOps(aToYield, dot2aVals[dot], builder); + yieldValues.push_back(aToYield); + // bToYield + Value bToYield = generatePrefetch(mapping.lookup(dot2bYield[dot]), 1, true, + dotEncoding, builder); + cloneElementwiseOps(bToYield, dot2bVals[dot], builder); + yieldValues.push_back(bToYield); + } + // Update ops of yield + builder.setInsertionPointToEnd(newForOp.getBody()); + if (!yieldValues.empty()) + builder.create(yieldOp.getLoc(), yieldValues); + return newForOp; +} + +} // anonymous namespace + +struct PrefetchPass : public impl::TritonGPUPrefetchBase { + void runOnOperation() override { + + // Canonicalize convert ops to make the pattern matching easier. + RewritePatternSet cleanUpPatterns(&getContext()); + triton::gpu::ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns, + &getContext()); + if (mlir::applyPatternsGreedily(getOperation(), std::move(cleanUpPatterns)) + .failed()) { + signalPassFailure(); + } + getOperation()->walk([&](scf::ForOp forOp) { + Prefetcher prefetcher(forOp); + + if (prefetcher.initialize().failed()) + return; + + prefetcher.emitPrologue(); + + scf::ForOp newForOp = prefetcher.createNewForOp(); + + // replace the original loop + for (unsigned i = 0; i < forOp->getNumResults(); ++i) + forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i)); + forOp->erase(); + }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp new file mode 100644 index 000000000..811005997 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -0,0 +1,1689 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include + +namespace mlir::triton::gpu { + +#define GEN_PASS_DEF_TRITONGPUREMOVELAYOUTCONVERSIONS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "tritongpu-remove-layout-conversions" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace { + +// ----------------------------------------------------------------------------- +// +// ----------------------------------------------------------------------------- + +// The current algorithm works by analyzing the IR and doing a one-shot rewrite +// based on the analysis. The algorithm is as follows. +// +// 1. Find all the anchor ops. These are ops that have a layout we want to +// preserve. +// +// 2. For each anchor, propagate its layout to all its descendants. +// An op can have multiple ancestors that are anchors, so at this stage an op +// may have multiple layouts associated with it. +// +// 3. Resolve conflicts by deciding which of the multiple layouts the op should +// keep, inserting convert-layout ops to resolve conflicts. After this +// stage, each value has only one layout associated with it. +// +// 4. Rewrite the IR by walking the function in dominance order. Since we +// assume the IR is structured we just need to process the regions in the +// correct order. For each op, rewrite it using the layout decided by the +// analysis phase. +class LayoutPropagation { +public: + // Structure to keep track of the layout associated to a value. + struct LayoutInfo { + LayoutInfo(Attribute encoding) { encodings.insert(encoding); } + LayoutInfo() {} + llvm::SmallSetVector encodings; + }; + LayoutPropagation(FuncOp F) : funcOp(F) {} + // Find the anchor ops and set their layout in the data structure. + void initAnchorLayout(); + // Recursively Propagate the layout to all the users of the anchor ops until + // we reach a fix point. + void propagateLayout(); + // Add layouts given in `Info` to the uses of `value`. + SmallVector propagateToUsers(Value value, LayoutInfo &info); + // Set the encoding to all the values and fill out the values with new layout + // in `changed`. + void setEncoding(ValueRange values, LayoutInfo &info, + SmallVector &changed, Operation *op); + // Resolve cases where a value has multiple layouts associated to it. + void resolveConflicts(); + // Rewrite the IR for the full module. + void rewrite(); + // Rewrite the IR for a region. + void rewriteRegion(Region &R); + // Rewrite an op based on the layout picked by the analysis. + Operation *rewriteOp(Operation *op); + // Rewrite a for op based on the layout picked by the analysis. + Operation *rewriteForOp(scf::ForOp forOp); + Operation *rewriteWhileOp(scf::WhileOp whileOp); + Operation *rewriteIfOp(scf::IfOp ifOp); + void rewriteYieldOp(scf::YieldOp yieldOp); + void rewriteConditionOp(scf::ConditionOp conditionOp); + void rewriteReduceToScalar(Operation *reduceOp); + void rewriteAssertOp(AssertOp assertOp); + Operation *cloneElementwise(OpBuilder &rewriter, Operation *op, + Attribute encoding); + // Map the original value to the rewritten one. + void map(Value old, Value newV); + // Return the mapped value in the given encoding. This will insert a convert + // if the encoding is different than the encoding decided at resolve time. + Value getValueAs(Value value, Attribute encoding); + // Dump the current stage of layout information. + void dump(); + +private: + // map from value to layout information. + llvm::MapVector layouts; + // map of the values rewrite based on their encoding. + DenseMap, Value> rewriteMapping; + SetVector opToDelete; + FuncOp funcOp; +}; + +class LayoutRematerialization { +public: + LayoutRematerialization(FuncOp F) : funcOp(F) {} + + // Map the original value to the remat'ed one. + void addRematValue(Value old, Attribute encoding, Value newV); + // Get the remat'ed value in the given encoding, if one already exists and + // is different then the layout conversion root. + Value getRematValue(Value value, Attribute encoding) const { + return rematMapping.lookup({value, encoding}); + } + + void cleanup(); + void backwardRematerialization(); + void backwardRematerialization(ConvertLayoutOp convertOp); + // TODO: Merge the three hoistConvert*(); functions as they are duplicate code + void hoistConvertDotOperand(); + void hoistConvertDotOperand(ConvertLayoutOp convertOp); + void hoistConvertOnTopOfExtOrBroadcast(); + void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp); + void hoistConvertIntoConditionals(); + void hoistConvertIntoConditionals(ConvertLayoutOp convertOp); + void rewriteSlice(SetVector &slice, DenseMap &layout, + ConvertLayoutOp convertOp, IRMapping &mapping); + void rewriteSlice(SetVector &slice, DenseMap &layout, + ConvertLayoutOp convertOp); + + LogicalResult + getConvertBackwardSlice(OpOperand &root, Attribute rootEncoding, + SetVector &slice, + DenseMap &layout, + std::function stopPropagation); + + LogicalResult getRematerializableSlice( + OpOperand &root, Attribute rootEncoding, SetVector &slice, + DenseMap &layout, + std::function stopPropagation = nullptr); + +private: + void updateRematMapping(SmallVector> &values); + // Existing tuples of (value, layout) that needs to be updated when recreating + // scf ops. This prevents keeping track of Values that have been delete when + // rewriting slices. + DenseMap mappedValues; + // map of the values remat based on encoding. + DenseMap, Value> rematMapping; + // DenseMap, Operation*> + SetVector opToDelete; + FuncOp funcOp; + DominanceInfo domInfo; + PostDominanceInfo postDomInfo; +}; + +void LayoutRematerialization::addRematValue(Value old, Attribute encoding, + Value newV) { + LDBG("addRematValue " << old << " encoding " << encoding << " " << newV); + rematMapping[{old, encoding}] = newV; + mappedValues[old] = encoding; +} + +// Remove unneeded values now that we are done with the rematMapping. +void LayoutRematerialization::cleanup() { + for (Operation *op : llvm::reverse(opToDelete)) + op->erase(); +} + +// Return true if the op is an op with a layout we don't want to change. We will +// propagate the layout starting from anchor ops. +bool isLayoutAnchor(Operation *op) { + if (isa(op)) + return isExpensiveLoadOrStore(op); + if (isa(op)) + return true; + if (auto gatherOp = dyn_cast(op)) + return gatherOp.getEfficientLayout(); + + // Heuristic: Mark permuting reshape as a layout anchor. Its dst can be + // anything, so it stops forward-propagation of layouts. We rely on the + // backwards pass to fix it up if necessary. (If we didn't do this, then + // anything following the reshape won't be covered by the forward pass at + // all.) + if (auto reshape = dyn_cast(op)) + return reshape.getAllowReorder(); + + return false; +} + +void LayoutPropagation::initAnchorLayout() { + auto addAnchor = [&](Value v) { + if (auto tensorType = dyn_cast(v.getType())) { + layouts.insert({v, LayoutInfo(tensorType.getEncoding())}); + } + }; + + // Consider function args as anchors. This makes it easier to write tests -- + // you can pass a tensor with an encoding as an arg, instead of explicitly + // calling tt.load. + for (auto arg : funcOp.getArguments()) { + addAnchor(arg); + } + + funcOp.walk([&](Operation *op) { + if (isLayoutAnchor(op)) { + for (auto result : op->getResults()) { + addAnchor(result); + } + } + }); +} + +void LayoutPropagation::setEncoding(ValueRange values, LayoutInfo &info, + SmallVector &changed, + Operation *op) { + for (Value value : values) { + if (!isa(value.getType())) + continue; + bool hasChanged = false; + for (auto encoding : info.encodings) { + Attribute dstEncoding; + if (isa(op)) { + // Try to remove the convert by making the dst encoding match the source + // encoding. + dstEncoding = encoding; + } else { + dstEncoding = inferDstEncoding(op, encoding); + } + if (dstEncoding) + hasChanged |= layouts[value].encodings.insert(dstEncoding); + } + if (hasChanged) + changed.push_back(value); + } +} + +SmallVector LayoutPropagation::propagateToUsers(Value value, + LayoutInfo &info) { + SmallVector changed; + for (OpOperand &use : value.getUses()) { + Operation *user = use.getOwner(); + if (auto forOp = dyn_cast(user)) { + Value arg = forOp.getTiedLoopRegionIterArg(&use); + Value result = forOp.getTiedLoopResult(&use); + setEncoding({arg, result}, info, changed, user); + continue; + } + if (auto whileOp = dyn_cast(user)) { + Value arg = whileOp.getBeforeArguments()[use.getOperandNumber()]; + setEncoding({arg}, info, changed, user); + continue; + } + if (auto yieldOp = dyn_cast(user)) { + auto parent = yieldOp->getParentOp(); + SmallVector valuesToPropagate; + if (isa(parent)) + valuesToPropagate.push_back(parent->getResult(use.getOperandNumber())); + if (auto forOp = dyn_cast(parent)) + valuesToPropagate.push_back( + forOp.getRegionIterArg(use.getOperandNumber())); + if (auto whileOp = dyn_cast(parent)) + valuesToPropagate.push_back( + whileOp.getBeforeArguments()[use.getOperandNumber()]); + if (isa(parent)) + setEncoding(valuesToPropagate, info, changed, user); + continue; + } + if (auto conditionOp = dyn_cast(user)) { + auto whileOp = cast(conditionOp->getParentOp()); + // Skip arg 0 as it is the condition. + unsigned argIndex = use.getOperandNumber() - 1; + Value afterArg = whileOp.getAfterArguments()[argIndex]; + Value result = whileOp->getResult(argIndex); + setEncoding({afterArg, result}, info, changed, user); + continue; + } + if (auto dotWaitOp = dyn_cast(user)) { + unsigned opIndex = use.getOperandNumber(); + Value result = dotWaitOp->getResult(opIndex); + setEncoding(result, info, changed, user); + continue; + } + if (auto gatherOp = dyn_cast(user)) { + // Propagate the layout through the indices only, and if the layout does + // not have an efficient layout set. + if (!gatherOp.getEfficientLayout() && + &use == &gatherOp.getIndicesMutable()) { + setEncoding(gatherOp.getResult(), info, changed, user); + continue; + } + } + if (user->hasTrait() || + user->hasTrait() || + isa(user)) { + if(isa(user)) { + auto srcType = dyn_cast(user->getOperand(0).getType()); + auto dstType = dyn_cast(user->getResult(0).getType()); + if(srcType != nullptr && dstType != nullptr) { + auto srcElemType = srcType.getElementType(); + auto dstElemType = dstType.getElementType(); + if((srcElemType.getIntOrFloatBitWidth() == 32 && dstElemType.getIntOrFloatBitWidth() == 16) + || (dstElemType.getIntOrFloatBitWidth() == 8)){ + continue; + } + } + } + setEncoding(user->getResults(), info, changed, user); + continue; + } + } + return changed; +} + +void LayoutPropagation::propagateLayout() { + SmallVector queue; + for (auto it : layouts) { + queue.push_back(it.first); + } + while (!queue.empty()) { + Value currentValue = queue.back(); + LayoutInfo info = layouts[currentValue]; + queue.pop_back(); + SmallVector changed = propagateToUsers(currentValue, info); + + LLVM_DEBUG({ + DBGS() << "propagateLayout considering " << currentValue << ", which has " + << info.encodings.size() << " candidate encoding(s):\n"; + for (Attribute encoding : info.encodings) + DBGS() << " " << encoding << "\n"; + DBGS() << "changed: " << changed.size() << "\n"; + }); + + queue.insert(queue.end(), changed.begin(), changed.end()); + } +} + +void LayoutPropagation::resolveConflicts() { + for (auto &it : layouts) { + Operation *op = it.first.getDefiningOp(); + LayoutInfo &info = it.second; + if (info.encodings.size() <= 1) + continue; + // Hacky resolve, prefer block encoding. + // TODO: add a proper heuristic. + Attribute encoding = *info.encodings.begin(); + bool isLoadOrStore = + op && isa(op); + for (Attribute e : info.encodings) { + if ((isLoadOrStore && isa(e)) || + (!isLoadOrStore && isa(e))) { + encoding = e; + break; + } + } + info.encodings.clear(); + info.encodings.insert(encoding); + } +} + +void LayoutPropagation::dump() { + for (auto it : layouts) { + llvm::errs() << "Value: "; + OpPrintingFlags flags; + flags.skipRegions(); + it.first.print(llvm::errs(), flags); + llvm::errs() << " \n encoding:\n"; + for (auto encoding : it.second.encodings) { + encoding.print(llvm::errs()); + llvm::errs() << "\n"; + } + llvm::errs() << "--\n"; + } +} + +void LayoutPropagation::rewrite() { rewriteRegion(funcOp->getRegion(0)); } + +bool reduceToScalar(Operation *op) { + // For reductions returning a scalar we can change the src encoding without + // affecting the output. + return isa(op) && !isa(op->getResultTypes()[0]); +} + +void LayoutPropagation::rewriteRegion(Region ®ion) { + std::deque queue = {®ion}; + while (!queue.empty()) { + Region *currentRegion = queue.front(); + queue.pop_front(); + for (Operation &op : currentRegion->getOps()) { + bool needRewrite = false; + SmallVector results = op.getResults(); + for (Value result : results) { + auto it = layouts.find(result); + // If we haven't mapped this value skip. + if (it == layouts.end()) + continue; + LayoutInfo &info = it->second; + assert(info.encodings.size() == 1 && + "we should have resolved to a single encoding"); + auto encoding = cast(result.getType()).getEncoding(); + // If the encoding is already what we want skip. + if (encoding == *info.encodings.begin()) + continue; + needRewrite = true; + } + if (needRewrite) { + Operation *newOp = rewriteOp(&op); + for (Region &R : newOp->getRegions()) + queue.push_back(&R); + } else if (auto yieldOp = dyn_cast(&op)) { + rewriteYieldOp(yieldOp); + } else if (auto conditionOp = dyn_cast(&op)) { + rewriteConditionOp(conditionOp); + } else if (reduceToScalar(&op)) { + rewriteReduceToScalar(&op); + } else if (auto assertOp = dyn_cast(&op)) { + rewriteAssertOp(assertOp); + } else { + // If we don't need to rewrite the op we still need to remap the + // operands. + for (OpOperand &operand : op.getOpOperands()) { + auto it = layouts.find(operand.get()); + if (it == layouts.end()) + continue; + Attribute encoding = + cast(operand.get().getType()).getEncoding(); + Value newOperand = getValueAs(operand.get(), encoding); + op.setOperand(operand.getOperandNumber(), newOperand); + } + for (Region &R : op.getRegions()) + queue.push_back(&R); + } + } + } + for (Operation *op : llvm::reverse(opToDelete)) + op->erase(); +} + +void LayoutPropagation::map(Value old, Value newV) { + rewriteMapping[{old, cast(newV.getType()).getEncoding()}] = + newV; +} + +Value LayoutPropagation::getValueAs(Value value, Attribute encoding) { + if (auto tensorType = dyn_cast(value.getType())) { + Value rewrittenValue; + auto layoutIt = layouts.find(value); + if (layoutIt == layouts.end()) { + rewrittenValue = value; + } else { + assert(layoutIt->second.encodings.size() == 1 && + "we should have resolved to a single encoding"); + Attribute encodingPicked = *(layoutIt->second.encodings.begin()); + if (encodingPicked == tensorType.getEncoding()) + rewrittenValue = value; + else + rewrittenValue = rewriteMapping[{value, encodingPicked}]; + } + assert(rewrittenValue); + if (cast(rewrittenValue.getType()).getEncoding() == + encoding) + return rewrittenValue; + OpBuilder rewriter(value.getContext()); + rewriter.setInsertionPointAfterValue(rewrittenValue); + auto tmpType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + Value converted = rewriter.create(value.getLoc(), tmpType, + rewrittenValue); + // TODO: we could cache the conversion. + return converted; + } + return value; +} + +Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter, + Operation *op, + Attribute encoding) { + Operation *newOp = rewriter.clone(*op); + + Attribute operandEnc; + if (op->getNumOperands() > 0) { + operandEnc = inferSrcEncoding(op, encoding); + assert(operandEnc); + } + + for (OpOperand &operand : op->getOpOperands()) { + newOp->setOperand(operand.getOperandNumber(), + getValueAs(operand.get(), operandEnc)); + } + + for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) { + auto origType = dyn_cast(op->getResult(i).getType()); + if (!origType) + continue; + auto newType = RankedTensorType::get(origType.getShape(), + origType.getElementType(), encoding); + newOp->getResult(i).setType(newType); + } + return newOp; +} + +Operation *LayoutPropagation::rewriteForOp(scf::ForOp forOp) { + SmallVector operands; + OpBuilder rewriter(forOp); + for (auto [operand, result] : + llvm::zip(forOp.getInitArgs(), forOp.getResults())) { + Value convertedOperand = operand; + if (layouts.count(result)) + convertedOperand = + getValueAs(operand, *layouts[result].encodings.begin()); + operands.push_back(convertedOperand); + } + auto newForOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), operands); + newForOp->setAttrs(forOp->getAttrs()); + newForOp.getBody()->getOperations().splice( + newForOp.getBody()->getOperations().begin(), + forOp.getBody()->getOperations()); + + for (auto [oldResult, newResult] : + llvm::zip(forOp.getResults(), newForOp.getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + + for (auto [oldArg, newArg] : llvm::zip(forOp.getBody()->getArguments(), + newForOp.getBody()->getArguments())) { + if (oldArg.getType() == newArg.getType()) { + oldArg.replaceAllUsesWith(newArg); + continue; + } + map(oldArg, newArg); + } + return newForOp.getOperation(); +} + +Operation *LayoutPropagation::rewriteWhileOp(scf::WhileOp whileOp) { + SmallVector operands; + SmallVector returnTypes; + OpBuilder rewriter(whileOp); + for (auto [operand, arg] : + llvm::zip(whileOp->getOperands(), whileOp.getBeforeArguments())) { + Value convertedOperand = operand; + if (layouts.count(arg)) + convertedOperand = getValueAs(operand, *layouts[arg].encodings.begin()); + operands.push_back(convertedOperand); + } + for (Value ret : whileOp.getResults()) { + auto it = layouts.find(ret); + if (it == layouts.end()) { + returnTypes.push_back(ret.getType()); + continue; + } + auto origType = dyn_cast(ret.getType()); + auto newType = + RankedTensorType::get(origType.getShape(), origType.getElementType(), + it->second.encodings[0]); + returnTypes.push_back(newType); + } + + auto newWhileOp = + rewriter.create(whileOp.getLoc(), returnTypes, operands); + SmallVector argsTypesBefore; + for (Value operand : operands) + argsTypesBefore.push_back(operand.getType()); + SmallVector bbArgLocsBefore(argsTypesBefore.size(), + whileOp.getLoc()); + SmallVector bbArgLocsAfter(returnTypes.size(), whileOp.getLoc()); + rewriter.createBlock(&newWhileOp.getBefore(), {}, argsTypesBefore, + bbArgLocsBefore); + rewriter.createBlock(&newWhileOp.getAfter(), {}, returnTypes, bbArgLocsAfter); + + for (int i = 0; i < whileOp.getNumRegions(); ++i) { + newWhileOp->getRegion(i).front().getOperations().splice( + newWhileOp->getRegion(i).front().getOperations().begin(), + whileOp->getRegion(i).front().getOperations()); + } + + auto remapArg = [&](Value oldVal, Value newVal) { + if (oldVal.getType() == newVal.getType()) + oldVal.replaceAllUsesWith(newVal); + else + map(oldVal, newVal); + }; + for (auto [oldResult, newResult] : + llvm::zip(whileOp.getResults(), newWhileOp.getResults())) + remapArg(oldResult, newResult); + for (auto [oldArg, newArg] : + llvm::zip(whileOp.getBeforeArguments(), newWhileOp.getBeforeArguments())) + remapArg(oldArg, newArg); + for (auto [oldArg, newArg] : + llvm::zip(whileOp.getAfterArguments(), newWhileOp.getAfterArguments())) + remapArg(oldArg, newArg); + return newWhileOp.getOperation(); +} + +Operation *LayoutPropagation::rewriteIfOp(scf::IfOp ifOp) { + SmallVector operands; + OpBuilder rewriter(ifOp); + SmallVector newResultTypes(ifOp->getResultTypes()); + for (unsigned i = 0, e = ifOp->getNumResults(); i < e; ++i) { + auto it = layouts.find(ifOp->getResult(i)); + if (it == layouts.end()) + continue; + auto origType = cast(ifOp->getResult(i).getType()); + Attribute encoding = *(it->second.encodings.begin()); + newResultTypes[i] = RankedTensorType::get( + origType.getShape(), origType.getElementType(), encoding); + } + auto newIfOp = rewriter.create(ifOp.getLoc(), newResultTypes, + ifOp.getCondition(), true, true); + newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + for (auto [oldResult, newResult] : + llvm::zip(ifOp.getResults(), newIfOp.getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + return newIfOp.getOperation(); +} + +void LayoutPropagation::rewriteYieldOp(scf::YieldOp yieldOp) { + Operation *parentOp = yieldOp->getParentOp(); + for (OpOperand &operand : yieldOp->getOpOperands()) { + Type yieldType = operand.get().getType(); + if (isa(parentOp)) + yieldType = parentOp->getResult(operand.getOperandNumber()).getType(); + if (auto whileOp = dyn_cast(parentOp)) + yieldType = + whileOp.getBeforeArguments()[operand.getOperandNumber()].getType(); + auto tensorType = dyn_cast(yieldType); + if (!tensorType) + continue; + Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); + yieldOp->setOperand(operand.getOperandNumber(), newOperand); + } +} + +void LayoutPropagation::rewriteConditionOp(scf::ConditionOp conditionOp) { + scf::WhileOp whileOp = cast(conditionOp->getParentOp()); + for (unsigned i = 1; i < conditionOp->getNumOperands(); ++i) { + OpOperand &operand = conditionOp->getOpOperand(i); + Type argType = whileOp->getResult(operand.getOperandNumber() - 1).getType(); + auto tensorType = dyn_cast(argType); + if (!tensorType) + continue; + Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); + conditionOp->setOperand(operand.getOperandNumber(), newOperand); + } +} + +void LayoutPropagation::rewriteReduceToScalar(Operation *reduceOp) { + OpBuilder rewriter(reduceOp); + Attribute srcEncoding; + // Since all the operands need to have the same encoding pick the first one + // and use it for all the operands. + for (Value operand : reduceOp->getOperands()) { + auto it = layouts.find(operand); + if (it != layouts.end()) { + srcEncoding = it->second.encodings[0]; + break; + } + } + if (!srcEncoding) + return; + for (OpOperand &operand : reduceOp->getOpOperands()) { + Value newOperand = getValueAs(operand.get(), srcEncoding); + reduceOp->setOperand(operand.getOperandNumber(), newOperand); + } +} + +void LayoutPropagation::rewriteAssertOp(AssertOp assertOp) { + Attribute srcEncoding; + // Only need to deal with the first operand which is the condition tensor. + Value operand = assertOp->getOperand(0); + auto it = layouts.find(operand); + if (it == layouts.end()) + return; + srcEncoding = it->second.encodings[0]; + Value newOperand = getValueAs(operand, srcEncoding); + assertOp->setOperand(0, newOperand); +} + +Operation *LayoutPropagation::rewriteOp(Operation *op) { + opToDelete.insert(op); + if (auto forOp = dyn_cast(op)) + return rewriteForOp(forOp); + if (auto whileOp = dyn_cast(op)) + return rewriteWhileOp(whileOp); + if (auto ifOp = dyn_cast(op)) + return rewriteIfOp(ifOp); + OpBuilder rewriter(op); + Attribute encoding = *layouts[op->getResult(0)].encodings.begin(); + if (auto convertOp = dyn_cast(op)) { + Attribute srcEncoding = convertOp.getSrc().getType().getEncoding(); + auto it = layouts.find(convertOp.getSrc()); + if (it != layouts.end()) + srcEncoding = *(it->second.encodings.begin()); + Value src = getValueAs(convertOp.getSrc(), srcEncoding); + auto tensorType = cast(op->getResult(0).getType()); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + auto cvt = rewriter.create(op->getLoc(), newType, src); + map(op->getResult(0), cvt.getResult()); + return cvt.getOperation(); + } + if (canFoldIntoConversion(op, encoding)) { + Operation *newOp = rewriter.clone(*op); + auto tensorType = cast(op->getResult(0).getType()); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + auto cvt = rewriter.create(op->getLoc(), newType, + newOp->getResult(0)); + map(op->getResult(0), cvt.getResult()); + return cvt.getOperation(); + } + if (op->hasTrait() || + op->hasTrait() || + isa(op)) { + Operation *newOp = cloneElementwise(rewriter, op, encoding); + for (auto [oldResult, newResult] : + llvm::zip(op->getResults(), newOp->getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + return newOp; + } + llvm::report_fatal_error("unexpected op in rewrite"); + return nullptr; +} + +bool canBeRemat(Operation *op) { + if (isa(op)) + return !isExpensiveLoadOrStore(op); + if (isa(op)) + return false; + if (auto gather = dyn_cast(op)) + return !gather.getEfficientLayout(); + + if (isa(op)) + return false; + + return true; +} + +void LayoutRematerialization::updateRematMapping( + SmallVector> &values) { + for (auto [old, newV] : values) { + auto it = mappedValues.find(old); + if (it != mappedValues.end()) { + Attribute encoding = it->second; + auto rematIt = rematMapping.find({old, it->second}); + assert(rematIt != rematMapping.end()); + Value replacedValue = rematIt->second; + rematMapping.erase(rematIt); + mappedValues.erase(it); + // Loop through the replacement value to find the new version of remat + // value. This should be okay as the number of values should be small. + for (auto [before, after] : values) { + if (before == replacedValue) { + replacedValue = after; + break; + } + } + rematMapping[{newV, encoding}] = replacedValue; + mappedValues[newV] = encoding; + } + } +} + +void LayoutRematerialization::rewriteSlice(SetVector &slice, + DenseMap &layout, + ConvertLayoutOp convertOp, + IRMapping &mapping) { + SetVector opsToRewrite; + // Keep track of yield operands that need to be duplicated. + DenseMap> yieldOperandsMap; + // Keep these around to remove them from the slice after our collection pass + // This ensures we don't duplicate them during an for rewrite or causing the + // for/yield to fall out of sync + SetVector valuesWithExistingRemat; + for (Value v : slice) { + auto layoutIt = layout.find(v); + assert(layoutIt != layout.end()); + // If we already have a remat value for this value, use it. + if (Value remat = getRematValue(v, layoutIt->second)) { + mapping.map(v, remat); + valuesWithExistingRemat.insert(v); + continue; + } + if (v.getDefiningOp()) { + opsToRewrite.insert(v.getDefiningOp()); + if (auto ifOp = v.getDefiningOp()) { + unsigned operandIdx = cast(v).getResultNumber(); + opsToRewrite.insert(ifOp.thenYield().getOperation()); + yieldOperandsMap[ifOp.thenYield()].push_back(operandIdx); + opsToRewrite.insert(ifOp.elseYield().getOperation()); + yieldOperandsMap[ifOp.elseYield()].push_back(operandIdx); + } + } else { + BlockArgument blockArg = cast(v); + Operation *parentOp = blockArg.getOwner()->getParentOp(); + if (auto loopOp = cast(parentOp)) { + opsToRewrite.insert(loopOp.getOperation()); + OpOperand *operand = loopOp.getTiedLoopYieldedValue(blockArg); + auto yieldOp = blockArg.getOwner()->getTerminator(); + yieldOperandsMap[yieldOp].push_back(operand->getOperandNumber()); + opsToRewrite.insert(yieldOp); + } + } + } + slice.set_subtract(valuesWithExistingRemat); + opsToRewrite = multiRootTopologicalSort(opsToRewrite); + + // replaceAllUsesWith calls delayed until after initial rewrite. + // This is required for slice.count(value) to work mid rewrite. + SmallVector> replacements; + + SmallVector deadOps; + IRRewriter builder(slice.begin()->getContext()); + for (Operation *op : opsToRewrite) { + if (auto forOp = dyn_cast(op)) { + // Keep a mapping of the operands index to the new operands index. + SmallVector> argMapping; + SmallVector newOperands; + for (auto arg : forOp.getRegionIterArgs()) { + if (slice.count(arg)) { + OpOperand &initVal = *forOp.getTiedLoopInit(arg); + argMapping.push_back(std::make_pair( + forOp.getTiedLoopResult(&initVal).getResultNumber(), + forOp.getInitArgs().size() + newOperands.size())); + newOperands.push_back(mapping.lookup(initVal.get())); + } + } + // Create a new for loop with the new operands. + scf::ForOp newForOp = replaceForOpWithNewSignature( + builder, forOp, newOperands, replacements); + deadOps.push_back(forOp.getOperation()); + Block &loopBody = *newForOp.getBody(); + for (auto m : argMapping) { + mapping.map(forOp.getResult(m.first), newForOp.getResult(m.second)); + int numIndVars = newForOp.getNumInductionVars(); + mapping.map(loopBody.getArgument(m.first + numIndVars), + loopBody.getArgument(m.second + numIndVars)); + LLVM_DEBUG({ + DBGS() << "mapping forOp " + << loopBody.getArgument(m.first + numIndVars) << " to " + << loopBody.getArgument(m.second + numIndVars) << '\n'; + }); + // The result is not in the layout/slice, the argument is. + Value oldArg = loopBody.getArgument(m.first + numIndVars); + addRematValue(newForOp.getResult(m.first), layout[oldArg], + newForOp.getResult(m.second)); + addRematValue(oldArg, layout[oldArg], + loopBody.getArgument(m.second + numIndVars)); + } + continue; + } + if (auto ifOp = dyn_cast(op)) { + SmallVector newTypes; + for (auto res : ifOp.getResults()) { + if (slice.count(res)) { + auto it = layout.find(res); + assert(it != layout.end()); + + auto oldType = cast(res.getType()); + auto newType = RankedTensorType::get( + oldType.getShape(), oldType.getElementType(), it->second); + newTypes.push_back(newType); + } + } + scf::IfOp newIfOp = + replaceIfOpWithNewSignature(builder, ifOp, newTypes, replacements); + unsigned oldIdx = 0; + unsigned newIdx = ifOp.getNumResults(); + for (auto res : ifOp.getResults()) { + if (slice.count(res)) { + // Why can't we use res instead of ifOp.getResult(oldIdx)? + mapping.map(ifOp.getResult(oldIdx), newIfOp.getResult(newIdx)); + addRematValue(ifOp.getResult(oldIdx), layout[res], + newIfOp.getResult(newIdx)); + ++newIdx; + } + ++oldIdx; + } + deadOps.push_back(ifOp.getOperation()); + continue; + } + builder.setInsertionPoint(op); + if (auto yieldOp = dyn_cast(op)) { + auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); + SmallVector operandsToRewrite = yieldOperandsMap[op]; + // Sort so that operands are added in the same order as the new scf + // results/arguments. + std::sort(operandsToRewrite.begin(), operandsToRewrite.end()); + for (int operandIdx : operandsToRewrite) { + yieldOperands.push_back(mapping.lookup(yieldOp.getOperand(operandIdx))); + } + builder.create(op->getLoc(), yieldOperands); + op->erase(); + continue; + } + if (isa(op)) { + Operation *newOp = builder.clone(*op); + auto tensorType = cast(op->getResult(0).getType()); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), + layout[op->getResult(0)]); + auto cvt = builder.create(op->getLoc(), newType, + newOp->getResult(0)); + mapping.map(op->getResult(0), cvt.getResult()); + addRematValue(op->getResult(0), layout[op->getResult(0)], + cvt.getResult()); + continue; + } + Operation *newOp = builder.clone(*op, mapping); + for (auto [old, newV] : llvm::zip(op->getResults(), newOp->getResults())) { + auto it = layout.find(old); + if (it == layout.end()) + continue; + auto newType = RankedTensorType::get( + cast(old.getType()).getShape(), + cast(old.getType()).getElementType(), it->second); + newV.setType(newType); + addRematValue(old, it->second, newV); + } + } + // Check mapping and see if there are existing convertOps on the old Argument + convertOp.replaceAllUsesWith(mapping.lookup(convertOp.getSrc())); + opToDelete.insert(convertOp); + + updateRematMapping(replacements); + for (auto &kv : replacements) { + builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv)); + } + + for (Operation *op : deadOps) + opToDelete.insert(op); +} + +void LayoutRematerialization::rewriteSlice(SetVector &slice, + DenseMap &layout, + ConvertLayoutOp convertOp) { + IRMapping mapping; + rewriteSlice(slice, layout, convertOp, mapping); +} + +LogicalResult LayoutRematerialization::getConvertBackwardSlice( + OpOperand &root, Attribute rootEncoding, SetVector &slice, + DenseMap &layout, + std::function stopPropagation) { + // Allow re-using existing conversions for a value. Check dominance of any + // reusable materializations against the root value. This is sufficient + // because the conversions are processed in post-order. + auto getExistingConversion = [&](OpOperand &value, Attribute encoding) { + Value remat = getRematValue(value.get(), encoding); + if (!remat) + return Value(); + // `value` can be replaced with an existing rematerialization if it + // dominates the current use of value. + Operation *user = value.getOwner(); + if (domInfo.properlyDominates(remat, user)) { + return remat; + } + // FIXME: If the current user is a conversion, then we know it will become + // a no-op when its operand is replaced with `remat`, but we need to check + // that its users are all dominated by `remat` so the IR is valid. + // if (isa(user) && remat.getDefiningOp() && + // domInfo.properlyDominates(user, remat.getDefiningOp())) { + // for (Operation *op : user->getUsers()) { + // if (!domInfo.dominates(remat, op)) + // return Value(); + // } + // return remat; + // } + return Value(); + }; + + return mlir::getConvertBackwardSlice(root, slice, rootEncoding, layout, + stopPropagation, getExistingConversion); +} + +LogicalResult LayoutRematerialization::getRematerializableSlice( + OpOperand &root, Attribute rootEncoding, SetVector &slice, + DenseMap &layout, + std::function stopPropagation) { + LogicalResult result = getConvertBackwardSlice(root, rootEncoding, slice, + layout, stopPropagation); + if (result.failed() || slice.empty()) + return failure(); + + // Check if all the operations in the slice can be rematerialized. + for (Value v : slice) { + if (Operation *op = v.getDefiningOp()) { + if (!canBeRemat(op)) + return failure(); + } + } + return success(); +} + +void LayoutRematerialization::backwardRematerialization() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + backwardRematerialization(convertOp); + if (!opToDelete.contains(convertOp)) { + // If the conversion didn't get removed, consider it for reuse in future + // backward slices. + addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), + convertOp.getResult()); + } + } +} + +void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + hoistConvertOnTopOfExtOrBroadcast(convertOp); + if (!opToDelete.contains(convertOp)) { + // If the conversion didn't get removed, consider it for reuse in future + // backward slices. + addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), + convertOp.getResult()); + } + } +} + +void LayoutRematerialization::hoistConvertIntoConditionals() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + hoistConvertIntoConditionals(convertOp); + if (!opToDelete.contains(convertOp)) { + // If the conversion didn't get removed, consider it for reuse in future + // backward slices. + addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), + convertOp.getResult()); + } + } +} + +static bool isExpensiveMathOp(Operation *op) { + // These operations are either multiple instructions or have throughput + // lower than 16 according to the arithmetic instructions table in: + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions + return isa(op); +} + +static int64_t getByteCount(Value result, int64_t minElementCount = 0, + int64_t minBitWidth = 0) { + int64_t elementCount = 0; + int64_t dtypeBitWidth = 0; + if (auto tensorTy = dyn_cast(result.getType())) { + elementCount = tensorTy.getNumElements(); + auto elemType = tensorTy.getElementType(); + if (elemType.isIntOrFloat()) { + dtypeBitWidth = elemType.getIntOrFloatBitWidth(); + } + } + if (elementCount < minElementCount) { + elementCount = minElementCount; + } + if (dtypeBitWidth < minBitWidth) { + dtypeBitWidth = minBitWidth; + } + return (elementCount * dtypeBitWidth) >> 3; +} + +void LayoutRematerialization::backwardRematerialization( + ConvertLayoutOp convertOp) { + // DotOperand is hoisted by hoistDotOperand + RankedTensorType targetType = convertOp.getType(); + if (isa(targetType.getEncoding())) + return; + Value oldV = convertOp.getSrc(); + LDBG("check backward remat with source " << oldV << " encoding " + << targetType.getEncoding()); + // Check to see if there are existing remat'ed values for the pair of oldValue + // and encoding. Make sure it dominates the current conversion. + Value newV = getRematValue(oldV, targetType.getEncoding()); + if (newV && domInfo.properlyDominates(newV, convertOp)) { + // Replace it with the remat'ed value. + convertOp.replaceAllUsesWith(newV); + opToDelete.insert(convertOp); + LDBG("found remat'ed value" << newV); + return; + } + + // 1. Take a backward slice of all the tensor dependencies that can be + // rematerialized. + SetVector slice; + DenseMap layout; + LogicalResult result = getRematerializableSlice( + convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout); + if (result.failed()) { + LDBG(" getRematerializableSlice failed"); + return; + } + + // 2. Determine whether rematerialisation is beneficial. + + // Identify all operations in the slice + SetVector sliceOps; + for (Value v : slice) { + if (Operation *op = v.getDefiningOp()) { + sliceOps.insert(op); + } + } + + // Compute single-use operations + DenseMap isSingleUse; + std::function isOpSingleUse; + isOpSingleUse = [&](Operation *op) -> bool { + // lookup in memoization array: + auto it = isSingleUse.find(op); + if (it != isSingleUse.end()) { + return it->second; + } + + bool singleUse = true; + + for (Value result : op->getResults()) { + for (Operation *user : result.getUsers()) { + if (user == convertOp) { + continue; + } + if (sliceOps.contains(user)) { + if (!isOpSingleUse(user)) { + singleUse = false; + break; + } + } else { + singleUse = false; + break; + } + } + if (!singleUse) { + break; + } + } + + // insert into memoization array: + isSingleUse[op] = singleUse; + return singleUse; + }; + + // Measure the number of bytes that we're manipulating with the + // ConvertLayoutOp. We pessimistically assume that we round-trip + // through shared memory and that we cannot vectorise sub-register + // loads/stores, so we set a minimum element count of 32 (the warp + // size and number of shared memory banks) and minimum bitwidth of + // 32 (the width per bank of the shared memory load/store unit). + int64_t convertLayoutBytes = getByteCount(convertOp.getSrc(), 32, 32); + + // We measure costs in standardised milli-SM-cycles. The smem load + // and store each cost 8 * convertLayoutBytes, and then we double + // it to account for extra cost due to synchronisation. + int64_t convertLayoutCost = 32 * convertLayoutBytes; + int64_t rematerialisationCost = 0; + + // Evaluate single-use status for every operation in slice + for (Operation *op : sliceOps) { + auto dialect = op->getDialect(); + if (isOpSingleUse(op)) { + // when we rematerialise, this operation does not get duplicated + // so it does not contribute to our cost model: + continue; + } else if (isa(op)) { + // special-case: arith.constant has zero cost + continue; + } else if (isa(op)) { + // optimistically assume L1-cached: + for (Value result : op->getResults()) { + rematerialisationCost += 8 * getByteCount(result); + } + } else if (isa(dialect)) { + // this is an arithmetic operation; we distinguish between cheap + // operations (such as floating point add/mul which can be fused + // as halves of a single-cycle FMA instruction) and expensive + // operations which use the special function unit and/or involve + // multiple instructions. + int64_t multiplier = isExpensiveMathOp(op) ? 8 : 1; + for (Value result : op->getResults()) { + rematerialisationCost += multiplier * getByteCount(result); + } + } + } + + LLVM_DEBUG({ + DBGS() << " convert layout cost: " << convertLayoutCost << "\n"; + DBGS() << " rematerialisation cost: " << rematerialisationCost << "\n"; + }); + + if (rematerialisationCost > convertLayoutCost) { + LDBG(" skipped rematerialization due to higher cost"); + return; + } + + LLVM_DEBUG({ + DBGS() << " remat convert op " << convertOp << '\n'; + for (Value v : slice) + DBGS() << " " << v << '\n'; + }); + + // 3. Rewrite the slice. + rewriteSlice(slice, layout, convertOp); +} + +void LayoutRematerialization::hoistConvertDotOperand() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + hoistConvertDotOperand(convertOp); + if (!opToDelete.contains(convertOp)) { + // If the conversion didn't get removed, consider it for reuse in future + // backward slices. + addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), + convertOp.getResult()); + } + } +} + +void LayoutRematerialization::hoistConvertDotOperand( + ConvertLayoutOp convertOp) { + auto targetType = convertOp.getType(); + // The pass is targeted to MMA dot operands + + auto canBePipelined = [&](ConvertLayoutOp convertOp) { + // FIXME: Check that the parent is a for loop + auto parent = convertOp->getParentOp(); + if (!parent) + return false; + + // Find all the dot-like ops in the for loop that have a dot operand + // encoding on the lhs and check if any of them post-dominates the load + + // cvt + SmallVector dotLikeOps; + parent->walk([&](Operation *op) { + if (!isa(op)) + return; + auto opType = dyn_cast(op->getOperand(0).getType()); + if (!opType) + return; + auto dotEnc = dyn_cast(opType.getEncoding()); + if (!dotEnc) + return; + if (isa(dotEnc.getParent())) + dotLikeOps.push_back(op); + }); + if (dotLikeOps.empty()) + return false; + return llvm::any_of(dotLikeOps, [&](Operation *dot) { + return postDomInfo.postDominates(dot, convertOp); + }); + }; + + // We move convert #dot_operand next to their loads. This is done + // so that it's then easy to pipeline these loads + if (!canBePipelined(convertOp)) + return; + + // We hoist over any operation that can be done without data movement between + // threads We do views and elementwise pure ops for now + auto noDataMovement = [](Operation *op) { + return (op->hasTrait() && isMemoryEffectFree(op)) || + isa(op) || isView(op); + }; + // Stop the slice as soon as we find an operation that cannot be done without + // data movement between threads + auto stop = std::not_fn(noDataMovement); + + SetVector slice; + DenseMap layout; + // Set-up the conversion "cache" + LogicalResult result = getConvertBackwardSlice( + convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout, stop); + if (result.failed()) + return; + + IRMapping mapping; + OpBuilder builder(convertOp.getContext()); + SetVector innerSlice; + for (Value v : slice) { + if (!v.getDefiningOp()) { + LLVM_DEBUG( + { DBGS() << " Block arguments not supported. Got " << v << "\n"; }); + return; + } + + // We expect the leaves of the slice to be Load, DescriptorLoad or + // arith::Constant This could be generalised if necessary + if (!isa(v.getDefiningOp())) { + auto op = v.getDefiningOp(); + if (isa(op) || noDataMovement(op)) { + innerSlice.insert(v); + continue; + } else { + LLVM_DEBUG({ + DBGS() << " Leaves must be Load, DescriptorLoad or Constant. Got " + << v << "\n"; + }); + return; + } + } + Operation *loadOp = v.getDefiningOp(); + builder.setInsertionPointAfter(loadOp); + auto type = dyn_cast(loadOp->getResult(0).getType()); + if (!type) + continue; + auto newType = RankedTensorType::get(type.getShape(), type.getElementType(), + layout[loadOp->getResult(0)]); + auto newConvertOp = builder.create( + convertOp.getLoc(), newType, loadOp->getResult(0)); + mapping.map(loadOp->getResult(0), newConvertOp.getResult()); + } + + if (innerSlice.empty()) { + return; + } + + LLVM_DEBUG({ + DBGS() << " Hoisting " << convertOp << '\n'; + for (Value v : innerSlice) + DBGS() << " " << v << '\n'; + }); + + rewriteSlice(innerSlice, layout, convertOp, mapping); +} + +// For convert left we try to hoist them above type extension to reduce the cost +// of the convert. +void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( + ConvertLayoutOp convertOp) { + // DotOperand is hoisted by hoistDotOperand + RankedTensorType targetType = convertOp.getType(); + if (isa(targetType.getEncoding())) + return; + + auto isExtOrBroadcastOp = [](Operation *op) { + if (isa(op)) { + return true; + } + if (auto fpToFpOp = dyn_cast(op)) { + auto srcType = cast(fpToFpOp.getOperand().getType()); + return getElementBitWidth(srcType) < + getElementBitWidth(cast(fpToFpOp.getType())); + } + return false; + }; + // 1. Take a backward slice of all the tensor dependencies. + SetVector slice; + DenseMap layout; + LogicalResult result = getRematerializableSlice( + convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout, + isExtOrBroadcastOp); + if (result.failed()) + return; + + Operation *extOrBroadcastOp = nullptr; + unsigned sliceSize = slice.size(); + for (unsigned i = 0; i < sliceSize; i++) { + Value v = slice[i]; + Operation *op = v.getDefiningOp(); + if (!op) + continue; + if (isExtOrBroadcastOp(op)) { + SetVector tempSlice; + DenseMap tempLayout; + Attribute srcEncoding = inferSrcEncoding(op, layout[v]); + if (!srcEncoding) + return; + LogicalResult result = getRematerializableSlice( + op->getOpOperand(0), srcEncoding, tempSlice, tempLayout); + + // If a value is already assigned to a _different_ layout, + // we cannot propagate past this op (as it would conflict with + // an already-assigned layout). + for (auto [val, enc] : tempLayout) { + auto preexistingLayout = layout.find(val); + if (preexistingLayout != layout.end() && + preexistingLayout->second != enc) { + result = failure(); + break; + } + } + + // If we can rematerialize the rest of the ext slice we can ignore this + // ext as it won't need a convert. + if (result.succeeded()) { + slice.insert(tempSlice.begin(), tempSlice.end()); + layout.insert(tempLayout.begin(), tempLayout.end()); + continue; + } + // Only apply it if there is a single ext op otherwise we would have to + // duplicate the convert. + if (extOrBroadcastOp != nullptr) + return; + extOrBroadcastOp = op; + } + } + + if (extOrBroadcastOp == nullptr) + return; + Attribute dstEncoding = layout[extOrBroadcastOp->getResult(0)]; + Attribute srcEncoding = inferSrcEncoding(extOrBroadcastOp, dstEncoding); + if (!srcEncoding) + return; + // Move the convert before the ext op and rewrite the slice. + OpBuilder builder(extOrBroadcastOp); + auto tensorType = + cast(extOrBroadcastOp->getOperand(0).getType()); + auto newType = RankedTensorType::get( + tensorType.getShape(), tensorType.getElementType(), srcEncoding); + auto newConvertOp = builder.create( + convertOp.getLoc(), newType, extOrBroadcastOp->getOperand(0)); + Operation *newExtOrBroadcast = builder.clone(*extOrBroadcastOp); + newExtOrBroadcast->setOperand(0, newConvertOp.getResult()); + auto oldExtOrBroadcastType = + cast(extOrBroadcastOp->getResult(0).getType()); + Type newExtOrBroadcasrType = RankedTensorType::get( + oldExtOrBroadcastType.getShape(), oldExtOrBroadcastType.getElementType(), + dstEncoding); + newExtOrBroadcast->getResult(0).setType(newExtOrBroadcasrType); + IRMapping mapping; + mapping.map(extOrBroadcastOp->getResult(0), newExtOrBroadcast->getResult(0)); + slice.remove(extOrBroadcastOp->getResult(0)); + // 3. Rewrite the slice. + rewriteSlice(slice, layout, convertOp, mapping); +} + +void LayoutRematerialization::hoistConvertIntoConditionals( + ConvertLayoutOp convertOp) { + // Take the backward slice of tensor dependencies rooted at the conversion, + // stopping at conditionals. This subslice is used to initialize the analysis. + SetVector slice; + DenseMap layout; + auto isIfOp = [](Operation *op) { return isa(op); }; + if (failed(getRematerializableSlice(convertOp.getSrcMutable(), + convertOp.getType().getEncoding(), slice, + layout, isIfOp))) + return; + + // These are the conditional edges above which conversions should be hoisted. + // The value represents the `scf.if` op result and the operand represents the + // edge into one of the branches. + SmallVector> hoistAbove; + + // The list of `scf.if` op results in the slice that are not rematerializable. + // Hoisting is terminated at these values. + SmallVector terminals; + + // This loop recurses through the subslices of the backwards dependencies, so + // re-query the size of `slice`. + for (unsigned i = 0; i != slice.size(); ++i) { + Value v = slice[i]; + auto ifOp = v.getDefiningOp(); + if (!ifOp) + continue; + + Attribute rootLayout = layout.at(v); + unsigned resIdx = cast(v).getResultNumber(); + + // Take the backward slice along each branch. + auto thenYield = + cast(ifOp.getThenRegion().front().getTerminator()); + auto elseYield = + cast(ifOp.getElseRegion().front().getTerminator()); + + OpOperand &thenRes = thenYield.getResultsMutable()[resIdx]; + OpOperand &elseRes = elseYield.getResultsMutable()[resIdx]; + + SetVector thenSlice, elseSlice; + DenseMap thenLayout, elseLayout; + + LogicalResult thenResult = getRematerializableSlice( + thenRes, rootLayout, thenSlice, thenLayout, isIfOp); + LogicalResult elseResult = getRematerializableSlice( + elseRes, rootLayout, elseSlice, elseLayout, isIfOp); + + // If propagation across both edges of this conditional succeeded, then we + // don't need to hoist across it. Merge into the current slice. + if (succeeded(thenResult) && succeeded(elseResult)) { + slice.insert(thenSlice.begin(), thenSlice.end()); + slice.insert(elseSlice.begin(), elseSlice.end()); + layout.insert(thenLayout.begin(), thenLayout.end()); + layout.insert(elseLayout.begin(), elseLayout.end()); + continue; + } + + // If propagation across both edges failed, then this conditional + // terminates backwards rematerialization. + if (failed(thenResult) && failed(elseResult)) { + terminals.push_back(cast(v)); + continue; + } + + // Only hoist into conditionals inside loops. The assumption is that an if + // inside a loop executes fewer than the total number of loop iterations, + // making this hoist profitable. + if (!isa(ifOp->getParentOp())) { + terminals.push_back(cast(v)); + continue; + } + + // The layout conversion can be rematerialized along one edge but not the + // other. We can hoist the conversion into the other branch. Push this + // into the subslice list for analysis. + if (succeeded(thenResult)) { + hoistAbove.emplace_back(v, &elseRes); + slice.insert(thenSlice.begin(), thenSlice.end()); + layout.insert(thenLayout.begin(), thenLayout.end()); + } else { + hoistAbove.emplace_back(v, &thenRes); + slice.insert(elseSlice.begin(), elseSlice.end()); + layout.insert(elseLayout.begin(), elseLayout.end()); + } + } + + // Exit early if there is nothing to do. + if (hoistAbove.empty()) + return; + + // Rematerialize failed hoists right before the condtional, and hoist those + // that succeeded into the branch and then rewrite the slice. + IRMapping mapping; + auto hoistRemat = [&](OpBuilder &b, Value v, Attribute encoding) { + auto tensorType = cast(v.getType()); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + Value newCvt = b.create(convertOp.getLoc(), newType, v); + + mapping.map(v, newCvt); + slice.remove(v); + }; + for (Value v : terminals) { + OpBuilder b(v.getContext()); + b.setInsertionPointAfter(v.getDefiningOp()); + hoistRemat(b, v, layout.at(v)); + } + for (auto [result, edge] : hoistAbove) { + OpBuilder b(edge->getOwner()); + hoistRemat(b, edge->get(), layout.at(result)); + } + rewriteSlice(slice, layout, convertOp, mapping); +} + +void backwardRematerialization(ModuleOp module) { + module.walk([](FuncOp funcOp) { + LayoutRematerialization layoutRemat(funcOp); + layoutRemat.backwardRematerialization(); + layoutRemat.cleanup(); + }); +} + +void hoistConvert(ModuleOp module) { + SmallVector convertOps; + module.walk([](FuncOp funcOp) { + LayoutRematerialization layoutRemat(funcOp); + layoutRemat.hoistConvertOnTopOfExtOrBroadcast(); + layoutRemat.cleanup(); + + layoutRemat = LayoutRematerialization(funcOp); + layoutRemat.hoistConvertIntoConditionals(); + layoutRemat.cleanup(); + + layoutRemat = LayoutRematerialization(funcOp); + layoutRemat.hoistConvertDotOperand(); + layoutRemat.cleanup(); + }); +} +} // namespace + +class TritonGPURemoveLayoutConversionsPass + : public impl::TritonGPURemoveLayoutConversionsBase< + TritonGPURemoveLayoutConversionsPass> { +public: + // Cleanup convert ops. + void cleanupConvertOps() { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + RewritePatternSet cleanUpPatterns(context); + ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns, context); + if (applyPatternsGreedily(m, std::move(cleanUpPatterns)).failed()) { + signalPassFailure(); + } + + LLVM_DEBUG({ + DBGS() << "Module after canonicalizing:\n"; + m.dump(); + }); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + // 1. Propagate layout forward starting from "anchor" ops. + m.walk([](FuncOp funcOp) { + LayoutPropagation layoutPropagation(funcOp); + layoutPropagation.initAnchorLayout(); + layoutPropagation.propagateLayout(); + layoutPropagation.resolveConflicts(); + layoutPropagation.rewrite(); + }); + + LLVM_DEBUG({ + DBGS() << "Module after propagating layouts forward:\n"; + m.dump(); + }); + + cleanupConvertOps(); + + // 2. For remaining convert ops, try to rematerialize the slice of producer + // operation to avoid having to convert. + backwardRematerialization(m); + LLVM_DEBUG({ + DBGS() << "Module after backward remat:\n"; + m.dump(); + }); + + // Cleanup dummy converts created during backward remat. + cleanupConvertOps(); + + // 3. For remaining converts, try to hoist them above cast generating larger + // size types in order to reduce the cost of the convert op. + hoistConvert(m); + LLVM_DEBUG({ + DBGS() << "Module after hoisting converts:\n"; + m.dump(); + }); + + // 4. Apply clean up patterns to remove remove dead convert and dead code + // generated by the previous transformations. + RewritePatternSet cleanUpPatterns2(context); + populateForOpDeadArgumentElimination(cleanUpPatterns2); + scf::ForOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + scf::IfOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + if (applyPatternsGreedily(m, std::move(cleanUpPatterns2)).failed()) { + signalPassFailure(); + } + LLVM_DEBUG({ + DBGS() << "Module after final cleanups:\n"; + m.dump(); + }); + } +}; + +} // namespace mlir::triton::gpu diff --git a/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/Utility.cpp new file mode 100644 index 000000000..ff86a7642 --- /dev/null +++ b/third_party/sunrise/backend/spec/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -0,0 +1,1590 @@ +#include "triton/Analysis/Utility.h" + +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "ttg-utility" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; +namespace mlir { + +using namespace triton; + +SmallVector mmaVersionToInstrShape(int version, + const ArrayRef &shape, + Type eltType, int numWarps) { + if (version == 1) + return {16, 16}; + else if (version == 2) { + auto rank = shape.size(); + SmallVector ret(rank, 1); + ret[rank - 1] = 8; + ret[rank - 2] = 16; + return ret; + } else if (version == 3) { + unsigned k = 256 / eltType.getIntOrFloatBitWidth(); + if (shape[0] % 64 != 0 || shape[1] % 8 != 0) { + assert(false && "type not supported"); + return {0, 0, 0}; + } + SmallVector validN; + + // MMAv3 with larger instruction shape is preferred. + if (llvm::isa( + eltType) || + eltType.isF16() || eltType.isBF16() || eltType.isF32()) { + validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, + 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, + 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); + } + + if (eltType.isInteger(8)) { + validN.assign({224, 208, 192, 176, 160, 144, 128, 112, 96, 80, 64, 48, 32, + 24, 16, 8}); + } + + unsigned m = 16; + unsigned mWarps = std::max(shape[0] / m, 1); + unsigned nWarps = std::max(numWarps / mWarps, 1); + unsigned maxN = std::max(shape[1] / nWarps, 8); + for (auto n : validN) { + if (shape[1] % n == 0 && n <= maxN) { + return {m, n, k}; + } + } + + assert(false && "type not supported"); + return {0, 0, 0}; + } else if (version == 5) { + unsigned m = shape[0] >= 128 ? 128 : 64; + // Right now default to distributing along N. TODO: For cases where we have + // dot followed by reduction we need to be able to distribute along M. + // if (numWarps > 4) + // m = 64; + unsigned n = shape[1] >= 256 ? 256 : shape[1]; + unsigned k = 256 / eltType.getIntOrFloatBitWidth(); + return {m, n, k}; + } else { + assert(false && "version not supported"); + return {0, 0}; + } +} + +bool isLoadFromTensorPtr(triton::LoadOp op) { + return mlir::triton::isTensorPointerType(op.getPtr().getType()); +} + +SmallVector argSort(const SmallVector &arr) { + SmallVector ret(arr.size()); + std::iota(ret.begin(), ret.end(), 0); + std::stable_sort(ret.begin(), ret.end(), + [&](unsigned x, unsigned y) { return arr[x] > arr[y]; }); + return ret; +} + +Value getMemAccessPtr(Operation *op) { + if (auto ld = dyn_cast(op)) + return ld.getPtr(); + if (auto atomic = dyn_cast(op)) + return atomic.getPtr(); + if (auto atomic = dyn_cast(op)) + return atomic.getPtr(); + if (auto copy = dyn_cast(op)) + return copy.getSrc(); + if (auto store = dyn_cast(op)) + return store.getPtr(); + return nullptr; +} + +unsigned getElementBitWidth(RankedTensorType type) { + auto typeForMem = + isa(type.getElementType()) + ? cast(type.getElementType()).getPointeeType() + : type.getElementType(); + return typeForMem.getIntOrFloatBitWidth(); +} + +unsigned getNumElementsPerThread(Operation *op, SmallVector order, + ModuleAxisInfoAnalysis &axisInfoAnalysis) { + Value val = getMemAccessPtr(op); + auto ty = cast(val.getType()); + auto shapePerCTA = triton::gpu::getShapePerCTA(ty); + AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val); + unsigned elemNumBits = getElementBitWidth(ty); // 16 + unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); // 2 + unsigned maxMultipleBytes = valInfo.getDivisibility(order[0]); // 2 + unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u); // 1 + unsigned maxContig = // 1 + std::min(valInfo.getContiguity(order[0]), shapePerCTA[order[0]]); + unsigned alignment = std::min(maxMultiple, maxContig); + unsigned currPerThread = std::min(alignment, std::max(32 / elemNumBits, 1u)); + return currPerThread; +} + +bool isView(Operation *op) { + return isa(op); +} + +bool isNoop(Operation *op) { + if (isa(op)) + return true; + if (auto cvt = dyn_cast(op)) { + // The conversion op is a noop if the conversion layout is trivial + return minimalCvtLayout(cvt.getSrc().getType(), + cvt.getResult().getType()) == LinearLayout::empty(); + } + return false; +} + +//===----------------------------------------------------------------------===// +// GraphDumper +//===----------------------------------------------------------------------===// + +GraphDumper::NodeInfo GraphDumper::onValue(Value value) const { + return {{"shape", "box"}, {"style", "filled"}, {"fillcolor", "white"}}; +} + +GraphDumper::NodeInfo GraphDumper::onOperation(Operation *op) const { + return {{"shape", "ellipse"}, {"style", "filled"}, {"fillcolor", "white"}}; +} + +std::string GraphDumper::dump(triton::FuncOp func) const { + llvm::SetVector values; + llvm::SetVector operations; + + func.walk([&](Operation *op) { + operations.insert(op); + for (Value operand : op->getOperands()) + values.insert(operand); + for (Value result : op->getResults()) + values.insert(result); + }); + + std::ostringstream oss; + oss << "// Generated by Triton GraphDumper\n" + << "\n" + << "digraph {\n"; + + oss << " // Value Nodes\n"; + for (Value value : values) + oss << " " << emitValueNode(value) << "\n"; + oss << "\n"; + + oss << " // Operation Nodes\n"; + for (Operation *op : operations) + oss << " " << emitOperationNode(op) << "\n"; + oss << "\n"; + + oss << " // Edges\n"; + for (Operation *op : operations) { + for (Value operand : op->getOperands()) + oss << " " << emitEdge(getUniqueId(operand), getUniqueId(op)) << "\n"; + for (Value result : op->getResults()) + oss << " " << emitEdge(getUniqueId(op), getUniqueId(result)) << "\n"; + } + + oss << "}\n"; + return oss.str(); +} + +void GraphDumper::dumpToFile(triton::FuncOp func, + const std::string &filename) const { + std::ofstream ofs(filename); + ofs << dump(func); +} + +std::string GraphDumper::getShapeStr(const Type &type) const { + std::ostringstream oss; + oss << "["; + if (auto tensorTy = dyn_cast(type)) { + auto shape = tensorTy.getShape(); + for (unsigned i = 0; i < shape.size(); ++i) { + if (i > 0) + oss << ", "; + oss << shape[i]; + } + } + oss << "]"; + return oss.str(); +} + +std::string GraphDumper::getUniqueId(Value value) const { + std::ostringstream oss; + oss << value.getImpl(); + return oss.str(); +} + +std::string GraphDumper::getUniqueId(Operation *op) const { + std::ostringstream oss; + oss << op; + return oss.str(); +} + +std::string GraphDumper::emitNode(const std::string &id, + const GraphDumper::NodeInfo info) const { + std::ostringstream oss; + oss << "\"" << id << "\" ["; + for (auto it = info.begin(); it != info.end(); ++it) { + if (it != info.begin()) + oss << ", "; + oss << it->first << " = \"" << it->second << "\""; + } + oss << "];"; + return oss.str(); +} + +std::string GraphDumper::emitEdge(const std::string &srcId, + const std::string &destId) const { + std::ostringstream oss; + oss << "\"" << srcId << "\" -> \"" << destId << "\";"; + return oss.str(); +} + +std::string GraphDumper::emitValueNode(Value value) const { + NodeInfo info = onValue(value); + if (info.find("label") == info.end()) { + std::string shapeStr = getShapeStr(value.getType()); + if (auto arg = mlir::dyn_cast(value)) + info["label"] = + "BlockArg" + std::to_string(arg.getArgNumber()) + " " + shapeStr; + else + info["label"] = shapeStr; + } + return emitNode(getUniqueId(value), info); +} + +std::string GraphDumper::emitOperationNode(Operation *op) const { + NodeInfo info = onOperation(op); + if (info.find("label") == info.end()) + info["label"] = op->getName().getStringRef().str(); + return emitNode(getUniqueId(op), info); +} + +//===----------------------------------------------------------------------===// +// GraphLayoutMarker +//===----------------------------------------------------------------------===// + +GraphDumper::NodeInfo GraphLayoutMarker::onValue(Value value) const { + std::string color = getColor(value.getType()); + return {{"shape", "box"}, {"style", "filled"}, {"fillcolor", color}}; +} + +std::string GraphLayoutMarker::getColor(const Type &type) const { + if (auto tensorTy = dyn_cast(type)) { + auto layout = tensorTy.getEncoding(); + if (isa(layout)) + return "green"; + else if (isa(layout)) + return "yellow"; + else if (isa(layout)) + return "lightslateblue"; + else if (isa(layout)) + return "orange"; + else if (isa(layout)) + return "orangered"; + else { + llvm::report_fatal_error("Unrecognized layout"); + return "unknown"; + } + } else { + return "white"; + } +} +// -------------------------------------------------------------------------- // + +static Attribute inferDstEncoding(triton::ReduceOp op, Attribute encoding) { + return triton::gpu::SliceEncodingAttr::get( + op->getContext(), op.getAxis(), + cast(encoding)); +} + +static Attribute inferDstEncoding(triton::ExpandDimsOp op, Attribute encoding) { + auto sliceEncoding = mlir::dyn_cast(encoding); + if (!sliceEncoding) + return {}; + if (op.getAxis() != sliceEncoding.getDim()) + return {}; + return sliceEncoding.getParent(); +} + +static Attribute inferDstEncoding(JoinOp op, Attribute srcEnc) { + Attribute dstEnc; + auto shape = op.getLhs().getType().getShape(); + if (srcEnc.getDialect() + .getRegisteredInterface() + ->inferDefaultJoinOpEncoding(srcEnc, dstEnc, shape, + /*loc=*/std::nullopt) + .succeeded()) { + return dstEnc; + } + return {}; +} + +static Attribute inferDstEncoding(SplitOp op, Attribute srcEnc) { + Attribute dstEnc; + auto shape = op.getSrc().getType().getShape(); + if (srcEnc.getDialect() + .getRegisteredInterface() + ->inferSplitOpEncoding(srcEnc, dstEnc, shape, + /*loc=*/std::nullopt) + .succeeded()) { + return dstEnc; + } + return {}; +} + +static Attribute inferSrcEncoding(triton::ReduceOp op, Attribute encoding) { + auto sliceEncoding = mlir::dyn_cast(encoding); + if (!sliceEncoding) + return {}; + if (op.getAxis() != sliceEncoding.getDim()) + return {}; + return sliceEncoding.getParent(); +} + +static Attribute inferSrcEncoding(triton::ExpandDimsOp op, Attribute encoding) { + return triton::gpu::SliceEncodingAttr::get( + op->getContext(), op.getAxis(), + cast(encoding)); +} + +static Attribute inferSrcEncoding(JoinOp op, Attribute dstEnc) { + // Split is the inverse of join. + auto shape = op.getResult().getType().getShape(); + Attribute srcEnc; + if (dstEnc.getDialect() + .getRegisteredInterface() + ->inferSplitOpEncoding(dstEnc, srcEnc, shape, /*loc=*/std::nullopt) + .succeeded()) { + return srcEnc; + } + return {}; +} + +static Attribute inferSrcEncoding(SplitOp op, Attribute dstEnc) { + // Join is the inverse of split. + Attribute srcEnc; + auto shape = op.getOutLHS().getType().getShape(); + if (dstEnc.getDialect() + .getRegisteredInterface() + ->inferDefaultJoinOpEncoding(dstEnc, srcEnc, shape, + /*loc=*/std::nullopt) + .succeeded()) { + return srcEnc; + } + return {}; +} + +static Attribute inferSrcEncoding(GatherOp op, Attribute dstEnc) { + // The index encoding is the same as the output encoding. + return dstEnc; +} + +static Attribute inferTransOpDstEncoding(Attribute srcEnc, + ArrayRef shape, + ArrayRef order) { + // Simply forward to the existing inferTransOpEncoding function. + Attribute retEncoding; + if (succeeded( + srcEnc.getDialect() + .getRegisteredInterface() + ->inferTransOpEncoding(srcEnc, shape, order, retEncoding, + /*loc=*/{}))) { + return retEncoding; + } + return {}; +} + +static Attribute inferDstEncoding(triton::gpu::Fp4ToFpOp op, Attribute srcEnc) { + Attribute dstEnc; + auto shape = op.getSrc().getType().getShape(); + auto result = + srcEnc.getDialect() + .getRegisteredInterface() + ->inferFp4ToFpOpEncoding(shape, op.getAxis(), srcEnc, dstEnc, + /*fwdInference*/ true, std::nullopt); + assert(succeeded(result)); + return dstEnc; +} + +static Attribute inferSrcEncoding(triton::gpu::Fp4ToFpOp op, Attribute dstEnc) { + Attribute srcEnc; + auto shape = op.getSrc().getType().getShape(); + if (succeeded( + dstEnc.getDialect() + .getRegisteredInterface() + ->inferFp4ToFpOpEncoding(shape, op.getAxis(), dstEnc, srcEnc, + /*fwdInference*/ false, std::nullopt))) { + return srcEnc; + } + return {}; +} + +static Attribute inferDstEncoding(triton::TransposeOpInterface op, + Attribute encoding) { + return inferTransOpDstEncoding( + encoding, cast(op.getSrc().getType()).getShape(), + op.getOrder()); +} + +static Attribute inferSrcEncoding(triton::TransposeOpInterface op, + Attribute encoding) { + // We want to solve for srcEnc in + // transpose(srcEnc, order) -> dstEnc. + // Given the identity + // transpose(transpose(x, order), inverse(order)) == x, + // we can see this is equivalent to + // transpose(dstEnc, inverse(order)) -> srcEnc. + auto shape = cast(op->getResult(0).getType()).getShape(); + return inferTransOpDstEncoding(encoding, shape, + triton::inversePermutation(op.getOrder())); +} + +static Attribute inferReshapeOpDstEncoding(ArrayRef srcShape, + Attribute srcEnc, + ArrayRef dstShape, + bool allowReorder) { + // We don't do anything smart to allow-reorder reshapes here. They are + // handled in OptimizeThreadLocality. + if (allowReorder) + return {}; + + Attribute dstEnc; + auto result = + srcEnc.getDialect() + .getRegisteredInterface() + ->inferReshapeOpEncoding(srcShape, srcEnc, dstShape, dstEnc, + /*loc=*/std::nullopt); + assert(succeeded(result)); + return dstEnc; +} + +static Attribute inferDstEncoding(triton::ReshapeOp op, Attribute encoding) { + return inferReshapeOpDstEncoding(op.getSrc().getType().getShape(), encoding, + op.getType().getShape(), + op.getAllowReorder()); +} + +static Attribute inferDstEncoding(GatherOp op, Attribute encoding) { + // The output encoding is the same as the index encoding. + // FIXME: This assumes `encoding` is the index encoding, which can be + // different than the source encoding. + return encoding; +} + +static Attribute inferSrcEncoding(triton::ReshapeOp op, Attribute encoding) { + // The encoding of x given the encoding of y in `reshape(x) -> y` is the same + // as the encoding of x given the encoding of y in `reshape(y) -> x`. It's an + // invariant of inferReshapeOpNoReorderEncoding that it's symmetric in this + // way. + return inferReshapeOpDstEncoding(op.getType().getShape(), encoding, + op.getSrc().getType().getShape(), + op.getAllowReorder()); +} + +static bool isSingleValue(Value value) { + // Don't consider load as expensive if it is loading a scalar. + if (auto tensorTy = dyn_cast(value.getType())) + return tensorTy.getNumElements() == 1; + // TODO: Handle other cases. + // For example, when ptr is a tensor of single value. + // It means that ptr is a resultant of broadcast or generated through + // a chain of broadcast and other operations. + // Rematerialize it without considering contiguous memory access pattern is + // fine. + return true; +} + +Attribute inferSrcEncoding(Operation *op, Attribute encoding) { + if (isa(op)) { + // Scan only supports blocked encoding at the moment. + if (!isa(encoding)) + return {}; + } + if (op->hasTrait() || + op->hasTrait() || + op->hasTrait() || + isa(op)) { + return encoding; + } + + if (auto reduceOp = dyn_cast(op)) + return inferSrcEncoding(reduceOp, encoding); + if (auto expand = dyn_cast(op)) + return inferSrcEncoding(expand, encoding); + if (auto join = dyn_cast(op)) + return inferSrcEncoding(join, encoding); + if (auto split = dyn_cast(op)) + return inferSrcEncoding(split, encoding); + if (auto trans = dyn_cast(op)) + return inferSrcEncoding(trans, encoding); + if (auto reshape = dyn_cast(op)) + return inferSrcEncoding(reshape, encoding); + if (auto gather = dyn_cast(op)) + return inferSrcEncoding(gather, encoding); + if (auto fp4ToFp = dyn_cast(op)) + return inferSrcEncoding(fp4ToFp, encoding); + + return {}; +} + +Attribute inferDstEncoding(Operation *op, Attribute encoding) { + if (isa(op)) { + if (!isa(encoding)) + return {}; + } + if (op->hasTrait() || + op->hasTrait() || + op->hasTrait() || + isa(op)) + return encoding; + if (auto reduceOp = dyn_cast(op)) + return inferDstEncoding(reduceOp, encoding); + if (auto expand = dyn_cast(op)) + return inferDstEncoding(expand, encoding); + if (auto join = dyn_cast(op)) + return inferDstEncoding(join, encoding); + if (auto split = dyn_cast(op)) + return inferDstEncoding(split, encoding); + if (auto trans = dyn_cast(op)) + return inferDstEncoding(trans, encoding); + if (auto reshape = dyn_cast(op)) + return inferDstEncoding(reshape, encoding); + if (auto gather = dyn_cast(op)) + return inferDstEncoding(gather, encoding); + if (auto fp4ToFp = dyn_cast(op)) + return inferDstEncoding(fp4ToFp, encoding); + + return {}; +} + +bool isExpensiveLoadOrStore(Operation *op) { + // Case 1: Pointer of tensor is always expensive + auto operandType = op->getOperand(0).getType(); + if (triton::isTensorPointerType(operandType)) + return true; + // Case 2a: A size 1 tensor is not expensive since all threads will load the + // same + if (isSingleValue(op->getOperand(0))) + return false; + // Case 2b: Tensor of pointers has more threads than elements + // we can presume a high hit-rate that makes it cheap to load + auto ptrType = cast(op->getOperand(0).getType()); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::lookupNumWarps(op); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + if (ptrType.getNumElements() < numWarps * threadsPerWarp) + return false; + return true; +} + +bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { + if (!op) + return true; + if (isa(op)) + return isExpensiveLoadOrStore(op); + if (isa(op)) + return triton::gpu::isExpensiveCat(cast(op), targetEncoding); + if (isa(op)) + return true; + if (isa( + op)) + return true; + return false; +} + +bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) { + if (isa(op)) + return !triton::gpu::isExpensiveCat(cast(op), + targetEncoding); + if (auto convert = dyn_cast(op)) { + if (mlir::isa(targetEncoding)) { + auto srcEncoding = convert.getSrc().getType().getEncoding(); + if (targetEncoding != srcEncoding) + return false; + } + return true; + } + + if (auto reshape = dyn_cast(op)) { + auto reshapeDstType = reshape.getType(); + RankedTensorType newDstType = + RankedTensorType::get(reshapeDstType.getShape(), + reshapeDstType.getElementType(), targetEncoding); + return reshape.getAllowReorder() && !reshape.getEfficientLayout() && + !triton::gpu::isExpensiveView(reshape.getSrc().getType(), + newDstType); + } + return isa(op); +} + +scf::ForOp replaceForOpWithNewSignature( + OpBuilder &rewriter, scf::ForOp loop, ValueRange newIterOperands, + SmallVectorImpl> &replacements) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(loop); + + // Create a new loop before the existing one, with the extra operands. + auto operands = llvm::to_vector<4>(loop.getInitArgs()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + scf::ForOp newLoop = rewriter.create( + loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), + operands); + newLoop->setAttrs(loop->getAttrs()); + newLoop.getBody()->erase(); + newLoop.getRegion().getBlocks().splice( + newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks()); + for (Value operand : newIterOperands) + newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); + + for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( + loop.getNumResults()))) + replacements.push_back(it); + return newLoop; +} + +scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop, + ValueRange newIterOperands) { + SmallVector> replacements; + auto newForOp = replaceForOpWithNewSignature(rewriter, loop, newIterOperands, + replacements); + for (auto [result, value] : replacements) { + result.replaceAllUsesWith(value); + } + return newForOp; +} + +scf::ForOp addIterArgsToLoop(OpBuilder &rewriter, scf::ForOp loop, + ValueRange newIterOperands) { + scf::ForOp newLoop = + replaceForOpWithNewSignature(rewriter, loop, newIterOperands); + // Save the caller from insertion point invalidation. + if (rewriter.getInsertionPoint() == loop->getIterator()) + rewriter.setInsertionPoint(newLoop); + loop.erase(); + return newLoop; +} + +scf::WhileOp replaceWhileOpWithNewSignature( + OpBuilder &rewriter, scf::WhileOp loop, ValueRange newIterOperands, + TypeRange newResultTypes, + SmallVectorImpl> &replacements) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(loop); + + // Create a new loop before the existing one, with the extra operands. + auto operands = llvm::to_vector<4>(loop.getInits()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + + // Result and operand types + SmallVector resultTypes; + SmallVector argsTypesBefore; + for (auto res : loop.getResults()) + resultTypes.push_back(res.getType()); + for (auto type : newResultTypes) + resultTypes.push_back(type); + for (Value operand : operands) + argsTypesBefore.push_back(operand.getType()); + scf::WhileOp newLoop = + rewriter.create(loop.getLoc(), resultTypes, operands); + newLoop->setAttrs(loop->getAttrs()); + + SmallVector bbArgLocsBefore(argsTypesBefore.size(), loop.getLoc()); + SmallVector bbArgLocsAfter(resultTypes.size(), loop.getLoc()); + rewriter.createBlock(&newLoop.getBefore(), {}, argsTypesBefore, + bbArgLocsBefore); + rewriter.createBlock(&newLoop.getAfter(), {}, resultTypes, bbArgLocsAfter); + + // Copy regions + for (int i = 0; i < loop.getNumRegions(); ++i) + newLoop->getRegion(i).front().getOperations().splice( + newLoop->getRegion(i).front().getOperations().begin(), + loop->getRegion(i).front().getOperations()); + + // Remap arguments + for (auto [oldArg, newArg] : llvm::zip( + loop.getBeforeArguments(), newLoop.getBeforeArguments().take_front( + loop.getBeforeArguments().size()))) + oldArg.replaceAllUsesWith(newArg); + for (auto [oldArg, newArg] : llvm::zip(loop.getAfterArguments(), + newLoop.getAfterArguments().take_front( + loop.getAfterArguments().size()))) + oldArg.replaceAllUsesWith(newArg); + + // Stack the new results + for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( + loop.getNumResults()))) + replacements.push_back(it); + + return newLoop; +} + +scf::WhileOp replaceWhileOpWithNewSignature(OpBuilder &rewriter, + scf::WhileOp loop, + ValueRange newIterOperands, + TypeRange newResultTypes) { + SmallVector> replacements; + auto newWhileOp = replaceWhileOpWithNewSignature( + rewriter, loop, newIterOperands, newResultTypes, replacements); + for (auto &kv : replacements) { + std::get<0>(kv).replaceAllUsesWith(std::get<1>(kv)); + } + return newWhileOp; +} + +scf::IfOp replaceIfOpWithNewSignature( + OpBuilder &rewriter, scf::IfOp ifOp, TypeRange newResultTypes, + SmallVectorImpl> &replacements) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(ifOp); + + // Create a new loop before the existing one, with the extra operands. + auto resultTypes = llvm::to_vector<4>(ifOp.getResults().getTypes()); + resultTypes.append(newResultTypes.begin(), newResultTypes.end()); + scf::IfOp newIf = rewriter.create(ifOp.getLoc(), resultTypes, + ifOp.getCondition()); + newIf->setAttrs(ifOp->getAttrs()); + + newIf.getThenRegion().takeBody(ifOp.getThenRegion()); + newIf.getElseRegion().takeBody(ifOp.getElseRegion()); + scf::IfOp::ensureTerminator(newIf.getThenRegion(), rewriter, ifOp.getLoc()); + scf::IfOp::ensureTerminator(newIf.getElseRegion(), rewriter, ifOp.getLoc()); + + for (auto it : llvm::zip(ifOp.getResults(), + newIf.getResults().take_front(ifOp.getNumResults()))) + replacements.push_back(it); + return newIf; +} + +void appendToForOpYield(scf::ForOp forOp, ArrayRef newOperands) { + Operation *yieldOp = forOp.getBody()->getTerminator(); + SmallVector operands(yieldOp->getOperands()); + operands.append(newOperands.begin(), newOperands.end()); + + OpBuilder builder(yieldOp); + builder.create(yieldOp->getLoc(), operands); + yieldOp->erase(); +} + +scf::IfOp replaceIfOpWithNewSignature(OpBuilder &rewriter, scf::IfOp ifOp, + TypeRange newResultTypes) { + SmallVector> replacements; + auto newIfOp = + replaceIfOpWithNewSignature(rewriter, ifOp, newResultTypes, replacements); + for (auto &kv : replacements) + std::get<0>(kv).replaceAllUsesWith(std::get<1>(kv)); + return newIfOp; +} + +Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, + IRMapping &mapping) { + Operation *newOp = rewriter.clone(*op, mapping); + // if input types haven't changed, we're done + bool preserveTypes = + std::all_of(op->operand_begin(), op->operand_end(), [&](Value v) { + return !mapping.contains(v) || + v.getType() == mapping.lookup(v).getType(); + }); + if (preserveTypes) + return newOp; + + if (newOp->getNumResults() == 0) + return newOp; + auto origType = dyn_cast(op->getResult(0).getType()); + auto argType = dyn_cast(newOp->getOperand(0).getType()); + if (!origType || !argType) + return newOp; + auto newType = RankedTensorType::get( + origType.getShape(), origType.getElementType(), argType.getEncoding()); + newOp->getResult(0).setType(newType); + auto typeInfer = dyn_cast(newOp); + if (typeInfer) { + SmallVector newTypes; + auto success = typeInfer.inferReturnTypes( + newOp->getContext(), newOp->getLoc(), newOp->getOperands(), + newOp->getAttrDictionary(), newOp->getPropertiesStorage(), + newOp->getRegions(), newTypes); + if (succeeded(success)) { + for (size_t i = 0; i < newTypes.size(); i++) + newOp->getResult(i).setType(newTypes[i]); + } + } + return newOp; +} + +// Check if the convert will be performed by reordering registers. +static bool isFreeConvert(Operation *op) { + auto convertOp = dyn_cast(op); + if (!convertOp) + return false; + return cvtReordersRegisters(convertOp.getSrc().getType(), + convertOp.getType()); +} + +LogicalResult getConvertBackwardSlice( + OpOperand &root, SetVector &slice, Attribute rootEncoding, + DenseMap &layout, + std::function stopPropagation, + std::function getExistingConversion) { + DenseSet> seen; + SmallVector> queue; + + auto enqueue = [&](OpOperand &operand, Attribute encoding) { + auto x = std::make_pair(&operand, encoding); + if (!seen.insert(x).second) { + return; // Already enqueued, skip + } + queue.push_back(x); + }; + enqueue(root, rootEncoding); + + auto updateLayout = [&](Value value, Attribute encoding) { + assert((isa(value.getType()))); + slice.insert(value); + Attribute &existing = layout[value]; + if (existing && existing != encoding) + return failure(); + existing = encoding; + return success(); + }; + + while (!queue.empty()) { + auto [currentValueUse, encoding] = queue.back(); + Value currentValue = currentValueUse->get(); + queue.pop_back(); + if (!isa(currentValue.getType())) + continue; + // Skip propagating through for op results for now. + // TODO: enable this based on needs. + if (currentValue.getDefiningOp()) + return failure(); + if (failed(updateLayout(currentValue, encoding))) + return failure(); + + Value existing; + if (getExistingConversion && + (existing = getExistingConversion(*currentValueUse, encoding))) { + if (failed(updateLayout(existing, encoding))) + return failure(); + currentValue = existing; + } + + if (auto ifOp = currentValue.getDefiningOp()) { + if (stopPropagation && stopPropagation(ifOp)) + continue; + unsigned argIdx = mlir::cast(currentValue).getResultNumber(); + + OpOperand &thenValue = ifOp.thenYield()->getOpOperand(argIdx); + OpOperand &elseValue = ifOp.elseYield()->getOpOperand(argIdx); + + enqueue(thenValue, encoding); + enqueue(elseValue, encoding); + + continue; + } + if (auto *definingOp = currentValue.getDefiningOp()) { + // If the op has multiple results we need to update all results layout. + for (Value result : definingOp->getResults()) { + if (result == currentValue || !isa(result.getType())) + continue; + if (failed(updateLayout(result, encoding))) + return failure(); + } + if (isFreeConvert(definingOp)) { + enqueue(definingOp->getOpOperand(0), encoding); + continue; + } + if (canFoldIntoConversion(definingOp, encoding)) + continue; + if (stopPropagation && stopPropagation(definingOp)) + continue; + if (isa(definingOp)) + return failure(); + if (auto gather = dyn_cast(definingOp)) { + // Specially handle gather since its transfer function only applies + // between its index operand and result. + auto srcEncoding = inferSrcEncoding(gather, encoding); + if (!srcEncoding) + return failure(); + enqueue(gather.getIndicesMutable(), srcEncoding); + continue; + } + for (auto [i, operand] : llvm::enumerate(definingOp->getOpOperands())) { + auto srcEncoding = inferSrcEncoding(definingOp, encoding); + if (!srcEncoding) + return failure(); + // If the infered layout matches the original one we don't need to keep + // propagating. + if (auto operandType = + dyn_cast(operand.get().getType())) { + if (srcEncoding == operandType.getEncoding()) + continue; + } + enqueue(operand, srcEncoding); + } + continue; + } + auto blockArg = cast(currentValue); + Block *block = blockArg.getOwner(); + Operation *parentOp = block->getParentOp(); + if (auto forOp = dyn_cast(parentOp)) { + OpOperand *initOperand = forOp.getTiedLoopInit(blockArg); + OpOperand &yieldOperand = forOp.getBody()->getTerminator()->getOpOperand( + blockArg.getArgNumber() - forOp.getNumInductionVars()); + enqueue(*initOperand, encoding); + enqueue(yieldOperand, encoding); + continue; + } + // TODO: add support for WhileOp and other region types. + return failure(); + } + return success(); +} + +// TODO(thomas): this is duplicated with what is in GPUToLLVM +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape, + ArrayRef order) { + unsigned rank = shape.size(); + assert(rank == order.size()); + auto reordered = triton::applyPermutation(shape, order); + auto reorderedMultiDim = delinearize(b, loc, linear, reordered); + SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; + } + return multiDim; +} + +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape) { + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + if (rank == 1) { + multiDim[0] = linear; + } else { + Value remained = linear; + for (auto &&en : llvm::enumerate(shape.drop_back())) { + auto dimSize = b.create(loc, en.value(), 32); + multiDim[en.index()] = b.create(loc, remained, dimSize); + remained = b.create(loc, remained, dimSize); + } + multiDim[rank - 1] = remained; + } + return multiDim; +} + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order) { + return linearize(b, loc, triton::applyPermutation(multiDim, order), + triton::applyPermutation(shape, order)); +} + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape) { + auto rank = multiDim.size(); + Value linear = b.create(loc, 0, 32); + if (rank > 0) { + linear = multiDim.back(); + for (auto [dim, dimShape] : + llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) { + Value dimSize = b.create(loc, dimShape, 32); + linear = b.create( + loc, b.create(loc, linear, dimSize), dim); + } + } + return linear; +} + +bool isPureUnaryInlineAsm(Operation *op) { + auto inlineAsmOp = dyn_cast(op); + if (!inlineAsmOp) + return false; + return op->getNumOperands() == 1 && op->getNumResults() == 1 && + inlineAsmOp.getPure(); +} + +int getNVIDIAComputeCapability(Operation *module) { + StringAttr targetAttr = + module->getAttrOfType(triton::gpu::AttrTargetName); + assert(targetAttr && "Expected a target attribute on the module operation"); + + StringRef ref = targetAttr.strref(); + assert(ref.starts_with("cuda:") && + "expected target attribute to be prefixed with \"cuda:\""); + + StringRef capabilityStr = ref.drop_front(5); // drop the "cuda:" + int computeCapability; + bool parseError = capabilityStr.getAsInteger(10, computeCapability); + assert(!parseError && + "invalid compute capability string in target attribute"); + + return computeCapability; +} + +StringRef getAMDArch(Operation *module) { + StringAttr targetAttr = + module->getAttrOfType(triton::gpu::AttrTargetName); + assert(targetAttr && "Expected a target attribute on the module operation"); + + StringRef ref = targetAttr.strref(); + assert(ref.starts_with("hip:") && + "expected target attribute to be prefixed with \"hip:\""); + + return ref.drop_front(4); // drop the "hip:" +} + +inline ttg::SwizzledSharedEncodingAttr +swizzleDotOperandLike(RankedTensorType type, ttg::CTALayoutAttr ctaLayout) { + // We want to see if the linear layout has the same order as an mma microtile + // of shape (8, 4*kWidth) or (4*kWidth, 8). If so, we return a + // DotOperandEncodingAttr with a tile of this shape This works because + // SwizzledSharedEncodingAttr::get just looks at the microtile to determine + // the swizzling + + auto *ctx = type.getContext(); + auto layout = ttg::toLinearEncoding(type); + auto order = layout.getThreadOrder(); + auto rank = order.size(); + if (rank < 2) { + return {}; + } + int opIdx; + if (ttg::getOrderForDotOperand(0, rank, /*kContig=*/true) == order) { + opIdx = 0; + } else if (ttg::getOrderForDotOperand(1, rank, /*kContig=*/true) == order) { + opIdx = 1; + } else { + return {}; + } + auto kWidth = layout.getContigPerThread()[order[0]]; + SmallVector microtileShape(rank, 1); + microtileShape[order[0]] = 4 * kWidth; + microtileShape[order[1]] = 8; + // All the LinearLayouts contained within LinearEncoidngAttr have order [0, 1, + // 2, ...] + auto repOrder = to_vector(llvm::seq(rank)); + auto tile = ttg::nvidiaMmaTile(ctx, microtileShape, kWidth, order, repOrder); + if (!divideLeft(layout.getLinearLayout(), tile).has_value()) { + return {}; + } + return ttg::SwizzledSharedEncodingAttr::get( + ctx, opIdx, kWidth, type.getShape(), order, ctaLayout, + type.getElementTypeBitWidth(), false); +} + +// If all the transitive uses of the given value have are used by a convert to +// the same dot operand encoding, return the shared encoding that needs to be +// used to be compatible with users' layouts. If there are incompatible shared +// encodings, set incompatible to true. +std::optional +getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { + ttg::SwizzledSharedEncodingAttr attr; + incompatible = false; + for (Operation *user : val.getUsers()) { + ttg::SwizzledSharedEncodingAttr tempAttr; + if (user->getNumResults() != 1) + return std::nullopt; + if (auto memDesc = + dyn_cast(user->getResult(0).getType())) { + // First time we find a shared encoding in the chain, save it and try to + // use it if it is compatible with the other users. + tempAttr = + dyn_cast(memDesc.getEncoding()); + if (!tempAttr) + return std::nullopt; + if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0), incompatible) + .has_value()) + return std::nullopt; + } else { + if (!isa(user)) + return std::nullopt; + auto srcTy = cast(val.getType()); + auto dstTy = cast(user->getResult(0).getType()); + + // FIXME This may not be correct for multiple CTA, but getCTALayout is NYI + // for LinearEncodingAttr + auto CTALayout = isa(dstTy.getEncoding()) + ? ttg::getCTALayout(srcTy.getEncoding()) + : ttg::getCTALayout(dstTy.getEncoding()); + + if (auto dot = + dyn_cast(dstTy.getEncoding())) { + auto order = getOrderForMemory(srcTy); + unsigned bitWidth = srcTy.getElementTypeBitWidth(); + tempAttr = ttg::SwizzledSharedEncodingAttr::get( + val.getContext(), dot, srcTy.getShape(), order, CTALayout, bitWidth, + /*needTrans=*/false); + } else { + // Try to see if the layout is like an mma microtile + tempAttr = swizzleDotOperandLike(dstTy, CTALayout); + } + if (!tempAttr) + return std::nullopt; + } + // Check that the shared encodings needed by the users are compatible. + if (attr != nullptr && attr != tempAttr) { + incompatible = true; + return std::nullopt; + } + attr = tempAttr; + } + return attr; +} + +namespace { + +/// Detect dead arguments in scf.for op by assuming all the values are dead and +/// propagate liveness property. +struct ForOpDeadArgElimination : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const final { + Block &block = *forOp.getBody(); + auto yieldOp = cast(block.getTerminator()); + // Assume that nothing is live at the beginning and mark values as live + // based on uses. + DenseSet aliveValues; + SmallVector queue; + // Helper to mark values as live and add them to the queue of value to + // propagate if it is the first time we detect the value as live. + auto markLive = [&](Value val) { + if (!forOp->isAncestor(val.getParentRegion()->getParentOp())) + return; + if (aliveValues.insert(val).second) + queue.push_back(val); + }; + // Mark all yield operands as live if the associated forOp result has any + // use. + for (auto result : llvm::enumerate(forOp.getResults())) { + if (!result.value().use_empty()) + markLive(yieldOp.getOperand(result.index())); + } + if (aliveValues.size() == forOp.getNumResults()) + return failure(); + // Operations with side-effects are always live. Mark all theirs operands as + // live. + block.walk([&](Operation *op) { + if (!isa(op) && !wouldOpBeTriviallyDead(op)) { + for (Value operand : op->getOperands()) + markLive(operand); + } + }); + // Propagate live property until reaching a fixed point. + while (!queue.empty()) { + Value value = queue.pop_back_val(); + if (auto nestedFor = value.getDefiningOp()) { + auto result = mlir::cast(value); + OpOperand &forOperand = *nestedFor.getTiedLoopInit(result); + markLive(forOperand.get()); + auto nestedYieldOp = + cast(nestedFor.getBody()->getTerminator()); + Value nestedYieldOperand = + nestedYieldOp.getOperand(result.getResultNumber()); + markLive(nestedYieldOperand); + continue; + } + if (auto nestedIf = value.getDefiningOp()) { + auto result = mlir::cast(value); + // mark condition as live. + markLive(nestedIf.getCondition()); + for (scf::YieldOp nestedYieldOp : + {nestedIf.thenYield(), nestedIf.elseYield()}) { + Value nestedYieldOperand = + nestedYieldOp.getOperand(result.getResultNumber()); + markLive(nestedYieldOperand); + } + continue; + } + if (Operation *def = value.getDefiningOp()) { + // TODO: support while ops. + if (isa(def)) + return failure(); + for (Value operand : def->getOperands()) + markLive(operand); + continue; + } + // If an argument block is live then the associated yield operand and + // forOp operand are live. + auto arg = mlir::cast(value); + if (auto forOwner = dyn_cast(arg.getOwner()->getParentOp())) { + if (arg.getArgNumber() < forOwner.getNumInductionVars()) + continue; + unsigned iterIdx = arg.getArgNumber() - forOwner.getNumInductionVars(); + Value yieldOperand = + forOwner.getBody()->getTerminator()->getOperand(iterIdx); + markLive(yieldOperand); + markLive(forOwner.getInitArgs()[iterIdx]); + } + } + SmallVector deadArg; + for (auto yieldOperand : llvm::enumerate(yieldOp->getOperands())) { + if (aliveValues.contains(yieldOperand.value())) + continue; + if (yieldOperand.value() == block.getArgument(yieldOperand.index() + 1)) + continue; + + // The yield operand might live outside the loop, e.g. + // %init = ... + // %x = ... + // %y = for iter_args(%unused = %init) { + // yield %x + // } + // + // In this case, the loop returns %x if it runs 1 or more times, and + // otherwise it returns %init. We cowardly refuse to remove this operand + // from the yield. (We could, but we'd need to prove that the loop runs 0 + // or >=1 times.) + // + // As a special case, if it doesn't matter whether the loop runs 0 or >=1 + // times (because the loop returns the same value in both cases) then we + // can still mark the operand as dead. This occurs in the above example + // when %init is the same as %x. + if (!forOp->isAncestor( + yieldOperand.value().getParentRegion()->getParentOp()) && + yieldOperand.value() != forOp.getInitArgs()[yieldOperand.index()]) + continue; + + deadArg.push_back(yieldOperand.index()); + } + if (deadArg.empty()) + return failure(); + rewriter.modifyOpInPlace(forOp, [&]() { + // For simplicity we just change the dead yield operand to use the + // associated argument and leave the operations and argument removal to + // dead code elimination. + for (unsigned deadArgIdx : deadArg) { + BlockArgument arg = block.getArgument(deadArgIdx + 1); + yieldOp.setOperand(deadArgIdx, arg); + } + }); + return success(); + } +}; + +} // namespace + +void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +ttg::LocalAllocOp findShmemAlloc(Value operand) { + // If it's a shmem operand, it must either be defined outside the loop, or + // come from an MemDescSubview op. Only ConvertLayout and Trans ops are + // allowed in between. + Value transitiveOperand = operand; + while (isa_and_nonnull( + transitiveOperand.getDefiningOp()) || + isa(transitiveOperand)) { + if (auto blockArg = dyn_cast(transitiveOperand)) { + assert(isa(blockArg.getOwner()->getParentOp()) && + "Block argument must come from a for loop"); + transitiveOperand = + cast(blockArg.getOwner()->getTerminator()) + .getOperand(blockArg.getArgNumber() - 1); + } else { + transitiveOperand = transitiveOperand.getDefiningOp()->getOperand(0); + } + } + if (auto subView = dyn_cast_or_null( + transitiveOperand.getDefiningOp())) { + // Multi-buffered operand + return dyn_cast_or_null( + subView.getSrc().getDefiningOp()); + } else { + // Single bufferred operand that does not require a subview (not loaded in + // the loop) + return dyn_cast_or_null( + transitiveOperand.getDefiningOp()); + } + return nullptr; +} + +SmallVector +getMMAsWithMultiBufferredOperands(scf::ForOp forOp, + SmallVector &mmaOps) { + // The A and B operands of the mmaOp should be multi-buffered + SmallVector eligible; + for (auto mmaOp : mmaOps) { + auto a = findShmemAlloc(mmaOp->getOperand(0)); + auto b = findShmemAlloc(mmaOp->getOperand(1)); + if (a && forOp.isDefinedOutsideOfLoop(a) && b && + forOp.isDefinedOutsideOfLoop(b)) { + eligible.push_back(mmaOp); + } + } + + return eligible; +} + +template +static Operation *findNearestCommonDominatorImpl( + ArrayRef ops, DomInfoT &domInfo, + function_ref isBefore) { + if (ops.size() == 0) { + return nullptr; + } + if (ops.size() == 1) { + return ops[0]; + } + llvm::SmallPtrSet blocks; + for (auto op : ops) { + blocks.insert(op->getBlock()); + } + Block *domBlock = domInfo.findNearestCommonDominator(blocks); + if (domBlock == nullptr) { + return nullptr; + } + SmallVector ancestorOps; + for (auto op : ops) { + ancestorOps.push_back(domBlock->findAncestorOpInBlock(*op)); + } + Operation *dom = ancestorOps[0]; + for (unsigned i = 1; i < ops.size(); i++) { + if (isBefore(ancestorOps[i], dom)) { + dom = ancestorOps[i]; + } + } + return dom; +} + +Operation *findNearestCommonDominator(ArrayRef ops, + DominanceInfo &domInfo) { + return findNearestCommonDominatorImpl( + ops, domInfo, + [](Operation *a, Operation *b) { return a->isBeforeInBlock(b); }); +} + +Operation *findNearestCommonPostDominator(ArrayRef ops, + PostDominanceInfo &domInfo) { + return findNearestCommonDominatorImpl( + ops, domInfo, + [](Operation *a, Operation *b) { return b->isBeforeInBlock(a); }); +} + +void visitNestedOperands(Operation *op, + function_ref visitor) { + op->walk([&](Operation *nestedOp) { + for (OpOperand &operand : nestedOp->getOpOperands()) { + if (operand.get().getParentBlock()->getParentOp()->isProperAncestor(op)) + visitor(operand); + } + }); +} + +void visitNestedOperands(Operation *op, function_ref visitor) { + visitNestedOperands(op, [&](OpOperand &operand) { visitor(operand.get()); }); +} + +SetVector getNestedOperands(Operation *op) { + SetVector result; + visitNestedOperands(op, [&](Value operand) { result.insert(operand); }); + return result; +} + +void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices) { + // Pad the indices in case new arguments were added. + while (indices.size() != loop.getInitArgs().size()) + indices.push_back(false); + + loop.getBody()->getTerminator()->eraseOperands(indices); + loop.getBody()->eraseArguments([&](BlockArgument arg) { + int idx = arg.getArgNumber(); + return idx != 0 && indices.test(idx - 1); + }); + + llvm::BitVector loopOperandIndices(loop->getNumOperands()); + for (auto [i, operand] : llvm::enumerate(loop.getInitArgsMutable())) { + if (indices.test(i)) + loopOperandIndices.set(operand.getOperandNumber()); + } + loop->eraseOperands(loopOperandIndices); + + // Rewrite the loop to erase results. + OperationState state(loop.getLoc(), loop->getName(), loop->getOperands(), + loop.getInitArgs().getTypes(), loop->getAttrs()); + state.addRegion()->takeBody(loop.getBodyRegion()); + + OpBuilder b(loop); + auto newLoop = cast(b.create(state)); + + // Replace uses of the old loop with the new loop. + unsigned newResultIdx = 0; + for (auto [i, result] : llvm::enumerate(loop.getResults())) { + if (indices.test(i)) { + assert(result.use_empty() && "loop carried value still has uses"); + continue; + } + result.replaceAllUsesWith(newLoop.getResult(newResultIdx++)); + } + + loop.erase(); + loop = newLoop; +} + +} // namespace mlir + +namespace mlir::triton { +void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse, + Value val) { + SmallVector opsToDelete; + SmallVector operandsToReplace; + + // Save the operand to replace / delete later (avoid iterator invalidation). + // TODO: can we use an early_inc iterator? + for (OpOperand &use : oldUse->getUses()) { + // Propagate through `ttg.warp_specialize`. + if (auto wsOp = dyn_cast(use.getOwner())) { + for (Region *region : wsOp.getPartitionRegions()) + region->getArgument(use.getOperandNumber()).setType(val.getType()); + } + + // Non-subview/trans ops will be replaced by `val`. + if (!use.getOwner()->hasTrait()) { + operandsToReplace.push_back(&use); + continue; + } + + Operation *user = use.getOwner(); + // `subview(old_op)` is replaced by a new `subview(val)`. + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(user); + Value newVal; + if (auto subview = dyn_cast(user)) { + ttg::MemDescType oldType = subview.getType(); + bool isMutable = cast(val.getType()).getMutableMemory(); + Type newDstType = ttg::MemDescType::get( + oldType.getShape(), oldType.getElementType(), oldType.getEncoding(), + oldType.getMemorySpace(), isMutable); + newVal = builder.create( + subview.getLoc(), newDstType, val, subview.getOffsets()); + } else if (auto trans = dyn_cast(user)) { + newVal = builder.create(trans.getLoc(), val, + trans.getOrder()); + } else if (auto reshape = dyn_cast(user)) { + newVal = builder.create(reshape.getLoc(), + reshape.getType(), val); + } + assert(newVal && "unhandled memdesc view"); + newVal.getDefiningOp()->setAttrs(user->getAttrs()); + replaceUsesAndPropagateType(builder, user, newVal); + opsToDelete.push_back(use.getOwner()); + } + + // Perform late replacement. + for (OpOperand *operand : operandsToReplace) { + if (auto wait = dyn_cast(operand->getOwner())) { + // Need to update the return type on the wait op as well + builder.setInsertionPointAfter(wait); + auto operands = llvm::to_vector(wait.getOperands()); + operands[operand->getOperandNumber()] = val; + auto newWait = builder.create( + wait.getLoc(), operands, wait.getPendings()); + wait.replaceAllUsesWith(newWait.getResults()); + wait.erase(); + } else { + Operation *op = operand->getOwner(); + operand->set(val); + } + } + + // Perform late op erasure. + for (Operation *op : opsToDelete) + op->erase(); +} + +void replaceUsesWithLocalLoad(OpBuilder &builder, OpResult old, + TypedValue alloc, + TypedValue token) { + // Remove redundant local_load -> local_alloc + auto allocTy = alloc.getType(); + SmallVector allocsToErase; + for (Operation *user : old.getUsers()) { + if (auto userAlloc = dyn_cast(user)) { + if (allocTy.getEncoding() == userAlloc.getType().getEncoding()) { + replaceUsesAndPropagateType(builder, userAlloc, alloc); + allocsToErase.push_back(userAlloc); + } + } + } + + // If there are some uses that were not local_allocs, we need to create a + // local_load for them. + if (std::distance(old.getUsers().begin(), old.getUsers().end()) > + allocsToErase.size()) { + auto loc = old.getOwner()->getLoc(); + auto sharedLoad = builder.template create( + loc, old.getType(), alloc, token); + old.replaceAllUsesWith(sharedLoad.getResult()); + } + for (auto alloc : allocsToErase) { + alloc.erase(); + } +} + +bool comesFromLoadOrBlockArg(Value v) { + // Peel out the original cvt dot_op<..., #blocked> + // and any other potential cvt/trans ops + while (true) { + Operation *def = v.getDefiningOp(); + if (!def) + break; + if (auto cvtOp = dyn_cast(def)) { + v = cvtOp.getSrc(); + continue; + } + if (auto transOp = dyn_cast(def)) { + v = transOp.getSrc(); + continue; + } + if (def->hasTrait()) { + v = def->getOperand(0); + continue; + } + break; + } + // We also accept block arguments as they appear in many MLIR tests + // If this is problematic we can totally drop them + return isa(v) || + (v.getDefiningOp() && + isa(v.getDefiningOp())); +} + +} // namespace mlir::triton + +std::string getEnvStr(const std::string& env_name) { + const char* val = getenv(env_name.c_str()); + if(val == nullptr) { + return ""; + } + return std::string(val); +} \ No newline at end of file diff --git a/third_party/sunrise/python/src/gluon_ir.cc b/third_party/sunrise/python/src/gluon_ir.cc new file mode 100644 index 000000000..aa8bc5ad8 --- /dev/null +++ b/third_party/sunrise/python/src/gluon_ir.cc @@ -0,0 +1,458 @@ +#include "ir.h" +#include "pybind11/pybind11.h" +#include + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" + +using namespace mlir; +namespace py = pybind11; +namespace tt = triton; +namespace ttg = triton::gpu; +namespace ttng = triton::nvidia_gpu; + +// Helper to check if an MLIR type or attribute has a verifier method. +template +static constexpr auto hasVerifier(AttrOrType t) + -> decltype(t.verifyInvariants, true) { + return true; +} +static constexpr auto hasVerifier(...) { return false; } + +// Print a diagnostic without its location. The frontend will attach the AST +// location to the error message. +static void printDiagStr(llvm::raw_ostream &os, const Diagnostic &diag) { + for (const DiagnosticArgument &arg : diag.getArguments()) + arg.print(os); + os << "\n"; + for (const Diagnostic ¬e : diag.getNotes()) + printDiagStr(os, note); +} + +struct GluonOpBuilder : public TritonOpBuilder { + // Construct an attribute or type while calling its verifier. Error messages + // are intercepted and sent back to Python via a C++ exception. + template + std::enable_if_t + getChecked(ArgTs &&...args) { + // Set up a scoped handler to intercept errors. + std::string msg; + llvm::raw_string_ostream os(msg); + ScopedDiagnosticHandler handler( + getContext(), [&](Diagnostic &diag) { printDiagStr(os, diag); }); + + auto result = + AttrOrType::getChecked([&] { return mlir::emitError(getLastLoc()); }, + std::forward(args)...); + if (!result) + throw std::runtime_error(os.str()); + return result; + } + + // A variant of the above due to issues with C++ overload resolution and how + // MLIR sets up the default `getChecked` implementation. + template + std::enable_if_t + getChecked(MLIRContext *ctx, ArgTs &&...args) { + // Set up a scoped handler to intercept errors. + std::string msg; + llvm::raw_string_ostream os(msg); + ScopedDiagnosticHandler handler( + getContext(), [&](Diagnostic &diag) { printDiagStr(os, diag); }); + + if (failed(AttrOrType::verifyInvariants( + [&] { return mlir::emitError(getLastLoc()); }, + std::forward(args)...))) + throw std::runtime_error(os.str()); + + return AttrOrType::get(ctx, std::forward(args)...); + } + + // Fallback method for types or attributes that do not have a verifier. + template + std::enable_if_t + getChecked(ArgTs &&...args) { + return AttrOrType::get(std::forward(args)...); + } +}; + +struct GluonLayouts { + py::handle BlockedLayout; + py::handle SliceLayout; + py::handle DistributedLinearLayout; + py::handle NVMMASharedLayout; + py::handle SwizzledSharedLayout; + + GluonLayouts() { + auto layouts = + py::module::import("triton.experimental.gluon.language._layouts"); + BlockedLayout = py::object(layouts.attr("BlockedLayout")).release(); + SliceLayout = py::object(layouts.attr("SliceLayout")).release(); + DistributedLinearLayout = + py::object(layouts.attr("DistributedLinearLayout")).release(); + NVMMASharedLayout = py::object(layouts.attr("NVMMASharedLayout")).release(); + SwizzledSharedLayout = + py::object(layouts.attr("SwizzledSharedLayout")).release(); + } +}; + +template std::vector toStdVector(llvm::ArrayRef array) { + return std::vector(array.begin(), array.end()); +} + +py::object layoutToGluon(Attribute layout) { + static GluonLayouts layouts; + if (auto blocked = dyn_cast(layout)) { + auto ctaLayout = blocked.getCTALayout(); + return layouts.BlockedLayout(toStdVector(blocked.getSizePerThread()), + toStdVector(blocked.getThreadsPerWarp()), + toStdVector(blocked.getWarpsPerCTA()), + toStdVector(blocked.getOrder()), + toStdVector(ctaLayout.getCTAsPerCGA()), + toStdVector(ctaLayout.getCTASplitNum()), + toStdVector(ctaLayout.getCTAOrder())); + } else if (auto sliced = dyn_cast(layout)) { + return layouts.SliceLayout(sliced.getDim(), + layoutToGluon(sliced.getParent())); + } else if (auto linear = dyn_cast(layout)) { + auto ll = linear.getLinearLayout(); + auto ctx = layout.getContext(); + auto kReg = mlir::StringAttr::get(ctx, "register"); + auto kLane = mlir::StringAttr::get(ctx, "lane"); + auto kWarp = mlir::StringAttr::get(ctx, "warp"); + auto kBlock = mlir::StringAttr::get(ctx, "block"); + return layouts.DistributedLinearLayout( + ll.getBases().lookup(kReg), ll.getBases().lookup(kLane), + ll.getBases().lookup(kWarp), ll.getBases().lookup(kBlock), + ll.getOutDimSizes()); + } else if (auto nvmma = dyn_cast(layout)) { + auto ctaLayout = nvmma.getCTALayout(); + return layouts.NVMMASharedLayout( + nvmma.getSwizzlingByteWidth(), nvmma.getElementBitWidth(), + ctaLayout.getRank(), nvmma.getTransposed(), nvmma.getFp4Padded(), + toStdVector(ctaLayout.getCTAsPerCGA()), + toStdVector(ctaLayout.getCTASplitNum()), + toStdVector(ctaLayout.getCTAOrder())); + } else if (auto swizzled = + dyn_cast(layout)) { + auto ctaLayout = nvmma.getCTALayout(); + return layouts.SwizzledSharedLayout( + swizzled.getVec(), swizzled.getPerPhase(), swizzled.getMaxPhase(), + swizzled.getOrder(), toStdVector(ctaLayout.getCTAsPerCGA()), + toStdVector(ctaLayout.getCTASplitNum()), + toStdVector(ctaLayout.getCTAOrder())); + } + throw py::value_error("Unhandled encoding encountered"); +} + +void init_gluon_ir(py::module &&m) { + using ret = py::return_value_policy; + + py::class_( + m, "GluonOpBuilder", py::module_local(), py::dynamic_attr()) + .def(py::init()) + .def("get_distributed_ty", + [](GluonOpBuilder &self, Type &elementType, + std::vector &shape, Attribute layout) -> Type { + return self.getChecked(shape, elementType, + layout); + }) + .def("get_shared_mem_desc_ty", + [](GluonOpBuilder &self, Type &elementType, + std::vector &shape, Attribute layout, + std::vector &allocShape) -> Type { + auto ctx = self.getContext(); + return self.getChecked( + shape, elementType, layout, + ttg::SharedMemorySpaceAttr::get(ctx), /*mutableMemory=*/true, + /*allocShape=*/allocShape); + }) + .def("get_tensor_mem_desc_ty", + [](GluonOpBuilder &self, Type &elementType, + std::vector &shape, Attribute layout, + std::vector &allocShape) -> Type { + auto ctx = self.getContext(); + return self.getChecked( + shape, elementType, layout, + ttng::TensorMemorySpaceAttr::get(ctx), /*mutableMemory=*/true, + /*allocShape=*/allocShape); + }) + .def("get_blocked_layout", + [](GluonOpBuilder &self, std::vector &sizePerThread, + std::vector &threadsPerWarp, + std::vector &warpsPerCta, std::vector &order, + std::vector &ctasPerCga, + std::vector &ctaSplitNum, + std::vector &ctaOrder) -> Attribute { + auto ctx = self.getContext(); + auto ctaLayout = self.getChecked( + ctx, ctasPerCga, ctaSplitNum, ctaOrder); + return self.getChecked( + ctx, sizePerThread, threadsPerWarp, warpsPerCta, order, + ctaLayout); + }) + .def("get_slice_layout", + [](GluonOpBuilder &self, unsigned dim, + Attribute parent) -> Attribute { + auto ctx = self.getContext(); + auto dist = cast(parent); + return self.getChecked(ctx, dim, dist); + }) + .def("get_distributed_linear_layout", + [](GluonOpBuilder &self, std::vector> regBases, + std::vector> laneBases, + std::vector> warpBases, + std::vector> blockBases, + std::vector shape) -> Attribute { + auto ctx = self.getContext(); + auto kReg = mlir::StringAttr::get(ctx, "register"); + auto kLane = mlir::StringAttr::get(ctx, "lane"); + auto kWarp = mlir::StringAttr::get(ctx, "warp"); + auto kBlock = mlir::StringAttr::get(ctx, "block"); + auto outDims = tt::standardOutDimPairs(ctx, shape); + auto ll = tt::LinearLayout({{kReg, regBases}, + {kLane, laneBases}, + {kWarp, warpBases}, + {kBlock, blockBases}}, + outDims, + /*requiresSurjective=*/true); + return ttg::LinearEncodingAttr::get(ctx, ll); + }) + .def("get_nvmma_shared_layout", + [](GluonOpBuilder &self, unsigned swizzleByteWidth, + unsigned elementBitwidth, bool transposed, bool fp4Padded, + std::vector &ctasPerCga, + std::vector &ctaSplitNum, + std::vector &ctaOrder) -> Attribute { + auto ctx = self.getContext(); + auto ctaLayout = self.getChecked( + ctx, ctasPerCga, ctaSplitNum, ctaOrder); + return self.getChecked( + ctx, swizzleByteWidth, transposed, elementBitwidth, fp4Padded, + ctaLayout); + }) + .def("get_swizzled_shared_layout", + [](GluonOpBuilder &self, int vec, int perPhase, int maxPhase, + std::vector &order, std::vector &ctasPerCga, + std::vector &ctaSplitNum, + std::vector &ctaOrder) -> Attribute { + auto ctx = self.getContext(); + auto ctaLayout = self.getChecked( + ctx, ctasPerCga, ctaSplitNum, ctaOrder); + return self.getChecked( + ctx, vec, perPhase, maxPhase, order, ctaLayout); + }) + .def("get_tensor_memory_layout", + [](GluonOpBuilder &self, std::vector &block, bool unpacked, + std::vector &ctaSplitNum) -> Attribute { + auto ctx = self.getContext(); + assert(block.size() == 2); + assert(ctaSplitNum.size() == 2); + return self.getChecked( + ctx, block[0], block[1], unpacked, ctaSplitNum[0], + ctaSplitNum[1]); + }) + .def("get_gluon_layout_from_tensor", + [](GluonOpBuilder &self, Value tensor) -> py::object { + auto ty = dyn_cast(tensor.getType()); + assert(ty.getEncoding()); + return layoutToGluon(ty.getEncoding()); + }) + .def("get_gluon_layout_from_memdesc", + [](GluonOpBuilder &self, Value memdesc) -> py::object { + auto ty = dyn_cast(memdesc.getType()); + assert(ty.getEncoding()); + return layoutToGluon(ty.getEncoding()); + }) + .def("get_tensor_descriptor_layout_type", + [](GluonOpBuilder &self, Type blockType, bool isSigned, + Attribute layout) -> Type { + auto ctx = self.getContext(); + auto blockTy = cast(blockType); + auto blockTyLayout = RankedTensorType::get( + blockTy.getShape(), blockTy.getElementType(), layout); + return triton::TensorDescType::get(ctx, blockTyLayout, isSigned); + }) + .def("create_convert_layout", + [](GluonOpBuilder &self, Type resultTy, Value value) -> Value { + return self.create(resultTy, value); + }) + .def("create_local_alloc", + [](GluonOpBuilder &self, Type resultTy) -> Value { + return self.create(resultTy); + }) + .def("create_local_alloc", + [](GluonOpBuilder &self, Type resultTy, Value value) -> Value { + return self.create(resultTy, value); + }) + .def("create_local_store", + [](GluonOpBuilder &self, Value memDesc, Value value) { + self.create(value, memDesc); + }) + .def("create_local_load", + [](GluonOpBuilder &self, Type resultTy, Value memDesc) -> Value { + return self.create(resultTy, memDesc); + }) + .def("create_local_dealloc", + [](GluonOpBuilder &self, Value memDesc) -> Operation * { + return self.create(memDesc); + }) + + .def("create_memdesc_subview", + [](GluonOpBuilder &self, Type resultType, Value src, + std::vector &offsets) -> Value { + return self.create(resultType, src, + offsets); + }) + .def("create_memdesc_trans", + [](GluonOpBuilder &self, Value src, + std::vector &order) -> Value { + return self.create(src, order); + }) + .def("create_memdesc_reshape", + [](GluonOpBuilder &self, Type resultType, Value src) -> Value { + return self.create(resultType, src); + }) + .def("create_memdesc_reinterpret", + [](GluonOpBuilder &self, Type resultType, Value src) -> Value { + return self.create(resultType, src); + }) + + .def("create_tmem_alloc", + [](GluonOpBuilder &self, Type resultTy, Value value) -> Value { + return self.create(resultTy, value); + }) + .def("create_tmem_alloc", + [](GluonOpBuilder &self, Type resultTy, py::none value) -> Value { + return self.create(resultTy, Value{}); + }) + .def("create_tmem_store", + [](GluonOpBuilder &self, Value memDesc, Value value, Value pred) { + self.create(memDesc, value, pred); + }) + .def("create_tmem_load", + [](GluonOpBuilder &self, Type resultTy, Value memDesc) -> Value { + return self.create(resultTy, memDesc); + }) + .def("create_tmem_subslice", + [](GluonOpBuilder &self, Type resultTy, Value memDesc, + int N) -> Value { + return self.create(resultTy, memDesc, N); + }) + .def("create_mbarrier_init", + [](GluonOpBuilder &self, Value memDesc, int count) { + self.create(memDesc, count); + }) + .def("create_mbarrier_inval", + [](GluonOpBuilder &self, Value memDesc) { + self.create(memDesc); + }) + .def("create_mbarrier_expect", + [](GluonOpBuilder &self, Value memDesc, int bytes, Value pred) { + self.create(memDesc, bytes, pred); + }) + .def("create_mbarrier_wait", + [](GluonOpBuilder &self, Value memDesc, Value phase, Value pred, + std::vector &deps) { + self.create(memDesc, phase, pred, deps); + }) + .def("create_mbarrier_arrive", + [](GluonOpBuilder &self, Value memDesc, int count, Value pred) { + self.create(memDesc, count, pred); + }) + .def("create_tcgen05_mma", + [](GluonOpBuilder &self, Value a, Value b, Value acc, Value useAcc, + Value pred, std::vector &mbarriers, + std::vector &mbarrier_preds) { + Value accDep; + bool two_ctas = false; + auto tokType = self.getBuilder().getType(); + self.create(tokType, a, b, acc, accDep, useAcc, + pred, two_ctas, mbarriers, + mbarrier_preds); + }) + + .def("create_async_tma_copy_global_to_local", + [](GluonOpBuilder &self, Value descPtr, std::vector &coord, + Value barrier, Value result, Value pred) { + self.create( + descPtr, coord, barrier, result, pred); + }) + .def("create_async_tma_copy_local_to_global", + [](GluonOpBuilder &self, Value descPtr, std::vector &coord, + Value src) { + self.create(descPtr, coord, + src); + }) + .def("create_async_tma_reduce", + [](GluonOpBuilder &self, triton::DescriptorReduceKind kind, + Value descPtr, std::vector &coord, Value src) { + self.create(kind, descPtr, coord, src); + }) + .def("create_async_tma_store_wait", + [](GluonOpBuilder &self, int pendings) { + self.create(pendings); + }) + .def("create_async_tma_gather", + [](GluonOpBuilder &self, Value descPtr, Value xOffsets, + Value yOffset, Value barrier, Value result, Value pred) { + self.create(descPtr, xOffsets, yOffset, + barrier, result, pred); + }) + .def("create_async_tma_scatter", + [](GluonOpBuilder &self, Value descPtr, Value xOffsets, + Value yOffset, Value src) { + self.create(descPtr, xOffsets, yOffset, + src); + }) + .def("create_fence_async_shared", + [](GluonOpBuilder &self, bool bCluster) -> OpState { + return self.create(bCluster); + }) + + .def("create_broadcast", + [](TritonOpBuilder &self, Value &arg, Type retTy) -> Value { + return self.create(retTy, arg); + }) + .def( + "create_expand_dims", + [](TritonOpBuilder &self, Value &arg, int axis, Type retTy) -> Value { + return self.create(retTy, arg, axis); + }) + .def("create_warp_return", + [](GluonOpBuilder &self) -> Operation * { + return self.create(); + }) + .def("create_warp_yield", + [](GluonOpBuilder &self, std::vector &values) -> Operation * { + return self.create(values); + }) + .def("create_warp_specialize_partitions", + [](GluonOpBuilder &self, int numPartitions) -> Operation * { + return self.create(numPartitions); + }) + .def("create_warp_specialize", [](GluonOpBuilder &self, + std::vector &resultTypes, + std::vector &explicitCaptures, + std::vector &partitionNumWarps) { + return self.create(resultTypes, explicitCaptures, + partitionNumWarps); + }); + + py::class_(m, "WarpSpecializeOp", + py::module_local()) + .def("get_default_region", &ttg::WarpSpecializeOp::getDefaultRegion, + ret::reference) + .def("get_partition_op_holder", + &ttg::WarpSpecializeOp::getPartitionOpHolder, ret::reference) + .def("set_requested_registers", [](ttg::WarpSpecializeOp &self, + std::vector &requestedRegisters) { + self.setRequestedRegisters(requestedRegisters); + }); +} diff --git a/third_party/sunrise/python/src/interpreter.cc b/third_party/sunrise/python/src/interpreter.cc new file mode 100644 index 000000000..747a0cc17 --- /dev/null +++ b/third_party/sunrise/python/src/interpreter.cc @@ -0,0 +1,740 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +namespace { + +struct npy_half { + uint16_t value; +}; + +enum class MemSemantic { ACQUIRE_RELEASE, ACQUIRE, RELEASE, RELAXED }; + +std::mutex atomic_op_guard; + +template +constexpr bool is_reinterpret_cast_to_atomic_safe = + std::is_trivially_copyable_v && + std::is_trivially_copyable_v> && + std::is_standard_layout_v && std::is_standard_layout_v> && + sizeof(T) == sizeof(std::atomic) && + alignof(T) == alignof(std::atomic); + +enum class RMWOp { ADD, FADD, AND, OR, XOR, XCHG, MAX, MIN, UMIN, UMAX }; + +std::map mem_semantic_map = { + {MemSemantic::ACQUIRE_RELEASE, std::memory_order_acq_rel}, + {MemSemantic::ACQUIRE, std::memory_order_acquire}, + {MemSemantic::RELEASE, std::memory_order_release}, + {MemSemantic::RELAXED, std::memory_order_relaxed}, +}; + +template +T atomic_cmp(T *ptr, T val, std::memory_order order) { + auto cmp = [](T old, T val) { + if constexpr (is_min) { + return old > val; + } else { + return old < val; + } + }; + + T old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_ptr = reinterpret_cast *>(ptr); + old_val = atomic_ptr->load(order); + while (cmp(old_val, val)) { + if (atomic_ptr->compare_exchange_weak(old_val, val, order, order)) { + break; + } + } + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *ptr; + if (cmp(old_val, val)) { + *ptr = val; + } + } + return old_val; +} + +template T atomic_fadd(T *loc, T value, std::memory_order order) { + static_assert(std::is_floating_point::value, + "T must be a floating-point type"); + T old_value; + + if constexpr (is_reinterpret_cast_to_atomic_safe) { + T new_value; + std::atomic *atomic_loc = reinterpret_cast *>(loc); + old_value = atomic_loc->load(order); + do { + new_value = old_value + value; + } while ( + !atomic_loc->compare_exchange_weak(old_value, new_value, order, order)); + } else { + const std::lock_guard lock(atomic_op_guard); + old_value = *loc; + *loc = old_value + value; + } + + return old_value; +} + +/** Create a value of type `To` from the bits of `from`. + * + * similar to `std::bit_cast` but compatible with C++17, + * should perform similar to `*reinterpret_cast(&from)` + * or through punning without expecting any undefined behaviors. + * + * Note: taken from + * https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/utils.hpp#L32 + * with simplification. + */ +template +inline To BitCast(const From &from) noexcept { + static_assert(sizeof(To) == sizeof(From), + "both data types must have the same size"); + + static_assert(std::is_trivially_copyable_v && + std::is_trivially_copyable_v, + "both data types must be trivially copyable"); + + To to; + memcpy(&to, &from, sizeof(from)); + return to; +} + +// Taken from +// https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/half_private.hpp#L14 +template +inline uint16_t FromFloatBits(uint32_t f) { + uint32_t f_exp, f_sig; + uint16_t h_sgn, h_exp, h_sig; + + h_sgn = (uint16_t)((f & 0x80000000u) >> 16); + f_exp = (f & 0x7f800000u); + + /* Exponent overflow/NaN converts to signed inf/NaN */ + if (f_exp >= 0x47800000u) { + if (f_exp == 0x7f800000u) { + /* Inf or NaN */ + f_sig = (f & 0x007fffffu); + if (f_sig != 0) { + /* NaN - propagate the flag in the significand... */ + uint16_t ret = (uint16_t)(0x7c00u + (f_sig >> 13)); + /* ...but make sure it stays a NaN */ + if (ret == 0x7c00u) { + ret++; + } + return h_sgn + ret; + } else { + /* signed inf */ + return (uint16_t)(h_sgn + 0x7c00u); + } + } else { + if constexpr (gen_overflow) { + // FloatStatus::RaiseOverflow(); + throw std::overflow_error("overflow to signed inf"); + } + return (uint16_t)(h_sgn + 0x7c00u); + } + } + + /* Exponent underflow converts to a subnormal half or signed zero */ + if (f_exp <= 0x38000000u) { + /* + * Signed zeros, subnormal floats, and floats with small + * exponents all convert to signed zero half-floats. + */ + if (f_exp < 0x33000000u) { + if constexpr (gen_underflow) { + /* If f != 0, it underflowed to 0 */ + if ((f & 0x7fffffff) != 0) { + // FloatStatus::RaiseUnderflow(); + throw std::underflow_error(""); + } + } + return h_sgn; + } + /* Make the subnormal significand */ + f_exp >>= 23; + f_sig = (0x00800000u + (f & 0x007fffffu)); + if constexpr (gen_underflow) { + /* If it's not exactly represented, it underflowed */ + if ((f_sig & (((uint32_t)1 << (126 - f_exp)) - 1)) != 0) { + // FloatStatus::RaiseUnderflow(); + throw std::underflow_error(""); + } + } + /* + * Usually the significand is shifted by 13. For subnormals an + * additional shift needs to occur. This shift is one for the largest + * exponent giving a subnormal `f_exp = 0x38000000 >> 23 = 112`, which + * offsets the new first bit. At most the shift can be 1+10 bits. + */ + f_sig >>= (113 - f_exp); + /* Handle rounding by adding 1 to the bit beyond half precision */ + if constexpr (round_even) { + /* + * If the last bit in the half significand is 0 (already even), and + * the remaining bit pattern is 1000...0, then we do not add one + * to the bit after the half significand. However, the (113 - f_exp) + * shift can lose up to 11 bits, so the || checks them in the original. + * In all other cases, we can just add one. + */ + if (((f_sig & 0x00003fffu) != 0x00001000u) || (f & 0x000007ffu)) { + f_sig += 0x00001000u; + } + } else { + f_sig += 0x00001000u; + } + h_sig = (uint16_t)(f_sig >> 13); + /* + * If the rounding causes a bit to spill into h_exp, it will + * increment h_exp from zero to one and h_sig will be zero. + * This is the correct result. + */ + return (uint16_t)(h_sgn + h_sig); + } + + /* Regular case with no overflow or underflow */ + h_exp = (uint16_t)((f_exp - 0x38000000u) >> 13); + /* Handle rounding by adding 1 to the bit beyond half precision */ + f_sig = (f & 0x007fffffu); + if constexpr (round_even) { + /* + * If the last bit in the half significand is 0 (already even), and + * the remaining bit pattern is 1000...0, then we do not add one + * to the bit after the half significand. In all other cases, we do. + */ + if ((f_sig & 0x00003fffu) != 0x00001000u) { + f_sig += 0x00001000u; + } + } else { + f_sig += 0x00001000u; + } + h_sig = (uint16_t)(f_sig >> 13); + /* + * If the rounding causes a bit to spill into h_exp, it will + * increment h_exp by one and h_sig will be zero. This is the + * correct result. h_exp may increment to 15, at greatest, in + * which case the result overflows to a signed inf. + */ + if constexpr (gen_overflow) { + h_sig += h_exp; + if (h_sig == 0x7c00u) { + // FloatStatus::RaiseOverflow(); + throw std::overflow_error(""); + } + return h_sgn + h_sig; + } else { + return h_sgn + h_exp + h_sig; + } +} + +// Taken from +// https://github.com/numpy/numpy/blob/70fde29fdd4d8fcc6098df7ef8a34c84844e347f/numpy/_core/src/common/half_private.hpp#L269 +constexpr uint32_t ToFloatBits(uint16_t h) { + uint16_t h_exp = (h & 0x7c00u); + uint32_t f_sgn = ((uint32_t)h & 0x8000u) << 16; + switch (h_exp) { + case 0x0000u: { // 0 or subnormal + uint16_t h_sig = (h & 0x03ffu); + // Signed zero + if (h_sig == 0) { + return f_sgn; + } + // Subnormal + h_sig <<= 1; + while ((h_sig & 0x0400u) == 0) { + h_sig <<= 1; + h_exp++; + } + uint32_t f_exp = ((uint32_t)(127 - 15 - h_exp)) << 23; + uint32_t f_sig = ((uint32_t)(h_sig & 0x03ffu)) << 13; + return f_sgn + f_exp + f_sig; + } + case 0x7c00u: // inf or NaN + // All-ones exponent and a copy of the significand + return f_sgn + 0x7f800000u + (((uint32_t)(h & 0x03ffu)) << 13); + default: // normalized + // Just need to adjust the exponent and shift + return f_sgn + (((uint32_t)(h & 0x7fffu) + 0x1c000u) << 13); + } +} + +npy_half npy_float_to_half(float f) { + return {FromFloatBits(BitCast(f))}; +} + +float npy_half_to_float(npy_half h) { + return BitCast(ToFloatBits(h.value)); +} + +template <> +npy_half atomic_fadd(npy_half *loc, npy_half value, + std::memory_order order) { + npy_half old_value; + + const std::lock_guard lock(atomic_op_guard); + old_value = *loc; + *loc = npy_float_to_half(npy_half_to_float(old_value) + + npy_half_to_float(value)); + + return old_value; +} + +class AtomicOp { +public: + AtomicOp(const uint64_t *ptr, size_t numel, std::memory_order order) + : ptr(ptr), numel(numel), order(order) {} + + void apply() { + for (size_t i = 0; i < numel; ++i) { + applyAt(reinterpret_cast(ptr[i]), i); + } + } + + virtual ~AtomicOp() = default; + +protected: + virtual void applyAt(void *, size_t i) = 0; + + const uint64_t *ptr; + size_t numel; + std::memory_order order; +}; + +template class AtomicRMWOpBase : public AtomicOp { +public: + AtomicRMWOpBase(const uint64_t *ptr, const void *val, void *ret, + const bool *mask, size_t numel, std::memory_order order) + : AtomicOp(ptr, numel, order), val(val), ret(ret), mask(mask) {} + +protected: + void applyAt(void *loc, size_t i) override final { + if (mask[i]) { + DType *ptr = static_cast(loc); + *(static_cast(ret) + i) = + applyAtMasked(ptr, *(static_cast(val) + i), order); + } + } + + virtual DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) = 0; + + const void *val; + void *ret; + const bool *mask; +}; + +template +class AtomicRMWOp : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_add_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc + value; + } + return old_val; + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + return atomic_fadd(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_and_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc & value; + } + return old_val; + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_or_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc | value; + } + return old_val; + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = std::atomic_fetch_xor_explicit(atomic_loc, value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = *loc ^ value; + } + return old_val; + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + return atomic_cmp(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + return atomic_cmp(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, + std::memory_order order) override { + DType old_val; + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = + reinterpret_cast *>(loc); + old_val = atomic_loc->exchange(value, order); + } else { + const std::lock_guard lock(atomic_op_guard); + old_val = *loc; + *loc = value; + } + return old_val; + } +}; + +template +void atomic_compare_exchange_strong(void *loc, void *expected, + const void *desired, size_t i, + std::memory_order order) { + T desired_val = *(static_cast(desired) + i); + T *expected_uint = static_cast(expected) + i; + + if constexpr (is_reinterpret_cast_to_atomic_safe) { + std::atomic *atomic_loc = reinterpret_cast *>(loc); + atomic_loc->compare_exchange_strong(*expected_uint, desired_val, order, + order); + } else { + const std::lock_guard lock(atomic_op_guard); + T *atomic_loc = static_cast(loc); + if (*atomic_loc == *expected_uint) { + *atomic_loc = desired_val; + } else { + *expected_uint = *atomic_loc; + } + } +} + +class AtomicCASOp : public AtomicOp { +public: + AtomicCASOp(const uint64_t *ptr, void *expected, const void *desired, + size_t itemsize, size_t numel, std::memory_order order) + : AtomicOp(ptr, numel, order), expected(expected), desired(desired), + itemsize(itemsize) {} + +protected: + void applyAt(void *loc, size_t i) override { + // Atomic operations perform bitwise comparison, so it's safe to + // use number of bytes (itemsize) to determine the type of pointers + if (itemsize == 1) { + atomic_compare_exchange_strong(loc, expected, desired, i, order); + } else if (itemsize == 2) { + atomic_compare_exchange_strong(loc, expected, desired, i, + order); + } else if (itemsize == 4) { + atomic_compare_exchange_strong(loc, expected, desired, i, + order); + } else if (itemsize == 8) { + atomic_compare_exchange_strong(loc, expected, desired, i, + order); + } else { + throw std::invalid_argument("Invalid byte size"); + } + } + +private: + void *expected; + const void *desired; + size_t itemsize; +}; + +// This is a workaround because explicit template parameter list for lambdas is +// a C++20 extension: +// auto try_make_op = [&]() { +// if (dtype.is(pybind11::dtype::of())) { +// atomic_op = std::make_unique>(ptr, val, ret, mask, +// numel, order); +// } +// }; +template struct OpCreator { + pybind11::dtype dtype; + const uint64_t *ptr; + const void *val; + void *ret; + const bool *mask; + size_t numel; + std::memory_order order; + std::unique_ptr &atomic_op; + + template void create() { + if (!atomic_op && dtype.is(pybind11::dtype::of())) { + atomic_op = std::make_unique>(ptr, val, ret, mask, + numel, order); + } + } +}; + +template <> template <> void OpCreator::create() { + if (!atomic_op && dtype.char_() == 'e') { // float16 + // workaround until https://github.com/pybind/pybind11/issues/4061 is + // implemented + atomic_op = std::make_unique>( + ptr, val, ret, mask, numel, order); + } +}; + +template +std::unique_ptr +makeAtomicRMWOp(pybind11::dtype dtype, const uint64_t *ptr, const void *val, + void *ret, const bool *mask, size_t numel, + std::memory_order order) { + // Iterate over all supported data types, make one that matches, and return + std::unique_ptr atomic_op; + OpCreator try_make_op{dtype, ptr, val, ret, + mask, numel, order, atomic_op}; + + (try_make_op.template create(), ...); + if (!atomic_op) { + throw std::invalid_argument("Unsupported data type"); + } + // Make it a unique_ptr + return atomic_op; +} + +} // namespace + +void init_triton_interpreter(py::module &&m) { + using ret = py::return_value_policy; + + py::enum_(m, "MEM_SEMANTIC", py::module_local()) + .value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE) + .value("ACQUIRE", MemSemantic::ACQUIRE) + .value("RELEASE", MemSemantic::RELEASE) + .value("RELAXED", MemSemantic::RELAXED) + .export_values(); + + py::enum_(m, "RMW_OP", py::module_local()) + .value("ADD", RMWOp::ADD) + .value("FADD", RMWOp::FADD) + .value("AND", RMWOp::AND) + .value("OR", RMWOp::OR) + .value("XOR", RMWOp::XOR) + .value("XCHG", RMWOp::XCHG) + .value("MAX", RMWOp::MAX) + .value("MIN", RMWOp::MIN) + .value("UMIN", RMWOp::UMIN) + .value("UMAX", RMWOp::UMAX) + .export_values(); + + m.def("load", + [](py::array_t ptr, py::array_t mask, py::array other, + py::dtype ret_dtype) -> py::array { + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_others = other.reshape({numel}); + for (size_t i = 0; i < ptr.size(); ++i) { + if (reshaped_mask.at(i)) + memcpy(ret.mutable_data(i), + reinterpret_cast(reshaped_ptr.at(i)), + ret_dtype.itemsize()); + else + memcpy(ret.mutable_data(i), reshaped_others.data(i), + ret_dtype.itemsize()); + } + return ret.reshape(shape); + }); + + m.def("store", + [](py::array_t ptr, py::array value, py::array_t mask) { + int numel = ptr.size(); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_value = value.reshape({numel}); + for (size_t i = 0; i < ptr.size(); ++i) { + if (reshaped_mask.at(i)) { + memcpy(reinterpret_cast(reshaped_ptr.mutable_at(i)), + reshaped_value.data(i), value.dtype().itemsize()); + } + } + }); + + m.def("atomic_rmw", + [](RMWOp rmw_op, py::array_t ptr, py::array val, + py::array_t mask, MemSemantic sem) -> py::array { + std::memory_order order = mem_semantic_map[sem]; + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + auto ret_dtype = val.dtype(); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_val = val.reshape({numel}); + auto *ptr_data = reshaped_ptr.data(); + auto *mask_data = reshaped_mask.data(); + auto *val_data = static_cast(reshaped_val.data()); + auto *ret_data = static_cast(ret.mutable_data()); + + std::unique_ptr atomic_op; + +#define MAKE_ATOMIC_RMW_OP(OP_NAME, ...) \ + case OP_NAME: \ + atomic_op = makeAtomicRMWOp( \ + ret_dtype, ptr_data, val_data, ret_data, mask_data, numel, order); \ + break; + + switch (rmw_op) { + MAKE_ATOMIC_RMW_OP(RMWOp::ADD, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::FADD, npy_half, float, double) + MAKE_ATOMIC_RMW_OP(RMWOp::AND, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::OR, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::XOR, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::MAX, int32_t, int64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::UMAX, uint32_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::MIN, int32_t, int64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::UMIN, uint32_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::XCHG, int32_t, uint32_t, int64_t, + uint64_t) + default: + throw std::invalid_argument("Unsupported RMW operation"); + } + +#undef MAKE_ATOMIC_RMW_OP + + atomic_op->apply(); + return ret.reshape(shape); + }); + + m.def("atomic_cas", + [](py::array_t ptr, py::array &cmp, py::array &val, + MemSemantic sem) -> py::array { + std::memory_order order = mem_semantic_map[sem]; + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + auto ret_dtype = cmp.dtype(); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array reshaped_cmp = cmp.reshape({numel}); + py::array reshaped_val = val.reshape({numel}); + auto itemsize = cmp.itemsize(); + memcpy(static_cast(ret.mutable_data()), + static_cast(reshaped_cmp.data()), + itemsize * numel); + AtomicCASOp(reshaped_ptr.data(), ret.mutable_data(), + static_cast(reshaped_val.data()), itemsize, + numel, order) + .apply(); + return ret.reshape(shape); + }); +} diff --git a/third_party/sunrise/python/src/ir.cc b/third_party/sunrise/python/src/ir.cc new file mode 100644 index 000000000..0aa6ca313 --- /dev/null +++ b/third_party/sunrise/python/src/ir.cc @@ -0,0 +1,1888 @@ +#include "ir.h" + +#include +#include +#include +#include +#include + +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Transforms/LocationSnapshot.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/SourceMgr.h" + +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" + +#include "llvm/ADT/SmallVector.h" + +void setAsyncTaskIds(mlir::Operation *op, + llvm::ArrayRef asyncTaskIds) { + llvm::SmallVector sortedAsyncTaskIds(asyncTaskIds.begin(), + asyncTaskIds.end()); + sort(sortedAsyncTaskIds); + auto i32Ty = IntegerType::get(op->getContext(), 32); + auto size = static_cast(sortedAsyncTaskIds.size()); + auto vecTy = VectorType::get(size, i32Ty); + op->setAttr("async_task_id", + DenseI32ArrayAttr::get(op->getContext(), sortedAsyncTaskIds)); +} + +namespace { + +namespace py = pybind11; +using namespace mlir; +using namespace triton; +namespace tt = triton; +namespace ttg = triton::gpu; +namespace ttng = triton::nvidia_gpu; + +llvm::raw_fd_ostream &mlir_dumps() { + std::error_code EC; + static llvm::raw_fd_ostream S(::triton::tools::getStrEnv("MLIR_DUMP_PATH"), + EC, llvm::sys::fs::CD_CreateAlways); + assert(!EC); + return S; +} + +llvm::raw_ostream &mlir_dumps_or_dbgs() { + if (!::triton::tools::getStrEnv("MLIR_DUMP_PATH").empty()) { + return mlir_dumps(); + } else { + return llvm::dbgs(); + } +} + +// Run the pass manager under a source manager diagnostic handler, which +// enables emitted MLIR diagnostics to directly reference Python source +// code. This diagnostic handler supports filtering diagnostic info by +// severity levels. +struct TritonSourceMgrDiagnosticHandler : public SourceMgrDiagnosticHandler { + TritonSourceMgrDiagnosticHandler(MLIRContext *ctx, + DiagnosticSeverity minSeverity) + : SourceMgrDiagnosticHandler(sourceMgr, ctx, llvm::errs()) { + setHandler([this, minSeverity](Diagnostic &diag) { + auto severity = diag.getSeverity(); + switch (severity) { + case DiagnosticSeverity::Error: + break; + case DiagnosticSeverity::Warning: + if (minSeverity == DiagnosticSeverity::Error) + return success(); + break; + case DiagnosticSeverity::Remark: + if (minSeverity == DiagnosticSeverity::Error || + minSeverity == DiagnosticSeverity::Warning) + return success(); + break; + case DiagnosticSeverity::Note: + // notes are handled somewhere else. + return failure(); + default: + llvm_unreachable("Unknown diagnostic severity"); + } + emitDiagnostic(diag); + return success(); + }); + } + + llvm::SourceMgr sourceMgr; +}; + +std::string locationToString(Location loc) { + std::string str; + llvm::raw_string_ostream os(str); + loc.print(os); + os.flush(); // Make sure all the content is dumped into the 'str' string + return str; +} + +// Function to parse a comma-separated string into a vector of C-style strings +llvm::SmallVector +parseCommaSeparatedValues(const std::string &input, + llvm::SmallVector &storage) { + llvm::SmallVector split; + llvm::SmallVector result; + StringRef(input.c_str()).split(split, ','); + llvm::transform(split, std::back_inserter(result), [&storage](StringRef str) { + // StringRefs are not always null-terminated. + // The purpose for this storage pattern is to + // produce a collection of C-strings that are. + storage.push_back(str.str()); + return storage.back().c_str(); + }); + return result; +} + +void outputWarning(Location loc, const std::string &msg) { + std::string locStr = locationToString(loc); + + PyErr_WarnEx(PyExc_UserWarning, (locStr + ": " + msg).c_str(), + /*stack_level=*/2); +} + +// Allow dump a reproducer in the console on crash. +struct ConsoleReproducerStream : public mlir::ReproducerStream { + ~ConsoleReproducerStream() override {} + + StringRef description() override { + return "std::errs, please share the reproducer above with Triton project."; + } + raw_ostream &os() override { return llvm::errs(); } +}; + +ReproducerStreamFactory makeConsoleReproducer() { + return [](std::string &error) -> std::unique_ptr { + return std::make_unique(); + }; +} + +OpPrintingFlags getOpPrintingFlags() { + auto printingFlags = OpPrintingFlags(); + printingFlags.enableDebugInfo(); + return printingFlags; +} + +py::list getTensorDescMetadata(ModuleOp &mod) { + py::list result; + triton::FuncOp kernelFunc; + mod.walk([&](triton::FuncOp func) { + if (triton::isKernel(func)) { + kernelFunc = func; + return WalkResult::interrupt(); + } + return WalkResult::skip(); + }); + assert(kernelFunc); + + for (auto [i, argTy] : llvm::enumerate(kernelFunc.getArgumentTypes())) { + auto descTy = dyn_cast(argTy); + if (!descTy) + continue; + + auto blockType = descTy.getBlockType(); + auto encoding = blockType.getEncoding(); + auto mmaEncoding = dyn_cast(encoding); + auto swizzle = ttng::getTMASwizzleMode(nullptr, descTy); + auto elemType = ttng::getTMAElementType(nullptr, descTy); + assert(swizzle.has_value()); + assert(elemType.has_value()); + auto blockSize = ttng::getTMABlockShape(blockType, /*packedSize=*/false); + py::dict metadata; + metadata["swizzle"] = *swizzle; + metadata["elem_size"] = descTy.getBlockType().getElementTypeBitWidth() / 8; + metadata["elem_type"] = *elemType; + metadata["block_size"] = + std::vector(blockSize.begin(), blockSize.end()); + metadata["fp4_padded"] = mmaEncoding && mmaEncoding.getFp4Padded(); + result.append(std::move(metadata)); + } + return result; +} + +} // anonymous namespace + +/*****************************************************************************/ +/* Python bindings for ir */ +/*****************************************************************************/ + +void init_triton_ir(py::module &&m) { + using ret = py::return_value_policy; + using namespace pybind11::literals; + + py::enum_(m, "PADDING_OPTION", py::module_local()) + .value("PAD_ZERO", PaddingOption::PAD_ZERO) + .value("PAD_NAN", PaddingOption::PAD_NAN) + .export_values(); + + py::enum_(m, "CACHE_MODIFIER", py::module_local()) + .value("NONE", CacheModifier::NONE) + .value("CA", CacheModifier::CA) + .value("CG", CacheModifier::CG) + .value("WB", CacheModifier::WB) + .value("CS", CacheModifier::CS) + .value("WT", CacheModifier::WT) + .value("CV", CacheModifier::CV) + .export_values(); + + py::enum_(m, "MEM_SEMANTIC", py::module_local()) + .value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE) + .value("ACQUIRE", MemSemantic::ACQUIRE) + .value("RELEASE", MemSemantic::RELEASE) + .value("RELAXED", MemSemantic::RELAXED) + .export_values(); + + py::enum_(m, "MEM_SYNC_SCOPE", py::module_local()) + .value("GPU", MemSyncScope::GPU) + .value("CTA", MemSyncScope::CTA) + .value("SYSTEM", MemSyncScope::SYSTEM) + .export_values(); + + py::enum_(m, "EVICTION_POLICY", py::module_local()) + .value("NORMAL", EvictionPolicy::NORMAL) + .value("EVICT_FIRST", EvictionPolicy::EVICT_FIRST) + .value("EVICT_LAST", EvictionPolicy::EVICT_LAST) + .export_values(); + + py::enum_(m, "ATOMIC_OP", py::module_local()) + .value("ADD", RMWOp::ADD) + .value("FADD", RMWOp::FADD) + .value("AND", RMWOp::AND) + .value("OR", RMWOp::OR) + .value("XOR", RMWOp::XOR) + .value("XCHG", RMWOp::XCHG) + .value("MAX", RMWOp::MAX) + .value("MIN", RMWOp::MIN) + .value("UMIN", RMWOp::UMIN) + .value("UMAX", RMWOp::UMAX); + + py::enum_(m, "DESCRIPTOR_REDUCE_KIND", + py::module_local()) + .value("ADD", DescriptorReduceKind::ADD) + .value("AND", DescriptorReduceKind::AND) + .value("OR", DescriptorReduceKind::OR) + .value("XOR", DescriptorReduceKind::XOR) + .value("MAX", DescriptorReduceKind::MAX) + .value("MIN", DescriptorReduceKind::MIN) + .value("INC", DescriptorReduceKind::INC) + .value("DEC", DescriptorReduceKind::DEC); + + py::enum_(m, "ROUNDING_MODE", py::module_local()) + .value("RTZ", RoundingMode::RTZ) + .value("RTNE", RoundingMode::RTNE); + + py::enum_(m, "PROPAGATE_NAN", py::module_local()) + .value("NONE", PropagateNan::NONE) + .value("ALL", PropagateNan::ALL); + + py::enum_(m, "INPUT_PRECISION", py::module_local()) + .value("TF32", InputPrecision::TF32) + .value("TF32x3", InputPrecision::TF32x3) + .value("IEEE", InputPrecision::IEEE) + .export_values(); + + py::enum_(m, "ScaleDotElemTypeTY", py::module_local()) + .value("E4M3", ScaleDotElemType::E4M3) + .value("E5M2", ScaleDotElemType::E5M2) + .value("E2M3", ScaleDotElemType::E2M3) + .value("E3M2", ScaleDotElemType::E3M2) + .value("E2M1", ScaleDotElemType::E2M1) + .value("BF16", ScaleDotElemType::BF16) + .value("FP16", ScaleDotElemType::FP16) + .export_values(); + + py::class_(m, "context", py::module_local()) + .def(py::init<>()) + .def("printOpOnDiagnostic", + [](MLIRContext &self, bool v) { self.printOpOnDiagnostic(v); }) + .def("printStackTraceOnDiagnostic", + [](MLIRContext &self, bool v) { + self.printStackTraceOnDiagnostic(v); + }) + .def("disable_multithreading", + [](MLIRContext &self) { self.disableMultithreading(); }); + + py::class_(m, "source_mgr_diag", + py::module_local()) + .def(py::init()); + + m.def("load_dialects", [](MLIRContext &context) { + DialectRegistry registry; + registry.insert(); + mlir::LLVM::registerInlinerInterface(registry); + registerBuiltinDialectTranslation(registry); + registerLLVMDialectTranslation(registry); + mlir::LLVM::registerInlinerInterface(registry); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + py::class_(m, "type", py::module_local()) + .def("is_integer", + [](Type &self, unsigned width) { return self.isInteger(width); }) + .def("is_fp16", &Type::isF16) + .def("__eq__", + [](Type &self, py::object &other) { + Type *other_ty = py::cast(other); + return (other_ty != nullptr) && (*other_ty == self); + }) + .def("__ne__", + [](Type &self, py::object &other) { + Type *other_ty = py::cast(other); + return (other_ty == nullptr) || (*other_ty != self); + }) + .def("__str__", [](Type &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }); + + py::class_(m, "function_type", py::module_local()) + .def("param_types", [](FunctionType &self) { + return std::vector(self.getInputs().begin(), + self.getInputs().end()); + }); + + py::class_(m, "location", py::module_local()) + .def("__str__", [](Location &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }); + + py::class_(m, "value", py::module_local()) + .def("set_attr", + [](Value &self, std::string &name, Attribute &attr) -> void { + if (Operation *definingOp = self.getDefiningOp()) + definingOp->setAttr(name, attr); + else { + auto arg = mlir::cast(self); + int id = arg.getArgNumber(); + std::string attrName = name + "_arg" + std::to_string(id); + Block *owner = arg.getOwner(); + if (owner->isEntryBlock() && + !isa(owner->getParentOp())) { + owner->getParentOp()->setAttr(attrName, attr); + } + } + }) + .def("get_context", &Value::getContext) + .def("replace_all_uses_with", + [](Value &self, Value &newValue) { + self.replaceAllUsesWith(newValue); + }) + .def("get_type", &Value::getType) + .def("id", [](Value &self) { + // The Value is identified by and compared with + // other Values via the underlying ValueImpl + return (uint64_t)self.getImpl(); + }); + + py::class_(m, "op_result", py::module_local()); + + py::class_(m, "block_argument", py::module_local()); + + py::class_(m, "region", py::module_local()) + .def("get_parent_region", &Region::getParentRegion, ret::reference) + .def("size", [](Region &self) { return self.getBlocks().size(); }) + .def("empty", &Region::empty) + .def("id", [](Region &self) { return (uint64_t)&self; }) + .def("push_back", + [](Region &self, Block *block) { self.push_back(block); }) + .def("push_front", + [](Region &self, Block *block) { self.push_front(block); }); + + py::class_(m, "block", py::module_local()) + .def("arg", + [](Block &self, int index) -> BlockArgument { + if (index >= self.getNumArguments()) + throw pybind11::index_error("Block argument index out of range"); + return self.getArgument(index); + }) + .def("add_argument", + [](Block &self, Type ty) { + auto loc = UnknownLoc::get(ty.getContext()); + self.addArgument(ty, loc); + }) + .def("get_num_arguments", &Block::getNumArguments) + .def("get_argument", &Block::getArgument) + .def("dump", &Block::dump) + .def("move_before", + [](Block &self, Block &dst) { self.moveBefore(&dst); }) + .def("insert_before", &Block::insertBefore) + .def("get_parent", &Block::getParent, ret::reference) + .def("merge_block_before", + [](Block &self, Block &dst) { + // ref: RewriterBase::mergeBlocks() + if (self.getNumArguments() != 0) + throw std::runtime_error( + "This block has arguments, don't merge"); + dst.getOperations().splice(dst.begin(), self.getOperations()); + self.dropAllUses(); + self.erase(); + }) + .def("replace_use_in_block_with", + [](Block &self, Value &v, Value &newVal) { + v.replaceUsesWithIf(newVal, [&](OpOperand &operand) { + Operation *user = operand.getOwner(); + Block *currentBlock = user->getBlock(); + while (currentBlock) { + if (currentBlock == &self) + return true; + // Move up one level + currentBlock = + currentBlock->getParent()->getParentOp()->getBlock(); + } + return false; + }); + }) + .def("__str__", + [](Block &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return str; + }) + .def("has_terminator", + [](Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) + .def("has_return", + [](Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) + .def("erase", [](Block &self) { self.erase(); }) + .def("id", [](Block &self) { return (uint64_t)&self; }); + + py::class_(m, "attribute", py::module_local()); + py::class_(m, "integer_attr", py::module_local()); + py::class_(m, "bool_attr", py::module_local()); + py::class_(m, "unit_attr", py::module_local()); + + // Ops + py::class_(m, "OpState", py::module_local()) + .def("set_attr", + [](OpState &self, std::string &name, Attribute &attr) -> void { + self->setAttr(name, attr); + }) + .def("get_num_results", + [](OpState &self) -> unsigned { return self->getNumResults(); }) + .def("get_result", + [](OpState &self, unsigned idx) -> Value { + if (idx >= self->getNumResults()) + throw pybind11::index_error("Op result index out of range"); + return self->getResult(idx); + }) + .def( + "get_region", + [](OpState &self, unsigned idx) -> Region & { + if (idx >= self->getNumRegions()) + throw pybind11::index_error("Op region index out of range"); + return self->getRegion(idx); + }, + ret::reference) + .def( + "get_body", + [](scf::ForOp &self, unsigned idx) -> Block * { + if (idx >= self->getNumRegions()) + throw pybind11::index_error("Op region index out of range"); + return self.getBody(idx); + }, + ret::reference) + .def("dump", [](OpState &self) { self->dump(); }) + .def("__str__", + [](OpState &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + auto printingFlags = getOpPrintingFlags(); + self->print(os, printingFlags); + return str; + }) + .def("str_nodebug", + [](OpState &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + self->print(os); + return str; + }) + .def("append_operand", + [](OpState &self, Value &val) { + self->insertOperands(self->getNumOperands(), val); + }) + .def("verify", + [](OpState &self) -> bool { + return succeeded(verify(self.getOperation())); + }) + .def("get_operation", [](OpState &self) { return self.getOperation(); }); + + // scf Ops + py::class_(m, "ForOp", py::module_local()) + .def("get_induction_var", &scf::ForOp::getInductionVar); + + py::class_(m, "IfOp", py::module_local()) + .def("get_then_block", &scf::IfOp::thenBlock, ret::reference) + .def("get_else_block", &scf::IfOp::elseBlock, ret::reference) + .def("get_then_yield", &scf::IfOp::thenYield) + .def("get_else_yield", &scf::IfOp::elseYield); + py::class_(m, "YieldOp", py::module_local()); + py::class_(m, "WhileOp", py::module_local()) + .def("get_before", &scf::WhileOp::getBefore, ret::reference) + .def("get_after", &scf::WhileOp::getAfter, ret::reference); + py::class_(m, "ConditionOp", py::module_local()); + + py::class_>( + m, "operation", py::module_local()) + .def("get_name", + [](Operation &self) { + llvm::StringRef opName = self.getName().getStringRef(); + return opName.str(); + }) + .def("get_num_operands", &Operation::getNumOperands) + .def("get_operand", &Operation::getOperand) + .def("get_num_results", &Operation::getNumResults) + .def("get_result", &Operation::getResult) + .def("get_num_regions", &Operation::getNumRegions) + .def("get_region", &Operation::getRegion, ret::reference) + .def("get_block", &Operation::getBlock, ret::reference) + .def("get_str_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }) + .def("get_bool_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::bool_(ret.getValue()); + }) + .def("get_flat_symbol_ref_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }); + + // dynamic_attr is used to transfer ownership of the MLIR context to the + // module + py::class_(m, "module", py::module_local(), + py::dynamic_attr()) + .def("dump", &ModuleOp::dump) + .def("str", + [](ModuleOp &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + auto printingFlags = getOpPrintingFlags(); + self.print(os, printingFlags); + return str; + }) + .def("push_back", + [](ModuleOp &self, FuncOp &funcOp) -> void { + self.push_back(funcOp); + }) + .def("get_entry_func_name", + [](ModuleOp &self) -> std::string { + for (auto &op : self.getOps()) { + if (auto func = dyn_cast(op)) { + if (triton::isKernel(func)) + return func.getName().str(); + } + } + return ""; + }) + .def("has_function", + [](ModuleOp &self, std::string &funcName) -> bool { + if (self.lookupSymbol(funcName)) + return true; + return false; + }) + .def("get_function", + [](ModuleOp &self, std::string &funcName) -> FuncOp { + return self.lookupSymbol(funcName); + }) + /* + * def ty_to_cpp(ty) is the consumer of this function. + * If the type is a ptr it expects ty[0] == '*', else the type itself. + */ + + .def("get_function_signature", + [](ModuleOp &self, FuncOp &func) -> std::vector { + std::vector strVec; + + auto type = func.getFunctionType(); + unsigned numArgs = type.getNumInputs(); + for (unsigned i = 0; i != numArgs; ++i) { + std::string tempType; + llvm::raw_string_ostream os(tempType); + + auto ty = type.getInput(i); + if (auto attributes = func.getCallableArgAttrs()) { + Attribute attr = attributes[i]; + // Check for tt.nv_tma_desc = 1 + if (auto dAttr = dyn_cast(attr)) { + if (dAttr.contains("tt.nv_tma_desc")) { + strVec.push_back("nvTmaDesc"); + continue; + } + } + } + if (auto ptrType = dyn_cast(ty)) { + auto pType = ptrType.getPointeeType(); + os << "*"; + pType.print(os); + } else { + ty.print(os); + } + strVec.push_back(tempType); + } + return strVec; + }) + .def("get_int_attr", + [](ModuleOp &self, std::string name) -> py::object { + auto ret = self->getAttrOfType(name); + if (!ret) + return py::none(); + return py::int_(ret.getInt()); + }) + .def("get_tensordesc_metadata", getTensorDescMetadata) + .def("create_location_snapshot", + [](ModuleOp &self, const std::string &fileName) -> void { + auto printingFlags = getOpPrintingFlags(); + if (failed(generateLocationsFromIR(fileName, self, printingFlags))) + throw std::runtime_error("Failed to create location snapshot"); + }) + .def("walk", + [](ModuleOp &self, const std::function &fn) { + self.walk(fn); + }); + + m.def("make_attr", [](const std::vector &values, MLIRContext &context) { + return mlir::cast(DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(values.size())}, + IntegerType::get(&context, 32)), + values)); + }); + + m.def( + "parse_mlir_module", + [](const std::string &inputFilename, MLIRContext &context) { + // parse module + OwningOpRef module = + parseSourceFile(inputFilename, &context); + if (!module) + throw std::runtime_error("Parse MLIR file failed."); + return module->clone(); + }, + ret::take_ownership); + + py::class_(m, "function", py::module_local()) + // .def_property_readonly("attrs", &ir::function::attrs) + // .def("add_attr", &ir::function::add_attr); + .def("args", + [](FuncOp &self, unsigned idx) -> BlockArgument { + if (idx >= self.getNumArguments()) + throw pybind11::index_error( + "Function argument index out of range"); + return self.getArgument(idx); + }) + .def("get_num_args", &FuncOp::getNumArguments) + .def( + "add_entry_block", + [](FuncOp &self) -> Block * { return self.addEntryBlock(); }, + ret::reference) + .def( + "set_arg_attr", + [](FuncOp &self, int arg_no, const std::string &name, int val) { + if (arg_no >= self.getNumArguments()) + throw pybind11::index_error( + "Function argument index out of range"); + // set arg attributes "name" to value "val" + auto attrTy = IntegerType::get(self.getContext(), 32); + self.setArgAttr(arg_no, name, IntegerAttr::get(attrTy, val)); + }, + ret::reference) + // .def("has_attr", &::FuncOp::hasAttr) + .def("finalize", [](FuncOp &self) -> void {}) + .def_property_readonly("type", &FuncOp::getFunctionType) + .def("reset_type", &FuncOp::setType); + + py::class_(m, "InsertPoint", py::module_local()); + + py::class_(m, "builder", py::module_local(), + py::dynamic_attr()) + .def(py::init()) + // getters + .def("create_module", + [](TritonOpBuilder &self) -> ModuleOp { + return self.create(); + }) + // insertion block/point + .def("set_insertion_point_to_start", + [](TritonOpBuilder &self, Block &block) -> void { + self.setInsertionPointToStart(block); + }) + .def("set_insertion_point_to_end", + [](TritonOpBuilder &self, Block &block) { + self.setInsertionPointToEnd(block); + }) + .def("set_insertion_point_after", + [](TritonOpBuilder &self, Operation &op) { + self.setInsertionPointAfter(op); + }) + .def( + "get_insertion_block", + [](TritonOpBuilder &self) -> Block * { + return self.getBuilder().getInsertionBlock(); + }, + ret::reference) + .def("get_insertion_point", + [](TritonOpBuilder &self) { + return self.getBuilder().saveInsertionPoint(); + }) + .def("restore_insertion_point", + [](TritonOpBuilder &self, OpBuilder::InsertPoint pt) { + self.restoreInsertionPoint(pt); + }) + .def("set_async_task_ids", + [](TritonOpBuilder &self, std::vector v) { + self.setAsyncTaskIds(v); + }) + .def("unset_async_task_ids", + [](TritonOpBuilder &self) { self.unsetAsyncTaskIds(); }) + // Attr + .def( + "get_unit_attr", + [](TritonOpBuilder &self) { return self.getBuilder().getUnitAttr(); }) + .def("get_bool_attr", + [](TritonOpBuilder &self, bool value) { + return self.getBuilder().getBoolAttr(value); + }) + .def("get_int32_attr", + [](TritonOpBuilder &self, int32_t value) { + return self.getBuilder().getI32IntegerAttr(value); + }) + .def("get_string_attr", + [](TritonOpBuilder &self, std::string value) -> Attribute { + return self.getBuilder().getStringAttr(value); + }) + // Use arith.ConstantOp to create constants + // Constants + .def("get_int1", + [](TritonOpBuilder &self, bool v) -> Value { + return Value(self.create( + v, self.getBuilder().getI1Type())); + }) + .def("get_int8", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI8Type())); + }) + .def("get_int16", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI16Type())); + }) + .def("get_int32", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI32Type())); + }) + .def("get_int64", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI64Type())); + }) + .def("get_uint8", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI8Type())); + }) + .def("get_uint16", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI16Type())); + }) + .def("get_uint32", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI32Type())); + }) + .def("get_uint64", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI64Type())); + }) + .def("get_bf16", + [](TritonOpBuilder &self, float v) -> Value { + auto type = self.getBuilder().getBF16Type(); + return self.create( + APFloat(type.getFloatSemantics(), std::to_string(v)), type); + }) + .def("get_fp16", + [](TritonOpBuilder &self, float v) -> Value { + return self.create( + self.getBuilder().getF16FloatAttr(v)); + }) + .def("get_fp32", + [](TritonOpBuilder &self, float v) -> Value { + return self.create( + self.getBuilder().getF32FloatAttr(v)); + }) + .def("get_fp64", + [](TritonOpBuilder &self, double v) -> Value { + return self.create( + self.getBuilder().getF64FloatAttr(v)); + }) + .def("get_null_value", + [](TritonOpBuilder &self, Type type) -> Value { + if (auto floatTy = dyn_cast(type)) + return self.create( + APFloat(floatTy.getFloatSemantics(), 0), floatTy); + else if (auto intTy = dyn_cast(type)) + return self.create(0, intTy); + else + throw std::runtime_error("Not implemented"); + }) + .def("get_all_ones_value", + [](TritonOpBuilder &self, Type type) -> Value { + uint64_t val = 0xFFFFFFFFFFFFFFFF; + if (auto intTy = dyn_cast(type)) + return self.create(val, intTy); + else + throw std::runtime_error("Not implemented"); + }) + + // Types + .def("get_void_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getNoneType(); + }) + .def("get_int1_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI1Type(); + }) // or ret::copy? + .def("get_int8_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI8Type(); + }) + .def("get_int16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(16); + }) + .def("get_int32_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI32Type(); + }) + .def("get_int64_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI64Type(); + }) + .def("get_fp8e4nv_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e4b8_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e4b15_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI8Type(); + }) + .def("get_fp8e5_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e5b16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_half_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF16Type(); + }) + .def("get_bf16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getBF16Type(); + }) + .def("get_float_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF32Type(); + }) + .def("get_double_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF64Type(); + }) + .def("get_ptr_ty", + [](TritonOpBuilder &self, Type &type, int addrSpace) -> Type { + return PointerType::get(type, addrSpace); + }) + .def("get_block_ty", + [](TritonOpBuilder &self, Type &elementType, + std::vector &shape) -> Type { + return RankedTensorType::get(shape, elementType); + }) + .def("get_function_ty", + [](TritonOpBuilder &self, std::vector inTypes, + std::vector outTypes) -> Type { + return self.getBuilder().getFunctionType(inTypes, outTypes); + }) + // locs + .def("set_loc", + [](TritonOpBuilder &self, Location loc) { self.setLastLoc(loc); }) + .def("set_loc", + [](TritonOpBuilder &self, const std::string &fileName, int line, + int column) { self.setLastLoc(fileName, line, column); }) + .def("get_loc", + [](TritonOpBuilder &self) -> Location { return self.getLastLoc(); }) + + // Ops + .def("get_or_insert_function", + [](TritonOpBuilder &self, ModuleOp &module, std::string &funcName, + Type &funcType, std::string &visibility, + bool noinline) -> FuncOp { + if (Operation *funcOperation = module.lookupSymbol(funcName)) + return llvm::dyn_cast(funcOperation); + if (auto funcTy = dyn_cast(funcType)) { + llvm::SmallVector attrs = { + NamedAttribute( + self.getBuilder().getStringAttr("sym_visibility"), + self.getBuilder().getStringAttr(visibility)), + NamedAttribute(self.getBuilder().getStringAttr("noinline"), + self.getBuilder().getBoolAttr(noinline))}; + return self.create(funcName, funcTy, attrs); + } + throw std::invalid_argument("invalid function type"); + }) + .def( + "create_block", + [](TritonOpBuilder &self) -> Block * { + Region *parent = self.getBuilder().getBlock()->getParent(); + return self.getBuilder().createBlock(parent); + }, + ret::reference) + .def( + "create_block_with_parent", + [](TritonOpBuilder &self, Region &parent, + std::vector &argTypes) -> Block * { + // TODO: update arg loc + auto loc = self.getBuilder().getUnknownLoc(); + llvm::SmallVector argLocs(argTypes.size(), loc); + return self.getBuilder().createBlock(&parent, {}, argTypes, + argLocs); + }, + ret::reference) + .def( + "new_block", + [](TritonOpBuilder &self) -> Block * { return new Block(); }, + ret::reference) + // Function + .def("ret", + [](TritonOpBuilder &self, std::vector &vals) -> OpState { + return self.create(vals); + }) + .def("call", + [](TritonOpBuilder &self, FuncOp &func, std::vector &args) + -> OpState { return self.create(func, args); }) + // Unstructured control flow + .def("create_cond_branch", + [](TritonOpBuilder &self, Value condition, Block *trueDest, + Block *falseDest) -> OpState { + return self.create(condition, trueDest, + falseDest); + }) + .def("create_branch", + [](TritonOpBuilder &self, Block *dest, std::vector &args) + -> OpState { return self.create(dest, args); }) + // Structured control flow + .def("create_for_op", + [](TritonOpBuilder &self, Value &lb, Value &ub, Value &step, + std::vector &initArgs) -> scf::ForOp { + return self.create(lb, ub, step, initArgs); + }) + .def("create_if_op", + [](TritonOpBuilder &self, std::vector &retTypes, + Value &condition, bool withElse) -> scf::IfOp { + return self.create(retTypes, condition, withElse); + }) + .def("create_yield_op", + [](TritonOpBuilder &self, std::vector &yields) + -> scf::YieldOp { return self.create(yields); }) + .def("create_while_op", + [](TritonOpBuilder &self, std::vector &retTypes, + std::vector &initArgs) -> scf::WhileOp { + return self.create(retTypes, initArgs); + }) + .def("create_condition_op", + [](TritonOpBuilder &self, Value &cond, + std::vector &args) -> scf::ConditionOp { + return self.create(cond, args); + }) + + // miscellaneous + .def("create_make_range", + [](TritonOpBuilder &self, Type retTy, int start, int end) -> Value { + return self.create(retTy, start, end); + }) + + // Cast instructions + // Conversions for custom FP types (FP8 and non-standard rounding modes) + .def("create_fp_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType, + std::optional roundingMode) -> Value { + if (roundingMode.has_value()) + return self.create( + dstType, src, + RoundingModeAttr::get(self.getBuilder().getContext(), + roundingMode.value())); + else + return self.create(dstType, src); + }) + // Conversions for standard LLVM builtin types + .def("create_bitcast", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_si_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_ui_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_to_si", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_to_ui", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_ext", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_trunc", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_int_cast", + [](TritonOpBuilder &self, Value &src, Type &dstType, + bool isSigned) -> Value { + // get element type if necessary + Type srcType = src.getType(); + auto srcTensorType = dyn_cast(srcType); + auto dstTensorType = dyn_cast(dstType); + Type srcEltType = srcType; + Type dstEltType = dstType; + if (dstTensorType && srcTensorType) { + dstEltType = dstTensorType.getElementType(); + srcEltType = srcTensorType.getElementType(); + } + unsigned srcWidth = srcEltType.getIntOrFloatBitWidth(); + unsigned dstWidth = dstEltType.getIntOrFloatBitWidth(); + if (srcWidth == dstWidth) + return self.create(dstType, src); + else if (srcWidth > dstWidth) + return self.create(dstType, src); + else if (isSigned) + return self.create(dstType, src); + else + return self.create(dstType, src); + }) + .def("create_fmul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fdiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_frem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fadd", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fsub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_mul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_umulhi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_sdiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_udiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_srem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_urem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_add", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_sub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_fma", + [](TritonOpBuilder &self, Value &a, Value &b, Value &c) -> Value { + return Value(self.create(a, b, c)); + }) + .def("create_shl", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_lshr", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_ashr", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_minsi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_minui", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // minimumf follows the torch.minimum convention and returns NaN if either + // operand is NaN + .def("create_minimumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // minnumf follows the torch.fmin convention and returns the non-NaN + // operand + .def("create_minnumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_maxsi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_maxui", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // maximumf follows the torch.maximum convention and returns NaN if either + // operand is NaN + .def("create_maximumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // maxnumf follows the torch.fmax convention and returns the non-NaN + // operand + .def("create_maxnumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_clampf", + [](TritonOpBuilder &self, Value &input, Value &min, Value &max, + PropagateNan propagateNan) -> Value { + return Value(self.create(input, min, max, propagateNan)); + }) + .def("create_precise_sqrt", + [](TritonOpBuilder &self, Value &input) -> Value { + return Value(self.create(input)); + }) + .def("create_precise_divf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // AddPtr (similar to GEP) + .def("create_addptr", + [](TritonOpBuilder &self, Value &ptr, Value &offset) -> Value { + return self.create(ptr.getType(), ptr, offset); + }) + // Comparison (int) + .def("create_icmpSLE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sle, lhs, + rhs); + }) + .def("create_icmpSLT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::slt, lhs, + rhs); + }) + .def("create_icmpSGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sge, lhs, + rhs); + }) + .def("create_icmpSGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sgt, lhs, + rhs); + }) + .def("create_icmpULE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ule, lhs, + rhs); + }) + .def("create_icmpULT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ult, lhs, + rhs); + }) + .def("create_icmpUGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::uge, lhs, + rhs); + }) + .def("create_icmpUGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ugt, lhs, + rhs); + }) + .def("create_icmpEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::eq, lhs, + rhs); + }) + .def("create_icmpNE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ne, lhs, + rhs); + }) + // Comparison (float) + .def("create_fcmpOLT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OLT, lhs, + rhs); + }) + .def("create_fcmpOGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OGT, lhs, + rhs); + }) + .def("create_fcmpOLE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OLE, lhs, + rhs); + }) + .def("create_fcmpOGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OGE, lhs, + rhs); + }) + .def("create_fcmpOEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OEQ, lhs, + rhs); + }) + .def("create_fcmpONE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ONE, lhs, + rhs); + }) + .def("create_fcmpULT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ULT, lhs, + rhs); + }) + .def("create_fcmpUGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UGT, lhs, + rhs); + }) + .def("create_fcmpULE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ULE, lhs, + rhs); + }) + .def("create_fcmpUGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UGE, lhs, + rhs); + }) + .def("create_fcmpUEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UEQ, lhs, + rhs); + }) + .def("create_fcmpUNE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UNE, lhs, + rhs); + }) + // // Logical + .def("create_and", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_xor", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_or", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + // Input/Output + .def("create_load", + [](TritonOpBuilder &self, Value &ptrs, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + return self.create(ptrs, cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_store", + [](TritonOpBuilder &self, Value &ptrs, Value &value, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptrs, value, cacheModifier, evictionPolicy); + }) + .def("create_tensor_pointer_load", + [](TritonOpBuilder &self, Value &ptr, + std::vector &boundaryCheck, + std::optional paddingOption, + CacheModifier cacheModifier, EvictionPolicy evictionPolicy, + bool isVolatile) -> Value { + return self.create(ptr, boundaryCheck, paddingOption, + cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_tensor_pointer_store", + [](TritonOpBuilder &self, Value &ptr, Value &val, + std::vector &boundaryCheck, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptr, val, boundaryCheck, cacheModifier, + evictionPolicy); + }) + .def("create_masked_load", + [](TritonOpBuilder &self, Value &ptrs, Value &mask, + std::optional &other, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + return self.create(ptrs, mask, other.value_or(Value()), + cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_masked_store", + [](TritonOpBuilder &self, Value &ptrs, Value &val, Value &mask, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptrs, val, mask, cacheModifier, + evictionPolicy); + }) + .def("create_tensor_descriptor_type", + [](TritonOpBuilder &self, Type blockTy, bool isSigned) -> Type { + auto ctx = self.getContext(); + return triton::TensorDescType::get( + ctx, cast(blockTy), isSigned); + }) + .def("create_descriptor_load", + [](TritonOpBuilder &self, Value desc, std::vector &indices, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> Value { + auto descTy = cast(desc.getType()); + auto resTy = descTy.getSignlessBlockType(); + return self.create( + resTy, desc, indices, cacheModifier, evictionPolicy); + }) + .def("create_descriptor_gather", + [](TritonOpBuilder &self, Value desc, Value x_indices, Value y_index, + Type type) -> Value { + return self.create(type, desc, x_indices, + y_index); + }) + .def("create_descriptor_store", + [](TritonOpBuilder &self, Value desc, Value value, + std::vector &indices) -> void { + self.create(desc, value, indices); + }) + .def("create_descriptor_reduce", + [](TritonOpBuilder &self, DescriptorReduceKind kind, Value desc, + Value value, std::vector &indices) -> void { + self.create(kind, desc, value, indices); + }) + .def("create_descriptor_scatter", + [](TritonOpBuilder &self, Value desc, Value value, Value x_indices, + Value y_index) -> void { + self.create(desc, x_indices, y_index, value); + }) + .def("create_reshape", + [](TritonOpBuilder &self, Value &arg, std::vector &shape, + bool allowReorder) -> Value { + return self.create(shape, arg, allowReorder); + }) + .def("create_expand_dims", + [](TritonOpBuilder &self, Value &arg, int axis) -> Value { + return self.create(arg, axis); + }) + .def("create_cat", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + auto lhsType = dyn_cast(lhs.getType()); + auto rhsType = dyn_cast(rhs.getType()); + if (!(lhsType.getShape().size() == 1 && + rhsType.getShape().size() == 1)) + throw std::invalid_argument( + "shape not supported by cat. Expecting rank-1 inputs"); + std::vector shape{lhsType.getShape()[0] + + rhsType.getShape()[0]}; + return self.create( + RankedTensorType::get(shape, lhsType.getElementType()), lhs, + rhs); + }) + .def("create_join", + [](TritonOpBuilder &self, Value &a, Value &b) -> Value { + return self.create(a, b); + }) + .def("create_split", + [](TritonOpBuilder &self, Value &a) -> std::vector { + auto op = self.create(a); + return std::vector(op->result_begin(), op->result_end()); + }) + // Implements tl.trans and tl.permute. + .def("create_trans", + [](TritonOpBuilder &self, Value &arg, std::vector &order) + -> Value { return self.create(arg, order); }) + .def("create_broadcast", + [](TritonOpBuilder &self, Value &arg, + std::vector &shape) -> Value { + if (auto argType = dyn_cast(arg.getType())) + return self.createOrFold( + RankedTensorType::get(shape, argType.getElementType()), arg); + throw std::invalid_argument( + "arg is not of RankedTensorType, use create_splat"); + }) + .def("create_splat", + [](TritonOpBuilder &self, Type &retTy, Value &arg) -> Value { + return self.createOrFold(retTy, arg); + }) + // // atomic + .def("create_atomic_cas", + [](TritonOpBuilder &self, Value &ptr, Value &cmp, Value &val, + MemSemantic sem, MemSyncScope scope) -> Value { + Type dstType; + if (auto srcTensorType = + dyn_cast(ptr.getType())) { + Type dstElemType = + cast(srcTensorType.getElementType()) + .getPointeeType(); + dstType = + RankedTensorType::get(srcTensorType.getShape(), dstElemType); + } else { + auto ptrType = cast(getElementTypeOrSelf(ptr)); + dstType = ptrType.getPointeeType(); + } + return self.create(dstType, ptr, cmp, val, sem, + scope); + }) + .def("create_atomic_rmw", + [](TritonOpBuilder &self, RMWOp rmwOp, Value &ptr, Value &val, + Value &mask, MemSemantic sem, MemSyncScope scope) -> Value { + Type dstType; + if (auto srcTensorType = + dyn_cast(ptr.getType())) { + Type dstElemType = + cast(srcTensorType.getElementType()) + .getPointeeType(); + dstType = + RankedTensorType::get(srcTensorType.getShape(), dstElemType); + } else { + auto ptrType = cast(getElementTypeOrSelf(ptr)); + dstType = ptrType.getPointeeType(); + } + return self.create(dstType, rmwOp, ptr, val, mask, + sem, scope); + }) + // External + .def("create_extern_elementwise", + [](TritonOpBuilder &self, const std::string &libName, + const std::string &libPath, const std::string &symbol, + std::vector &argList, Type retType, bool isPure) -> Value { + return self.create(retType, argList, libName, + libPath, symbol, isPure); + }) + // Built-in instruction + .def("create_get_program_id", + [](TritonOpBuilder &self, int axis) -> Value { + if (axis < 0 || axis > 3) + throw pybind11::index_error("program_id must be in [0,3]"); + return self.create(axis); + }) + .def("create_get_num_programs", + [](TritonOpBuilder &self, int axis) -> Value { + if (axis < 0 || axis > 3) + throw pybind11::index_error("program_id must be in [0,3]"); + return self.create(axis); + }) + .def("create_dot", + [](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b, + mlir::Value &c, InputPrecision inputPrecision, + int maxNumImpreciseAcc) -> mlir::Value { + return self.create(c.getType(), a, b, c, inputPrecision, + maxNumImpreciseAcc); + }) + .def("create_dot_scaled", + [](TritonOpBuilder &self, mlir::Value &lhs, + std::optional &lhs_scale, + ScaleDotElemType lhs_format, mlir::Value &rhs, + std::optional &rhs_scale, + ScaleDotElemType rhs_format, bool fast_math, bool lhs_k_pack, + bool rhs_k_pack, mlir::Value &c) -> mlir::Value { + return self.create( + c.getType(), lhs, rhs, c, lhs_scale.value_or(Value()), + rhs_scale.value_or(Value()), lhs_format, rhs_format, fast_math, + lhs_k_pack, rhs_k_pack); + }) + .def("create_floor", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_ceil", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_exp", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_exp2", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_cos", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_sin", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_log", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_log2", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_erf", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_sqrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_rsqrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_fabs", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_iabs", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_reduce", + [](TritonOpBuilder &self, std::vector operands, int axis) + -> OpState { return self.create(operands, axis); }) + .def("create_reduce_ret", + [](TritonOpBuilder &self, py::args args) -> OpState { + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); + } + return self.create(return_values); + }) + .def("create_scan", + [](TritonOpBuilder &self, std::vector operands, int axis, + bool reverse) -> OpState { + return self.create(operands, axis, reverse); + }) + .def("create_scan_ret", + [](TritonOpBuilder &self, py::args args) -> OpState { + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); + } + return self.create(return_values); + }) + .def("create_ptr_to_int", + [](TritonOpBuilder &self, Value &val, Type &type) -> Value { + return self.create(type, val); + }) + .def("create_int_to_ptr", + [](TritonOpBuilder &self, Value &val, Type &type) -> Value { + return self.create(type, val); + }) + .def("create_select", + [](TritonOpBuilder &self, Value &condition, Value &trueValue, + Value &falseValue) -> Value { + return self.create(condition, trueValue, + falseValue); + }) + .def("create_inline_asm", + [](TritonOpBuilder &self, const std::string &inlineAsm, + const std::string &constraints, const std::vector &values, + const std::vector &types, bool isPure, + int pack) -> OpState { + return self.create( + types, inlineAsm, constraints, isPure, pack, values); + }) + .def("create_print", + [](TritonOpBuilder &self, const std::string &prefix, bool hex, + const std::vector &values, + const std::vector &isSigned) -> void { + auto prefixAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(prefix)); + self.create(prefixAttr, hex, values, isSigned); + }) + .def("create_assert", + [](TritonOpBuilder &self, Value &condition, + const std::string &message) -> void { + auto messageAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(message)); + self.create(condition, messageAttr); + }) + .def("create_assume", + [](TritonOpBuilder &self, Value &condition) { + self.create(condition); + }) + .def("create_poison", + [](TritonOpBuilder &self, Type &type) -> Value { + return self.create(type); + }) + .def("create_histogram", + [](TritonOpBuilder &self, Value operand, int numBins, + std::optional mask) -> Value { + if (!mask) { + return self.create( + RankedTensorType::get( + {static_cast(numBins)}, + IntegerType::get(operand.getContext(), 32)), + operand); + } else { + return self.create( + RankedTensorType::get( + {static_cast(numBins)}, + IntegerType::get(operand.getContext(), 32)), + operand, *mask); + } + }) + .def("create_gather", + [](TritonOpBuilder &self, Value src, Value indices, int axis) + -> Value { return self.create(src, indices, axis); }) + // Force GPU barrier + .def("create_barrier", + [](TritonOpBuilder &self) { self.create(); }) + // Make a block pointer (tensor pointer in Triton IR) + .def("create_make_block_ptr", + [](TritonOpBuilder &self, Value &base, std::vector &shape, + std::vector &strides, std::vector &offsets, + std::vector &tensorShape, + std::vector &order) -> Value { + return self.create(base, shape, strides, offsets, + tensorShape, order); + }) + // Advance a block pointer + .def("create_advance", + [](TritonOpBuilder &self, Value &ptr, + std::vector &offsets) -> Value { + return self.create(ptr.getType(), ptr, offsets); + }) + // Make a tensor descriptor + .def("create_make_tensor_descriptor", + [](TritonOpBuilder &self, Value &base, std::vector &shape, + std::vector &strides, std::vector &tensorShape, + bool isSignedInteger) -> Value { + return self.create(base, shape, strides, + tensorShape, isSignedInteger); + }) + // Proton Ops + .def("create_proton_record", + [](TritonOpBuilder &self, bool isStart, int32_t regionId) -> void { + self.create(isStart, regionId); + }); + + py::class_(m, "pass_manager", py::module_local()) + .def(py::init()) + .def("enable_debug", + [](PassManager &self) -> bool { + auto *context = self.getContext(); + bool haveDump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"); + std::string funcToDump; + if (!haveDump) { + funcToDump = triton::tools::getStrEnv("MLIR_ENABLE_DUMP"); + bool isEnvValueBool = + triton::tools::isEnvValueBool(funcToDump).has_value(); + if (!funcToDump.empty() && !isEnvValueBool) + haveDump = true; + } + if (haveDump) { + context->disableMultithreading(); + auto printingFlags = getOpPrintingFlags(); + auto printAlways = [funcToDump](Pass *, Operation *op) -> bool { + if (funcToDump.empty()) + return true; + if (auto mod = dyn_cast(op)) { + return mod.lookupSymbol(funcToDump); + } + if (auto func = dyn_cast(op)) { + return SymbolTable::getSymbolName(func).getValue() == + funcToDump; + } + + return false; + }; + self.enableIRPrinting( + /*shouldPrintBeforePass=*/printAlways, + /*shouldPrintAfterPass=*/printAlways, + /*printModuleScope=*/true, + /*printAfterOnlyOnChange=*/false, + /*printAfterOnlyOnFailure*/ true, mlir_dumps_or_dbgs(), + printingFlags); + } + return haveDump; + }) + .def("get_pipeline_str", + [](PassManager &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.printAsTextualPipeline(os); + return str; + }) + .def("run", [](PassManager &self, ModuleOp &mod) { + // TODO: maybe dump module to file and print error for better + // diagnostics + + auto *context = mod.getContext(); + if (::triton::tools::getBoolEnv("MLIR_DISABLE_MULTITHREADING")) + context->disableMultithreading(); + + auto reproducerPath = + triton::tools::getStrEnv("TRITON_REPRODUCER_PATH"); + if (!reproducerPath.empty()) { + auto anchorName = self.getOpAnchorName(); + auto passes = self.getPasses(); + Operation *op = mod.getOperation(); + // Save a reproducer for the current pass manager invocation + // immediately. + makeReproducer(anchorName, passes, op, reproducerPath); + // But if the pass manager crashes, attempt to generate a local + // reproducer instead. + context->disableMultithreading(); + self.enableCrashReproducerGeneration(reproducerPath, + /*genLocalReproducer=*/true); + } else { + self.enableCrashReproducerGeneration(makeConsoleReproducer()); + } + + if (triton::tools::getBoolEnv("TRITON_ENABLE_LLVM_DEBUG")) { + ::llvm::DebugFlag = true; + } + + if (auto debugOnly = triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY"); + !debugOnly.empty()) { + llvm::SmallVector storage; + llvm::SmallVector debugTypes = + parseCommaSeparatedValues(debugOnly, storage); + ::llvm::DebugFlag = true; + using namespace llvm; + setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); + } + + bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING"); + if (haveTiming) { + self.enableTiming(); + } + + // setting up diagnostics + bool showOperations = false, showStacktraces = false, + showRemarks = false, showWarnings = false; + + if (auto enableDiagnostics = + triton::tools::getStrEnv("MLIR_ENABLE_DIAGNOSTICS"); + !enableDiagnostics.empty()) { + llvm::SmallVector storage; + parseCommaSeparatedValues(enableDiagnostics, storage); + for (auto &str : storage) { + if (str == "warnings") { + showWarnings = true; + } else if (str == "remarks") { + showRemarks = true; + } else if (str == "stacktraces") { + showStacktraces = true; + } else if (str == "operations") { + showOperations = true; + } + // we show errors by default, so no need to set it + } + } + + DiagnosticSeverity minSeverity = showWarnings + ? DiagnosticSeverity::Warning + : DiagnosticSeverity::Error; + minSeverity = showRemarks ? DiagnosticSeverity::Remark : minSeverity; + + TritonSourceMgrDiagnosticHandler diagHandler(context, minSeverity); + + context->printOpOnDiagnostic(showOperations); + context->printStackTraceOnDiagnostic(showStacktraces); + if (showStacktraces) { + context->disableMultithreading(); + } + if (failed(self.run(mod.getOperation()))) + throw std::runtime_error("PassManager::run failed"); + }); +} + +void init_triton_env_vars(py::module &m) { + m.def("get_cache_invalidating_env_vars", + []() -> std::map { + std::map ret; + for (const auto &envVar : CACHE_INVALIDATING_ENV_VARS) { + auto strVal = triton::tools::getStrEnv(envVar); + if (strVal.empty()) + continue; + auto boolV = triton::tools::isEnvValueBool(strVal); + if (boolV.has_value()) + ret[envVar] = boolV.value() ? "true" : "false"; + else + ret[envVar] = strVal; + } + return ret; + }); +} diff --git a/third_party/sunrise/python/src/ir.h b/third_party/sunrise/python/src/ir.h new file mode 100644 index 000000000..e1f9ce848 --- /dev/null +++ b/third_party/sunrise/python/src/ir.h @@ -0,0 +1,105 @@ +#pragma once +#include "mlir/IR/Builders.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/ArrayRef.h" +#include + +typedef int AsyncTaskId; +void setAsyncTaskIds(mlir::Operation *op, + llvm::ArrayRef asyncTaskIds); + +// A custom op builder that keeps track of the last location +class TritonOpBuilder { +public: + TritonOpBuilder(mlir::MLIRContext *context) { + builder = std::make_unique(context); + lastLoc = std::make_unique(builder->getUnknownLoc()); + } + + mlir::OpBuilder &getBuilder() { return *builder; } + mlir::MLIRContext *getContext() { return builder->getContext(); } + + bool isLineInfoEnabled() { return lineInfoEnabled; } + + void setLastLoc(mlir::Location loc) { + if (lineInfoEnabled) + lastLoc = std::make_unique(loc); + } + + void setLastLoc(const std::string &fileName, int line, int column) { + auto context = builder->getContext(); + setLastLoc(mlir::FileLineColLoc::get(context, fileName, line, column)); + } + + mlir::Location getLastLoc() { + assert(lastLoc); + return *lastLoc; + } + + void setInsertionPointToStart(mlir::Block &block) { + if (!block.empty()) + setLastLoc(block.begin()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToStart(&block); + } + + void setInsertionPointToEnd(mlir::Block &block) { + if (!block.empty()) + setLastLoc(block.back().getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToEnd(&block); + } + + void setInsertionPointAfter(mlir::Operation &op) { + setLastLoc(op.getLoc()); + builder->setInsertionPointAfter(&op); + } + + void restoreInsertionPoint(mlir::OpBuilder::InsertPoint pt) { + if (pt.isSet() && pt.getPoint() != pt.getBlock()->end()) + setLastLoc(pt.getPoint()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->restoreInsertionPoint(pt); + } + + template OpTy create(Args &&...args) { + auto loc = getLastLoc(); + auto ret = builder->create(loc, std::forward(args)...); + if (asyncTaskIds) + ::setAsyncTaskIds(ret, *asyncTaskIds); + return ret; + } + + // Overload to create or fold a single result operation. + template + std::enable_if_t(), + mlir::Value> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + + // Overload to create or fold a zero result operation. + template + std::enable_if_t(), OpTy> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + + void setAsyncTaskIds(std::vector taskIds) { + this->asyncTaskIds = taskIds; + } + + void unsetAsyncTaskIds() { this->asyncTaskIds = std::nullopt; } + +private: + std::unique_ptr builder; + std::unique_ptr lastLoc; + std::optional> asyncTaskIds; + bool lineInfoEnabled = + !mlir::triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); +}; diff --git a/third_party/sunrise/python/src/llvm.cc b/third_party/sunrise/python/src/llvm.cc new file mode 100644 index 000000000..016958c18 --- /dev/null +++ b/third_party/sunrise/python/src/llvm.cc @@ -0,0 +1,530 @@ +#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Linker/Linker.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Pass.h" +#include "llvm/Passes/OptimizationLevel.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/PassPlugin.h" +#include "llvm/Passes/StandardInstrumentations.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/Signals.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/IPO/AlwaysInliner.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include "llvm/Transforms/Instrumentation/AddressSanitizer.h" +#include "llvm/Transforms/Instrumentation/AddressSanitizerOptions.h" +#include +#include +#include +#include +#include + +namespace py = pybind11; + +namespace llvm { +struct BreakStructPhiNodesPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + static StringRef name() { return "BreakStructPhiNodesPass"; } +}; +} // namespace llvm + +using namespace llvm; + +std::unique_ptr +createTargetMachine(llvm::Module *module, std::string proc, + bool enable_fp_fusion, const std::string &features) { + std::string error; + auto target = llvm::TargetRegistry::lookupTarget( + module->getTargetTriple().str(), error); + llvm::TargetOptions opt; + bool disableLLVMOpt = mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT"); + if (enable_fp_fusion) + opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; + opt.UnsafeFPMath = false; + opt.NoInfsFPMath = false; + opt.NoNaNsFPMath = true; + opt.TrapUnreachable = true; + opt.MCOptions.AsmVerbose = true; + opt.MCOptions.PreserveAsmComments = true; + std::unique_ptr machine{target->createTargetMachine( + module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, + std::nullopt, + disableLLVMOpt ? llvm::CodeGenOptLevel::None + : llvm::CodeGenOptLevel::Aggressive)}; + return machine; +} + +std::string translateLLVMIRToASM(llvm::Module &module, + const std::string &triple, + const std::string &proc, + const std::string &features, + const std::vector &flags, + bool enable_fp_fusion, bool isObject) { + using namespace mlir; + auto isInteger=[&](const std::string& str) { + return !str.empty() && std::all_of(str.begin(), str.end(), ::isdigit); + }; + // options + auto options = llvm::cl::getRegisteredOptions(); + for (std::string flag : flags) { + auto pos = flag.find("="); + if(pos == std::string::npos){ + auto *shortPtr = static_cast *>(options[flag]); + assert(shortPtr); + shortPtr->setValue(true); + }else{ + auto optIt = options.find(flag.substr(0,pos)); + assert(optIt != options.end() && "Option not found in cl::opt!"); + + if (isInteger(flag.substr(pos+1))) { + if(auto *unsignedOpt = static_cast*>(options[flag.substr(0,pos)])) { + unsignedOpt->setValue(std::stoi(flag.substr(pos+1))); + } + } + else { + if (auto *stringOpt = static_cast*>(options[flag.substr(0,pos)])) { + stringOpt->setValue(flag.substr(pos+1)); + } + } + } + } + if (triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) { + auto optIt = options.find("print-after-all"); + if (optIt != options.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + } + bool disableLLVMOpt = triton::tools::getBoolEnv("DISABLE_LLVM_OPT"); + if (!disableLLVMOpt) { + // Check to see if we are passing a list of flags to disable optimizations. + auto flagList = triton::tools::getStrEnv("DISABLE_LLVM_OPT"); + if (!flagList.empty()) { + llvm::SmallVector split; + StringRef(flagList.c_str()).split(split, ','); + for (auto flag : split) { + auto optIt = options.find(flag); + if (optIt != options.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + } + } + } + + // inline everything + for (llvm::Function &f : module.functions()) + if (!f.hasFnAttribute(llvm::Attribute::NoInline)) + f.addFnAttr(llvm::Attribute::AlwaysInline); + // verify and store llvm + llvm::legacy::PassManager pm; + pm.add(llvm::createAlwaysInlinerLegacyPass()); + pm.add(llvm::createVerifierPass()); + + const bool enabledTiming = triton::tools::getBoolEnv("LLVM_ENABLE_TIMING"); + if (enabledTiming) { + llvm::TimePassesIsEnabled = true; + llvm::TimePassesPerRun = true; + } + + pm.run(module); + + SmallString<0> timePassesStr; + raw_svector_ostream reportStream(timePassesStr); + + if (enabledTiming) { + reportAndResetTimings(&reportStream); + llvm::dbgs() << reportStream.str(); + timePassesStr.clear(); + } + // module->print(llvm::outs(), nullptr); + + // create machine + module.setTargetTriple(Triple(triple)); + auto machine = createTargetMachine(&module, proc, enable_fp_fusion, features); + // set data layout + module.setDataLayout(machine->createDataLayout()); + // emit machine code + std::string result; + { + llvm::raw_string_ostream stream(result); + llvm::buffer_ostream pstream(stream); + llvm::legacy::PassManager pass; + // emit + auto fileType = isObject ? llvm::CodeGenFileType::ObjectFile + : llvm::CodeGenFileType::AssemblyFile; + machine->addPassesToEmitFile(pass, pstream, nullptr, fileType); + pass.run(module); + + if (enabledTiming) { + reportAndResetTimings(&reportStream); + llvm::dbgs() << reportStream.str(); + timePassesStr.clear(); + } + } + return result; +} + +using ret = py::return_value_policy; + +void init_triton_llvm(py::module &&m) { + + py::class_(m, "context", py::module_local()) + .def(py::init<>()); + py::class_(m, "source_mgr", py::module_local()) + .def(py::init<>()); + + py::class_(m, "function_list") + .def( + "__iter__", + [](llvm::Module::FunctionListType &s) { + return py::make_iterator(s.begin(), s.end()); + }, + py::keep_alive<0, 1>()); + + // Module Flag behavior. See + // https://llvm.org/doxygen/classllvm_1_1Module.html#a0a5c55e12c97b80021330fe82b642293 + // for details. + py::class_(m, "module_flag_behavior", + py::module_local()); + m.attr("MODULE_FLAG_BEHAVIOR_ERROR") = llvm::Module::Error; + m.attr("MODULE_FLAG_BEHAVIOR_WARNING") = llvm::Module::Warning; + m.attr("MODULE_FLAG_BEHAVIOR_REQUIRE") = llvm::Module::Require; + m.attr("MODULE_FLAG_BEHAVIOR_OVERRIDE") = llvm::Module::Override; + m.attr("MODULE_FLAG_BEHAVIOR_APPEND") = llvm::Module::Append; + m.attr("MODULE_FLAG_BEHAVIOR_APPEND_UNIQUE") = llvm::Module::AppendUnique; + m.attr("MODULE_FLAG_BEHAVIOR_MAX") = llvm::Module::Max; + m.attr("MODULE_FLAG_BEHAVIOR_MIN") = llvm::Module::Min; + + py::class_(m, "module", py::module_local()) + .def( + "__str__", + [](llvm::Module *self) { + std::string str; + llvm::raw_string_ostream os(str); + os << *self; + return os.str(); + }, + ret::take_ownership) + .def( + "get_functions", + [](llvm::Module *mod) -> llvm::Module::FunctionListType & { + // Note: Backends assume that we are compiling exactly one kernel + // (i.e. one function that's that's called by the CPU) and that it's + // the first function in this list. + return mod->getFunctionList(); + }, + ret::reference_internal) + .def("add_flag", + [](llvm::Module *mod, llvm::Module::ModFlagBehavior behavior, + std::string &key, uint32_t value) { + return mod->addModuleFlag(behavior, key, value); + }); + + py::class_(m, "function", py::module_local()) + .def_property_readonly( + "name", [](llvm::Function *fn) { return fn->getName().str(); }) + .def("set_calling_conv", &llvm::Function::setCallingConv) + .def("add_fn_attr", [](llvm::Function *fn, std::string &name, + std::string &val) { fn->addFnAttr(name, val); }) + .def("add_fn_asan_attr", + [](llvm::Function *fn) { + fn->addFnAttr(llvm::Attribute::SanitizeAddress); + }) + .def("add_fn_target_feature", + [](llvm::Function *fn, std::string &val) { + fn->addFnAttr("target-features", val); + }) + // Sets the nvvm.maxreg property on the given function. + .def("set_nvvm_maxnreg", + [](llvm::Function *fn, int maxnreg) { + auto op = MDNode::get( + fn->getContext(), + { + ValueAsMetadata::get(fn), + MDString::get(fn->getContext(), "maxnreg"), + ConstantAsMetadata::get(ConstantInt::get( + Type::getInt32Ty(fn->getContext()), maxnreg)), + }); + fn->getParent() + ->getOrInsertNamedMetadata("nvvm.annotations") + ->addOperand(op); + }) + // External functions that are definitions (i.e. not declarations) are + // kernel functions. + .def("is_declaration", &llvm::Function::isDeclaration) + .def("is_external_linkage", [](llvm::Function *fn) { + return fn->getLinkage() == llvm::GlobalValue::ExternalLinkage; + }); + + // optimization levels + py::class_(m, "optimization_level", + py::module_local()); + m.attr("OPTIMIZE_O0") = llvm::OptimizationLevel::O0; + m.attr("OPTIMIZE_O1") = llvm::OptimizationLevel::O1; + m.attr("OPTIMIZE_O2") = llvm::OptimizationLevel::O2; + m.attr("OPTIMIZE_O3") = llvm::OptimizationLevel::O3; + m.attr("OPTIMIZE_Os") = llvm::OptimizationLevel::Os; + m.attr("OPTIMIZE_Oz") = llvm::OptimizationLevel::Oz; + + m.def( + "to_module", + [](mlir::ModuleOp &mod, llvm::LLVMContext &ctx) { + std::unique_ptr llvmMod = + mlir::translateModuleToLLVMIR(mod, ctx); + if (!llvmMod) { + throw std::runtime_error("failed to translate module to LLVM IR"); + } + return llvmMod; + }, + py::keep_alive<0, 2>()); + + m.def("attach_datalayout", [](llvm::Module *mod, const std::string triple, + const std::string proc, + const std::string features) { + std::string error; + auto target = llvm::TargetRegistry::lookupTarget(triple, error); + if (!target) { + throw std::runtime_error("target lookup error: " + error); + } + llvm::TargetOptions opt; + // Target machine is only used to create the data layout. + std::unique_ptr machine{target->createTargetMachine( + llvm::Triple(triple), proc, features, opt, llvm::Reloc::PIC_, + std::nullopt, llvm::CodeGenOptLevel::None)}; + // set data layout + mod->setDataLayout(machine->createDataLayout()); + }); + + m.def( + "optimize_module", + [](llvm::Module *mod, const llvm::OptimizationLevel &opt, + std::string arch, std::string features, std::vector flags, + bool enable_fp_fusion) { + if (mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT")) + return; + // Check to see if we are passing a list of flags to disable + // optimizations. + auto flagList = mlir::triton::tools::getStrEnv("DISABLE_LLVM_OPT"); + if (!flagList.empty()) { + auto options = llvm::cl::getRegisteredOptions(); + llvm::SmallVector split; + StringRef(flagList.c_str()).split(split, ','); + for (auto flag : split) { + auto optIt = options.find(flag); + if (optIt != options.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + } + } + using namespace llvm; + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + + if (arch.empty()) { + llvm::TargetLibraryInfoImpl TLII; + TLII.disableAllFunctions(); + fam.registerPass([TLII = std::move(TLII)] { + return llvm::TargetLibraryAnalysis(TLII); + }); + } + + PassInstrumentationCallbacks *instrCbPtr = nullptr; + PassInstrumentationCallbacks passInstrCb; + StandardInstrumentations standardInstr(mod->getContext(), + /*DebugLogging*/ true); + if (mlir::triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) { + auto optMap = llvm::cl::getRegisteredOptions(); + auto optIt = optMap.find("print-after-all"); + if (optIt != optMap.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + standardInstr.registerCallbacks(passInstrCb, &mam); + instrCbPtr = &passInstrCb; + } + + PipelineTuningOptions tuningOptions; + tuningOptions.LoopUnrolling = true; + tuningOptions.LoopInterleaving = true; + tuningOptions.LoopVectorization = true; + // TODO: currently we run SLP vectorizer with an empty target machine. + // This cause the vectorizer to create larger vector which could be bad. + // Disabling it would currently cause regressions as this pass also + // applies some scheduling that helps performance in some cases. We + // should work on using NVPTX target instead and address the performance + // regressions with some scheduling solution. + tuningOptions.SLPVectorization = true; + + std::string pluginFile = + mlir::triton::tools::getStrEnv("LLVM_PASS_PLUGIN_PATH"); + + // We don't pass the targetMachine to the LLVM-IR pass builder, unless + // `arch` is specified. + // + // Don't set target machine in LLVM pass builder when using LLVM IR + // level plugins. LLVM IR level plugin passes typically want to insert + // calls to externally generated code (i.e. precompile a Cuda/Hip kernel + // with Clang and then insert a call to it within an instrumentation + // pass) setting the targetMachine value here can can cause a mismatch + // in the target machine between the MLIR and Clang generated kernels + // and break the lowering of some target specific intrinsics. + std::unique_ptr targetMachine = nullptr; + if (!arch.empty() && pluginFile.empty()) + targetMachine = + createTargetMachine(mod, arch, enable_fp_fusion, features); + PassBuilder pb(/*targetMachine=*/targetMachine.get(), tuningOptions, + std::nullopt, instrCbPtr); + + if (!pluginFile.empty()) { + // TODO: Add some logging here that we inserted a pass into the LLVM + // pass pipeline + auto passPlugin = llvm::PassPlugin::Load(pluginFile); + if (!passPlugin) { + llvm::Error Err = passPlugin.takeError(); + std::string ErrMsg = + "Pass Plugin Error: " + llvm::toString(std::move(Err)); + throw std::runtime_error(ErrMsg); + } + passPlugin->registerPassBuilderCallbacks(pb); + } + + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + pb.registerVectorizerStartEPCallback( + [&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) { + // Triton generates large structure of scalars which may pessimise + // optimizations, we run a pass to break up phi of struct to make + // sure all the struct are removed for the following passes. + fpm.addPass(BreakStructPhiNodesPass()); + fpm.addPass(InstCombinePass()); + }); + bool enableAddressSanitizer = + mlir::triton::tools::getBoolEnv("TRITON_ENABLE_ASAN"); + if (enableAddressSanitizer) { + AddressSanitizerOptions Opts; + mpm.addPass(AddressSanitizerPass(Opts)); + } + mpm.addPass(pb.buildPerModuleDefaultPipeline(opt)); + mpm.run(*mod, mam); + }, + // Mandatory parameters + py::arg("mod"), py::arg("opt"), + // If we want to specify the target machine, we require additional + // (optional) parameters + py::arg("arch") = "", py::arg("features") = "", + py::arg("flags") = std::vector{}, + py::arg("enable_fp_fusion") = false); + + m.def( + "translate_to_asm", + [](std::string llvmIR, std::string triple, std::string proc, + std::string features, std::vector flags, + bool enable_fp_fusion, bool isObject) -> py::object { + std::string obj; + { + // when allow_threads goes out of scope, gil will be released + py::gil_scoped_release allow_threads; + // create LLVM module from C++ + llvm::LLVMContext context; + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIR(buffer->getMemBufferRef(), error, context); + if (!module) { + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + obj = translateLLVMIRToASM(*module, triple, proc, features, flags, + enable_fp_fusion, isObject); + } + if (isObject) + return py::bytes(obj); + else + return py::str(obj); + }, + ret::take_ownership); + + m.def("init_targets", []() { + static std::once_flag init_flag; + std::call_once(init_flag, []() { + llvm::InitializeAllTargetInfos(); + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + llvm::InitializeAllAsmParsers(); + llvm::InitializeAllAsmPrinters(); + }); + }); + + m.def("link_extern_libs", [](llvm::Module *dstMod, + const std::vector &paths) { + if (paths.empty()) + return; + + LLVMContext &ctx = dstMod->getContext(); + llvm::Linker linker(*dstMod); + for (const std::string &path : paths) { + llvm::SMDiagnostic err; + std::unique_ptr libMod = llvm::parseIRFile(path, err, ctx); + if (!libMod) { + std::string message = "Failed to parse library at " + path; + throw std::invalid_argument(message); + } + libMod->setTargetTriple(Triple(dstMod->getTargetTriple())); + libMod->setDataLayout(dstMod->getDataLayout()); + + std::unordered_set externalFns; + for (llvm::Function &fn : libMod->functions()) { + if (!fn.isDeclaration()) + externalFns.insert(fn.getName().str()); + } + + if (linker.linkInModule(std::move(libMod), + llvm::Linker::Flags::LinkOnlyNeeded)) { + std::string message = "Failed to link library at " + path; + throw std::invalid_argument(message); + } + + // Mark linked-in functions as internal because backends use external + // linkage as a signifier of kernel functions. + for (llvm::Function &fn : dstMod->functions()) { + if (externalFns.count(fn.getName().str())) { + fn.setLinkage(llvm::GlobalValue::InternalLinkage); + } + } + } + }); +} + +void triton_stacktrace_signal_handler(void *) { + llvm::sys::PrintStackTrace(llvm::errs()); + raise(SIGABRT); +} + +void init_triton_stacktrace_hook(pybind11::module &m) { + if (mlir::triton::tools::getBoolEnv("TRITON_ENABLE_PYTHON_STACKTRACE")) { + llvm::sys::AddSignalHandler(triton_stacktrace_signal_handler, nullptr); + } +} diff --git a/third_party/sunrise/python/src/main.cc b/third_party/sunrise/python/src/main.cc new file mode 100644 index 000000000..815a187fb --- /dev/null +++ b/third_party/sunrise/python/src/main.cc @@ -0,0 +1,57 @@ +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Signals.h" +#include + +namespace py = pybind11; + +#define FOR_EACH_1(MACRO, X) MACRO(X) +#define FOR_EACH_2(MACRO, X, ...) MACRO(X) FOR_EACH_1(MACRO, __VA_ARGS__) +#define FOR_EACH_3(MACRO, X, ...) MACRO(X) FOR_EACH_2(MACRO, __VA_ARGS__) +#define FOR_EACH_4(MACRO, X, ...) MACRO(X) FOR_EACH_3(MACRO, __VA_ARGS__) + +#define FOR_EACH_NARG(...) FOR_EACH_NARG_(__VA_ARGS__, FOR_EACH_RSEQ_N()) +#define FOR_EACH_NARG_(...) FOR_EACH_ARG_N(__VA_ARGS__) +#define FOR_EACH_ARG_N(_1, _2, _3, _4, N, ...) N +#define FOR_EACH_RSEQ_N() 4, 3, 2, 1, 0 + +#define CONCATENATE(x, y) CONCATENATE1(x, y) +#define CONCATENATE1(x, y) x##y + +#define FOR_EACH(MACRO, ...) \ + CONCATENATE(FOR_EACH_, FOR_EACH_NARG_HELPER(__VA_ARGS__))(MACRO, __VA_ARGS__) +#define FOR_EACH_NARG_HELPER(...) FOR_EACH_NARG(__VA_ARGS__) + +// New macro to remove parentheses +#define REMOVE_PARENS(...) __VA_ARGS__ + +// Intermediate macro to ensure correct expansion +#define FOR_EACH_P_INTERMEDIATE(MACRO, ...) FOR_EACH(MACRO, __VA_ARGS__) + +// Modified FOR_EACH to handle parentheses +#define FOR_EACH_P(MACRO, ARGS_WITH_PARENS) \ + FOR_EACH_P_INTERMEDIATE(MACRO, REMOVE_PARENS ARGS_WITH_PARENS) + +#define DECLARE_BACKEND(name) void init_triton_##name(pybind11::module &&m); + +#define INIT_BACKEND(name) init_triton_##name(m.def_submodule(#name)); + +void init_triton_env_vars(pybind11::module &m); +void init_triton_ir(pybind11::module &&m); +void init_triton_llvm(pybind11::module &&m); +void init_triton_interpreter(pybind11::module &&m); +void init_triton_passes(pybind11::module &&m); +void init_triton_stacktrace_hook(pybind11::module &m); +void init_gluon_ir(pybind11::module &&m); +FOR_EACH_P(DECLARE_BACKEND, TRITON_BACKENDS_TUPLE) + +PYBIND11_MODULE(libtriton, m) { + m.doc() = "Python bindings to the C++ Triton API"; + init_triton_stacktrace_hook(m); + init_triton_env_vars(m); + init_triton_ir(m.def_submodule("ir")); + init_triton_passes(m.def_submodule("passes")); + init_triton_interpreter(m.def_submodule("interpreter")); + init_triton_llvm(m.def_submodule("llvm")); + init_gluon_ir(m.def_submodule("gluon_ir")); + FOR_EACH_P(INIT_BACKEND, TRITON_BACKENDS_TUPLE) +} diff --git a/third_party/sunrise/python/src/passes.cc b/third_party/sunrise/python/src/passes.cc new file mode 100644 index 000000000..ebac33f43 --- /dev/null +++ b/third_party/sunrise/python/src/passes.cc @@ -0,0 +1,117 @@ +#include "mlir/Transforms/Passes.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "passes.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Target/LLVMIR/Passes.h" +#include +#include + +namespace py = pybind11; + +void init_triton_analysis(py::module &&m) { + py::class_(m, "allocation", py::module_local()) + .def(py::init()); + py::class_(m, "membar", py::module_local()) + .def(py::init()) + .def("run", &mlir::ModuleMembarAnalysis::run); +} + +void init_triton_passes_common(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_sccp", createSCCPPass); + ADD_PASS_WRAPPER_0("add_symbol_dce", createSymbolDCEPass); + ADD_PASS_WRAPPER_0("add_inliner", createInlinerPass); + ADD_PASS_WRAPPER_0("add_canonicalizer", createCanonicalizerPass); + ADD_PASS_WRAPPER_0("add_cse", createCSEPass); + ADD_PASS_WRAPPER_0("add_licm", createLoopInvariantCodeMotionPass); + ADD_PASS_WRAPPER_0("print_ir", createPrintIRPass); +} + +void init_triton_passes_ttir(py::module &&m) { + using namespace mlir::triton; + ADD_PASS_WRAPPER_0("add_combine", createTritonCombineOps); + ADD_PASS_WRAPPER_0("add_reorder_broadcast", createTritonReorderBroadcast); + ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer", + createTritonRewriteTensorPointer); + ADD_PASS_WRAPPER_0("add_rewrite_tensor_descriptor_to_pointer", + createTritonRewriteTensorDescriptorToPointer); + ADD_PASS_WRAPPER_0("add_loop_unroll", createTritonLoopUnroll); + ADD_PASS_WRAPPER_0("add_triton_licm", createTritonLoopInvariantCodeMotion); + ADD_PASS_WRAPPER_0("add_loop_aware_cse", createTritonLoopAwareCSE); + ADD_PASS_OPTION_WRAPPER_4("add_convert_to_ttgpuir", + createConvertTritonToTritonGPU, const std::string &, + int, int, int); +} + +void init_triton_passes_ttgpuir(py::module &&m) { + using namespace mlir; + using namespace mlir::triton::gpu; + ADD_PASS_WRAPPER_0("add_coalesce", createTritonGPUCoalesce); + ADD_PASS_WRAPPER_0("add_optimize_thread_locality", + createTritonGPUOptimizeThreadLocality); + ADD_PASS_WRAPPER_0("add_hoist_tmem_alloc", createTritonGPUHoistTMEMAlloc); + ADD_PASS_OPTION_WRAPPER_1("add_assign_latencies", + createTritonGPUAssignLatencies, int); + ADD_PASS_WRAPPER_0("add_schedule_loops", createTritonGPUScheduleLoops); + ADD_PASS_OPTION_WRAPPER_2("add_pipeline", createTritonGPUPipeline, int, bool); + ADD_PASS_OPTION_WRAPPER_1("add_warp_specialize", + createTritonGPUAutomaticWarpSpecialization, int); + ADD_PASS_WRAPPER_0("add_prefetch", createTritonGPUPrefetch); + ADD_PASS_WRAPPER_0("add_accelerate_matmul", createTritonGPUAccelerateMatmul); + ADD_PASS_WRAPPER_0("add_reorder_instructions", + createTritonGPUReorderInstructions); + ADD_PASS_WRAPPER_0("add_f32_dot_tc", createTritonGPUF32DotTC); + ADD_PASS_OPTION_WRAPPER_1("add_optimize_dot_operands", + createTritonGPUOptimizeDotOperands, bool); + ADD_PASS_WRAPPER_0("add_remove_layout_conversions", + createTritonGPURemoveLayoutConversions); + ADD_PASS_WRAPPER_0("add_reduce_data_duplication", + createTritonGPUReduceDataDuplication); + ADD_PASS_WRAPPER_0("add_allocate_warp_groups", + createTritonGPUAllocateWarpGroups); + ADD_PASS_WRAPPER_0("add_allocate_shared_memory", createAllocateSharedMemory); + ADD_PASS_WRAPPER_0("add_allocate_global_scratch_memory", + createTritonGPUGlobalScratchAllocationPass); + ADD_PASS_WRAPPER_0("add_combine_tensor_select_and_if", + createTritonGPUCombineTensorSelectAndIf); + ADD_PASS_WRAPPER_0("add_optimize_accumulator_init", + createTritonGPUOptimizeAccumulatorInit); + ADD_PASS_WRAPPER_0("add_fuse_nested_loops", createTritonGPUFuseNestedLoops); + ADD_PASS_WRAPPER_0("add_coalesce_async_copy", + createTritonGPUCoalesceAsyncCopy); + ADD_PASS_WRAPPER_0("add_canonicalizer", createTritonGPUCanonicalize); + ADD_PASS_WRAPPER_0("add_inliner", [] { + return createInlinerPass(/*opPipelines=*/{}, [](OpPassManager &pm) { + pm.addPass(createTritonGPUCanonicalize()); + }); + }); +} + +void init_triton_passes_convert(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_scf_to_cf", createSCFToControlFlowPass); + ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass); + ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass); + ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass); +} + +void init_triton_passes_llvmir(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_di_scope", mlir::createLLVMDIScope); +} + +void init_triton_passes(py::module &&m) { + init_triton_analysis(m.def_submodule("analysis")); + init_triton_passes_common(m.def_submodule("common")); + init_triton_passes_convert(m.def_submodule("convert")); + init_triton_passes_ttir(m.def_submodule("ttir")); + init_triton_passes_ttgpuir(m.def_submodule("ttgpuir")); + init_triton_passes_llvmir(m.def_submodule("llvmir")); +} diff --git a/third_party/sunrise/python/src/passes.h b/third_party/sunrise/python/src/passes.h new file mode 100644 index 000000000..629fe362d --- /dev/null +++ b/third_party/sunrise/python/src/passes.h @@ -0,0 +1,38 @@ +#define ADD_PASS_WRAPPER_0(name, builder) \ + m.def(name, [](mlir::PassManager &pm) { pm.addPass(builder()); }) + +#define ADD_PASS_WRAPPER_1(name, builder, ty0) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder(val0)); }) + +#define ADD_PASS_WRAPPER_2(name, builder, ty0, ty1) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \ + pm.addPass(builder(val0, val1)); \ + }) + +#define ADD_PASS_WRAPPER_3(name, builder, ty0, ty1, ty2) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2) { \ + pm.addPass(builder(val0, val1, val2)); \ + }) + +#define ADD_PASS_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \ + ty3 val3) { pm.addPass(builder(val0, val1, val2, val3)); }) + +#define ADD_PASS_OPTION_WRAPPER_1(name, builder, ty0) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder({val0})); }) + +#define ADD_PASS_OPTION_WRAPPER_2(name, builder, ty0, ty1) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \ + pm.addPass(builder({val0, val1})); \ + }) + +#define ADD_PASS_OPTION_WRAPPER_3(name, builder, ty0, ty1, ty2) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2) { \ + pm.addPass(builder({val0, val1, val2})); \ + }) + +#define ADD_PASS_OPTION_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \ + ty3 val3) { pm.addPass(builder({val0, val1, val2, val3})); }) diff --git a/third_party/sunrise/python/test_examples/01-vector-add.py b/third_party/sunrise/python/test_examples/01-vector-add.py new file mode 100644 index 000000000..3619d7ec5 --- /dev/null +++ b/third_party/sunrise/python/test_examples/01-vector-add.py @@ -0,0 +1,135 @@ +""" +Vector Addition +=============== + +In this tutorial, you will write a simple vector addition using Triton. + +In doing so, you will learn about: + +* The basic programming model of Triton. + +* The `triton.jit` decorator, which is used to define Triton kernels. + +* The best practices for validating and benchmarking your custom ops against native reference implementations. + +""" + +# %% +# Compute Kernel +# -------------- + +import torch +import torch_ptpu +import triton +import triton.language as tl + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) + + +# %% +# Let's also declare a helper function to (1) allocate the `z` tensor +# and (2) enqueue the above kernel with appropriate grid/block sizes: + + +def add(x: torch.Tensor, y: torch.Tensor): + # We need to preallocate the output. + output = torch.empty_like(x) + assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE + n_elements = output.numel() + # The SPMD launch grid denotes the number of kernel instances that run in parallel. + # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. + # In this case, we use a 1D grid where the size is the number of blocks: + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + # NOTE: + # - Each torch.tensor object is implicitly converted into a pointer to its first element. + # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. + # - Don't forget to pass meta-parameters as keywords arguments. + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still + # running asynchronously at this point. + return output + + +# %% +# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness: + +torch.manual_seed(0) +size = 98432 +x = torch.rand(size, device=DEVICE) +y = torch.rand(size, device=DEVICE) +output_torch = x + y +output_triton = add(x, y) +print(output_torch) +print(output_triton) +print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') + +# %% +# Seems like we're good to go! + +# %% +# Benchmark +# --------- +# +# We can now benchmark our custom op on vectors of increasing sizes to get a sense of how it does relative to PyTorch. +# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom ops. +# for different problem sizes. + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['size'], # Argument names to use as an x-axis for the plot. + x_vals=[2**i for i in range(12, 28, 1)], # Different possible values for `x_name`. + x_log=True, # x axis is logarithmic. + line_arg='provider', # Argument name whose value corresponds to a different line in the plot. + line_vals=['triton', 'torch'], # Possible values for `line_arg`. + line_names=['Triton', 'Torch'], # Label name for the lines. + styles=[('blue', '-'), ('green', '-')], # Line styles. + ylabel='GB/s', # Label name for the y-axis. + plot_name='vector-add-performance', # Name for the plot. Used also as a file name for saving the plot. + args={}, # Values for function arguments not in `x_names` and `y_name`. + )) +def benchmark(size, provider): + x = torch.rand(size, device=DEVICE, dtype=torch.float32) + y = torch.rand(size, device=DEVICE, dtype=torch.float32) + quantiles = [0.5, 0.2, 0.8] + if provider == 'torch': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles) + gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) + return gbps(ms), gbps(max_ms), gbps(min_ms) + + +# %% +# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or +# `save_path='/path/to/results/' to save them to disk along with raw CSV data: +benchmark.run(print_data=True, show_plots=True) diff --git a/third_party/sunrise/python/triton/__init__.py b/third_party/sunrise/python/triton/__init__.py new file mode 100644 index 000000000..bd93b642c --- /dev/null +++ b/third_party/sunrise/python/triton/__init__.py @@ -0,0 +1,74 @@ +"""isort:skip_file""" +__version__ = '3.4.0.1' + +# --------------------------------------- +# Note: import order is significant here. +# submodules +from .runtime import ( + autotune, + Config, + heuristics, + JITFunction, + KernelInterface, + reinterpret, + TensorWrapper, + OutOfResources, + InterpreterError, + MockTensor, +) +from .runtime.jit import jit +from .compiler import compile, CompilationError +from .errors import TritonError +from .runtime._allocation import set_allocator + +from . import language +from . import testing + +must_use_result = language.core.must_use_result + +__all__ = [ + "autotune", + "cdiv", + "CompilationError", + "compile", + "Config", + "heuristics", + "InterpreterError", + "jit", + "JITFunction", + "KernelInterface", + "language", + "MockTensor", + "must_use_result", + "next_power_of_2", + "OutOfResources", + "reinterpret", + "runtime", + "set_allocator", + "TensorWrapper", + "TritonError", + "testing", + "tools", +] + +# ------------------------------------- +# misc. utilities that don't fit well +# into any specific module +# ------------------------------------- + + +def cdiv(x: int, y: int): + return (x + y - 1) // y + + +def next_power_of_2(n: int): + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n diff --git a/third_party/sunrise/python/triton/_filecheck.py b/third_party/sunrise/python/triton/_filecheck.py new file mode 100644 index 000000000..2a857d8f1 --- /dev/null +++ b/third_party/sunrise/python/triton/_filecheck.py @@ -0,0 +1,87 @@ +import os +import inspect +import subprocess +import tempfile + +import triton +from triton.compiler import ASTSource, make_backend +from triton.backends.compiler import GPUTarget +from triton.experimental.gluon._runtime import GluonASTSource +from triton._C.libtriton import ir + +# ===-----------------------------------------------------------------------===# +# filecheck_test +# ===-----------------------------------------------------------------------===# + +# Stub target for testing the frontend. +stub_target = GPUTarget("cuda", 100, 32) +stub_backend = make_backend(stub_target) + +triton_dir = os.path.dirname(__file__) +filecheck_path = os.path.join(triton_dir, "FileCheck") + + +class MatchError(ValueError): + + def __init__(self, message, module_str): + super().__init__(message) + self.module_str = module_str + + def __str__(self): + return f"{super().__str__()}\n{self.module_str}" + + +def run_filecheck(name, module_str, check_template): + with tempfile.TemporaryDirectory() as tempdir: + temp_module = os.path.join(tempdir, "module") + with open(temp_module, "w") as temp: + temp.write(module_str) + + temp_expected = os.path.join(tempdir, "expected") + with open(temp_expected, "w") as temp: + temp.write(check_template) + + try: + subprocess.check_output([filecheck_path, temp_expected, "--input-file", temp_module], + stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as error: + decoded = error.output.decode('unicode_escape') + raise ValueError(decoded) + + +def run_parser(kernel_fn): + sigkeys = [x.name for x in kernel_fn.params] + sigvals = [f"arg{i}" for i in range(len(sigkeys))] + signature = {k: v for (k, v) in zip(sigkeys, sigvals)} + source_cls = GluonASTSource if kernel_fn.is_gluon() else ASTSource + src = source_cls(fn=kernel_fn, signature=signature) + + context = ir.context() + ir.load_dialects(context) + stub_backend.load_dialects(context) + + extra_options = src.parse_options() + options = stub_backend.parse_options(dict(**extra_options)) + codegen_fns = stub_backend.get_codegen_implementation(options) + module_map = stub_backend.get_module_map() + module = src.make_ir(options, codegen_fns, module_map, context) + assert module.verify() + return module + + +def run_filecheck_test(kernel_fn): + assert isinstance(kernel_fn, triton.runtime.JITFunction) + check_template = inspect.getsource(kernel_fn.fn) + if check_template is None: + raise ValueError("kernel function must have a docstring with FileCheck template") + mlir_module = run_parser(kernel_fn) + + run_filecheck("placeholder", mlir_module.str_nodebug(), check_template) + + +def filecheck_test(fn): + + def test_fn(): + run_filecheck_test(fn) + + return test_fn diff --git a/third_party/sunrise/python/triton/_utils.py b/third_party/sunrise/python/triton/_utils.py new file mode 100644 index 000000000..1da8d692b --- /dev/null +++ b/third_party/sunrise/python/triton/_utils.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from functools import reduce +from typing import Any, Callable, TYPE_CHECKING, Union, List, Dict + +if TYPE_CHECKING: + from .language import core + IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type] + ObjPath = tuple[int, ...] + +TRITON_MAX_TENSOR_NUMEL = 1048576 + + +def get_iterable_path(iterable: IterableType, path: ObjPath) -> Any: + return reduce(lambda a, idx: a[idx], path, iterable) # type: ignore[index] + + +def set_iterable_path(iterable: IterableType, path: tuple[int, ...], val: Any): + assert len(path) != 0 + prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1]) + prev[path[-1]] = val # type: ignore[index] + + +def find_paths_if(iterable: Union[IterableType, Any], pred: Callable[[ObjPath, Any], bool]) -> list[ObjPath]: + from .language import core + is_iterable: Callable[[Any], bool] = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type)) + # We need to use dict so that ordering is maintained, while set doesn't guarantee order + ret: dict[ObjPath, None] = {} + + def _impl(path: tuple[int, ...], current: Any): + if is_iterable(current): + for idx, item in enumerate(current): + _impl((*path, idx), item) + elif pred(path, current): + ret[path] = None + + _impl((), iterable) + + return list(ret.keys()) + + +def is_power_of_two(x): + return (x & (x - 1)) == 0 + + +def validate_block_shape(shape: List[int]): + numel = 1 + for i, d in enumerate(shape): + if not isinstance(d, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]") + if not is_power_of_two(d): + raise ValueError(f"Shape element {i} must be a power of 2") + numel *= d + + if numel > TRITON_MAX_TENSOR_NUMEL: + raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") + return numel + + +type_canonicalisation_dict = { + # we canonicalise all bools to be unsigned: + "bool": "u1", + "int1": "u1", + "uint1": "u1", + "i1": "u1", + # floating-point dtypes: + "float8e4nv": "fp8e4nv", + "float8e5": "fp8e5", + "float8e4b15": "fp8e4b15", + "float8_e4m3fn": "fp8e4nv", + "float8e4b8": "fp8e4b8", + "float8_e4m3fnuz": "fp8e4b8", + "float8_e5m2": "fp8e5", + "float8e5b16": "fp8e5b16", + "float8_e5m2fnuz": "fp8e5b16", + "half": "fp16", + "float16": "fp16", + "bfloat16": "bf16", + "float": "fp32", + "float32": "fp32", + "double": "fp64", + "float64": "fp64", + # signed integers: + "int8": "i8", + "int16": "i16", + "int": "i32", + "int32": "i32", + "int64": "i64", + # unsigned integers: + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", + "void": "void", +} + +for v in list(type_canonicalisation_dict.values()): + type_canonicalisation_dict[v] = v + + +def canonicalize_dtype(dtype): + dtype_str = str(dtype).split(".")[-1] + return type_canonicalisation_dict[dtype_str] + + +BITWIDTH_DICT: Dict[str, int] = { + **{f"u{n}": n + for n in (1, 8, 16, 32, 64)}, + **{f"i{n}": n + for n in (1, 8, 16, 32, 64)}, + **{f"fp{n}": n + for n in (16, 32, 64)}, + **{f"fp8{suffix}": 8 + for suffix in ("e4nv", "e4b15", "e4b8", "e5", "e5b16")}, + "bf16": 16, + "void": 0, +} + +for k, v in type_canonicalisation_dict.items(): + BITWIDTH_DICT[k] = BITWIDTH_DICT[v] + + +def get_primitive_bitwidth(dtype: str) -> int: + return BITWIDTH_DICT[dtype] diff --git a/third_party/sunrise/python/triton/compiler/__init__.py b/third_party/sunrise/python/triton/compiler/__init__.py new file mode 100644 index 000000000..f055926fa --- /dev/null +++ b/third_party/sunrise/python/triton/compiler/__init__.py @@ -0,0 +1,4 @@ +from .compiler import CompiledKernel, ASTSource, IRSource, compile, make_backend, LazyDict +from .errors import CompilationError + +__all__ = ["compile", "make_backend", "ASTSource", "IRSource", "CompiledKernel", "CompilationError", "LazyDict"] diff --git a/third_party/sunrise/python/triton/compiler/code_generator.py b/third_party/sunrise/python/triton/compiler/code_generator.py new file mode 100644 index 000000000..7b2818858 --- /dev/null +++ b/third_party/sunrise/python/triton/compiler/code_generator.py @@ -0,0 +1,1507 @@ +import ast +import copy +import inspect +import re +import warnings +import textwrap +import itertools +from dataclasses import dataclass +from types import ModuleType +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, Iterable, List + +from .. import knobs, language +from .._C.libtriton import ir, gluon_ir +from ..language import constexpr, str_to_ty, tensor +from ..language.core import _unwrap_if_constexpr, base_value, base_type +from ..runtime.jit import get_jit_fn_file_line, get_full_name +# ideally we wouldn't need any runtime component +from ..runtime import JITFunction +from .._utils import find_paths_if, get_iterable_path, set_iterable_path + +from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) + + +def check_identifier_legality(name, type): + pattern = r'^[a-zA-Z_][a-zA-Z0-9_]*$' + if not re.match(pattern, name): + raise CompilationError(f"invalid {type} identifier: {name}", name) + return name + + +def mangle_fn(name, arg_tys, constants): + # doesn't mangle ret type, which must be a function of arg tys + mangled_arg_names = '_'.join([ty.mangle() for ty in arg_tys]) + mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)]) + mangled_constants = mangled_constants.replace('.', '_d_') + mangled_constants = mangled_constants.replace("'", '_sq_') + # [ and ] are not allowed in LLVM identifiers + mangled_constants = mangled_constants.replace('[', '_').replace(']', '_') + ret = f'{name}__{mangled_arg_names}__{mangled_constants}' + return ret + + +def _is_triton_value(o: Any) -> bool: + return isinstance(o, base_value) + + +def _is_triton_tensor(o: Any) -> bool: + return isinstance(o, tensor) + + +def _is_constexpr(o: Any) -> bool: + return o is None or isinstance(o, (constexpr, language.core.dtype, JITFunction)) + + +def _is_non_scalar_tensor(o: Any) -> bool: + return _is_triton_tensor(o) and (o.type.is_block() and o.type.numel != 1) + + +def _is_list_like(o: Any) -> bool: + return isinstance(o, (list, tuple)) + + +def _check_fn_args(node, fn, args): + if fn.noinline: + for idx, arg in enumerate(args): + if not _is_constexpr(arg) and _is_non_scalar_tensor(arg): + raise UnsupportedLanguageConstruct( + fn.src, node, + f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}' + ) + + +def _is_namedtuple(val): + return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields") + + +def _apply_to_tuple_values(value, fn): + if _is_namedtuple(type(value)): + fields = value._fields + elif isinstance(value, language.tuple): + fields = value.type.fields + else: + assert False, f"Unsupported type {type(value)}" + + vals = [fn(v) for v in value] + vals = [constexpr(v) if v is None else v for v in vals] + types = [v.type for v in vals] + return language.tuple(vals, language.tuple_type(types, fields)) + + +def flatten_values_to_ir(values: Iterable[base_value]): + handles = [] + for v in values: + v._flatten_ir(handles) + return handles + + +def unflatten_ir_values(handles: List[ir.value], types: List[base_type]): + cursor = 0 + for ty in types: + value, cursor = ty._unflatten_ir(handles, cursor) + yield value + assert cursor == len(handles) + + +_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels + + +class enter_sub_region: + + def __init__(self, generator): + self.generator = generator + + def __enter__(self): + # record lscope & local_defs in the parent scope + self.liveins = self.generator.lscope.copy() + self.prev_defs = self.generator.local_defs.copy() + self.generator.local_defs = {} + self.insert_block = self.generator.builder.get_insertion_block() + self.insert_point = self.generator.builder.get_insertion_point() + return self.liveins, self.insert_block + + def __exit__(self, *args, **kwargs): + self.generator.builder.restore_insertion_point(self.insert_point) + self.generator.lscope = self.liveins + self.generator.local_defs = self.prev_defs + + +# Check if the given syntax node has an "early" return +class ContainsReturnChecker(ast.NodeVisitor): + + def __init__(self, gscope): + self.gscope = gscope + + def _visit_stmts(self, body) -> bool: + return any(self.visit(s) for s in body) + + def _visit_function(self, fn) -> bool: + # no need to check within the function as it won't cause an early return. + # If the function itself has unstructured control flow we may not be able to inline it causing poor performance. + # We should check for this and fail or emit a warning. + return False + + def generic_visit(self, node) -> bool: + ret = False + for _, value in ast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast.AST): + ret = ret or self.visit(item) + elif isinstance(value, ast.AST): + ret = ret or self.visit(value) + return ret + + def visit_Attribute(self, node: ast.Attribute) -> bool: + # If the left part is a name, it's possible that + # we call triton native function or a jit function from another module. + # If the left part is not a name, it must return a tensor or a constexpr + # whose methods do not contain return statements + # e.g., (tl.load(x)).to(y) + # So we only check if the expressions within value have return or not + if isinstance(node.value, ast.Name): + if node.value.id in self.gscope: + value = self.gscope[node.value.id] + fn = getattr(value, node.attr) + return self._visit_function(fn) + return False + return self.visit(node.value) + + def visit_Name(self, node: ast.Name) -> bool: + if type(node.ctx) is ast.Store: + return False + if node.id in self.gscope: + fn = self.gscope[node.id] + return self._visit_function(fn) + return False + + def visit_Return(self, node: ast.Return) -> bool: + return True + + def visit_Assign(self, node: ast.Assign) -> bool: + # There couldn't be an early return + # x = ... + return False + + def visit_AugAssign(self, node: ast.AugAssign) -> bool: + # There couldn't be an early return + # x += ... + return False + + def visit_Module(self, node: ast.Module) -> bool: + return self._visit_stmts(node.body) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> bool: + return self._visit_stmts(node.body) + + def visit_If(self, node: ast.If) -> bool: + # TODO: optimize the following case in which we actually don't have + # a return when static_cond is false: + # if dynamic_cond + # if static_cond + # func_with_return + # else + # func_without_return + ret = self._visit_stmts(node.body) + if node.orelse: + ret = ret or self._visit_stmts(node.orelse) + return ret + + def visit_IfExp(self, node: ast.IfExp) -> bool: + return self.visit(node.body) or self.visit(node.orelse) + + def visit_Call(self, node: ast.Call) -> bool: + return self.visit(node.func) + + +class ASTFunction: + + def __init__(self, ret_types, arg_types, constants, attrs): + self.ret_types = ret_types + self.arg_types = arg_types + self.constants = constants + self.attrs = attrs + + def flatten_ir_types(self, builder: ir.builder, types: List[base_type]) -> List[ir.type]: + ir_types = [] + for ty in types: + if ty is None: + continue + ty._flatten_ir_types(builder, ir_types) + return ir_types + + def return_types_ir(self, builder: ir.builder) -> List[ir.type]: + return self.flatten_ir_types(builder, self.ret_types) + + def serialize(self, builder: ir.builder): + # fill up IR values in template + # > build function + is_val = lambda path, _: path not in self.constants and _ is not None + val_paths = list(find_paths_if(self.arg_types, is_val)) + arg_types = [get_iterable_path(self.arg_types, path) for path in val_paths] + arg_types_ir = self.flatten_ir_types(builder, arg_types) + ret_types_ir = self.return_types_ir(builder) + return builder.get_function_ty(arg_types_ir, ret_types_ir) + + def deserialize(self, fn): + # create "template" + def make_template(ty): + if isinstance(ty, (list, tuple, language.tuple_type)): + return language.tuple([make_template(x) for x in ty], ty) + return language.constexpr(None) + + vals = make_template(self.arg_types) + is_val = lambda path, _: path not in self.constants and _ is not None + val_paths = list(find_paths_if(self.arg_types, is_val)) + # > add IR values to the template + cursor = 0 + handles = [fn.args(i) for i in range(fn.get_num_args())] + for path in val_paths: + ty = get_iterable_path(self.arg_types, path) + # > set attributes + attr_specs = self.attrs.get(path, []) + for attr_name, attr_val in attr_specs: + fn.set_arg_attr(cursor, attr_name, attr_val) + # > build frontend value + val, cursor = ty._unflatten_ir(handles, cursor) + set_iterable_path(vals, path, val) + # > add constexpr values to the template + constants = self.constants + for path, val in constants.items(): + set_iterable_path(vals, path, language.constexpr(val)) + return vals + + +@dataclass(frozen=True) +class BoundJITMethod: + __self__: base_value + __func__: JITFunction + + +class CodeGenerator(ast.NodeVisitor): + + def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, options, codegen_fns, module_map, + module=None, is_kernel=False, function_types: Optional[Dict] = None, noinline=False, + file_name: Optional[str] = None, begin_line=0): + self.context = context + if jit_fn.is_gluon(): + from triton.experimental.gluon.language._semantic import GluonSemantic + self.builder = gluon_ir.GluonOpBuilder(context) + self.semantic = GluonSemantic(self.builder) + else: + from triton.language.semantic import TritonSemantic + self.builder = ir.builder(context) + self.semantic = TritonSemantic(self.builder) + self.file_name = file_name + # node.lineno starts from 1, so we need to subtract 1 + self.begin_line = begin_line - 1 + self.builder.set_loc(file_name, begin_line, 0) + self.builder.options = options + # dict of functions provided by the backend. Below are the list of possible functions: + # Convert custom types not natively supported on HW. + # convert_custom_types(input_tensor, dtype, fp_downcast_rounding=None, _builder=None) + self.builder.codegen_fns = codegen_fns + self.builder.module_map = {} if module_map is None else module_map + self.module = self.builder.create_module() if module is None else module + self.function_ret_types = {} if function_types is None else function_types + self.prototype = prototype + + self.gscope = {} + for k, v in gscope.items(): + if isinstance(v, ModuleType): + self.gscope[k] = module_map.get(v.__name__, v) + continue + + module_name = getattr(v, "__module__", "") + if module_name in module_map: + self.gscope[k] = getattr(module_map[module_name], v.__name__) + else: + self.gscope[k] = v + + self.lscope = {} + self.jit_fn = jit_fn + # TODO: we currently generate illegal names for non-kernel functions involving constexprs! + if is_kernel: + function_name = function_name[function_name.rfind('.') + 1:] + function_name = check_identifier_legality(function_name, "function") + self.function_name = function_name + self.is_kernel = is_kernel + self.cur_node = None + self.noinline = noinline + self.scf_stack = [] + self.ret_type = None + # SSA-construction + # name => language.tensor + self.local_defs: Dict[str, tensor] = {} + self.dereference_name: Callable[[str], Any] = self._define_name_lookup() + self.fn = None + # Are we currently visiting an ast.arg's default value? These have some + # special handling. + self.visiting_arg_default_value = False + + builtin_namespace: Dict[str, Any] = { + _.__name__: _ + for _ in (len, list, range, float, int, isinstance, getattr, hasattr) + } + builtin_namespace.update(( + ('print', language.core.device_print), + ('min', language.minimum), + ('max', language.maximum), + )) + + def _unsupported(self, node, message): + return UnsupportedLanguageConstruct(self.jit_fn.src, node, message) + + def _is_constexpr_global(self, name): + absent_marker = object() + val = self.gscope.get(name, absent_marker) + if val is absent_marker: + return False + + if _is_constexpr(val): + return True + + return False + + def _define_name_lookup(self): + + def local_lookup(name: str, absent): + # this needs to be re-fetched from `self` every time, because it gets switched occasionally + return self.lscope.get(name, absent) + + def global_lookup(name: str, absent): + val = self.gscope.get(name, absent) + # The high-level rule is that only constexpr globals are allowed. + # But actually a bunch of other things, such as module imports, are + # technically Python globals. We have to allow these too! + if any([ + val is absent, + name in self.builtin_namespace, # + type(val) is ModuleType, # + isinstance(val, JITFunction), # + getattr(val, "__triton_builtin__", False), # + getattr(val, "__triton_aggregate__", False), # + getattr(val, "__module__", "").startswith("triton.language"), # + getattr(val, "__module__", "").startswith("triton.experimental.gluon.language"), # + isinstance(val, language.dtype), # + _is_namedtuple(val), + self._is_constexpr_global(name), # + # Allow accesses to globals while visiting an ast.arg + # because you should be able to do + # @triton.jit def fn(x: tl.constexpr = GLOBAL): ... + self.visiting_arg_default_value, # + knobs.compilation.allow_non_constexpr_globals, + ]): + return val + raise NameError( + textwrap.dedent(f"""\ + Cannot access global variable {name} from within @jit'ed + function. Triton kernels can only access global variables that + are instanstiated as constexpr (`x = triton.language.constexpr(42)`). Note that this is different from + annotating a variable as constexpr (`x: triton.language.constexpr = 42`), which is not supported. Alternatively, set the + envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not + promise to support this forever.""").replace("\n", " ")) + + absent_marker = object() + + def name_lookup(name: str) -> Any: + absent = absent_marker + for lookup_function in local_lookup, global_lookup, self.builtin_namespace.get: + value = lookup_function(name, absent) + if value is not absent: + return value + raise NameError(f'{name} is not defined') + + return name_lookup + + def set_value(self, name: str, value: Union[base_value, constexpr]) -> None: + ''' This function: + called by visit_Assign() & visit_FunctionDef() to store left value (lvalue) + 1. record local defined name (FIXME: should consider control flow) + 2. store tensor in self.lvalue + ''' + self.lscope[name] = value + self.local_defs[name] = value + + def _get_insertion_point_and_loc(self): + # XXX: this is a hack to get the location of the insertion point. + # The insertion point's location could be invalid sometimes, + # so we need to explicitly set the location + loc = self.builder.get_loc() + ip = self.builder.get_insertion_point() + return ip, loc + + def _set_insertion_point_and_loc(self, ip, loc): + self.builder.restore_insertion_point(ip) + self.builder.set_loc(loc) + + # + # AST visitor + # + def visit_compound_statement(self, stmts): + # Ensure that stmts is iterable + if not _is_list_like(stmts): + stmts = [stmts] + for stmt in stmts: + self.visit(stmt) + # Stop parsing as soon as we hit a `return` statement; everything + # after this is dead code. + if isinstance(stmt, ast.Return): + break + + def visit_Module(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_List(self, node): + ctx = self.visit(node.ctx) + assert ctx is None + elts = language.tuple([self.visit(elt) for elt in node.elts]) + return elts + + # By design, only non-kernel functions can return + def visit_Return(self, node): + ret_value = self.visit(node.value) + handles = [] + + def decay(value): + if isinstance(value, language.tuple): + return _apply_to_tuple_values(value, decay) + elif isinstance(value, (language.constexpr, int, float)): + return self.semantic.to_tensor(value) + return value + + ret_value = decay(ret_value) + + if ret_value is None: + ret_ty = language.void + else: + assert isinstance(ret_value, language.core.base_value) + ret_value._flatten_ir(handles) + ret_ty = ret_value.type + self.builder.ret(handles) + if self.ret_type is None: + self.ret_type = ret_ty + elif self.ret_type != ret_ty: + raise TypeError(f'Inconsistent return types: {self.ret_type} and {ret_ty}') + + # A return op must always terminate the basic block, so we create a dead + # basic block in case there are any ops after the return. + post_ret_block = self.builder.create_block() + self.builder.set_insertion_point_to_end(post_ret_block) + + def visit_Starred(self, node) -> Any: + args = self.visit(node.value) + assert isinstance(args, language.core.tuple) + return args.values + + def visit_FunctionDef(self, node): + arg_names, kwarg_names = self.visit(node.args) + if self.fn: + raise self._unsupported(node, "nested function definition is not supported.") + # initialize defaults + for i, default_value in enumerate(node.args.defaults[::-1]): + arg_node = node.args.args[-i - 1] + annotation = arg_node.annotation + name = arg_node.arg + st_target = ast.Name(id=name, ctx=ast.Store()) + if annotation is None: + init_node = ast.Assign(targets=[st_target], value=default_value) + else: + init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + self.visit(init_node) + finally: + self.visiting_arg_default_value = False + + # initialize function + visibility = "public" if self.is_kernel else "private" + fn_ty = self.prototype.serialize(self.builder) + self.fn = self.builder.get_or_insert_function(self.module, self.function_name, fn_ty, visibility, self.noinline) + self.module.push_back(self.fn) + entry = self.fn.add_entry_block() + arg_values = self.prototype.deserialize(self.fn) + # bind arguments to symbols + for arg_name, arg_value in zip(arg_names, arg_values): + self.set_value(arg_name, arg_value) + insert_pt = self.builder.get_insertion_block() + self.builder.set_insertion_point_to_start(entry) + # visit function body + self.visit_compound_statement(node.body) + + # finalize function + assert not self.builder.get_insertion_block().has_terminator() + if self.ret_type is None or self.ret_type == language.void: + self.ret_type = language.void + self.builder.ret([]) + else: + if isinstance(self.ret_type, language.tuple_type): + self.prototype.ret_types = self.ret_type.types + else: + self.prototype.ret_types = [self.ret_type] + self.fn.reset_type(self.prototype.serialize(self.builder)) + self.builder.ret([self.builder.create_poison(ty) for ty in self.prototype.return_types_ir(self.builder)]) + self.fn.finalize() + + if insert_pt: + self.builder.set_insertion_point_to_end(insert_pt) + + def visit_arguments(self, node): + arg_names = [] + for arg in node.args: + arg_names += [self.visit(arg)] + kwarg_names = self.visit(node.kwarg) + return arg_names, kwarg_names + + def visit_arg(self, node): + ast.NodeVisitor.generic_visit(self, node) + return node.arg + + def visit_AnnAssign(self, node): + # extract attributes + annotation = self.visit(node.annotation) + target = self.visit(node.target) + value = self.visit(node.value) + # constexpr + if annotation == constexpr: + if target in self.lscope: + raise ValueError(f'{target} is already defined.' + f' constexpr cannot be reassigned.') + value = constexpr(value) + self.lscope[target] = value + return self.lscope[target] + # default: call visit_Assign + return self.visit_Assign(node) + + def assignTarget(self, target, value): + assert isinstance(target.ctx, ast.Store) + if isinstance(target, ast.Subscript): + return self.visit_Subscript_Store(target, value) + if isinstance(target, ast.Tuple): + for i, target in enumerate(target.elts): + self.assignTarget(target, value.values[i]) + return + if isinstance(target, ast.Attribute): + base = self.visit(target.value) + setattr(base, target.attr, value) + return + assert isinstance(target, ast.Name) + self.set_value(self.visit(target), value) + + def visit_Assign(self, node): + # construct values to assign + def _sanitize_value(value): + if isinstance(value, language.tuple): + return _apply_to_tuple_values(value, _sanitize_value) + native_nontensor_types = (language.dtype, language.tuple) + value = _unwrap_if_constexpr(value) + if value is not None and \ + not _is_triton_value(value) and \ + not isinstance(value, native_nontensor_types): + value = self.semantic.to_tensor(value) + return value + + values = _sanitize_value(self.visit(node.value)) + targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets + assert len(targets) == 1 + self.assignTarget(targets[0], values) + + def visit_AugAssign(self, node): + lhs = copy.deepcopy(node.target) + lhs.ctx = ast.Load() + rhs = ast.BinOp(lhs, node.op, node.value) + assign = ast.Assign(targets=[node.target], value=rhs) + self.visit(assign) + return self.visit(lhs) + + def visit_Name(self, node): + if type(node.ctx) is ast.Store: + return node.id + return self.dereference_name(node.id) + + def visit_Store(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Load(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Tuple(self, node): + args = [self.visit(x) for x in node.elts] + return language.tuple(args) + + def _apply_binary_method(self, method_name, lhs, rhs): + # TODO: raise something meaningful if getattr fails below, esp for reverse method + if _is_triton_tensor(lhs): + return getattr(lhs, method_name)(rhs, _semantic=self.semantic) + if _is_triton_tensor(rhs): + reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name) + return getattr(rhs, reverse_method_name)(lhs, _semantic=self.semantic) + if not isinstance(lhs, (constexpr, language.tuple)) and isinstance(rhs, constexpr): + lhs = constexpr(lhs) + return getattr(lhs, method_name)(rhs) + + def visit_BinOp(self, node): + lhs = self.visit(node.left) + rhs = self.visit(node.right) + method_name = self._method_name_for_bin_op.get(type(node.op)) + if method_name is None: + raise self._unsupported(node, + "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bin_op: Dict[Type[ast.operator], str] = { + ast.Add: '__add__', + ast.Sub: '__sub__', + ast.Mult: '__mul__', + ast.Div: '__truediv__', + ast.FloorDiv: '__floordiv__', + ast.Mod: '__mod__', + ast.Pow: '__pow__', + ast.LShift: '__lshift__', + ast.RShift: '__rshift__', + ast.BitAnd: '__and__', + ast.BitOr: '__or__', + ast.BitXor: '__xor__', + } + + def visit_then_else_blocks(self, node, liveins, then_block, else_block): + # then block + self.builder.set_insertion_point_to_start(then_block) + self.visit_compound_statement(node.body) + then_block = self.builder.get_insertion_block() + then_defs = self.local_defs.copy() + # else block + else_defs = {} + if node.orelse: + self.builder.set_insertion_point_to_start(else_block) + self.lscope = liveins.copy() + self.local_defs = {} + self.visit_compound_statement(node.orelse) + else_defs = self.local_defs.copy() + else_block = self.builder.get_insertion_block() + + # update block arguments + names = [] + # variables in livein whose value is updated in `if` + for name in liveins: + # check type + for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]: + if name in defs: + type_equal = type(defs[name]) == type(liveins[name]) # noqa: E721 + assert type_equal and defs[name].type == liveins[name].type, \ + f'initial value for `{name}` is of type {liveins[name]}, '\ + f'but the {block_name} block redefines it as {defs[name]}' + if name in then_defs or name in else_defs: + names.append(name) + # variable defined in then but not in else + if name in then_defs and name not in else_defs: + else_defs[name] = liveins[name] + # variable defined in else but not in then + if name in else_defs and name not in then_defs: + then_defs[name] = liveins[name] + # variables that are both in then and else but not in liveins + # TODO: could probably be cleaned up + for name in sorted(then_defs.keys() & else_defs.keys()): + if name in names: + continue + then_val = then_defs[name] + then_ty = then_val.type + else_val = else_defs[name] + else_ty = else_val.type + type_equal = type(then_val) == type(else_val) # noqa: E721 + assert type_equal and then_ty == else_ty, \ + f'Mismatched type for {name} between then block ({then_ty}) '\ + f'and else block ({else_ty})' + names.append(name) + + return then_defs, else_defs, then_block, else_block, names + + def visit_if_top_level(self, cond, node): + with enter_sub_region(self) as sr: + liveins, ip_block = sr + then_block = self.builder.create_block() + else_block = self.builder.create_block() + # create branch + self.builder.set_insertion_point_to_end(ip_block) + self.builder.create_cond_branch(cond.handle, then_block, else_block) + # visit then and else blocks + then_defs, else_defs, then_block, else_block, names = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create basic-block after conditional + endif_block = self.builder.create_block() + # then terminator + self.builder.set_insertion_point_to_end(then_block) + assert not then_block.has_terminator(), f"{then_block}" + then_handles = flatten_values_to_ir(then_defs[name] for name in names) + self.builder.create_branch(endif_block, then_handles) + # else terminator + self.builder.set_insertion_point_to_end(else_block) + assert not else_block.has_terminator(), f"{else_block}" + else_handles = flatten_values_to_ir(else_defs[name] for name in names) + self.builder.create_branch(endif_block, else_handles) + assert len(then_handles) == len(else_handles) + for then_h, else_h in zip(then_handles, else_handles): + ty = then_h.get_type() + assert ty == else_h.get_type() + endif_block.add_argument(ty) + + # change block + self.builder.set_insertion_point_to_start(endif_block) + # update value + res_handles = [endif_block.arg(i) for i in range(len(then_handles))] + types = [then_defs[name].type for name in names] + new_values = unflatten_ir_values(res_handles, types) + for name, new_value in zip(names, new_values): + self.set_value(name, new_value) + + # TODO: refactor + def visit_if_scf(self, cond, node): + with enter_sub_region(self) as sr: + liveins, _ = sr + ip, last_loc = self._get_insertion_point_and_loc() + then_block = self.builder.create_block() + else_block = self.builder.create_block() if node.orelse else None + then_defs, else_defs, then_block, else_block, names = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create if op + then_handles = flatten_values_to_ir(then_defs[name] for name in names) + self._set_insertion_point_and_loc(ip, last_loc) + if_op = self.builder.create_if_op([h.get_type() for h in then_handles], cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + if len(names) > 0: + self.builder.create_yield_op(then_handles) + if not node.orelse: + else_block = if_op.get_else_block() + else: + else_block.merge_block_before(if_op.get_else_block()) + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + if len(names) > 0: + else_handles = flatten_values_to_ir(else_defs[name] for name in names) + self.builder.create_yield_op(else_handles) + # update values + res_handles = [if_op.get_result(i) for i in range(len(then_handles))] + types = [then_defs[name].type for name in names] + new_values = unflatten_ir_values(res_handles, types) + for name, new_value in zip(names, new_values): + self.set_value(name, new_value) + + def visit_If(self, node): + cond = self.visit(node.test) + + if _is_triton_tensor(cond): + if _is_non_scalar_tensor(cond): + raise self._unsupported(node, "Boolean value of Tensor with more than one value is ambiguous") + if cond.type.is_block(): + warnings.warn( + "If conditional called with multidimensional Tensor instead of scalar; please use \"if (%s).item()\" instead" + % ast.unparse(node.test)) + cond = language.core._unsplat(cond, _semantic=self.semantic, _generator=self) + cond = cond.to(language.int1, _semantic=self.semantic) + contains_return = ContainsReturnChecker(self.gscope).visit(node) + if contains_return: + if self.scf_stack: + raise self._unsupported( + node, "Cannot have `return` statements inside `while` or `for` statements in triton " + "(note that this also applies to `return` statements that are inside functions " + "transitively called from within `while`/`for` statements)") + self.visit_if_top_level(cond, node) + else: + self.visit_if_scf(cond, node) + else: + cond = _unwrap_if_constexpr(cond) + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + + active_block = node.body if cond else node.orelse + self.visit_compound_statement(active_block) + + def visit_IfExp(self, node): + cond = self.visit(node.test) + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _semantic=self.semantic) + # TODO: Deal w/ more complicated return types (e.g tuple) + with enter_sub_region(self): + ip, last_loc = self._get_insertion_point_and_loc() + + then_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(then_block) + then_val = self.semantic.to_tensor(self.visit(node.body)) + then_block = self.builder.get_insertion_block() + + else_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(else_block) + # do not need to reset lscope since + # ternary expressions cannot define new variables + else_val = self.semantic.to_tensor(self.visit(node.orelse)) + else_block = self.builder.get_insertion_block() + + self._set_insertion_point_and_loc(ip, last_loc) + + assert then_val.type == else_val.type, \ + f'Ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}' + ret_type = then_val.type + + ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else [] + if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + self.builder.create_yield_op([then_val.handle]) + + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + else_block.merge_block_before(if_op.get_else_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + self.builder.create_yield_op([else_val.handle]) + return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None + else: + cond = _unwrap_if_constexpr(cond) + + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + if cond: + return self.visit(node.body) + else: + return self.visit(node.orelse) + + def visit_Pass(self, node): + pass + + def visit_Compare(self, node): + if not (len(node.comparators) == 1 and len(node.ops) == 1): + raise self._unsupported(node, "simultaneous multiple comparison is not supported") + lhs = self.visit(node.left) + rhs = self.visit(node.comparators[0]) + lhs_value = _unwrap_if_constexpr(lhs) + rhs_value = _unwrap_if_constexpr(rhs) + if type(node.ops[0]) is ast.Is: + return constexpr(lhs_value is rhs_value) + if type(node.ops[0]) is ast.IsNot: + return constexpr(lhs_value is not rhs_value) + method_name = self._method_name_for_comp_op.get(type(node.ops[0])) + if method_name is None: + raise self._unsupported( + node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = { + ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__' + } + + def visit_UnaryOp(self, node): + operand = self.visit(node.operand) + fn = self._method_name_for_unary_op.get(type(node.op)) + if fn is None: + raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.") + if _is_triton_tensor(operand): + return getattr(operand, fn)(_semantic=self.semantic) + try: + return getattr(operand, fn)() + except AttributeError: + if fn == "__not__": + return constexpr(not operand) + raise self._unsupported( + node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}") + + _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = { + ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__' + } + + def _verify_loop_carried_variable(self, name, loop_val, live_val): + assert _is_triton_value(loop_val), f'cannot reassign constxpr {name} in the loop' + assert _is_triton_value(live_val), f'cannot reasign constexpr {name} in the loop' + assert type(loop_val) is type(live_val), f'Loop carried variable {name} changed type' + assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \ + f'Loop-carried variable {name} has initial type {live_val.type} '\ + f'but is re-assigned to {loop_val.type} in loop! '\ + f'Please make sure that the type stays consistent.' + + def visit_withitem(self, node): + return self.visit(node.context_expr) + + def visit_With(self, node): + assert len(node.items) == 1 + context = node.items[0].context_expr + withitemClass = self.visit(context.func) + if withitemClass == language.async_task: + args = [self.visit(arg) for arg in context.args] + with withitemClass(*args, _builder=self.builder): + self.visit_compound_statement(node.body) + else: + self.visit_compound_statement(node.body) + + def visit_While(self, node): + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # loop body (the after region) + # loop_block = self.builder.create_block() + dummy = self.builder.create_block() + self.builder.set_insertion_point_to_start(dummy) + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + dummy.erase() + + # collect loop-carried values + names = [] + init_args = [] + for name in loop_defs: + if name in liveins: + # We should not def new constexpr + loop_val = loop_defs[name] + live_val = liveins[name] + self._verify_loop_carried_variable(name, loop_val, live_val) + + # these are loop-carried values + names.append(name) + init_args.append(live_val) + + init_handles = flatten_values_to_ir(init_args) + init_tys = [h.get_type() for h in init_handles] + init_fe_tys = [a.type for a in init_args] + self._set_insertion_point_and_loc(ip, last_loc) + while_op = self.builder.create_while_op(init_tys, init_handles) + # merge the condition region + before_block = self.builder.create_block_with_parent(while_op.get_before(), init_tys) + self.builder.set_insertion_point_to_start(before_block) + block_args = [before_block.arg(i) for i in range(len(init_handles))] + condition_args = unflatten_ir_values(block_args, init_fe_tys) + for name, val in zip(names, condition_args): + self.lscope[name] = val + self.local_defs[name] = val + cond = self.visit(node.test) + self.builder.set_insertion_point_to_end(before_block) + # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... + self.builder.create_condition_op(cond.handle, block_args) + # merge the loop body + after_block = self.builder.create_block_with_parent(while_op.get_after(), init_tys) + + # generate loop body + self.builder.set_insertion_point_to_start(after_block) + body_handles = [after_block.arg(i) for i in range(len(init_handles))] + body_args = unflatten_ir_values(body_handles, init_fe_tys) + for name, val in zip(names, body_args): + self.lscope[name] = val + self.local_defs[name] = val + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + yields = [] + for name in loop_defs: + if name in liveins: + loop_defs[name]._flatten_ir(yields) + + self.builder.create_yield_op(yields) + + # WhileOp defines new values, update the symbol table (lscope, local_defs) + result_handles = [while_op.get_result(i) for i in range(len(init_handles))] + result_vals = unflatten_ir_values(result_handles, init_fe_tys) + for name, new_def in zip(names, result_vals): + self.lscope[name] = new_def + self.local_defs[name] = new_def + + for stmt in node.orelse: + assert False, "Not implemented" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Subscript_Load(self, node): + assert isinstance(node.ctx, ast.Load) + lhs = self.visit(node.value) + slices = self.visit(node.slice) + if _is_triton_tensor(lhs): + return lhs.__getitem__(slices, _semantic=self.semantic) + return lhs[slices] + + def visit_Subscript_Store(self, node, value): + assert isinstance(node.ctx, ast.Store) + lhs = self.visit(node.value) + slices = self.visit(node.slice) + assert isinstance(lhs, language.tuple) + lhs.__setitem__(slices, value) + + def visit_Subscript(self, node): + return self.visit_Subscript_Load(node) + + def visit_ExtSlice(self, node): + return [self.visit(dim) for dim in node.dims] + + def visit_For(self, node): + IteratorClass = self.visit(node.iter.func) + iter_args = [self.visit(arg) for arg in node.iter.args] + iter_kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords) + if IteratorClass == language.static_range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + static_range = range(iterator.start.value, iterator.end.value, iterator.step.value) + for i in static_range: + self.lscope[node.target.id] = constexpr(i) + self.visit_compound_statement(node.body) + for stmt in node.orelse: + ast.NodeVisitor.generic_visit(self, stmt) + return + num_stages = None + loop_unroll_factor = None + disallow_acc_multi_buffer = False + flatten = False + warp_specialize = False + if IteratorClass is language.range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iterator.start + ub = iterator.end + step = iterator.step + num_stages = iterator.num_stages + loop_unroll_factor = iterator.loop_unroll_factor + disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer + flatten = iterator.flatten + warp_specialize = iterator.warp_specialize + elif IteratorClass is range: + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0)) + ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0]) + step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) + else: + raise RuntimeError('Only `range` and `static_range` iterators are currently supported') + # handle negative constant step (not supported by scf.for in MLIR) + negative_step = False + if _is_constexpr(step) and step.value < 0: + step = constexpr(-step.value) + negative_step = True + lb, ub = ub, lb + lb = self.semantic.to_tensor(lb) + ub = self.semantic.to_tensor(ub) + step = self.semantic.to_tensor(step) + # induction variable type + if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int(): + raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})") + iv_type = self.semantic.integer_promote_impl(lb.dtype, ub.dtype) + iv_type = self.semantic.integer_promote_impl(iv_type, step.dtype) + iv_ir_type = iv_type.to_ir(self.builder) + iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED + # lb/ub/step might be constexpr, we need to cast them to tensor + lb = lb.handle + ub = ub.handle + step = step.handle + # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index + lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed) + ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed) + step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed) + # Create placeholder for the loop induction variable + iv = self.builder.create_poison(iv_ir_type) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # create loop body block + block = self.builder.create_block() + self.builder.set_insertion_point_to_start(block) + # dry visit loop body + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + block.erase() + + # If a variable (name) is defined in both its parent & itself, then it's + # a loop-carried variable. (They must be of the same type) + init_args = [] + yields = [] + names = [] + for name in self.local_defs: + if name in liveins: + loop_val = self.local_defs[name] + live_val = liveins[name] + self._verify_loop_carried_variable(name, loop_val, live_val) + + names.append(name) + init_args.append(live_val) + yields.append(loop_val) + + # create ForOp + self._set_insertion_point_and_loc(ip, last_loc) + init_handles = flatten_values_to_ir(init_args) + init_tys = [v.type for v in init_args] + for_op = self.builder.create_for_op(lb, ub, step, init_handles) + if _unwrap_if_constexpr(num_stages) is not None: + for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) + if _unwrap_if_constexpr(loop_unroll_factor) is not None: + for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor)) + if disallow_acc_multi_buffer: + for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_unit_attr()) + if flatten: + for_op.set_attr("tt.flatten", self.builder.get_unit_attr()) + if warp_specialize: + for_op.set_attr("tt.warp_specialize", self.builder.get_unit_attr()) + + self.scf_stack.append(node) + for_op_body = for_op.get_body(0) + self.builder.set_insertion_point_to_start(for_op_body) + # reset local scope to not pick up local defs from the previous dry run. + self.lscope = liveins.copy() + self.local_defs = {} + block_handles = [for_op_body.arg(i + 1) for i in range(len(init_handles))] + block_args = unflatten_ir_values(block_handles, init_tys) + for name, val in zip(names, block_args): + self.set_value(name, val) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + yields = [] + for name in self.local_defs: + if name in liveins: + local = self.local_defs[name] + if isinstance(local, constexpr): + local = self.semantic.to_tensor(local) + yields.append(local) + + # create YieldOp + if len(yields) > 0: + yield_handles = flatten_values_to_ir(yields) + self.builder.create_yield_op(yield_handles) + for_op_region = for_op_body.get_parent() + assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" + + # update induction variable with actual value, and replace all uses + self.builder.set_insertion_point_to_start(for_op_body) + iv = for_op.get_induction_var() + if negative_step: + iv = self.builder.create_sub(ub, iv) + iv = self.builder.create_add(iv, lb) + self.lscope[node.target.id].handle.replace_all_uses_with(iv) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + # update lscope & local_defs (ForOp defines new values) + result_handles = [for_op.get_result(i) for i in range(len(init_handles))] + result_values = unflatten_ir_values(result_handles, init_tys) + for name, val in zip(names, result_values): + self.set_value(name, val) + + for stmt in node.orelse: + assert False, "Don't know what to do with else after for" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Slice(self, node): + lower = self.visit(node.lower) + upper = self.visit(node.upper) + step = self.visit(node.step) + return language.slice(lower, upper, step) + + def visit_Index(self, node): + return self.visit(node.value) + + def visit_keyword(self, node) -> Tuple[str, Any]: + return node.arg, self.visit(node.value) + + def visit_Assert(self, node) -> Any: + test = self.visit(node.test) + msg = self.visit(node.msg) if node.msg is not None else "" + return language.core.device_assert(test, msg, _semantic=self.semantic) + + def call_JitFunction(self, fn: JITFunction, args, kwargs): + args = inspect.getcallargs(fn.fn, *args, **kwargs) + args = [args[name] for name in fn.arg_names] + for i, arg in enumerate(args): + if isinstance(arg, (language.dtype, float, int, bool, JITFunction)): + args[i] = language.core.constexpr(arg) + args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x)) + args_cst = {path: get_iterable_path(args, path) for path in args_cst} + args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x)) + args_val = [get_iterable_path(args, path) for path in args_path] + # mangle + fn_name = mangle_fn(get_full_name(fn), [arg.type for arg in args_val], args_cst) + # generate function def if necessary + if not self.module.has_function(fn_name): + # If the callee is not set, we use the same debug setting as the caller + file_name, begin_line = get_jit_fn_file_line(fn) + arg_types = [ + language.core.constexpr if arg is None or isinstance(arg, + (bool, int, language.core.dtype)) else arg.type + for arg in args + ] + prototype = ASTFunction([], arg_types, args_cst, dict()) + generator = CodeGenerator(self.context, prototype, fn.get_capture_scope(), module=self.module, jit_fn=fn, + function_name=fn_name, function_types=self.function_ret_types, + noinline=fn.noinline, file_name=file_name, begin_line=begin_line, + options=self.builder.options, codegen_fns=self.builder.codegen_fns, + module_map=self.builder.module_map) + try: + generator.visit(fn.parse()) + except Exception as e: + # Wrap the error in the callee with the location of the call. + if knobs.compilation.front_end_debugging: + raise + raise CompilationError(self.jit_fn.src, self.cur_node, None) from e + + callee_ret_type = generator.ret_type + self.function_ret_types[fn_name] = callee_ret_type + else: + callee_ret_type = self.function_ret_types[fn_name] + symbol = self.module.get_function(fn_name) + args_val = flatten_values_to_ir(args_val) + call_op = self.builder.call(symbol, args_val) + if callee_ret_type == language.void: + return None + handles = [call_op.get_result(i) for i in range(call_op.get_num_results())] + return next(unflatten_ir_values(handles, [callee_ret_type])) + + def visit_Call(self, node): + fn = _unwrap_if_constexpr(self.visit(node.func)) + if not isinstance(fn, BoundJITMethod): + static_implementation = self.statically_implemented_functions.get(fn) + if static_implementation is not None: + return static_implementation(self, node) + + mur = getattr(fn, '_must_use_result', False) + if mur and getattr(node, '_is_unused', False): + error_message = ["The result of %s is not being used." % ast.unparse(node.func)] + if isinstance(mur, str): + error_message.append(mur) + raise CompilationError(self.jit_fn.src, node, " ".join(error_message)) + + kws = dict(self.visit(keyword) for keyword in node.keywords) + args = [self.visit(arg) for arg in node.args] + args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args)) + if isinstance(fn, BoundJITMethod): + args.insert(0, fn.__self__) + fn = fn.__func__ + if isinstance(fn, JITFunction): + _check_fn_args(node, fn, args) + return self.call_JitFunction(fn, args, kws) + if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn): + extra_kwargs = {"_semantic": self.semantic} + sig = inspect.signature(fn) + if '_generator' in sig.parameters: + extra_kwargs['_generator'] = self + try: + ret = fn(*args, **extra_kwargs, **kws) + # builtin functions return plain tuples for readability + if isinstance(ret, tuple): + ret = language.tuple(ret) + return ret + except Exception as e: + if knobs.compilation.front_end_debugging: + raise + # Normally when we raise a CompilationError, we raise it as + # `from None`, because the original fileline from the exception + # is not relevant (and often points into code_generator.py + # itself). But when calling a function, we raise as `from e` to + # preserve the traceback of the original error, which may e.g. + # be in core.py. + raise CompilationError(self.jit_fn.src, node, None) from e + + if fn in self.builtin_namespace.values(): + args = map(_unwrap_if_constexpr, args) + ret = fn(*args, **kws) + return _apply_to_tuple_values(ret, lambda x: x) if _is_namedtuple(type(ret)) else ret + + def visit_Constant(self, node): + return constexpr(node.value) + + def visit_BoolOp(self, node: ast.BoolOp): + method_name = self._method_name_for_bool_op.get(type(node.op)) + if method_name is None: + raise self._unsupported( + node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__)) + + nontrivial_values = [] + + for subnode in node.values: + # we visit the values in order, executing their side-effects + # and possibly early-exiting: + value = self.visit(subnode) + if not _is_triton_tensor(value): + # this is a constexpr, so we might be able to short-circuit: + bv = bool(value) + if (bv is False) and (method_name == "logical_and"): + # value is falsey so return that: + return value + if (bv is True) and (method_name == "logical_or"): + # value is truthy so return that: + return value + # otherwise, our constexpr has no effect on the output of the + # expression so we do not append it to nontrivial_values. + else: + if value.type.is_block(): + lineno = getattr(node, "lineno", None) + if lineno is not None: + lineno += self.begin_line + warnings.warn_explicit( + "Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead", + category=UserWarning, + filename=self.file_name, + lineno=lineno, + source=ast.unparse(node), + ) + # not a constexpr so we must append it: + nontrivial_values.append(value) + + if len(nontrivial_values) == 0: + # the semantics of a disjunction of falsey values or conjunction + # of truthy values is to return the final value: + nontrivial_values.append(value) + + while len(nontrivial_values) >= 2: + rhs = nontrivial_values.pop() + lhs = nontrivial_values.pop() + res = self._apply_binary_method(method_name, lhs, rhs) + nontrivial_values.append(res) + + assert len(nontrivial_values) == 1 + return nontrivial_values[0] + + _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'} + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + if _is_triton_tensor(lhs) and node.attr == "T": + return self.semantic.permute(lhs, (1, 0)) + # NOTE: special case ".value" for BC + if isinstance(lhs, constexpr) and node.attr != "value": + lhs = lhs.value + attr = getattr(lhs, node.attr) + if _is_triton_value(lhs) and isinstance(attr, JITFunction): + return BoundJITMethod(lhs, attr) + return attr + + def visit_Expr(self, node): + node.value._is_unused = True + ast.NodeVisitor.generic_visit(self, node) + + def visit_NoneType(self, node): + return None + + def visit_JoinedStr(self, node): + values = list(node.values) + for i, value in enumerate(values): + if isinstance(value, ast.Constant): + values[i] = str(value.value) + elif isinstance(value, ast.FormattedValue): + conversion_code = value.conversion + evaluated = self.visit(value.value) + if not _is_constexpr(evaluated): + raise self._unsupported( + node, + "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + + str(type(evaluated))) + values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value) + else: + raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value))) + return ''.join(values) + + def visit(self, node): + if node is None: + return + with warnings.catch_warnings(): + # The ast library added visit_Constant and deprecated some other + # methods but we can't move to that without breaking Python 3.6 and 3.7. + warnings.simplefilter("ignore", DeprecationWarning) # python 3.9 + warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8 + last_node = self.cur_node + last_loc = self.builder.get_loc() + self.cur_node = node + if hasattr(node, 'lineno') and hasattr(node, 'col_offset'): + self.builder.set_loc(self.file_name, self.begin_line + node.lineno, node.col_offset) + last_loc = self.builder.get_loc() + try: + ret = super().visit(node) + except CompilationError: + raise + except Exception as e: + if knobs.compilation.front_end_debugging: + raise + # Wrap the error in a CompilationError which contains the source + # of the @jit function. + raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None + + # Reset the location to the last one before the visit + if last_loc: + self.cur_node = last_node + self.builder.set_loc(last_loc) + return ret + + def generic_visit(self, node): + raise self._unsupported(node, "unsupported AST node type: {}".format(type(node).__name__)) + + def execute_static_assert(self, node: ast.Call) -> None: + arg_count = len(node.args) + if not (0 < arg_count <= 2) or len(node.keywords): + raise TypeError("`static_assert` requires one or two positional arguments only") + + passed = _unwrap_if_constexpr(self.visit(node.args[0])) + if not isinstance(passed, bool): + raise NotImplementedError( + "Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values" + ) + if not passed: + if arg_count == 1: + message = "" + else: + try: + message = self.visit(node.args[1]) + except Exception as e: + message = "" + + raise CompileTimeAssertionFailure(self.jit_fn.src, node, _unwrap_if_constexpr(message)) + return None + + def static_executor(python_fn): + + def ret(self, node: ast.Call): + kws = { + name: _unwrap_if_constexpr(value) + for name, value in (self.visit(keyword) for keyword in node.keywords) + } + args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args] + return constexpr(python_fn(*args, **kws)) + + return ret + + from ..experimental.gluon import language as ttgl + statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = { + language.core.static_assert: execute_static_assert, + language.core.static_print: static_executor(print), + ttgl.static_assert: execute_static_assert, + ttgl.static_print: static_executor(print), + int: static_executor(int), + len: static_executor(len), + } + + +def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None): + arg_types = [None] * len(fn.arg_names) + for k, v in src.signature.items(): + idx = fn.arg_names.index(k) + arg_types[idx] = str_to_ty(v) + prototype = ASTFunction([], arg_types, src.constants, src.attrs) + file_name, begin_line = get_jit_fn_file_line(fn) + # query function representation + from collections import namedtuple + leaves = filter(lambda v: len(v) == 1, src.constants) + constants = {fn.arg_names[i[0]]: src.constants[i] for i in leaves} + signature = src.signature + proxy = namedtuple("SpecializationProxy", ["constants", "signature"])(constants, signature) + generator = CodeGenerator(context, prototype, gscope=fn.get_capture_scope(), function_name=fn.repr(proxy), + jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options, + codegen_fns=codegen_fns, module_map=module_map, module=module) + generator.visit(fn.parse()) + ret = generator.module + # module takes ownership of the context + ret.context = context + return ret diff --git a/third_party/sunrise/python/triton/compiler/compiler.py b/third_party/sunrise/python/triton/compiler/compiler.py new file mode 100644 index 000000000..06b651f33 --- /dev/null +++ b/third_party/sunrise/python/triton/compiler/compiler.py @@ -0,0 +1,526 @@ +from __future__ import annotations +import hashlib +import json +from .._C.libtriton import get_cache_invalidating_env_vars, ir +from ..backends import backends +from ..backends.compiler import Language +from ..backends.compiler import BaseBackend, GPUTarget +from .. import __version__, knobs +from ..runtime.autotuner import OutOfResources +from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager +from ..runtime.driver import driver +from ..tools.disasm import get_sass +from pathlib import Path +import re +import functools +import os +import sysconfig +import time + +# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, +# and any following whitespace +# - (public\s+)? : optionally match the keyword public and any following whitespace +# - (@\w+) : match an @ symbol followed by one or more word characters +# (letters, digits, or underscores), and capture it as group 1 (the function name) +# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing +# zero or more arguments separated by commas, and capture it as group 2 (the argument list) +# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 +ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" +prototype_pattern = { + "ptx": ptx_prototype_pattern, +} + +ptx_arg_type_pattern = r"\.param\s+\.(\w+)" +arg_type_pattern = { + "ptx": ptx_arg_type_pattern, +} + + +def convert_type_repr(x): + # Currently we only capture the pointer type and assume the pointer is on global memory. + # TODO: Capture and support shared memory space + match = re.search(r'!tt\.ptr<([^,]+)', x) + tma = re.search(r'tt.nv_tma_desc = 1', x) + if tma is not None: + return 'nvTmaDesc' + x = re.sub(r' {[^}]+}', '', x) + if match is not None: + return '*' + convert_type_repr(match.group(1)) + return x + + +class ASTSource: + + def __init__(self, fn, signature, constexprs=None, attrs=None) -> None: + self.fn = fn + self.language = Language.TRITON + self.ext = "ttir" + self.name = fn.__name__ + self.signature = signature + self.constants = dict() + if constexprs is not None: + for k, v in constexprs.items(): + k = (fn.arg_names.index(k), ) if isinstance(k, str) else k + assert isinstance(k, tuple) + self.constants[k] = v + self.attrs = attrs or dict() + if isinstance(self.signature, str): + self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} + else: + for k in self.signature.keys(): + if not isinstance(k, str): + raise TypeError("Signature keys must be string") + + def hash(self): + sorted_sig = [v for k, v in sorted(self.signature.items())] + get_key = lambda x: x.cache_key if hasattr(x, 'cache_key') else str(x) + constants_key = '-'.join([get_key(v) for k, v in sorted(self.constants.items())]) + key = f"{self.fn.cache_key}-{str(self.attrs)}-{sorted_sig}-{constants_key}" + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, module_map, context): + from .code_generator import ast_to_ttir + return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns, + module_map=module_map) + + def parse_options(self): + return dict() + + +class IRSource: + + def __init__(self, path, context, backend): + self.path = path + path = Path(path) + self.ext = path.suffix[1:] + self.language = Language.TRITON + self.src = path.read_text() + ir.load_dialects(context) + backend.load_dialects(context) + + # We don't have a easy-to-use PTX parser that we can use, so keep that regex for now. + # TODO - replace with a proper parser + if self.ext == "ptx": + match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) + self.name = match.group(1) + signature = match.group(2) + types = re.findall(arg_type_pattern[self.ext], signature) + self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + else: + self.module = ir.parse_mlir_module(self.path, context) + fn_name = self.module.get_entry_func_name() + self.name = "@" + fn_name + funcOp = self.module.get_function(fn_name) + func_ty = self.module.get_function_signature(funcOp) + self.signature = {k: ty for k, ty in enumerate(func_ty)} + + def hash(self): + return hashlib.sha256(self.src.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, module_map, context): + self.module.context = context + return self.module + + def parse_options(self): + if self.ext == "ttgir": + num_warps = self.module.get_int_attr("ttg.num-warps") + assert num_warps is not None, "Unable to parse ttg.num-warps attribute" + return {'num_warps': num_warps} + return dict() + + +@functools.lru_cache() +def triton_key(): + import pkgutil + TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + contents = [] + # frontend + with open(__file__, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + # compiler + path_prefixes = [ + (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."), + (os.path.join(TRITON_PATH, "backends"), "triton.backends."), + ] + for path, prefix in path_prefixes: + for lib in pkgutil.walk_packages([path], prefix=prefix): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + + # backend + libtriton_hash = hashlib.sha256() + ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1] + with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f: + while True: + chunk = f.read(1024**2) + if not chunk: + break + libtriton_hash.update(chunk) + contents.append(libtriton_hash.hexdigest()) + # language + language_path = os.path.join(TRITON_PATH, 'language') + for lib in pkgutil.walk_packages([language_path], prefix="triton.language."): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + return f'{__version__}' + '-'.join(contents) + + +@functools.lru_cache() +def max_shared_mem(device): + return driver.active.utils.get_device_properties(device)["max_shared_mem"] + + +def parse(full_name, ext, context): + if ext == "ttir" or ext == "ttgir": + module = ir.parse_mlir_module(full_name, context) + module.context = context + return module + if ext == "llir" or ext == "ptx" or ext == "amdgcn": + return Path(full_name).read_text() + if ext == "cubin" or ext == "hsaco": + return Path(full_name).read_bytes() + + +def filter_traceback(e: BaseException): + """ + Removes code_generator.py and related files from tracebacks. + + These are uninteresting to the user -- "just show me *my* code!" + """ + if knobs.compilation.front_end_debugging: + return + + if e.__cause__ is not None: + filter_traceback(e.__cause__) + if e.__context__ is not None: + filter_traceback(e.__context__) + + # If a user has a file that matches one of these, they're out of luck. + BAD_FILES = [ + "/triton/compiler/code_generator.py", + "/ast.py", + ] + BAD_FILES = [bad_file.replace("/", os.sep) for bad_file in BAD_FILES] + + tb = e.__traceback__ + frames = [] + while tb is not None: + if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)): + frames.append(tb) + tb = tb.tb_next + + for (cur_frame, next_frame) in zip(frames, frames[1:]): + cur_frame.tb_next = next_frame + + if not frames: + e.__traceback__ = None + else: + frames[-1].tb_next = None + e.__traceback__ = frames[0] + + +class CompileTimer: + + def __init__(self) -> None: + self.start: float = time.time() + self.ir_initialization_end: float | None = None + self.lowering_stage_ends: list[tuple[str, float]] = [] + self.store_results_end: float | None = None + + def finished_ir_initialization(self) -> None: + self.ir_initialization_end = time.time() + + def stage_finished(self, stage_name: str) -> None: + self.lowering_stage_ends.append((stage_name, time.time())) + + def end(self) -> knobs.CompileTimes: + timestamp = time.time() + if self.ir_initialization_end is None: + self.ir_initialization_end = timestamp + else: + self.store_results_end = timestamp + + def delta(start: float, end: float | None) -> int: + if end is None: + return 0 + return int((end - start) * 1000000) + + lowering_stage_durations = [] + stage_start = self.ir_initialization_end + for stage_name, stage_end in self.lowering_stage_ends: + lowering_stage_durations.append((stage_name, delta(stage_start, stage_end))) + stage_start = stage_end + + return knobs.CompileTimes( + ir_initialization=delta(self.start, self.ir_initialization_end), + lowering_stages=lowering_stage_durations, + store_results=delta(stage_start, self.store_results_end), + ) + + +def compile(src, target=None, options=None): + compilation_listener = knobs.compilation.listener + if compilation_listener: + timer = CompileTimer() + + if target is None: + target = driver.active.get_current_target() + assert isinstance(target, GPUTarget), "target must be of GPUTarget type" + backend = make_backend(target) + ir_source = not isinstance(src, ASTSource) + # create backend + if ir_source: + assert isinstance(src, str), "source must be either AST or a filepath" + context = ir.context() + src = IRSource(src, context, backend) + + extra_options = src.parse_options() + options = backend.parse_options(dict(options or dict(), **extra_options)) + # create cache manager + env_vars = get_cache_invalidating_env_vars() + key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}" + hash = hashlib.sha256(key.encode("utf-8")).hexdigest() + fn_cache_manager = get_cache_manager(hash) + # For dumping/overriding only hash the source as we want it to be independent of triton + # core changes to make it easier to track kernels by hash. + enable_override = knobs.compilation.override + enable_ir_dump = knobs.compilation.dump_ir + store_only_binary = knobs.compilation.store_binary_only + fn_override_manager = get_override_manager(src.hash()) if enable_override else None + fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None + # Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms. + # The final file name in the cache will have a format of f"{filename}.{ext}.tmp.pid_{pid}_{uuid}". + # A PID string can be 5-character long. A UUID string has typically 36 characters. Let's truncate + # the file name to 150 characters to be safe. + file_name = src.name[:150] + metadata_filename = f"{file_name}.json" + metadata_group = fn_cache_manager.get_group(metadata_filename) or {} + metadata_path = metadata_group.get(metadata_filename) + always_compile = knobs.compilation.always_compile + if not always_compile and metadata_path is not None: + # cache hit! + res = CompiledKernel(src, metadata_group, hash) + if compilation_listener: + compilation_listener( + src=src, + metadata=res.metadata._asdict(), + metadata_group=metadata_group, + times=timer.end(), + cache_hit=True, + ) + return res + + # initialize metadata + metadata = { + "hash": hash, + "target": target, + **options.__dict__, + **env_vars, + } + metadata["triton_version"] = __version__ + # run compilation pipeline and populate metadata + stages = dict() + backend.add_stages(stages, options, src.language) + first_stage = list(stages.keys()).index(src.ext) + # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. + if ir_source: + first_stage += 1 + + # For IRSource, we have already grabbed the context + called both + # ir.load_dialects and backend.load_dialects. + if not isinstance(src, IRSource): + context = ir.context() + ir.load_dialects(context) + backend.load_dialects(context) + + codegen_fns = backend.get_codegen_implementation(options) + module_map = backend.get_module_map() + try: + module = src.make_ir(options, codegen_fns, module_map, context) + except Exception as e: + filter_traceback(e) + raise + + if ir_source: + ir_filename = f"{file_name}.{src.ext}" + metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename) + else: + ir_filename = f"{file_name}.source" + metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename) + + use_ir_loc = knobs.compilation.use_ir_loc + if ir_source and use_ir_loc: + module.create_location_snapshot(src.path) + print(f"Creating new locations for {src.path}") + + if compilation_listener: + timer.finished_ir_initialization() + for ext, compile_ir in list(stages.items())[first_stage:]: + next_module = compile_ir(module, metadata) + ir_filename = f"{file_name}.{ext}" + if fn_override_manager is None: + # Users can override kernels at scale by setting `ir_override` in autotune config + # without TRITON_KERNEL_OVERRIDE + if (ir_override := metadata.get("ir_override", None)) and ir_override.endswith(f".{ext}"): + next_module = parse(ir_override, ext, context) + elif full_name := fn_override_manager.get_file(ir_filename): + print(f"\nOverriding kernel with file {full_name}") + next_module = parse(full_name, ext, context) + # If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json + if (not store_only_binary) or (ext in ("cubin", "hsaco", "json")): + metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) + if fn_dump_manager is not None: + fn_dump_manager.put(next_module, ir_filename) + # use an env variable to parse ir from file + if use_ir_loc == ext: + ir_full_name = fn_cache_manager.get_file(ir_filename) + next_module.create_location_snapshot(ir_full_name) + print(f"Creating new locations for {ir_full_name}") + module = next_module + if compilation_listener: + timer.stage_finished(ext) + # write-back metadata + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, + binary=False) + fn_cache_manager.put_group(metadata_filename, metadata_group) + # Compilation completed, disabling multithreading in context. + # This is needed to safely finalize threads pool inside context: if current process forks before + # python GC deletes context object, thread pool in child process will be invalid, which could + # lead to child crash or hang. + # + # However disabling multithreading causes the code to hang if the ASAN pass is enabled + # this is likely due to the llvm-symbolizer forking a process + # TODO: Reconcile the difference here between the ASAN and non-ASAN path with enabling + # multithreading in the MLIR context + if not knobs.compilation.enable_asan: + context.disable_multithreading() + + # notify any listener + if compilation_listener: + compilation_listener(src=src, metadata=metadata, metadata_group=metadata_group, times=timer.end(), + cache_hit=False) + # return handle to compiled kernel + return CompiledKernel(src, metadata_group, hash) + + +def make_backend(target: GPUTarget) -> BaseBackend: + actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)] + if len(actives) != 1: + raise RuntimeError( + f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.") + return actives[0](target) + + +class LazyDict: + + def __init__(self, data): + self.data = data + self.extras = [] + + def get(self): + for func, args in self.extras: + self.data = self.data | func(*args) + self.extras.clear() + return self.data + + def add(self, func, args): + self.extras.append((func, args)) + + +class AsmDict(dict): + + def __missing__(self, key): + + if key == "sass": + value = get_sass(self["cubin"]) + else: + raise KeyError("Unknown key: '%s'" % key) + + self[key] = value + return value + + +class CompiledKernel: + + def __init__(self, src, metadata_group, hash): + from collections import namedtuple + metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json"))) + metadata = json.loads(metadata_path.read_text()) + metadata['cluster_dims'] = tuple(metadata['cluster_dims']) + # JSON serialization dumps the target as a dict. Restore it to a GPUTarget. + target = metadata['target'] + metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size']) + KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys()))) + self.metadata = KernelMetadata(**metadata) + backend = make_backend(self.metadata.target) + self.packed_metadata = backend.pack_metadata(self.metadata) + self.src = src + self.hash = hash + self.name = self.metadata.name + # stores the text of each level of IR that was generated during compilation + asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")] + binary_ext = backend.binary_ext + self.asm = AsmDict({ + file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text() + for file in asm_files + }) + self.kernel = self.asm[binary_ext] + # binaries are lazily initialized + # because it involves doing runtime things + # (e.g., checking amount of shared memory on current device) + self.module = None + self.function = None + + def _init_handles(self): + if self.module is not None: + return + device = driver.active.get_current_device() + # create launcher + self.run = driver.active.launcher_cls(self.src, self.metadata) + # not enough shared memory to run the kernel + max_shared = max_shared_mem(device) + if self.metadata.shared > max_shared: + raise OutOfResources(self.metadata.shared, max_shared, "shared memory") + if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None: + # Use blackwell max tmem size for now, this should be moved in device properties + max_tmem_size = 512 # tmem size in number of columns + if self.metadata.tmem_size > max_tmem_size: + raise OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory") + # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` + self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary( + self.name, self.kernel, self.metadata.shared, device) + warp_size = driver.active.get_current_target().warp_size + if self.metadata.num_warps * warp_size > self.n_max_threads: + raise OutOfResources(self.metadata.num_warps * warp_size, self.n_max_threads, "threads") + + def __getattribute__(self, name): + if name == 'run': + self._init_handles() + return super().__getattribute__(name) + + def launch_metadata(self, grid, stream, *args): + if knobs.runtime.launch_enter_hook is None: + return None + ret = LazyDict({"name": self.name, "function": self.function, "stream": stream}) + if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None: + return ret + arg_dict = {} + arg_idx = 0 + for i, arg_name in enumerate(self.src.fn.arg_names): + arg_dict[arg_name] = args[arg_idx] + arg_idx += 1 + ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict)) + return ret + + def __getitem__(self, grid): + self._init_handles() + + def runner(*args, stream=None): + if stream is None: + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + launch_metadata = self.launch_metadata(grid, stream, *args) + self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata, + knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *args) + + return runner diff --git a/third_party/sunrise/python/triton/compiler/errors.py b/third_party/sunrise/python/triton/compiler/errors.py new file mode 100644 index 000000000..39e6c4dfb --- /dev/null +++ b/third_party/sunrise/python/triton/compiler/errors.py @@ -0,0 +1,51 @@ +import ast +from typing import Optional +from ..errors import TritonError + + +class CompilationError(TritonError): + """Base class for all errors raised during compilation""" + source_line_count_max_in_message = 12 + + def _format_message(self) -> str: + node = self.node + if self.src is None: + source_excerpt = " " + else: + if hasattr(node, 'lineno'): + source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:] + if source_excerpt: + source_excerpt.append(' ' * node.col_offset + '^') + source_excerpt = '\n'.join(source_excerpt) + else: + source_excerpt = " " + else: + source_excerpt = self.src + + message = "at {}:{}:\n{}".format(node.lineno, node.col_offset, source_excerpt) if hasattr( + node, 'lineno') else source_excerpt + if self.error_message: + message += '\n' + self.error_message + return message + + def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str] = None): + self.src = src + self.node = node + self.error_message = error_message + self.message = self._format_message() + + def __str__(self): + return self.message + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return type(self), (self.src, self.node, self.error_message) + + +class CompileTimeAssertionFailure(CompilationError): + """Specific exception for failed tests in `static_assert` invocations""" + pass + + +class UnsupportedLanguageConstruct(CompilationError): + pass diff --git a/third_party/sunrise/python/triton/compiler/make_launcher.py b/third_party/sunrise/python/triton/compiler/make_launcher.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/sunrise/python/triton/errors.py b/third_party/sunrise/python/triton/errors.py new file mode 100644 index 000000000..3a0a86355 --- /dev/null +++ b/third_party/sunrise/python/triton/errors.py @@ -0,0 +1,5 @@ +"""Base class for all errors raised by Triton""" + + +class TritonError(Exception): + ... diff --git a/third_party/sunrise/python/triton/knobs.py b/third_party/sunrise/python/triton/knobs.py new file mode 100644 index 000000000..ab0572b5d --- /dev/null +++ b/third_party/sunrise/python/triton/knobs.py @@ -0,0 +1,473 @@ +from __future__ import annotations + +import importlib +import os +import re +import subprocess +import sysconfig + +from dataclasses import dataclass +from contextlib import contextmanager +from typing import cast, Any, Callable, Generator, Generic, Optional, Protocol, Type, TypeVar, TypedDict, TYPE_CHECKING, Union + +if TYPE_CHECKING: + from .runtime.cache import CacheManager, RemoteCacheBackend + from .runtime.jit import JitFunctionInfo, KernelParam + from .compiler.compiler import ASTSource, LazyDict, IRSource + + +class Env: + pass + + +env = Env() + +propagate_env: bool = True + + +def getenv(key: str) -> Optional[str]: + res = os.getenv(key) + return res.strip() if res is not None else res + + +def setenv(key: str, value: Optional[str]) -> None: + if not propagate_env: + return + + if value is not None: + os.environ[key] = value + elif key in os.environ: + del os.environ[key] + + +def toenv(val: Any) -> Union[None, tuple[Optional[str]]]: + if val is None: + return (None, ) + + t = type(val) + if t is bool: + return ("1" if val else "0", ) + + if t is str: + return (val, ) + + if t is int: + return (str(val), ) + + return None + + +# There's an asymmetry here so that e.g. env_nvidia_tool can be specified with a +# a string but return an NvidiaTool. +SetType = TypeVar("SetType") +GetType = TypeVar("GetType") + + +class env_base(Generic[SetType, GetType]): + + def __init__(self, key: str, default: Union[SetType, Callable[[], SetType]]) -> None: + self.key = key + self.default: Callable[[], SetType] = default if callable(default) else lambda: default + + def __set_name__(self, objclass: Type[object], name: str) -> None: + self.name = name + + def __get__(self, obj: Optional[object], objclass: Optional[Type[object]]) -> GetType: + if obj is None: + raise AttributeError(f"Cannot access {type(self)} on non-instance") + + if self.name in obj.__dict__: + return self.transform(obj.__dict__[self.name]) + else: + return self.get() + + @property + def env_val(self) -> str | None: + return getenv(self.key) + + def get(self) -> GetType: + env = self.env_val + return self.transform(self.default() if env is None else self.from_env(env)) + + def __set__(self, obj: object, value: Union[SetType, Env]) -> None: + if isinstance(value, Env): + obj.__dict__.pop(self.name, None) + else: + obj.__dict__[self.name] = value + if env_val := toenv(value): + setenv(self.key, env_val[0]) + + def __delete__(self, obj: object) -> None: + obj.__dict__.pop(self.name, None) + + def transform(self, val: SetType) -> GetType: + # See comment about GetType/SetType in their definition above. Only needed + # if GetType != SetType. + return cast(GetType, val) + + def from_env(self, val: str) -> SetType: + raise NotImplementedError() + + +class env_str(env_base[str, str]): + + def from_env(self, val: str) -> str: + return val + + +class env_bool(env_base[bool, bool]): + + def __init__(self, key: str, default: Union[bool, Callable[[], bool]] = False) -> None: + super().__init__(key, default) + + def from_env(self, val: str) -> bool: + return val.lower() in ("1", "true", "yes", "on", "y") + + +class env_int(env_base[int, int]): + + def __init__(self, key: str, default: Union[int, Callable[[], int]] = 0) -> None: + super().__init__(key, default) + + def from_env(self, val: str) -> int: + try: + return int(val) + except ValueError as exc: + raise RuntimeError(f"Unable to use {self.key}={val}: expected int") from exc + + +class env_opt_base(Generic[GetType, SetType], env_base[Optional[GetType], Optional[SetType]]): + + def __init__(self, key: str) -> None: + super().__init__(key, None) + + +ClassType = TypeVar("ClassType") + + +class env_class(Generic[ClassType], env_opt_base[Type[ClassType], Type[ClassType]]): + + def __init__(self, key: str, type: str) -> None: + super().__init__(key) + # We can't pass the type directly to avoid import cycles + self.type = type + + def from_env(self, val: str) -> Type[ClassType]: + comps = val.split(":", 1) + if len(comps) != 2: + raise RuntimeError(f"Unable to read {self.key}: '{val}' isn't of the form MODULE:CLASS") + cls = getattr(importlib.import_module(comps[0]), comps[1]) + + if not any((c.__name__ == self.type for c in cls.mro())): + raise RuntimeError(f"Unable to use '{val}' from {self.key}: not of type '{self.type}'") + + return cast(Type[ClassType], cls) + + +@dataclass +class NvidiaTool: + path: str + version: str + + @staticmethod + def from_path(path: str) -> Optional[NvidiaTool]: + try: + result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT) + if result is None: + return None + version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) + if version is None: + return None + return NvidiaTool(path, version.group(1)) + except subprocess.CalledProcessError: + return None + + +class env_nvidia_tool(env_base[str, NvidiaTool]): + + def __init__(self, binary: str) -> None: + binary += sysconfig.get_config_var("EXE") + self.binary = binary + super().__init__(f"TRITON_{binary.upper()}_PATH", lambda: os.path.join( + os.path.dirname(__file__), + "backends", + "nvidia", + "bin", + self.binary, + )) + + def transform(self, path: str) -> NvidiaTool: + paths = [ + path, + # We still add default as fallback in case the pointed binary isn't + # accessible. + self.default(), + ] + for path in paths: + if not path or not os.access(path, os.X_OK): + continue + if tool := NvidiaTool.from_path(path): + return tool + + raise RuntimeError(f"Cannot find {self.binary}") + + def from_env(self, val: str) -> str: + return val + + +# Separate classes so that types are correct +class env_opt_str(env_opt_base[str, str], env_str): + pass + + +class env_opt_bool(env_opt_base[bool, bool], env_bool): + pass + + +@dataclass(frozen=True) +class CompileTimes: + """ + Model holding timing information for an invocation of the compiler. + + All times in microseconds. + """ + + # Duration of make_ir + ir_initialization: int + + # Ordered mapping from lowering stage to duration spent in that stage. + # Keyed by stage extension, e.g. ttir, ttgir + lowering_stages: list[tuple[str, int]] + + # Duration of saving artifacts/metadata to cache + store_results: int + + @property + def total_lowering(self) -> int: + return sum((stage[1] for stage in self.lowering_stages)) + + @property + def total(self) -> int: + return self.ir_initialization + self.total_lowering + self.store_results + + +class CompilationListener(Protocol): + + def __call__(self, *, src: Union[ASTSource, IRSource], metadata: dict[str, Any], metadata_group: dict[str, str], + times: CompileTimes, cache_hit: bool) -> None: + ... + + +knobs_type = TypeVar("knobs_type", bound='base_knobs') + + +class base_knobs: + + @property + def knob_descriptors(self) -> dict[str, env_base]: + return { + k: v + # data descriptors live on the class object + for k, v in type(self).__dict__.items() + if isinstance(v, env_base) + } + + @property + def knobs(self) -> dict[str, Any]: + return {k: getattr(self, k) for k in self.knob_descriptors.keys()} + + def copy(self: knobs_type) -> knobs_type: + res = type(self)() + res.__dict__.update(self.__dict__) + return res + + def reset(self: knobs_type) -> knobs_type: + for knob in self.knob_descriptors.keys(): + delattr(self, knob) + return self + + @contextmanager + def scope(self) -> Generator[None, None, None]: + try: + initial_env = {knob.key: knob.env_val for knob in self.knob_descriptors.values()} + orig = dict(self.__dict__) + yield + finally: + self.__dict__.clear() + self.__dict__.update(orig) + + for k, v in initial_env.items(): + if v is not None: + os.environ[k] = v + elif k in os.environ: + del os.environ[k] + + +class BuildImpl(Protocol): + + def __call__(self, name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str], + libraries: list[str], /) -> str: + ... + + +class build_knobs(base_knobs): + """Configuration controlling how the native compiler is invoked""" + cc: env_opt_str = env_opt_str("CC") + + cudacrt_path: env_opt_str = env_opt_str("TRITON_CUDACRT_PATH") + cudart_path: env_opt_str = env_opt_str("TRITON_CUDART_PATH") + + impl: Optional[BuildImpl] = None + + @property + def backend_dirs(self) -> set[str]: + return {path for path in (self.cudacrt_path, self.cudart_path) if path is not None} + + +class redis_knobs(base_knobs): + key_format: env_str = env_str("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}") + host: env_str = env_str("TRITON_REDIS_HOST", "localhost") + port: env_int = env_int("TRITON_REDIS_PORT", 6379) + + +cache: cache_knobs + + +class cache_knobs(base_knobs): + home_dir: env_str = env_str("TRITON_HOME", lambda: os.path.expanduser("~/")) + + dump_dir: env_str = env_str("TRITON_DUMP_DIR", lambda: cache.get_triton_dir("dump")) + override_dir: env_str = env_str("TRITON_OVERRIDE_DIR", lambda: cache.get_triton_dir("override")) + dir: env_str = env_str("TRITON_CACHE_DIR", lambda: cache.get_triton_dir("cache")) + + manager_class: env_class[CacheManager] = env_class("TRITON_CACHE_MANAGER", "CacheManager") + remote_manager_class: env_class[RemoteCacheBackend] = env_class("TRITON_REMOTE_CACHE_BACKEND", "RemoteCacheBackend") + + def get_triton_dir(self, dirname: str) -> str: + return os.path.join(self.home_dir, ".triton", dirname) + + +class compilation_knobs(base_knobs): + override: env_bool = env_bool("TRITON_KERNEL_OVERRIDE") + dump_ir: env_bool = env_bool("TRITON_KERNEL_DUMP") + store_binary_only: env_bool = env_bool("TRITON_STORE_BINARY_ONLY") + always_compile: env_bool = env_bool("TRITON_ALWAYS_COMPILE") + # TODO: Use enum to constrain / 'typecheck' the values + use_ir_loc: env_opt_str = env_opt_str("USE_IR_LOC") + enable_asan: env_bool = env_bool("TRITON_ENABLE_ASAN") + disable_line_info: env_bool = env_bool("TRITON_DISABLE_LINE_INFO") + front_end_debugging: env_bool = env_bool("TRITON_FRONT_END_DEBUGGING") + allow_non_constexpr_globals: env_bool = env_bool("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS") + listener: Union[CompilationListener, None] = None + + +class autotuning_knobs(base_knobs): + cache: env_bool = env_bool("TRITON_CACHE_AUTOTUNING") + print: env_bool = env_bool("TRITON_PRINT_AUTOTUNING") + + +class LaunchHook(Protocol): + + def __call__(self, metadata: LazyDict) -> None: + ... + + +# This is of the form [attr_name, attr_val] +# TODO: Use tuple instead of list for better typing. +KernelAttr = list[Union[str, int]] + + +class JITHookCompileInfo(TypedDict): + key: str + signature: dict[KernelParam, str] + device: int + constants: None + num_warps: int + num_ctas: int + num_stages: int + enable_fp_fusion: bool + launch_cooperative_grid: bool + extern_libs: tuple[tuple[str, str], ...] + configs: list[dict[tuple[int, ...], list[KernelAttr]]] + specialization_data: str + is_warmup: bool + + +class JITHook(Protocol): + + def __call__(self, *, key: str, repr: str, fn: JitFunctionInfo, compile: JITHookCompileInfo, is_manual_warmup: bool, + already_compiled: bool) -> Optional[bool]: + ... + + +class runtime_knobs(base_knobs): + interpret: env_bool = env_bool("TRITON_INTERPRET") + debug: env_bool = env_bool("TRITON_DEBUG") + override_arch: env_opt_str = env_opt_str("TRITON_OVERRIDE_ARCH") + + launch_enter_hook: Optional[LaunchHook] = None + launch_exit_hook: Optional[LaunchHook] = None + + # Hook for inspecting compiled functions and modules + jit_cache_hook: Optional[JITHook] = None + # Hook to signal that a kernel is done compiling and inspect compiled function. + # jit_cache_hook will always be called before compilation and jit_post_compile_hook after. + jit_post_compile_hook: Optional[JITHook] = None + + +class language_knobs(base_knobs): + fp32_default: env_opt_str = env_opt_str("TRITON_F32_DEFAULT") + default_fp_fusion: env_bool = env_bool("TRITON_DEFAULT_FP_FUSION", True) + + +class nvidia_knobs(base_knobs): + cuobjdump: env_nvidia_tool = env_nvidia_tool("cuobjdump") + nvdisasm: env_nvidia_tool = env_nvidia_tool("nvdisasm") + ptxas: env_nvidia_tool = env_nvidia_tool("ptxas") + + dump_nvptx: env_bool = env_bool("NVPTX_ENABLE_DUMP") + disable_ptxas_opt: env_bool = env_bool("DISABLE_PTXAS_OPT") + mock_ptx_version: env_opt_str = env_opt_str("TRITON_MOCK_PTX_VERSION") + + libdevice_path: env_opt_str = env_opt_str("TRITON_LIBDEVICE_PATH") + libcuda_path: env_opt_str = env_opt_str("TRITON_LIBCUDA_PATH") + + +class amd_knobs(base_knobs): + use_buffer_ops: env_bool = env_bool("AMDGCN_USE_BUFFER_OPS") + dump_amdgcn: env_bool = env_bool("AMDGCN_ENABLE_DUMP") + libhip_path: env_opt_str = env_opt_str("TRITON_LIBHIP_PATH") + lld_path: env_opt_str = env_opt_str("TRITON_HIP_LLD_PATH") + + # We use strs so that we can have a default value based on other runtime info + use_block_pingpong: env_opt_bool = env_opt_bool("TRITON_HIP_USE_BLOCK_PINGPONG") + use_in_thread_transpose: env_opt_bool = env_opt_bool("TRITON_HIP_USE_IN_THREAD_TRANSPOSE") + + global_prefetch: env_int = env_int("TRITON_HIP_GLOBAL_PREFETCH") + local_prefetch: env_int = env_int("TRITON_HIP_LOCAL_PREFETCH") + use_async_copy: env_bool = env_bool("TRITON_HIP_USE_ASYNC_COPY") + scalarize_packed_fops: env_bool = env_bool("AMDGCN_SCALARIZE_PACKED_FOPS") + +class sunrise_knobs(base_knobs): + lld_path: env_opt_str = env_opt_str("TRITON_SUNRISE_LLD_PATH") + triple: env_opt_str = env_opt_str("TRITON_SUNRISE_TRANSLATE_TRIPLE") + flag: env_opt_str = env_opt_str("TRITON_SUNRISE_OPTION_FLAG") + libtang_path: env_opt_str = env_opt_str("TRITON_LIBTANG_PATH") + opt_level: env_int = env_int("TRITON_SUNRISE_OPTIMIZATION_LEVEL", 3) + dump_stcu: env_bool = env_bool("TRITON_SUNRISE_ENABLE_DUMP") + +class proton_knobs(base_knobs): + cupti_dir: env_opt_str = env_opt_str("TRITON_CUPTI_LIB_PATH") + + +build = build_knobs() +redis = redis_knobs() +cache = cache_knobs() +compilation = compilation_knobs() +autotuning = autotuning_knobs() +runtime = runtime_knobs() +language = language_knobs() +nvidia = nvidia_knobs() +amd = amd_knobs() +sunrise = sunrise_knobs() +proton = proton_knobs() diff --git a/third_party/sunrise/python/triton/language/__init__.py b/third_party/sunrise/python/triton/language/__init__.py new file mode 100644 index 000000000..7625e700b --- /dev/null +++ b/third_party/sunrise/python/triton/language/__init__.py @@ -0,0 +1,336 @@ +"""isort:skip_file""" +# Import order is significant here. + +from . import math +from . import extra +from .standard import ( + argmax, + argmin, + bitonic_merge, + cdiv, + cumprod, + cumsum, + flip, + interleave, + max, + min, + ravel, + reduce_or, + sigmoid, + softmax, + sort, + sum, + swizzle2d, + topk, + xor_sum, + zeros, + zeros_like, +) +from .core import ( + PropagateNan, + TRITON_MAX_TENSOR_NUMEL, + load_tensor_descriptor, + store_tensor_descriptor, + make_tensor_descriptor, + tensor_descriptor, + tensor_descriptor_type, + add, + advance, + arange, + associative_scan, + assume, + async_task, + atomic_add, + atomic_and, + atomic_cas, + atomic_max, + atomic_min, + atomic_or, + atomic_xchg, + atomic_xor, + bfloat16, + block_type, + broadcast, + broadcast_to, + cat, + cast, + clamp, + const, + constexpr, + constexpr_function, + debug_barrier, + device_assert, + device_print, + dot, + dot_scaled, + dtype, + expand_dims, + float16, + float32, + float64, + float8e4b15, + float8e4nv, + float8e4b8, + float8e5, + float8e5b16, + full, + gather, + histogram, + inline_asm_elementwise, + int1, + int16, + int32, + int64, + int8, + join, + load, + make_block_ptr, + max_constancy, + max_contiguous, + maximum, + minimum, + multiple_of, + num_programs, + permute, + pi32_t, + pointer_type, + program_id, + range, + reduce, + reshape, + slice, + split, + static_assert, + static_print, + static_range, + store, + tensor, + trans, + tuple, + tuple_type, + uint16, + uint32, + uint64, + uint8, + view, + void, + where, +) +from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor, + ceil) +from .random import ( + pair_uniform_to_normal, + philox, + philox_impl, + rand, + rand4x, + randint, + randint4x, + randn, + randn4x, + uint_to_uniform_float, +) + +__all__ = [ + "PropagateNan", + "TRITON_MAX_TENSOR_NUMEL", + "load_tensor_descriptor", + "store_tensor_descriptor", + "make_tensor_descriptor", + "tensor_descriptor", + "abs", + "add", + "advance", + "arange", + "argmax", + "argmin", + "associative_scan", + "assume", + "async_task", + "atomic_add", + "atomic_and", + "atomic_cas", + "atomic_max", + "atomic_min", + "atomic_or", + "atomic_xchg", + "atomic_xor", + "bfloat16", + "bitonic_merge", + "block_type", + "broadcast", + "broadcast_to", + "cat", + "cast", + "cdiv", + "ceil", + "clamp", + "const", + "constexpr", + "constexpr_function", + "cos", + "cumprod", + "cumsum", + "debug_barrier", + "device_assert", + "device_print", + "div_rn", + "dot", + "dot_scaled", + "dtype", + "erf", + "exp", + "exp2", + "expand_dims", + "extra", + "fdiv", + "flip", + "float16", + "float32", + "float64", + "float8e4b15", + "float8e4nv", + "float8e4b8", + "float8e5", + "float8e5b16", + "floor", + "fma", + "full", + "gather", + "histogram", + "inline_asm_elementwise", + "interleave", + "int1", + "int16", + "int32", + "int64", + "int8", + "join", + "load", + "log", + "log2", + "make_block_ptr", + "math", + "max", + "max_constancy", + "max_contiguous", + "maximum", + "min", + "minimum", + "multiple_of", + "num_programs", + "pair_uniform_to_normal", + "permute", + "philox", + "philox_impl", + "pi32_t", + "pointer_type", + "program_id", + "rand", + "rand4x", + "randint", + "randint4x", + "randn", + "randn4x", + "range", + "ravel", + "reduce", + "reduce_or", + "reshape", + "rsqrt", + "slice", + "sigmoid", + "sin", + "softmax", + "sort", + "split", + "sqrt", + "sqrt_rn", + "static_assert", + "static_print", + "static_range", + "store", + "sum", + "swizzle2d", + "tensor", + "topk", + "trans", + "tuple", + "uint16", + "uint32", + "uint64", + "uint8", + "uint_to_uniform_float", + "umulhi", + "view", + "void", + "where", + "xor_sum", + "zeros", + "zeros_like", +] + + +def str_to_ty(name): + from builtins import tuple + + if isinstance(name, tuple): + fields = type(name).__dict__.get("_fields", None) + return tuple_type([str_to_ty(x) for x in name], fields) + + if name[0] == "*": + name = name[1:] + const = False + if name[0] == "k": + name = name[1:] + const = True + ty = str_to_ty(name) + return pointer_type(element_ty=ty, const=const) + + if name.startswith("tensordesc"): + inner = name.split("<")[1].rstrip(">") + dtype, rest = inner.split("[", maxsplit=2) + block_shape, rest = rest.split("]", maxsplit=2) + block_shape = [int(s.strip()) for s in block_shape.rstrip("]").split(",")] + layout = rest.lstrip(",") + is_gluon = len(layout) + dtype = str_to_ty(dtype) + ndim = len(block_shape) + shape_type = tuple_type([int32] * ndim) + # FIXME: Last dim stride should be constexpr(1) + stride_type = tuple_type(([int64] * ndim)) + block = block_type(dtype, block_shape) + if is_gluon: + from triton.experimental.gluon.language._layouts import NVMMASharedLayout + from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor_type as gluon_tensor_descriptor_type + layout = eval(layout, dict(NVMMASharedLayout=NVMMASharedLayout)) + assert isinstance(layout, NVMMASharedLayout) + return gluon_tensor_descriptor_type(block, shape_type, stride_type, layout) + return tensor_descriptor_type(block, shape_type, stride_type) + + if name == "constexpr": + return constexpr + + tys = { + "fp8e4nv": float8e4nv, + "fp8e4b8": float8e4b8, + "fp8e5": float8e5, + "fp8e5b16": float8e5b16, + "fp8e4b15": float8e4b15, + "fp16": float16, + "bf16": bfloat16, + "fp32": float32, + "fp64": float64, + "i1": int1, + "i8": int8, + "i16": int16, + "i32": int32, + "i64": int64, + "u1": int1, + "u8": uint8, + "u16": uint16, + "u32": uint32, + "u64": uint64, + "B": int1, + } + return tys[name] diff --git a/third_party/sunrise/python/triton/language/core.py b/third_party/sunrise/python/triton/language/core.py new file mode 100644 index 000000000..ddf0e0fd0 --- /dev/null +++ b/third_party/sunrise/python/triton/language/core.py @@ -0,0 +1,3325 @@ +from __future__ import annotations + +from warnings import warn +from contextlib import contextmanager +from enum import Enum +from functools import partial, wraps +import typing +from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple +from dataclasses import dataclass +import builtins +from .. import knobs +from ..runtime.jit import jit, JITFunction +import inspect + +from .._C.libtriton import ir +from .._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape, get_primitive_bitwidth + +T = TypeVar('T') + +TRITON_BUILTIN = "__triton_builtin__" + +PropagateNan = ir.PROPAGATE_NAN + + +def must_use_result(x, s=True): + """If the result of this function is unused, throw an error.""" + if isinstance(x, str): + return (lambda fn: must_use_result(fn, x)) + x._must_use_result = s + return x + + +def builtin(fn: T) -> T: + """Mark a function as a builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_semantic" not in kwargs or kwargs["_semantic"] is None: + raise ValueError("Did you forget to add @triton.jit ? " + "(`_semantic` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + setattr(wrapper, TRITON_BUILTIN, True) + + return wrapper + + +def _tensor_member_fn(fn: T) -> T: + """Decorator that adds this free function as a member fn on class tensor. + + When called as a member function on class tensor, the first argument to `fn` + is `self`, i.e. the tensor object. + + If there are multiple decorators on a function, you probably want this one + to be the highest one (i.e. furthest from the function's `def`), so it's + applied last. + + Unfortunately you still need to add a type stub to the body of class tensor + in order for pytype to know about it. + """ + assert callable(fn) + orig_sig = inspect.signature(fn) + # Does fn take args other than _semantic, _generator, and the tensor itself? + has_args = len(orig_sig.parameters.keys() - {"_semantic", "_generator"}) > 1 + + if not fn.__doc__: + fn.__doc__ = "" + fn.__doc__ += f""" + This function can also be called as a member function on :py:class:`tensor`, + as :code:`x.{fn.__name__}({"..." if has_args else ""})` instead of + :code:`{fn.__name__}(x{", ..." if has_args else ""})`. + """ + + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + # Match the signature of `fn`, but change the first arg to `self` so the + # docs are a little less weird. + new_params = list(orig_sig.parameters.values()) + new_params[0] = new_params[0].replace(name='self') + new_sig = orig_sig.replace(parameters=new_params) + wrapper.__signature__ = new_sig + wrapper.__doc__ = f"Forwards to :py:func:`{fn.__name__}` free function" + # If fn is a builtin, mark the wrapper as a builtin too. + if is_builtin(fn): + setattr(wrapper, TRITON_BUILTIN, True) + + setattr(tensor, fn.__name__, fn if isinstance(fn, JITFunction) else wrapper) + return fn + + +def _unwrap_iterable(x): + """Returns x[0] if x has one element and x[0] is iterable.""" + if len(x) == 1: + # Determine whether x[0] is iterable. + # + # You might want to use collections.abc.Iterable instead of this + # try/except block. Unfortunately, this doesn't work with constexpr. + # + # The problem is that abc.Iterable checks for __iter__ on the *class*. + # But we want constexpr to expose an __iter__ method if and only if the + # wrapped *object* (i.e. self.value) is iterable. Therefore there's no + # right answer for whether the class constexpr defines __iter__, and + # abc.Iterable doesn't work (at least not without some metaclass magic). + try: + iter(x[0]) + return x[0] + except TypeError: + pass + + return x + + +def is_builtin(fn) -> bool: + """Is this a registered triton builtin function?""" + return getattr(fn, TRITON_BUILTIN, False) + + +@builtin +def to_tensor(x, _semantic=None): + return _semantic.to_tensor(x) + + +# ----------------------- +# constexpr +# ----------------------- + + +class const: + """ + This class is used as a type annotation to mark pointers to constant data. + The `store` function cannot be called with a pointer to const. Constness + is part of the pointer type and the usual Triton type consistency rules + apply. For example you cannot have a function that returns constant pointer + in one return statement and non-constant pointer in another. + """ + pass + + +class base_value: + """Base class of values that exist in the triton IR (i.e. not constexprs). + """ + type: base_type + + def _flatten_ir(self, handles: List[ir.value]) -> None: + """Flatten frontend value into a sequence of mlir handles, which are appended + to the output list + """ + raise NotImplementedError + + +class base_type: + + def __eq__(self, other): + raise NotImplementedError("Types must implement __eq__") + + def __ne__(self, other): + return not (self == other) + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]: + """Build a frontend value with the current dtype, wrapping a list of existing handles. + cursor is the index of the first handle relevant to this value, and the function + should return the updated cursor position after any handles consumed by the created value. + """ + raise NotImplementedError + + def mangle(self) -> str: + raise NotImplementedError(f"NYI: Type mangling for type {self.__class__}") + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + raise NotImplementedError + + +class constexpr_type(base_type): + + def __init__(self, value): + self.value = value + + def __eq__(self, other): + return self.value == other.value + + def __repr__(self) -> str: + return f"constexpr[{self.value}]" + + def mangle(self) -> str: + return repr(self) + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + return + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]: + return constexpr(self.value), cursor + + +class constexpr(base_value): + """ + This class is used to store a value that is known at compile-time. + """ + + def __init__(self, value): + if isinstance(value, constexpr): + self.value = value.value + else: + self.value = value + self.type = constexpr_type(value) + + def __repr__(self) -> str: + return f"constexpr[{self.value}]" + + def _flatten_ir(self, handles: List[ir.value]) -> None: + return + + def __index__(self): + return self.value + + # In interpreter mode, constant values are not wrapped in constexpr, + # and therefore do not have a .value attribute. + # As a result, from here and below, we need to call the _unwrap_if_constexpr + # function to obtain either constexpr.value or the value itself. + def __add__(self, other): + return constexpr(self.value + _unwrap_if_constexpr(other)) + + def __radd__(self, other): + return constexpr(_unwrap_if_constexpr(other) + self.value) + + def __sub__(self, other): + return constexpr(self.value - _unwrap_if_constexpr(other)) + + def __rsub__(self, other): + return constexpr(_unwrap_if_constexpr(other) - self.value) + + def __mul__(self, other): + return constexpr(self.value * _unwrap_if_constexpr(other)) + + def __mod__(self, other): + return constexpr(self.value % _unwrap_if_constexpr(other)) + + def __rmul__(self, other): + return constexpr(_unwrap_if_constexpr(other) * self.value) + + def __truediv__(self, other): + return constexpr(self.value / _unwrap_if_constexpr(other)) + + def __rtruediv__(self, other): + return constexpr(_unwrap_if_constexpr(other) / self.value) + + def __floordiv__(self, other): + return constexpr(self.value // _unwrap_if_constexpr(other)) + + def __rfloordiv__(self, other): + return constexpr(_unwrap_if_constexpr(other) // self.value) + + def __gt__(self, other): + return constexpr(self.value > _unwrap_if_constexpr(other)) + + def __rgt__(self, other): + return constexpr(_unwrap_if_constexpr(other) > self.value) + + def __ge__(self, other): + return constexpr(self.value >= _unwrap_if_constexpr(other)) + + def __rge__(self, other): + return constexpr(_unwrap_if_constexpr(other) >= self.value) + + def __lt__(self, other): + return constexpr(self.value < _unwrap_if_constexpr(other)) + + def __rlt__(self, other): + return constexpr(_unwrap_if_constexpr(other) < self.value) + + def __le__(self, other): + return constexpr(self.value <= _unwrap_if_constexpr(other)) + + def __rle__(self, other): + return constexpr(_unwrap_if_constexpr(other) <= self.value) + + def __eq__(self, other): + return constexpr(self.value == _unwrap_if_constexpr(other)) + + def __ne__(self, other): + return constexpr(self.value != _unwrap_if_constexpr(other)) + + def __bool__(self): + return bool(self.value) + + def __neg__(self): + return constexpr(-self.value) + + def __and__(self, other): + return constexpr(self.value & _unwrap_if_constexpr(other)) + + def logical_and(self, other): + return constexpr(self.value and _unwrap_if_constexpr(other)) + + def __or__(self, other): + return constexpr(self.value | _unwrap_if_constexpr(other)) + + def __xor__(self, other): + return constexpr(self.value ^ _unwrap_if_constexpr(other)) + + def logical_or(self, other): + return constexpr(self.value or _unwrap_if_constexpr(other)) + + def __pos__(self): + return constexpr(+self.value) + + def __invert__(self): + return constexpr(~self.value) + + def __pow__(self, other): + return constexpr(self.value**_unwrap_if_constexpr(other)) + + def __rpow__(self, other): + return constexpr(_unwrap_if_constexpr(other)**self.value) + + def __rshift__(self, other): + return constexpr(self.value >> _unwrap_if_constexpr(other)) + + def __lshift__(self, other): + return constexpr(self.value << _unwrap_if_constexpr(other)) + + def __not__(self): + return constexpr(not self.value) + + def __iter__(self): + return iter(self.value) + + def __call__(self, *args, **kwds): + return self.value(*args, **kwds) + + def __getitem__(self, *args): + args = (_unwrap_if_constexpr(x) for x in _normalize_tuple(args)) + return self.value.__getitem__(*args) + + +def constexpr_function(f): + """ + Wraps an arbitrary Python function so that it can be called at + compile-time on constexpr arguments in a Triton function and + returns a constexpr result. + """ + + @wraps(f) + def wrapper(*args, _semantic=None, **kwargs): + # de-constexpr arguments and discard the _semantic keyword argument: + args = [_unwrap_if_constexpr(x) for x in args] + kwargs = {k: _unwrap_if_constexpr(v) for (k, v) in kwargs.items()} + + # call the raw Python function f: + res = f(*args, **kwargs) + + # convert result back to a Triton constexpr: + return constexpr(res) + + # disguise the function as a Triton builtin to avoid raising an error + # that we're calling a non-JIT function from within a Triton kernel: + wrapper.__triton_builtin__ = True + wrapper.__module__ = constexpr_function.__module__ + return wrapper + + +CONSTEXPR_0 = constexpr(0) + + +def _unwrap_if_constexpr(o): + if isinstance(o, list): + return [_unwrap_if_constexpr(x) for x in o] + if isinstance(o, builtins.tuple): + return builtins.tuple(_unwrap_if_constexpr(x) for x in o) + if isinstance(o, tuple): + return tuple(_unwrap_if_constexpr(x) for x in o) + return o.value if isinstance(o, constexpr) else o + + +def _normalize_tuple(t): + normalized_tuple = _unwrap_if_constexpr(t) + if isinstance(normalized_tuple, (list, builtins.tuple)): + normalized_tuple = tuple(normalized_tuple) + return normalized_tuple + + +def check_bit_width(value, shift_value): + if isinstance(value, tensor) and isinstance(shift_value, constexpr): + bitwidth = value.type.scalar.primitive_bitwidth + if shift_value.value >= bitwidth: + warn( + f"Value {shift_value.value} exceeds the maximum bitwidth ({bitwidth}) for type '{value.dtype}'. This may result in undefined behavior." + ) + + +# ----------------------- +# dtype +# ----------------------- + + +class dtype(base_type): + SINT_TYPES = ['int8', 'int16', 'int32', 'int64'] + UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64'] + FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64'] + STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64'] + OTHER_TYPES = ['void'] + + class SIGNEDNESS(Enum): + SIGNED = 0 + UNSIGNED = 1 + + class KIND(Enum): + BOOLEAN = 0 + INTEGRAL = 1 + FLOATING = 2 + + def __init__(self, name): + name = _unwrap_if_constexpr(name) + self.name = name + assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name + self.primitive_bitwidth = get_primitive_bitwidth(name) + self.itemsize = self.primitive_bitwidth // 8 + if name in dtype.SINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.SIGNED + self.int_bitwidth = self.primitive_bitwidth + elif name in dtype.UINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.UNSIGNED + self.int_bitwidth = self.primitive_bitwidth + elif name in dtype.FP_TYPES: + if name == 'fp8e4b15': + self.fp_mantissa_width = 3 + self.exponent_bias = 15 + elif name == 'fp8e4nv': + self.fp_mantissa_width = 3 + self.exponent_bias = 7 + elif name == 'fp8e4b8': + self.fp_mantissa_width = 3 + self.exponent_bias = 8 + elif name == 'fp8e5': + self.fp_mantissa_width = 2 + self.exponent_bias = 15 + elif name == 'fp8e5b16': + self.fp_mantissa_width = 2 + self.exponent_bias = 16 + elif name == 'fp16': + self.fp_mantissa_width = 10 + self.exponent_bias = 15 + elif name == 'bf16': + self.fp_mantissa_width = 7 + self.exponent_bias = 127 + elif name == 'fp32': + self.fp_mantissa_width = 23 + self.exponent_bias = 127 + elif name == 'fp64': + self.fp_mantissa_width = 52 + self.exponent_bias = 1023 + else: + raise RuntimeError(f'Unsupported floating-point type {name}') + + def is_fp8(self): + return 'fp8' in self.name + + def is_fp8e4nv(self): + return self.name == 'fp8e4nv' + + def is_fp8e4b8(self): + return self.name == 'fp8e4b8' + + def is_fp8e4b15(self): + return self.name == 'fp8e4b15' + + def is_fp8e5(self): + return self.name == 'fp8e5' + + def is_fp8e5b16(self): + return self.name == 'fp8e5b16' + + def is_fp16(self): + return self.name == 'fp16' + + def is_bf16(self): + return self.name == 'bf16' + + def is_fp32(self): + return self.name == 'fp32' + + def is_fp64(self): + return self.name == 'fp64' + + def is_int1(self): + return self.name == 'int1' + + def is_int8(self): + return self.name == 'int8' + + def is_int16(self): + return self.name == 'int16' + + def is_int32(self): + return self.name == 'int32' + + def is_int64(self): + return self.name == 'int64' + + def is_uint8(self): + return self.name == 'uint8' + + def is_uint16(self): + return self.name == 'uint16' + + def is_uint32(self): + return self.name == 'uint32' + + def is_uint64(self): + return self.name == 'uint64' + + def is_floating(self): + return self.name in dtype.FP_TYPES + + def is_standard_floating(self): + return self.name in dtype.STANDARD_FP_TYPES + + def is_int_signed(self): + return self.name in dtype.SINT_TYPES + + def is_int_unsigned(self): + return self.name in dtype.UINT_TYPES + + def is_int(self): + return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES + + def is_bool(self): + return self.is_int1() + + def kind(self): + # Return int value following the type ordering bool < integer < fp + if self.is_bool(): + return dtype.KIND.BOOLEAN + elif self.is_int(): + return dtype.KIND.INTEGRAL + else: + assert self.is_floating() + return dtype.KIND.FLOATING + + def get_int_max_value(self): + if self.is_int_signed(): + return 2**(self.int_bitwidth - 1) - 1 + if self.is_int_unsigned(): + return 2**self.int_bitwidth - 1 + assert False + + def get_int_min_value(self): + if self.is_int_signed(): + return -2**(self.int_bitwidth - 1) + if self.is_int_unsigned(): + return 0 + assert False + + @staticmethod + def is_dtype(type_str): + return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES + + @staticmethod + def is_void(): + raise RuntimeError("Not implemented") + + @staticmethod + def is_block(): + return False + + @staticmethod + def is_ptr(): + return False + + @staticmethod + def is_const(): + return False + + def __eq__(self, other: dtype): + if not isinstance(other, dtype): + return False + return self.name == other.name + + def __hash__(self): + return hash((self.name, )) + + @property + def scalar(self): + return self + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + out.append(self.to_ir(builder)) + + def to_ir(self, builder: ir.builder) -> ir.type: + if self.name.startswith("fp8"): + if self.name not in builder.options.supported_fp8_dtypes: + raise ValueError(f'type {self} not supported in this architecture. ' + f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}') + + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name in ('int8', 'uint8'): + return builder.get_int8_ty() + elif self.name in ('int16', 'uint16'): + return builder.get_int16_ty() + elif self.name in ('int32', 'uint32'): + return builder.get_int32_ty() + elif self.name in ('int64', 'uint64'): + return builder.get_int64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e5b16': + return builder.get_fp8e5b16_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b8': + return builder.get_fp8e4b8_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + def __str__(self): + return self.name + + def codegen_name(self): + if self.name.startswith("fp"): + return "float" + self.name[2:] + elif self.name.startswith("bf"): + return "bfloat" + self.name[2:] + else: + return self.name + + @property + def cache_key_part(self) -> str: + """See cache_key_part() in triton.cc.""" + return self.name + + def __repr__(self): + """Output of repr needs to be an evaluatable expression""" + return f'triton.language.{self.codegen_name()}' + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]: + return tensor(handles[cursor], self), cursor + 1 + + def mangle(self) -> str: + if self.is_int(): + SIGNED = dtype.SIGNEDNESS.SIGNED + prefix = 'i' if self.int_signedness == SIGNED else 'u' + return prefix + str(self.int_bitwidth) + if self.is_floating(): + return str(self) + if self.is_void(): + return 'V' + return super().mangle() + + def with_element_ty(self, element_ty: dtype): + assert not self.is_block() + return element_ty + + +# Some functions have a param named `dtype`, which shadows the `dtype` class. +# We can't change the param name because it is part of function's public API. +# Declare an alias so those functions can still reference the dtype class. +_DtypeClass = dtype + + +class pointer_type(dtype): + + def __init__(self, element_ty: dtype, address_space: int = 1, const: bool = False): + element_ty = _unwrap_if_constexpr(element_ty) + if not isinstance(element_ty, dtype): + raise TypeError(f'element_ty has type `{type(element_ty).__name__}`; expected `dtype`.') + self.element_ty = element_ty + self.address_space = address_space + self.const = const + self.name = f'pointer<{element_ty}>' if not const else f'const_pointer<{element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.pointer_type: + return builder.get_ptr_ty(self.element_ty.to_ir(builder), self.address_space) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_ptr(self): + return True + + def is_const(self): + return self.const + + def __eq__(self, other: pointer_type) -> bool: + if not isinstance(other, pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space and self.const == other.const + + @property + def scalar(self): + return self + + def mangle(self) -> str: + return f"P{self.element_ty.mangle()}" + + +class block_type(dtype): + + def __init__(self, element_ty: dtype, shape: List): + self.element_ty = element_ty + + # Note that block_type's shape is a list of int + # while tensor's shape is a list of constexpr. + assert (isinstance(shape, (list, tuple))) + + # shape can be empty ([]) when an input is a 0D tensor. + self.shape = tuple(_unwrap_shape(shape)) + if not self.shape: + raise TypeError('0d block_type is forbidden') + + self.numel = validate_block_shape(self.shape) + self.name = f'<{self.shape}, {self.element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.block_type: + return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_block(self): + return True + + def get_block_shapes(self) -> Tuple[int]: + return self.shape + + def with_element_ty(self, scalar_ty: dtype) -> block_type: + return block_type(scalar_ty, self.shape) + + def __eq__(self, other) -> bool: + if not isinstance(other, block_type): + return False + return self.element_ty == other.element_ty and self.shape == other.shape + + @property + def scalar(self): + return self.element_ty + + def mangle(self) -> str: + elt = self.scalar.mangle() + shape = '_'.join(map(str, self.shape)) + return f'{elt}S{shape}S' + + +class tuple_type(base_type): + + def __init__(self, types, fields=None): + self.types = types + self.fields = fields or [''] * len(types) + self.name = '[' + ','.join([f"{k}:{v}" for k, v in zip(self.fields, self.types)]) + ']' + + def __str__(self): + return self.name + + def __iter__(self): + return iter(self.types) + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]): + for ty in self.types: + if not isinstance(ty, constexpr): + ty._flatten_ir_types(builder, out) + + def __getitem__(self, index: int) -> dtype: + return self.types[index] + + def __eq__(self, other): + return type(self) is type(other) and self.types == other.types and self.fields == other.fields + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tuple, int]: + values = [] + for ty in self.types: + value, cursor = ty._unflatten_ir(handles, cursor) + values.append(value) + return tuple(values, self), cursor + + def mangle(self): + return 'T' + '_'.join(ty.mangle for ty in self.types) + 'T' + + +class slice_type(dtype): + + def __init__(self): + self.name = 'slice_type' + + +# scalar types +void = dtype('void') +int1 = dtype('int1') +int8 = dtype('int8') +int16 = dtype('int16') +int32 = dtype('int32') +int64 = dtype('int64') +uint8 = dtype('uint8') +uint16 = dtype('uint16') +uint32 = dtype('uint32') +uint64 = dtype('uint64') +float8e5 = dtype('fp8e5') +float8e5b16 = dtype('fp8e5b16') +float8e4nv = dtype('fp8e4nv') +float8e4b8 = dtype('fp8e4b8') +float8e4b15 = dtype('fp8e4b15') +float16 = dtype('fp16') +bfloat16 = dtype('bf16') +float32 = dtype('fp32') +float64 = dtype('fp64') +# pointer types +pi32_t = pointer_type(int32) + + +def get_int_dtype(bitwidth: int, signed: bool) -> dtype: + if bitwidth == 1: + return int1 + elif bitwidth == 8 and signed: + return int8 + elif bitwidth == 8 and not signed: + return uint8 + elif bitwidth == 16 and signed: + return int16 + elif bitwidth == 16 and not signed: + return uint16 + elif bitwidth == 32 and signed: + return int32 + elif bitwidth == 32 and not signed: + return uint32 + elif bitwidth == 64 and signed: + return int64 + elif bitwidth == 64 and not signed: + return uint64 + else: + raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}') + + +# ----------------------- +# tensor +# ----------------------- + + +class tensor(base_value): + """Represents an N-dimensional array of values or pointers. + + :code:`tensor` is the fundamental data structure in Triton programs. Most + functions in :py:mod:`triton.language` operate on and return tensors. + + Most of the named member functions here are duplicates of the free functions + in :code:`triton.language`. For example, :code:`triton.language.sqrt(x)` is + equivalent to :code:`x.sqrt()`. + + :code:`tensor` also defines most of the magic/dunder methods, so you can + write :code:`x+y`, :code:`x << 2`, etc. + + .. rubric:: Constructors + .. + For some reason Sphinx includes __init__ before printing the full table + of methods. Not what I want, but I can't figure out how to fix it. Give + it its own section so it looks intentional. :) + """ + + def __init__(self, handle, type: dtype): + """Not called by user code.""" + super().__init__() + # IR handle + self.handle = handle + # Block shape + self.shape = type.shape if type.is_block() else () + self.numel = 1 + for s in self.shape: + self.numel *= s + self.numel = constexpr(self.numel) + self.type = type # Tensor type (can be block_type) + # Following the practice in pytorch, dtype is scalar type + self.dtype = type.scalar + self.shape = tuple([constexpr(s) for s in self.shape]) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + + def __str__(self) -> str: + # ex. "float32[16, 32]" + return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']' + + @builtin + def __add__(self, other, _semantic=None): + return add(self, other, sanitize_overflow=True, _semantic=_semantic) + + @builtin + def __radd__(self, other, _semantic=None): + return add(other, self, sanitize_overflow=True, _semantic=_semantic) + + @builtin + def __sub__(self, other, _semantic=None): + return sub(self, other, sanitize_overflow=True, _semantic=_semantic) + + @builtin + def __rsub__(self, other, _semantic=None): + return sub(other, self, sanitize_overflow=True, _semantic=_semantic) + + @builtin + def __mul__(self, other, _semantic=None): + return mul(self, other, sanitize_overflow=True, _semantic=_semantic) + + @builtin + def __rmul__(self, other, _semantic=None): + return mul(other, self, sanitize_overflow=True, _semantic=_semantic) + + @builtin + def __truediv__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.truediv(self, other) + + @builtin + def __rtruediv__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.truediv(other, self) + + @builtin + def __floordiv__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.floordiv(self, other) + + @builtin + def __rfloordiv__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.floordiv(other, self) + + @builtin + def __mod__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.mod(self, other) + + @builtin + def __rmod__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.mod(other, self) + + # unary operators + @builtin + def __neg__(self, _semantic=None): + return _semantic.minus(self) + + @builtin + def __invert__(self, _semantic=None): + return _semantic.invert(self) + + # bitwise operators + + @builtin + def __and__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.and_(self, other) + + @builtin + def __rand__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.and_(other, self) + + @builtin + def __or__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.or_(self, other) + + @builtin + def __ror__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.or_(other, self) + + @builtin + def __xor__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.xor_(self, other) + + @builtin + def __rxor__(self, other, _semantic=None): + other = _unwrap_if_constexpr(other) + return _semantic.xor_(other, self) + + @builtin + def __lshift__(self, other, _semantic=None): + check_bit_width(self, other) + other = _unwrap_if_constexpr(other) + return _semantic.shl(self, other) + + @builtin + def __rlshift__(self, other, _semantic=None): + check_bit_width(other, self) + other = _unwrap_if_constexpr(other) + return _semantic.shl(other, self) + + @builtin + def __rshift__(self, other, _semantic=None): + check_bit_width(self, other) + other = _unwrap_if_constexpr(other) + if self.dtype.is_int_signed(): + return _semantic.ashr(self, other) + else: + return _semantic.lshr(self, other) + + @builtin + def __rrshift__(self, other, _semantic=None): + check_bit_width(other, self) + other = _unwrap_if_constexpr(other) + if self.dtype.is_int_signed(): + return _semantic.ashr(other, self) + else: + return _semantic.lshr(other, self) + + # > + @builtin + def __gt__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.greater_than(self, other) + + @builtin + def __rgt__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.greater_than(other, self) + + # >= + @builtin + def __ge__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.greater_equal(self, other) + + @builtin + def __rge__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.greater_equal(other, self) + + # < + @builtin + def __lt__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.less_than(self, other) + + @builtin + def __rlt__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.less_than(other, self) + + # <= + @builtin + def __le__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.less_equal(self, other) + + @builtin + def __rle__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.less_equal(other, self) + + # == + @builtin + def __eq__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.equal(self, other) + + @builtin + def __req__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.equal(other, self) + + @builtin + def __ne__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.not_equal(self, other) + + @builtin + def __rne__(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.not_equal(other, self) + + @builtin + def logical_and(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.logical_and(self, other) + + @builtin + def logical_or(self, other, _semantic=None): + other = _semantic.to_tensor(other) + return _semantic.logical_or(self, other) + + # note: __not__ isn't actually a magic method in python + # but it's ok because our ASTVisitor handles it + @builtin + def __not__(self, _semantic=None): + return _semantic.not_(self) + + @builtin + def __getitem__(self, slices, _semantic=None): + if isinstance(slices, (builtins.slice, slice, constexpr)) or slices is None: + slices = [slices] + if isinstance(slices, tuple): + slices = slices.values + ret = self + for dim, sl in enumerate(slices): + if _unwrap_if_constexpr(sl) is None: + ret = _semantic.expand_dims(ret, dim) + elif isinstance(sl, (builtins.slice, slice)) and all( + _unwrap_if_constexpr(arg) is None for arg in (sl.start, sl.stop, sl.step)): + pass # an unsqueeze + else: + raise ValueError(f"unsupported tensor index: {sl}") + return ret + + @property + def T(self): + """Transposes a 2D tensor.""" + assert False, "Transposition must be created by the AST Visitor" + + @builtin + def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _semantic=None): + """ + Alias for :py:func:`tensor.cast`. + """ + return cast(self, dtype, fp_downcast_rounding, bitcast, _semantic=_semantic) + + # Type stubs for functions added by the _tensor_member_fn decorator. + # (Unfortunately these can't be created automatically.) + # + # We couldn't write these definitions out even if we wanted to, because some + # of these functions are defined in standard.py. + def broadcast_to(self, *shape) -> tensor: + ... + + def trans(self, *dims) -> tensor: + ... + + def permute(self, *dims) -> tensor: + ... + + def split(self) -> tuple[tensor, tensor]: + ... + + def view(self, *shape) -> tensor: + ... + + def reshape(self, *shape) -> tensor: + ... + + def expand_dims(self, axis) -> tensor: + ... + + def cast(self, dtype, fp_downcast_rounding=None, bitcast=False) -> tensor: + ... + + def store(self, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="") -> tensor: + ... + + def advance(self, offsets) -> tensor: + ... + + def atomic_cas(self, cmp, val, sem=None, scope=None) -> tensor: + ... + + def atomic_xchg(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_add(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_max(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_min(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_and(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_or(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_xor(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def exp(self) -> tensor: + ... + + def log(self) -> tensor: + ... + + def cos(self) -> tensor: + ... + + def sin(self) -> tensor: + ... + + def sqrt(self) -> tensor: + ... + + def rsqrt(self) -> tensor: + ... + + def abs(self) -> tensor: + ... + + def reduce(self, axis, combine_fn, keep_dims=False) -> tensor: + ... + + def associative_scan(self, axis, combine_fn, reverse=False) -> tensor: + ... + + def gather(self, indices, axis) -> tensor: + ... + + def histogram(self, num_bins) -> tensor: + ... + + def cdiv(self, div) -> tensor: + ... + + def sigmoid(self) -> tensor: + ... + + def softmax(self, dim=None, keep_dims=False, ieee_rounding=False) -> tensor: + ... + + def ravel(self) -> tensor: + ... + + def max(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor: + ... + + def argmax(self, axis, tie_break_left=True, keep_dims=False) -> tensor: + ... + + def min(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor: + ... + + def argmin(self, axis, tie_break_left=True, keep_dims=False) -> tensor: + ... + + def sum(self, axis=None, keep_dims=False, dtype=None) -> tensor: + ... + + def xor_sum(self, axis=None, keep_dims=False) -> tensor: + ... + + def reduce_or(self, axis=None, keep_dims=False) -> tensor: + ... + + def cumsum(self, axis=0, reverse=False) -> tensor: + ... + + def cumprod(self, axis=0, reverse=False) -> tensor: + ... + + def sort(self, dim: constexpr = None, descending: constexpr = CONSTEXPR_0) -> tensor: + ... + + def flip(self, dim=None) -> tensor: + ... + + +class tuple(base_value): + + def __init__(self, args: Sequence, type: tuple_type = None): + self.values = [i for i in args] + + def get_type(x): + if isinstance(x, dtype): + return dtype + if isinstance(x, (int, float)): + return constexpr + return x.type + + self.type = type or tuple_type([get_type(x) for x in self.values]) + + def __getitem__(self, idx: constexpr): + if isinstance(idx, int): + idx = constexpr(idx) + if isinstance(idx, constexpr): + return self.values[idx] + else: + assert isinstance(idx, (slice, builtins.slice)) + return tuple(self.values[idx.start:idx.stop:idx.step]) + + def __getattr__(self, name): + return self.values[self.type.fields.index(name)] + + # TODO: remove + def __setitem__(self, idx: constexpr, value): + if isinstance(idx, int): + idx = constexpr(idx) + assert isinstance(idx, constexpr) + self.values[idx] = value + + def __add__(self, other): + other = _normalize_tuple(other) + return tuple(self.values + other.values) + # return tuple(a + b for a, b in zip(self.values, other.values)) + + def __mul__(self, other): + assert isinstance(other, constexpr) + return tuple(self.values * other.value) + + def __eq__(self, other): + other = _normalize_tuple(other) + return constexpr(self.values == other.values) + + def __hash__(self): + return hash(builtins.tuple(self.values)) + + def __str__(self): + return str([str(x) for x in self.values]) + + def __iter__(self): + return iter(self.values) + + def __len__(self): + return len(self.values) + + def _flatten_ir(self, handles: List[ir.value]): + for v in self.values: + v._flatten_ir(handles) + + def __repr__(self): + return f"({' ,'.join(repr(x) for x in self.values)})" + + +class slice: + + def __init__(self, start, stop, step): + self.start = start + self.stop = stop + self.step = step + self.type = slice_type() + + +class tensor_descriptor_base_type(base_type): + + def __init__(self, block_type: block_type): + self.block_type = block_type + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]: + value = tensor_descriptor_base(handles[cursor], self.block_type) + return value, cursor + 1 + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + is_signed = self.block_type.element_ty.is_int_signed() + out.append(builder.create_tensor_descriptor_type(self.block_type.to_ir(builder), is_signed)) + + def __str__(self) -> str: + # ex. "tensor_descriptor" + return f"tensor_descriptor<{self.block_type}>" + + def __eq__(self, other) -> bool: + if type(other) is not type(self): + return False + return self.block_type == other.block_type + + def __neq__(self, other) -> bool: + return not (self == other) + + def mangle(self) -> str: + return f"TD{self.block_type.mangle()}" + + +class tensor_descriptor_base(base_value): + """" + A tensor descriptor with unknown shape and strides + """ + + def __init__(self, handle, block_type: block_type): + """Not called by user code.""" + super().__init__() + + self.handle = handle # IR handle + self.type = tensor_descriptor_base_type(block_type) # Tensor type (block_type) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + + @property + def block_type(self): + return self.type.block_type + + @property + def block_shape(self): + return self.type.block_type.shape + + @property + def dtype(self): + return self.type.block_type.element_ty + + def __str__(self) -> str: + return str(self.type) + + @builtin + def load(self, offsets: Sequence[constexpr | tensor], _semantic=None) -> tensor: + """Load a block from the descriptor starting at the given element offsets. + + Values outside of the tensor bounds will be filled with zeros. + + :note: Offset must be a multiple of 16-bytes + """ + return _semantic.descriptor_load(self, offsets, "", "") + + @builtin + def store(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor: + """Store a block from the descriptor starting at the given element offsets. + + Values outside of the tensor bounds will be ignored. + + :note: Offset must be a multiple of 16-bytes + """ + return _semantic.descriptor_store(self, value, offsets) + + @builtin + def atomic_add(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor: + return _semantic.descriptor_atomic_add(self, value, offsets) + + @builtin + def atomic_min(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor: + return _semantic.descriptor_atomic_min(self, value, offsets) + + @builtin + def atomic_max(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor: + return _semantic.descriptor_atomic_max(self, value, offsets) + + @builtin + def atomic_and(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor: + return _semantic.descriptor_atomic_and(self, value, offsets) + + @builtin + def atomic_or(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor: + return _semantic.descriptor_atomic_or(self, value, offsets) + + @builtin + def atomic_xor(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor: + return _semantic.descriptor_atomic_xor(self, value, offsets) + + @builtin + def gather(self, *args, _semantic=None) -> tensor: + """Gather multiple descriptors worth of data""" + assert len(args) == 2, f"descriptor gather only supports 2D indexing, but got {len(args)}" + x_offsets = args[0] + y_offset = args[1] + return _semantic.descriptor_gather(self, x_offsets, y_offset, "", "") + + @builtin + def scatter(self, value, *args, _semantic=None) -> tensor: + """Scatter multiple descriptors worth of data""" + assert len(args) == 2, f"descriptor scatter only supports 2D indexing, but got {len(args)}" + x_offsets = args[0] + y_offset = args[1] + return _semantic.descriptor_scatter(self, value, x_offsets, y_offset) + + +class tensor_descriptor_type(tensor_descriptor_base_type): + + def __init__(self, block_type: block_type, shape_type: tuple_type, strides_type: tuple_type): + self.block_type = block_type + self.shape_type = shape_type + self.strides_type = strides_type + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]: + handle = handles[cursor] + cursor += 1 + shape, cursor = self.shape_type._unflatten_ir(handles, cursor) + strides, cursor = self.strides_type._unflatten_ir(handles, cursor) + shape = shape.values + strides = strides.values + value = tensor_descriptor(handle, shape, strides, self.block_type) + return value, cursor + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + super()._flatten_ir_types(builder, out) + self.shape_type._flatten_ir_types(builder, out) + self.strides_type._flatten_ir_types(builder, out) + + def __eq__(self, other): + return super().__eq__(other) and (self.shape_type == other.shape_type) and (self.strides_type + == other.strides_type) + + +class tensor_descriptor(tensor_descriptor_base): + """A descriptor representing a tensor in global memory. + """ + + def __init__(self, handle, shape: List[tensor], strides: List[tensor], block_type: block_type): + """Not called by user code.""" + # IR handle + super().__init__(handle, block_type) + # Global shape + self.shape = tuple(shape) + self.strides = tuple(strides) + self.type = tensor_descriptor_type( + block_type, + shape_type=self.shape.type, + strides_type=self.strides.type, + ) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + self.shape._flatten_ir(handles) + self.strides._flatten_ir(handles) + + +# ----------------------- +# aggregate +# ----------------------- + + +@dataclass(frozen=True) +class _aggregate_type(base_type): + """A generic base type for all Triton aggregate types. + + This class contains a reference to the original user-defined Python class + and a list of class fields with their Triton types. + """ + + base_cls: type + fields: List[Tuple[str, base_type]] + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[ir.value, int]: + instance = self.base_cls._get_instance() + for name, ty in self.fields: + value, cursor = ty._unflatten_ir(handles, cursor) + setattr(instance, name, value) + return instance, cursor + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + for name, ty in self.fields: + ty._flatten_ir_types(builder, out) + + def mangle(self) -> str: + name = f"{self.base_cls.__module__}.{self.base_cls.__qualname__}" + fields = [ty.mangle() for (name, ty) in self.fields] + return f"{name}<{', '.join(fields)}>" + + +def _aggregate(cls): + + # Define the wrapped Triton value type. + class aggregate_value(base_value): + __triton_builtin__ = True + __triton_aggregate__ = True + + @classmethod + def _get_instance(this_cls): + return super().__new__(this_cls) + + def __new__(this_cls, *args, _semantic=None, _generator=None, **kwargs): + # Call into the user-defined constructor. + instance = this_cls._get_instance() + if isinstance(cls.__init__, JITFunction): + raise ValueError(f"{cls.__name__}.__init__ cannot be a @triton.jit function") + extra_kwargs = {} + if "_semantic" in inspect.signature(cls.__init__).parameters: + extra_kwargs["_semantic"] = _semantic + if "_generator" in inspect.signature(cls.__init__).parameters: + extra_kwargs["_generator"] = _generator + cls.__init__(instance, *args, **extra_kwargs, **kwargs) + + # Require that the user-defined constructor initialized all fields. + for name in cls.__annotations__.keys(): + if not hasattr(instance, name): + raise AttributeError(f"constructor for {cls.__name__} did not initialize attribute '{name}'") + + return instance + + # Only allow setting attributes defined in the class annotations. + def __setattr__(self, name, value): + if name not in cls.__annotations__: + raise AttributeError(f"{cls.__name__} has no attribute '{name}'") + if not isinstance(value, cls.__annotations__[name]): + raise TypeError(f"Expected {cls.__annotations__[name]} for attribute '{name}', got {type(value)}") + super().__setattr__(name, value) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + for name in cls.__annotations__.keys(): + getattr(self, name)._flatten_ir(handles) + + @property + def type(self): + return _aggregate_type(aggregate_value, + [(name, getattr(self, name).type) for name in cls.__annotations__.keys()]) + + for (name, member) in inspect.getmembers(cls): + if inspect.isfunction(member) or inspect.ismethod(member) or isinstance(member, JITFunction): + if name != "__init__": + setattr(aggregate_value, name, member) + + aggregate_value.__name__ = cls.__name__ + aggregate_value.__module__ = cls.__module__ + aggregate_value.__qualname__ = cls.__qualname__ + aggregate_value.__doc__ = cls.__doc__ + + return aggregate_value + + +# ----------------------- +# SPMD Programming Model +# ----------------------- + + +@builtin +def program_id(axis, _semantic=None): + """ + Returns the id of the current program instance along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2. + :type axis: int + """ + # if axis == -1: + # pid0 = _semantic.program_id(0) + # pid1 = _semantic.program_id(1) + # pid2 = _semantic.program_id(2) + # npg0 = _semantic.num_programs(0) + # npg1 = _semantic.num_programs(1) + # return pid0 + pid1*npg0 + pid2*npg0*npg1 + axis = _unwrap_if_constexpr(axis) + return _semantic.program_id(axis) + + +@builtin +def num_programs(axis, _semantic=None): + """ + Returns the number of program instances launched along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2. + :type axis: int + """ + axis = _unwrap_if_constexpr(axis) + return _semantic.num_programs(axis) + + +# ----------------------- +# Block Initialization +# ----------------------- + + +@builtin +def arange(start, end, _semantic=None): + start = _unwrap_if_constexpr(start) + end = _unwrap_if_constexpr(end) + return _semantic.arange(start, end) + + +arange.__doc__ = f""" + Returns contiguous values within the half-open interval :code:`[start, + end)`. :code:`end - start` must be less than or equal to + :code:`TRITON_MAX_TENSOR_NUMEL = {TRITON_MAX_TENSOR_NUMEL}` + + :param start: Start of the interval. Must be a power of two. + :type start: int32 + :param end: End of the interval. Must be a power of two greater than + :code:`start`. + :type end: int32 +""" + + +def _unwrap_shape(shape): + shape = _unwrap_if_constexpr(shape) + return [_unwrap_if_constexpr(s) for s in shape] + + +def _shape_check_impl(shape): + shape = _unwrap_shape(shape) + validate_block_shape(shape) + return shape + + +@builtin +def full(shape, value, dtype, _semantic=None): + """ + Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :type shape: tuple of ints + :param value: A scalar value to fill the array with + :type value: scalar + :param dtype: Data type of the new array, e.g., :code:`tl.float16` + :type dtype: tl.dtype + """ + shape = _shape_check_impl(shape) + value = _unwrap_if_constexpr(value) + dtype = _unwrap_if_constexpr(dtype) + return _semantic.full(shape, value, dtype) + + +# ----------------------- +# Shape Manipulation +# ----------------------- + + +@builtin +def broadcast(input, other, _semantic=None): + """ + Tries to broadcast the two given blocks to a common compatible shape. + + :param input: The first input tensor. + :type input: Block + :param other: The second input tensor. + :type other: Block + """ + return _semantic.broadcast_impl_value(input, other) + + +@_tensor_member_fn +@builtin +def broadcast_to(input, *shape, _semantic=None): + """ + Tries to broadcast the given tensor to a new :code:`shape`. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + :type shape: + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + broadcast_to(x, (32, 32)) + broadcast_to(x, 32, 32) + """ + shape = _shape_check_impl(_unwrap_iterable(shape)) + return _semantic.broadcast_impl_shape(input, shape) + + +@_tensor_member_fn +@builtin +def trans(input: tensor, *dims, _semantic=None): + """ + Permutes the dimensions of a tensor. + + If the parameter :code:`dims` is not specified, the function defaults to a (1,0) permutation, + effectively transposing a 2D tensor. + + :param input: The input tensor. + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + trans(x, (2, 1, 0)) + trans(x, 2, 1, 0) + + :py:func:`permute` is equivalent to this function, except it doesn't + have the special case when no permutation is specified. + """ + dims = _unwrap_iterable(dims) + if not dims: + dims = (1, 0) + return _semantic.permute(input, dims) + + +@_tensor_member_fn +@builtin +def permute(input, *dims, _semantic=None): + """ + Permutes the dimensions of a tensor. + + :param input: The input tensor. + :type input: Block + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + permute(x, (2, 1, 0)) + permute(x, 2, 1, 0) + + :py:func:`trans` is equivalent to this function, except when + :code:`dims` is empty, it tries to do a (1,0) permutation. + """ + dims = _unwrap_iterable(dims) + return _semantic.permute(input, dims) + + +@builtin +def cat(input, other, can_reorder=False, _semantic=None): + """ + Concatenate the given blocks + + :param input: The first input tensor. + :type input: Tensor + :param other: The second input tensor. + :type other: Tensor + :param reorder: Compiler hint. If true, the compiler is + allowed to reorder elements while concatenating inputs. Only use if the + order does not matter (e.g., result is only used in reduction ops). + Current implementation of `cat` supports only can_reorder=True. + """ + return _semantic.cat(input, other, can_reorder) + + +@builtin +def join(a, b, _semantic=None): + """ + Join the given tensors in a new, minor dimension. + + For example, given two tensors of shape (4,8), produces a new tensor of + shape (4,8,2). Given two scalars, returns a tensor of shape (2). + + The two inputs are broadcasted to be the same shape. + + If you want to join more than two elements, you can use multiple calls to + this function. This reflects the constraint in Triton that tensors must + have power-of-two sizes. + + join is the inverse of split. + + :param a: The first input tensor. + :type a: Tensor + :param b: The second input tensor. + :type b: Tensor + """ + return _semantic.join(a, b) + + +@jit +def _take_first(a, b): + return a + + +def _unsplat(x, _semantic=None, _generator=None): + """ + Convert a single-element tensor to a scalar. + """ + if len(x.shape) == 0: + return x + numel = 1 + for d in x.shape: + numel *= d + assert numel == 1, "can only unsplat single-element tensors" + if len(x.shape) >= 2: + x = _semantic.reshape(x, [1]) + x = typing.cast(tensor, reduce(x, 0, _take_first, _semantic=_semantic, _generator=_generator)) + return x + + +@_tensor_member_fn +@builtin +def split(a, _semantic=None, _generator=None) -> tuple[tensor, tensor]: + """ + Split a tensor in two along its last dim, which must have size 2. + + For example, given a tensor of shape (4,8,2), produces two tensors of shape + (4,8). Given a tensor of shape (2), returns two scalars. + + If you want to split into more than two pieces, you can use multiple calls + to this function (probably plus calling reshape). This reflects the + constraint in Triton that tensors must have power-of-two sizes. + + split is the inverse of join. + + :param a: The tensor to split. + :type a: Tensor + """ + # If len(a.shape) == 1, i.e. a.shape == [2], we should return two scalars. + # But _semantic.split can only handle returning tensors. Work around this by + # expanding the input to shape [1,2] and then reducing the result. + was_rank_1 = len(a.shape) == 1 + if was_rank_1: + a = _semantic.expand_dims(a, 0) + + out_lhs, out_rhs = _semantic.split(a) + + if was_rank_1: + # Currently `reduce` is the best way to convert a tensor of shape [1] to a scalar. + out_lhs = _unsplat(out_lhs, _semantic=_semantic, _generator=_generator) + out_rhs = _unsplat(out_rhs, _semantic=_semantic, _generator=_generator) + + return out_lhs, out_rhs + + +@_tensor_member_fn +@builtin +def view(input, *shape, _semantic=None): + """ + Returns a tensor with the same elements as `input` but a different shape. + The order of the elements may not be preserved. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + view(x, (32, 32)) + view(x, 32, 32) + """ + warn("view is deprecated, please use reshape with can_reorder being true.") + shape = _shape_check_impl(_unwrap_iterable(shape)) + return _semantic.reshape(input, shape, can_reorder=True) + + +@_tensor_member_fn +@builtin +def item(input, _semantic=None, _generator=None): + """ + Converts a single-element tensor into a scalar. + """ + return _unsplat(input, _semantic=_semantic, _generator=_generator) + + +@_tensor_member_fn +@builtin +def reshape(input, *shape, can_reorder=False, _semantic=None, _generator=None): + """ + Returns a tensor with the same number of elements as input but with the + provided shape. + + :param input: The input tensor. + :type input: Block + :param shape: The new shape. + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + reshape(x, (32, 32)) + reshape(x, 32, 32) + """ + shape = _shape_check_impl(_unwrap_iterable(shape)) + if len(shape) == 0: + return _unsplat(input, _semantic=_semantic, _generator=_generator) + return _semantic.reshape(input, shape, can_reorder) + + +def _wrap_axis(axis, ndim): + if not (-ndim <= axis < ndim): + raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}") + + return axis if axis >= 0 else axis + ndim + + +@_tensor_member_fn +@builtin +def expand_dims(input, axis, _semantic=None): + """ + Expand the shape of a tensor, by inserting new length-1 dimensions. + + Axis indices are with respect to the resulting tensor, so + ``result.shape[axis]`` will be 1 for each axis. + + :param input: The input tensor. + :type input: tl.tensor + :param axis: The indices to add new axes + :type axis: int | Sequence[int] + + """ + input = _semantic.to_tensor(input) + axis = _unwrap_if_constexpr(axis) + axes = list(axis) if isinstance(axis, (Sequence, tuple)) else [axis] + new_ndim = len(input.shape) + len(axes) + axes = [_wrap_axis(_unwrap_if_constexpr(d), new_ndim) for d in axes] + + if len(set(axes)) != len(axes): + raise ValueError(f"expand_dims received duplicate axes, normalized axes = {axes}") + + ret = input + for a in sorted(axes): + ret = _semantic.expand_dims(ret, a) + return ret + + +@_tensor_member_fn +@builtin +def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _semantic=None): + """ + Casts a tensor to the given :code:`dtype`. + + :param dtype: The target data type. + :type dtype: tl.dtype + :param fp_downcast_rounding: The rounding mode for downcasting + floating-point values. This parameter is only used when self is a + floating-point tensor and dtype is a floating-point type with a + smaller bitwidth. Supported values are :code:`"rtne"` (round to + nearest, ties to even) and :code:`"rtz"` (round towards zero). + :type fp_downcast_rounding: str, optional + :param bitcast: If true, the tensor is bitcasted to the given + :code:`dtype`, instead of being numerically casted. + :type bitcast: bool, optional + """ + input = _semantic.to_tensor(input) + dtype = _unwrap_if_constexpr(dtype) + fp_downcast_rounding = _unwrap_if_constexpr(fp_downcast_rounding) + bitcast = _unwrap_if_constexpr(bitcast) + if bitcast: + return _semantic.bitcast(input, dtype) + return _semantic.cast(input, dtype, fp_downcast_rounding) + + +# ----------------------- +# Linear Algebra +# ----------------------- + + +@builtin +def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32, + _semantic=None): + """ + Returns the matrix product of two blocks. + + The two blocks must both be two-dimensional or three-dimensional and have compatible inner dimensions. + For three-dimensional blocks, `tl.dot` performs the batched matrix product, + where the first dimension of each block represents the batch dimension. + + :param input: The first tensor to be multiplied. + :type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param other: The second tensor to be multiplied. + :type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param acc: The accumulator tensor. If not None, the result is added to this tensor. + :type acc: 2D or 3D tensor of scalar-type in {:code:`float16`, :code:`float32`, :code:`int32`} + :param input_precision: How to exercise the Tensor Cores for f32 x f32. If + the device does not have Tensor Cores or the inputs are not of dtype f32, + this option is ignored. For devices that do have tensor cores, the + default precision is tf32. + :type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Available options for amd: :code:`"ieee"`, (CDNA3 only) :code:`"tf32"`. + :param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32". + Only one of :code:`input_precision` and :code:`allow_tf32` can be + specified (i.e. at least one must be :code:`None`). + """ + assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified" + if input_precision is None: + supports_tf32 = "tf32" in _semantic.builder.options.allowed_dot_input_precisions + input_precision = knobs.language.fp32_default or ("tf32" if (supports_tf32 and + (allow_tf32 or allow_tf32 is None)) else "ieee") + + input_precision = _unwrap_if_constexpr(input_precision) + out_dtype = _unwrap_if_constexpr(out_dtype) + max_num_imprecise_acc = _unwrap_if_constexpr(max_num_imprecise_acc) + acc = _unwrap_if_constexpr(acc) + return _semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype) + + +@builtin +def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, lhs_k_pack=True, + rhs_k_pack=True, out_dtype=float32, _semantic=None): + """ + Returns the matrix product of two blocks in microscaling format. + + lhs and rhs use microscaling formats described here: + https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + + Software emulation enables targeting hardware architectures without native microscaling + operation support. Right now for such case, microscaled lhs/rhs are upcasted to + :code:`bf16` element type beforehand for dot computation, with one exception: + for AMD CDNA3 specifically, if one of the inputs is of :code:`fp16` element type, + the other input is also upcasted to :code:`fp16` element type instead. + This behavior is experimental and may be subject to change in the future. + + :param lhs: The first tensor to be multiplied. + :type lhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type. + :param lhs_scale: Scale factor for lhs tensor. + :type lhs_scale: e8m0 type represented as an uint8 tensor. + :param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}. + :type lhs_format: str + :param rhs: The second tensor to be multiplied. + :type rhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type. + :param rhs_scale: Scale factor for rhs tensor. + :type rhs_scale: e8m0 type represented as an uint8 tensor. + :param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}. + :type rhs_format: str + :param acc: The accumulator tensor. If not None, the result is added to this tensor. + :param lhs_k_pack: If false, the lhs tensor is packed into uint8 along M dimension. + :type lhs_k_pack: bool, optional + :param rhs_k_pack: If false, the rhs tensor is packed into uint8 along N dimension. + :type rhs_k_pack: bool, optional + """ + out_dtype = _unwrap_if_constexpr(out_dtype) + assert out_dtype == float32, "Only float32 is supported for out_dtype at the moment" + return _semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, fast_math, lhs_k_pack, + rhs_k_pack, out_dtype) + + +# ----------------------- +# Non-Atomic Memory Operations +# ----------------------- + + +@builtin +def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="", + volatile=False, _semantic=None): + """ + Return a tensor of data whose values are loaded from memory at location defined by `pointer`: + + (1) If `pointer` is a single element pointer, a scalar is be loaded. In + this case: + + - `mask` and `other` must also be scalars, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional tensor is loaded. In this case: + + - `mask` and `other` are implicitly broadcast to `pointer.shape`, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a + tensor is loaded. In this case: + + - `mask` and `other` must be `None`, and + - `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access. + + :param pointer: Pointer to the data to be loaded + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]` + (must be `None` with block pointers) + :type mask: Block of `triton.int1`, optional + :param other: if `mask[idx]` is false, return `other[idx]` + :type other: Block, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param padding_option: should be one of {"", "zero", "nan"}, the padding value to use while out of bounds. "" means an undefined value. + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional, should be one of {"", ".ca", ".cg", ".cv"}, where ".ca" stands for + cache at all levels, ".cg" stands for cache at global level (cache in L2 and below, not L1), + and ".cv" means don’t cache and fetch again. see + `cache operator `_ for more details. + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional + :param volatile: changes volatile option in NVIDIA PTX + :type volatile: bool, optional + """ + # `mask` and `other` can be constexpr + mask = _unwrap_if_constexpr(mask) + other = _unwrap_if_constexpr(other) + if mask is not None: + mask = _semantic.to_tensor(mask) + if other is not None: + other = _semantic.to_tensor(other) + padding_option = _unwrap_if_constexpr(padding_option) + cache_modifier = _unwrap_if_constexpr(cache_modifier) + eviction_policy = _unwrap_if_constexpr(eviction_policy) + volatile = _unwrap_if_constexpr(volatile) + return _semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy, + volatile) + + +@builtin +def load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[constexpr | tensor], + _semantic=None) -> tensor: + """Load a block of data from a tensor descriptor.""" + return desc.load(offsets, _semantic=_semantic) + + +@builtin +def store_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[constexpr | tensor], value: tensor, + _semantic=None) -> tensor: + """Store a block of data to a tensor descriptor.""" + return desc.store(offsets, value, _semantic=_semantic) + + +@_tensor_member_fn +@builtin +def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _semantic=None): + """ + Store a tensor of data into memory locations defined by `pointer`. + + (1) If `pointer` is a single element pointer, a scalar is stored. In + this case: + + - `mask` must also be scalar, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional block is stored. In this case: + + - `mask` is implicitly broadcast to `pointer.shape`, and + - `boundary_check` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a block + of data is stored. In this case: + + - `mask` must be None, and + - `boundary_check` can be specified to control the behavior of out-of-bound access. + + `value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`. + + :param pointer: The memory location where the elements of `value` are stored + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param value: The tensor of elements to be stored + :type value: Block + :param mask: If `mask[idx]` is false, do not store `value[idx]` at `pointer[idx]` + :type mask: Block of triton.int1, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional, should be one of {"", ".wb", ".cg", ".cs", ".wt"}, where ".wb" stands for + cache write-back all coherent levels, ".cg" stands for cache global, ".cs" stands for cache streaming, ".wt" + stands for cache write-through, see `cache operator `_ for more details. + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional, should be one of {"", "evict_first", "evict_last"} + """ + # `value` can be constexpr + value = _semantic.to_tensor(value) + mask = _unwrap_if_constexpr(mask) + if mask is not None: + mask = _semantic.to_tensor(mask) + cache_modifier = _unwrap_if_constexpr(cache_modifier) + eviction_policy = _unwrap_if_constexpr(eviction_policy) + return _semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy) + + +@builtin +def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _semantic=None): + """ + Returns a pointer to a block in a parent tensor + + :param base: The base pointer to the parent tensor + :param shape: The shape of the parent tensor + :param strides: The strides of the parent tensor + :param offsets: The offsets to the block + :param block_shape: The shape of the block + :param order: The order of the original data format + """ + return _semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order) + + +@must_use_result( + "Note that tl.advance does not have any side effects. To move the block pointer, you need to assign the result of tl.advance to a variable." +) +@_tensor_member_fn +@builtin +def advance(base, offsets, _semantic=None): + """ + Advance a block pointer + + :param base: the block pointer to advance + :param offsets: the offsets to advance, a tuple by dimension + """ + return _semantic.advance(base, offsets) + + +@builtin +def make_tensor_descriptor( + base: tensor, + shape: List[tensor], + strides: List[tensor], + block_shape: List[constexpr], + _semantic=None, +) -> tensor_descriptor: + """Make a tensor descriptor object + + :param base: the base pointer of the tensor, must be 16-byte aligned + :param shape: A list of non-negative integers representing the tensor shape + :param strides: A list of tensor strides. Leading dimensions must be multiples + of 16-byte strides and the last dimension must be contiguous. + :param block_shape: The shape of block to be loaded/stored from global memory + + Notes + ***** + On NVIDIA GPUs with TMA support, this will result in a TMA descriptor object + and loads and stores from the descriptor will be backed by the TMA hardware. + + Currently only 2-5 dimensional tensors are supported. + + Example + ******* + .. code-block:: python + + @triton.jit + def inplace_abs(in_out_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): + desc = tl.make_tensor_descriptor( + in_out_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M_BLOCK, N_BLOCK], + ) + + moffset = tl.program_id(0) * M_BLOCK + noffset = tl.program_id(1) * N_BLOCK + + value = desc.load([moffset, noffset]) + desc.store([moffset, noffset], tl.abs(value)) + + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + M, N = 256, 256 + x = torch.randn(M, N, device="cuda") + M_BLOCK, N_BLOCK = 32, 32 + grid = (M / M_BLOCK, N / N_BLOCK) + inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK) + + """ + return _semantic.make_tensor_descriptor(base, shape, strides, block_shape) + + +# ----------------------- +# Atomic Memory Operations +# ----------------------- + + +def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = f""" + Performs an atomic {name} at the memory location specified by :code:`pointer`. + + Return the data stored at :code:`pointer` before the atomic operation. + + :param pointer: The memory locations to operate on + :type pointer: Block of dtype=triton.PointerDType""" + if has_cmp: + docstr += """ + :param cmp: The values expected to be found in the atomic object + :type cmp: Block of dtype=pointer.dtype.element_ty""" + docstr += """ + :param val: The values with which to perform the atomic operation + :type val: Block of dtype=pointer.dtype.element_ty + :param sem: Specifies the memory semantics for the operation. Acceptable values are "acquire", + "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, + the function defaults to using "acq_rel" semantics. + :type sem: str, optional + :param scope: Defines the scope of threads that observe the synchronizing effect of the atomic operation. + Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + :type scope: str, optional + """ + func.__doc__ = docstr + return func + + return _decorator + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("compare-and-swap", has_cmp=True) +def atomic_cas(pointer, cmp, val, sem=None, scope=None, _semantic=None): + cmp = _semantic.to_tensor(cmp) + val = _semantic.to_tensor(val) + sem = _unwrap_if_constexpr(sem) + scope = _unwrap_if_constexpr(scope) + return _semantic.atomic_cas(pointer, cmp, val, sem, scope) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("exchange") +def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _semantic=None): + val = _semantic.to_tensor(val) + sem = _unwrap_if_constexpr(sem) + scope = _unwrap_if_constexpr(scope) + mask = _unwrap_if_constexpr(mask) + return _semantic.atomic_xchg(pointer, val, mask, sem, scope) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("add") +def atomic_add(pointer, val, mask=None, sem=None, scope=None, _semantic=None): + val = _semantic.to_tensor(val) + sem = _unwrap_if_constexpr(sem) + scope = _unwrap_if_constexpr(scope) + mask = _unwrap_if_constexpr(mask) + return _semantic.atomic_add(pointer, val, mask, sem, scope) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("max") +def atomic_max(pointer, val, mask=None, sem=None, scope=None, _semantic=None): + val = _semantic.to_tensor(val) + sem = _unwrap_if_constexpr(sem) + scope = _unwrap_if_constexpr(scope) + mask = _unwrap_if_constexpr(mask) + return _semantic.atomic_max(pointer, val, mask, sem, scope) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("min") +def atomic_min(pointer, val, mask=None, sem=None, scope=None, _semantic=None): + val = _semantic.to_tensor(val) + sem = _unwrap_if_constexpr(sem) + scope = _unwrap_if_constexpr(scope) + mask = _unwrap_if_constexpr(mask) + return _semantic.atomic_min(pointer, val, mask, sem, scope) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical and") +def atomic_and(pointer, val, mask=None, sem=None, scope=None, _semantic=None): + val = _semantic.to_tensor(val) + sem = _unwrap_if_constexpr(sem) + scope = _unwrap_if_constexpr(scope) + mask = _unwrap_if_constexpr(mask) + return _semantic.atomic_and(pointer, val, mask, sem, scope) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical or") +def atomic_or(pointer, val, mask=None, sem=None, scope=None, _semantic=None): + val = _semantic.to_tensor(val) + sem = _unwrap_if_constexpr(sem) + scope = _unwrap_if_constexpr(scope) + mask = _unwrap_if_constexpr(mask) + return _semantic.atomic_or(pointer, val, mask, sem, scope) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical xor") +def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _semantic=None): + val = _semantic.to_tensor(val) + sem = _unwrap_if_constexpr(sem) + scope = _unwrap_if_constexpr(scope) + mask = _unwrap_if_constexpr(mask) + return _semantic.atomic_xor(pointer, val, mask, sem, scope) + + +# ----------------------- +# Conditioning +# ----------------------- + + +@builtin +def where(condition, x, y, _semantic=None): + """ + Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. + + Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`. + + If you want to avoid unintended memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead. + + The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`. + :code:`x` and :code:`y` must have the same data type. + + :param condition: When True (nonzero), yield x, otherwise yield y. + :type condition: Block of triton.bool + :param x: values selected at indices where condition is True. + :param y: values selected at indices where condition is False. + """ + condition = _semantic.to_tensor(condition) + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return _semantic.where(condition, x, y) + + +# ----------------------- +# Math +# ----------------------- + + +@builtin +def add(x, y, sanitize_overflow: constexpr = True, _semantic=None): + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return _semantic.add(x, y, sanitize_overflow) + + +@builtin +def sub(x, y, sanitize_overflow: constexpr = True, _semantic=None): + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return _semantic.sub(x, y, sanitize_overflow) + + +@builtin +def mul(x, y, sanitize_overflow: constexpr = True, _semantic=None): + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return _semantic.mul(x, y, sanitize_overflow) + + +@builtin +def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None): + """ + Computes the element-wise minimum of :code:`x` and :code:`y`. + + :param x: the first input tensor + :type x: Block + :param y: the second input tensor + :type y: Block + :param propagate_nan: whether to propagate NaN values. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _semantic.to_tensor(x) + y = _semantic.to_tensor(y) + x = _promote_bfloat16_to_float32(x, _semantic=_semantic) + y = _promote_bfloat16_to_float32(y, _semantic=_semantic) + propagate_nan = _unwrap_if_constexpr(propagate_nan) + return _semantic.minimum(x, y, propagate_nan) + + +@builtin +def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None): + """ + Computes the element-wise maximum of :code:`x` and :code:`y`. + + :param x: the first input tensor + :type x: Block + :param y: the second input tensor + :type y: Block + :param propagate_nan: whether to propagate NaN values. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _semantic.to_tensor(x) + y = _semantic.to_tensor(y) + x = _promote_bfloat16_to_float32(x, _semantic=_semantic) + y = _promote_bfloat16_to_float32(y, _semantic=_semantic) + propagate_nan = _unwrap_if_constexpr(propagate_nan) + return _semantic.maximum(x, y, propagate_nan) + + +@builtin +def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None): + """ + Clamps the input tensor :code:`x` within the range [min, max]. + Behavior when :code:`min` > :code:`max` is undefined. + + :param x: the input tensor + :type x: Block + :param min: the lower bound for clamping + :type min: Block + :param max: the upper bound for clamping + :type max: Block + :param propagate_nan: whether to propagate NaN values. Applies only to the :code:`x` tensor. + If either :code:`min` or :code:`max` is NaN, the result is undefined. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _semantic.to_tensor(x) + min = _semantic.to_tensor(min) + max = _semantic.to_tensor(max) + x = _promote_bfloat16_to_float32(x, _semantic=_semantic) + min = _promote_bfloat16_to_float32(min, _semantic=_semantic) + max = _promote_bfloat16_to_float32(max, _semantic=_semantic) + + propagate_nan = _unwrap_if_constexpr(propagate_nan) + + return _semantic.clamp(x, min, max, propagate_nan) + + +# ----------------------- +# Reductions +# ----------------------- + + +def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None, + dtype_arg: str = None) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :type input: Tensor + :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions + :type axis: int + :param keep_dims: if true, keep the reduced dimensions with length 1 + :type keep_dims: bool""" + if return_indices_arg is not None: + docstr += f""" + :param {return_indices_arg}: if true, return index corresponding to the {name} value + :type {return_indices_arg}: bool""" + if tie_break_arg is not None: + docstr += f""" + :param {tie_break_arg}: if true, in case of a tie (i.e., multiple elements have the same {name} value), return the left-most index for values that aren't NaN + :type {tie_break_arg}: bool""" + if dtype_arg is not None: + docstr += f""" + :param {dtype_arg}: the desired data type of the returned tensor. If specified, the input tensor is casted to :code:`{dtype_arg}` before the operation is performed. This is useful for preventing data overflows. If not specified, integer and bool dtypes are upcasted to :code:`tl.int32` and float dtypes are upcasted to at least :code:`tl.float32`. + :type {dtype_arg}: tl.dtype""" + + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@contextmanager +def _insertion_guard(builder): + ip = builder.get_insertion_point() + yield + builder.restore_insertion_point(ip) + + +@_tensor_member_fn +@builtin +def reduce(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None): + """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis` + + :param input: the input tensor, or tuple of tensors + :type input: Tensor + :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions + :type axis: int | None + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :type combine_fn: Callable + :param keep_dims: if true, keep the reduced dimensions with length 1 + :type keep_dims: bool + + """ + if isinstance(input, tensor): + return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _semantic=_semantic, _generator=_generator)[0] + + def make_combine_region(reduce_op): + param_types = [t.type.scalar for t in input] * 2 + region = reduce_op.get_region(0) + builder = _semantic.builder + with _insertion_guard(builder): + to_ir = lambda T: T.to_ir(builder) + block = builder.create_block_with_parent(region, list(map(to_ir, param_types))) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + builder.create_reduce_ret(*handles) + + def expand_ndims(t, ndims): + for _ in builtins.range(ndims): + t = expand_dims(t, 0, _semantic=_semantic) + return t + + axis = _unwrap_if_constexpr(axis) + keep_dims = _unwrap_if_constexpr(keep_dims) + if axis is not None: + axis = _wrap_axis(axis, len(input[0].shape)) + ret = _semantic.reduction(input, axis, make_combine_region) + if keep_dims: + if axis is not None: + ret = tuple(expand_dims(t, axis, _semantic=_semantic) for t in ret) + else: + ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret) + return ret + + +@builtin +def _promote_bfloat16_to_float32(t, _semantic=None): + scalar_ty = t.type.scalar + + # hardware doesn't support FMAX, FMIN, CMP for bfloat16 + if scalar_ty is bfloat16: + return t.to(float32, _semantic=_semantic) + return t + + +@builtin +def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None): + axis = _unwrap_if_constexpr(axis) + n = input.shape[axis] + index = arange(0, n, _semantic=_semantic) + + if len(input.shape) > 1: + # Broadcast index across the non-reduced axes + axes_to_expand = [constexpr(d) for d in builtins.range(len(input.shape))] + del axes_to_expand[axis] + index = expand_dims(index, axes_to_expand, _semantic=_semantic) + index = broadcast_to(index, input.shape, _semantic=_semantic) + + rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _semantic=_semantic, + _generator=_generator) + return rvalue, rindices + + +# ----------------------- +# Scans +# ----------------------- + + +def _add_scan_docstr(name: str, dtype_arg: str = None) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :type input: Tensor + :param axis: the dimension along which the scan should be done + :type axis: int + :param reverse: if true, the scan is performed in the reverse direction + :type reverse: bool""" + + if dtype_arg is not None: + docstr += f""" + :param {dtype_arg}: the desired data type of the returned tensor. If specified, the input tensor is casted to :code:`{dtype_arg}` before the operation is performed. If not specified, small integer types (< 32 bits) are upcasted to prevent overflow. Note that :code:`tl.bfloat16` inputs are automatically promoted to :code:`tl.float32`. + :type {dtype_arg}: tl.dtype""" + + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@_tensor_member_fn +@builtin +def associative_scan(input, axis, combine_fn, reverse=False, _semantic=None, _generator=None): + """Applies the combine_fn to each elements with a carry in :code:`input` tensors along the provided :code:`axis` and update the carry + + :param input: the input tensor, or tuple of tensors + :type input: Tensor + :param axis: the dimension along which the reduction should be done + :type axis: int + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :type combine_fn: Callable + :param reverse: whether to apply the associative scan in the reverse direction along axis + :type reverse: bool + + """ + if isinstance(input, tensor): + return associative_scan((input, ), axis, combine_fn, reverse, _semantic=_semantic, _generator=_generator)[0] + + def make_combine_region(scan_op): + param_types = [t.type.scalar for t in input] * 2 + region = scan_op.get_region(0) + builder = _semantic.builder + with _insertion_guard(builder): + to_ir = lambda T: T.to_ir(builder) + block = builder.create_block_with_parent(region, list(map(to_ir, param_types))) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + builder.create_scan_ret(*handles) + + axis = _unwrap_if_constexpr(axis) + if axis is not None: + axis = _wrap_axis(axis, len(input[0].shape)) + return _semantic.associative_scan(input, axis, make_combine_region, reverse) + + +@_tensor_member_fn +@builtin +def histogram(input, num_bins, mask=None, _semantic=None, _generator=None): + """computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0. + + :param input: the input tensor + :type input: Tensor + :param num_bins: number of histogram bins + :type num_bins: int + :param mask: if `mask[idx]` is false, exclude `input[idx]` from histogram + :type mask: Block of `triton.int1`, optional + + """ + num_bins = _unwrap_if_constexpr(num_bins) + mask = _unwrap_if_constexpr(mask) + if mask is not None: + mask = _semantic.to_tensor(mask) + return _semantic.histogram(input, num_bins, mask) + + +@_tensor_member_fn +@builtin +def gather(src, index, axis, _semantic=None): + """Gather from a tensor along a given dimension. + + :param src: the source tensor + :type src: Tensor + :param index: the index tensor + :type index: Tensor + :param axis: the dimension to gather along + :type axis: int + + """ + axis = _unwrap_if_constexpr(axis) + return _semantic.gather(src, index, axis) + + +# ----------------------- +# Compiler Hint Ops +# ----------------------- + + +@builtin +def debug_barrier(_semantic=None): + ''' + Insert a barrier to synchronize all threads in a block. + ''' + return _semantic.debug_barrier() + + +@builtin +def multiple_of(input, values, _semantic=None): + """ + Let the compiler know that the values in :code:`input` are all multiples of :code:`value`. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return _semantic.multiple_of(input, values) + + +@builtin +def max_contiguous(input, values, _semantic=None): + """ + Let the compiler know that the `value` first values in :code:`input` are contiguous. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return _semantic.max_contiguous(input, values) + + +@builtin +def max_constancy(input, values, _semantic=None): + """ + Let the compiler know that the `value` first values in :code:`input` are constant. + + e.g. if :code:`values` is [4], then each group of 4 values in :code:`input` should all be equal, + for example [0, 0, 0, 0, 1, 1, 1, 1]. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return _semantic.max_constancy(input, values) + + +@builtin +def assume(cond, _semantic=None): + ''' + Allow compiler to assume the :code:`cond` is True. + ''' + return _semantic.assume(_semantic.to_tensor(cond)) + + +# ----------------------- +# Debugging functions +# ----------------------- + + +@builtin +def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _semantic=None): + ''' + Print the values at compile time. The parameters are the same as the builtin :code:`print`. + + NOTE: Calling the Python builtin :code:`print` is not the same as calling this, it instead maps to :code:`device_print`, + which has special requirements for the arguments. + + .. highlight:: python + .. code-block:: python + + tl.static_print(f"BLOCK_SIZE={BLOCK_SIZE}") + ''' + pass + + +@builtin +def static_assert(cond, msg="", _semantic=None): + ''' + Assert the condition at compile time. Does not require that the :code:`TRITON_DEBUG` environment variable + is set. + + .. highlight:: python + .. code-block:: python + + tl.static_assert(BLOCK_SIZE == 1024) + ''' + pass + + +@builtin +def device_print(prefix, *args, hex=False, _semantic=None): + ''' + Print the values at runtime from the device. String formatting does not work for runtime values, so you should + provide the values you want to print as arguments. The first value must be a string, all following values must + be scalars or tensors. + + Calling the Python builtin :code:`print` is the same as calling this function, and the requirements for the arguments will match + this function (not the normal requirements for :code:`print`). + + .. highlight:: python + .. code-block:: python + + tl.device_print("pid", pid) + print("pid", pid) + + On CUDA, printfs are streamed through a buffer of limited size (on one host, + we measured the default as 6912 KiB, but this may not be consistent across + GPUs and CUDA versions). If you notice some printfs are being dropped, you + can increase the buffer size by calling + + .. highlight:: python + .. code-block:: python + + triton.runtime.driver.active.utils.set_printf_fifo_size(size_bytes) + + CUDA may raise an error if you try to change this value after running a + kernel that uses printfs. The value set here may only affect the current + device (so if you have multiple GPUs, you'd need to call it multiple times). + + :param prefix: a prefix to print before the values. This is required to be a string literal. + :param args: the values to print. They can be any tensor or scalar. + :param hex: print all values as hex instead of decimal + ''' + import string + prefix = _unwrap_if_constexpr(prefix) + assert isinstance(prefix, str), f"{prefix} is not string" + b_ascii = True + for ch in prefix: + if ch not in string.printable: + b_ascii = False + break + assert b_ascii, f"{prefix} is not an ascii string" + new_args = [] + for arg in args: + new_args.append(_semantic.to_tensor(arg)) + return _semantic.device_print(prefix, new_args, hex) + + +@builtin +def device_assert(cond, msg="", _semantic=None): + ''' + Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG` + is set to a value besides :code:`0` in order for this to have any effect. + + Using the Python :code:`assert` statement is the same as calling this function, except that the second argument + must be provided and must be a string, e.g. :code:`assert pid == 0, "pid != 0"`. The environment variable must + be set for this :code:`assert` statement to have any effect. + + .. highlight:: python + .. code-block:: python + + tl.device_assert(pid == 0) + assert pid == 0, f"pid != 0" + + :param cond: the condition to assert. This is required to be a boolean tensor. + :param msg: the message to print if the assertion fails. This is required to be a string literal. + ''' + msg = _unwrap_if_constexpr(msg) + return _semantic.device_assert(_semantic.to_tensor(cond), msg) + + +@builtin +def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Union[dtype, Sequence[dtype]], + is_pure: bool, pack: int, _semantic=None): + ''' + Execute inline assembly over a tensor. Essentially, this is :code:`map` + where the function is inline assembly. + + The input tensors :code:`args` are implicitly broadcasted to the same shape. + + :code:`dtype` can be a tuple of types, in which case the output is a + tuple of tensors. + + Each invocation of the inline asm processes :code:`pack` elements at a + time. Exactly which set of inputs a block receives is unspecified. + Input elements of size less than 4 bytes are packed into 4-byte + registers. + + This op does not support empty :code:`dtype` -- the inline asm must + return at least one tensor, even if you don't need it. You can work + around this by returning a dummy tensor of arbitrary type; it shouldn't + cost you anything if you don't use it. + + Example using + `PTX `_ + assembly: + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor + b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor + + # For each (a,b) in zip(a,b), perform the following: + # - Let ai be `a` converted to int32. + # - Let af be `a` converted to float. + # - Let m be the max of ai and b. + # - Return ai and mi. + # Do the above 4 elements at a time. + (c, d) = tl.inline_asm_elementwise( + asm=""" + { + // Unpack `a` into `ai`. + .reg .b8 tmp<4>; + mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; + cvt.u32.u8 $0, tmp0; + cvt.u32.u8 $1, tmp1; + cvt.u32.u8 $2, tmp2; + cvt.u32.u8 $3, tmp3; + } + // Convert `ai` to float. + cvt.rn.f32.s32 $4, $0; + cvt.rn.f32.s32 $5, $1; + cvt.rn.f32.s32 $6, $2; + cvt.rn.f32.s32 $7, $3; + // Take max of `ai` and `b`. + max.f32 $4, $4, $9; + max.f32 $5, $5, $10; + max.f32 $6, $6, $11; + max.f32 $7, $7, $12; + """, + constraints=( + # 8 output registers, namely + # $0=ai0, $1=ai1, $2=ai2, $3=ai3, + # $4=m0, $5=m1, $6=m2, $7=m3. + "=r,=r,=r,=r,=r,=r,=r,=r," + # 5 input registers, namely + # $8=ai, + # $9=b0, $10=b1, $11=b2, $12=b3. + # The four elements from `a` are all packed into one register. + "r,r,r,r,r"), + args=[a, b], + dtype=(tl.int32, tl.float32), + is_pure=True, + pack=4, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + :param asm: assembly to run. Must match target's assembly format. + :param constraints: asm constraints in + `LLVM format `_ + :param args: the input tensors, whose values are passed to the asm block + :param dtype: the element type(s) of the returned tensor(s) + :param is_pure: if true, the compiler assumes the asm block has no side-effects + :param pack: the number of elements to be processed by one instance of inline assembly + :return: one tensor or a tuple of tensors of the given dtypes + ''' + asm = _unwrap_if_constexpr(asm) + constraints = _unwrap_if_constexpr(constraints) + pack = _unwrap_if_constexpr(pack) + is_pure = _unwrap_if_constexpr(is_pure) + + # Wrap `dtype` in a tuple if it's not already. + try: + iter(dtype) # type: ignore + has_multiple_outputs = True + except TypeError: + has_multiple_outputs = False + dtype = (dtype, ) # type: ignore + + dtype = typing.cast(Sequence[_DtypeClass], dtype) + + res_tys = dtype + if dispatch_args := [_semantic.to_tensor(arg) for arg in args]: + bin_op_type_checking = partial( + _semantic.binary_op_type_checking_impl, + arithmetic_check=False, + allow_lhs_ptr=True, + allow_rhs_ptr=True, + ) + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for item in dispatch_args: + _, broadcast_arg = bin_op_type_checking(item, broadcast_arg) + if broadcast_arg.shape: + # Change the shape of each argument based on the broadcast shape + for i, item in enumerate(dispatch_args): + dispatch_args[i], _ = bin_op_type_checking(item, broadcast_arg) + res_tys = [broadcast_arg.type.with_element_ty(dt) for dt in dtype] + handles = [t.handle for t in dispatch_args] + builder = _semantic.builder + call = builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(builder) for ty in res_tys], is_pure, pack) + + if not has_multiple_outputs: + return tensor(call.get_result(0), res_tys[0]) + return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(res_tys)) + + +# ----------------------- +# Iterators +# ----------------------- + + +class static_range: + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.static_range(10): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it also guides the compiler to unroll the loop aggressively. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + """ + + def __init__(self, arg1, arg2=None, step=None): + assert isinstance(arg1, constexpr), f"{arg1} used as tl.static_range start value is not a constexpr" + if step is None: + self.step = constexpr(1) + else: + assert isinstance(step, constexpr), f"{step} used as tl.static_range step value is not a constexpr" + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + assert isinstance(arg2, constexpr), f"{arg2} used as tl.static_range end value is not a constexpr" + self.start = arg1 + self.end = arg2 + + def __iter__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + +class async_task: + """ + Context manager to run code fragments asynchronously. + """ + + def __init__(self, task_ids, _builder=None): + self.task_ids = list({_unwrap_if_constexpr(tid) for tid in task_ids}) + self.builder = _builder + + def __enter__(self): + self.builder.set_async_task_ids(self.task_ids) + + def __exit__(self, exc_type, exc_value, traceback): + self.builder.unset_async_task_ids() + + +class range: + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.range(10, num_stages=3): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + :param num_stages: pipeline the loop into this many stages (so there are + :code:`num_stages` iterations of the loop in flight at once). + + Note this is subtly different than passing :code:`num_stages` as a + kernel argument. The kernel argument only pipelines loads that feed + into :code:`dot` operations, while this attribute tries to pipeline most + (though not all) loads in this loop. + :param loop_unroll_factor: Tells the Triton IR level loop unroller how many + times to unroll a for loop that this range is used with. Less than 2 for + this value implies no unrolling. + :param disallow_acc_multi_buffer: If true, prevent the accumulator of the dot + operation in the loop to be multi-buffered, if applicable. + :param flatten: automatically flatten the loop nest starting at this loop to + create a single flattened loop. The compiler will try to pipeline the + flattened loop which can avoid stage stalling. + :param warp_specialize: Enable automatic warp specialization on the loop. + The compiler will attempt to partition memory, MMA, and vector + operations in the loop into separate async partitions. This will + increase the total number of warps required by the kernel. + + Note that warp specialization is only supported on Blackwell GPUs and + only works on simple matmul loops. Support for arbitrary loops will be + expanded over time. + """ + + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, + disallow_acc_multi_buffer=False, flatten=False, warp_specialize=False): + if step is None: + self.step = constexpr(1) + else: + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + self.start = arg1 + self.end = arg2 + self.num_stages = num_stages + self.loop_unroll_factor = loop_unroll_factor + self.disallow_acc_multi_buffer = disallow_acc_multi_buffer + self.flatten = flatten + self.warp_specialize = warp_specialize + + def __iter__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + +# ----------------------- +# Extern functions +# ----------------------- + + +def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, + is_pure: bool, _semantic): + ''' + Dispatch a function to a library + :param func: the function to dispatch + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param ret_shape: the shape of the return value + :return: the return value of the function + ''' + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + arg_types = [] + arg_list = [] + for arg in args: + if isinstance(arg, tensor): + arg_types.append(arg.dtype) + arg_list.append(arg.handle) + else: + arg_types.append(type(arg)) + arg_list.append(arg) + arg_types = tuple(arg_types) + + if arg_types not in arg_type_symbol_dict: + raise ValueError(f"input arg type does not match." + f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}") + else: + symbol = arg_type_symbol_dict[arg_types][0] + ret_type = arg_type_symbol_dict[arg_types][1] + if ret_shape: + ret_type = block_type(ret_type, ret_shape) + builder = _semantic.builder + return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(builder), is_pure), ret_type) + + +@builtin +def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, + _semantic=None): + ''' + Dispatch an elementwise function to a library + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param is_pure: whether the function is pure + :return: the return value of the function + ''' + dispatch_args = args.copy() + all_scalar = True + ret_shape = None + arg_types = [] + for i in builtins.range(len(dispatch_args)): + dispatch_args[i] = _semantic.to_tensor(dispatch_args[i]) + arg_types.append(dispatch_args[i].dtype) + if dispatch_args[i].type.is_block(): + all_scalar = False + if len(arg_types) > 0: + arg_types = tuple(arg_types) + arithmetic_check = True + # If there's a type tuple that is not supported by the library, we will do arithmetic check + if arg_types in arg_type_symbol_dict: + arithmetic_check = False + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for item in dispatch_args: + _, broadcast_arg = _semantic.binary_op_type_checking_impl(item, broadcast_arg, + arithmetic_check=arithmetic_check) + # Change the shape of each argument based on the broadcast shape + for i in builtins.range(len(dispatch_args)): + dispatch_args[i], _ = _semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, + arithmetic_check=arithmetic_check) + if not all_scalar: + ret_shape = broadcast_arg.shape + func = _semantic.builder.create_extern_elementwise + return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _semantic) + + +def binary_op_type_legalization(lhs, rhs, semantic): + ''' + Convert both operands to a single common type + :param lhs: the left operand + :param rhs: the right operand + :param builder: the builder + ''' + return semantic.binary_op_type_checking_impl(lhs, rhs) + + +def extern(fn): + """A decorator for external functions.""" + return builtin(fn) diff --git a/third_party/sunrise/python/triton/language/extra/__init__.py b/third_party/sunrise/python/triton/language/extra/__init__.py new file mode 100644 index 000000000..3f8c70a71 --- /dev/null +++ b/third_party/sunrise/python/triton/language/extra/__init__.py @@ -0,0 +1,26 @@ +import pkgutil +from importlib.util import module_from_spec +from sys import modules + +_backends = [] +for module_finder, module_name, is_pkg in pkgutil.iter_modules( + __path__, + prefix=__name__ + ".", +): + # skip .py files (like libdevice.py) + if not is_pkg: + continue + + # import backends (like cuda and hip) that are included during setup.py + spec = module_finder.find_spec(module_name) + if spec is None or spec.loader is None: + continue + module = module_from_spec(spec) + spec.loader.exec_module(module) + + _backends.append(module_name) + modules[module_name] = module + +__all__ = _backends + +del _backends diff --git a/third_party/sunrise/python/triton/language/extra/cuda b/third_party/sunrise/python/triton/language/extra/cuda new file mode 120000 index 000000000..fc5f8a28a --- /dev/null +++ b/third_party/sunrise/python/triton/language/extra/cuda @@ -0,0 +1 @@ +/mnt/data/lisirui/TritonDev/Repos/tb/tt34/triton/third_party/nvidia/language/cuda \ No newline at end of file diff --git a/third_party/sunrise/python/triton/language/extra/hip b/third_party/sunrise/python/triton/language/extra/hip new file mode 120000 index 000000000..dbeb20d81 --- /dev/null +++ b/third_party/sunrise/python/triton/language/extra/hip @@ -0,0 +1 @@ +/mnt/data/lisirui/TritonDev/Repos/tb/tt34/triton/third_party/amd/language/hip \ No newline at end of file diff --git a/third_party/sunrise/python/triton/language/extra/libdevice.py b/third_party/sunrise/python/triton/language/extra/libdevice.py new file mode 100644 index 000000000..76627035d --- /dev/null +++ b/third_party/sunrise/python/triton/language/extra/libdevice.py @@ -0,0 +1,786 @@ +def clz(arg0): + ... + + +def popc(arg0): + ... + + +def byte_perm(arg0, arg1, arg2): + ... + + +def mulhi(arg0, arg1): + ... + + +def mul24(arg0, arg1): + ... + + +def brev(arg0): + ... + + +def sad(arg0, arg1, arg2): + ... + + +def abs(arg0): + ... + + +def floor(arg0): + ... + + +def rcp64h(arg0): + ... + + +def rsqrt(arg0): + ... + + +def ceil(arg0): + ... + + +def trunc(arg0): + ... + + +def exp2(arg0): + ... + + +def saturatef(arg0): + ... + + +def fma_rn(arg0, arg1, arg2): + ... + + +def fma_rz(arg0, arg1, arg2): + ... + + +def fma_rd(arg0, arg1, arg2): + ... + + +def fma_ru(arg0, arg1, arg2): + ... + + +def fast_dividef(arg0, arg1): + ... + + +def div_rn(arg0, arg1): + ... + + +def div_rz(arg0, arg1): + ... + + +def div_rd(arg0, arg1): + ... + + +def div_ru(arg0, arg1): + ... + + +def rcp_rn(arg0): + ... + + +def rcp_rz(arg0): + ... + + +def rcp_rd(arg0): + ... + + +def rcp_ru(arg0): + ... + + +def sqrt_rn(arg0): + ... + + +def sqrt_rz(arg0): + ... + + +def sqrt_rd(arg0): + ... + + +def sqrt_ru(arg0): + ... + + +def sqrt(arg0): + ... + + +def add_rn(arg0, arg1): + ... + + +def add_rz(arg0, arg1): + ... + + +def add_rd(arg0, arg1): + ... + + +def add_ru(arg0, arg1): + ... + + +def mul_rn(arg0, arg1): + ... + + +def mul_rz(arg0, arg1): + ... + + +def mul_rd(arg0, arg1): + ... + + +def mul_ru(arg0, arg1): + ... + + +def double2float_rn(arg0): + ... + + +def double2float_rz(arg0): + ... + + +def double2float_rd(arg0): + ... + + +def double2float_ru(arg0): + ... + + +def double2int_rn(arg0): + ... + + +def double2int_rz(arg0): + ... + + +def double2int_rd(arg0): + ... + + +def double2int_ru(arg0): + ... + + +def double2uint_rn(arg0): + ... + + +def double2uint_rz(arg0): + ... + + +def double2uint_rd(arg0): + ... + + +def double2uint_ru(arg0): + ... + + +def int2double_rn(arg0): + ... + + +def uint2double_rn(arg0): + ... + + +def float2int_rn(arg0): + ... + + +def float2int_rz(arg0): + ... + + +def float2int_rd(arg0): + ... + + +def float2int_ru(arg0): + ... + + +def float2uint_rn(arg0): + ... + + +def float2uint_rz(arg0): + ... + + +def float2uint_rd(arg0): + ... + + +def float2uint_ru(arg0): + ... + + +def int2float_rn(arg0): + ... + + +def int2float_rz(arg0): + ... + + +def int2float_rd(arg0): + ... + + +def int2float_ru(arg0): + ... + + +def uint2float_rn(arg0): + ... + + +def uint2float_rz(arg0): + ... + + +def uint2float_rd(arg0): + ... + + +def uint2float_ru(arg0): + ... + + +def hiloint2double(arg0, arg1): + ... + + +def double2loint(arg0): + ... + + +def double2hiint(arg0): + ... + + +def float2ll_rn(arg0): + ... + + +def float2ll_rz(arg0): + ... + + +def float2ll_rd(arg0): + ... + + +def float2ll_ru(arg0): + ... + + +def float2ull_rn(arg0): + ... + + +def float2ull_rz(arg0): + ... + + +def float2ull_rd(arg0): + ... + + +def float2ull_ru(arg0): + ... + + +def double2ll_rn(arg0): + ... + + +def double2ll_rz(arg0): + ... + + +def double2ll_rd(arg0): + ... + + +def double2ll_ru(arg0): + ... + + +def double2ull_rn(arg0): + ... + + +def double2ull_rz(arg0): + ... + + +def double2ull_rd(arg0): + ... + + +def double2ull_ru(arg0): + ... + + +def ll2float_rn(arg0): + ... + + +def ll2float_rz(arg0): + ... + + +def ll2float_rd(arg0): + ... + + +def ll2float_ru(arg0): + ... + + +def ull2float_rn(arg0): + ... + + +def ull2float_rz(arg0): + ... + + +def ull2float_rd(arg0): + ... + + +def ull2float_ru(arg0): + ... + + +def ll2double_rn(arg0): + ... + + +def ll2double_rz(arg0): + ... + + +def ll2double_rd(arg0): + ... + + +def ll2double_ru(arg0): + ... + + +def ull2double_rn(arg0): + ... + + +def ull2double_rz(arg0): + ... + + +def ull2double_rd(arg0): + ... + + +def ull2double_ru(arg0): + ... + + +def int_as_float(arg0): + ... + + +def float_as_int(arg0): + ... + + +def uint_as_float(arg0): + ... + + +def float_as_uint(arg0): + ... + + +def longlong_as_double(arg0): + ... + + +def double_as_longlong(arg0): + ... + + +def fast_sinf(arg0): + ... + + +def fast_cosf(arg0): + ... + + +def fast_log2f(arg0): + ... + + +def fast_logf(arg0): + ... + + +def fast_expf(arg0): + ... + + +def fast_tanf(arg0): + ... + + +def fast_exp10f(arg0): + ... + + +def fast_log10f(arg0): + ... + + +def fast_powf(arg0, arg1): + ... + + +def hadd(arg0, arg1): + ... + + +def rhadd(arg0, arg1): + ... + + +def sub_rn(arg0, arg1): + ... + + +def sub_rz(arg0, arg1): + ... + + +def sub_rd(arg0, arg1): + ... + + +def sub_ru(arg0, arg1): + ... + + +def rsqrt_rn(arg0): + ... + + +def ffs(arg0): + ... + + +def rint(arg0): + ... + + +def llrint(arg0): + ... + + +def nearbyint(arg0): + ... + + +def isnan(arg0): + ... + + +def signbit(arg0): + ... + + +def copysign(arg0, arg1): + ... + + +def finitef(arg0): + ... + + +def isinf(arg0): + ... + + +def nextafter(arg0, arg1): + ... + + +def sin(arg0): + ... + + +def cos(arg0): + ... + + +def sinpi(arg0): + ... + + +def cospi(arg0): + ... + + +def tan(arg0): + ... + + +def log2(arg0): + ... + + +def exp(arg0): + ... + + +def exp10(arg0): + ... + + +def cosh(arg0): + ... + + +def sinh(arg0): + ... + + +def tanh(arg0): + ... + + +def atan2(arg0, arg1): + ... + + +def atan(arg0): + ... + + +def asin(arg0): + ... + + +def acos(arg0): + ... + + +def log(arg0): + ... + + +def log10(arg0): + ... + + +def log1p(arg0): + ... + + +def acosh(arg0): + ... + + +def asinh(arg0): + ... + + +def atanh(arg0): + ... + + +def expm1(arg0): + ... + + +def hypot(arg0, arg1): + ... + + +def rhypot(arg0, arg1): + ... + + +def norm3d(arg0, arg1, arg2): + ... + + +def rnorm3d(arg0, arg1, arg2): + ... + + +def norm4d(arg0, arg1, arg2, arg3): + ... + + +def rnorm4d(arg0, arg1, arg2, arg3): + ... + + +def cbrt(arg0): + ... + + +def rcbrt(arg0): + ... + + +def j0(arg0): + ... + + +def j1(arg0): + ... + + +def y0(arg0): + ... + + +def y1(arg0): + ... + + +def yn(arg0, arg1): + ... + + +def jn(arg0, arg1): + ... + + +def cyl_bessel_i0(arg0): + ... + + +def cyl_bessel_i1(arg0): + ... + + +def erf(arg0): + ... + + +def erfinv(arg0): + ... + + +def erfc(arg0): + ... + + +def erfcx(arg0): + ... + + +def erfcinv(arg0): + ... + + +def normcdfinv(arg0): + ... + + +def normcdf(arg0): + ... + + +def lgamma(arg0): + ... + + +def ldexp(arg0, arg1): + ... + + +def scalbn(arg0, arg1): + ... + + +def fmod(arg0, arg1): + ... + + +def remainder(arg0, arg1): + ... + + +def fma(arg0, arg1, arg2): + ... + + +def pow(arg0, arg1): + ... + + +def tgamma(arg0): + ... + + +def round(arg0): + ... + + +def llround(arg0): + ... + + +def fdim(arg0, arg1): + ... + + +def ilogb(arg0): + ... + + +def logb(arg0): + ... + + +def isfinited(arg0): + ... diff --git a/third_party/sunrise/python/triton/language/extra/tang b/third_party/sunrise/python/triton/language/extra/tang new file mode 120000 index 000000000..16c8cfeaa --- /dev/null +++ b/third_party/sunrise/python/triton/language/extra/tang @@ -0,0 +1 @@ +/mnt/data/lisirui/TritonDev/Repos/tb/tt34/triton/third_party/sunrise/language/tang \ No newline at end of file diff --git a/third_party/sunrise/python/triton/language/math.py b/third_party/sunrise/python/triton/language/math.py new file mode 100644 index 000000000..582cd876c --- /dev/null +++ b/third_party/sunrise/python/triton/language/math.py @@ -0,0 +1,249 @@ +from . import core +from functools import wraps +from typing import List + +T = core.TypeVar('T') + + +def _check_dtype(dtypes: List[str]) -> T: + """ + We're following libdevice's convention to check accepted data types for math functions. + It is not a good practice to support all data types as accelerators/GPUs don't support + many float16 and bfloat16 math operations. + We should let the users know that they are using and invoke explicit cast to convert + the data type to the supported one. + """ + + def wrapper(fn): + + @wraps(fn) + def check(*args, **kwargs): + # concatenate args and kwargs + all_args = list(args) + list(kwargs.values()) + for arg in [a for a in all_args if isinstance(a, core.tensor)]: + if arg.type.scalar.name not in dtypes: + raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}") + return fn(*args, **kwargs) + + return check + + return wrapper + + +def _add_math_1arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`. + + :param x: the input values + :type x: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +def _add_math_2arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x` and :code:`y`. + + :param x: the input values + :type x: Block + :param y: the input values + :type y: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +def _add_math_3arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`, :code:`y`, and :code:`z`. + + :param x: the input values + :type x: Block + :param y: the input values + :type y: Block + :param z: the input values + :type z: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@core.builtin +@_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"]) +@_add_math_2arg_docstr("most significant N bits of the 2N-bit product") +def umulhi(x, y, _semantic=None): + x = _semantic.to_tensor(x) + y = _semantic.to_tensor(y) + x, y = core.binary_op_type_legalization(x, y, _semantic) + return core.tensor(_semantic.builder.create_umulhi(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("exponential") +@core._tensor_member_fn +def exp(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_exp(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("exponential (base 2)") +@core._tensor_member_fn +def exp2(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_exp2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("natural logarithm") +@core._tensor_member_fn +def log(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_log(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("logarithm (base 2)") +@core._tensor_member_fn +def log2(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_log2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("cosine") +@core._tensor_member_fn +def cos(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_cos(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("sine") +@core._tensor_member_fn +def sin(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_sin(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("fast square root") +@core._tensor_member_fn +def sqrt(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32"]) +@_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)") +@core._tensor_member_fn +def sqrt_rn(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_precise_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("inverse square root") +@core._tensor_member_fn +def rsqrt(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_rsqrt(x.handle), x.type) + + +@core._tensor_member_fn +@core.builtin +@_add_math_1arg_docstr("absolute value") +def abs(x, _semantic=None): + x = _semantic.to_tensor(x) + dtype = x.dtype + if dtype.is_fp8e4b15(): + mask = core.full(x.shape, 0x7F, core.int8, _semantic=_semantic) + return core.tensor(_semantic.builder.create_and(x.handle, mask.handle), x.type) + elif dtype.is_floating(): + return core.tensor(_semantic.builder.create_fabs(x.handle), x.type) + elif dtype.is_int_signed(): + return core.tensor(_semantic.builder.create_iabs(x.handle), x.type) + elif dtype.is_int_unsigned(): + return x # no-op + else: + assert False, f"Unexpected dtype {dtype}" + + +@core.builtin +@_add_math_2arg_docstr("fast division") +def fdiv(x, y, ieee_rounding=False, _semantic=None): + ieee_rounding = core._unwrap_if_constexpr(ieee_rounding) + x = _semantic.to_tensor(x) + y = _semantic.to_tensor(y) + return _semantic.fdiv(x, y, ieee_rounding) + + +@core.builtin +@_check_dtype(dtypes=["fp32"]) +@_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)") +def div_rn(x, y, _semantic=None): + x = _semantic.to_tensor(x) + y = _semantic.to_tensor(y) + x, y = core.binary_op_type_legalization(x, y, _semantic) + return core.tensor(_semantic.builder.create_precise_divf(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("error function") +@core._tensor_member_fn +def erf(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_erf(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("floor") +@core._tensor_member_fn +def floor(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_floor(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("ceil") +@core._tensor_member_fn +def ceil(x, _semantic=None): + x = _semantic.to_tensor(x) + return core.tensor(_semantic.builder.create_ceil(x.handle), x.type) + + +@core.builtin +@_add_math_3arg_docstr("fused multiply-add") +def fma(x, y, z, _semantic=None): + x = _semantic.to_tensor(x) + y = _semantic.to_tensor(y) + z = _semantic.to_tensor(z) + x, y = core.binary_op_type_legalization(x, y, _semantic) + z, x = core.binary_op_type_legalization(z, x, _semantic) + z, y = core.binary_op_type_legalization(z, y, _semantic) + return core.tensor(_semantic.builder.create_fma(x.handle, y.handle, z.handle), x.type) diff --git a/third_party/sunrise/python/triton/language/random.py b/third_party/sunrise/python/triton/language/random.py new file mode 100644 index 000000000..1f6c192dd --- /dev/null +++ b/third_party/sunrise/python/triton/language/random.py @@ -0,0 +1,218 @@ +from ..runtime.jit import jit +from . import core as tl +from . import math + +N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox + +# ------------------- +# randint +# ------------------- + + +@jit +def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1). + """ + if c0.dtype == tl.uint32: + PHILOX_KEY_A: tl.constexpr = 0x9E3779B9 + PHILOX_KEY_B: tl.constexpr = 0xBB67AE85 + PHILOX_ROUND_A: tl.constexpr = 0xD2511F53 + PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57 + else: + tl.static_assert(c0.dtype == tl.uint64, "dtype not supported in philox_impl") + PHILOX_KEY_A: tl.constexpr = 0x9E3779B97F4A7C15 + PHILOX_KEY_B: tl.constexpr = 0xBB67AE8584CAA73B + PHILOX_ROUND_A: tl.constexpr = 0xD2E7470EE14C6C93 + PHILOX_ROUND_B: tl.constexpr = 0xCA5A826395121157 + + for _ in tl.static_range(n_rounds): + # for _ in range(n_rounds): + # update random state + A = PHILOX_ROUND_A + B = PHILOX_ROUND_B + _c0, _c2 = c0, c2 + c0 = math.umulhi(B, _c2) ^ c1 ^ k0 + c2 = math.umulhi(A, _c0) ^ c3 ^ k1 + c1 = tl.mul(B, _c2, sanitize_overflow=False) + c3 = tl.mul(A, _c0, sanitize_overflow=False) + # raise key + k0 = tl.add(k0, PHILOX_KEY_A, sanitize_overflow=False) + k1 = tl.add(k1, PHILOX_KEY_B, sanitize_overflow=False) + return c0, c1, c2, c3 + + +@jit +def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + seed = tl.to_tensor(seed) + tl.static_assert(seed.dtype.is_int()) + seed = seed.to(tl.uint64) + c0 = tl.to_tensor(c0) + c1 = tl.to_tensor(c1) + c2 = tl.to_tensor(c2) + c3 = tl.to_tensor(c3) + + if tl.constexpr(c0.dtype.primitive_bitwidth) == 32: + int_dtype = tl.uint32 + seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32) + seed_lo = (seed & 0xffffffff).to(tl.uint32) + else: + tl.static_assert(tl.constexpr(c0.dtype.primitive_bitwidth) == 64, "bitwidth not supported in philox") + int_dtype = tl.uint64 + seed_hi = tl.full((1, ), 0, dtype=int_dtype) + seed_lo = seed + + c0 = c0.to(int_dtype, bitcast=True) + c1 = c1.to(int_dtype, bitcast=True) + c2 = c2.to(int_dtype, bitcast=True) + c3 = c3.to(int_dtype, bitcast=True) + return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds) + + +@jit +def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns a single + block of random :code:`int32`. + + If you need multiple streams of random numbers, + using `randint4x` is likely to be faster than calling `randint` 4 times. + + :param seed: The seed for generating random numbers. + :param offset: The offsets to generate random numbers for. + """ + ret, _, _, _ = randint4x(seed, offset, n_rounds) + return ret + + +@jit +def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns four + blocks of random :code:`int32`. + + This is the maximally efficient entry point + to Triton's Philox pseudo-random number generator. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + # _0 = tl.zeros(offset.shape, offset.dtype) + + offset_lo = offset.to(tl.uint32) + _0 = offset_lo * 0 + + if tl.constexpr(offset.dtype.primitive_bitwidth) > 32: + offset_hi = (offset >> 32).to(tl.uint32) + else: + offset_hi = _0 + + return philox(seed, offset_lo, offset_hi, _0, _0, n_rounds) + + +# ------------------- +# rand +# ------------------- + +# @jit +# def uint32_to_uniform_float(x): +# """ +# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). +# """ +# two_to_the_minus_32: tl.constexpr = 2.328306e-10 +# return x * two_to_the_minus_32 + + +@jit +def uint_to_uniform_float(x): + """ + Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1). + """ + # TODO: fix frontend issues and cleanup + # conditions can be simplified + # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1) + if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32): + # maximum value such that `MAX_INT * scale < 1.0` (with float rounding) + x = x.to(tl.int32, bitcast=True) + scale = 4.6566127342e-10 + else: + tl.static_assert(tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64)) + x = x.to(tl.int64, bitcast=True) + scale = 1.0842020432385337e-19 + x = tl.where(x < 0, -x - 1, x) + return x * scale + + +@jit +def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`U(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + source = randint(seed, offset, n_rounds) + return uint_to_uniform_float(source) + + +@jit +def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offsets` block, + returns 4 blocks of random :code:`float32` in :math:`U(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds) + u1 = uint_to_uniform_float(i1) + u2 = uint_to_uniform_float(i2) + u3 = uint_to_uniform_float(i3) + u4 = uint_to_uniform_float(i4) + return u1, u2, u3, u4 + + +# ------------------- +# randn +# ------------------- + + +@jit +def pair_uniform_to_normal(u1, u2): + """Box-Muller transform""" + u1 = tl.maximum(1.0e-7, u1) + th = 6.283185307179586 * u2 + r = math.sqrt(-2.0 * math.log(u1)) + return r * math.cos(th), r * math.sin(th) + + +@jit +def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, _, _ = randint4x(seed, offset, n_rounds) + u1 = uint_to_uniform_float(i1) + u2 = uint_to_uniform_float(i2) + n1, _ = pair_uniform_to_normal(u1, u2) + return n1 + + +@jit +def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + u1, u2, u3, u4 = rand4x(seed, offset, n_rounds) + n1, n2 = pair_uniform_to_normal(u1, u2) + n3, n4 = pair_uniform_to_normal(u3, u4) + return n1, n2, n3, n4 diff --git a/third_party/sunrise/python/triton/language/semantic.py b/third_party/sunrise/python/triton/language/semantic.py new file mode 100644 index 000000000..0da22adef --- /dev/null +++ b/third_party/sunrise/python/triton/language/semantic.py @@ -0,0 +1,1886 @@ +from __future__ import annotations # remove after python 3.11 +import warnings + +from typing import List, Optional, Sequence, Tuple, TypeVar, Generic, Type +import numbers + +from triton.runtime import driver + +from .._C.libtriton import ir +from . import core as tl + +T = TypeVar('T') +TensorTy = TypeVar('TensorTy') + + +class IncompatibleTypeErrorImpl(Exception): + + def __init__(self, type_a, type_b): + self.type_a = type_a + self.type_b = type_b + self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__() + super(IncompatibleTypeErrorImpl, self).__init__(self.message) + + +class TritonSemantic(Generic[TensorTy]): + tensor: Type[TensorTy] = tl.tensor + lang = tl + + builder: ir.builder + + def __init__(self, builder): + self.builder = builder + +# ===----------------------------------------------------------------------===## +# Programming Model +# ===----------------------------------------------------------------------===## + + def program_id(self, axis: int) -> TensorTy: + if axis not in (0, 1, 2): + raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}") + return self.tensor(self.builder.create_get_program_id(axis), tl.int32) + + def num_programs(self, axis: int) -> TensorTy: + if axis not in (0, 1, 2): + raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}") + return self.tensor(self.builder.create_get_num_programs(axis), tl.int32) + +# ===----------------------------------------------------------------------===// +# Implicit Casting Utilities +# ===----------------------------------------------------------------------===// + + def integer_promote_impl(self, a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype: + a_rank = a_ty.int_bitwidth + b_rank = b_ty.int_bitwidth + a_sn = a_ty.int_signedness + b_sn = b_ty.int_signedness + # Rules for signedness taken from "Usual arithmetic conversions" on + # https://en.cppreference.com/w/c/language/conversion. + if a_sn == b_sn: + return a_ty if a_rank > b_rank else b_ty + elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return a_ty if a_rank >= b_rank else b_ty + elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return b_ty if b_rank >= a_rank else a_ty + raise TypeError(f"unexpected signedness {a_sn} and {b_sn}") + + def computation_type_impl(self, a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_is_scalar: bool, + div_or_mod: bool) -> tl.dtype: + # 0) For scalars we follow semantics similar to PyTorch, namely: + # - If the scalar is of a lower or equal kind (bool < uint < int < fp), + # it doesn't participate in the promotion + if a_is_scalar != b_is_scalar: + scalar_ty, tensor_ty = (a_ty, b_ty) if a_is_scalar else (b_ty, a_ty) + if scalar_ty.kind().value <= tensor_ty.kind().value: + # Upcast because of 3) and 4) below! + if div_or_mod and (tensor_ty in (tl.float16, tl.bfloat16)): + return tl.float32 + return tensor_ty + + # 1) if one operand is double, the other is implicitly + # converted to double + if a_ty.is_fp64() or b_ty.is_fp64(): + return tl.float64 + # 2) if one operand is float, the other is implicitly + # converted to float + if a_ty.is_fp32() or b_ty.is_fp32(): + return tl.float32 + # 3 ) if one operand is half, the other is implicitly converted to half + # unless we're doing / or %, which do not exist natively in PTX for fp16. + # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp + if a_ty.is_fp16() or b_ty.is_fp16(): + if div_or_mod: + return tl.float32 + else: + return tl.float16 + # 4) return bf16 only if both operands are of bf16 + if a_ty.is_bf16() and b_ty.is_bf16(): + if div_or_mod: + return tl.float32 + else: + return tl.bfloat16 + if a_ty.is_bf16() or b_ty.is_bf16(): + return tl.float32 + # 5) return fp16 if operands are different fp8 + if a_ty.is_fp8() and b_ty.is_fp8(): + return a_ty if a_ty == b_ty else tl.float16 + if not a_ty.is_int() or not b_ty.is_int(): + raise TypeError(f"unexpected type {a_ty} and {b_ty}") + # 6 ) both operands are integer and undergo + # integer promotion + if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: + raise TypeError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + + " because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + return self.integer_promote_impl(a_ty, b_ty) + + def to_tensor(self, x, check_type: bool = True): + if isinstance(x, bool): + return self.tensor(self.builder.get_int1(x), tl.int1) + # Note: compile-time const integers are represented by unsigned values + elif isinstance(x, int): + if -2**31 <= x < 2**31: + dtype = tl.int32 + elif 2**31 <= x < 2**32: + dtype = tl.uint32 + elif -2**63 <= x < 2**63: + dtype = tl.int64 + elif 2**63 <= x < 2**64: + dtype = tl.uint64 + else: + raise ValueError(f'Nonrepresentable integer {x}.') + return self.scalar_constant(x, dtype=dtype) + elif isinstance(x, float): + min_float32 = 2**-126 + max_float32 = (2 - 2**-23) * 2**127 + abs_x = __builtins__['abs'](x) + if abs_x == float("inf") or\ + abs_x == 0.0 or \ + x != x or \ + min_float32 <= abs_x <= max_float32: + dtype = tl.float32 + else: + dtype = tl.float64 + return self.scalar_constant(x, dtype=dtype) + + elif isinstance(x, tl.constexpr): + return self.to_tensor(x.value) + elif isinstance(x, self.tensor): + return x + if check_type: + raise TypeError(f"cannot convert {x} of type {type(x)} to tensor") + return x + +# ===----------------------------------------------------------------------===// +# Binary Operators +# ===----------------------------------------------------------------------===// + + def check_ptr_type_impl(self, type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None: + if type_a.is_ptr(): + if not allow_ptr_a: + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + U* with T != U + if type_b.is_ptr() and (type_a != type_b): + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + float + if type_b.is_floating(): + raise IncompatibleTypeErrorImpl(type_a, type_b) + + def binary_op_type_checking_impl(self, lhs: TensorTy | numbers.Number, rhs: TensorTy | numbers.Number, + allow_lhs_ptr=False, allow_rhs_ptr=False, arithmetic_check=True, + div_or_mod=False) -> Tuple[TensorTy, TensorTy]: + lhs_is_scalar = isinstance(lhs, numbers.Number) + rhs_is_scalar = isinstance(rhs, numbers.Number) + if lhs_is_scalar: + lhs_scalar = lhs + lhs = self.to_tensor(lhs) + if rhs_is_scalar: + rhs_scalar = rhs + rhs = self.to_tensor(rhs) + + # implicit typecasting + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + self.check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr) + self.check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr) + if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr(): + ret_sca_ty = self.computation_type_impl(lhs_sca_ty, lhs_is_scalar, rhs_sca_ty, rhs_is_scalar, div_or_mod) + if (lhs_is_scalar and lhs_scalar < 0 and ret_sca_ty.is_int_unsigned() + or rhs_is_scalar and rhs_scalar < 0 and ret_sca_ty.is_int_unsigned()): + raise ValueError("Cannot perform a binary operation between an unsigned tensor and a negative scalar. " + "Perform a explicit cast on one of them.") + if ret_sca_ty.is_int(): + if lhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= lhs_scalar <= + ret_sca_ty.get_int_max_value()): + raise ValueError(f"Scalar {lhs_scalar} is out of range for type {ret_sca_ty}") + if rhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= rhs_scalar <= + ret_sca_ty.get_int_max_value()): + raise ValueError(f"Scalar {rhs_scalar} is out of range for type {ret_sca_ty}") + lhs = self.scalar_constant(lhs_scalar, dtype=ret_sca_ty) if lhs_is_scalar else self.cast(lhs, ret_sca_ty) + rhs = self.scalar_constant(rhs_scalar, dtype=ret_sca_ty) if rhs_is_scalar else self.cast(rhs, ret_sca_ty) + + # implicit broadcasting + lhs, rhs = self.broadcast_impl_value(lhs, rhs) + return lhs, rhs + + def binary_op_sanitize_overflow_impl(self, lhs: TensorTy, rhs: TensorTy, binary_op: callable): + if lhs.type.scalar.int_bitwidth >= 64 or not self.builder.options.sanitize_overflow: + return + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + assert lhs_sca_ty == rhs_sca_ty + assert lhs_sca_ty.is_int() + lhs = self.cast(lhs, tl.int64) + rhs = self.cast(rhs, tl.int64) + ret = binary_op(lhs, rhs, False) + max_value = lhs_sca_ty.get_int_max_value() + max_value = self.scalar_constant(max_value, tl.int64) + min_value = lhs_sca_ty.get_int_min_value() + min_value = self.scalar_constant(min_value, tl.int64) + cond = self.and_(self.less_equal(ret, max_value), self.greater_equal(ret, min_value)) + msg = f"int{lhs_sca_ty.int_bitwidth} overflow detected for operation {binary_op.__name__}" + self.device_assert(cond, msg) + + def add(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number, + sanitize_overflow: bool) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr(): + raise TypeError("cannot add pointers together") + + # offset + ptr + # ptr + offset + if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): + input, other = other, input + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr(): + other_handle = other.handle + if other.dtype.is_int_unsigned() and other.dtype.int_bitwidth < 64: + # addptr treats offset as signed. Zero-extend unsigned offsets to ensure they're positive + i64_ty = other.type.with_element_ty(tl.int64).to_ir(self.builder) + other_handle = self.builder.create_int_cast(other.handle, i64_ty, False) + return self.tensor(self.builder.create_addptr(input.handle, other_handle), input.type) + # float + float + elif input_scalar_ty.is_floating(): + return self.tensor(self.builder.create_fadd(input.handle, other.handle), input.type) + # int + int + elif input_scalar_ty.is_int(): + if sanitize_overflow: + self.binary_op_sanitize_overflow_impl(input, other, self.add) + return self.tensor(self.builder.create_add(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + def sub(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number, + sanitize_overflow: bool) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other, True, False) + scalar_ty = input.type.scalar + # ptr - offset + if scalar_ty.is_ptr(): + return self.add(input, self.minus(other), sanitize_overflow=False) + # float - float + if scalar_ty.is_floating(): + return self.tensor(self.builder.create_fsub(input.handle, other.handle), input.type) + # int - int + elif scalar_ty.is_int(): + if sanitize_overflow: + self.binary_op_sanitize_overflow_impl(input, other, self.sub) + return self.tensor(self.builder.create_sub(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + def mul(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number, + sanitize_overflow: bool) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other) + scalar_ty = input.type.scalar + # float * float + if scalar_ty.is_floating(): + return self.tensor(self.builder.create_fmul(input.handle, other.handle), input.type) + # int * int + elif scalar_ty.is_int(): + if sanitize_overflow: + self.binary_op_sanitize_overflow_impl(input, other, self.mul) + return self.tensor(self.builder.create_mul(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + def truediv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float / int + if input_scalar_ty.is_floating() and other_scalar_ty.is_int(): + other = self.cast(other, input_scalar_ty) + # int / float + elif input_scalar_ty.is_int() and other_scalar_ty.is_floating(): + input = self.cast(input, other_scalar_ty) + # int / int (cast to tl.float32) + elif input_scalar_ty.is_int() and other_scalar_ty.is_int(): + input = self.cast(input, tl.float32) + other = self.cast(other, tl.float32) + # float / float (cast to the highest exponent type) + elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating(): + if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width: + other = self.cast(other, input_scalar_ty) + else: + input = self.cast(input, other_scalar_ty) + # unreachable + else: + raise TypeError(f"unexpected type {input_scalar_ty}") + return self.tensor(self.builder.create_fdiv(input.handle, other.handle), input.type) + + def floordiv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_int() and other_scalar_ty.is_int(): + ret_ty = self.integer_promote_impl(input_scalar_ty, other_scalar_ty) + input = self.cast(input, ret_ty) + other = self.cast(other, ret_ty) + if ret_ty.is_int_signed(): + return self.tensor(self.builder.create_sdiv(input.handle, other.handle), input.type) + else: + return self.tensor(self.builder.create_udiv(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + def fdiv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number, ieee_rounding: bool) -> TensorTy: + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): + raise TypeError("both operands of fdiv must have floating scalar type") + input, other = self.binary_op_type_checking_impl(input, other, False, False, False, True) + ret = self.builder.create_fdiv(input.handle, other.handle) + return self.tensor(ret, input.type) + + def mod(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True) + scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float % float + if scalar_ty.is_floating(): + return self.tensor(self.builder.create_frem(input.handle, other.handle), input.type) + # % int + elif scalar_ty.is_int(): + if scalar_ty.int_signedness != other_scalar_ty.int_signedness: + raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " + "because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + if scalar_ty.is_int_signed(): + return self.tensor(self.builder.create_srem(input.handle, other.handle), input.type) + else: + return self.tensor(self.builder.create_urem(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + +############## +# other arithmetic ops +############## + + def minimum(self, x: TensorTy, y: TensorTy, propagate_nan: tl.PropagateNan): + x, y = self.binary_op_type_checking_impl(x, y) + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return self.tensor(self.builder.create_minimumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return self.tensor(self.builder.create_minnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return self.tensor(self.builder.create_minsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return self.tensor(self.builder.create_minui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") + + def maximum(self, x: TensorTy, y: TensorTy, propagate_nan: tl.PropagateNan): + x, y = self.binary_op_type_checking_impl(x, y) + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return self.tensor(self.builder.create_maximumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return self.tensor(self.builder.create_maxnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return self.tensor(self.builder.create_maxsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return self.tensor(self.builder.create_maxui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") + + def clamp(self, x: TensorTy, min: TensorTy, max: TensorTy, propagate_nan: tl.PropagateNan): + min, max = self.binary_op_type_checking_impl(min, max) + x, min = self.binary_op_type_checking_impl(x, min) + x, max = self.binary_op_type_checking_impl(x, max) + + dtype = x.dtype + if dtype.is_floating(): + return self.tensor(self.builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}. Only floating point clamp is supported") + +############## +# bitwise ops +############## + + def bitwise_op_type_checking_impl(self, input: TensorTy, other: TensorTy) -> Tuple[TensorTy, TensorTy]: + input, other = self.binary_op_type_checking_impl(input, other) + input_sca_ty = input.type.scalar + other_sca_ty = other.type.scalar + if not input_sca_ty.is_int() or not other_sca_ty.is_int(): + raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty) + ret_sca_ty = self.integer_promote_impl(input_sca_ty, other_sca_ty) + if ret_sca_ty != input_sca_ty: + input = self.cast(input, ret_sca_ty) + if ret_sca_ty != other_sca_ty: + other = self.cast(other, ret_sca_ty) + return input, other + + def and_(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.bitwise_op_type_checking_impl(input, other) + return self.tensor(self.builder.create_and(input.handle, other.handle), input.type) + + def or_(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.bitwise_op_type_checking_impl(input, other) + return self.tensor(self.builder.create_or(input.handle, other.handle), input.type) + + def xor_(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.bitwise_op_type_checking_impl(input, other) + return self.tensor(self.builder.create_xor(input.handle, other.handle), input.type) + + def logical_and(self, input: TensorTy, other: TensorTy) -> TensorTy: + if not input.type.is_int1(): + input = self.bitcast(input, tl.int1) + if not other.type.is_int1(): + other = self.bitcast(other, tl.int1) + return self.and_(input, other) + + def logical_or(self, input: TensorTy, other: TensorTy) -> TensorTy: + if not input.type.is_int1(): + input = self.bitcast(input, tl.int1) + if not other.type.is_int1(): + other = self.bitcast(other, tl.int1) + return self.or_(input, other) + + def not_(self, input: TensorTy): + if not input.type.is_int1(): + input = self.bitcast(input, tl.int1) + return self.invert(input) + + def lshr(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.bitwise_op_type_checking_impl(input, other) + return self.tensor(self.builder.create_lshr(input.handle, other.handle), input.type) + + def ashr(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.bitwise_op_type_checking_impl(input, other) + return self.tensor(self.builder.create_ashr(input.handle, other.handle), input.type) + + def shl(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.bitwise_op_type_checking_impl(input, other) + return self.tensor(self.builder.create_shl(input.handle, other.handle), input.type) + +# ===----------------------------------------------------------------------===// +# Unary Operators +# ===----------------------------------------------------------------------===// + + def plus(self, input: TensorTy) -> TensorTy: + return input + + def minus(self, input: TensorTy) -> TensorTy: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr(): + raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") + _0 = self.tensor(self.builder.get_null_value(input_sca_ty.to_ir(self.builder)), input_sca_ty) + return self.sub(_0, input, True) + + def invert(self, input: TensorTy) -> TensorTy: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): + raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") + _1 = self.tensor(self.builder.get_all_ones_value(input_sca_ty.to_ir(self.builder)), input_sca_ty) + return self.xor_(input, _1) + +# ===----------------------------------------------------------------------===// +# Comparison Operators +# ===----------------------------------------------------------------------===// + + def _bool_like(self, v: TensorTy) -> tl.block_type: + return v.type.with_element_ty(tl.int1) + + def greater_than(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other) + scalar_ty = input.type.scalar + # float > float + if scalar_ty.is_floating(): + return self.tensor(self.builder.create_fcmpOGT(input.handle, other.handle), self._bool_like(input)) + # > int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return self.tensor(self.builder.create_icmpSGT(input.handle, other.handle), self._bool_like(input)) + else: + return self.tensor(self.builder.create_icmpUGT(input.handle, other.handle), self._bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + def greater_equal(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other) + scalar_ty = input.type.scalar + # float >= float + if scalar_ty.is_floating(): + return self.tensor(self.builder.create_fcmpOGE(input.handle, other.handle), self._bool_like(input)) + # >= int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return self.tensor(self.builder.create_icmpSGE(input.handle, other.handle), self._bool_like(input)) + else: + return self.tensor(self.builder.create_icmpUGE(input.handle, other.handle), self._bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + def less_than(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return self.tensor(self.builder.create_fcmpOLT(input.handle, other.handle), self._bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return self.tensor(self.builder.create_icmpSLT(input.handle, other.handle), self._bool_like(input)) + else: + return self.tensor(self.builder.create_icmpULT(input.handle, other.handle), self._bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + def less_equal(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return self.tensor(self.builder.create_fcmpOLE(input.handle, other.handle), self._bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return self.tensor(self.builder.create_icmpSLE(input.handle, other.handle), self._bool_like(input)) + else: + return self.tensor(self.builder.create_icmpULE(input.handle, other.handle), self._bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + def equal(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return self.tensor(self.builder.create_fcmpOEQ(input.handle, other.handle), self._bool_like(input)) + # == int + elif scalar_ty.is_int(): + return self.tensor(self.builder.create_icmpEQ(input.handle, other.handle), self._bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + def not_equal(self, input: TensorTy, other: TensorTy) -> TensorTy: + input, other = self.binary_op_type_checking_impl(input, other) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return self.tensor(self.builder.create_fcmpUNE(input.handle, other.handle), self._bool_like(input)) + # == int + elif scalar_ty.is_int(): + return self.tensor(self.builder.create_icmpNE(input.handle, other.handle), self._bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + +# ===----------------------------------------------------------------------===// +# Block Creation +# ===----------------------------------------------------------------------===// + + def arange(self, start: int, end: int, *, ret_ty: tl.block_type = None) -> TensorTy: + if not isinstance(start, int) or not isinstance(end, int): + raise ValueError("arange's arguments must be of type tl.constexpr") + is_start_int64 = bool(start >> 32) + is_end_int64 = bool(end >> 32) + if is_start_int64 or is_end_int64: + raise ValueError("arange must fit in int32") + if end <= start: + raise ValueError("arange's end argument must be greater than the start argument") + range = end - start + if (range & (range - 1)) != 0: + raise ValueError("arange's range must be a power of 2") + shape = [range] + if ret_ty is None: + ret_ty = tl.block_type(tl.int32, shape) + ret_ty_ir = ret_ty.to_ir(self.builder) + return self.tensor(self.builder.create_make_range(ret_ty_ir, start, end), ret_ty) + + def scalar_constant(self, value, dtype: tl.dtype) -> TensorTy: + # scalar + if dtype is None: + raise ValueError("dtype must be specified when value is not a tensor") + if value == 0: + value = self.builder.get_null_value(dtype.to_ir(self.builder)) + else: + get_value_fn = getattr(self.builder, f"get_{dtype.name}") + value = get_value_fn(value) + return self.tensor(value, dtype) + + def make_scalar(self, value, dtype: tl.dtype) -> TensorTy: + if isinstance(value, tl.tensor): + assert value.numel.value == 1, "only accepts size-1 tensor" + return self.cast(value, dtype) + # scalar + return self.scalar_constant(value, dtype) + + def full(self, shape: List[int], value, dtype: tl.dtype) -> TensorTy: + return self.splat(self.make_scalar(value, dtype), shape) + +# ===----------------------------------------------------------------------===// +# Shape Manipulation +# ===----------------------------------------------------------------------===// + + def splat(self, value: TensorTy, shape: List[int]) -> TensorTy: + assert not value.type.is_block(), "Cannot splat a block tensor" + if len(shape) == 0: + return value + ret_ty = tl.block_type(value.dtype, shape) + return self.tensor(self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle), ret_ty) + + def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool) -> TensorTy: + numel = 1 + for s in dst_shape: + numel *= s + if input.type.numel != numel: + raise ValueError("reshape() cannot change total number of elements in tensor") + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return self.tensor(self.builder.create_reshape(input.handle, dst_shape, can_reorder), ret_ty) + + def expand_dims(self, input: TensorTy, axis: int) -> TensorTy: + dst_shape = [tl._unwrap_if_constexpr(x) for x in input.shape] + dst_shape.insert(axis, 1) + + if not input.type.is_block(): + return self.splat(input, shape=dst_shape) + + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return self.tensor(self.builder.create_expand_dims(input.handle, axis), ret_ty) + + def cat(self, lhs: TensorTy, rhs: TensorTy, can_reorder: bool) -> TensorTy: + assert can_reorder, "current implementation of `cat` always may reorder elements" + assert len(lhs.shape) == 1 + ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]]) + return self.tensor(self.builder.create_cat(lhs.handle, rhs.handle), ret_type) + + def join(self, a: TensorTy, b: TensorTy) -> TensorTy: + a, b = self.broadcast_impl_value(a, b) + + # The IR can't handle joining two scalars, so upcast them to 1D tensors, + # then downcast the result. + was_rank_1 = a.shape == [] + if was_rank_1: + a = self.expand_dims(a, 0) + b = self.expand_dims(b, 0) + + if isinstance(a.shape[-1], tl.constexpr): + two = tl.constexpr(2) + else: + two = 2 + new_shape = a.shape + [two] + + ret_type = tl.block_type(a.type.scalar, new_shape) + ret = self.tensor(self.builder.create_join(a.handle, b.handle), ret_type) + + if was_rank_1: + ret = self.reshape(ret, [2], can_reorder=False) + + return ret + + def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]: + assert (len(a.shape) > 0) + assert (tl._unwrap_if_constexpr(a.shape[-1]) == 2) + + new_shape = a.shape[:-1] + ret_type = tl.block_type(a.type.scalar, new_shape) + outLHS, outRHS = self.builder.create_split(a.handle) + return ( + self.tensor(outLHS, ret_type), + self.tensor(outRHS, ret_type), + ) + + def permute(self, input: TensorTy, dims: Tuple[int]) -> TensorTy: + if len(input.shape) != len(dims): + raise ValueError("permute dims must have the same length as input shape") + if sorted(tl._unwrap_if_constexpr(d) for d in dims) != list(range(len(dims))): + raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}") + + ret_type = tl.block_type(input.type.scalar, [input.shape[d] for d in dims]) + return self.tensor(self.builder.create_trans(input.handle, dims), ret_type) + + def broadcast_impl_shape(self, input: TensorTy, shape: Tuple[int]) -> TensorTy: + if not input.type.is_block(): + return self.splat(input, shape) + src_shape = input.type.get_block_shapes() + if len(src_shape) != len(shape): + raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") + if shape == src_shape: + return input + for i, item in enumerate(src_shape): + if shape[i] != item and item != 1: + raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" + f" must match the existing size ({item}) at non-singleton dimension" + f" {i}: {src_shape}, {shape}") + ret_ty = tl.block_type(input.type.scalar, shape) + return self.tensor(self.builder.create_broadcast(input.handle, shape), ret_ty) + + def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy: + lhs_ty = lhs.type + rhs_ty = rhs.type + + # make_shape_compatible(block, scalar) + if lhs_ty.is_block() and not rhs_ty.is_block(): + rhs_ty = lhs_ty.with_element_ty(rhs_ty.scalar) + rhs = self.tensor(self.builder.create_splat(rhs_ty.to_ir(self.builder), rhs.handle), rhs_ty) + # make_shape_compatible(scalar, block) + elif not lhs_ty.is_block() and rhs_ty.is_block(): + lhs_ty = rhs_ty.with_element_ty(lhs_ty.scalar) + lhs = self.tensor(self.builder.create_splat(lhs_ty.to_ir(self.builder), lhs.handle), lhs_ty) + # make_shape_compatible(block, block) + elif lhs_ty.is_block() and rhs_ty.is_block(): + lhs_shape = lhs_ty.get_block_shapes() + rhs_shape = rhs_ty.get_block_shapes() + + if len(lhs_shape) < len(rhs_shape): + # Add new axes to lhs + for _ in range(len(lhs_shape), len(rhs_shape)): + lhs = self.tensor(self.builder.create_expand_dims(lhs.handle, 0), + tl.block_type(lhs_ty.scalar, [1] + lhs_shape.values)) + lhs_ty = lhs.type + lhs_shape = lhs_ty.get_block_shapes() + elif len(rhs_shape) < len(lhs_shape): + # Add new axes to rhs + for _ in range(len(rhs_shape), len(lhs_shape)): + rhs = self.tensor(self.builder.create_expand_dims(rhs.handle, 0), + tl.block_type(rhs_ty.scalar, [1] + rhs_shape.values)) + rhs_ty = rhs.type + rhs_shape = rhs_ty.get_block_shapes() + assert len(rhs_shape) == len(lhs_shape) + + ret_shape = [] + for i, left in enumerate(lhs_shape): + right = rhs_shape[i] + if left == 1: + ret_shape.append(right) + elif (right == 1) or (right == left): + ret_shape.append(left) + else: + raise ValueError("Cannot make_shape_compatible: incompatible dimensions " + "at index " + str(i) + ": " + str(left) + " and " + str(right)) + if lhs_shape != ret_shape: + ret_ty = tl.block_type(lhs_ty.scalar, ret_shape) + lhs = self.tensor(self.builder.create_broadcast(lhs.handle, ret_shape), ret_ty) + if rhs_shape != ret_shape: + ret_ty = tl.block_type(rhs_ty.scalar, ret_shape) + rhs = self.tensor(self.builder.create_broadcast(rhs.handle, ret_shape), ret_ty) + # (scalar, scalar) => returns original blocks + return lhs, rhs + +####### +# cast +####### + + def _str_to_rounding_mode(self, rounding_mode: Optional[str]): + if rounding_mode is None: + return None + if rounding_mode == 'rtne': + return ir.ROUNDING_MODE.RTNE + if rounding_mode == 'rtz': + return ir.ROUNDING_MODE.RTZ + raise ValueError(f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz'.") + + def bitcast(self, input: TensorTy, dst_ty: tl.dtype) -> TensorTy: + src_ty = input.type + if src_ty.is_block(): + dst_ty = src_ty.with_element_ty(dst_ty.scalar) + if src_ty == dst_ty: + return input + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): + return self.cast(input, dst_ty) + # Bitcast + src_bits = src_sca_ty.primitive_bitwidth + dst_bits = dst_sca_ty.primitive_bitwidth + if src_bits != dst_bits: + raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to " + "data-type of size " + str(dst_bits)) + return self.tensor(self.builder.create_bitcast(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + + def cast(self, input: TensorTy, dst_ty: tl.dtype, fp_downcast_rounding: Optional[str] = None) -> TensorTy: + src_ty = input.type + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty == dst_sca_ty: + return input + if src_ty.is_block(): + dst_ty = src_ty.with_element_ty(dst_sca_ty) + + # For fp downcasting default rounding mode should be RTNE, for all other conversions it should + # not be set + fp_downcast_rounding = self._str_to_rounding_mode(fp_downcast_rounding) + use_custom_rounding = False + if dst_sca_ty.is_floating() and src_sca_ty.is_floating( + ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth: + if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE + elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True + else: + if fp_downcast_rounding is not None: + raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + + str(dst_sca_ty)) + + if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): + assert self.builder.codegen_fns.get( + "convert_custom_types") is not None, "target doesn't provide conversion for this type." + return self.builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _semantic=self) + # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 + # and non-default rounding modes for downcasting + if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ + (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \ + use_custom_rounding: + return self.tensor( + self.builder.create_fp_to_fp(input.handle, dst_ty.to_ir(self.builder), fp_downcast_rounding), dst_ty) + + # bf16 <=> (not fp32) + if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ + (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): + return self.cast(self.cast(input, tl.float32), dst_sca_ty) + + # Standard floating types' casting: truncation + # fp64 => fp32, fp16, bf16 + # fp32 => fp16, bf16 + truncate_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth + if truncate_fp: + return self.tensor(self.builder.create_fp_trunc(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + + # Standard floating types' casting: extension + # fp32 => fp64 + # fp16 => fp32, fp64 + # bf16 => fp32, fp64 + ext_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth + if ext_fp: + return self.tensor(self.builder.create_fp_ext(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + + # Casting between integer types + if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ + (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): + sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(self.builder) + _0 = self.tensor(self.builder.get_null_value(ty), input.dtype) + return self.not_equal(input, _0) + else: + return self.tensor(self.builder.create_int_cast(input.handle, dst_ty.to_ir(self.builder), sign_extend), + dst_ty) + + # Casting standard floating types to integer types + if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(self.builder) + _0 = self.tensor(self.builder.get_null_value(ty), input.dtype) + return self.not_equal(input, _0) + elif dst_sca_ty.is_int_signed(): + return self.tensor(self.builder.create_fp_to_si(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + else: + return self.tensor(self.builder.create_fp_to_ui(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + + # Casting integer types to standard floating types + if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): + if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): + return self.tensor(self.builder.create_ui_to_fp(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + else: + return self.tensor(self.builder.create_si_to_fp(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + + # Casting pointer types to integer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): + bitwidth = dst_sca_ty.int_bitwidth + if bitwidth == 64: + return self.tensor(self.builder.create_ptr_to_int(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + if bitwidth == 1: + return self.not_equal(self.cast(input, tl.int64), self.tensor(self.builder.get_int64(0), tl.int64)) + + # Casting integer types to pointer types + if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): + return self.tensor(self.builder.create_int_to_ptr(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + + # Casting pointer types to pointer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return self.tensor(self.builder.create_bitcast(input.handle, dst_ty.to_ir(self.builder)), dst_ty) + + assert False, f'cannot cast {input} to {dst_ty}' + +# ===----------------------------------------------------------------------===// +# Memory Operators +# ===----------------------------------------------------------------------===// + + def _str_to_load_cache_modifier(self, cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".ca": + cache = ir.CACHE_MODIFIER.CA + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + elif cache_modifier == ".cv": + cache = ir.CACHE_MODIFIER.CV + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + def _str_to_store_cache_modifier(self, cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".wb": + cache = ir.CACHE_MODIFIER.WB + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + elif cache_modifier == ".cs": + cache = ir.CACHE_MODIFIER.CS + elif cache_modifier == ".wt": + cache = ir.CACHE_MODIFIER.WT + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + def _str_to_eviction_policy(self, eviction_policy): + eviction = ir.EVICTION_POLICY.NORMAL # default + if eviction_policy: + if eviction_policy == "evict_last": + eviction = ir.EVICTION_POLICY.EVICT_LAST + elif eviction_policy == "evict_first": + eviction = ir.EVICTION_POLICY.EVICT_FIRST + else: + raise ValueError(f"Eviction policy {eviction_policy} not supported") + return eviction + + def _str_to_padding_option(self, padding_option): + padding = None # default + if padding_option: + if padding_option == "zero": + padding = ir.PADDING_OPTION.PAD_ZERO + elif padding_option == "nan": + padding = ir.PADDING_OPTION.PAD_NAN + else: + raise ValueError(f"Padding option {padding_option} not supported") + return padding + + def _str_to_sem(self, sem_option): + sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE + if sem_option: + if sem_option == "acquire": + sem = ir.MEM_SEMANTIC.ACQUIRE + elif sem_option == "release": + sem = ir.MEM_SEMANTIC.RELEASE + elif sem_option == "acq_rel": + sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE + elif sem_option == "relaxed": + sem = ir.MEM_SEMANTIC.RELAXED + else: + raise ValueError(f"Memory semantic {sem_option} not supported") + return sem + + def _str_to_scope(self, scope_option): + scope = ir.MEM_SYNC_SCOPE.GPU + if scope_option: + if scope_option == "gpu": + scope = ir.MEM_SYNC_SCOPE.GPU + elif scope_option == "cta": + scope = ir.MEM_SYNC_SCOPE.CTA + elif scope_option == "sys": + scope = ir.MEM_SYNC_SCOPE.SYSTEM + else: + raise ValueError(f"Memory semantic {scope_option} not supported") + return scope + + def _canonicalize_boundary_check(self, boundary_check, block_shape): + if boundary_check: + if not hasattr(boundary_check, "__iter__"): + boundary_check = [boundary_check] + boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check] + for dim in boundary_check: + assert isinstance(dim, int) and 0 <= dim < len(block_shape) + assert len(boundary_check) > 0 + assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`" + return sorted(boundary_check) + return () + + def _load_block_pointer(self, ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile): + # Load by a block pointer: `pointer_type>` + # Block pointer can not have `mask` and `other` arguments + if mask is not None or other is not None: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`" + if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN: + raise ValueError("Padding option `nan` is not supported for integer block pointers") + + # `dst_ty` is de-referenced type of the pointer type + dst_ty = ptr.type.element_ty + + # Check `boundary_check` argument + boundary_check = self._canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes()) + + # Build IR + return self.tensor( + self.builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), + dst_ty) + + def _load_legacy(self, ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile): + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`") + + # Check `mask`, `other`, `boundary_check`, and `padding` arguments + if mask is None and other is not None: + raise ValueError("`other` cannot be provided without `mask`") + if padding or boundary_check: + raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of" + "pointers or loading a scalar. Because the compiler does not know the boundary; please " + "use block pointers (defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `mask` and `other` + if not ptr.type.is_block(): + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + if other and other.type.is_block(): + raise ValueError("Other argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `other` into the same shape as `ptr` + if ptr.type.is_block(): + if mask is not None: + mask = self.broadcast_impl_shape(mask, ptr.type.get_block_shapes()) + if other is not None: + other = self.broadcast_impl_shape(other, ptr.type.get_block_shapes()) + + # Get `pointer_type` and `elt_ty` + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + is_bool = elt_ty == tl.int1 + if is_bool: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = self.cast(ptr, ptr_ty) + + # Cast `other` into `elt_ty` type + if other is not None: + other = self.cast(other, elt_ty) + + # Create loaded result type `dst_ty` + if ptr.type.is_block(): + dst_ty = ptr.type.with_element_ty(elt_ty) + else: + # Load by de-referencing the pointer of scalar + dst_ty = elt_ty + + # Build IR + if mask is None: + ret = self.tensor(self.builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) + else: + ret = self.tensor( + self.builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, + eviction, is_volatile), dst_ty) + if is_bool: + ret = self.cast(ret, tl.int1) + return ret + + def load(self, ptr: TensorTy, mask: Optional[TensorTy], other: Optional[TensorTy], boundary_check: Tuple, + padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool) -> TensorTy: + # Cache, eviction and padding options + cache = self._str_to_load_cache_modifier(cache_modifier) + eviction = self._str_to_eviction_policy(eviction_policy) + padding = self._str_to_padding_option(padding_option) + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Load by a block pointer: `pointer_type>` + return self._load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile) + else: + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return self._load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile) + + def descriptor_load(self, desc: tl.tensor_descriptor_base, offsets, cache_modifier: str, + eviction_policy: str) -> TensorTy: + assert isinstance(desc, tl.tensor_descriptor_base) + ndim = len(desc.block_shape) + assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}" + + offsets = self._convert_to_ir_values(offsets, require_i64=False) + x = self.builder.create_descriptor_load(desc.handle, offsets, self._str_to_load_cache_modifier(cache_modifier), + self._str_to_eviction_policy(eviction_policy)) + return self.tensor(x, desc.block_type) + + def validate_store_like(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> None: + assert isinstance(desc, tl.tensor_descriptor_base) + ndim = len(desc.block_shape) + assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}" + assert value.shape == desc.block_shape + + def descriptor_store(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: + self.validate_store_like(desc, value, offsets) + offsets = self._convert_to_ir_values(offsets, require_i64=False) + return self.tensor(self.builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void) + + def descriptor_atomic_add(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: + self.validate_store_like(desc, value, offsets) + assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.float32, tl.float16, tl.bfloat16}, "Unsupported dtype" + offsets = self._convert_to_ir_values(offsets, require_i64=False) + kind = ir.DESCRIPTOR_REDUCE_KIND.ADD + return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) + + def _has_native_tma(self, ): + target = driver.active.get_current_target() + return (target.backend == "cuda" and target.arch >= 90) + + def _descriptor_atomic_min_max_supported(self, dtype): + assert dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16}, "Unsupported dtype" + if dtype in {tl.float16, tl.bfloat16}: + assert self._has_native_tma(), "16-bit float types require native tma support" + + def descriptor_atomic_min(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: + self.validate_store_like(desc, value, offsets) + self._descriptor_atomic_min_max_supported(desc.dtype) + offsets = self._convert_to_ir_values(offsets, require_i64=False) + kind = ir.DESCRIPTOR_REDUCE_KIND.MIN + return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) + + def descriptor_atomic_max(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: + self.validate_store_like(desc, value, offsets) + self._descriptor_atomic_min_max_supported(desc.dtype) + offsets = self._convert_to_ir_values(offsets, require_i64=False) + kind = ir.DESCRIPTOR_REDUCE_KIND.MAX + return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) + + def descriptor_atomic_and(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: + self.validate_store_like(desc, value, offsets) + assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype" + offsets = self._convert_to_ir_values(offsets, require_i64=False) + kind = ir.DESCRIPTOR_REDUCE_KIND.AND + return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) + + def descriptor_atomic_or(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: + self.validate_store_like(desc, value, offsets) + assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype" + offsets = self._convert_to_ir_values(offsets, require_i64=False) + kind = ir.DESCRIPTOR_REDUCE_KIND.OR + return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) + + def descriptor_atomic_xor(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: + self.validate_store_like(desc, value, offsets) + assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype" + offsets = self._convert_to_ir_values(offsets, require_i64=False) + kind = ir.DESCRIPTOR_REDUCE_KIND.XOR + return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) + + def descriptor_gather(self, desc, x_offsets, y_offset, cache_modifier: str, eviction_policy: str) -> TensorTy: + assert isinstance(desc, tl.tensor_descriptor_base) + assert cache_modifier == "", "cache modifier is not supported yet" + assert eviction_policy == "", "eviction policy is not supported yet" + + # Validate descriptor. + assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}" + assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}" + + # Validate offsets. + assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shape}" + + # Validate minimum block size. + assert x_offsets.shape[0] >= 8, f"descriptor gather must have at least 8 rows, but got {x_offsets.shape}" + dtype = desc.dtype + min_cols = 32 // dtype.primitive_bitwidth * 8 + assert desc.block_shape[ + 1] >= min_cols, f"descriptor gather of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}" + + type = tl.block_type(desc.dtype, [x_offsets.shape[0], desc.block_shape[1]]) + y_offset = self._convert_to_ir_values((y_offset, ), require_i64=False)[0] + x = self.builder.create_descriptor_gather(desc.handle, x_offsets.handle, y_offset, type.to_ir(self.builder)) + return self.tensor(x, type) + + def descriptor_scatter(self, desc, value: TensorTy, x_offsets, y_offset) -> TensorTy: + assert isinstance(desc, tl.tensor_descriptor_base) + + # Validate descriptor. + assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}" + assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}" + + # Validate offsets. + assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shapae}" + + # Validate minimum block size. + assert x_offsets.shape[0] >= 8, f"descriptor scatter must have at least 8 rows, but got {x_offsets.shape}" + dtype = desc.dtype + min_cols = 32 // dtype.primitive_bitwidth * 8 + assert desc.block_shape[ + 1] >= min_cols, f"descriptor scatter of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}" + + y_offset = self._convert_to_ir_values((y_offset, ), require_i64=False)[0] + self.builder.create_descriptor_scatter(desc.handle, value.handle, x_offsets.handle, y_offset) + return self.tensor(None, tl.void) + + def _store_block_pointer(self, ptr, val, mask, boundary_check, cache, eviction): + # Store by a block pointer: `pointer_type>` + # Block pointers can not have the `mask` argument + if mask is not None: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + # Check same shape and element type + block_shape = ptr.type.element_ty.get_block_shapes() + if not val.type.is_block(): + val = self.broadcast_impl_shape(val, block_shape) + assert val.type.is_block(), "Value argument must be block type or a scalar" + assert block_shape == val.type.get_block_shapes( + ), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch" + assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch" + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`" + + # Check `boundary_check` argument + boundary_check = self._canonicalize_boundary_check(boundary_check, block_shape) + + # Cast to target data type + val = self.cast(val, elt_ty) + + # Build IR + return self.tensor( + self.builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction), tl.void) + + def _store_legacy(self, ptr, val, mask, boundary_check, cache, eviction): + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`") + + # Check `boundary_check` argument + if boundary_check: + raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a " + "scalar. Because the compiler does not know the boundary; please use block pointers " + "(defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `val` and `mask` + if not ptr.type.is_block(): + if val.type.is_block(): + raise ValueError("Value argument cannot be block type if pointer argument is not a block") + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `val` into the same shape as `ptr` + if ptr.type.is_block(): + val = self.broadcast_impl_shape(val, ptr.type.get_block_shapes()) + if mask is not None: + mask = self.broadcast_impl_shape(mask, ptr.type.get_block_shapes()) + + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = self.cast(ptr, ptr_ty) + + # Cast to target data type + val = self.cast(val, elt_ty) + + # Build IR + if mask is None: + return self.tensor(self.builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void) + if not mask.type.scalar.is_bool(): + raise ValueError("Mask must have boolean scalar type") + return self.tensor(self.builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), + tl.void) + + def store(self, ptr: TensorTy, val: TensorTy, mask: Optional[TensorTy], boundary_check, cache_modifier: str, + eviction_policy: str) -> TensorTy: + # Cache and eviction options + cache = self._str_to_store_cache_modifier(cache_modifier) + eviction = self._str_to_eviction_policy(eviction_policy) + + if ptr.type.is_const() or ptr.type.scalar.is_const(): + raise ValueError("Cannot store to a constant pointer") + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Store by a block pointer: `pointer_type>` + return self._store_block_pointer(ptr, val, mask, boundary_check, cache, eviction) + else: + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return self._store_legacy(ptr, val, mask, boundary_check, cache, eviction) + +######### +# atomic +######### + + def atomic_cas(self, ptr: TensorTy, cmp: TensorTy, val: TensorTy, sem: str, scope: str) -> TensorTy: + sem = self._str_to_sem(sem) + scope = self._str_to_scope(scope) + element_ty = ptr.type.scalar.element_ty + if element_ty.primitive_bitwidth not in [16, 32, 64]: + raise ValueError("atomic_cas only supports elements with width {16, 32, 64}") + return self.tensor(self.builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type) + + def atom_red_typechecking_impl(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, + op: str) -> Tuple[TensorTy, TensorTy, TensorTy]: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + if ptr.type.is_const() or ptr.type.element_ty.is_const(): + raise ValueError("Cannot store to a constant pointer") + element_ty = ptr.type.scalar.element_ty + if element_ty is tl.float16 and op != 'add': + raise ValueError("atomic_" + op + " does not support fp16") + if element_ty is tl.bfloat16 and op != 'add': + raise ValueError("atomic_" + op + " does not support bf16") + if element_ty in [tl.int16, tl.uint16] or element_ty.primitive_bitwidth < 16: + raise ValueError("atomic_" + op + " does not support " + str(element_ty)) + if ptr.type.is_block(): + if mask is not None: + mask = self.broadcast_impl_shape(mask, ptr.type.get_block_shapes()) + if val is not None: + val = self.broadcast_impl_shape(val, ptr.type.get_block_shapes()) + val = self.cast(val, ptr.type.scalar.element_ty) + if mask is None: + mask_ir = self.builder.get_int1(True) + mask_ty = tl.int1 + if ptr.type.is_block(): + mask_ty = ptr.type.with_element_ty(tl.int1) + mask_ir = self.builder.create_splat(mask_ty.to_ir(self.builder), mask_ir) + mask = self.tensor(mask_ir, mask_ty) + return ptr, val, mask + + def _signbit(self, x: TensorTy) -> TensorTy: + bitwidth = x.dtype.primitive_bitwidth + idtype = tl.get_int_dtype(bitwidth=bitwidth, signed=False) + ix = self.bitcast(x, idtype) + signbit = self.lshr(ix, bitwidth - 1) + return self.cast(signbit, tl.int1) + + def atomic_max(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: + ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'max') + sem = self._str_to_sem(sem) + scope = self._str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_max for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + else: + return self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + # for float + # return atomic_smax(i_ptr, i_val) if val >= 0 + # return atomic_umin(i_ptr, i_val) if val < 0 + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_max not supported for dtype {sca_ty}") + + i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 + i_val = self.bitcast(val, i_type) + i_ptr = self.bitcast(ptr, tl.pointer_type(i_type, 1)) + ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 + ui_val = self.bitcast(val, ui_type) + ui_ptr = self.bitcast(ptr, tl.pointer_type(ui_type, 1)) + neg = self._signbit(val) + pos = self.not_(neg) + pos_ret = self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, + self.and_(mask, pos).handle, sem, scope), i_val.type) + neg_ret = self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ui_ptr.handle, ui_val.handle, + self.and_(mask, neg).handle, sem, scope), ui_val.type) + ret = self.where(pos, pos_ret, neg_ret) + return self.bitcast(ret, sca_ty) + + def atomic_min(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: + ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'min') + sem = self._str_to_sem(sem) + scope = self._str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_min for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + else: + return self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + # for float + # return atomic_smin(i_ptr, i_val) if val >= 0 + # return atomic_umax(i_ptr, i_val) if val < 0 + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_min not supported for dtype {sca_ty}") + + i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 + i_val = self.bitcast(val, i_type) + i_ptr = self.bitcast(ptr, tl.pointer_type(i_type, 1)) + ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 + ui_val = self.bitcast(val, ui_type) + ui_ptr = self.bitcast(ptr, tl.pointer_type(ui_type, 1)) + neg = self._signbit(val) + pos = self.not_(neg) + pos_ret = self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle, + self.and_(mask, pos).handle, sem, scope), i_val.type) + neg_ret = self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ui_ptr.handle, ui_val.handle, + self.and_(mask, neg).handle, sem, scope), ui_ptr.type) + ret = self.where(pos, pos_ret, neg_ret) + return self.bitcast(ret, sca_ty) + + def atomic_add(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: + ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'add') + sem = self._str_to_sem(sem) + scope = self._str_to_scope(scope) + sca_ty = val.type.scalar + op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD + return self.tensor(self.builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + def atomic_and(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: + ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'and') + sem = self._str_to_sem(sem) + scope = self._str_to_scope(scope) + return self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + def atomic_or(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: + ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'or') + sem = self._str_to_sem(sem) + scope = self._str_to_scope(scope) + return self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + def atomic_xor(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: + ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'xor') + sem = self._str_to_sem(sem) + scope = self._str_to_scope(scope) + return self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + def atomic_xchg(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy: + ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'xchg') + sem = self._str_to_sem(sem) + scope = self._str_to_scope(scope) + return self.tensor( + self.builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + +# ===----------------------------------------------------------------------===// +# Linear Algebra +# ===----------------------------------------------------------------------===// + + def _str_to_dot_input_precision(self, input_precision): + assert input_precision.lower() in self.builder.options.allowed_dot_input_precisions, \ + f"input_precision must be one of {self.builder.options.allowed_dot_input_precisions}. Got {input_precision}" + input_precision = input_precision.upper() + if input_precision == "TF32X3": + input_precision = "TF32x3" + return getattr(ir.INPUT_PRECISION, input_precision) + + def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Optional[str], + max_num_imprecise_acc: int, out_dtype: tl.dtype) -> TensorTy: + assert lhs.type.is_block() and rhs.type.is_block() + + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): + # All combinations of supported fp8 x fp8 are permitted + pass + else: + assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, + tl.float32), f"Unsupported lhs dtype {lhs.dtype}" + assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, + tl.float32), f"Unsupported rhs dtype {rhs.dtype}" + assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}" + + if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): + if "fp8e4b15" in self.builder.options.deprecated_fp8_dot_operand_dtypes: + warnings.warn( + "the use of fp8e4b15 is deprecated on Hopper and later architectures and can cause significant slow down. It will be removed in a future triton release" + ) + # We upcast because there's no fp8e4b15 type in MLIR + lhs = self.cast(lhs, tl.float16) + rhs = self.cast(rhs, tl.float16) + + if input_precision is None: + input_precision = self.builder.options.default_dot_input_precision + + input_precision = self._str_to_dot_input_precision(input_precision) + + lhs_rank = len(lhs.shape) + rhs_rank = len(rhs.shape) + assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + assert lhs.shape[-1].value == rhs.shape[ + -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})" + assert self.builder.codegen_fns.get( + "min_dot_size") is not None, "target doesn't provide lower shape bounds for dot." + min_dot_size = self.builder.codegen_fns["min_dot_size"](lhs.type, rhs.type) + assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \ + and rhs.shape[-1].value >= min_dot_size[1], \ + f"Input shapes should have M >= {min_dot_size[0]}, N >= {min_dot_size[1]} and K >= {min_dot_size[2]}" + if lhs.type.scalar.is_int(): + assert lhs.type.scalar == tl.int8, "only int8 supported!" + _0 = self.builder.get_int32(0) + ret_scalar_ty = tl.int32 + elif out_dtype.is_bf16(): + raise ValueError( + "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`" + ) + elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): + _0 = self.builder.get_fp32(0) + ret_scalar_ty = tl.float32 + else: + _0 = self.builder.get_fp16(0) if out_dtype.is_fp16() else self.builder.get_fp32(0) + ret_scalar_ty = out_dtype + + M = lhs.type.shape[-2] + N = rhs.type.shape[-1] + K = lhs.type.shape[-1] + B = lhs.type.shape[0] if lhs_rank == 3 else None + ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N]) + if acc is None: + acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0) + else: + acc_handle = acc.handle + assert acc.type == ret_ty + + # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 + if max_num_imprecise_acc is None: + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): + max_num_imprecise_acc = self.builder.options.max_num_imprecise_acc_default + else: + max_num_imprecise_acc = 0 + else: + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc > K: + raise ValueError(f"max_num_imprecise_acc ({max_num_imprecise_acc}) must be <= K ({K})") + + return self.tensor( + self.builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), ret_ty) + + def _str_to_fp_type(self, float_format: str): + ty_enum = getattr(ir.ScaleDotElemTypeTY, float_format.upper(), None) + if ty_enum is None: + raise ValueError(f"Invalid float format: {float_format}.") + return ty_enum + + def _bitcast_to_fp_type(self, val: TensorTy, float_format: str): + """ + If float_format is subbyte, make sure it's packed as uint8 and return it. + Otherwise, return a tensor (perhaps bitcasting) of the specified float format. + """ + triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16, "fp16": + tl.float16}.get(float_format) + if triton_ty is None: + assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}" + assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}" + return val + if val.dtype == triton_ty: + return val + else: + unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16, "fp16": tl.uint16}[float_format] + assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}" + return self.bitcast(val, triton_ty) + + def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: TensorTy, + rhs_scale: Optional[TensorTy], rhs_format: str, acc: TensorTy | None, fast_math: bool, + lhs_k_pack: bool, rhs_k_pack: bool, out_dtype: tl.dtype) -> TensorTy: + assert lhs.type.is_block() and rhs.type.is_block() + #TODO: validate types. + lhs_rank = len(lhs.shape) + rhs_rank = len(rhs.shape) + assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + lhs_format: str = lhs_format.value + rhs_format: str = rhs_format.value + lhs_format_enum = self._str_to_fp_type(lhs_format) + rhs_format_enum = self._str_to_fp_type(rhs_format) + allowed_formats = {"e2m1", "e4m3", "e5m2", "bf16", "fp16"} + assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}" + assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}" + rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None) + lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None) + lhs = self._bitcast_to_fp_type(lhs, lhs_format) + rhs = self._bitcast_to_fp_type(rhs, rhs_format) + + assert lhs_k_pack or lhs_format == "e2m1", "only mxfp4 inputs can be packed along a dimension different than K" + assert rhs_k_pack or rhs_format == "e2m1", "only mxfp4 inputs can be packed along a dimension different than K" + M, K_LHS = lhs.type.shape[-2:] + K_RHS, N = rhs.type.shape[-2:] + PACKED_A = 2 if lhs_format == "e2m1" else 1 + PACKED_B = 2 if rhs_format == "e2m1" else 1 + PACKED_A_DIM = PACKED_A * K_LHS if lhs_k_pack else K_LHS + PACKED_B_DIM = PACKED_B * K_RHS if rhs_k_pack else K_RHS + assert PACKED_B_DIM == PACKED_A_DIM, f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + #assert K * PACKED_B >= 64, f"scaled_dot NYI for K < 64. Got {K=}" + B = lhs.type.shape[0] if lhs_rank == 3 else None + if not lhs_k_pack: + M = M * PACKED_A + if not rhs_k_pack: + N = N * PACKED_B + ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N]) + _0 = self.builder.get_fp32(0) + if acc is None: + acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0) + else: + acc_handle = acc.handle + assert acc.type == ret_ty + rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle + lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle + return self.tensor( + self.builder.create_dot_scaled(lhs.handle, lhs_scale_handle, lhs_format_enum, rhs.handle, rhs_scale_handle, + rhs_format_enum, fast_math, lhs_k_pack, rhs_k_pack, acc_handle), ret_ty) + +# ===----------------------------------------------------------------------===// +# Indexing +# ===----------------------------------------------------------------------===// + + def where(self, condition: TensorTy, x: TensorTy, y: TensorTy) -> TensorTy: + if condition.dtype != tl.int1: + warnings.warn( + f"tl.where with a non-boolean condition is deprecated and will error out in a future triton release. Got {condition.dtype}" + ) + condition = self.cast(condition, tl.int1) + x, y = self.binary_op_type_checking_impl(x, y, True, True) + # x, y are broadcasted + if condition.type.is_block(): + condition, x = self.broadcast_impl_value(condition, x) + x, y = self.broadcast_impl_value(x, y) + else: + condition, _ = self.broadcast_impl_value(condition, x) + ret_ty = x.type + return self.tensor(self.builder.create_select(condition.handle, x.handle, y.handle), ret_ty) + +# ===----------------------------------------------------------------------===// +# Reduction +# ===----------------------------------------------------------------------=== + + def wrap_tensor(self, x, scalar_ty, ret_shape): + if ret_shape: + res_ty = tl.block_type(scalar_ty, ret_shape) + else: + # 0d-tensor -> scalar + res_ty = scalar_ty + return self.tensor(x, res_ty) + + def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]: + if axis is None: + inputs = tuple(self.reshape(t, [t.numel.value], can_reorder=True) for t in inputs) + axis = 0 + # get result shape + shape = inputs[0].type.shape + rank = len(shape) + assert axis < rank, f"reduction axis must be < inputs rank ({rank})" + ret_shape = [s for i, s in enumerate(shape) if i != axis] + assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape" + + reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis) + region_builder_fn(reduce_op) + assert reduce_op.verify() + + return tuple( + self.wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs))) + +# ===----------------------------------------------------------------------=== +# Associative Scan +# ===----------------------------------------------------------------------=== + + def associative_scan(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn, + reverse: bool) -> Tuple[TensorTy, ...]: + shape = inputs[0].type.shape + rank = len(shape) + + assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})" + + if axis < 0: + axis += rank + + for t in inputs: + assert t.type.shape == shape, "all scan inputs must have the same shape" + + scan_op = self.builder.create_scan([t.handle for t in inputs], axis, reverse) + region_builder_fn(scan_op) + assert scan_op.verify() + + return tuple(self.wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs))) + +# ===----------------------------------------------------------------------=== +# Gather +# ===----------------------------------------------------------------------=== + + def gather(self, src: TensorTy, index: TensorTy, axis: int) -> TensorTy: + assert index.dtype.is_int(), "index must be an integer tensor" + + rank = len(src.type.shape) + assert len(index.type.shape) == rank, "source and index tensors must have the same rank" + + assert -rank <= axis < rank, f"gather axis {axis} must be < source rank ({rank})" + if axis < 0: + axis += rank + + for d in range(rank): + if d == axis: + continue + assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim" + + gather = self.builder.create_gather(src.handle, index.handle, axis) + return self.wrap_tensor(gather, src.type.scalar, index.type.shape) + + +# ===----------------------------------------------------------------------=== +# Histogram +# ===----------------------------------------------------------------------=== + + def histogram(self, input: TensorTy, num_bins: int, mask: Optional[TensorTy]) -> TensorTy: + assert len(input.shape) == 1, "histogram only supports 1D input" + assert input.dtype.is_int(), "histogram only supports integer input" + if mask is not None: + mask = self.broadcast_impl_shape(mask, input.shape) + if not mask.type.scalar.is_bool(): + raise ValueError("Mask must have boolean scalar type") + mask = mask.handle + return self.tensor(self.builder.create_histogram(input.handle, num_bins, mask), + tl.block_type(tl.int32, [num_bins])) + + def multiple_of(self, x: TensorTy, values: List[int]) -> TensorTy: + if max(1, len(x.shape)) != len(values): + raise ValueError("Shape of input to multiple_of does not match the length of values") + x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context())) + return x + + def max_contiguous(self, x: TensorTy, values: List[int]) -> TensorTy: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_contiguous does not match the length of values") + x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context())) + return x + + def max_constancy(self, x: TensorTy, values: List[int]) -> TensorTy: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_constancy does not match the length of values") + x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context())) + return x + + def debug_barrier(self) -> TensorTy: + return self.tensor(self.builder.create_barrier(), tl.void) + + def device_print(self, prefix: str, args: List[TensorTy], hex: bool) -> TensorTy: + # It makes sense visually for prefix to end in ": "; make it so. Also, + # non-empty prefixes should start with " ". + if not prefix.endswith(" ") and args: + prefix += " " + if not prefix.endswith(": ") and args: + prefix = prefix[:-1] + ": " + if len(prefix) > 2 and not prefix.startswith(" "): + prefix = " " + prefix + + new_args = [arg.handle for arg in args] + is_signed = [arg.dtype.is_int_signed() for arg in args] + return self.tensor(self.builder.create_print(prefix, hex, new_args, is_signed), tl.void) + + def device_assert(self, cond: TensorTy, msg: str) -> TensorTy: + if not self.builder.options.debug: + return + return self.tensor(self.builder.create_assert(cond.handle, msg), tl.void) + + def assume(self, cond) -> TensorTy: + return self.tensor(self.builder.create_assume(cond.handle), tl.void) + + def _convert_elem_to_ir_value(self, elem, require_i64): + if isinstance(elem, int): + elem = tl.constexpr(elem) + if isinstance(elem, tl.constexpr): + if isinstance(elem.value, bool): + return self.builder.get_int1(elem.value) + if require_i64: + assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \ + f"got a value {elem.value} which is out of the range" + return self.builder.get_int64(elem.value) + else: + assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \ + f"got a value {elem.value} which is out of the range" + return self.builder.get_int32(elem.value) + elif isinstance(elem, tl.tensor): + assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets" + assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets" + if elem.dtype != tl.int64 and require_i64: + return self.builder.create_int_cast(elem.handle, self.builder.get_int64_ty(), + elem.dtype.is_int_signed()) + elif elem.dtype != tl.int32 and not require_i64: + assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \ + "add a `.to(tl.int32)` or use regular indexing for 64 bit support" + return elem.handle + assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}" + + def _convert_to_ir_values(self, list_like, require_i64=True): + if hasattr(list_like, "__iter__"): + return [self._convert_elem_to_ir_value(elem, require_i64) for elem in list_like] + return [self._convert_elem_to_ir_value(list_like, require_i64)] + + def make_block_ptr(self, base: TensorTy, shape, strides, offsets, block_shape, order) -> TensorTy: + # Convert dynamic arguments to IR values + # NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t` + shape = self._convert_to_ir_values(shape) + strides = self._convert_to_ir_values(strides) + offsets = self._convert_to_ir_values(offsets, require_i64=False) + + # Check `base` type + if not base.type.is_ptr() or base.type.element_ty.is_block(): + raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)") + + # Treat `pointer_type` as `pointer_type` + if base.type.element_ty == tl.int1: + base = self.cast(base, tl.pointer_type(tl.int8, base.type.address_space)) + + # Check whether `block_shape` is static + if not hasattr(block_shape, "__iter__"): + block_shape = [block_shape] + block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape] + assert all(isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape), \ + "Expected a list of constant integers (`int32_t` range) in `block_shape`" + + # Check `order` + if not hasattr(order, "__iter__"): + order = [order] + order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order] + assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order" + + # Must have same length + assert all(len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]), \ + "Expected shape/strides/offsets/block_shape to have the same length" + + # Build value, the type is: + # `pointer_type>` in Python + # `tt.ptr>` in MLIR + handle = self.builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order) + return self.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape))) + + def advance(self, base: TensorTy, offsets) -> TensorTy: + # Convert dynamic offsets to IR values + offsets = self._convert_to_ir_values(offsets, require_i64=False) + + # Advanced block pointer type is the same as before + return self.tensor(self.builder.create_advance(base.handle, offsets), base.type) + + def make_tensor_descriptor( + self, + base: TensorTy, + shape: List[TensorTy], + strides: List[TensorTy], + block_shape: List[tl.constexpr], + ) -> tl.tensor_descriptor: + ndim = len(shape) + if not (1 <= ndim <= 5): + raise ValueError(f"Expected 1 <= ndim <= 5 but got {ndim} dimensions") + if len(strides) != ndim: + raise ValueError(f"Expected {ndim} strides but got {len(strides)}") + if len(block_shape) != ndim: + raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}") + assert isinstance(base.dtype, tl.pointer_type) + elem_size = base.dtype.element_ty.primitive_bitwidth // 8 + contig_dim_size = tl._unwrap_if_constexpr(block_shape[-1]) + if contig_dim_size * elem_size < 16: + raise ValueError( + f"Descriptor block shape must have at least 16 bytes in the last dimension, but got {contig_dim_size} * {elem_size} = {contig_dim_size * elem_size} bytes" + ) + + strides[-1] = tl._unwrap_if_constexpr(strides[-1]) + if strides[-1] != 1: + raise ValueError(f"Tensor descriptor last dim must be 1 but got {strides[-1]}") + + shape = [self.make_scalar(x, tl.int32) for x in shape] + strides = [self.make_scalar(x, tl.int64) for x in strides] + + # Check whether `block_shape` is static + block_shape = tl._unwrap_shape(block_shape) + + assert isinstance(base.type, tl.pointer_type) + type = tl.block_type(base.type.element_ty, block_shape) + base_handle = base.handle + is_signed_int = base.type.element_ty.is_int_signed() + + handle = self.builder.create_make_tensor_descriptor(base_handle, [s.handle for s in shape], + [s.handle for s in strides], block_shape, is_signed_int) + return tl.tensor_descriptor(handle, shape, strides, type) diff --git a/third_party/sunrise/python/triton/language/standard.py b/third_party/sunrise/python/triton/language/standard.py new file mode 100644 index 000000000..faac30966 --- /dev/null +++ b/third_party/sunrise/python/triton/language/standard.py @@ -0,0 +1,535 @@ +from __future__ import annotations + +from ..runtime.jit import jit +from . import core +from . import math + +# constexpr utilities + + +def _log2(i: core.constexpr): + log2 = 0 + n = core.constexpr(i).value + while n > 1: + n >>= 1 + log2 += 1 + return core.constexpr(log2) + + +def _is_power_of_two(i: core.constexpr): + n = i.value + return core.constexpr((n & (n - 1)) == 0 and n != 0) + + +# ----------------------- +# Standard library +# ----------------------- + + +@core._tensor_member_fn +@jit +def cdiv(x, div): + """ + Computes the ceiling division of :code:`x` by :code:`div` + + :param x: the input number + :type x: Block + :param div: the divisor + :type div: Block + """ + return (x + div - 1) // div + + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("sigmoid") +def sigmoid(x): + return 1 / (1 + math.exp(-x)) + + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("softmax") +def softmax(x, dim=None, keep_dims=False, ieee_rounding=False): + if dim is None: + _dim: core.constexpr = 0 + else: + _dim: core.constexpr = dim + z = x - max(x, _dim, keep_dims=keep_dims) + num = math.exp(z) + den = sum(num, _dim, keep_dims=keep_dims) + return math.fdiv(num, den, ieee_rounding) + + +@core._tensor_member_fn +@jit +def ravel(x, can_reorder=False): + """ + Returns a contiguous flattened view of :code:`x`. + + :param x: the input tensor + :type x: Block + """ + return core.reshape(x, [x.numel], can_reorder=can_reorder) + + +@jit +def swizzle2d(i, j, size_i, size_j, size_g): + """ + Transforms the indices of a row-major `size_i * size_j` matrix into + the indices of a column-major matrix for each group of `size_g` rows. + + For example, for :code:`size_i = size_j = 4` and :code:`size_g = 2`, it will + transform :: + + [[0 , 1 , 2 , 3 ], + [4 , 5 , 6 , 7 ], + [8 , 9 , 10, 11], + [12, 13, 14, 15]] + + into :: + + [[0, 2, 4 , 6 ], + [1, 3, 5 , 7 ], + [8, 10, 12, 14], + [9, 11, 13, 15]] + """ + # "unrolled index in array" + ij = i * size_j + j + # number of elements in `size_g` groups + # of `size_j` columns + size_gj = size_g * size_j + # index of the group in which (i,j) is + group_id = ij // size_gj + # row-index of the first element of this group + off_i = group_id * size_g + # last group may have fewer rows + size_g = core.minimum(size_i - off_i, size_g) + # linear index with respect to the first element in this group + ij = ij % size_gj + # new row and column indices + new_i = off_i + ij % size_g + new_j = ij // size_g + return new_i, new_j + + +@jit +def zeros(shape, dtype): + """ + Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., :code:`tl.float16` + :type dtype: DType + """ + return core.full(shape, 0, dtype) + + +@jit +def zeros_like(input): + """ + Returns a tensor of zeros with the same shape and type as a given tensor. + + :param input: input tensor + :type input: Tensor + """ + return zeros(input.shape, input.dtype) + + +# max and argmax + + +@jit +def _argmax_combine(value1, index1, value2, index2, tie_break_left): + if tie_break_left: + tie = value1 == value2 and index1 < index2 + else: + tie = False + gt = value1 > value2 or tie + v_ret = core.where(gt, value1, value2) + i_ret = core.where(gt, index1, index2) + return v_ret, i_ret + + +@jit +def _argmax_combine_tie_break_left(value1, index1, value2, index2): + return _argmax_combine(value1, index1, value2, index2, True) + + +@jit +def _argmax_combine_tie_break_fast(value1, index1, value2, index2): + return _argmax_combine(value1, index1, value2, index2, False) + + +@jit +def _elementwise_max(a, b): + return core.maximum(a, b) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("maximum", return_indices_arg="return_indices", + tie_break_arg="return_indices_tie_break_left") +def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + if return_indices: + if return_indices_tie_break_left: + return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_left, keep_dims=keep_dims) + else: + return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast, keep_dims=keep_dims) + else: + if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32): + if core.constexpr(input.dtype.is_floating()): + input = input.to(core.float32) + else: + assert input.dtype.is_int(), "Expecting input to be integer type" + input = input.to(core.int32) + return core.reduce(input, axis, _elementwise_max, keep_dims=keep_dims) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("maximum index", tie_break_arg="tie_break_left") +def argmax(input, axis, tie_break_left=True, keep_dims=False): + (_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims) + return ret + + +# min and argmin + + +@jit +def _argmin_combine(value1, index1, value2, index2, tie_break_left): + if tie_break_left: + tie = value1 == value2 and index1 < index2 + else: + tie = False + lt = value1 < value2 or tie + value_ret = core.where(lt, value1, value2) + index_ret = core.where(lt, index1, index2) + return value_ret, index_ret + + +@jit +def _argmin_combine_tie_break_left(value1, index1, value2, index2): + return _argmin_combine(value1, index1, value2, index2, True) + + +@jit +def _argmin_combine_tie_break_fast(value1, index1, value2, index2): + return _argmin_combine(value1, index1, value2, index2, False) + + +@jit +def _elementwise_min(a, b): + return core.minimum(a, b) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("minimum", return_indices_arg="return_indices", + tie_break_arg="return_indices_tie_break_left") +def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + if return_indices: + if return_indices_tie_break_left: + return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_left, keep_dims=keep_dims) + else: + return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast, keep_dims=keep_dims) + else: + if core.constexpr(input.dtype.primitive_bitwidth) < 32: + if core.constexpr(input.dtype.is_floating()): + input = input.to(core.float32) + else: + assert input.dtype.is_int(), "Expecting input to be integer type" + input = input.to(core.int32) + return core.reduce(input, axis, _elementwise_min, keep_dims=keep_dims) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left") +def argmin(input, axis, tie_break_left=True, keep_dims=False): + _, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims) + return ret + + +@jit +def _sum_combine(a, b): + return a + b + + +# sum + + +def _pick_sum_dtype(in_dtype: core.constexpr, dtype: core.constexpr): + dtype = core._unwrap_if_constexpr(dtype) + if dtype is not None: + return dtype + + # For integer bitwidths less than 32, pick int32 with the same sign to + # avoid overflow. + out_dtype = None + if in_dtype.is_int_signed(): + out_dtype = core.int32 if in_dtype.int_bitwidth < 32 else None + elif in_dtype.is_int_unsigned(): + out_dtype = core.uint32 if in_dtype.int_bitwidth < 32 else None + return out_dtype + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("sum", dtype_arg="dtype") +def sum(input, axis=None, keep_dims=False, dtype: core.constexpr = None): + # Pick a default dtype for the reduction if one was not specified. + out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype) + + if out_dtype is not None: + input = input.to(out_dtype) + return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims) + + +@jit +def _xor_combine(a, b): + return a ^ b + + +# xor sum + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("xor sum") +def xor_sum(input, axis=None, keep_dims=False): + core.static_assert(input.type.scalar.is_int(), "xor_sum only supported for integers") + return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims) + + +# or reduction + + +@jit +def _or_combine(x, y): + return x | y + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("reduce_of") +def reduce_or(input, axis, keep_dims=False): + core.static_assert(input.type.scalar.is_int(), "reduce_of only supported for integers") + return core.reduce(input, axis, _or_combine, keep_dims=keep_dims) + + +# cumsum + + +@core._tensor_member_fn +@jit +@core._add_scan_docstr("cumsum", dtype_arg="dtype") +def cumsum(input, axis=0, reverse=False, dtype: core.constexpr = None): + # todo rename this to a generic function name + + input = core._promote_bfloat16_to_float32(input) + out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype) + + if out_dtype is not None: + input = input.to(out_dtype) + + return core.associative_scan(input, axis, _sum_combine, reverse) + + +# cumprod + + +@jit +def _prod_combine(a, b): + return a * b + + +@core._tensor_member_fn +@jit +@core._add_scan_docstr("cumprod") +def cumprod(input, axis=0, reverse=False): + # todo rename this to a generic function name + input = core._promote_bfloat16_to_float32(input) + return core.associative_scan(input, axis, _prod_combine, reverse) + + +# sort + + +@jit +def _indicator(n_dims: core.constexpr, j: core.constexpr): + ar = core.arange(0, 2) + ar = core.reshape(ar, [1] * (n_dims - j - 1) + [2] + [1] * j) + return ar + + +@jit +def _compare_and_swap(x, flip, i: core.constexpr): + # compare-and-swap on the ith *innermost* dimension + n_dims: core.constexpr = _log2(x.numel) + + # flip along middle dimension (the bitwise XORs will be optimised away): + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + ix = x.to(idtype, bitcast=True) + iy = ix ^ xor_sum(ix, n_dims - 1 - i, True) + y = iy.to(x.dtype, bitcast=True) + + # determines whether we are in the right (rather than left) position along the axis: + is_right = _indicator(n_dims, i) + + # conditional swap: + ret = core.where((x > y) != (flip ^ is_right), y, x) + return ret + + +@jit +def _bitonic_merge_hypercube(x, stage: core.constexpr, order: core.constexpr): + ''' + order_type 0 == ascending + order_type 1 == descending + order_type 2 == alternating + ''' + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if order == 2: + flip = _indicator(_log2(x.numel), stage) + else: + flip = order + # perform `stage` rounds of `compare-and-swap` + for i in core.static_range(stage): + x = _compare_and_swap(x, flip, stage - 1 - i) + return x + + +@jit +def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr): + h = core.reshape(x, [2] * _log2(x.numel)) + h = _bitonic_merge_hypercube(h, stage, order) + x = core.reshape(h, x.shape) + return x + + +@jit +def sort_impl(x, k: core.constexpr = None, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0): + """ + Sorts a tensor along a specified dimension. + + :param x: The input tensor to be sorted. + :type x: Tensor + :param dim: The dimension along which to sort the tensor. If None, the tensor is sorted along the last dimension. Currently, only sorting along the last dimension is supported. + :type dim: int, optional + :param k: the number of top elements to select. If none, assume k = x.shape[dim] + :type k: int, optional + :param descending: If set to True, the tensor is sorted in descending order. If set to False, the tensor is sorted in ascending order. + :type descending: bool, optional + """ + # handle default dimension or check that it is the most minor dim + _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim + core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") + + log_n: core.constexpr = _log2(x.shape[_dim]) + log_k: core.constexpr = log_n if k is None else _log2(k) + + n_dims: core.constexpr = _log2(x.numel) + + # reshape to hypercube: + h = core.reshape(x, [2] * n_dims) + + # run first log_k bitonic sort iterations: + for i in core.static_range(1, log_k + 1): + h = _bitonic_merge_hypercube(h, i, 2 if i < log_n else descending) + + # select top k elements using bitonic top-k + # https://www.doc.ic.ac.uk/~hlgr/pdfs/MassivelyParallelTopK.pdf + for i in core.static_range(log_k + 1, log_n + 1): + h = max(h, axis=(_log2(h.numel) - 1 - log_k)) if descending else min(h, axis=(_log2(h.numel) - 1 - log_k)) + h = _bitonic_merge_hypercube(h, log_k, 2 if i < log_n else descending) + + # reshape back: + x = core.reshape(h, x.shape[:-1] + [2**log_k]) + return x + + +@jit +def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0): + return sort_impl(x, dim=dim, descending=descending) + + +@jit +def topk(x, k: core.constexpr, dim: core.constexpr = None): + return sort_impl(x, k=k, dim=dim, descending=True) + + +@jit +def bitonic_merge(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0): + # handle default dimension or check that it is the most minor dim + _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim + core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") + n_dims: core.constexpr = _log2(x.shape[-1]) + return _bitonic_merge(x, n_dims, descending, n_dims) + + +def _get_flip_dim(dim, shape): + dim = core._unwrap_if_constexpr(dim) + shape = core._unwrap_if_constexpr(shape) + if dim is None: + dim = len(shape) - 1 + if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index + dim += len(shape) + return core.constexpr(dim) + + +@core._tensor_member_fn +@jit +def flip(x, dim=None): + """ + Flips a tensor `x` along the dimension `dim`. + + :param x: the first input tensor + :type x: Block + :param dim: the dimension to flip along + :type dim: int + """ + if dim is not None: # fix bugs for dim=NOne + core.static_assert(-len(x.shape) <= dim and dim < len(x.shape)) + _dim: core.constexpr = _get_flip_dim(dim, x.shape) + core.static_assert(_is_power_of_two(x.shape[_dim])) + steps: core.constexpr = _log2(x.shape[_dim]) + + # reshape the swap dimension to (2, 2, ..., 2) + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + y = core.reshape(x.to(idtype, bitcast=True), x.shape[:_dim] + [2] * steps + x.shape[_dim + 1:]) + for i in core.static_range(steps): + y = y ^ xor_sum(y, _dim + i, True) + x = core.reshape(y, x.shape).to(x.dtype, bitcast=True) + return x + + +@jit +def interleave(a, b): + """ + Interleaves the values of two tensors along their last dimension. The two tensors must have the same shape. + Equivalent to `tl.join(a, b).reshape(a.shape[:-1] + [2 * a.shape[-1]])` + + :param a: The first input tensor. + :type a: Tensor + :param b: The second input tensor. + :type b: Tensor + """ + c = core.join(a, b) + + if len(c.shape) == 1: + # We must have interleaved two scalars. + return c + else: + # This `else` is necessary because Triton's AST parser doesn't + # understand that if we take the `if` above we definitely don't run this + # `else`. + return core.reshape(c, c.shape[:-2] + [2 * c.shape[-2]]) diff --git a/third_party/sunrise/python/triton/runtime/__init__.py b/third_party/sunrise/python/triton/runtime/__init__.py new file mode 100644 index 000000000..0b3979d28 --- /dev/null +++ b/third_party/sunrise/python/triton/runtime/__init__.py @@ -0,0 +1,23 @@ +from .autotuner import (Autotuner, Config, Heuristics, autotune, heuristics) +from .cache import RedisRemoteCacheBackend, RemoteCacheBackend +from .driver import driver +from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret +from .errors import OutOfResources, InterpreterError + +__all__ = [ + "autotune", + "Autotuner", + "Config", + "driver", + "Heuristics", + "heuristics", + "InterpreterError", + "JITFunction", + "KernelInterface", + "MockTensor", + "OutOfResources", + "RedisRemoteCacheBackend", + "reinterpret", + "RemoteCacheBackend", + "TensorWrapper", +] diff --git a/third_party/sunrise/python/triton/runtime/_allocation.py b/third_party/sunrise/python/triton/runtime/_allocation.py new file mode 100644 index 000000000..aa8a45488 --- /dev/null +++ b/third_party/sunrise/python/triton/runtime/_allocation.py @@ -0,0 +1,32 @@ +from typing import Optional, Protocol + + +class Buffer(Protocol): + + def data_ptr(self) -> int: + ... + + +class Allocator(Protocol): + + def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer: + ... + + +class NullAllocator: + + def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer: + raise RuntimeError("Kernel requires a runtime memory allocation, but no allocator was set. " + + "Use triton.set_allocator to specify an allocator.") + + +_allocator: Allocator = NullAllocator() + + +def set_allocator(allocator: Allocator): + """ + The allocator function is called during kernel launch for kernels that + require additional global memory workspace. + """ + global _allocator + _allocator = allocator diff --git a/third_party/sunrise/python/triton/runtime/autotuner.py b/third_party/sunrise/python/triton/runtime/autotuner.py new file mode 100644 index 000000000..2f8878e7e --- /dev/null +++ b/third_party/sunrise/python/triton/runtime/autotuner.py @@ -0,0 +1,483 @@ +from __future__ import annotations + +import builtins +import time +import inspect +import hashlib +import json +from functools import cached_property +from typing import Dict, Tuple, List, Optional + +from .. import knobs +from .jit import KernelInterface +from .errors import OutOfResources, PTXASError +from .driver import driver + +import os + + +class Autotuner(KernelInterface): + + def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pre_hook=None, post_hook=None, + prune_configs_by: Optional[Dict] = None, warmup=None, rep=None, use_cuda_graph=False, do_bench=None, + cache_results=False): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. + """ + if not configs: + self.configs = [Config({}, num_warps=4, num_stages=3, num_ctas=1)] + else: + self.configs = configs + self.keys = key + self.cache: Dict[Tuple, Config] = {} + self.arg_names = arg_names + self.cache_results = cache_results or (knobs.autotuning.cache and not knobs.runtime.interpret) + + # Reset to zero or restore values + self.reset_to_zero = [] + if reset_to_zero is not None: + self.reset_to_zero = list(reset_to_zero) + self.restore_value = [] + if restore_value is not None: + self.restore_value = list(restore_value) + + # Hook to reset or restore for required tensors + self.pre_hook = lambda kwargs, reset_only=False: 0 + self.post_hook = lambda kwargs, exception: 0 + self.user_defined_pre_hook = False + self.user_defined_post_hook = False + if pre_hook: + self.pre_hook = pre_hook + self.user_defined_pre_hook = True + elif (len(self.reset_to_zero) > 0 or len(self.restore_value) > 0): + + def _pre_hook(kwargs, reset_only=False): + for name in self.reset_to_zero: + kwargs[name].zero_() + if not reset_only: + self.restore_copies = {name: kwargs[name].clone() for name in self.restore_value} + + self.pre_hook = _pre_hook + + if post_hook: + self.post_hook = post_hook + self.user_defined_post_hook = True + elif len(self.restore_value) > 0: + + def _post_hook(kwargs, exception): + for name in self.restore_value: + kwargs[name].copy_(self.restore_copies[name]) + self.restore_copies = {} + + self.post_hook = _post_hook + + self.perf_model = None + self.configs_top_k = 1.0 + self.early_config_prune = None + if prune_configs_by: + self.perf_model = prune_configs_by.get("perf_model", self.perf_model) + self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k) + self.early_config_prune = prune_configs_by.get("early_config_prune", self.early_config_prune) + + self.fn = fn + self.base_fn = fn + while not inspect.isfunction(self.base_fn): + self.base_fn = self.base_fn.fn + + self._do_bench = do_bench + self.num_warmups = warmup + self.num_reps = rep + self.use_cuda_graph = use_cuda_graph + + # If we got explicitly called via the old interface, raise a warning + # and proceed with the old behavior. + if warmup is not None or rep is not None or use_cuda_graph: + import warnings + warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See " + "https://github.com/triton-lang/triton/pull/4496 for details."), DeprecationWarning, + stacklevel=1) + if use_cuda_graph: + from ..testing import do_bench_cudagraph + self._do_bench = lambda kernel_call, quantiles: do_bench_cudagraph( + kernel_call, + rep=rep if rep is not None else 100, + quantiles=quantiles, + ) + return + + import triton.testing + self._do_bench = lambda kernel_call, quantiles: triton.testing.do_bench( + kernel_call, + warmup=warmup if warmup is not None else 25, + rep=rep if rep is not None else 100, + quantiles=quantiles, + ) + return + + @cached_property + def do_bench(self): + if self._do_bench is None: + return driver.active.get_benchmarker() + return self._do_bench + + def _bench(self, *args, config, **meta): + from ..compiler.errors import CompileTimeAssertionFailure + + verbose = knobs.autotuning.print + if verbose: + print(f"Autotuning kernel {self.base_fn.__name__} with config {config}") + + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.all_kwargs()) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + self.pre_hook(full_nargs) + try: + self.fn.run( + *args, + **current, + ) + except Exception as e: + try: + self.post_hook(full_nargs, exception=e) + finally: + # Throw exception raised by `self.fn.run` + raise + + self.post_hook(full_nargs, exception=None) + + try: + return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) + except (OutOfResources, CompileTimeAssertionFailure, PTXASError) as e: + if verbose: + print(f"Autotuning failed with {e}") + return [float("inf"), float("inf"), float("inf")] + + def check_disk_cache(self, tuning_key, configs, bench_fn): + # We can't serialize prehooks, so just give up and run the benchmarks. + if not tuning_key or any(cfg.pre_hook for cfg in configs): + bench_fn() + return False + + from triton._C.libtriton import get_cache_invalidating_env_vars + from triton.compiler.compiler import make_backend, triton_key + from triton.runtime.cache import get_cache_manager + from triton.runtime.jit import JITFunction + + fn = self.fn + while not isinstance(fn, JITFunction): + fn = fn.fn + + env_vars = get_cache_invalidating_env_vars() + cache_key = [ + triton_key(), + make_backend(driver.active.get_current_target()).hash(), + fn.cache_key, + str(sorted(env_vars.items())), + str(tuning_key), + ] + [str(c) for c in configs] + cache_key = hashlib.sha256("-".join(cache_key).encode("utf-8")).hexdigest() + cache = get_cache_manager(cache_key) + file_name = f"{fn.__name__[:150]}.autotune.json" + path = cache.get_file(file_name) + if path: + with open(path, "r") as cached_configs: + timings = json.load(cached_configs)["configs_timings"] + timings = {Config(**config): timing for config, timing in timings} + self.cache[tuning_key] = builtins.min(timings, key=timings.get) + self.configs_timings = timings + return True + + bench_fn() + cache.put( + json.dumps({ + "key": + tuning_key, + "configs_timings": + [(config.__dict__, timings) for config, timings in self.configs_timings.items() if not config.pre_hook], + }), file_name, binary=False) + return False + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + used_cached_result = True + if len(self.configs) > 1: + all_args = {**self.nargs, **kwargs} + _args = {k: v for (k, v) in all_args.items() if k in self.arg_names} + key = [_args[key] for key in self.keys if key in _args] + for _, arg in _args.items(): + if hasattr(arg, "dtype"): + key.append(str(arg.dtype)) + key = tuple(key) + if key not in self.cache: + used_cached_result = False + pruned_configs = self.prune_configs(kwargs) + + def benchmark(): + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()} + self.pre_hook(full_nargs, reset_only=True) + self.configs_timings = timings + + if self.cache_results: + used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark) + else: + benchmark() + + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if knobs.autotuning.print and not used_cached_result: + print(f"Triton autotuning for function {self.base_fn.__name__},\nwith key as {key},\n" + f"finished after {self.bench_time:.2f}s,\nbest config selected: {self.best_config};") + if config.pre_hook is not None: + full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} + config.pre_hook(full_nargs) + ret = self.fn.run( + *args, + **kwargs, + **config.all_kwargs(), + ) + self.nargs = None + return ret + + def prune_configs(self, kwargs: Dict) -> List[Config]: + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + elif not isinstance(top_k, int): + # Slice index must be an integer + raise TypeError("Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int") + + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.all_kwargs(), + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + ret = [] + for autotune_config in self.prune_configs(kwargs): + ret.append(self.fn.warmup( + *args, + **kwargs, + **autotune_config.all_kwargs(), + )) + self.nargs = None + return ret + + +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type kwargs: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_stages: int + :ivar num_ctas: number of blocks in a block cluster. SM90+ only. + :type num_ctas: int + :type maxnreg: Optional[int] + :ivar maxnreg: maximum number of registers one thread can use. Corresponds + to ptx .maxnreg directive. Not supported on all platforms. + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + :ivar ir_override: filename of a user-defined IR (*.{ttgir|llir|ptx|amdgcn}). + """ + + def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1, maxnreg=None, pre_hook=None, ir_override=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_ctas = num_ctas + self.num_stages = num_stages + self.maxnreg = maxnreg + self.pre_hook = pre_hook + self.ir_override = ir_override + + def __setstate__(self, state): + self.kwargs = state.get("kwargs", {}) + self.num_warps = state.get("num_warps", 4) + self.num_stages = state.get("num_stages", 3) + self.num_ctas = state.get("num_ctas", 1) + self.maxnreg = state.get("maxnreg", None) + self.pre_hook = state.get("pre_hook", None) + self.ir_override = state.get("ir_override", None) + + def all_kwargs(self): + return { + **self.kwargs, **{ + k: v + for (k, v) in ( + ("num_warps", self.num_warps), + ("num_ctas", self.num_ctas), + ("num_stages", self.num_stages), + ("maxnreg", self.maxnreg), + ("ir_override", self.ir_override), + ) if v is not None + } + } + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f"{k}: {v}") + res.append(f"num_warps: {self.num_warps}") + res.append(f"num_ctas: {self.num_ctas}") + res.append(f"num_stages: {self.num_stages}") + res.append(f"maxnreg: {self.maxnreg}") + return ", ".join(res) + + def __hash__(self): + return hash((*self.all_kwargs().items(), self.pre_hook)) + + def __eq__(self, other): + self_tuple = tuple(( + *self.all_kwargs().items(), + self.pre_hook, + )) + other_tuple = tuple(( + *other.all_kwargs().items(), + other.pre_hook, + )) + return self_tuple == other_tuple + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None, + warmup=None, rep=None, use_cuda_graph=False, do_bench=None, cache_results=False): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr): + ... + :note: When all the configurations are evaluated, the kernel will run multiple times. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + resets the value of the provided tensor to `zero` before running any configuration. + + If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to + :code:`"1"`, Triton will print a message to stdout after autotuning each + kernel, including the time spent autotuning and the best configuration. + + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] + :param pre_hook: a function that will be called before the kernel is called. + This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'. + 'kwargs': a dict of all arguments passed to the kernel. + 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook. + :type pre_hook: lambda args, reset_only + :param post_hook: a function that will be called after the kernel is called. + This overrides the default post_hook used for 'restore_value'. + 'kwargs': a dict of all arguments passed to the kernel. + 'exception': the exception raised by the kernel in case of a compilation or runtime error. + :type post_hook: lambda args, exception + :param warmup: warmup time (in ms) to pass to benchmarking (deprecated). + :type warmup: int + :param rep: repetition time (in ms) to pass to benchmarking (deprecated). + :type rep: int + :param do_bench: a benchmark function to measure the time of each run. + :type do_bench: lambda fn, quantiles + :param cache_results: whether to cache autotune timings to disk. Defaults to False. + "type cache_results: bool + """ + if os.getenv('TRITON_OFF_AUTOTUNER', '0') == '1': + def empty_decorator(fn): + return fn + return empty_decorator + + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, + use_cuda_graph=use_cuda_graph, do_bench=do_bench, cache_results=cache_results) + + return decorator + + +class Heuristics(KernelInterface): + + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + +def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitively expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + # smallest power-of-two >= x_size + @triton.heuristics(values={'BLOCK_SIZE': lambda args: triton.next_power_of_2(args['x_size'])}) + @triton.jit + def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr): + ... + :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + :type values: dict[str, Callable[[dict[str, Any]], Any]] + """ + + def decorator(fn): + return Heuristics(fn, fn.arg_names, values) + + return decorator diff --git a/third_party/sunrise/python/triton/runtime/build.py b/third_party/sunrise/python/triton/runtime/build.py new file mode 100644 index 000000000..2d0899603 --- /dev/null +++ b/third_party/sunrise/python/triton/runtime/build.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import functools +import hashlib +import importlib.util +import logging +import os +import shutil +import subprocess +import sysconfig +import tempfile + +from types import ModuleType + +from .cache import get_cache_manager +from .. import knobs + + +def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str], + libraries: list[str]) -> str: + if impl := knobs.build.impl: + return impl(name, src, srcdir, library_dirs, include_dirs, libraries) + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) + # try to avoid setuptools if possible + cc = os.environ.get("CC") + if cc is None: + clang = shutil.which("clang") + gcc = shutil.which("gcc") + cc = gcc if gcc is not None else clang + if cc is None: + raise RuntimeError( + "Failed to find C compiler. Please specify via CC environment variable or set triton.knobs.build.impl.") + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() # type: ignore + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + custom_backend_dirs = knobs.build.backend_dirs + include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs] + # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 + cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so] + cc_cmd += [f'-l{lib}' for lib in libraries] + cc_cmd += [f"-L{dir}" for dir in library_dirs] + cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] + subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL) + return so + + +@functools.lru_cache +def platform_key() -> str: + from platform import machine, system, architecture + return ",".join([machine(), system(), *architecture()]) + + +def _load_module_from_path(name: str, path: str) -> ModuleType: + spec = importlib.util.spec_from_file_location(name, path) + if not spec or not spec.loader: + raise RuntimeError(f"Failed to load newly compiled {name} from {path}") + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None = None, + include_dirs: list[str] | None = None, libraries: list[str] | None = None) -> ModuleType: + key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + suffix = sysconfig.get_config_var("EXT_SUFFIX") + cache_path = cache.get_file(f"{name}{suffix}") + + if cache_path is not None: + try: + return _load_module_from_path(name, cache_path) + except (RuntimeError, ImportError): + log = logging.getLogger(__name__) + log.warning(f"Triton cache error: compiled module {name}.so could not be loaded") + + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, name + ".c") + with open(src_path, "w") as f: + f.write(src) + so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or []) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}{suffix}", binary=True) + + return _load_module_from_path(name, cache_path) diff --git a/third_party/sunrise/python/triton/runtime/cache.py b/third_party/sunrise/python/triton/runtime/cache.py new file mode 100644 index 000000000..7edb75cb8 --- /dev/null +++ b/third_party/sunrise/python/triton/runtime/cache.py @@ -0,0 +1,266 @@ +import json +import os +import uuid +from abc import ABC, abstractmethod +from typing import Dict, List, Optional +import base64 +import hashlib + +from .. import knobs + + +class CacheManager(ABC): + + def __init__(self, key, override=False, dump=False): + pass + + @abstractmethod + def get_file(self, filename) -> Optional[str]: + pass + + @abstractmethod + def put(self, data, filename, binary=True) -> str: + pass + + @abstractmethod + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + pass + + @abstractmethod + def put_group(self, filename: str, group: Dict[str, str]): + pass + + +class FileCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = knobs.cache.dump_dir + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = knobs.cache.override_dir + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = knobs.cache.dir + if self.cache_dir: + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") + + def _make_path(self, filename) -> str: + return os.path.join(self.cache_dir, filename) + + def has_file(self, filename) -> bool: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + return os.path.exists(self._make_path(filename)) + + def get_file(self, filename) -> Optional[str]: + if self.has_file(filename): + return self._make_path(filename) + else: + return None + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + grp_filename = f"__grp__{filename}" + if not self.has_file(grp_filename): + return None + grp_filepath = self._make_path(grp_filename) + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + # Invalid group data. + if child_paths is None: + return None + result = {} + for c, p in child_paths.items(): + if os.path.exists(p): + result[c] = p + return result + + # Note a group of pushed files as being part of a group + def put_group(self, filename: str, group: Dict[str, str]) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + grp_contents = json.dumps({"child_paths": group}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename, binary=False) + + def put(self, data, filename, binary=True) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + binary = isinstance(data, bytes) + if not binary: + data = str(data) + assert self.lock_path is not None + filepath = self._make_path(filename) + # Random ID to avoid any collisions + rnd_id = str(uuid.uuid4()) + # we use the PID in case a bunch of these around so we can see what PID made it + pid = os.getpid() + # use temp dir to be robust against program interruptions + temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}") + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.join(temp_dir, filename) + + mode = "wb" if binary else "w" + with open(temp_path, mode) as f: + f.write(data) + # Replace is guaranteed to be atomic on POSIX systems if it succeeds + # so filepath cannot see a partial write + os.replace(temp_path, filepath) + os.removedirs(temp_dir) + return filepath + + +class RemoteCacheBackend: + """ + A backend implementation for accessing a remote/distributed cache. + """ + + def __init__(self, key: str): + pass + + @abstractmethod + def get(self, filenames: List[str]) -> Dict[str, bytes]: + pass + + @abstractmethod + def put(self, filename: str, data: bytes): + pass + + +class RedisRemoteCacheBackend(RemoteCacheBackend): + + def __init__(self, key): + import redis + self._key = key + self._key_fmt = knobs.cache.redis.key_format + self._redis = redis.Redis( + host=knobs.cache.redis.host, + port=knobs.cache.redis.port, + ) + + def _get_key(self, filename: str) -> str: + return self._key_fmt.format(key=self._key, filename=filename) + + def get(self, filenames: List[str]) -> Dict[str, str]: + results = self._redis.mget([self._get_key(f) for f in filenames]) + return {filename: result for filename, result in zip(filenames, results) if result is not None} + + def put(self, filename: str, data: bytes) -> Dict[str, bytes]: + self._redis.set(self._get_key(filename), data) + + +class RemoteCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + # Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`. + remote_cache_cls = knobs.cache.remote_manager_class + if not remote_cache_cls: + raise RuntimeError( + "Unable to instantiate RemoteCacheManager, TRITON_REMOTE_CACHE_BACKEND doesn't point to a valid class") + self._backend = remote_cache_cls(key) + + self._override = override + self._dump = dump + + # Use a `FileCacheManager` to materialize remote cache paths locally. + self._file_cache_manager = FileCacheManager(key, override=override, dump=dump) + + def _materialize(self, filename: str, data: bytes): + # We use a backing `FileCacheManager` to provide the materialized data. + return self._file_cache_manager.put(data, filename, binary=True) + + def get_file(self, filename: str) -> Optional[str]: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.get_file(filename) + + # We always check the remote cache backend -- even if our internal file- + # based cache has the item -- to make sure LRU accounting works as + # expected. + results = self._backend.get([filename]) + if len(results) == 0: + return None + (_, data), = results.items() + return self._materialize(filename, data) + + def put(self, data, filename: str, binary=True) -> str: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.put(data, filename, binary=binary) + + if not isinstance(data, bytes): + data = str(data).encode("utf-8") + self._backend.put(filename, data) + return self._materialize(filename, data) + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.get_group(filename) + + grp_filename = f"__grp__{filename}" + grp_filepath = self.get_file(grp_filename) + if grp_filepath is None: + return None + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + + result = None + + # Found group data. + if child_paths is not None: + result = {} + for child_path, data in self._backend.get(child_paths).items(): + result[child_path] = self._materialize(child_path, data) + + return result + + def put_group(self, filename: str, group: Dict[str, str]): + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.put_group(filename, group) + + grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename) + + +def _base32(key): + # Assume key is a hex string. + return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=") + + +def get_cache_manager(key) -> CacheManager: + cls = knobs.cache.manager_class or FileCacheManager + return cls(_base32(key)) + + +def get_override_manager(key) -> CacheManager: + cls = knobs.cache.manager_class or FileCacheManager + return cls(_base32(key), override=True) + + +def get_dump_manager(key) -> CacheManager: + cls = knobs.cache.manager_class or FileCacheManager + return cls(_base32(key), dump=True) + + +def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): + # Get unique key for the compiled code + signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()} + key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}" + for kw in kwargs: + key = f"{key}-{kwargs.get(kw)}" + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + return _base32(key) diff --git a/third_party/sunrise/python/triton/runtime/driver.py b/third_party/sunrise/python/triton/runtime/driver.py new file mode 100644 index 000000000..962c1a91c --- /dev/null +++ b/third_party/sunrise/python/triton/runtime/driver.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from ..backends import backends, DriverBase + +from typing import Any, Callable, Generic, TypeVar, Union + + +def _create_driver() -> DriverBase: + active_drivers = [x.driver for x in backends.values() if x.driver.is_active()] + if len(active_drivers) != 1: + raise RuntimeError(f"{len(active_drivers)} active drivers ({active_drivers}). There should only be one.") + return active_drivers[0]() + + +T = TypeVar("T") + + +class LazyProxy(Generic[T]): + + def __init__(self, init_fn: Callable[[], T]) -> None: + self._init_fn = init_fn + self._obj: Union[T, None] = None + + def _initialize_obj(self) -> T: + if self._obj is None: + self._obj = self._init_fn() + return self._obj + + def __getattr__(self, name) -> Any: + return getattr(self._initialize_obj(), name) + + def __setattr__(self, name: str, value: Any) -> None: + if name in ["_init_fn", "_obj"]: + super().__setattr__(name, value) + else: + setattr(self._initialize_obj(), name, value) + + def __delattr__(self, name: str) -> None: + delattr(self._initialize_obj(), name) + + def __repr__(self) -> str: + if self._obj is None: + return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>" + return repr(self._obj) + + def __str__(self) -> str: + return str(self._initialize_obj()) + + +class DriverConfig: + + def __init__(self) -> None: + self.default: LazyProxy[DriverBase] = LazyProxy(_create_driver) + self.active: Union[LazyProxy[DriverBase], DriverBase] = self.default + + def set_active(self, driver: DriverBase) -> None: + self.active = driver + + def reset_active(self) -> None: + self.active = self.default + + +driver = DriverConfig() diff --git a/third_party/sunrise/python/triton/runtime/errors.py b/third_party/sunrise/python/triton/runtime/errors.py new file mode 100644 index 000000000..1a8046430 --- /dev/null +++ b/third_party/sunrise/python/triton/runtime/errors.py @@ -0,0 +1,36 @@ +from ..errors import TritonError +from typing import Optional + + +class InterpreterError(TritonError): + + def __init__(self, error_message: Optional[str] = None): + self.error_message = error_message + + def __str__(self) -> str: + return self.error_message or "" + + +class OutOfResources(TritonError): + + def __init__(self, required, limit, name): + self.required = required + self.limit = limit + self.name = name + + def __str__(self) -> str: + return f"out of resource: {self.name}, Required: {self.required}, Hardware limit: {self.limit}. Reducing block sizes or `num_stages` may help." + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return (type(self), (self.required, self.limit, self.name)) + + +class PTXASError(TritonError): + + def __init__(self, error_message: Optional[str] = None): + self.error_message = error_message + + def __str__(self) -> str: + error_message = self.error_message or "" + return f"PTXAS error: {error_message}" diff --git a/third_party/sunrise/python/triton/runtime/interpreter.py b/third_party/sunrise/python/triton/runtime/interpreter.py new file mode 100644 index 000000000..370bba1ce --- /dev/null +++ b/third_party/sunrise/python/triton/runtime/interpreter.py @@ -0,0 +1,1406 @@ +from __future__ import annotations +import ast +import textwrap +import inspect +from typing import Tuple, List, Dict + +import math +import numpy as np + +import triton +import triton.language as tl +import dataclasses +from dataclasses import dataclass + +from triton.language.semantic import TritonSemantic +from triton.tools.tensor_descriptor import TensorDescriptor +from .errors import InterpreterError +from functools import partial +from .._C.libtriton import interpreter as _interpreter +from .._C.libtriton import ir as _ir + + +@dataclass +class TensorHandle: + ''' + data: numpy array + dtype: triton type, either pointer_type or scalar_type. + we don't store block_type here because the shape information is already available in the data field + attr: a dictionary of attributes + ''' + data: np.array + dtype: tl.dtype + attr: Dict = dataclasses.field(default_factory=dict) + + def __bool__(self): + return bool(self.data.all()) + + def get_element_ty(self): + dtype = self.dtype + while hasattr(dtype, "element_ty"): + dtype = dtype.element_ty + return dtype + + def clone(self): + return TensorHandle(self.data.copy(), self.dtype) + + def set_attr(self, key, value): + self.attr[key] = value + + +class BlockPointerHandle: + + def __init__(self, base, shape, strides, offsets, block_shape, order): + self.base = base + self.shape = shape + self.strides = strides + self.offsets = offsets + self.block_shape = block_shape + self.order = order + + def materialize_pointers(self, boundary_check): + dtype_tt = self.base.get_element_ty() + n_bytes = dtype_tt.primitive_bitwidth // 8 + ptrs = np.broadcast_to(self.base.data, self.block_shape) + masks = np.ones(self.block_shape, dtype=bool) + for dim in range(len(self.block_shape)): + bcast_dims = [1] * len(self.block_shape) + bcast_dims[dim] = self.block_shape[dim] + off = (self.offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims) + ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64) + if dim in boundary_check: + masks = masks & (off < self.shape[dim].data) & (off >= 0) + ptrs = TensorHandle(ptrs, self.base.dtype.scalar) + return ptrs, masks + + +class TensorDescHandle: + + def __init__(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle], + block_shape: List[int]): + self.base = base + self.ndim = len(shape) + self.shape = shape + self.strides = strides + self.block_shape = block_shape + + def validate(self): + assert self.base.data.item() % 16 == 0, "base must be 16-byte aligned" + assert len(self.strides) == self.ndim + assert len(self.block_shape) == self.ndim + + for stride in self.strides[:-1]: + assert stride.data.item() % 16 == 0, "stride must be 16-byte aligned" + assert self.strides[-1].data.item() == 1, "last dim must be contiguous" + + def materialize_pointers(self, offsets: List[TensorHandle]): + assert len(offsets) == self.ndim + scalar_ty = self.base.dtype.element_ty + itemsize = scalar_ty.primitive_bitwidth // 8 + assert (offsets[-1].data * itemsize) % 16 == 0, "block offset start must be 16-byte aligned" + + ptrs = np.broadcast_to(self.base.data, self.block_shape) + masks = np.ones(self.block_shape, dtype=bool) + for dim in range(len(self.block_shape)): + bcast_dims = [1] * len(self.block_shape) + bcast_dims[dim] = self.block_shape[dim] + off = (offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims) + ptrs = ptrs + (itemsize * off * self.strides[dim].data).astype(np.uint64) + masks = masks & (0 <= off) & (off < self.shape[dim].data) + assert ptrs.dtype == np.uint64 + ptrs = TensorHandle(ptrs, self.base.dtype.scalar) + return ptrs, masks + + +@dataclass(frozen=True) +class InterpreterOptions: + extern_libs: dict = None + debug: bool = False + sanitize_overflow: bool = True + arch: str = None + supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15") + deprecated_fp8_dot_operand_dtypes: Tuple[str] = () + default_dot_input_precision: str = "tf32" + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") + max_num_imprecise_acc_default: int = 0 + backend_name: str = "interpreter" + + +def _get_signed_np_dtype(dtype): + if dtype == np.uint8: + return np.int8 + if dtype == np.uint16: + return np.int16 + if dtype == np.uint32: + return np.int32 + if dtype == np.uint64: + return np.int64 + return dtype + + +def _get_np_dtype(tt_dtype): + if isinstance(tt_dtype, tl.pointer_type): + return np.dtype(np.uint64) + np_types = { + tl.int1: np.dtype(bool), + tl.float16: np.dtype(np.float16), + tl.float32: np.dtype(np.float32), + tl.float64: np.dtype(np.float64), + tl.int8: np.dtype(np.int8), + tl.uint8: np.dtype(np.uint8), + tl.int16: np.dtype(np.int16), + tl.uint16: np.dtype(np.uint16), + tl.int32: np.dtype(np.int32), + tl.uint32: np.dtype(np.uint32), + tl.int64: np.dtype(np.int64), + tl.uint64: np.dtype(np.uint64), + # bfloat16 types are stored as uint16 + tl.bfloat16: np.dtype(np.uint16), + # float8 types are stored as uint8 + tl.float8e5: np.dtype(np.uint8), + tl.float8e5b16: np.dtype(np.uint8), + tl.float8e4nv: np.dtype(np.uint8), + tl.float8e4b8: np.dtype(np.uint8), + tl.float8e4b15: np.dtype(np.uint8), + } + if isinstance(tt_dtype, tl.block_type): + if isinstance(tt_dtype.element_ty, tl.pointer_type): + return np.dtype(np.uint64) + return np_types[tt_dtype.element_ty] + return np_types[tt_dtype] + + +def _convert_float(input, input_dtype, output_dtype, rounding_mode): + input_uint_dtype = getattr(np, f"uint{input_dtype.primitive_bitwidth}") + output_unint_dtype = getattr(np, f"uint{output_dtype.primitive_bitwidth}") + input_bin = np.frombuffer(input.tobytes(), dtype=input_uint_dtype) + sign = (input_bin >> (input_dtype.primitive_bitwidth - 1)) & 0x01 + input_exponent_width = input_dtype.primitive_bitwidth - input_dtype.fp_mantissa_width - 1 + output_exponent_width = output_dtype.primitive_bitwidth - output_dtype.fp_mantissa_width - 1 + significand = input_bin & ((1 << input_dtype.fp_mantissa_width) - 1) + bias_input = input_dtype.exponent_bias + bias_output = output_dtype.exponent_bias + exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + subnormal_index = exponent == 0 + if np.any(subnormal_index): + # Credit to Phil: phil@openai.com + # subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (2^(m0) + 2^(m1) + ... + 2^(mn)) + # where m0, m1, ..., mn are the 1-bit of the mantissa + # convert it to normal repr: ((-1.0)**sign) * (2.0**(1 + m0 - exp_bias)) * (1 + 2^(m1 - m0) + ... + 2^(mn - m0)) + bit_pos = np.zeros_like(input_bin, dtype=np.int32) + # Find the most significant bit of the mantissa in the significand + for i in range(input_dtype.fp_mantissa_width): + bit_index = ((significand >> i) & 0x01) + # pos should be >= 1 + bit_pos[bit_index == 1] = input_dtype.fp_mantissa_width - i + zero_significand_index = significand == 0 + exponent[subnormal_index] = 1 - bit_pos[subnormal_index] + # 0 significand and subnormal should be treated as 0 + exponent[zero_significand_index & subnormal_index] = bias_input - bias_output + significand[subnormal_index] = (significand[subnormal_index] << bit_pos[subnormal_index]) & ( + (1 << input_dtype.fp_mantissa_width) - 1) + # Prevent overflow and underflow + exponent_output = np.maximum(0, np.minimum((exponent - bias_input + bias_output), (1 << output_exponent_width) - 1)) + exponent_output = exponent_output.astype(output_unint_dtype) + sign_output = sign.astype(output_unint_dtype) + if input_dtype.primitive_bitwidth > output_dtype.primitive_bitwidth: # Downcast + significand_output = (significand >> (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width)) & ( + (1 << output_dtype.fp_mantissa_width) - 1) + if rounding_mode == _ir.ROUNDING_MODE.RTNE: # Round to nearst even + # find the cut-off bit + cut_off = significand & (1 << (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width - 1)) + significand_output = significand_output + (cut_off > 0) + significand_output = significand_output.astype(output_unint_dtype) + else: # Upcast + significand_output = (significand.astype(output_unint_dtype) << + (output_dtype.fp_mantissa_width - input_dtype.fp_mantissa_width)) & ( + (1 << output_dtype.fp_mantissa_width) - 1) + subnormal_index = exponent_output == 0 + if np.any(subnormal_index): # underflow + # normal repr: ((-1.0)**sign) * (2.0**(exp - exp_bias_input)) * (1 + 2^(m0) + 2^(m1) + ... + 2^(mn)) + # where m0, m1, ..., mn are the 1-bit of the mantissa + # shift = (1 - exp_bias_output) - (exp - exp_bias_input) + # convert it to subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias_output)) * (2^(-shift) + 2^(m0 - shift) + 2^(m1 - shift) + ... + 2^(mn - shift)) + exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + non_zero_exponent_index = exponent != 0 + # If the original exponent is not zero, we still need to shift the significand and consider the 1.0 part in mantissa + subnormal_index = subnormal_index & non_zero_exponent_index + shift = np.zeros_like(input_bin, dtype=np.int32) + shift[subnormal_index] = (1 - bias_output) - (exponent[subnormal_index] - bias_input) + significand_output[subnormal_index] = (significand_output[subnormal_index] >> shift[subnormal_index]) | ( + 1 << (output_dtype.fp_mantissa_width - shift[subnormal_index])) + output = (sign_output << (output_dtype.primitive_bitwidth - 1)) | ( + exponent_output << output_dtype.fp_mantissa_width) | significand_output + return output.reshape(input.shape) + + +def _erf(x): + # Numpy does not support erf + return math.erf(x) + + +def _umulhi_64(a, b): + # Numpy does not support 128-bit multiplication + # So we have to implement it manually + return (int(a) * int(b)) >> 64 + + +np_erf_fp32 = np.vectorize(_erf, otypes=[np.float32]) +np_erf_fp64 = np.vectorize(_erf, otypes=[np.float64]) +np_umulhi_u64 = np.vectorize(_umulhi_64, otypes=[np.uint64]) + + +class ExtraFunctions: + + @staticmethod + def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _semantic): + return tl.tensor(_semantic.builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty) + + +class InterpreterBuilder: + ir_sem_to_interpreter_sem = { + _ir.MEM_SEMANTIC.ACQUIRE: _interpreter.MEM_SEMANTIC.ACQUIRE, + _ir.MEM_SEMANTIC.RELEASE: _interpreter.MEM_SEMANTIC.RELEASE, + _ir.MEM_SEMANTIC.RELAXED: _interpreter.MEM_SEMANTIC.RELAXED, + _ir.MEM_SEMANTIC.ACQUIRE_RELEASE: _interpreter.MEM_SEMANTIC.ACQUIRE_RELEASE, + } + + ir_rmw_op_to_interpreter_rmw_op = { + _ir.ATOMIC_OP.ADD: _interpreter.RMW_OP.ADD, + _ir.ATOMIC_OP.FADD: _interpreter.RMW_OP.FADD, + _ir.ATOMIC_OP.MIN: _interpreter.RMW_OP.MIN, + _ir.ATOMIC_OP.UMIN: _interpreter.RMW_OP.UMIN, + _ir.ATOMIC_OP.MAX: _interpreter.RMW_OP.MAX, + _ir.ATOMIC_OP.UMAX: _interpreter.RMW_OP.UMAX, + _ir.ATOMIC_OP.AND: _interpreter.RMW_OP.AND, + _ir.ATOMIC_OP.OR: _interpreter.RMW_OP.OR, + _ir.ATOMIC_OP.XOR: _interpreter.RMW_OP.XOR, + _ir.ATOMIC_OP.XCHG: _interpreter.RMW_OP.XCHG, + } + + def __init__(self) -> None: + self.arch = None + self.options = InterpreterOptions() + self.codegen_fns = {} + self.codegen_fns["convert_custom_types"] = ExtraFunctions._convert_custom_types + self.codegen_fns["min_dot_size"] = lambda lhsType, rhsType: (1, 1, 1) + + def set_grid_idx(self, x, y, z): + if not x < self.grid_dim[0]: + raise ValueError("x >= grid_dim[0]") + if not y < self.grid_dim[1]: + raise ValueError("y >= grid_dim[1]") + if not z < self.grid_dim[2]: + raise ValueError("z >= grid_dim[2]") + self.grid_idx = (x, y, z) + + def set_grid_dim(self, nx, ny, nz): + self.grid_dim = (nx, ny, nz) + + # constants + + def get_half_ty(self): + return tl.float16 + + def get_bf16_ty(self): + return tl.bfloat16 + + def get_float_ty(self): + return tl.float32 + + def get_double_ty(self): + return tl.float64 + + def get_int1_ty(self): + return tl.int1 + + def get_int8_ty(self): + return tl.int8 + + def get_uint8_ty(self): + return tl.uint8 + + def get_int16_ty(self): + return tl.int16 + + def get_uint16_ty(self): + return tl.uint16 + + def get_int32_ty(self): + return tl.int32 + + def get_uint32_ty(self): + return tl.uint32 + + def get_int64_ty(self): + return tl.int64 + + def get_uint64_ty(self): + return tl.uint64 + + def get_fp8e4nv_ty(self): + return tl.float8e4nv + + def get_fp8e4b15_ty(self): + return tl.float8e4b15 + + def get_fp8e4b8_ty(self): + return tl.float8e4b8 + + def get_fp8e5_ty(self): + return tl.float8e5 + + def get_fp8e5b16_ty(self): + return tl.float8e5b16 + + def get_ptr_ty(self, elt_ty, addr_space): + return tl.pointer_type(elt_ty, addr_space) + + def get_block_ty(self, dtype, shape): + return tl.block_type(dtype, shape) + + def get_int1(self, value): + return TensorHandle(np.array([value], dtype=np.bool_), tl.int1) + + def get_uint8(self, value): + return TensorHandle(np.array([value], dtype=np.uint8), tl.uint8) + + def get_int8(self, value): + return TensorHandle(np.array([value], dtype=np.int8), tl.int8) + + def get_uint16(self, value): + return TensorHandle(np.array([value], dtype=np.uint16), tl.uint16) + + def get_int16(self, value): + return TensorHandle(np.array([value], dtype=np.int16), tl.int16) + + def get_uint32(self, value): + return TensorHandle(np.array([value], dtype=np.uint32), tl.uint32) + + def get_int32(self, value): + return TensorHandle(np.array([value], dtype=np.int32), tl.int32) + + def get_uint64(self, value): + return TensorHandle(np.array([value], dtype=np.uint64), tl.uint64) + + def get_int64(self, value): + return TensorHandle(np.array([value], dtype=np.int64), tl.int64) + + def get_fp16(self, value): + return TensorHandle(np.array([value], dtype=np.float16), tl.float16) + + def get_fp32(self, value): + return TensorHandle(np.array([value], dtype=np.float32), tl.float32) + + def get_fp64(self, value): + return TensorHandle(np.array([value], dtype=np.float64), tl.float64) + + def get_null_value(self, type): + return TensorHandle(np.array([0], dtype=_get_np_dtype(type)), type) + + # programming model + def create_get_program_id(self, axis): + if self.grid_idx is None: + raise ValueError("grid_idx is None") + return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32) + + def create_get_num_programs(self, axis): + return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32) + + # memory ops + def create_load(self, ptr, _0, _1, is_volatile): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + other = None + return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile) + + def create_store(self, ptr, val, _0, _1): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + return self.create_masked_store(ptr, val, mask, None, None) + + def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile): + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + if other is None: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np) + return TensorHandle(ret, dtype_tt) + + def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy): + return _interpreter.store(ptrs.data, value.data, mask.data) + + # casting ops + def cast_impl(self, src, dst_type): + src_element_type = src.dtype.scalar + dst_element_type = dst_type.scalar + if (src_element_type == tl.bfloat16 and dst_element_type == tl.float32) or \ + (src_element_type == tl.float32 and dst_element_type == tl.bfloat16): + data = _convert_float(src.data, src_element_type, dst_element_type, None).view(_get_np_dtype(dst_type)) + return TensorHandle(data, dst_type.scalar) + else: + return TensorHandle(src.data.astype(_get_np_dtype(dst_type)), dst_type.scalar) + + create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type) + + def create_fp_to_fp(self, src, dst_type, rounding_mode): + src_element_type = src.dtype.scalar + dst_element_type = dst_type.scalar + data = _convert_float(src.data, src_element_type, dst_element_type, rounding_mode).view(_get_np_dtype(dst_type)) + return TensorHandle(data, dst_type.scalar) + + def create_bitcast(self, src, dst_type): + return TensorHandle(src.data.view(_get_np_dtype(dst_type)), dst_type.scalar) + + # binary operators + def binary_op(self, lhs, rhs, op): + return TensorHandle(op(lhs.data, rhs.data), lhs.dtype.scalar) + + create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_precise_divf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_sdiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs) + create_udiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs) + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift) + create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift) + create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minimumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maximumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and) + create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor) + create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or) + create_int_to_ptr = create_bitcast + create_ptr_to_int = create_bitcast + + def create_idiv(self, lhs, rhs): + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + return TensorHandle((lhs.data - np.fmod(lhs.data, rhs.data)) // rhs.data, lhs.dtype.scalar) + + def create_ashr(self, lhs, rhs): + # Triton's rshift operator depends on the signedness of the left operand + lhs_dtype = _get_signed_np_dtype(lhs.data.dtype) + rhs_dtype = _get_signed_np_dtype(rhs.data.dtype) + lhs.data = lhs.data.astype(lhs_dtype) + rhs.data = rhs.data.astype(rhs_dtype) + return self.binary_op(lhs, rhs, np.right_shift) + + def create_umulhi(self, lhs, rhs): + dtype = lhs.data.dtype + if dtype == np.int64 or dtype == np.uint64: + return TensorHandle(np_umulhi_u64(lhs.data, rhs.data), lhs.dtype.scalar) + else: + compute_dtype = getattr(np, f"uint{dtype.itemsize * 8 * 2}") + lhs_data = lhs.data.astype(compute_dtype) + rhs_data = rhs.data.astype(compute_dtype) + ret_data = np.multiply(lhs_data, rhs_data) >> (dtype.itemsize * 8) + return TensorHandle(ret_data.astype(dtype), lhs.dtype.scalar) + + # ternary functions + def ternary_op(self, lhs, rhs, other, op): + return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype.scalar) + + create_clampf = lambda self, arg, lo, hi, propagate_nans: self.ternary_op(arg, lo, hi, np.clip) + create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where) + + def create_fma(self, x, y, z): + return TensorHandle(x.data * y.data + z.data, z.dtype.scalar) + + # unary functions + def unary_op(self, arg, op): + return TensorHandle(op(arg.data), arg.dtype.scalar) + + def create_fabs(self, arg): + # Mask out the sign bit based on the primitive length + dtype_tt = arg.dtype + mask_bitwidth = dtype_tt.primitive_bitwidth - 1 + np_uint_dtype = getattr(np, f"uint{dtype_tt.primitive_bitwidth}") + data = arg.data.view(np_uint_dtype) + mask = (1 << mask_bitwidth) - 1 + ret = (data & mask).view(_get_np_dtype(dtype_tt)) + return TensorHandle(ret, arg.dtype.scalar) + + create_cos = lambda self, arg: self.unary_op(arg, np.cos) + create_exp = lambda self, arg: self.unary_op(arg, np.exp) + create_exp2 = lambda self, arg: self.unary_op(arg, np.exp2) + create_iabs = lambda self, arg: self.unary_op(arg, np.abs) + create_floor = lambda self, arg: self.unary_op(arg, np.floor) + create_ceil = lambda self, arg: self.unary_op(arg, np.ceil) + create_log = lambda self, arg: self.unary_op(arg, np.log) + create_log2 = lambda self, arg: self.unary_op(arg, np.log2) + create_precise_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_sin = lambda self, arg: self.unary_op(arg, np.sin) + + def create_erf(self, arg): + ret = np_erf_fp32(arg.data) if arg.data.dtype == np.float32 else np_erf_fp64(arg.data) + return TensorHandle(ret, arg.dtype.scalar) + + def create_rsqrt(self, arg): + return TensorHandle(1 / np.sqrt(arg.data), arg.dtype.scalar) + + # tensor operators + create_reshape = lambda self, arg, shape, allow_reorder: TensorHandle(arg.data.reshape(shape), arg.dtype.scalar) + + def create_trans(self, arg, perm): + return TensorHandle(np.transpose(arg.data, perm), arg.dtype.scalar) + + def create_dot(self, a, b, d, input_precision, max_num_imprecise_acc): + a_data = a.data + b_data = b.data + if (a.dtype.primitive_bitwidth == 8 and a.dtype.is_floating()) or \ + (b.dtype.primitive_bitwidth == 8 and b.dtype.is_floating()): + a_data = _convert_float(a_data, a.dtype, tl.float16, None).view(np.float16) + b_data = _convert_float(b_data, b.dtype, tl.float16, None).view(np.float16) + return TensorHandle(np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data, d.dtype.scalar) + + def create_make_range(self, ret_ty, start, stop): + return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32) + + def create_histogram(self, data, bins, mask): + if mask is None: + mask = TensorHandle(np.ones_like(data.data, dtype=bool), tl.int1) + # force all masked elements to zero + data = np.where(mask.data, data.data, np.zeros_like(data.data)) + histogram = np.histogram(data, bins=bins, range=(0, bins))[0] + # remove overcounted elements + histogram[0] -= np.logical_not(mask.data).sum() + return TensorHandle(histogram, tl.int32) + + def create_gather(self, src, indices, axis): + return TensorHandle(np.take_along_axis(src.data, indices.data, axis=axis), src.dtype.scalar) + + # pointer arithmetic + + def create_addptr(self, ptr, offset): + dtype_tt = ptr.get_element_ty() + element_bitwidth = dtype_tt.primitive_bitwidth + # int1's bitwidth is 1, but we need to use 8 for pointer arithmetic + element_bytewidth = max(1, element_bitwidth // 8) + return TensorHandle(ptr.data + element_bytewidth * offset.data.astype(np.uint64), ptr.dtype) + + def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, + is_volatile): + ptrs, masks = ptr.materialize_pointers(boundary_check) + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + if padding_option is None: + other = None + elif padding_option == _ir.PADDING_OPTION.PAD_ZERO: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + elif padding_option == _ir.PADDING_OPTION.PAD_NAN: + other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt) + else: + raise ValueError(f"unsupported padding option {padding_option}") + return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile) + + def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy): + ptrs, masks = ptr.materialize_pointers(boundary_check) + return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy) + + def create_expand_dims(self, arg, axis): + return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype.scalar) + + def create_broadcast(self, arg, shape): + return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype.scalar) + + def create_cat(self, lhs, rhs): + return TensorHandle(np.concatenate([lhs.data, rhs.data]), lhs.dtype.scalar) + + def create_join(self, lhs, rhs): + # Triton only supports joining two original tensors into a new one along the last axis + return TensorHandle(np.stack([lhs.data, rhs.data], axis=-1), lhs.dtype.scalar) + + def create_split(self, val): + # Triton only supports splitting the original tensor into two along the last axis + return (TensorHandle(val.data[..., 0], val.dtype.scalar), TensorHandle(val.data[..., 1], val.dtype.scalar)) + + def create_splat(self, ret_ty, arg): + shape = ret_ty.shape + if isinstance(arg.dtype, tl.block_type): + return TensorHandle(np.full(shape, arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + else: # scalar + return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + + def create_atomic_cas(self, ptr, cmp, val, sem, scope): + if sem not in self.ir_sem_to_interpreter_sem: + raise ValueError(f"unsupported semantic {sem}") + sem = self.ir_sem_to_interpreter_sem[sem] + return TensorHandle(_interpreter.atomic_cas(ptr.data, cmp.data, val.data, sem), cmp.dtype.scalar) + + def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem, scope): + if rmwOp not in self.ir_rmw_op_to_interpreter_rmw_op: + raise ValueError(f"unsupported rmwOp {rmwOp}") + if sem not in self.ir_sem_to_interpreter_sem: + raise ValueError(f"unsupported semantic {sem}") + rmwOp = self.ir_rmw_op_to_interpreter_rmw_op[rmwOp] + sem = self.ir_sem_to_interpreter_sem[sem] + return TensorHandle(_interpreter.atomic_rmw(rmwOp, ptr.data, val.data, mask.data, sem), val.dtype.scalar) + + def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure): + raise NotImplementedError("extern_elementwise not supported in interpreter mode") + + def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack): + raise NotImplementedError("inline_asm not supported in interpreter mode") + + def create_print(self, prefix, hex, values, isSigned): + # NOTE: the `isSigned` variable is not really used here; because Signness is already known + # by `values` themselves in python interpreter, thus not really needed here; + # it is only used for triton PrintOpToLLVM to correctly construct the format specifier. + # Interpreter's device_print function has a different format than Triton's device_print + msg = f"({self.grid_idx[0]}, {self.grid_idx[1]}, {self.grid_idx[2]})" + if prefix: + msg += f" {prefix}" + if hex: + np.set_printoptions(formatter={'all': lambda x: f"0x{x:02x}"}) + for value in values: + print(msg + f" {value.data}") + if hex: + np.set_printoptions(formatter=None) + + def create_assert(self, condition, message): + # Interpreter's device_assert function has a different format than Triton's device_assert + assert condition, f"{message}" + + def create_assume(self, condition): + assert condition, "Assume failed" + + def create_barrier(self): + # Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter + pass + + def create_make_block_ptr(self, base, shape, strides, offsets, block_shape, order): + # Create new offsets to avoid modifying the original + new_offsets = [offset.clone() for offset in offsets] + return BlockPointerHandle(base, shape, strides, new_offsets, block_shape, order) + + def create_advance(self, ptr, offsets): + if len(ptr.offsets) != len(offsets): + raise ValueError("len(ptr.offsets) != len(offsets)") + # Create new offsets to avoid modifying the original + new_offsets = [offset.clone() for offset in ptr.offsets] + ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.block_shape, ptr.order) + for i in range(len(offsets)): + ret.offsets[i].data += offsets[i].data + return ret + + def create_make_tensor_descriptor( + self, + base: TensorHandle, + shape: List[TensorHandle], + strides: List[TensorHandle], + tensor_shape: List[int], + is_signed: bool, + ): + desc = TensorDescHandle(base, shape, strides, tensor_shape) + desc.validate() + return desc + + def create_descriptor_load(self, desc: TensorDescHandle, indices: List[TensorHandle], cache_modifier, + eviction_policy): + assert isinstance(desc, TensorDescHandle) + ptrs, mask = desc.materialize_pointers(indices) + return self.create_masked_load(ptrs, mask, other=None, cache_modifier=cache_modifier, + eviction_policy=eviction_policy, is_volatile=False) + + def create_descriptor_store(self, desc: TensorDescHandle, value: TensorHandle, indices: List[TensorHandle]): + ptrs, mask = desc.materialize_pointers(indices) + return self.create_masked_store(ptrs, value, mask, None, None) + + def create_descriptor_gather(self, desc: TensorDescHandle, x_offsets: TensorHandle, y_offset: TensorHandle, type): + dtype = desc.base.dtype.element_ty + np_dtype = _get_np_dtype(dtype) + result = np.zeros([x_offsets.data.shape[0], desc.block_shape[-1]], dtype=np_dtype) + cache_modifier = None + eviction_policy = None + for i, x_offset in enumerate(x_offsets.data): + indices = [TensorHandle(x_offset, tl.int32), y_offset] + result[i, :] = self.create_descriptor_load(desc, indices, cache_modifier, eviction_policy).data + return TensorHandle(result, dtype) + + def create_descriptor_scatter(self, desc: TensorDescHandle, value: TensorHandle, x_offsets: TensorHandle, + y_offset: TensorHandle): + for i, x_offset in enumerate(x_offsets.data): + slice = TensorHandle(value.data[i], value.dtype) + indices = [TensorHandle(x_offset, tl.int32), y_offset] + self.create_descriptor_store(desc, slice, indices) + + def get_all_ones_value(self, type): + np_type = _get_np_dtype(type) + if "int" in np_type.name: + return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar) + elif np_type == np.bool_: + return TensorHandle(np.full(1, True, dtype=np_type), type.scalar) + else: + raise TypeError(f"unsupported type {type}") + + +def _patch_attr(obj, name, member, builder): + semantic = TritonSemantic(builder) + new_member = lambda *args, member=member, **kwargs: (member(*args, ** + {k: v + for k, v in kwargs.items() + if k != "_semantic"}, _semantic=semantic)) + setattr(obj, name, new_member) + + +def _patch_builtin(pkg, builder): + for name, member in inspect.getmembers(pkg): + if tl.core.is_builtin(member): + _patch_attr(pkg, name, member, builder) + + +def _patch_lang_tensor(tensor): + + def _get_bool(self): + data = self.handle.data + # in triton, only scalars can be converted to booleans + # here we need this hack because all scalars are tensors + return bool(data) if data.size == 1 else True + + def _get_transpose(self): + handle = TensorHandle(np.transpose(self.handle.data), self.handle.dtype) + assert self.type.is_block() + block_shape = list(self.type.shape) + block_shape[-1], block_shape[-2] = block_shape[-2], block_shape[-1] + res_ty = tl.core.block_type(self.dtype, block_shape) + return tl.core.tensor(handle, res_ty) + + tensor.__index__ = lambda self: int(self.handle.data) + tensor.__bool__ = lambda self: _get_bool(self) + tensor.__repr__ = lambda self: repr(self.handle.data) + tensor.__str__ = lambda self: str(self.handle.data) + tensor.T = property(_get_transpose) + + +class ReduceScanOpInterface: + + def __init__(self, axis, combine_fn): + self.axis = axis + self.combine_fn = combine_fn + + def check_axis(self, shape, axis): + if axis is not None and axis >= len(shape): + raise ValueError(f"axis {axis} out of bounds for shape {shape}") + + def check_tensor(self, input): + for arg in input: + if not isinstance(arg, tl.core.tensor): + raise ValueError(f"input must be a tensor, got {type(arg)}") + self.check_axis(arg.shape, self.axis) + + def to_tensor(self, ret, dtype): + np_dtype = _get_np_dtype(dtype) + if hasattr(ret, "shape") and ret.shape: + ret = ret.astype(np_dtype) + ret_type = tl.block_type(dtype, list(ret.shape)) + else: + ret = np.array([ret], dtype=np_dtype) + ret_type = dtype + return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type) + + def apply(self, input): + if not isinstance(input, tuple): + return self.apply((input, ))[0] + self.check_tensor(input) + ret = self.apply_impl(input) + return tuple(ret) if isinstance(ret, (list, tuple)) else (ret, ) + + +class ReduceOps(ReduceScanOpInterface): + + def __init__(self, axis, combine_fn, keep_dims): + super().__init__(axis, combine_fn) + self.keep_dims = keep_dims + + def unravel(self, input, axis): + ret = [] + for data in input: + if axis is not None: + ret.append(data) + else: + axis = 0 + ret.append(self.to_tensor(data.handle.data.flatten(), data.dtype)) + return tuple(ret), axis + + def generic_reduce(self, input): + original_axis = self.axis + input, axis = self.unravel(input, self.axis) + input_data = [] + output_data = [] + input_shape = input[0].handle.data.shape + output_shape = input_shape[0:axis] + input_shape[axis + 1:] + for arg in input: + input_data.append(arg.handle.data) + output_data.append(np.zeros(output_shape, dtype=arg.handle.data.dtype)) + # Reduce on axis + for i in range(input_data[0].size): + # Recover input_index from i using input_shape + input_index = np.unravel_index(i, input_shape) + output_index = input_index[0:axis] + input_index[axis + 1:] + input_tuple = tuple(self.to_tensor(d[input_index], input[ii].dtype) for ii, d in enumerate(input_data)) + if input_index[axis] == 0: + # First element + for j in range(len(output_data)): + output_data[j][output_index] = input_tuple[j].handle.data.item() + else: + acc_tuple = tuple(self.to_tensor(o[output_index], input[oi].dtype) for oi, o in enumerate(output_data)) + combine_fn_ret = self.combine_fn.fn(*acc_tuple, *input_tuple) + acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret + for j in range(len(output_data)): + output_data[j][output_index] = acc_tuple[j].handle.data.item() if isinstance( + acc_tuple[j], tl.core.tensor) else acc_tuple[j] + # Pack output + ret = [] + for i, data in enumerate(output_data): + if self.keep_dims: + if original_axis is not None: + data = np.expand_dims(data, axis) + else: + for _ in range(len(input_shape)): + data = np.expand_dims(data, 0) + + elif original_axis is None: + # Take a scalar + data = data.item() + ret.append(self.to_tensor(data, input[i].dtype)) + return ret + + def min_max(self, input, val_reduce_op, idx_reduce_op=None): + # If input is a tuple, it must be (val, index), and we only take val + input = input[0] if isinstance(input, tuple) else input + val = None + idx = None + if val_reduce_op: + val = self.to_tensor(val_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype) + if idx_reduce_op: + idx = self.to_tensor(idx_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), tl.int32) + if val is not None and idx is not None: + return val, idx + elif val is not None: + return val + elif idx is not None: + return idx + else: + raise ValueError("val_reduce_op and idx_reduce_op are both None") + + def sum(self, input): + return self.to_tensor(np.sum(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype) + + def apply_impl(self, input): + if self.combine_fn == tl.standard._argmin_combine_tie_break_left: + return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=np.argmin) + elif self.combine_fn == tl.standard._argmax_combine_tie_break_left: + return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax) + elif self.combine_fn == tl.standard._elementwise_max: + return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=None) + elif self.combine_fn == tl.standard._elementwise_min: + return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=None) + elif self.combine_fn == tl.standard._sum_combine: + return self.sum(input[0]) + else: + # Fall back to the slow mode + return self.generic_reduce(input) + + +class ScanOps(ReduceScanOpInterface): + + def __init__(self, axis, combine_fn, reverse): + super().__init__(axis, combine_fn) + self.reverse = reverse + + def cumsum(self, input): + return [self.to_tensor(np.cumsum(input.handle.data, axis=self.axis), dtype=input.dtype)] + + def cumprod(self, input): + return [self.to_tensor(np.cumprod(input.handle.data, axis=self.axis), dtype=input.dtype)] + + def generic_scan(self, input): + input_data = [] + output_data = [] + shape = input[0].handle.data.shape + for arg in input: + input_data.append(arg.handle.data) + output_data.append(np.zeros(shape, dtype=arg.handle.data.dtype)) + # Scan on axis + for i in range(input_data[0].size): + # Recover index from i using shape + index = np.unravel_index(i, shape) + data = tuple(self.to_tensor(d[index], input[ii].dtype) for ii, d in enumerate(input_data)) + if index[self.axis] == 0: + # First element + for j in range(len(output_data)): + output_data[j][index] = data[j].handle.data.item() + else: + prev_index = tuple(index[i] - 1 if i == self.axis else index[i] for i in range(len(index))) + acc_tuple = tuple(self.to_tensor(o[prev_index], input[oi].dtype) for oi, o in enumerate(output_data)) + combine_fn_ret = self.combine_fn.fn(*acc_tuple, *data) + acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret + for j in range(len(output_data)): + output_data[j][index] = acc_tuple[j].handle.data.item() if isinstance( + acc_tuple[j], tl.core.tensor) else acc_tuple[j] + # Pack output + ret = [] + for i, data in enumerate(output_data): + ret.append(self.to_tensor(data, input[i].dtype)) + return ret + + def apply_impl(self, input): + new_input = [] + if self.reverse: + for arg in input: + new_input.append(self.to_tensor(np.flip(arg.handle.data, axis=self.axis), arg.dtype)) + else: + new_input = input + if self.combine_fn == tl.standard._sum_combine: + ret = self.cumsum(new_input[0]) + elif self.combine_fn == tl.standard._prod_combine: + ret = self.cumprod(new_input[0]) + else: + # Fall back to the slow mode + ret = self.generic_scan(new_input) + if self.reverse: + for arg in ret: + arg.handle.data = np.flip(arg.handle.data, axis=self.axis) + return ret + + +def _patch_reduce_scan(): + # Because interpreter doesn't support region_builder_fn, we cannot patch the builder + # to use the new reduce and scan functions. + # Instead, we need to patch reduce and reduce functions in tl and tl.core + def _new_reduce(input, axis, combine_fn, keep_dims=False, **kwargs): + return ReduceOps(axis, combine_fn, keep_dims).apply(input) + + def _new_scan(input, axis, combine_fn, reverse=False, **kwargs): + return ScanOps(axis, combine_fn, reverse).apply(input) + + tl.reduce = _new_reduce + tl.associative_scan = _new_scan + tl.core.reduce = _new_reduce + tl.core.associative_scan = _new_scan + + +def _patch_lang_core(lang): + + def _new_to_ir(self, builder): + # We need to specify signedness for integer types in the numpy mode + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name == 'int8': + return builder.get_int8_ty() + elif self.name == 'uint8': + return builder.get_uint8_ty() + elif self.name == 'int16': + return builder.get_int16_ty() + elif self.name == 'uint16': + return builder.get_uint16_ty() + elif self.name == 'int32': + return builder.get_int32_ty() + elif self.name == 'uint32': + return builder.get_uint32_ty() + elif self.name == 'int64': + return builder.get_int64_ty() + elif self.name == 'uint64': + return builder.get_uint64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + # can't just map lang.static_range to `range`, because `tl.static_range` + # can get `step` passed by keyword + def _new_range(arg1, arg2=None, step=None, **kwargs): + if step is None: + step = 1 + if arg2 is None: + start, end = 0, arg1 + else: + start, end = arg1, arg2 + return range(start, end, step) + + def _new_static_assert(cond, msg=""): + assert cond, msg + + def _set_attr(input, values, name): + # skip non tensor types. This may happen for induction variables. + if not isinstance(input, tl.tensor): + return input + # Unwrap constexpr + values = [values] if not isinstance(values, (list, tuple)) else values + values = [v.value if isinstance(v, tl.constexpr) else v for v in values] + if len(values) != max(1, len(input.shape)): + raise ValueError(f"len(values) != len(input.shape) for {name}") + input.handle.set_attr(name, values) + return input + + lang.range = _new_range + lang.static_range = _new_range + lang.static_assert = _new_static_assert + lang.static_print = print + lang.dtype.to_ir = _new_to_ir + lang.multiple_of = partial(_set_attr, name="tt.divisibility") + lang.max_contiguous = partial(_set_attr, name="tt.contiguity") + lang.max_constancy = partial(_set_attr, name="tt.constancy") + + _patch_reduce_scan() + + +def _patch_lang(fn): + langs = [value for _, value in fn.__globals__.items() if inspect.ismodule(value) and value in [tl, tl.core]] + assert len(langs) >= 1, "triton.language must be visible from within jit'd function" + for lang in langs: + _patch_builtin(lang, interpreter_builder) + _patch_builtin(lang.tensor, interpreter_builder) + if lang == tl: + _patch_builtin(lang.math, interpreter_builder) + _patch_lang_tensor(lang.tensor) + _patch_lang_core(lang) + _patch_builtin(tl.core.tensor_descriptor_base, interpreter_builder) + + +def _tuple_create(arg, contents): + # NamedTuples and tuples have different construction semantics. NamedTuple + # has a constructor that takes individual arguments, while tuple takes an + # iterable. Both have type "tuple" making it difficult to distinguish + # between them, but only NamedTuple has "_fields" and apparently this is how + # everyone does the check. + return type(arg)(*contents) if hasattr(arg, "_fields") else type(arg)(contents) + + +# TODO: wrap everything in triton tensors +def _implicit_cvt(arg): + if isinstance(arg, int): + ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg)) + dtype = np.int32 + if -2**31 <= arg < 2**31: + dtype = np.int32 + elif 2**31 <= arg < 2**32: + dtype = np.uint32 + elif -2**63 <= arg < 2**63: + dtype = np.int64 + elif 2**63 <= arg < 2**64: + dtype = np.uint64 + else: + raise ValueError(f"Unsupported integer value {arg}") + handle = TensorHandle(np.array([arg], dtype=dtype), ty) + return tl.tensor(handle, ty) + if hasattr(arg, "data_ptr"): + ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg)) + handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty) + return tl.tensor(handle, ty) + elif isinstance(arg, tuple): + return _tuple_create(arg, map(_implicit_cvt, arg)) + elif isinstance(arg, TensorDescriptor): + strides = [_implicit_cvt(s) for s in arg.strides] + assert arg.strides[-1] == 1 + strides[-1] = tl.constexpr(1) + semantic = TritonSemantic(InterpreterBuilder()) + return semantic.make_tensor_descriptor( + base=_implicit_cvt(arg.base), + shape=[_implicit_cvt(s) for s in arg.shape], + strides=strides, + block_shape=[tl.constexpr(b) for b in arg.block_shape], + ) + return arg + + +interpreter_builder = InterpreterBuilder() +interpreter_semantic = TritonSemantic(interpreter_builder) + + +def _unwrap_tensor(t): + if isinstance(t, triton.runtime.jit.TensorWrapper): + return t.base + return t + + +def _rewrap_tensor(t, original_tensor): + if isinstance(original_tensor, triton.runtime.jit.TensorWrapper): + return triton.runtime.jit.TensorWrapper(t, original_tensor.dtype) + return t + + +class GridExecutor: + + def __init__(self, fn, arg_names, grid): + from .jit import _normalize_ty # TODO: modularize + + self.fn = fn + self.arg_names = arg_names + self.grid = grid + __annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()} + self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"] + + def _init_args_hst(self, args_dev, kwargs): + storages = {} + + def _to_cpu(arg): + if isinstance(arg, tuple): + return _tuple_create(arg, map(_to_cpu, arg)) + elif isinstance(arg, TensorDescriptor): + return TensorDescriptor( + _to_cpu(arg.base), + arg.shape, + arg.strides, + arg.block_shape, + ) + elif not hasattr(arg, "data_ptr"): + return arg + + unwrapped_arg = _unwrap_tensor(arg) + if unwrapped_arg.untyped_storage().data_ptr() not in storages: + storage = unwrapped_arg.untyped_storage() + storages[storage.data_ptr()] = storage.cpu() + + storage = storages[unwrapped_arg.untyped_storage().data_ptr()] + cpu_arg = unwrapped_arg.new_empty(0, device='cpu') + cpu_arg.set_(storage, unwrapped_arg.storage_offset(), unwrapped_arg.size(), unwrapped_arg.stride()) + cpu_arg = _rewrap_tensor(cpu_arg, original_tensor=arg) + return cpu_arg + + args_hst = [_to_cpu(arg) for arg in args_dev] + + # Process keyword arguments + kwargs_hst = {} + for key, value in kwargs.items(): + kwargs_hst[key] = _to_cpu(value) + return args_hst, kwargs_hst + + def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst): + storages = {} + + def _from_cpu(arg_dev, arg_hst): + if hasattr(arg_dev, "data_ptr"): + # No need to rewrap because this just modifies internal + arg_dev, arg_hst = _unwrap_tensor(arg_dev), _unwrap_tensor(arg_hst) + storages[arg_dev.untyped_storage().data_ptr()] = (arg_dev.untyped_storage(), arg_hst.untyped_storage()) + elif isinstance(arg_dev, tuple): + for (arg_dev, arg_hst) in zip(arg_dev, arg_hst): + _from_cpu(arg_dev, arg_hst) + elif isinstance(arg_dev, TensorDescriptor): + _from_cpu(arg_dev.base, arg_hst.base) + + for arg_dev, arg_hst in zip(args_dev, args_hst): + _from_cpu(arg_dev, arg_hst) + + # Restore keyword arguments + for key, kwarg_dev in kwargs.items(): + kwarg_hst = kwargs_hst[key] + _from_cpu(kwarg_dev, kwarg_hst) + + for (arg_dev, arg_hst) in storages.values(): + arg_dev.copy_(arg_hst) + + def __call__(self, *args_dev, **kwargs): + if kwargs.pop("warmup", False): + return + # Removes not used reserved keywords from kwargs + # Triton doesn't support keyword-only, variable positional or variable keyword arguments + # It's safe to inspect only positional or keyword arguments (i.e., argspec.args) + argspec = inspect.getfullargspec(self.fn) + kwargs = {k: v for k, v in kwargs.items() if k in argspec.args} + # copy arguments to the host + args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) + # remaps core language functions to interpreted ones + _patch_lang(self.fn) + # we need to copy arguments to the host for the interpreter + # implicitly convert tensor arguments to their base pointers + args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst) + args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()} + # iterate through grid + grid = self.grid(args) if callable(self.grid) else self.grid + assert len(grid) <= 3, "grid must have at most 3 dimensions" + grid = grid + (1, ) * (3 - len(grid)) + interpreter_builder.set_grid_dim(*grid) + try: + for x in range(grid[0]): + for y in range(grid[1]): + for z in range(grid[2]): + interpreter_builder.set_grid_idx(x, y, z) + self.fn(**args) + except Exception as e: + if triton.knobs.compilation.front_end_debugging: + raise + raise InterpreterError(repr(e)) from e + # copy arguments back to propagate side-effects + self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst) + + +class ASTTransformer(ast.NodeTransformer): + + def visit_Assign(self, node): + names = [] + for target in node.targets: + names += [self.visit(target)] + if len(names) > 1: + raise ValueError("Multiple assignments are not supported") + # Modify the assignment x = value to + # interpreter_semantic.to_tensor(value, False) + node.value = ast.Call( + func=ast.Attribute(value=ast.Name(id="interpreter_semantic", ctx=ast.Load()), attr="to_tensor", + ctx=ast.Load()), args=[node.value, ast.Constant(value=False)], keywords=[]) + return node + + +class FunctionRewriter: + ast_transformer = ASTTransformer() + + def __init__(self, fn, **kwargs): + self.fn = fn + self.kwargs = kwargs + self.filename: str = "" + # Absolute line number in the file + self.def_file_lineno: int = 0 + + def rewrite_ast(self): + # If exception is raise, it means the function does not have source code available, + # e.g., dynamically generated functions, we cannot rewrite it so just return the original function + try: + lines, _ = inspect.getsourcelines(self.fn) + except Exception: + return self.fn + + # truncate lines before def + # @triton.autotune(...) + # ... + # @triton.jit + # ... + # def foo(...): <- this line is the function definition + self.filename, self.def_file_lineno = self._get_jit_fn_file_line() + self.def_lineno = self._find_def(lines) + src = self._prepare_source(lines) + transformed_ast = self._transform_ast(src) + return self._compile_and_exec(transformed_ast) + + def _get_jit_fn_file_line(self): + from .jit import get_jit_fn_file_line, JITFunction + return get_jit_fn_file_line(JITFunction(self.fn)) + + def _find_def(self, lines): + def_lineno = 0 + # Line numbers start from 1 + for i, line in enumerate(lines): + if line.strip().startswith("def "): + def_lineno = i + 1 + return def_lineno + + def _prepare_source(self, lines): + lines = lines[self.def_lineno - 1:] + src = ''.join(lines) + return textwrap.dedent(src) + + def _transform_ast(self, src): + # src is like: + # 1: def foo(...): + # 2: ... + parsed_ast = ast.parse(src) + transformed_ast = self.ast_transformer.visit(parsed_ast) + ast.fix_missing_locations(transformed_ast) + inc_lineno = self.def_file_lineno - 1 + ast.increment_lineno(transformed_ast, inc_lineno) + return transformed_ast + + def _compile_and_exec(self, transformed_ast): + compiled_code = compile(transformed_ast, filename=self.filename, mode='exec') + local_namespace = {**self.kwargs} + fn_globals = self.fn.__globals__ + for key, value in globals().items(): + if key not in fn_globals: + fn_globals[key] = value + exec(compiled_code, fn_globals, local_namespace) + return local_namespace[self.fn.__name__] + + +class InterpretedFunction: + # Cache all rewritten functions + rewritten_fn = {} + + def __init__(self, fn, **kwargs) -> None: + self.fn = fn + self.rewriter = FunctionRewriter(fn, **kwargs) + + def run(*args, **kwargs): + grid = kwargs["grid"] + fn = self.rewrite() + return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs) + + self.run = run + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + + def rewrite(self): + if self.fn not in self.rewritten_fn: + self.rewritten_fn[self.fn] = self.rewriter.rewrite_ast() + return self.rewritten_fn[self.fn] + + @property + def __name__(self): + return self.fn.__name__ + + def __getitem__(self, grid): + fn = self.rewrite() + return GridExecutor(fn, self.arg_names, grid) + + def __call__(self, *args, **kwargs): + # This is a device function call + _patch_lang(self.fn) + fn = self.rewrite() + try: + return fn(*args, **kwargs) + except Exception as e: + raise InterpreterError(repr(e)) from e diff --git a/third_party/sunrise/python/triton/runtime/jit.py b/third_party/sunrise/python/triton/runtime/jit.py new file mode 100644 index 000000000..70d8341ac --- /dev/null +++ b/third_party/sunrise/python/triton/runtime/jit.py @@ -0,0 +1,949 @@ +from __future__ import annotations, division +import ast +import copy +import hashlib +import inspect +import itertools +import re +import textwrap +from collections import defaultdict +from dataclasses import dataclass +from functools import cached_property +from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple + +from triton.tools.tensor_descriptor import TensorDescriptor +from types import ModuleType +from .. import knobs +from ..runtime.driver import driver +from .._utils import find_paths_if, get_iterable_path, type_canonicalisation_dict, canonicalize_dtype + +TRITON_MODULE = __name__[:-len(".runtime.jit")] + +T = TypeVar("T") + +# ----------------------------------------------------------------------------- +# Dependencies Finder +# ----------------------------------------------------------------------------- + + +class DependenciesFinder(ast.NodeVisitor): + """ + This AST visitor is used to find dependencies of a JITFunction. This can + be used to invalidate a JITFunction's hash when its source code -- or + that of its dependencies -- changes. + + This visitor also keeps track of the global variables touched by the + JITFunction. When we launch the kernel, we check that these have the same + values as they did when we ran this visitor. If not, we raise an error (or + otherwise we could recompile). + """ + + def __init__(self, name, globals, nonlocals, src) -> None: + super().__init__() + self.name = name + self.hasher = hashlib.sha256(src.encode("utf-8")) + + # This function's __globals__ dict. + self.globals = globals + self.nonlocals = nonlocals + + # Python builtins that can be accessed from Triton kernels. + self.supported_python_builtins = { + 'float', + 'getattr', + 'int', + 'isinstance', + 'len', + 'list', + 'max', + 'min', + 'print', + 'range', + } + + # used_global_vals tells us which global variables are used by this + # function and all those it transitively calls, plus the values of those + # variables when each function was initially run. (That is, if A calls + # C, and B calls C, then the values for C in used_global_vals will be + # from the first time C was run, either by A or B.) + # + # Each function may have a different __globals__ dict, so the global + # variable `foo` may actually have a different value in the different + # functions. Thus this map is actually + # (var_name, id(__globals__)) -> (var_value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + self.visiting_arg_default_value = False + + @property + def ret(self): + return self.hasher.hexdigest() + + def _is_triton_builtin(self, node, func): + if inspect.isbuiltin(node.func): + return True + module = getattr(func, "__module__", "") + return module.startswith(TRITON_MODULE) + + def _update_hash(self, func): + if isinstance(func, JITFunction): + # Merge our used_global_vals with those of the called function, + # after checking that all overlapping values are consistent. + for k in self.used_global_vals.keys() & func.used_global_vals.keys(): + var_name, _ = k + v1, _ = self.used_global_vals[k] + v2, _ = func.used_global_vals[k] + if v1 != v2: + raise RuntimeError( + f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed." + ) + self.used_global_vals.update(func.used_global_vals) + # update hash + func_key = func.cache_key + func_key += str(getattr(func, "noinline", False)) + self.hasher.update(func_key.encode("utf-8")) + + def visit_Name(self, node): + if type(node.ctx) is ast.Store: + return node.id + + if node.id in self.local_names: + # The global name is hidden by the local name. + return None + + def name_lookup(name): + val = self.globals.get(name, None) + if val is not None: + return val, self.globals + val = self.nonlocals.get(name, None) + if val is not None: + return val, self.nonlocals + return None, None + + val, var_dict = name_lookup(node.id) + + # Only keep track of "interesting" global variables, that non-evil users + # might change. Don't consider functions, modules, builtins, etc. This + # helps keep the list of vars we have to check small. + if (val is not None # + # Python default arguments are resolved only once, when the + # function is defined. So if you do `foo(a=A)` and the value of + # A changes, foo will still use the old value of A. + and not self.visiting_arg_default_value + # It would be pretty evil if someone did `import x` and then + # `x = blah`. + and type(val) is not ModuleType + # It would be pretty evil if we used function `foo` inside of + # `bar` and then someone did `foo = baz`. + and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) # + and node.id not in self.supported_python_builtins): + self.used_global_vals[(node.id, id(var_dict))] = (copy.copy(val), var_dict) + + self._update_hash(val) + return val + + def visit_Tuple(self, node): + # We need to explicitly return the tuple values so that visit_Assign can + # access them in the case of `a, b = ...`. + return [self.visit(elt) for elt in node.elts] + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + while isinstance(lhs, ast.Attribute): + lhs = self.visit(lhs.value) + if lhs is None or (getattr(lhs, "__name__", "") == TRITON_MODULE): + return None + ret = getattr(lhs, node.attr) + self._update_hash(ret) + return ret + + def visit_FunctionDef(self, node): + # Save the local name, which may hide the global name. + self.local_names = {arg.arg for arg in node.args.args} + self.generic_visit(node) + + def visit_arguments(self, node): + # The purpose of this function is to visit everything in `arguments` + # just like `generic_visit`, except when we're visiting default values + # (i.e. the `foo` part of `def fn(x = foo)`), we set + # self.visiting_arg_default_value = True. This allows visit_Name to be + # aware that we're inside function default values, which have special + # semantics. + + # According to the AST docs, the arguments node has the following structure. + # + # arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, + # expr* kw_defaults, arg? kwarg, expr* defaults) + def visit_defaults(defaults): + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + for expr in defaults: + if expr is not None: + self.visit(expr) + finally: + self.visiting_arg_default_value = False + + for arg in itertools.chain(node.posonlyargs, node.args, [node.vararg] if node.vararg else [], node.kwonlyargs): + self.visit(arg) + + visit_defaults(node.kw_defaults) + + if node.kwarg is not None: + self.visit(node.kwarg) + + visit_defaults(node.defaults) + + def visitAssnTarget(self, node): + # Target is either a single string, or a list of strings (if the assn + # target is a tuple). + target = self.visit(node) + if isinstance(target, list): + self.local_names |= set(target) + else: + self.local_names.add(target) + + def visit_Assign(self, node): + if len(node.targets) != 1: + # TODO(jlebar): I don't actually know how to hit this. You don't + # get it from `a, b = ...` -- in that case, node.targets is a single + # Tuple, and in fact we *do* need to handle that case if we want + # existing code to work. + raise TypeError("Simultaneous multiple assignment is not supported.") + + self.visitAssnTarget(node.targets[0]) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_AnnAssign(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_For(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's fine. + self.generic_visit(node) + + +# ----------------------------------------------------------------------------- +# JITFunction +# ----------------------------------------------------------------------------- + + +def _normalize_ty(ty) -> str: + import triton.language.core as core + if isinstance(ty, str): + ty = ty.strip() + if ty.startswith("const "): + ty = ty.removeprefix("const") + ty = _normalize_ty(ty) + assert ty.startswith("*") + return "*k" + ty[1:] + if ty.endswith("*"): + return "*" + _normalize_ty(ty[:-1]) + if ty.startswith("*"): + return "*" + _normalize_ty(ty[1:]) + if ty.startswith("tl."): + return _normalize_ty(ty.removeprefix("tl.")) + elif isinstance(ty, core.pointer_type): + return f"*{_normalize_ty(ty.element_ty)}" + elif isinstance(ty, core.dtype): + ty = ty.name + elif isinstance(ty, type): + ty = ty.__name__ + else: + ty = str(ty) + return type_canonicalisation_dict.get(ty.replace("_t", ""), ty) + + +class KernelParam: + """Represents a parameter (name plus metadata) to a @jit'ed function.""" + + def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool, + do_not_specialize_on_alignment: bool): + self.num = num + self._param = param + self.do_not_specialize = do_not_specialize + self.do_not_specialize_on_alignment = do_not_specialize_on_alignment + + @cached_property + def name(self): + return self._param.name + + @cached_property + def annotation(self) -> str: + if not self._param.annotation or self._param.annotation == inspect.Parameter.empty: + return "" + return _normalize_ty(self._param.annotation) + + @cached_property + def annotation_type(self) -> str: + a = self.annotation + if a.startswith("*k"): + a = a[2:] + elif a.startswith("*"): + a = a[1:] + if a in set(type_canonicalisation_dict.values()): + return self.annotation + return "" + + @cached_property + def is_constexpr(self): + return "constexpr" in self.annotation + + @cached_property + def is_const(self): + if self.is_constexpr: + return False + return "const" in self.annotation or self.annotation.startswith("*k") + + @property + def default(self): + return self._param.default + + @property + def has_default(self): + return self._param.default != inspect.Parameter.empty + + +dtype2str = {} +specialize_impl_cache = [] + + +def create_specialize_impl(specialize_extra): + + from ..language import constexpr + from triton.experimental.gluon.nvidia.hopper import TensorDescriptor as GluonTensorDescriptor + + def specialize_impl(arg, is_const=False, specialize_value=True, align=True): + if arg is None: + return ("constexpr", None) + elif isinstance(arg, bool): + return ("u1", None) + elif isinstance(arg, int): + key = specialize_extra(arg, "int", align=align) if specialize_value else None + if arg == 1 and specialize_value: + return ("constexpr", 1) + elif -(2**31) <= arg and arg <= 2**31 - 1: + return ("i32", key) + elif 2**63 <= arg and arg <= 2**64 - 1: + return ("u64", key) + else: + return ("i64", key) + elif isinstance(arg, float): + return ("fp32", None) + elif hasattr(arg, "data_ptr"): + # dtypes are hashable so we can memoize this mapping: + dsk = (arg.dtype, is_const) + res = dtype2str.get(dsk, None) + if res is None: + res = ("*k" if dsk[1] else "*") + canonicalize_dtype(dsk[0]) + dtype2str[dsk] = res + key = specialize_extra(arg, "tensor", align=align) if specialize_value else None + return (res, key) + elif isinstance(arg, JITFunction): + return ("constexpr", arg.cache_key) + elif isinstance(arg, constexpr): + return ("constexpr", arg) + elif hasattr(arg, "tma_desc_cpu_ptr"): + return ("nvTmaDesc", None) + elif isinstance(arg, tuple): + spec = [specialize_impl(x) for x in arg] + make_tuple = lambda vals: type(arg)(*vals) if hasattr(arg, "_fields") else tuple(vals) + tys = make_tuple([x[0] for x in spec]) + keys = make_tuple([x[1] for x in spec]) + return (tys, keys) + elif isinstance(arg, TensorDescriptor): + assert hasattr(arg.base, "data_ptr") + inner = canonicalize_dtype(arg.base.dtype) + return (f"tensordesc<{inner}{list(arg.block_shape)}>", None) + elif isinstance(arg, GluonTensorDescriptor): + assert hasattr(arg.base, "data_ptr") + inner = canonicalize_dtype(arg.base.dtype) + return (f"tensordesc<{inner}{list(arg.block_shape)},{arg.layout!r}>", None) + else: + raise TypeError("Unsupported type: %s" % type(arg)) + + return specialize_impl + + +def mangle_type(arg, specialize=False): + if len(specialize_impl_cache) == 0: + specialize_impl_cache.append(create_specialize_impl(lambda _, **kwargs: None)) + specialize_impl = specialize_impl_cache[0] + return specialize_impl(arg, specialize_value=specialize)[0] + + +class KernelInterface(Generic[T]): + run: T + + def __getitem__(self, grid) -> T: + """ + A JIT function is launched with: fn[grid](*args, **kwargs). + Hence JITFunction.__getitem__ returns a callable proxy that + memorizes the grid. + """ + return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) + # return cast(T, functools.partial(cast(Callable, self.run), grid=grid)) + + +def serialize_specialization_data(name, signature, constants, attrs, options, key): + constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()} + import json + obj = { + 'name': name, 'signature': signature, 'constant_keys': [list(x) for x in constants.keys()], 'constant_vals': + list(constants.values()), 'attrs_keys': [list(x) for x in attrs.keys()], 'attrs_vals': list(attrs.values()), + 'options': options.__dict__, 'key': key + } + serialized_obj = json.dumps(obj) + return serialized_obj + + +def create_function_from_signature(sig, kparams, backend): + """ + Equivalent to sig.bind followed by apply_defaults. This generates a + native Python function (using exec) which can be memoized on a per-kernel + basis to avoid having to run these expensive functions -- which constitute + much of the kernel launch overhead -- every time we run the kernel. + """ + assert len(sig.parameters) == len(kparams) + # Create the function argument list and the dict entries for the return statement + specialization = [] + # signature + for name, kp in zip(sig.parameters.keys(), kparams): + if kp.is_constexpr: + specialization.append(f'("constexpr", {name})') + else: + is_const = 'True' if kp.is_const else 'False' + specialize = 'False' if kp.do_not_specialize else 'True' + align = 'False' if kp.do_not_specialize_on_alignment else 'True' + ret = f"specialize_impl({name}, {is_const}, {specialize}, {align})" + if kp.annotation_type: + if isinstance(kp.annotation_type, str): + if kp.annotation_type == "u1" or kp.annotation_type[:2] in ["fp", "bf"]: + # we do not specialize non-constexpr floats and bools: + specialize = False + if kp.annotation_type == "i32": + specialization.append(f"{ret}") + continue + if specialize: + specialization.append(f'("{kp.annotation_type}",) + {ret}[1:]') + else: + # skip runtime specialization: + specialization.append(f'("{kp.annotation_type}", None)') + else: + specialization.append(f"{ret}") + + # compute argument string for a given parameter + arg = lambda x: x[0] if x[1].default is inspect.Parameter.empty else f"{x[0]}=default_{x[0]}" + # Join all arguments into a function definition string + func_body = f""" +def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options"])}): + params = {{{', '.join([f"'{name}': {name}" for name in sig.parameters.keys()])}}} + specialization = [{','.join(specialization)}] + return params, specialization, options +""" + # Prepare defaults to be inserted into function namespace + func_namespace = { + f"default_{name}": param.default + for name, param in sig.parameters.items() + if param.default is not inspect.Parameter.empty + } + + func_namespace["JITFunction"] = JITFunction + func_namespace["specialize_impl"] = create_specialize_impl(backend.get_arg_specialization) + + # Execute the function string in func_namespace to create the function + exec(func_body, func_namespace) + + # Extract the newly created function from the namespace + return func_namespace['dynamic_func'] + + +def get_full_name(fn): + return f"{fn.__module__}.{fn.__qualname__}" + + +@dataclass +class JitFunctionInfo: + module: ModuleType + name: str + jit_function: JITFunction + + +class JITFunction(KernelInterface[T]): + + def is_gluon(self): + return False + + def _call_hook( + self, + hook, + key, + signature, + device, + constants, + options, + configs, + is_warmup, + ) -> bool | None: + if not hook: + return None + + name = self.fn.__qualname__ + module = self.fn.__module__ + arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])]) + repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}, launch_cooperative_grid={options.launch_cooperative_grid}]({arg_reprs})" + full_name = get_full_name(self.fn) + + specialization_data = serialize_specialization_data(full_name, signature, constants, configs[0], options, key) + + kwargs = { + 'signature': signature, + 'device': device, + 'constants': constants, + 'num_warps': options.num_warps, + 'num_ctas': options.num_ctas, + 'num_stages': options.num_stages, + 'enable_fp_fusion': options.enable_fp_fusion, + 'launch_cooperative_grid': options.launch_cooperative_grid, + 'extern_libs': options.extern_libs, + 'configs': configs, + 'specialization_data': specialization_data, + 'is_warmup': is_warmup, + } + + return hook( + key=key, + repr=repr, + fn=JitFunctionInfo(module, name, self), + compile={"key": key, **kwargs}, + is_manual_warmup=is_warmup, + already_compiled=False, + ) + + def add_pre_run_hook(self, hook): + ''' + Add a hook that will be executed prior to the execution of run + function with args and kwargs passed into the kernel + ''' + assert callable(hook) + self.pre_run_hooks.append(hook) + + def create_binder(self): + """ + Precompute as much as possible. + """ + from ..compiler import CompiledKernel, compile, ASTSource, make_backend + target = driver.active.get_current_target() + backend = make_backend(target) + self.CompiledKernel = CompiledKernel + self.compile = compile + self.ASTSource = ASTSource + binder = create_function_from_signature(self.signature, self.params, backend) + return {}, target, backend, binder + + def run(self, *args, grid, warmup, **kwargs): + kwargs["debug"] = kwargs.get("debug", self.debug) or knobs.runtime.debug + + # parse options + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + + # Execute pre run hooks with args and kwargs + for hook in self.pre_run_hooks: + hook(*args, **kwargs) + + kernel_cache, target, backend, binder = self.device_caches[device] + # specialization is list[tuple[str, Any]], where first element of tuple is + # the type and the second parameter is the 'specialization' value. + bound_args, specialization, options = binder(*args, **kwargs) + + # compute cache key + key = str(specialization) + str(options) + kernel = kernel_cache.get(key, None) + + # Kernel is not cached; we have to compile. + if kernel is None: + # options + options = backend.parse_options(kwargs) + # signature + sigkeys = [x.name for x in self.params] + sigvals = [x[0] for x in specialization] + signature = {k: v for (k, v) in zip(sigkeys, sigvals)} + # check arguments + assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used" + assert "device" not in kwargs, "device option is deprecated; current device will be used" + assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" + for k in kwargs: + if k not in options.__dict__ and k not in sigkeys: + raise KeyError("Keyword argument %s was specified but unrecognised" % k) + # constexprs + constexprs = find_paths_if(sigvals, lambda _, val: val == "constexpr") + constexprs = {path: get_iterable_path(list(bound_args.values()), path) for path in constexprs} + # attributes + attrvals = [x[1] for x in specialization] + attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str)) + attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs} + if self._call_hook(knobs.runtime.jit_cache_hook, key, signature, device, constexprs, options, [attrs], + warmup): + return None + # compile the kernel + src = self.ASTSource(self, signature, constexprs, attrs) + kernel = self.compile(src, target=target, options=options.__dict__) + kernel_cache[key] = kernel + self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, device, constexprs, options, [attrs], + warmup) + + # Check that used global values have not changed. + not_present = object() + for (name, _), (val, globals_dict) in self.used_global_vals.items(): + if (newVal := globals_dict.get(name, not_present)) != val: + raise RuntimeError( + f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}") + + if not warmup: + # canonicalize grid + assert grid is not None + if callable(grid): + grid = grid(bound_args) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + # launch kernel + launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values()) + kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, + knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *bound_args.values()) + return kernel + + def repr(self, _): + return self._fn_name if self._repr is None else self._repr(_) + + def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_on_alignment=None, debug=None, + noinline=None, repr=None, launch_metadata=None): + do_not_specialize = do_not_specialize if do_not_specialize else [] + do_not_specialize_on_alignment = do_not_specialize_on_alignment if do_not_specialize_on_alignment else [] + + self.fn = fn + self.module = fn.__module__ + self.version = version + self.signature = inspect.signature(fn) + self.do_not_specialize = do_not_specialize + self.do_not_specialize_on_alignment = do_not_specialize_on_alignment + self.starting_line_number = inspect.getsourcelines(fn)[1] + self._repr = repr + self._fn_name = get_full_name(fn) + self.launch_metadata = launch_metadata + + self.params = [] + for i, param in enumerate(self.signature.parameters.values()): + dns = i in do_not_specialize or param.name in do_not_specialize + dns_oa = i in do_not_specialize_on_alignment or param.name in do_not_specialize_on_alignment + self.params.append(KernelParam(i, param, dns, dns_oa)) + + # function source code (without decorators) + src = textwrap.dedent(inspect.getsource(fn)) + src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():] + self._unsafe_update_src(src) + # cache of just-in-time compiled kernels + self.device_caches = defaultdict(self.create_binder) + self.hash = None + + # Map of global variables used by the function and any functions it + # transitively calls, plus their values. The values are collected when + # the function is first compiled. Then every time we run the function, + # we check that the values of the globals match what's expected, + # otherwise we raise an error. + # + # Different functions can have different __globals__ maps, so the map + # key is actually (var name, id(__globals__)), and the map value is + # (value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + # JITFunction can be instantiated as kernel + # when called with a grid using __getitem__ + self.kernel = None + self.debug = debug + self.noinline = noinline + + # TODO(jlebar): Remove uses of these fields outside this file, then + # remove the fields here. + self.arg_names = [p.name for p in self.params] + self.constexprs = [p.num for p in self.params if p.is_constexpr] + + # Hooks that will be called prior to executing "run" + self.pre_run_hooks = [] + + # reuse docs of wrapped function + self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__qualname__ = fn.__qualname__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ + + def get_capture_scope(self): + return self.__globals__ | inspect.getclosurevars(self.fn).nonlocals + + @property + def cache_key(self): + # TODO : hash should be attribute of `self` + if self.hash is None: + nonlocals = inspect.getclosurevars(self.fn).nonlocals + dependencies_finder = DependenciesFinder(name=self._fn_name, globals=self.__globals__, nonlocals=nonlocals, + src=self.src) + dependencies_finder.visit(self.parse()) + self.hash = dependencies_finder.ret + str(self.starting_line_number) + self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items())) + return self.hash + + @property + def type(self): + from triton.language.core import constexpr + return constexpr + + def warmup(self, *args, grid, **kwargs): + return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs) + + def preload(self, specialization_data): + from ..compiler import compile, ASTSource + import json + import triton.language as tl + device = driver.active.get_current_device() + deserialized_obj = json.loads(specialization_data) + if deserialized_obj['name'] != self._fn_name: + raise RuntimeError( + f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self._fn_name}") + constant_keys = map(tuple, deserialized_obj['constant_keys']) + constant_vals = deserialized_obj['constant_vals'] + constants = { + key: tl.dtype(value) if tl.dtype.is_dtype(value) else value + for key, value in zip(constant_keys, constant_vals) + } + attrs_keys = map(tuple, deserialized_obj['attrs_keys']) + attrs_vals = deserialized_obj['attrs_vals'] + attrs = dict(zip(attrs_keys, attrs_vals)) + signature = dict(deserialized_obj['signature'].items()) + src = ASTSource(self, signature, constants, attrs) + options = { + key: tuple(value) if isinstance(value, list) else value + for key, value in deserialized_obj['options'].items() + } + key = deserialized_obj['key'] + kernel = compile(src, None, options) + self.device_caches[device][0][key] = kernel + return kernel + + # we do not parse `src` in the constructor because + # the user might want to monkey-patch self.src dynamically. + # Our unit tests do this, for example. + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + assert len(tree.body) == 1 + assert isinstance(tree.body[0], ast.FunctionDef) + return tree + + def __call__(self, *args, **kwargs): + raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") + + def __setattr__(self, name, value): + # - when `.src` attribute is set, cache key of all callers need to be re-computed + if name == "src": + raise AttributeError(f"Cannot set attribute '{name}' directly. " + f"Use '_unsafe_update_src()' and manually clear `.hash` of all callers" + f"instead.") + super(JITFunction, self).__setattr__(name, value) + + def _unsafe_update_src(self, new_src): + """ + The only method allowed to modify src. + Bypasses the __setattr__ restriction by calling super().__setattr__ directly. + """ + self.hash = None + super().__setattr__('src', new_src) + + def __repr__(self): + return f"JITFunction({self.module}:{self.fn.__qualname__})" + + +# ----------------------------------------------------------------------------- +# `jit` decorator +# ----------------------------------------------------------------------------- + + +@overload +def jit(fn: T) -> JITFunction[T]: + ... + + +@overload +def jit( + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int | str]] = None, + do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Callable[[T], JITFunction[T]]: + ... + + +def jit( + fn: Optional[T] = None, + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int | str]] = None, + do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: + """ + Decorator for JIT-compiling a function using the Triton compiler. + + :note: When a jit'd function is called, arguments are + implicitly converted to pointers if they have a :code:`.data_ptr()` method + and a `.dtype` attribute. + + :note: This function will be compiled and run on the GPU. It will only have access to: + + * python primitives, + * builtins within the triton package, + * arguments to this function, + * other jit'd functions + + :param fn: the function to be jit-compiled + :type fn: Callable + """ + + def decorator(fn: T) -> JITFunction[T]: + assert callable(fn) + if knobs.runtime.interpret: + from .interpreter import InterpretedFunction + return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize, + do_not_specialize_on_alignment=do_not_specialize_on_alignment, debug=debug, + noinline=noinline, repr=repr, launch_metadata=launch_metadata) + else: + return JITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + do_not_specialize_on_alignment=do_not_specialize_on_alignment, + debug=debug, + noinline=noinline, + repr=repr, + launch_metadata=launch_metadata, + ) + + if fn is not None: + return decorator(fn) + + else: + return decorator + + +# ----------------------------------------------------------------------------- +# Utilities for mocking tensors +# ----------------------------------------------------------------------------- + + +class MockTensor: + """ + Can be used in place of real tensors when calling: + kernel.warmup(MockTensor(torch.float32), ...) + """ + + @staticmethod + def wrap_dtype(arg): + if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch": + return MockTensor(arg) + return arg + + def __init__(self, dtype): + self.dtype = dtype + + @staticmethod + def data_ptr(): + return 0 # optimistically assumes multiple of 16 + + @staticmethod + def ptr_range(): + return 0 # optimistically assumes 32 bit pointer range + + +class TensorWrapper: + + def __init__(self, base, dtype): + self.dtype = dtype + self.base = base + self.data = base.data + self.device = base.device + self.shape = self.base.shape + + def data_ptr(self): + return self.base.data_ptr() + + def stride(self, *args): + return self.base.stride(*args) + + def __str__(self) -> str: + return f"TensorWrapper[{self.dtype}]({self.base})" + + def element_size(self): + return self.base.element_size() + + def cpu(self): + return TensorWrapper(self.base.cpu(), self.dtype) + + def copy_(self, other): + self.base.copy_(other.base) + + def clone(self): + return TensorWrapper(self.base.clone(), self.dtype) + + def to(self, device): + return TensorWrapper(self.base.to(device), self.dtype) + + def new_empty(self, sizes): + return TensorWrapper(self.base.new_empty(sizes), self.dtype) + + +def reinterpret(tensor, dtype): + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif hasattr(tensor, "data_ptr"): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f"Cannot reinterpret a {type(tensor)}.") + + +def get_jit_fn_file_line(fn): + base_fn = fn + while not isinstance(base_fn, JITFunction): + base_fn = base_fn.fn + file_name = base_fn.fn.__code__.co_filename + lines, begin_line = inspect.getsourcelines(base_fn.fn) + # Match the following pattern: + # @triton.autotune(...) <- foo.__code__.co_firstlineno + # @triton.heuristics(...) + # @triton.jit + # def foo(...): <- this line is the first line + for idx, line in enumerate(lines): + if line.strip().startswith("def "): + begin_line += idx + break + return file_name, begin_line diff --git a/third_party/sunrise/python/triton/testing.py b/third_party/sunrise/python/triton/testing.py new file mode 100644 index 000000000..df6db3bae --- /dev/null +++ b/third_party/sunrise/python/triton/testing.py @@ -0,0 +1,543 @@ +import functools +import math +import os +import statistics +import subprocess +import sys +from contextlib import contextmanager +from typing import Any, Dict, List +from . import language as tl +from . import runtime + + +def nvsmi(attrs): + attrs = ','.join(attrs) + cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] + out = subprocess.check_output(cmd) + ret = out.decode(sys.stdout.encoding).split(',') + ret = [int(x) for x in ret] + return ret + + +# pure Python implementation of np.quantile/torch.quantile +# to avoid unnecessary runtime dependency on numpy/torch + + +def _quantile(a, q): + n = len(a) + a = sorted(a) + + def get_quantile(q): + if not (0 <= q <= 1): + raise ValueError("Quantiles must be in the range [0, 1]") + point = q * (n - 1) + lower = math.floor(point) + upper = math.ceil(point) + t = point - lower + return (1 - t) * a[lower] + t * a[upper] + + return [get_quantile(q) for q in q] + + +def _summarize_statistics(times, quantiles, return_mode): + if quantiles is not None: + ret = _quantile(times, quantiles) + if len(ret) == 1: + ret = ret[0] + return ret + if return_mode == "all": + return times + elif return_mode == "min": + return min(times) + elif return_mode == "max": + return max(times) + elif return_mode == "mean": + return statistics.mean(times) + elif return_mode == "median": + return statistics.median(times) + + +def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"): + """ + Benchmark the runtime of the provided function. + + :param fn: Function to benchmark + :type fn: Callable + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean". + :type return_mode: str + """ + import torch + assert return_mode in ["min", "max", "mean", "median", "all"] + + with torch.cuda.stream(torch.cuda.Stream()): + # warmup + fn() + if grad_to_none is not None: + for x in grad_to_none: + x.detach_() + x.requires_grad_(True) + x.grad = None + # step 1 - we estimate the amount of time the kernel call takes + # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point + # but it is probably good enough + # NOTE: we don't use a graph to estimate the runtime because creating a graph is expensive, + # ~300ms on A100, so we default to the same method used in `do_bench` (minus the L2 + # cache flush). + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + # Rewrite to avoid possible division by 0 issues with fast benchmarks + if estimate_ms == 0: + n_repeat = 1000 + else: + n_repeat = max(1, int(rep / estimate_ms)) + # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize + # host overhead + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(n_repeat): + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + fn() + torch.cuda.synchronize() + # measure time and return + ret = [] + n_retries = 10 + for _ in range(n_retries): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + ret += [start_event.elapsed_time(end_event) / n_repeat] + return _summarize_statistics(ret, quantiles, return_mode) + + +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="median"): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float], optional + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean". + :type return_mode: str + """ + assert return_mode in ["min", "max", "mean", "median", "all"] + + di = runtime.driver.active.get_device_interface() + + fn() + di.synchronize() + + cache = runtime.driver.active.get_empty_cache_for_benchmark() + + # Estimate the runtime of the function + start_event = di.Event(enable_timing=True) + end_event = di.Event(enable_timing=True) + start_event.record() + for _ in range(5): + runtime.driver.active.clear_cache(cache) + fn() + end_event.record() + di.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + start_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + runtime.driver.active.clear_cache(cache) + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + di.synchronize() + times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)] + return _summarize_statistics(times, quantiles, return_mode) + + +def assert_close(x, y, atol=None, rtol=None, err_msg=''): + """ + Asserts that two inputs are close within a certain tolerance. + + :param x: The first input. + :type x: scala, list, numpy.ndarray, or torch.Tensor + :param y: The second input. + :type y: scala, list, numpy.ndarray, or torch.Tensor + :param atol: The absolute tolerance. Default value is 1e-2. + :type atol: float, optional + :param rtol: The relative tolerance. Default value is 0. + :type rtol: float, optional + :param err_msg: The error message to use if the assertion fails. + :type err_msg: str + """ + import numpy as np + import torch + + # canonicalize arguments to be tensors + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + if not isinstance(y, torch.Tensor): + y = torch.tensor(y) + # absolute tolerance + if atol is None: + atol = 1e-2 + atol = atol(x.dtype) if callable(atol) else atol + # relative tolerance hook + if rtol is None: + rtol = 0. + rtol = rtol(x.dtype) if callable(rtol) else rtol + # we use numpy instead of pytorch + # as it seems more memory efficient + # pytorch tends to oom on large tensors + if isinstance(x, torch.Tensor): + if x.dtype == torch.bfloat16: + x = x.float() + x = x.cpu().detach().numpy() + if isinstance(y, torch.Tensor): + if y.dtype == torch.bfloat16: + y = y.float() + y = y.cpu().detach().numpy() + # we handle size==1 case separately as we can + # provide better error message there + if x.size > 1 or y.size > 1: + np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True) + return + if not np.allclose(x, y, atol=atol, rtol=rtol): + raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})') + + +class Benchmark: + """ + This class is used by the :code:`perf_report` function to generate line plots with a concise API. + """ + + def __init__( + self, + x_names: List[str], + x_vals: List[Any], + line_arg: str, + line_vals: List[Any], + line_names: List[str], + plot_name: str, + args: Dict[str, Any], + xlabel: str = '', + ylabel: str = '', + x_log: bool = False, + y_log: bool = False, + styles=None, + ): + """ + Constructor. + x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list + of scalars and there are multiple x_names, all arguments will have the same value. + If x_vals is a list of tuples/lists, each element should have the same length as + x_names. + + :param x_names: Name of the arguments that should appear on the x axis of the plot. + :type x_names: List[str] + :param x_vals: List of values to use for the arguments in :code:`x_names`. + :type x_vals: List[Any] + :param line_arg: Argument name for which different values correspond to different lines in the plot. + :type line_arg: str + :param line_vals: List of values to use for the arguments in :code:`line_arg`. + :type line_vals: List[Any] + :param line_names: Label names for the different lines. + :type line_names: List[str] + :param plot_name: Name of the plot. + :type plot_name: str + :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark. + :type args: Dict[str, Any] + :param xlabel: Label for the x axis of the plot. + :type xlabel: str, optional + :param ylabel: Label for the y axis of the plot. + :type ylabel: str, optional + :param x_log: Whether the x axis should be log scale. + :type x_log: bool, optional + :param y_log: Whether the y axis should be log scale. + :type y_log: bool, optional + :param styles: A list of tuples, where each tuple contains two elements: a color and a linestyle. + :type styles: list[tuple[str, str]] + """ + self.x_names = x_names + self.x_vals = x_vals + self.x_log = x_log + self.line_arg = line_arg + self.line_vals = line_vals + self.line_names = line_names + self.y_log = y_log + self.styles = styles + # plot info + self.xlabel = xlabel + self.ylabel = ylabel + self.plot_name = plot_name + self.args = args + + +class Mark: + + def __init__(self, fn, benchmarks): + self.fn = fn + self.benchmarks = benchmarks + + def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, + save_precision=6, **kwrags): + import os + + import matplotlib.pyplot as plt + import pandas as pd + y_mean = bench.line_names + y_min = [f'{x}-min' for x in bench.line_names] + y_max = [f'{x}-max' for x in bench.line_names] + x_names = list(bench.x_names) + df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max) + for x in bench.x_vals: + # x can be a single value or a sequence of values. + if not isinstance(x, (list, tuple)): + x = [x for _ in x_names] + + if len(x) != len(x_names): + raise ValueError(f"Expected {len(x_names)} values, got {x}") + x_args = dict(zip(x_names, x)) + + row_mean, row_min, row_max = [], [], [] + for y in bench.line_vals: + ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags) + try: + y_mean, y_min, y_max = ret + except TypeError: + y_mean, y_min, y_max = ret, None, None + row_mean += [y_mean] + row_min += [y_min] + row_max += [y_max] + df.loc[len(df)] = list(x) + row_mean + row_min + row_max + + if bench.plot_name: + plt.figure() + ax = plt.subplot() + # Plot first x value on x axis if there are multiple. + first_x = x_names[0] + for i, y in enumerate(bench.line_names): + y_min, y_max = df[y + '-min'], df[y + '-max'] + col = bench.styles[i][0] if bench.styles else None + sty = bench.styles[i][1] if bench.styles else None + ax.plot(df[first_x], df[y], label=y, color=col, ls=sty) + if not y_min.isnull().all() and not y_max.isnull().all(): + y_min = y_min.astype(float) + y_max = y_max.astype(float) + ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col) + ax.legend() + ax.set_xlabel(bench.xlabel or first_x) + ax.set_ylabel(bench.ylabel) + # ax.set_title(bench.plot_name) + ax.set_xscale("log" if bench.x_log else "linear") + ax.set_yscale("log" if bench.y_log else "linear") + if show_plots: + plt.show() + if save_path: + plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) + df = df[x_names + bench.line_names] + if diff_col and df.shape[1] == 2: + col0, col1 = df.columns.tolist() + df['Diff'] = df[col1] - df[col0] + + if print_data: + print(bench.plot_name + ':') + print(df.to_string()) + if save_path: + df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f", + index=False) + return df + + def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs): + has_single_bench = isinstance(self.benchmarks, Benchmark) + benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks + result_dfs = [] + try: + for bench in benchmarks: + result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs)) + finally: + if save_path: + # Create directory if it doesn't exist + os.makedirs(save_path, exist_ok=True) + with open(os.path.join(save_path, "results.html"), "w") as html: + html.write("\n") + for bench in benchmarks[:len(result_dfs)]: + html.write(f"\n") + html.write("\n") + if return_df: + if has_single_bench: + return result_dfs[0] + else: + return result_dfs + return None + + +def perf_report(benchmarks): + """ + Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value. + + :param benchmarks: Benchmarking configurations. + :type benchmarks: List of :class:`Benchmark` + """ + wrapper = lambda fn: Mark(fn, benchmarks) + return wrapper + + +def get_dram_gbps(device=None): + ''' return DRAM bandwidth in GB/s ''' + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + mem_clock_khz = driver.active.utils.get_device_properties(device)["mem_clock_rate"] # in kHz + bus_width = driver.active.utils.get_device_properties(device)["mem_bus_width"] + bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s + return bw_gbps + + +def get_max_tensorcore_tflops(dtype, clock_rate, device=None): + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8: + assert dtype == torch.float16 + ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores + else: + if dtype in [torch.float32, torch.int32]: + ops_per_sub_core = 256 + elif dtype in [torch.float16, torch.bfloat16, torch.int16]: + ops_per_sub_core = 512 + elif dtype in [torch.int8, tl.float8e4nv, tl.float8e4b15, tl.float8e5]: + ops_per_sub_core = 1024 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops + + +# create decorator that wraps test function into +# a cuda-memcheck system call + + +def cuda_memcheck(**target_kwargs): + + def decorator(test_fn): + + @functools.wraps(test_fn) + def wrapper(*args, **kwargs): + import psutil + ppid_name = psutil.Process(os.getppid()).name() + run_cuda_memcheck = target_kwargs.items() <= kwargs.items() + if run_cuda_memcheck and ppid_name != "cuda-memcheck": + path = os.path.realpath(test_fn.__globals__["__file__"]) + # get path of current file + env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"} + assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture" + test_id = kwargs['request'].node.callspec.id + cmd = f"{path}::{test_fn.__name__}[{test_id}]" + out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env) + assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed" + assert "ERROR SUMMARY: 0 errors" in str(out.stdout) + else: + test_fn(*args, **kwargs) + + return wrapper + + return decorator + + +@contextmanager +def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): + try: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", + ]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", + ]) + cur_sm_clock = nvsmi(["clocks.current.sm"])[0] + cur_mem_clock = nvsmi(["clocks.current.memory"])[0] + assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" + assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz" + tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock + gbps = 640 * 2 * ref_mem_clock * 1e-3 + yield tflops, gbps + finally: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"]) + + +def get_max_simd_tflops(dtype, clock_rate, device=None): + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + if dtype == torch.float32: + ops_per_sub_core = 32 # 2*16 + elif dtype == torch.float16: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + else: + if dtype == torch.float32: + ops_per_sub_core = 32 + elif dtype in [torch.float16, torch.bfloat16]: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops diff --git a/third_party/sunrise/python/triton/tools/__init__.py b/third_party/sunrise/python/triton/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/sunrise/python/triton/tools/build_extern.py b/third_party/sunrise/python/triton/tools/build_extern.py new file mode 100644 index 000000000..8f0168d59 --- /dev/null +++ b/third_party/sunrise/python/triton/tools/build_extern.py @@ -0,0 +1,365 @@ +import argparse +import subprocess +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + + +class Symbol: + _name: str + _op_name: str + _ret_type: str + _arg_names: List[str] + _arg_types: List[str] + + def __init__( + self, + name: str, + op_name: str, + ret_type: str, + arg_names: List[str], + arg_types: List[str], + ) -> None: + ''' + A symbol is a function declaration. + :param name: name of the symbol + :param op_name: name of the operation + :param ret_type: return type of the operation + :param arg_names: names of the arguments + :param arg_types: types of the arguments + ''' + self._name = name + self._op_name = op_name + self._ret_type = ret_type + self._arg_names = list(arg_names) + self._arg_types = list(arg_types) + + @property + def name(self) -> str: + return self._name + + @property + def op_name(self) -> str: + return self._op_name + + @property + def ret_type(self) -> str: + return self._ret_type + + @property + def arg_names(self) -> List[str]: + return self._arg_names + + @property + def arg_types(self) -> List[str]: + return self._arg_types + + +def convert_type(type_str) -> Optional[str]: + if type_str == "i32": + return "int32" + elif type_str == "u32": + return "uint32" + elif type_str == "i64": + return "int64" + elif type_str == "u64": + return "uint64" + elif type_str == "float": + return "fp32" + elif type_str == "double": + return "fp64" + else: + # ignore other types, such as pointer types + return None + + +def to_unsigned(type_str) -> str: + if type_str == "int32": + return "uint32" + elif type_str == "int64": + return "uint64" + else: + return type_str + + +class ExternLibrary(ABC): + _name: str + _path: str + _symbols: Dict[str, Symbol] + _format: bool + _grouping: bool + + def __init__( + self, + name: str, + path: str, + format: bool = True, + grouping: bool = True, + ) -> None: + ''' + Abstract class for extern library. + :param name: name of the library + :param path: path of the library + :param format: whether to format the generated stub file + ''' + self._name = name + self._path = path + self._symbols = {} + self._format = format + self._grouping = grouping + + @property + def name(self) -> str: + return self._name + + @property + def path(self) -> str: + return self._path + + @property + def symbols(self) -> Dict[str, Symbol]: + return self._symbols + + @property + def grouping(self) -> bool: + return self._grouping + + @abstractmethod + def parse_symbols(self, input_file) -> None: + pass + + @abstractmethod + def _output_stubs(self) -> str: + pass + + def generate_stub_file(self, output_dir) -> None: + file_str = self._output_stubs() + if file_str is None or len(file_str) == 0: + raise Exception("file_str is empty") + + output_file = f"{output_dir}/{self._name}.py" + with open(output_file, "w") as f: + f.write(file_str) + f.close() + if self._format: + subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate() + subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate() + + +class Libdevice(ExternLibrary): + _symbol_groups: Dict[str, List[Symbol]] + + def __init__(self, path) -> None: + ''' + Constructor for Libdevice. + :param path: path of the libdevice library + ''' + super().__init__("libdevice", path) + self._symbol_groups = {} + self.is_pure = True + + @staticmethod + def _extract_symbol(line) -> Optional[Symbol]: + # Extract symbols from line in the following format: + # "define [internal] @(,)" + entries = line.split("@") + ret_str = entries[0] + func_str = entries[1] + # Get ret_type, skip internal symbols + ret_strs = ret_str.split() + if ret_strs[1] == "internal": + return None + ret_type = convert_type(ret_strs[1]) + if ret_type is None: + return None + # Get function name + func_strs = func_str.split("(") + func_name = func_strs[0].replace("@", "") + op_name = func_name.replace("__nv_", "") + if 'ieee' in op_name: + return None + # Get arg_types + arg_strs = func_strs[1].split(",") + arg_types = [] + arg_names = [] + for i, arg_str in enumerate(arg_strs): + arg_type = convert_type(arg_str.split()[0]) + if arg_type is None: + return None + arg_name = 'arg' + str(i) + arg_types.append(arg_type) + arg_names.append(arg_name) + if op_name == "sad": + # Special case for sad, where the last argument is an unsigned int + arg_types[-1] = to_unsigned(arg_types[-1]) + elif op_name.startswith("u"): + # LLVM does not differentiate between signed and unsigned integer type. + # We have to convert the types to unsigned + ret_type = to_unsigned(ret_type) + for i, arg_type in enumerate(arg_types): + arg_types[i] = to_unsigned(arg_type) + return Symbol(func_name, op_name, ret_type, arg_names, arg_types) + + def _group_symbols(self) -> None: + symbol_set = {} + for symbol in self._symbols.values(): + op_name = symbol.op_name + symbol_set[op_name] = symbol + + # Group functions together by renaming. + renaming = { + 'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': + 'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz': + 'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh', + 'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos', + 'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1', + 'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', + 'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf': + 'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2', + 'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll': + 'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru', + 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff': + 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn', + 'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f': + 'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax': + 'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min', + 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', + 'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24', + 'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': + 'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv', + 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', + 'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru', + 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot', + 'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt', + 'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit', + 'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd': + 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', + 'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn', + 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz', + 'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf': + 'yn' + } + + for symbol in self._symbols.values(): + op_name = symbol.op_name + if op_name in renaming: + op_name = renaming[op_name] + symbol._op_name = op_name + if op_name in self._symbol_groups: + self._symbol_groups[op_name].append(symbol) + else: + self._symbol_groups[op_name] = [symbol] + + def parse_symbols(self, input_file) -> None: + if len(self.symbols) > 0: + return + output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines() + for line in output: + symbol = self._extract_symbol(line) + if symbol is None: + continue + self._symbols[symbol.name] = symbol + + self._group_symbols() + + def _output_stubs(self) -> str: + # Generate python functions in the following format: + # @extern.extern + # def (, _builder=None): + # arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}} + # return core.extern_elementwise("libdevice", , , , _builder) + import_str = "from . import core\n" + + header_str = "" + func_str = "" + for symbols in self._symbol_groups.values(): + func_str += "@core.extern\n" + func_name_str = f"def {symbols[0].op_name}(" + for arg_name in symbols[0].arg_names: + func_name_str += f"{arg_name}, " + func_name_str += "_builder=None):\n" + + return_str = f"\treturn core.extern_elementwise(\"{self._name}\", libdevice_path(), [" + for arg_name in symbols[0].arg_names: + return_str += f"{arg_name}, " + return_str += "], \n" + + arg_type_symbol_dict_str = "{" + for symbol in symbols: + arg_type_symbol_dict_str += "(" + for arg_type in symbol.arg_types: + arg_type_symbol_dict_str += f'core.dtype("{arg_type}"),' + ret_type = f'core.dtype("{symbol.ret_type}")' + arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n" + arg_type_symbol_dict_str += "}" + + return_str += arg_type_symbol_dict_str + return_str += f", is_pure={self.is_pure}" + return_str += ", _builder=_builder)\n" + + func_str += func_name_str + return_str + "\n" + file_str = import_str + header_str + func_str + + return file_str + + +class LLVMDisassembler: + _path: str + _ll_file: str + + def __init__(self, path) -> None: + ''' + Invoke llvm-dis to disassemble the given file. + :param path: path to llvm-dis + ''' + self._path = path + self._ll_file = "/tmp/extern_lib.ll" + + def disasm(self, lib_path: str) -> None: + subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate() + + @property + def ll_file(self) -> str: + return self._ll_file + + @property + def path(self) -> str: + return self._path + + +extern_libs = ["libdevice"] + + +def build( + llvm_dis_path: str, + lib_path: str, + lib_name: str, + output_dir: str, +) -> None: + ''' + Interface function to build the library file. + :param llvm_dis_path: path to the llvm-dis binary + :param lib_path: path to the external library file + :param lib_name: name of the library + :param output_dir: path to the output directory + ''' + if lib_name == "libdevice": + extern_lib = Libdevice(lib_path) + else: + raise Exception(f"Unknown extern library: {lib_name}") + + llvm_disassembler = LLVMDisassembler(llvm_dis_path) + llvm_disassembler.disasm(lib_path) + + extern_lib.parse_symbols(llvm_disassembler.ll_file) + extern_lib.generate_stub_file(output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis") + parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library") + parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library") + parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/") + args = parser.parse_args() + + build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir) diff --git a/third_party/sunrise/python/triton/tools/compile.py b/third_party/sunrise/python/triton/tools/compile.py new file mode 100644 index 000000000..7eed34389 --- /dev/null +++ b/third_party/sunrise/python/triton/tools/compile.py @@ -0,0 +1,162 @@ +import binascii +import hashlib +import importlib.util +import sys +from argparse import ArgumentParser +from pathlib import Path +from typing import List + +import triton +import triton.backends +from triton.backends.nvidia.driver import ty_to_cpp + +desc = """ +Triton ahead-of-time compiler: + +This program compiles the kernel with name `kernel-name` in the file at the +provided `path` into self-contained C source-code that embeds the `cubin` +data along with utilities to load, unload and launch the kernel. + +signature is provided as a list of (optionally divisibility-hinted) types +or constexpr values, e.g. + +`compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py` + +will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`. +Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16, +and argument 2 is assumed to be a compile-time constant of value 1024, i.e. it won't be part of the generated prototype. + +The resulting entry point will have signature + +CUresult kernel_{specialization_suffix}(CUstream stream, unsigned gX, unsigned gY, unsigned gZ, float* arg0, int32_t arg1, int32_t arg2) + +Different such specialized entry points can be combined using the `linker.py` script. + +NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed from within its parent directory with the python interpreter +used to run this `compile.py` script +""" + +if __name__ == "__main__": + + # command-line arguments + parser = ArgumentParser(description=desc) + parser.add_argument("path", + help="Path to Python source containing desired kernel in its scope. File will be executed.") + parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", + required=True) + parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel") + parser.add_argument("--num-stages", "-ns", type=int, default=3, + help="Number of stages (meta-parameter of the kernel)") + parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel") + parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename") + parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True) + parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True) + args = parser.parse_args() + + out_name = args.out_name if args.out_name else args.kernel_name + out_path = args.out_path if args.out_path else Path(out_name) + + # execute python sources and extract functions wrapped in JITFunction + arg_path = Path(args.path) + sys.path.insert(0, str(arg_path.parent)) + spec = importlib.util.spec_from_file_location(arg_path.stem, arg_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + kernel = getattr(mod, args.kernel_name) + grid = args.grid.split(",") + assert len(grid) == 3 + + # validate and parse signature + signature = list(map(lambda s: s.strip(" "), args.signature.split(","))) + + def hash_signature(signature: List[str]): + m = hashlib.sha256() + m.update(" ".join(signature).encode()) + return m.hexdigest()[:8] + + meta_sig = f"warps{args.num_warps}xstages{args.num_stages}" + sig_hash = hash_signature(signature + [meta_sig]) + + def constexpr(s): + try: + ret = int(s) + return ret + except ValueError: + pass + try: + ret = float(s) + return ret + except ValueError: + pass + return None + + hints = {(i, ): constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} + hints = {k: v for k, v in hints.items() if v is not None} + constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)} + constants = {k: v for k, v in constants.items() if v is not None} + for key, value in hints.items(): + if value == 1: + constants[kernel.arg_names[key[0]]] = value + signature = {kernel.arg_names[i]: s.split(":")[0] for i, s in enumerate(signature)} + for key in constants: + signature[key] = 'constexpr' + const_sig = 'x'.join([str(v) for v in constants.values()]) + doc_string = [f"{k}={v}" for k, v in constants.items()] + doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] + # compile ast into cubin + for h in hints.values(): + assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" + attrs = {k: [["tt.divisibility", 16]] for k, v in hints.items() if v == 16} + src = triton.compiler.ASTSource(fn=kernel, constexprs=constants, signature=signature, attrs=attrs) + opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} + ccinfo = triton.compile(src, options=opts) + if ccinfo.metadata.global_scratch_size > 0: + raise RuntimeError("AOT compiling kernels with global scratch requirements is not yet implemented") + + arg_names = [] + arg_types = [] + arg_names_not_1 = [] + arg_types_not_1 = [] + for i, arg_name in enumerate(kernel.arg_names): + if arg_name not in constants: + arg_names.append(arg_name) + arg_types.append(signature[arg_name]) + arg_names_not_1.append(arg_name) + arg_types_not_1.append(signature[arg_name]) + elif hints.get((i, ), None) == 1: + arg_names.append(arg_name) + arg_types.append("i32") + + # dump C stub code + suffix = '' + for i, ty in enumerate(signature.values()): + suffix += str(i) + if hints.get((i, ), None) == 1: + suffix += 'c' + if hints.get((i, ), None) == 16: + suffix += 'd' + func_name = '_'.join([out_name, sig_hash, suffix]) + asm = ccinfo.asm["cubin"] # store binary data once + hex_ = str(binascii.hexlify(asm))[2:-1] + params = { + "kernel_name": func_name, + "triton_kernel_name": args.kernel_name, + "bin_size": len(asm), + "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]), + "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names_not_1, arg_types_not_1)]), + "full_signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]), + "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"]), + "num_args": len(arg_names_not_1) + 1, + "kernel_docstring": doc_string, + "shared": ccinfo.metadata.shared, + "num_warps": args.num_warps, + "algo_info": '_'.join([const_sig, meta_sig]), + "gridX": grid[0], + "gridY": grid[1], + "gridZ": grid[2], + "_placeholder": "", + } + for ext in ['h', 'c']: + template_path = Path(__file__).parent / "extra" / "cuda" / f"compile.{ext}" + with out_path.with_suffix(f".{sig_hash}_{suffix}.{ext}").open("w") as fp: + fp.write(Path(template_path).read_text().format(**params)) diff --git a/third_party/sunrise/python/triton/tools/disasm.py b/third_party/sunrise/python/triton/tools/disasm.py new file mode 100644 index 000000000..c2301fd2e --- /dev/null +++ b/third_party/sunrise/python/triton/tools/disasm.py @@ -0,0 +1,143 @@ +# MIT License + +# Copyright (c) 2020 Da Yan @ HKUST + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import functools +import os +import re +import subprocess +import tempfile + +FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') +SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') +FNAME_RE = re.compile(r'\s*Function : (\w+)\s*') +BRA_RE = re.compile(r'(.*BRA(?:\.U)? )(0x\w+);') + + +def parseCtrl(sline): + enc = int(SLINE_RE.match(sline).group(1), 16) + stall = (enc >> 41) & 0xf + yld = (enc >> 45) & 0x1 + wrtdb = (enc >> 46) & 0x7 + readb = (enc >> 49) & 0x7 + watdb = (enc >> 52) & 0x3f + + yld_str = 'Y' if yld == 0 else '-' + wrtdb_str = '-' if wrtdb == 7 else str(wrtdb) + readb_str = '-' if readb == 7 else str(readb) + watdb_str = '--' if watdb == 0 else f'{watdb:02d}' + return f'{watdb_str}:{readb_str}:{wrtdb_str}:{yld_str}:{stall:x}' + + +def processSassLines(fline, sline, labels): + asm = FLINE_RE.match(fline).group(1) + # Remove tailing space + if asm.endswith(" ;"): + asm = asm[:-2] + ";" + ctrl = parseCtrl(sline) + # BRA target address + if BRA_RE.match(asm) is not None: + target = int(BRA_RE.match(asm).group(2), 16) + if target in labels: + pass + else: + labels[target] = len(labels) + return (f'{ctrl}', f'{asm}') + + +@functools.lru_cache() +def get_sass(cubin_asm, fun=None): + fd, path = tempfile.mkstemp() + try: + with open(fd, 'wb') as cubin: + cubin.write(cubin_asm) + sass = extract(path, fun) + finally: + os.remove(path) + return sass + + +def path_to_cuobjdump(): + from triton import knobs + return knobs.nvidia.cuobjdump.path + + +def extract(file_path, fun): + cuobjdump = path_to_cuobjdump() + if fun is None: + sass_str = subprocess.check_output([cuobjdump, "-sass", file_path]) + else: + sass_str = subprocess.check_output([cuobjdump, "-fun", fun, "-sass", file_path]) + sass_lines = sass_str.splitlines() + line_idx = 0 + while line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + # format: + # function : + # .headerflags: ... + # /*0000*/ asmstr /*0x...*/ + # /*0x...*/ + + # Looking for new function header (function: ) + while FNAME_RE.match(line) is None: + line_idx += 1 + if line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + else: + return + + fname = FNAME_RE.match(line).group(1) + ret = '' + ret += f'Function:{fname}\n' + line_idx += 2 # bypass .headerflags + line = sass_lines[line_idx].decode() + # Remapping address to label + labels = {} # address -> label_idx + # store sass asm in buffer and them print them (for labels) + # (ctrl, asm) + asm_buffer = [] + while FLINE_RE.match(line) is not None: + # First line (Offset ASM Encoding) + fline = sass_lines[line_idx].decode() + line_idx += 1 + # Second line (Encoding) + sline = sass_lines[line_idx].decode() + line_idx += 1 + asm_buffer.append(processSassLines(fline, sline, labels)) + # peek the next line + line = sass_lines[line_idx].decode() + # Print sass + # label naming convention: LBB#i + for idx, (ctrl, asm) in enumerate(asm_buffer): + # Print label if this is BRA target + offset = idx * 16 + if offset in labels: + label_name = f'LBB{labels[offset]}' + ret += f'{label_name}:\n' + ret += ctrl + '\t' + # if this is BRA, remap offset to label + if BRA_RE.match(asm): + target = int(BRA_RE.match(asm).group(2), 16) + target_name = f'LBB{labels[target]}' + asm = BRA_RE.sub(rf'\1{target_name};', asm) + ret += asm + '\n' + ret += '\n' + return ret diff --git a/third_party/sunrise/python/triton/tools/extra/cuda b/third_party/sunrise/python/triton/tools/extra/cuda new file mode 120000 index 000000000..e5812852d --- /dev/null +++ b/third_party/sunrise/python/triton/tools/extra/cuda @@ -0,0 +1 @@ +/mnt/data/lisirui/TritonDev/Repos/tb/tt34/triton/third_party/nvidia/tools/cuda \ No newline at end of file diff --git a/third_party/sunrise/python/triton/tools/link.py b/third_party/sunrise/python/triton/tools/link.py new file mode 100644 index 000000000..75a1157a5 --- /dev/null +++ b/third_party/sunrise/python/triton/tools/link.py @@ -0,0 +1,322 @@ +from collections import defaultdict +from pathlib import Path +from typing import Sequence, Union + +from dataclasses import dataclass + + +def _exists(x): + return x is not None + + +class LinkerError(Exception): + pass + + +@dataclass +class KernelLinkerMeta: + orig_kernel_name: str + arg_names: Sequence[str] + arg_ctypes: Sequence[str] + sizes: Sequence[Union[int, None]] + sig_hash: str + triton_suffix: str + suffix: str + num_specs: int + """ number of specialized arguments """ + + +class HeaderParser: + + def __init__(self) -> None: + import re + + # [kernel_name, c signature] + self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)") + # [name, hash, suffix] + self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$") + # [(type, name)] + self.c_sig = re.compile("[\\s]*(\\w+)\\s(\\w+)[,]?") + # [d|c] + self.arg_suffix = re.compile("[c,d]") + + self.kernels = defaultdict(list) + + def extract_linker_meta(self, header: str): + for ln in header.splitlines(): + if ln.startswith("//"): + m = self.linker_directives.match(ln) + if _exists(m): + ker_name, c_sig, algo_info = m.group(1), m.group(2), m.group(3) + name, sig_hash, suffix = self._match_name(ker_name) + c_types, arg_names = self._match_c_sig(c_sig) + num_specs, sizes = self._match_suffix(suffix, c_sig) + self._add_kernel( + "_".join([name, algo_info]), + KernelLinkerMeta( + orig_kernel_name=name, + arg_names=arg_names, + arg_ctypes=c_types, + sizes=sizes, + sig_hash=sig_hash, + triton_suffix=suffix, + suffix=suffix, + num_specs=num_specs, + ), + ) + + def _match_name(self, ker_name: str): + m = self.kernel_name.match(ker_name) + if _exists(m): + name, sig_hash, suffix = m.group(1), m.group(2), m.group(3) + return name, sig_hash, suffix + raise LinkerError(f"{ker_name} is not a valid kernel name") + + def _match_c_sig(self, c_sig: str): + m = self.c_sig.findall(c_sig) + if len(m): + tys, args = [], [] + for ty, arg_name in m: + tys.append(ty) + args.append(arg_name) + return tys, args + + raise LinkerError(f"{c_sig} is not a valid argument signature") + + def _match_suffix(self, suffix: str, c_sig: str): + args = c_sig.split(",") + s2i = {"c": 1, "d": 16} + num_specs = 0 + sizes = [] + # scan through suffix, first find the index, + # then see if it is followed by d or c + for i in range(len(args)): + pos = suffix.find(str(i)) + if pos == -1: + raise LinkerError(f"{suffix} is not a valid kernel suffix") + pos += len(str(i)) + if self.arg_suffix.match(suffix, pos): + num_specs += 1 + sizes.extend([None] * (i - len(sizes))) + sizes.append(s2i[suffix[pos]]) + pos += 1 + if i < len(args) - 1: + suffix = suffix[pos:] + else: + sizes.extend([None] * (len(args) - len(sizes))) + return num_specs, sizes + + def _add_kernel(self, name: str, ker: KernelLinkerMeta): + if name in self.kernels: + last: KernelLinkerMeta = self.kernels[name][-1] + + for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes): + if cur != new_: + raise LinkerError( + f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}" + ) + + self.kernels[name].append(ker) + + +def gen_signature_with_full_args(m): + return ", ".join([f"{ty} {arg}" for ty, arg in zip(m.arg_ctypes, m.arg_names)]) + + +def gen_signature(m): + arg_types = [ty for ty, hint in zip(m.arg_ctypes, m.sizes) if hint != 1] + arg_names = [arg for arg, hint in zip(m.arg_names, m.sizes) if hint != 1] + sig = ", ".join([f"{ty} {arg}" for ty, arg in zip(arg_types, arg_names)]) + return sig + + +# generate declarations of kernels with meta-parameter and constant values +def make_algo_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str: + return f""" +CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}); +void load_{name}(); +void unload_{name}(); + """ + + +# generate declarations of kernels with meta-parameter and constant values +def make_global_decl(meta: KernelLinkerMeta) -> str: + return f""" +CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}); +CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id); +void load_{meta.orig_kernel_name}(); +void unload_{meta.orig_kernel_name}(); + """ + + +# generate dispatcher function for kernels with different meta-parameter and constant values +def make_default_algo_kernel(meta: KernelLinkerMeta) -> str: + src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n" + src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n") + src += "}\n" + return src + + +# generate dispatcher function for kernels with different integer value hints +def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str: + src = f"// launcher for: {name}\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n" + src += "\n" + + src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{") + src += "\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + cond_fn = ( # + lambda val, hint: f"({val} % {hint} == 0)" # + if hint == 16 # + else f"({val} == {hint})" # + if hint == 1 # + else None) + conds = " && ".join([ # + cond_fn(val, hint) # + for val, hint in zip(meta.arg_names, meta.sizes) # + if hint is not None + ]) + src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n" + ) # Edge case where no specializations hence no dispatching required + arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1] + src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n" + src += "\n" + src += " return CUDA_ERROR_INVALID_VALUE;\n" + src += "}\n" + + for mode in ["load", "unload"]: + src += f"\n// {mode} for: {name}\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += f"void {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n" + src += f"void {mode}_{name}() {{" + src += "\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n") + src += "}\n" + return src + + +# generate dispatcher function for kernels with different meta-parameter and constant values +def make_kernel_meta_const_dispatcher(meta: KernelLinkerMeta) -> str: + src = f"CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n" + src += f" assert (algo_id < (int)sizeof({meta.orig_kernel_name}_kernels));\n" + src += f" return {meta.orig_kernel_name}_kernels[algo_id](stream, {', '.join(meta.arg_names)});\n" + src += "}\n" + return src + + +# generate definition of function pointers of kernel dispatchers based on meta-parameter and constant values +def make_func_pointers(names: str, meta: KernelLinkerMeta) -> str: + # the table of hint dispatchers + src = f"typedef CUresult (*kernel_func_t)(CUstream stream, {gen_signature_with_full_args(meta)});\n" + src += f"kernel_func_t {meta.orig_kernel_name}_kernels[] = {{\n" + for name in names: + src += f" {name},\n" + src += "};\n" + return src + + +# generate definition for load/unload functions for kernels with different meta-parameter and constant values +def make_kernel_load_def(names: str, meta: KernelLinkerMeta) -> str: + src = "" + for mode in ["load", "unload"]: + src += f"void {mode}_{meta.orig_kernel_name}(void){{\n" + for name in names: + src += f" {mode}_{name}();\n" + src += "}\n\n" + return src + + +def make_get_num_algos_decl(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_get_num_algos(void);" + return src + + +def make_get_num_algos_def(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_get_num_algos(void){{\n" + src += f" return (int)(sizeof({meta.orig_kernel_name}_kernels) / sizeof({meta.orig_kernel_name}_kernels[0]));\n" + src += "}\n" + return src + + +desc = """ +Triton ahead-of-time linker: + +This program takes in header files generated by compile.py, and generates a +single entry-point responsible for dispatching the user's input to the right +kernel given the specializations that were compiled. + +Example usage: +python link.py /path/to/headers/*.h -o kernel_name +""" + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser(description=desc) + parser.add_argument( + "headers", + nargs="+", + help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)", + ) + parser.add_argument("--out", "-o", type=Path, help="Out filename") + parser.add_argument( + "--prefix", + type=str, + default="", + help="String to prefix kernel dispatcher names", + ) + args = parser.parse_args() + + # metadata + parser = HeaderParser() + includes = [] + for header in args.headers: + h_path = Path(header) + h_str = h_path.read_text() + includes.append(h_path.name) + parser.extract_linker_meta(h_str) + + # generate headers + algo_decls = [make_algo_decls(name, meta) for name, meta in parser.kernels.items()] + meta_lists = [meta for name, meta in parser.kernels.items()] + meta = meta_lists[0][0] + get_num_algos_decl = make_get_num_algos_decl(meta) + global_decl = make_global_decl(meta) + with args.out.with_suffix(".h").open("w") as fp: + out = "#include \n" + out += "\n".join(algo_decls) + out += "\n" + out += get_num_algos_decl + out += "\n" + out += global_decl + fp.write(out) + + # generate source + defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()] + names = [name for name in parser.kernels.keys()] + func_pointers_def = make_func_pointers(names, meta) + meta_const_def = make_kernel_meta_const_dispatcher(meta) + load_unload_def = make_kernel_load_def(names, meta) + get_num_algos_def = make_get_num_algos_def(meta) + default_algo_kernel = make_default_algo_kernel(meta) + with args.out.with_suffix(".c").open("w") as fp: + out = "" + out += "#include \n" + out += "#include \n" + out += "#include \n" + out += "\n" + out += "\n".join(defs) + out += "\n" + out += func_pointers_def + out += "\n" + out += get_num_algos_def + out += "\n" + out += meta_const_def + out += "\n" + out += load_unload_def + out += "\n" + out += default_algo_kernel + fp.write(out) diff --git a/third_party/sunrise/python/triton/tools/mxfp.py b/third_party/sunrise/python/triton/tools/mxfp.py new file mode 100644 index 000000000..1b129c1ae --- /dev/null +++ b/third_party/sunrise/python/triton/tools/mxfp.py @@ -0,0 +1,301 @@ +""" +Helper classes for working with low precision floating point types that +align with the opencompute (OCP) microscaling (MX) specification. + * MXFP4Tensor: 4-bit E2M1 floating point data + * MXScaleTensor: 8-bit E8M0 floating point data +Reference: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf +""" + +import torch + + +class MXFP4Tensor: + + def __init__(self, data=None, size=None, device=None): + """ + Tensor class for working with four bit E2M1 floating point data as defined by the + opencompute microscaling specification. + + + Parameters: + - data: A torch tensor of float32 numbers to convert to fp4e2m1 microscaling format. + - size: The size of the tensor to create. + - device: The device on which to create the tensor. + """ + self.device = device + if data is not None: + assert isinstance(data, torch.Tensor), "Parameter data must be a torch tensor" + self.device = data.device + self.data = self._from_float(data) + elif size is not None: + self.size = size if isinstance(size, tuple) else (size, ) + else: + raise ValueError("Either parameter data or size must be provided") + + def random(self): + S = torch.randint(0, 2, size=self.size, dtype=torch.uint8, device=self.device) + E = torch.randint(0, 4, size=self.size, dtype=torch.uint8, device=self.device) + M = torch.randint(0, 2, size=self.size, dtype=torch.uint8, device=self.device) + + self.data = ((S << 3) | (E << 1) | M).type(torch.uint8) + return self + + def to(self, dtype): + """ + Convert fp4e2m1 data to float32. + + Returns: + - A torch tensor of type dtype representing the fp4e2m1 data. + """ + assert dtype == torch.float32, "Currently only float32 is supported for fp4e2m1 to float conversion" + + data = self.data + S = ((data >> 3) & 0x1).type(dtype) + E = ((data >> 1) & 0x3).type(dtype) + M = (data & 0x1).type(dtype) + + # The MXF4 E2M1 spec defines 0bS000 as zero + value = torch.zeros_like(S) + is_zero = (E == 0) & (M == 0) + non_zero_mask = ~is_zero + if non_zero_mask.any(): + S_nz = S[non_zero_mask] + E_nz = E[non_zero_mask] + M_nz = M[non_zero_mask] + + sign = torch.pow(-1, S_nz) + # Normal and subnormal handling for the exponent and mantissa + exponent = torch.where(E_nz == 0, E_nz, E_nz - 1) + mantissa = torch.where(E_nz == 0, M_nz * 0.5, 1.0 + M_nz * 0.5) + value_nz = sign * torch.pow(2, exponent) * mantissa + + value[non_zero_mask] = value_nz + + # For zeros, the values must remain zero with the correct sign + value[is_zero & (S == 1)] *= -1 + return value.type(torch.float32) + + def _from_float(self, values): + """ + Convert float32 numbers to mxf4 e2m1 format. + * No encodings are reserved for Inf or NaN in mxf4. + * Conversion from float supports roundTiesToEven rounding mode. + * If a value exceeds the mxf4 representable range after rounding, + clamps to the maximum mxf4 magnitude, preserving the sign. + * If a value has magnitude less than the minimum subnormal magnitude + in mxf4 after rounding, converts to zero. + + Parameters: + - values: A torch tensor of float32 numbers to convert to fp4 format. + """ + S = torch.signbit(values).type(torch.uint8) + abs_values = torch.abs(values) + + is_zero = (abs_values == 0) + is_invalid = torch.isnan(values) | torch.isinf(values) + + # Enumerate all possible E2M1 exponent and mantissa values. We will + # use these to compare the distance between float32 and all possible + # E2M1 floats to find the nearest E2M1 representable value + E_bits = torch.tensor([0, 1, 2, 3], dtype=torch.uint8, device=self.device) + M_bits = torch.tensor([0, 1], dtype=torch.uint8, device=self.device) + + candidate_values = [] + candidate_E = [] + candidate_M = [] + + for E in E_bits: + if E == 0: + # Subnormals + exponent = 0 + for M in M_bits: + significand = M * 0.5 + value = significand * (2**exponent) + candidate_values.append(value) + candidate_E.append(E) + candidate_M.append(M) + else: + # Normals + exponent = E.item() - 1 + for M in M_bits: + significand = 1.0 + M * 0.5 + value = significand * (2**exponent) + candidate_values.append(value) + candidate_E.append(E) + candidate_M.append(M) + + candidates = torch.tensor(candidate_values, dtype=torch.float32, device=self.device) + candidate_E = torch.tensor(candidate_E, dtype=torch.uint8, device=self.device) + candidate_M = torch.tensor(candidate_M, dtype=torch.uint8, device=self.device) + + abs_values_flat = abs_values.view(-1) + N = abs_values_flat.shape[0] + abs_values_expanded = abs_values_flat.unsqueeze(1) + + # Clamp invalid values to the max e2m1 representable value + max_candidate_value = candidates.max().item() + abs_values_flat[is_invalid.view(-1)] = max_candidate_value + + # Compute distance between all abs_values and candidate e2m1 values + errors = torch.abs(abs_values_expanded - candidates.unsqueeze(0)) + + # To implement roundTiesToEven, we need to break ties by preferring + # even mantissas (M == 0). We do so by adding an epsilon bias to shift + # the closest candidate with an even mantissa closer to the float value + min_errors, _ = torch.min(errors, dim=1, keepdim=True) + is_tie = (errors == min_errors) + # More than one candidate has the min error for some float value + if is_tie.sum() > 1: + M_bits_expanded = candidate_M.unsqueeze(0).expand(N, -1) + tie_breaker = (M_bits_expanded == 0).type(torch.int32) + + errors = errors - (tie_breaker * 1e-6) + + best_indices = torch.argmin(errors, dim=1) + + E_selected = candidate_E[best_indices] + M_selected = candidate_M[best_indices] + E = E_selected.view(abs_values.shape) + M = M_selected.view(abs_values.shape) + + E[is_zero] = 0 + M[is_zero] = 0 + + return ((S << 3) | (E << 1) | M).type(torch.uint8) + + def to_packed_tensor(self, dim): + """ + Packs two e2m1 elements into a single uint8 along the specified dimension. + + Parameters: + - dim: The dimension along which to pack the elements. + + Returns: + - A torch tensor of dtype uint8 with two e2m1 elements packed into one uint8. + """ + data = self.data + assert 0 <= dim < data.ndim, \ + "The dimension to pack along is not within the range of tensor dimensions" + + size_along_dim = data.size(dim) + new_size_along_dim = (size_along_dim + 1) // 2 + + # If the size is odd, we pad the data along dim with zeros at the end + if size_along_dim % 2 != 0: + pad_sizes = [0] * (2 * data.ndim) + pad_index = (data.ndim - dim - 1) * 2 + 1 + pad_sizes[pad_index] = 1 + data = torch.nn.functional.pad(data, pad_sizes, mode='constant', value=0) + + new_shape = list(data.shape) + new_shape[dim] = new_size_along_dim + new_shape.insert(dim + 1, 2) # packed dimension of length 2 + data = data.reshape(*new_shape) + + low = data.select(dim + 1, 0) + high = data.select(dim + 1, 1) + packed = (high << 4) | low + + return packed + + def unpack_packed_tensor(self, packed_tensor, dim, original_shape): + """ + Unpacks a tensor where two fp4 elements are packed into a single uint8. + + Parameters: + - packed_tensor: The packed tensor + - dim: The dimension along which the tensor was packed. + - original_shape: The shape of the original tensor before packing. + + Returns: + - A tensor with the original data unpacked into uint8 elements containing one + fp4e2m1 element in the least significant bits. + """ + high = (packed_tensor >> 4) & 0xF + low = packed_tensor & 0xF + + stacked = torch.stack((low, high), dim=dim + 1) + + # Flatten along dim and dim+1 and then merge + shape = list(stacked.shape) + new_shape = shape[:dim] + [shape[dim] * 2] + shape[dim + 2:] + data = stacked.reshape(*new_shape) + + # Remove any padding + if original_shape[dim] % 2 != 0: + indices = [slice(None)] * data.ndim + indices[dim] = slice(0, original_shape[dim]) + data = data[tuple(indices)] + + return data.type(torch.uint8) + + +class MXScaleTensor: + + def __init__(self, data=None, size=None, device=None): + """ + Tensor class for working with microscaling E8M0 block scale factors. + + Parameters: + - data: A torch tensor of float32 numbers to convert to fp8e8m0 microscaling format. + - size: The size of the tensor to create. + - device: The device on which to create the tensor. + """ + self.device = device + if data is not None: + assert isinstance(data, torch.Tensor), "Parameter data must be a torch tensor" + self.device = data.device + self.data = self._from_float(data) + elif size is not None: + self.size = size if isinstance(size, tuple) else (size, ) + else: + raise ValueError("Either parameter data or size must be provided") + + def random(self, low=None, high=None): + """ + Generate random E8M0 data within a specified range. + * Excludes the NaN encoding (255). + """ + bias = 127 + + min_exponent = 0 if low is None else max(0, int(torch.log2(torch.tensor(low))) + bias) + max_exponent = 254 if high is None else min(254, max(0, int(torch.log2(torch.tensor(high))) + bias)) + assert min_exponent <= max_exponent, "Low must be less than or equal to high" + + E = torch.randint(min_exponent, max_exponent + 1, size=self.size, dtype=torch.uint8, device=self.device) + self.data = E + return self + + def to(self, dtype): + assert dtype == torch.float32, "Currently only float32 is supported for f8e8m0 to float conversion" + data = self.data.type(dtype) + is_nan = (data == 255) + e_biased = data.clone() + e_biased[is_nan] = 0 + e = e_biased - 127 + value = torch.pow(2.0, e) + value[is_nan] = torch.nan + return value.type(dtype) + + def _from_float(self, values): + """ + Convert float32 numbers to E8M0 format. + * Values <= 0, NaNs, and Infs are converted to the NaN encoding (255). + * Positive values are converted by computing the floor of log2(value) to get the exponent. + + Parameters: + - values: A torch tensor of float32 numbers to convert to E8M0 format. + """ + result = torch.empty_like(values, dtype=torch.uint8, device=self.device) + + is_invalid = torch.isnan(values) | torch.isinf(values) | (values <= 0) + result[is_invalid] = 255 + + valid_values = values[~is_invalid] + e = torch.floor(torch.log2(valid_values)) + e_biased = e + 127 + e_biased_int = e_biased.type(torch.int32) + e_biased_clamped = torch.clamp(e_biased_int, 0, 254) + result[~is_invalid] = e_biased_clamped.type(torch.uint8) + + return result diff --git a/third_party/sunrise/python/triton/tools/tensor_descriptor.py b/third_party/sunrise/python/triton/tools/tensor_descriptor.py new file mode 100644 index 000000000..21140b8b6 --- /dev/null +++ b/third_party/sunrise/python/triton/tools/tensor_descriptor.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass +from typing import List, Any +from triton._utils import validate_block_shape + + +@dataclass +class TensorDescriptor: + base: Any + shape: List[int] + strides: List[int] + block_shape: List[int] + + def __post_init__(self): + rank = len(self.shape) + assert len(self.strides) == rank, f"rank mismatch: {self}" + assert len(self.block_shape) == rank, f"rank mismatch: {self}" + assert rank > 0, "rank must not be zero" + assert rank <= 5, "rank cannot be more than 5" + ty = type(self.base) + type_name = f"{ty.__module__}.{ty.__name__}" + if type_name not in ("torch.FakeTensor", "torch.FunctionalTensor"): + assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned" + validate_block_shape(self.block_shape) + elem_bytes = self.base.dtype.itemsize + for stride in self.strides[:-1]: + assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned" + assert self.strides[-1] == 1, "Last dimension must be contiguous" + + @staticmethod + def from_tensor(tensor: Any, block_shape: List[int]): + return TensorDescriptor( + tensor, + tensor.shape, + tensor.stride(), + block_shape, + ) diff --git a/third_party/triton_shared b/third_party/triton_shared new file mode 160000 index 000000000..08684f92a --- /dev/null +++ b/third_party/triton_shared @@ -0,0 +1 @@ +Subproject commit 08684f92ad30696362dce1760a83be889639a3e4 From 29d6c4c1fe32ce2a0c9d61b14f8ceba3e4d8024f Mon Sep 17 00:00:00 2001 From: lisirui1 Date: Fri, 16 Jan 2026 14:03:28 +0800 Subject: [PATCH 2/7] fix bug --- log.log | 64 ++++ setup.py | 3 + setup_tools/setup_helper.py | 4 +- third_party/sunrise/backend/driver.py | 7 +- third_party/sunrise/python/triton/_C | 1 + third_party/sunrise/python/triton/backends | 1 + .../python/triton/experimental/__init__.py | 0 .../triton/experimental/gluon/__init__.py | 4 + .../triton/experimental/gluon/_compiler.py | 0 .../triton/experimental/gluon/_runtime.py | 99 ++++++ .../experimental/gluon/language/__init__.py | 18 + .../experimental/gluon/language/_core.py | 312 ++++++++++++++++++ .../experimental/gluon/language/_layouts.py | 230 +++++++++++++ .../experimental/gluon/language/_math.py | 12 + .../experimental/gluon/language/_semantic.py | 287 ++++++++++++++++ .../experimental/gluon/language/_standard.py | 47 +++ .../gluon/language/nvidia/__init__.py | 4 + .../language/nvidia/blackwell/__init__.py | 202 ++++++++++++ .../gluon/language/nvidia/blackwell/tma.py | 32 ++ .../gluon/language/nvidia/hopper/__init__.py | 11 + .../gluon/language/nvidia/hopper/mbarrier.py | 51 +++ .../gluon/language/nvidia/hopper/tma.py | 96 ++++++ .../experimental/gluon/nvidia/__init__.py | 4 + .../experimental/gluon/nvidia/blackwell.py | 3 + .../experimental/gluon/nvidia/hopper.py | 40 +++ 25 files changed, 1525 insertions(+), 7 deletions(-) create mode 100644 log.log create mode 120000 third_party/sunrise/python/triton/_C create mode 120000 third_party/sunrise/python/triton/backends create mode 100644 third_party/sunrise/python/triton/experimental/__init__.py create mode 100644 third_party/sunrise/python/triton/experimental/gluon/__init__.py create mode 100644 third_party/sunrise/python/triton/experimental/gluon/_compiler.py create mode 100644 third_party/sunrise/python/triton/experimental/gluon/_runtime.py create mode 100644 third_party/sunrise/python/triton/experimental/gluon/language/__init__.py create mode 100644 third_party/sunrise/python/triton/experimental/gluon/language/_core.py create mode 100644 third_party/sunrise/python/triton/experimental/gluon/language/_layouts.py create mode 100644 third_party/sunrise/python/triton/experimental/gluon/language/_math.py create mode 100644 third_party/sunrise/python/triton/experimental/gluon/language/_semantic.py create mode 100644 third_party/sunrise/python/triton/experimental/gluon/language/_standard.py create mode 100644 third_party/sunrise/python/triton/experimental/gluon/language/nvidia/__init__.py create mode 100644 third_party/sunrise/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py create mode 100644 third_party/sunrise/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py create mode 100644 third_party/sunrise/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py create mode 100644 third_party/sunrise/python/triton/experimental/gluon/language/nvidia/hopper/mbarrier.py create mode 100644 third_party/sunrise/python/triton/experimental/gluon/language/nvidia/hopper/tma.py create mode 100644 third_party/sunrise/python/triton/experimental/gluon/nvidia/__init__.py create mode 100644 third_party/sunrise/python/triton/experimental/gluon/nvidia/blackwell.py create mode 100644 third_party/sunrise/python/triton/experimental/gluon/nvidia/hopper.py diff --git a/log.log b/log.log new file mode 100644 index 000000000..5eeb35bc3 --- /dev/null +++ b/log.log @@ -0,0 +1,64 @@ +Using pip 25.3 from /usr/local/lib/python3.10/dist-packages/pip (python 3.10) +Obtaining file:///root/WorkSpace/flagtree_close/flagtree + Checking if build backend supports build_editable: started + Running command Checking if build backend supports build_editable + Checking if build backend supports build_editable: finished with status 'done' + Preparing editable metadata (pyproject.toml): started + Running command Preparing editable metadata (pyproject.toml) + fatal: detected dubious ownership in repository at '/root/WorkSpace/flagtree_close/flagtree' + To add an exception for this directory, call: + + git config --global --add safe.directory /root/WorkSpace/flagtree_close/flagtree + fatal: detected dubious ownership in repository at '/root/WorkSpace/flagtree_close/flagtree' + To add an exception for this directory, call: + + git config --global --add safe.directory /root/WorkSpace/flagtree_close/flagtree + fatal: detected dubious ownership in repository at '/root/WorkSpace/flagtree_close/flagtree' + To add an exception for this directory, call: + + git config --global --add safe.directory /root/WorkSpace/flagtree_close/flagtree + /usr/local/lib/python3.10/dist-packages/setuptools/_distutils/dist.py:289: UserWarning: Unknown distribution option: 'test_suite' + warnings.warn(msg) + /usr/local/lib/python3.10/dist-packages/setuptools/dist.py:759: SetuptoolsDeprecationWarning: License classifiers are deprecated. + !! + + ******************************************************************************** + Please consider removing the following classifiers in favor of a SPDX license expression: + + License :: OSI Approved :: MIT License + + See https://packaging.python.org/en/latest/guides/writing-pyproject-toml/#license for details. + ******************************************************************************** + + !! + self._finalize_license_expression() + [INFO] FlagTree Offline Build: No offline build for triton origin toolkits + [INFO] FlagtreeBackend is sunrise + !!!================ + {'': 'python', 'triton': './third_party/sunrise/python/triton', 'triton.backends.sunrise': 'third_party/sunrise/backend'} + running dist_info + creating /tmp/pip-modern-metadata-7yt5bbbz/triton.egg-info + writing /tmp/pip-modern-metadata-7yt5bbbz/triton.egg-info/PKG-INFO + writing dependency_links to /tmp/pip-modern-metadata-7yt5bbbz/triton.egg-info/dependency_links.txt + writing entry points to /tmp/pip-modern-metadata-7yt5bbbz/triton.egg-info/entry_points.txt + writing requirements to /tmp/pip-modern-metadata-7yt5bbbz/triton.egg-info/requires.txt + writing top-level names to /tmp/pip-modern-metadata-7yt5bbbz/triton.egg-info/top_level.txt + writing manifest file '/tmp/pip-modern-metadata-7yt5bbbz/triton.egg-info/SOURCES.txt' + error: package directory 'third_party/sunrise/python/triton/language/extra/cuda' does not exist + error: subprocess-exited-with-error + + × Preparing editable metadata (pyproject.toml) did not run successfully. + │ exit code: 1 + ╰─> No available output. + + note: This error originates from a subprocess, and is likely not a problem with pip. + full command: /usr/bin/python3 /usr/local/lib/python3.10/dist-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py prepare_metadata_for_build_editable /tmp/tmpf93nemq1 + cwd: /root/WorkSpace/flagtree_close/flagtree + Preparing editable metadata (pyproject.toml): finished with status 'error' +error: metadata-generation-failed + +× Encountered error while generating package metadata. +╰─> from file:///root/WorkSpace/flagtree_close/flagtree + +note: This is an issue with the package mentioned above, not pip. +hint: See above for details. diff --git a/setup.py b/setup.py index 97b4947b0..955d41f06 100644 --- a/setup.py +++ b/setup.py @@ -620,6 +620,9 @@ def download_and_copy_dependencies(): def get_package_dirs(): yield ("", "python") + if helper.flagtree_backend: + yield ('triton', './third_party/sunrise/python/triton') + for backend in backends: # we use symlinks for external plugins diff --git a/setup_tools/setup_helper.py b/setup_tools/setup_helper.py index f79989c93..7ee3dc84b 100644 --- a/setup_tools/setup_helper.py +++ b/setup_tools/setup_helper.py @@ -260,7 +260,7 @@ def get_package_dir(packages): package_dict = {} if flagtree_backend and flagtree_backend not in plugin_backends: connection = [] - backend_triton_path = f"../third_party/{flagtree_backend}/python/" + backend_triton_path = f"./third_party/{flagtree_backend}/python/" for package in packages: if CommonUtils.skip_package_dir(package): continue @@ -280,7 +280,7 @@ def handle_flagtree_backend(): print(f"\033[1;32m[INFO] FlagtreeBackend is {flagtree_backend}\033[0m") extend_backends.append(flagtree_backend) if "editable_wheel" in sys.argv and flagtree_backend not in plugin_backends: - ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/" + ext_sourcedir = os.path.abspath(f"./third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/" def set_env(env_dict: dict): diff --git a/third_party/sunrise/backend/driver.py b/third_party/sunrise/backend/driver.py index 1008ed4a2..823214d33 100644 --- a/third_party/sunrise/backend/driver.py +++ b/third_party/sunrise/backend/driver.py @@ -4,8 +4,7 @@ import subprocess import re from pathlib import Path -#from triton import knobs -from ..python.triton import knobs +from triton import knobs from triton.runtime.build import compile_module_from_src from triton.runtime import _allocation from triton.backends.compiler import GPUTarget @@ -482,8 +481,6 @@ class SunriseDriver(GPUDriver): def __init__(self): self.utils = SunriseUtils() # TODO: make static self.launcher_cls = SunriseLauncher - from triton.backends.iluvatar import spec - self.spec = spec super().__init__() def get_current_target(self): @@ -520,4 +517,4 @@ def get_empty_cache_for_benchmark(self): return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda') def clear_cache(self, cache): - cache.zero_() \ No newline at end of file + cache.zero_() diff --git a/third_party/sunrise/python/triton/_C b/third_party/sunrise/python/triton/_C new file mode 120000 index 000000000..e17821ba1 --- /dev/null +++ b/third_party/sunrise/python/triton/_C @@ -0,0 +1 @@ +../../../../python/triton/_C \ No newline at end of file diff --git a/third_party/sunrise/python/triton/backends b/third_party/sunrise/python/triton/backends new file mode 120000 index 000000000..13a83a85c --- /dev/null +++ b/third_party/sunrise/python/triton/backends @@ -0,0 +1 @@ +../../../../python/triton/backends \ No newline at end of file diff --git a/third_party/sunrise/python/triton/experimental/__init__.py b/third_party/sunrise/python/triton/experimental/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/sunrise/python/triton/experimental/gluon/__init__.py b/third_party/sunrise/python/triton/experimental/gluon/__init__.py new file mode 100644 index 000000000..21fa325d9 --- /dev/null +++ b/third_party/sunrise/python/triton/experimental/gluon/__init__.py @@ -0,0 +1,4 @@ +from . import nvidia +from ._runtime import jit + +__all__ = ["jit", "nvidia"] diff --git a/third_party/sunrise/python/triton/experimental/gluon/_compiler.py b/third_party/sunrise/python/triton/experimental/gluon/_compiler.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/sunrise/python/triton/experimental/gluon/_runtime.py b/third_party/sunrise/python/triton/experimental/gluon/_runtime.py new file mode 100644 index 000000000..42c7b72bc --- /dev/null +++ b/third_party/sunrise/python/triton/experimental/gluon/_runtime.py @@ -0,0 +1,99 @@ +from __future__ import annotations +import triton +from triton.compiler.compiler import ASTSource +from triton.backends.compiler import Language +from triton.runtime.jit import JITFunction +from typing import TypeVar, Optional, Callable, Iterable, Union +from triton._C.libtriton import ir + +T = TypeVar("T") + + +class GluonASTSource(ASTSource): + + def __init__(self, fn, signature, constexprs=None, attrs=None) -> None: + super().__init__(fn, signature, constexprs, attrs) + self.language = Language.GLUON + self.ext = "ttgir" + + def make_ir(self, options, codegen_fns, module_map, context): + from triton.compiler.compiler import make_backend + from triton.compiler.code_generator import ast_to_ttir + + builder = ir.builder(context) + module = builder.create_module() + + # Assign module attributes eagerly, as they are needed to verify layouts + target = triton.runtime.driver.active.get_current_target() + backend = make_backend(target) + target = backend.get_target_name(options) + module.set_attr("ttg.target", builder.get_string_attr(target)) + module.set_attr("ttg.num-warps", builder.get_int32_attr(options.num_warps)) + module.set_attr("ttg.num-ctas", builder.get_int32_attr(options.num_ctas)) + module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(32)) + if options.maxnreg is not None: + module.set_attr("ttg.maxnreg", builder.get_int32_attr(options.maxnreg)) + + module = ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns, + module_map=module_map, module=module) + return module + + +class GluonJITFunction(JITFunction[T]): + + def create_binder(self): + result = super().create_binder() + self.ASTSource = GluonASTSource + return result + + def is_gluon(self): + return True + + +def jit( + fn: Optional[T] = None, + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int | str]] = None, + do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Union[GluonJITFunction[T], Callable[[T], JITFunction[T]]]: + """ + Decorator for JIT-compiling a function using the Triton compiler. + + :note: When a jit'd function is called, arguments are + implicitly converted to pointers if they have a :code:`.data_ptr()` method + and a `.dtype` attribute. + + :note: This function will be compiled and run on the GPU. It will only have access to: + + * python primitives, + * builtins within the triton package, + * arguments to this function, + * other jit'd functions + + :param fn: the function to be jit-compiled + :type fn: Callable + """ + + def decorator(fn: T) -> JITFunction[T]: + assert callable(fn) + return GluonJITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + do_not_specialize_on_alignment=do_not_specialize_on_alignment, + debug=debug, + noinline=noinline, + repr=repr, + launch_metadata=launch_metadata, + ) + + if fn is not None: + return decorator(fn) + + else: + return decorator diff --git a/third_party/sunrise/python/triton/experimental/gluon/language/__init__.py b/third_party/sunrise/python/triton/experimental/gluon/language/__init__.py new file mode 100644 index 000000000..109093c19 --- /dev/null +++ b/third_party/sunrise/python/triton/experimental/gluon/language/__init__.py @@ -0,0 +1,18 @@ +from ._core import * # NOQA: F403 +from ._core import __all__ as __core_all +from ._layouts import * # NOQA: F403 +from ._layouts import __all__ as __layouts_all +from ._math import * # NOQA: F403 +from ._math import __all__ as __math_all +from ._standard import * # NOQA: F403 +from ._standard import __all__ as __standard_all + +from . import nvidia + +__all__ = [ + *__core_all, + *__layouts_all, + *__math_all, + *__standard_all, + "nvidia", +] diff --git a/third_party/sunrise/python/triton/experimental/gluon/language/_core.py b/third_party/sunrise/python/triton/experimental/gluon/language/_core.py new file mode 100644 index 000000000..3ec509eeb --- /dev/null +++ b/third_party/sunrise/python/triton/experimental/gluon/language/_core.py @@ -0,0 +1,312 @@ +from __future__ import annotations +from typing import TypeVar, List, TYPE_CHECKING, Tuple +from functools import wraps + +if TYPE_CHECKING: + from triton._C.libtriton.gluon_ir import GluonOpBuilder + from ._semantic import GluonSemantic + +from ._layouts import SharedLayout, DistributedLayout +from triton._C.libtriton import ir +import triton.language.core as tl_core +from triton.language.core import ( + constexpr, + base_value, + base_type, + dtype, + block_type, # TODO: block type with layout info + pointer_type, + void, + int1, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float8e5, + float8e5b16, + float8e4nv, + float8e4b8, + float8e4b15, + float16, + bfloat16, + float32, + float64, + _unwrap_if_constexpr, + _unwrap_shape, + tensor, + tuple, + tuple_type, +) + +_IMPORT_FROM_TRITON: List[str] = [ + "expand_dims", + "join", + "load", + "maximum", + "minimum", + "permute", + "program_id", + "reduce", + "reshape", + "split", + "static_assert", + "static_print", + "store", + "to_tensor", + "where", + "inline_asm_elementwise", +] + +__all__ = [ + "constexpr", + "base_value", + "base_type", + "dtype", + "block_type", + "pointer_type", + "tuple_type", + "void", + "int1", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float8e5", + "float8e5b16", + "float8e4nv", + "float8e4b8", + "float8e4b8", + "float8e4b15", + "float16", + "bfloat16", + "float32", + "float64", + "_unwrap_if_constexpr", + "tensor", + "tuple", + "tuple_type", + "thread_barrier", + "arange", + "full", + "convert_layout", + "allocate_shared_memory", + "shared_memory_descriptor", + "warp_specialize", + *_IMPORT_FROM_TRITON, +] + +T = TypeVar("T") + +# TODO: split these +GLUON_BUILTIN = "__triton_builtin__" + + +class distributed_type(block_type): + + def __init__(self, element_ty: dtype, shape: List[int], layout): + super().__init__(element_ty, shape) + self.layout = layout + self.name = f"<{self.shape}, {self.element_ty}, {self.layout}>" + assert isinstance(layout, DistributedLayout) + + def to_ir(self, builder: ir.builder) -> ir.type: + elem_ty = self.element_ty.to_ir(builder) + layout = self.layout._to_ir(builder) + return builder.get_distributed_ty(elem_ty, self.shape, layout) + + def mangle(self) -> str: + elt = self.scalar.mangle() + shape = "_".join(map(str, self.shape)) + layout = self.layout.mangle() + return f"{elt}S{shape}SL{layout}L" + + def with_element_ty(self, scalar_ty: dtype) -> block_type: + return distributed_type(scalar_ty, self.shape, self.layout) + + +def builtin(fn: T) -> T: + """Mark a function as a builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_semantic" not in kwargs or kwargs["_semantic"] is None: + raise ValueError("Did you forget to add @triton.gluon.jit ? " + "(`_semantic` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + setattr(wrapper, GLUON_BUILTIN, True) + + return wrapper + + +class shared_memory_descriptor_type(base_type): + + def __init__(self, element_ty, shape, layout, alloc_shape): + self.element_ty = element_ty + self.shape = shape + self.layout = layout + self.alloc_shape = alloc_shape + assert isinstance(layout, SharedLayout) + + def to_ir(self, builder: GluonOpBuilder) -> None: + return builder.get_shared_mem_desc_ty( + self.element_ty.to_ir(builder), + self.shape, + self.layout._to_ir(builder), + self.alloc_shape, + ) + + def _unflatten_ir(self, handles: List[ir.Value], cursor: int) -> Tuple[shared_memory_descriptor, int]: + value = shared_memory_descriptor(handles[cursor], self.element_ty, self.shape, self.layout, self.alloc_shape) + return value, cursor + 1 + + def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None: + out.append(self.to_ir(builder)) + + def __str__(self) -> str: + return f"shared_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}, {self.alloc_shape}>" + + def __eq__(self, other) -> bool: + return (type(self) is type(other) and self.shape == other.shape and self.layout == other.layout + and self.alloc_shape == other.alloc_shape) + + def __neq__(self, other) -> bool: + return not (self == other) + + def mangle(self) -> str: + shape_str = "_".join([str(s) for s in self.shape]) + return f"MD{self.element_ty.mangle()}S{shape_str}SL{self.layout.mangle()}LAS{self.alloc_shape}ASMD" + + +class shared_memory_descriptor(base_value): + + def __init__(self, handle, element_ty, shape, layout, alloc_shape): + self.handle = handle + self.type = shared_memory_descriptor_type(element_ty, shape, layout, alloc_shape) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + + @property + def dtype(self): + return self.type.element_ty + + @property + def shape(self): + return self.type.shape + + @property + def rank(self): + return len(self.shape) + + @property + def layout(self): + return self.type.layout + + def __str__(self) -> str: + return str(self.type) + + @builtin + def load(self, layout, _semantic: GluonSemantic) -> tensor: + layout = _unwrap_if_constexpr(layout) + return _semantic.shared_load(self, layout) + + @builtin + def store(self, value, _semantic: GluonSemantic) -> None: + return _semantic.shared_store(self, value) + + @builtin + def slice(self, start, length, dim=0, _semantic: GluonSemantic = None) -> shared_memory_descriptor: + start = _unwrap_if_constexpr(start) + length = _unwrap_if_constexpr(length) + dim = _unwrap_if_constexpr(dim) + return _semantic.memdesc_slice(self, start, length, dim) + + @builtin + def index(self, index, _semantic: GluonSemantic = None) -> shared_memory_descriptor: + index = _unwrap_if_constexpr(index) + return _semantic.memdesc_index(self, index) + + @builtin + def permute(self, order, _semantic: GluonSemantic) -> shared_memory_descriptor: + order = [_unwrap_if_constexpr(o) for o in order] + return _semantic.memdesc_trans(self, order) + + @builtin + def reshape(self, shape, layout, _semantic: GluonSemantic) -> shared_memory_descriptor: + shape = [_unwrap_if_constexpr(s) for s in shape] + layout = _unwrap_if_constexpr(layout) + + return _semantic.memdesc_reshape(self, shape, layout) + + @builtin + def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> shared_memory_descriptor: + dtype = _unwrap_if_constexpr(dtype) + shape = [_unwrap_if_constexpr(s) for s in shape] + layout = _unwrap_if_constexpr(layout) + + return _semantic.memdesc_reinterpret(self, dtype, shape, layout) + + @builtin + def _keep_alive(self, _semantic: GluonSemantic = None) -> None: + return _semantic.shared_dealloc(self) + + +for name in _IMPORT_FROM_TRITON: + fn = getattr(tl_core, name) + globals()[name] = builtin(fn) + + +@builtin +def arange(start, end, layout, _semantic=None): + start = _unwrap_if_constexpr(start) + end = _unwrap_if_constexpr(end) + layout = _unwrap_if_constexpr(layout) + return _semantic.arange(start, end, layout) + + +@builtin +def convert_layout(value, layout, _semantic=None): + layout = _unwrap_if_constexpr(layout) + return _semantic.convert_layout(value, layout) + + +@builtin +def full(shape, value, dtype, layout, _semantic=None): + shape = _unwrap_shape(shape) + value = _unwrap_if_constexpr(value) + dtype = _unwrap_if_constexpr(dtype) + layout = _unwrap_if_constexpr(layout) + return _semantic.full(shape, value, dtype, layout) + + +@builtin +def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None): + element_ty = _unwrap_if_constexpr(element_ty) + shape = _unwrap_if_constexpr(shape) + shape = [_unwrap_if_constexpr(s) for s in shape] + layout = _unwrap_if_constexpr(layout) + return _semantic.allocate_shared(element_ty, shape, layout, value) + + +@builtin +def warp_specialize(args, default_partition, worker_partitions, worker_num_warps, worker_num_regs, # + _semantic=None, _generator=None): + worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps] + worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs] + return _semantic.warp_specialize(args, default_partition, worker_partitions, worker_num_warps, # + worker_num_regs, _generator) + + +@builtin +def thread_barrier(_semantic=None): + return _semantic.debug_barrier() diff --git a/third_party/sunrise/python/triton/experimental/gluon/language/_layouts.py b/third_party/sunrise/python/triton/experimental/gluon/language/_layouts.py new file mode 100644 index 000000000..fd6ed2fd9 --- /dev/null +++ b/third_party/sunrise/python/triton/experimental/gluon/language/_layouts.py @@ -0,0 +1,230 @@ +from dataclasses import dataclass +from typing import List, Optional +from triton.language.core import _unwrap_if_constexpr, _unwrap_shape + +__all__ = [ + "BlockedLayout", + "SliceLayout", + "DistributedLinearLayout", + "NVMMASharedLayout", + "SwizzledSharedLayout", +] + + +def _realize_cta_layout(rank, ctas_per_cga, cta_split_num, cta_order): + ctas_per_cga = ctas_per_cga or [1] * rank + cta_split_num = cta_split_num or [1] * rank + cta_order = cta_order or list(reversed(range(rank))) + return ctas_per_cga, cta_split_num, cta_order + + +class DistributedLayout: + pass + + +@dataclass(frozen=True) +class BlockedLayout(DistributedLayout): + size_per_thread: List[int] + threads_per_warp: List[int] + warps_per_cta: List[int] + order: List[int] + ctas_per_cga: Optional[List[int]] = None + cta_split_num: Optional[List[int]] = None + cta_order: Optional[List[int]] = None + + def __post_init__(self): + super().__setattr__("size_per_thread", _unwrap_if_constexpr(self.size_per_thread)) + super().__setattr__("threads_per_warp", _unwrap_if_constexpr(self.threads_per_warp)) + super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta)) + super().__setattr__("order", _unwrap_if_constexpr(self.order)) + super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga)) + super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num)) + super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order)) + + rank = len(self.size_per_thread) + assert len(self.threads_per_warp) == rank + assert len(self.warps_per_cta) == rank + assert len(self.order) == rank + assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank + assert self.cta_split_num is None or len(self.cta_split_num) == rank + assert self.cta_order is None or len(self.cta_order) == rank + + def _to_ir(self, builder): + rank = len(self.size_per_thread) + ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(rank, self.ctas_per_cga, self.cta_split_num, + self.cta_order) + return builder.get_blocked_layout( + self.size_per_thread, + self.threads_per_warp, + self.warps_per_cta, + self.order, + ctas_per_cga, + cta_split_num, + cta_order, + ) + + def mangle(self) -> str: + + def stringify(x): + if x is None: + return "" + return "_".join(map(str, x)) + + size_per_thread = stringify(self.size_per_thread) + threads_per_warp = stringify(self.threads_per_warp) + warps_per_cta = stringify(self.warps_per_cta) + order = stringify(self.order) + ctas_per_cga = stringify(self.ctas_per_cga) + cta_split_num = stringify(self.cta_split_num) + cta_order = stringify(self.cta_order) + return f"B{size_per_thread}B{threads_per_warp}B{warps_per_cta}B{order}B{ctas_per_cga}B{cta_split_num}B{cta_order}B" + + +@dataclass(frozen=True) +class SliceLayout(DistributedLayout): + dim: int + parent: DistributedLayout + + def __post_init__(self): + super().__setattr__("dim", _unwrap_if_constexpr(self.dim)) + super().__setattr__("parent", _unwrap_if_constexpr(self.parent)) + + def _to_ir(self, builder): + return builder.get_slice_layout( + self.dim, + self.parent._to_ir(builder), + ) + + def mangle(self) -> str: + return f"SL{self.dim}_{self.parent.mangle()}SL" + + +@dataclass(frozen=True) +class DistributedLinearLayout(DistributedLayout): + reg_bases: List[List[int]] + lane_bases: List[List[int]] + warp_bases: List[List[int]] + block_bases: List[List[int]] + shape: List[int] + + def __post_init__(self): + super().__setattr__("reg_bases", _unwrap_shape(self.reg_bases)) + super().__setattr__("lane_bases", _unwrap_shape(self.lane_bases)) + super().__setattr__("warp_bases", _unwrap_shape(self.warp_bases)) + super().__setattr__("block_bases", _unwrap_shape(self.block_bases)) + super().__setattr__("shape", _unwrap_shape(self.shape)) + + rank = len(self.shape) + + for basis in self.reg_bases: + assert len(basis) == rank + for basis in self.lane_bases: + assert len(basis) == rank + for basis in self.warp_bases: + assert len(basis) == rank + for basis in self.block_bases: + assert len(basis) == rank + + def _to_ir(self, builder): + return builder.get_distributed_linear_layout(self.reg_bases, self.lane_bases, self.warp_bases, self.block_bases, + self.shape) + + def mangle(self): + return f"DLL{self.reg_bases}_{self.lane_bases}_{self.warp_bases}_{self.block_bases}_{self.shape}DLL" + + +class SharedLayout: + pass + + +@dataclass(frozen=True) +class NVMMASharedLayout(SharedLayout): + swizzle_byte_width: int + element_bitwidth: int + rank: int + transposed: bool = False + fp4_padded: bool = False + ctas_per_cga: Optional[List[int]] = None + cta_split_num: Optional[List[int]] = None + cta_order: Optional[List[int]] = None + + def __post_init__(self): + super().__setattr__("swizzle_byte_width", _unwrap_if_constexpr(self.swizzle_byte_width)) + super().__setattr__("element_bitwidth", _unwrap_if_constexpr(self.element_bitwidth)) + super().__setattr__("rank", _unwrap_if_constexpr(self.rank)) + super().__setattr__("transposed", _unwrap_if_constexpr(self.transposed)) + super().__setattr__("fp4_padded", _unwrap_if_constexpr(self.fp4_padded)) + super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga)) + super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num)) + super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order)) + + assert self.element_bitwidth in [8, 16, 32, 64] + assert self.swizzle_byte_width in [0, 32, 64, 128] + rank = self.rank + assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank + assert self.cta_split_num is None or len(self.cta_split_num) == rank + assert self.cta_order is None or len(self.cta_order) == rank + + def _to_ir(self, builder): + ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(self.rank, self.ctas_per_cga, self.cta_split_num, + self.cta_order) + return builder.get_nvmma_shared_layout( + self.swizzle_byte_width, + self.element_bitwidth, + self.transposed, + self.fp4_padded, + ctas_per_cga, + cta_split_num, + cta_order, + ) + + def mangle(self) -> str: + return f"NVMMA_{self.swizzle_byte_width}_{self.element_bitwidth}_{self.transposed}_{self.fp4_padded}_NVMMA" + + +@dataclass(frozen=True, eq=True) +class SwizzledSharedLayout(SharedLayout): + vec: int + per_phase: int + max_phase: int + order: List[int] + ctas_per_cga: Optional[List[int]] = None + cta_split_num: Optional[List[int]] = None + cta_order: Optional[List[int]] = None + + def __post_init__(self): + super().__setattr__("vec", _unwrap_if_constexpr(self.vec)) + super().__setattr__("per_phase", _unwrap_if_constexpr(self.per_phase)) + super().__setattr__("max_phase", _unwrap_if_constexpr(self.max_phase)) + super().__setattr__("order", _unwrap_if_constexpr(self.order)) + super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga)) + super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num)) + super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order)) + + rank = len(self.order) + assert self.ctas_per_cga is None or len(self.ctas_per_cga) == rank + assert self.cta_split_num is None or len(self.cta_split_num) == rank + assert self.cta_order is None or len(self.cta_order) == rank + + def _to_ir(self, builder): + rank = len(self.order) + ctas_per_cga, cta_split_num, cta_order = _realize_cta_layout(rank, self.ctas_per_cga, self.cta_split_num, + self.cta_order) + return builder.get_swizzled_shared_layout( + _unwrap_if_constexpr(self.vec), + _unwrap_if_constexpr(self.per_phase), + _unwrap_if_constexpr(self.max_phase), + self.order, + ctas_per_cga, + cta_split_num, + cta_order, + ) + + def mangle(self) -> str: + + def stringify(x): + if x is None: + return "" + return "_".join(map(str, x)) + + return f"SSS_{self.vec}_{self.per_phase}_{self.max_phase}_{stringify(self.order)}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_SSS" diff --git a/third_party/sunrise/python/triton/experimental/gluon/language/_math.py b/third_party/sunrise/python/triton/experimental/gluon/language/_math.py new file mode 100644 index 000000000..55ba1bc59 --- /dev/null +++ b/third_party/sunrise/python/triton/experimental/gluon/language/_math.py @@ -0,0 +1,12 @@ +# flake8: noqa +import triton.language.math as tl_math +from ._core import builtin + +__all__ = [ + "umulhi", "exp", "exp2", "fma", "log", "log2", "cos", "rsqrt", "sin", "sqrt", "sqrt_rn", "abs", "fdiv", "div_rn", + "erf", "floor", "ceil" +] + +for name in __all__: + fn = getattr(tl_math, name) + globals()[name] = builtin(fn) diff --git a/third_party/sunrise/python/triton/experimental/gluon/language/_semantic.py b/third_party/sunrise/python/triton/experimental/gluon/language/_semantic.py new file mode 100644 index 000000000..b1ba5b80f --- /dev/null +++ b/third_party/sunrise/python/triton/experimental/gluon/language/_semantic.py @@ -0,0 +1,287 @@ +from typing import Sequence, List, TypeVar, Tuple, Callable +from triton.language.semantic import TritonSemantic +from . import _core as ttgl +from ._layouts import SliceLayout +from triton._C.libtriton.gluon_ir import GluonOpBuilder +from triton.compiler.code_generator import flatten_values_to_ir, unflatten_ir_values + +TensorTy = TypeVar("TensorTy") + + +def _check(cond: bool, msg_fn: Callable[[], str], category=ValueError): + if not cond: + raise category(msg_fn()) + + +class GluonSemantic(TritonSemantic[TensorTy]): + tensor = ttgl.tensor + lang = ttgl + + builder: GluonOpBuilder + + def __init__(self, builder: GluonOpBuilder): + self.builder = builder + + def _wrap_tensor_infer_layout(self, tensor): + ty = ttgl.distributed_type(tensor.type.scalar, tensor.shape, + self.builder.get_gluon_layout_from_tensor(tensor.handle)) + return self.tensor(tensor.handle, ty) + + def _broadcast_shapes(self, lhs_shape: List[int], rhs_shape: List[int]): + if len(lhs_shape) != len(rhs_shape): + raise ValueError(f"Cannot broadcast, rank mismatch: {lhs_shape}, {rhs_shape}") + + ret_shape = [] + for i, left in enumerate(lhs_shape): + right = rhs_shape[i] + if left == 1: + ret_shape.append(right) + elif (right == 1) or (right == left): + ret_shape.append(left) + else: + raise ValueError("Cannot make_shape_compatible: incompatible dimensions " + "at index " + str(i) + ": " + str(left) + " and " + str(right)) + return ret_shape + + def expand_dims(self, input: TensorTy, axis: int) -> TensorTy: + dst_shape = [ttgl._unwrap_if_constexpr(x) for x in input.shape] + dst_shape.insert(axis, 1) + + if axis < 0: + axis += len(input.shape) + + _check(isinstance(input.type, ttgl.distributed_type), + lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}") + layout = input.type.layout + _check(isinstance(layout, SliceLayout), + lambda: f"expected expand_dims input to have a SliceLayout, but got: {layout}") + _check(layout.dim == axis, + lambda: f"expected expand_dims input layout to be sliced in axis {axis} but got {layout.dim}") + + ret_ty = ttgl.distributed_type(input.type.scalar, dst_shape, layout.parent) + handle = self.builder.create_expand_dims(input.handle, axis, ret_ty.to_ir(self.builder)) + return self.tensor(handle, ret_ty) + + def join(self, a: TensorTy, b: TensorTy) -> TensorTy: + a, b = self.broadcast_impl_value(a, b) + _check(a.shape != [], "Cannot join scalars in gluon") + value = super().join(a, b) + return self._wrap_tensor_infer_layout(value) + + def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]: + lhs, rhs = super().split(a) + return self._wrap_tensor_infer_layout(lhs), self._wrap_tensor_infer_layout(rhs) + + def permute(self, input: TensorTy, dims: Tuple[int]) -> TensorTy: + value = super().permute(input, dims) + return self._wrap_tensor_infer_layout(value) + + def broadcast_impl_shape(self, input: TensorTy, shape: Tuple[int]) -> TensorTy: + _check(isinstance(input.type, ttgl.distributed_type), + lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}") + src_shape = input.type.get_block_shapes() + _check(len(src_shape) == len(shape), lambda: f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") + if shape == src_shape: + return input + for i, item in enumerate(src_shape): + if shape[i] != item and item != 1: + raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" + f" must match the existing size ({item}) at non-singleton dimension" + f" {i}: {src_shape}, {shape}") + ret_ty = ttgl.distributed_type(input.type.scalar, shape, input.type.layout) + handle = self.builder.create_broadcast(input.handle, ret_ty.to_ir(self.builder)) + return self.tensor(handle, ret_ty) + + def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy: + lhs_ty = lhs.type + rhs_ty = rhs.type + + if not lhs_ty.is_block() or not rhs_ty.is_block(): + return super().broadcast_impl_value(lhs, rhs) + + _check(isinstance(lhs_ty, ttgl.distributed_type), + lambda: f"expected broadcast left input to be a distributed_type but got: {lhs_ty!r}") + _check(isinstance(rhs_ty, ttgl.distributed_type), + lambda: f"expected broadcast right input to be a distributed_type but got: {rhs_ty!r}") + + lhs_shape = lhs_ty.get_block_shapes() + rhs_shape = rhs_ty.get_block_shapes() + ret_shape = self._broadcast_shapes(lhs_shape, rhs_shape) + if lhs_ty.layout != rhs_ty.layout: + raise ValueError(f"Layout mismatch in broadcast: {lhs_ty.layout} vs {rhs_ty.layout}") + + lhs = self.broadcast_impl_shape(lhs, ret_shape) + rhs = self.broadcast_impl_shape(rhs, ret_shape) + return lhs, rhs + + def arange(self, start, end, layout): + shape = [end - start] + ret_ty = ttgl.distributed_type(ttgl.int32, shape, layout) + return super().arange(start, end, ret_ty=ret_ty) + + def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool): + _check(not can_reorder, "can_reorder is not supported in gluon") + value = super().reshape(input, dst_shape, can_reorder) + return self._wrap_tensor_infer_layout(value) + + def splat(self, value, shape, layout): + ret_ty = ttgl.distributed_type(value.dtype, shape, layout) + handle = self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle) + return ttgl.tensor(handle, ret_ty) + + def full(self, shape, value, dtype, layout): + scalar = self.make_scalar(value, dtype) + return self.splat(scalar, shape, layout) + + def convert_layout(self, value, layout): + ty = value.type + _check(isinstance(ty, ttgl.distributed_type), + lambda: f"expected convert_layout input to be a distributed_type but got: {ty!r}") + ret_ty = ttgl.distributed_type(ty.element_ty, ty.shape, layout) + handle = self.builder.create_convert_layout(ret_ty.to_ir(self.builder), value.handle) + return ttgl.tensor(handle, ret_ty) + + def allocate_shared(self, element_ty, shape, layout, value): + ty = ttgl.shared_memory_descriptor_type(element_ty, shape, layout, shape) + if value is not None: + handle = self.builder.create_local_alloc(ty.to_ir(self.builder), value.handle) + else: + handle = self.builder.create_local_alloc(ty.to_ir(self.builder)) + return ttgl.shared_memory_descriptor(handle, element_ty, shape, layout, shape) + + def shared_load(self, mem_desc, layout): + ret_ty = ttgl.distributed_type(mem_desc.dtype, mem_desc.shape, layout) + handle = self.builder.create_local_load(ret_ty.to_ir(self.builder), mem_desc.handle) + return ttgl.tensor(handle, ret_ty) + + def shared_store(self, mem_desc, value): + self.builder.create_local_store(mem_desc.handle, value.handle) + + def shared_dealloc(self, mem_desc): + self.builder.create_local_dealloc(mem_desc.handle) + + def _memdesc_subview(self, mem_desc, offsets, shape): + layout = mem_desc.layout + ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape) + builder = self.builder + handle = builder.create_memdesc_subview(ty.to_ir(builder), mem_desc.handle, offsets) + return ttgl.shared_memory_descriptor(handle, **ty.__dict__) + + def memdesc_slice(self, mem_desc, start, length, dim): + offsets = [self.builder.get_int32(0)] * mem_desc.rank + offsets[dim] = self.to_tensor(start).handle + shape = list(mem_desc.shape) + shape[dim] = length + return self._memdesc_subview(mem_desc, offsets, shape) + + def memdesc_index(self, mem_desc, index): + shape = mem_desc.shape[1:] + offsets = [self.builder.get_int32(0)] * mem_desc.rank + offsets[0] = self.to_tensor(index).handle + return self._memdesc_subview(mem_desc, offsets, shape) + + def memdesc_trans(self, mem_desc, order): + assert len(order) == len( + mem_desc.shape), f"source rank ({mem_desc.rank}) and order length ({len(order)}) must match" + + shape = [mem_desc.shape[i] for i in order] + alloc_shape = mem_desc.type.alloc_shape + new_alloc_shape = alloc_shape[:len(alloc_shape) - mem_desc.rank] + new_alloc_shape += [alloc_shape[len(alloc_shape) - mem_desc.rank:][i] for i in order] + + handle = self.builder.create_memdesc_trans(mem_desc.handle, order) + layout = self.builder.get_gluon_layout_from_memdesc(handle) + return ttgl.shared_memory_descriptor(handle, element_ty=mem_desc.dtype, shape=shape, + alloc_shape=new_alloc_shape, layout=layout) + + def memdesc_reshape(self, mem_desc, shape, layout): + ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape) + handle = self.builder.create_memdesc_reshape(ty.to_ir(self.builder), mem_desc.handle) + return ttgl.shared_memory_descriptor(handle, **ty.__dict__) + + def memdesc_reinterpret(self, mem_desc, dtype, shape, layout): + ty = ttgl.shared_memory_descriptor_type(dtype, shape, layout, shape) + handle = self.builder.create_memdesc_reinterpret(ty.to_ir(self.builder), mem_desc.handle) + return ttgl.shared_memory_descriptor(handle, **ty.__dict__) + + def wrap_tensor(self, x, scalar_ty, ret_shape, layout): + if ret_shape: + res_ty = ttgl.distributed_type(scalar_ty, ret_shape, layout) + else: + res_ty = scalar_ty + return self.tensor(x, res_ty) + + @staticmethod + def _check_same_layout(xs): + for x in xs: + _check(isinstance(x.type, ttgl.distributed_type), lambda: f"expected distributed_type but got: {x.type!r}") + layouts = [x.type.layout for x in xs] + l0 = layouts[0] + _check(all(l == l0 for l in layouts[1:]), + lambda: f"Expected inputs to have matching layouts, but got: {layouts}") + + def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]: + _check(axis is not None, lambda: "All-reduce is not yet implemented in gluon") + # get result shape + shape = inputs[0].type.shape + rank = len(shape) + _check(0 <= axis < rank, lambda: f"expected reduction axis to be in the range [0, {rank}) but got {axis}") + self._check_same_layout(inputs) + ret_shape = [s for i, s in enumerate(shape) if i != axis] + ret_layout = SliceLayout(axis, inputs[0].type.layout) + assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape" + + reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis) + region_builder_fn(reduce_op) + assert reduce_op.verify() + + return tuple( + self.wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape, ret_layout) + for i in range(len(inputs))) + + def warp_specialize(self, args, default_partition, worker_partitions, worker_num_warps: Sequence[int], + worker_num_regs: Sequence[int], generator): + num_partitions = len(worker_partitions) + assert num_partitions == len( + worker_num_warps + ), f"warp specialize got {num_partitions} partitions but {len(worker_num_warps)} warp counts" + assert num_partitions == len( + worker_num_regs + ), f"warp specialize got {num_partitions} partitions but {len(worker_num_regs)} register counts" + + builder = self.builder + insert_pt = builder.get_insertion_point() + + # Emit the default partition to get the result types. + default_block = builder.new_block() + builder.set_insertion_point_to_start(default_block) + default_results = generator.call_JitFunction(default_partition, args, kwargs={}) + mlir_results = [] + if default_results is not None: + mlir_results = flatten_values_to_ir(default_results) + builder.create_warp_yield(mlir_results) + result_types = [r.get_type() for r in mlir_results] + + # Create the warp specialize op. + builder.restore_insertion_point(insert_pt) + mlir_args = flatten_values_to_ir(args) + ws_op = builder.create_warp_specialize(result_types, mlir_args, worker_num_warps) + ws_op.get_default_region().push_back(default_block) + ws_op.set_requested_registers(worker_num_regs) + + # Emit the partition regions. + builder.create_block_with_parent(ws_op.get_partition_op_holder(), []) + partitions_op = builder.create_warp_specialize_partitions(num_partitions) + arg_types = [arg.get_type() for arg in mlir_args] + for i in range(num_partitions): + block = builder.create_block_with_parent(partitions_op.get_region(i), arg_types) + block_args = [block.get_argument(j) for j in range(len(mlir_args))] + block_args = unflatten_ir_values(block_args, [arg.type for arg in args]) + generator.call_JitFunction(worker_partitions[i], block_args, kwargs={}) + builder.create_warp_return() + + builder.set_insertion_point_after(ws_op.get_operation()) + mlir_results = [ws_op.get_result(i) for i in range(len(result_types))] + if default_results is None: + return + return tuple(unflatten_ir_values(mlir_results, [r.type for r in default_results])) diff --git a/third_party/sunrise/python/triton/experimental/gluon/language/_standard.py b/third_party/sunrise/python/triton/experimental/gluon/language/_standard.py new file mode 100644 index 000000000..3da7834d4 --- /dev/null +++ b/third_party/sunrise/python/triton/experimental/gluon/language/_standard.py @@ -0,0 +1,47 @@ +# flake8: noqa +import triton +import triton.language.standard as tl_standard +from .._runtime import jit +from triton import knobs +from . import _core as ttgl + +_IMPORT_FROM_TRITON = [ + "sum", + "max", + "min", + "reduce_or", + "xor_sum", +] + +__all__ = [ + "full_like", + "zeros", + "zeros_like", + *_IMPORT_FROM_TRITON, +] + +for name in _IMPORT_FROM_TRITON: + # Convert JITFunction -> GluonJitFunction + fn = getattr(tl_standard, name) + assert knobs.runtime.interpret or isinstance(fn, triton.runtime.JITFunction) + globals()[name] = jit(fn.fn) + + +@jit +def zeros(shape, dtype, layout): + return ttgl.full(shape, 0, dtype, layout) + + +@jit +def full_like(input, value, shape=None, dtype=None, layout=None): + return ttgl.full( + input.shape if shape is None else shape, + value, + input.dtype if dtype is None else dtype, + input.type.layout if layout is None else layout, + ) + + +@jit +def zeros_like(input, shape=None, dtype=None, layout=None): + return full_like(input, 0, shape=shape, dtype=dtype, layout=layout) diff --git a/third_party/sunrise/python/triton/experimental/gluon/language/nvidia/__init__.py b/third_party/sunrise/python/triton/experimental/gluon/language/nvidia/__init__.py new file mode 100644 index 000000000..3ecf36d3b --- /dev/null +++ b/third_party/sunrise/python/triton/experimental/gluon/language/nvidia/__init__.py @@ -0,0 +1,4 @@ +from . import blackwell +from . import hopper + +__all__ = ["blackwell", "hopper"] diff --git a/third_party/sunrise/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py b/third_party/sunrise/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py new file mode 100644 index 000000000..243c87c5b --- /dev/null +++ b/third_party/sunrise/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py @@ -0,0 +1,202 @@ +from __future__ import annotations +from typing import Optional, Tuple, List, TYPE_CHECKING + +from dataclasses import dataclass +from triton.experimental.gluon.language import _core as ttgl +from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr +from triton.experimental.gluon.language._semantic import _check + +from . import tma +from ..hopper import mbarrier, fence_async_shared + +if TYPE_CHECKING: + from triton._C.libtriton.gluon_ir import GluonOpBuilder + from triton._C.libtriton import gluon_ir as ir + from ..._semantic import GluonSemantic + +__all__ = [ + "allocate_tensor_memory", + "fence_async_shared", + "mbarrier", + "tensor_memory_descriptor", + "TensorMemoryLayout", + "tma", +] + + +@dataclass(frozen=True, eq=True) +class TensorMemoryLayout: + block: Tuple[int, int] + unpacked: bool + cta_split_num: Optional[Tuple[int, int]] = None + + def __post_init__(self): + assert len(self.block) == 2 + assert self.cta_split_num is None or len(self.cta_split_num) == 2 + + def _to_ir(self, builder): + cta_split_num = self.cta_split_num or [1, 1] + return builder.get_tensor_memory_layout( + self.block, + self.unpacked, + cta_split_num, + ) + + def mangle(self) -> str: + block_str = f"{self.block[0]}x{self.block[1]}" + unpacked_str = "U" if self.unpacked else "P" + cta_split_str = f"CS{self.cta_split_num[0]}x{self.cta_split_num[1]}" if self.cta_split_num else "" + return f"TL{block_str}{unpacked_str}{cta_split_str}TL" + + +class tensor_memory_descriptor_type(base_type): + + def __init__(self, element_ty, shape, layout, alloc_shape): + self.element_ty = element_ty + self.shape = shape + self.layout = layout + self.alloc_shape = alloc_shape + assert isinstance(layout, TensorMemoryLayout) + + def to_ir(self, builder: GluonOpBuilder) -> None: + return builder.get_tensor_mem_desc_ty( + self.element_ty.to_ir(builder), + self.shape, + self.layout._to_ir(builder), + self.alloc_shape, + ) + + def _unflatten_ir(self, handles: List[ir.Value], cursor: int) -> Tuple[tensor_memory_descriptor, int]: + value = tensor_memory_descriptor(handles[cursor], self.element_ty, self.shape, self.layout, self.alloc_shape) + return value, cursor + 1 + + def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None: + out.append(self.to_ir(builder)) + + def __str__(self) -> str: + return f"tensor_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}>" + + def __eq__(self, other) -> bool: + return (type(self) is type(other) and self.shape == other.shape and self.layout == other.layout + and self.alloc_shape == other.alloc_shape) + + def __neq__(self, other) -> bool: + return not (self == other) + + def mangle(self) -> str: + shape_str = "_".join([str(s) for s in self.shape]) + return f"MD{self.element_ty.mangle()}S{shape_str}SL{self.layout.mangle()}LAS{self.alloc_shape}ASMD" + + +class tensor_memory_descriptor(base_value): + + def __init__(self, handle, element_ty, shape, layout, alloc_shape): + self.handle = handle + self.type = tensor_memory_descriptor_type(element_ty, shape, layout, alloc_shape) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + + @property + def dtype(self): + return self.type.element_ty + + @property + def shape(self): + return self.type.shape + + @property + def rank(self): + return len(self.shape) + + @property + def layout(self): + return self.type.layout + + def __str__(self) -> str: + return str(self.type) + + @builtin + def load(self, layout, _semantic: GluonSemantic) -> ttgl.tensor: + layout = _unwrap_if_constexpr(layout) + ret_ty = ttgl.distributed_type(self.dtype, self.shape, layout) + builder = _semantic.builder + handle = builder.create_tmem_load(ret_ty.to_ir(builder), self.handle) + return ttgl.tensor(handle, ret_ty) + + @builtin + def store(self, value, pred=True, _semantic: GluonSemantic = None) -> None: + pred = _unwrap_if_constexpr(pred) + pred = _semantic.to_tensor(pred) + _semantic.builder.create_tmem_store(self.handle, value.handle, pred.handle) + + @builtin + def slice(self, start, length, _semantic: GluonSemantic) -> None: + start = _unwrap_if_constexpr(start) + length = _unwrap_if_constexpr(length) + _check(isinstance(start, int), lambda: "start must be a constant int") + _check(isinstance(length, int), lambda: "length must be a constant int") + shape = self.shape[:-1] + [length] + layout = self.type.layout + layout = TensorMemoryLayout((layout.block[0], min(layout.block[1], length)), layout.unpacked, + layout.cta_split_num) + ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape) + builder = _semantic.builder + ret.handle = builder.create_tmem_subslice(ret.type.to_ir(builder), self.handle, start) + return ret + + @builtin + def index(self, index, _semantic: GluonSemantic = None) -> tensor_memory_descriptor: + index = _semantic.to_tensor(index) + builder = _semantic.builder + offsets = [builder.get_int32(0)] * self.rank + offsets[0] = index.handle + shape = self.shape[1:] + layout = self.layout + ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape) + ret.handle = builder.create_memdesc_subview(ret.type.to_ir(builder), self.handle, offsets) + return ret + + @builtin + def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> tensor_memory_descriptor: + dtype = _unwrap_if_constexpr(dtype) + shape = [_unwrap_if_constexpr(s) for s in shape] + layout = _unwrap_if_constexpr(layout) + + ty = tensor_memory_descriptor_type(dtype, shape, layout, shape) + handle = _semantic.builder.create_memdesc_reinterpret(ty.to_ir(_semantic.builder), self.handle) + return tensor_memory_descriptor(handle, **ty.__dict__) + + +@builtin +def allocate_tensor_memory(element_ty, shape, layout, value=None, _semantic=None): + element_ty = _unwrap_if_constexpr(element_ty) + shape = _unwrap_if_constexpr(shape) + layout = _unwrap_if_constexpr(layout) + value = value.handle if value is not None else None + + ty = tensor_memory_descriptor_type(element_ty, shape, layout, shape) + builder = _semantic.builder + handle = builder.create_tmem_alloc(ty.to_ir(builder), value) + return tensor_memory_descriptor(handle, element_ty, shape, layout, shape) + + +@builtin +def tcgen05_mma(a, b, acc, *, use_acc=True, pred=True, mbarriers=None, mbarrier_preds=None, _semantic=None): + use_acc = _semantic.to_tensor(use_acc) + pred = _semantic.to_tensor(pred) + + if mbarriers is None: + assert mbarrier_preds is None + mbarriers = [] + mbarrier_preds = [] + else: + mbarriers = [bar.handle for bar in mbarriers] + if mbarrier_preds is None: + true = _semantic.to_tensor(True) + mbarrier_preds = [true] * len(mbarriers) + else: + mbarrier_preds = _semantic._convert_to_ir_values(mbarrier_preds, require_i64=False) + + _semantic.builder.create_tcgen05_mma(a.handle, b.handle, acc.handle, use_acc.handle, pred.handle, mbarriers, + mbarrier_preds) diff --git a/third_party/sunrise/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py b/third_party/sunrise/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py new file mode 100644 index 000000000..c36339b25 --- /dev/null +++ b/third_party/sunrise/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py @@ -0,0 +1,32 @@ +from triton.experimental.gluon.language._core import builtin +from triton.experimental.gluon.language.nvidia.hopper.tma import ( + async_copy_global_to_shared, + async_copy_shared_to_global, + store_wait, + tensor_descriptor, + tensor_descriptor_type, +) + +__all__ = [ + "async_gather", + "async_scatter", + "async_copy_global_to_shared", + "async_copy_shared_to_global", + "store_wait", + "tensor_descriptor", + "tensor_descriptor_type", +] + + +@builtin +def async_gather(tensor_desc, x_offsets, y_offset, barrier, result, pred=True, _semantic=None): + pred = _semantic.to_tensor(pred) + y_offset = _semantic.to_tensor(y_offset) + _semantic.builder.create_async_tma_gather(tensor_desc.handle, x_offsets.handle, y_offset.handle, barrier.handle, + result.handle, pred.handle) + + +@builtin +def async_scatter(tensor_desc, x_offsets, y_offset, src, _semantic=None): + y_offset = _semantic.to_tensor(y_offset) + _semantic.builder.create_async_tma_scatter(tensor_desc.handle, x_offsets.handle, y_offset.handle, src.handle) diff --git a/third_party/sunrise/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py b/third_party/sunrise/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py new file mode 100644 index 000000000..3d61d8130 --- /dev/null +++ b/third_party/sunrise/python/triton/experimental/gluon/language/nvidia/hopper/__init__.py @@ -0,0 +1,11 @@ +from . import mbarrier +from . import tma +from ... import _core + +__all__ = ["fence_async_shared", "mbarrier", "tma"] + + +@_core.builtin +def fence_async_shared(cluster=False, _semantic=None): + cluster = _core._unwrap_if_constexpr(cluster) + _semantic.builder.create_fence_async_shared(cluster) diff --git a/third_party/sunrise/python/triton/experimental/gluon/language/nvidia/hopper/mbarrier.py b/third_party/sunrise/python/triton/experimental/gluon/language/nvidia/hopper/mbarrier.py new file mode 100644 index 000000000..ab15ac66e --- /dev/null +++ b/third_party/sunrise/python/triton/experimental/gluon/language/nvidia/hopper/mbarrier.py @@ -0,0 +1,51 @@ +from triton.experimental.gluon.language._layouts import SwizzledSharedLayout +from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr + +__all__ = ["MBarrierLayout", "init", "invalidate", "expect", "wait", "arrive"] + + +class MBarrierLayout(SwizzledSharedLayout): + + def __init__(self, ctas_per_cga: int = 1, cta_split_num: int = 1): + super().__init__( + vec=1, + per_phase=1, + max_phase=1, + order=[0], + ctas_per_cga=[ctas_per_cga], + cta_split_num=[cta_split_num], + cta_order=[0], + ) + + +@builtin +def init(mbarrier, count, _semantic=None): + count = _unwrap_if_constexpr(count) + _semantic.builder.create_mbarrier_init(mbarrier.handle, count) + + +@builtin +def invalidate(mbarrier, _semantic=None): + _semantic.builder.create_mbarrier_inval(mbarrier.handle) + + +@builtin +def expect(mbarrier, bytes, pred=True, _semantic=None): + bytes = _unwrap_if_constexpr(bytes) + pred = _semantic.to_tensor(pred) + _semantic.builder.create_mbarrier_expect(mbarrier.handle, bytes, pred.handle) + + +@builtin +def wait(mbarrier, phase, pred=True, deps=(), _semantic=None): + phase = _semantic.to_tensor(phase) + pred = _semantic.to_tensor(pred) + deps = [x.handle for x in deps] + _semantic.builder.create_mbarrier_wait(mbarrier.handle, phase.handle, pred.handle, deps) + + +@builtin +def arrive(mbarrier, count, pred=True, _semantic=None): + count = _unwrap_if_constexpr(count) + pred = _semantic.to_tensor(pred) + _semantic.builder.create_mbarrier_arrive(mbarrier.handle, count, pred.handle) diff --git a/third_party/sunrise/python/triton/experimental/gluon/language/nvidia/hopper/tma.py b/third_party/sunrise/python/triton/experimental/gluon/language/nvidia/hopper/tma.py new file mode 100644 index 000000000..2914ee0de --- /dev/null +++ b/third_party/sunrise/python/triton/experimental/gluon/language/nvidia/hopper/tma.py @@ -0,0 +1,96 @@ +from __future__ import annotations +from typing import List, Tuple, TYPE_CHECKING +from dataclasses import dataclass +import triton.experimental.gluon.language._core as ttgl +from triton.experimental.gluon.language._layouts import NVMMASharedLayout +from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr + +if TYPE_CHECKING: + from triton._C import ir + +__all__ = ["async_copy_global_to_shared", "async_copy_shared_to_global", "store_wait"] + + +@dataclass(eq=True) +class tensor_descriptor_type: + block_type: ttgl.block_type + shape_type: ttgl.tuple_type + strides_type: ttgl.tuple_type + layout: NVMMASharedLayout + + def __str__(self) -> str: + return f"tensor_descriptor<{self.block_type}, {self.layout}>" + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor, int]: + handle = handles[cursor] + cursor += 1 + shape, cursor = self.shape_type._unflatten_ir(handles, cursor) + strides, cursor = self.strides_type._unflatten_ir(handles, cursor) + value = tensor_descriptor(handle, shape, strides, self.block_type, layout=self.layout) + return value, cursor + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + is_signed = self.block_type.element_ty.is_int_signed() + ty = builder.get_tensor_descriptor_layout_type( + self.block_type.to_ir(builder), + is_signed, + self.layout._to_ir(builder), + ) + out.append(ty) + self.shape_type._flatten_ir_types(builder, out) + self.strides_type._flatten_ir_types(builder, out) + + def mangle(self) -> str: + return f"TD{self.block_type.mangle}_{self.layout.mangle()}TD" + + +class tensor_descriptor: + + def __init__(self, handle, shape: List[ttgl.tensor], strides: List[ttgl.tensor], block_type: ttgl.block_type, + layout: NVMMASharedLayout): + self.handle = handle + self.shape = ttgl.tuple(shape) + self.strides = ttgl.tuple(strides) + self.type = tensor_descriptor_type(block_type, shape_type=self.shape.type, strides_type=self.strides.type, + layout=layout) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + self.shape._flatten_ir(handles) + self.strides._flatten_ir(handles) + + @property + def block_type(self): + return self.type.block_type + + @property + def block_shape(self): + return self.type.block_type.shape + + @property + def dtype(self): + return self.type.block_type.element_ty + + @property + def layout(self): + return self.type.layout + + +@builtin +def async_copy_global_to_shared(tensor_desc, coord, barrier, result, pred=True, _semantic=None): + coord = _semantic._convert_to_ir_values(coord, require_i64=False) + pred = _semantic.to_tensor(pred) + _semantic.builder.create_async_tma_copy_global_to_local(tensor_desc.handle, coord, barrier.handle, result.handle, + pred.handle) + + +@builtin +def async_copy_shared_to_global(tensor_desc, coord, src, _semantic=None): + coord = _semantic._convert_to_ir_values(coord, require_i64=False) + _semantic.builder.create_async_tma_copy_local_to_global(tensor_desc.handle, coord, src.handle) + + +@builtin +def store_wait(pendings, _semantic=None): + pendings = _unwrap_if_constexpr(pendings) + _semantic.builder.create_async_tma_store_wait(pendings) diff --git a/third_party/sunrise/python/triton/experimental/gluon/nvidia/__init__.py b/third_party/sunrise/python/triton/experimental/gluon/nvidia/__init__.py new file mode 100644 index 000000000..8184c7388 --- /dev/null +++ b/third_party/sunrise/python/triton/experimental/gluon/nvidia/__init__.py @@ -0,0 +1,4 @@ +from . import hopper +from . import blackwell + +__all__ = ["hopper", "blackwell"] diff --git a/third_party/sunrise/python/triton/experimental/gluon/nvidia/blackwell.py b/third_party/sunrise/python/triton/experimental/gluon/nvidia/blackwell.py new file mode 100644 index 000000000..abf919805 --- /dev/null +++ b/third_party/sunrise/python/triton/experimental/gluon/nvidia/blackwell.py @@ -0,0 +1,3 @@ +from .hopper import TensorDescriptor + +__all__ = ["TensorDescriptor"] diff --git a/third_party/sunrise/python/triton/experimental/gluon/nvidia/hopper.py b/third_party/sunrise/python/triton/experimental/gluon/nvidia/hopper.py new file mode 100644 index 000000000..8a8354933 --- /dev/null +++ b/third_party/sunrise/python/triton/experimental/gluon/nvidia/hopper.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass +from typing import List, Any +from triton._utils import validate_block_shape, canonicalize_dtype, get_primitive_bitwidth +from triton.experimental.gluon.language._layouts import NVMMASharedLayout + +__all__ = ["TensorDescriptor"] + + +@dataclass +class TensorDescriptor: + base: Any + shape: List[int] + strides: List[int] + block_shape: List[int] + layout: NVMMASharedLayout + + def __post_init__(self): + rank = len(self.shape) + assert len(self.strides) == rank, f"rank mismatch: {self}" + assert len(self.block_shape) == rank, f"rank mismatch: {self}" + assert rank > 0, "rank must not be zero" + assert rank <= 5, "rank cannot be more than 5" + assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned" + validate_block_shape(self.block_shape) + dtype_str = canonicalize_dtype(self.base.dtype) + elem_bytes = get_primitive_bitwidth(dtype_str) // 8 + for stride in self.strides[:-1]: + assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned" + assert self.strides[-1] == 1, "Last dimension must be contiguous" + assert isinstance(self.layout, NVMMASharedLayout), "Layout must be NVMMASharedLayout" + + @staticmethod + def from_tensor(tensor: Any, block_shape: List[int], layout: NVMMASharedLayout): + return TensorDescriptor( + tensor, + tensor.shape, + tensor.stride(), + block_shape, + layout, + ) From c0afeedfea603c1200902a437e0a18d35d06170f Mon Sep 17 00:00:00 2001 From: lisirui1 Date: Mon, 19 Jan 2026 17:58:19 +0800 Subject: [PATCH 3/7] fix bug, add env script --- .gitignore | 1 + log.log | 64 -------- python/triton/backends/driver.py | 44 ++++- third_party/sunrise/backend/driver.py | 2 +- third_party/sunrise/language/tang/__init__.py | 3 + .../sunrise/language/tang/libdevice.py | 154 ++++++++++++++++++ .../{01-vector-add.py => vector-add.py} | 21 ++- third_party/sunrise/python/triton/backends | 1 - .../sunrise/python/triton/language/extra/cuda | 1 - .../sunrise/python/triton/language/extra/hip | 1 - .../sunrise/python/triton/language/extra/tang | 1 - .../sunrise/python/triton/runtime/jit.py | 2 +- .../script/docker_clean_sunrise_flagtree.sh | 18 ++ .../sunrise/script/docker_sunrise_env.sh | 53 ++++++ 14 files changed, 281 insertions(+), 85 deletions(-) delete mode 100644 log.log create mode 100644 third_party/sunrise/language/tang/__init__.py create mode 100644 third_party/sunrise/language/tang/libdevice.py rename third_party/sunrise/python/test_examples/{01-vector-add.py => vector-add.py} (93%) delete mode 120000 third_party/sunrise/python/triton/backends delete mode 120000 third_party/sunrise/python/triton/language/extra/cuda delete mode 120000 third_party/sunrise/python/triton/language/extra/hip delete mode 120000 third_party/sunrise/python/triton/language/extra/tang create mode 100644 third_party/sunrise/script/docker_clean_sunrise_flagtree.sh create mode 100644 third_party/sunrise/script/docker_sunrise_env.sh diff --git a/.gitignore b/.gitignore index bb7ed373f..29d98a797 100644 --- a/.gitignore +++ b/.gitignore @@ -68,6 +68,7 @@ ptxas third_party/nvidia/backend/include third_party/nvidia/backend/lib/cupti third_party/sunrise/backend/lib +third_party/sunrise/python/triton/backends/* # Docs docs/_build/ diff --git a/log.log b/log.log deleted file mode 100644 index 5eeb35bc3..000000000 --- a/log.log +++ /dev/null @@ -1,64 +0,0 @@ -Using pip 25.3 from /usr/local/lib/python3.10/dist-packages/pip (python 3.10) -Obtaining file:///root/WorkSpace/flagtree_close/flagtree - Checking if build backend supports build_editable: started - Running command Checking if build backend supports build_editable - Checking if build backend supports build_editable: finished with status 'done' - Preparing editable metadata (pyproject.toml): started - Running command Preparing editable metadata (pyproject.toml) - fatal: detected dubious ownership in repository at '/root/WorkSpace/flagtree_close/flagtree' - To add an exception for this directory, call: - - git config --global --add safe.directory /root/WorkSpace/flagtree_close/flagtree - fatal: detected dubious ownership in repository at '/root/WorkSpace/flagtree_close/flagtree' - To add an exception for this directory, call: - - git config --global --add safe.directory /root/WorkSpace/flagtree_close/flagtree - fatal: detected dubious ownership in repository at '/root/WorkSpace/flagtree_close/flagtree' - To add an exception for this directory, call: - - git config --global --add safe.directory /root/WorkSpace/flagtree_close/flagtree - /usr/local/lib/python3.10/dist-packages/setuptools/_distutils/dist.py:289: UserWarning: Unknown distribution option: 'test_suite' - warnings.warn(msg) - /usr/local/lib/python3.10/dist-packages/setuptools/dist.py:759: SetuptoolsDeprecationWarning: License classifiers are deprecated. - !! - - ******************************************************************************** - Please consider removing the following classifiers in favor of a SPDX license expression: - - License :: OSI Approved :: MIT License - - See https://packaging.python.org/en/latest/guides/writing-pyproject-toml/#license for details. - ******************************************************************************** - - !! - self._finalize_license_expression() - [INFO] FlagTree Offline Build: No offline build for triton origin toolkits - [INFO] FlagtreeBackend is sunrise - !!!================ - {'': 'python', 'triton': './third_party/sunrise/python/triton', 'triton.backends.sunrise': 'third_party/sunrise/backend'} - running dist_info - creating /tmp/pip-modern-metadata-7yt5bbbz/triton.egg-info - writing /tmp/pip-modern-metadata-7yt5bbbz/triton.egg-info/PKG-INFO - writing dependency_links to /tmp/pip-modern-metadata-7yt5bbbz/triton.egg-info/dependency_links.txt - writing entry points to /tmp/pip-modern-metadata-7yt5bbbz/triton.egg-info/entry_points.txt - writing requirements to /tmp/pip-modern-metadata-7yt5bbbz/triton.egg-info/requires.txt - writing top-level names to /tmp/pip-modern-metadata-7yt5bbbz/triton.egg-info/top_level.txt - writing manifest file '/tmp/pip-modern-metadata-7yt5bbbz/triton.egg-info/SOURCES.txt' - error: package directory 'third_party/sunrise/python/triton/language/extra/cuda' does not exist - error: subprocess-exited-with-error - - × Preparing editable metadata (pyproject.toml) did not run successfully. - │ exit code: 1 - ╰─> No available output. - - note: This error originates from a subprocess, and is likely not a problem with pip. - full command: /usr/bin/python3 /usr/local/lib/python3.10/dist-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py prepare_metadata_for_build_editable /tmp/tmpf93nemq1 - cwd: /root/WorkSpace/flagtree_close/flagtree - Preparing editable metadata (pyproject.toml): finished with status 'error' -error: metadata-generation-failed - -× Encountered error while generating package metadata. -╰─> from file:///root/WorkSpace/flagtree_close/flagtree - -note: This is an issue with the package mentioned above, not pip. -hint: See above for details. diff --git a/python/triton/backends/driver.py b/python/triton/backends/driver.py index 6606b21ca..b18fe9cc7 100644 --- a/python/triton/backends/driver.py +++ b/python/triton/backends/driver.py @@ -39,14 +39,46 @@ class GPUDriver(DriverBase): def __init__(self): # TODO: support other frameworks than torch import torch - self.get_device_capability = torch.cuda.get_device_capability + try: + import torch_ptpu + _is_ptpu = True + except ImportError as e: + _is_ptpu = False + if _is_ptpu: + self.get_device_capability = torch.ptpu.get_device_capability + self.get_current_stream = lambda dev_idx: torch.ptpu.current_stream(dev_idx).ptpu_stream + self.get_current_device = torch.ptpu.current_device + self.set_current_device = torch.ptpu.set_device + return + try: from torch._C import _cuda_getCurrentRawStream - self.get_current_stream = _cuda_getCurrentRawStream - except ImportError: - self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream - self.get_current_device = torch.cuda.current_device - self.set_current_device = torch.cuda.set_device + _is_cuda = True + except ImportError as e: + _cuda_getCurrentRawStream = None + _is_cuda = True if torch.version.cuda else False + if _is_cuda: + self.get_device_capability = torch.cuda.get_device_capability + if _cuda_getCurrentRawStream is not None: + self.get_current_stream = _cuda_getCurrentRawStream + else: + self.get_current_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream + self.get_current_device = torch.cuda.current_device + self.set_current_device = torch.cuda.set_device + return + + try: + import torch_dipu + _is_dipu = True + except ImportError as e: + _is_dipu = False + if _is_dipu: + self.get_device_capability = torch.cuda.get_device_capability + self.get_current_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).dipu_stream + self.get_current_device = torch.cuda.current_device + self.set_current_device = torch.cuda.set_device + return + # TODO: remove once TMA is cleaned up def assemble_tensormap_to_arg(self, tensormaps_info, args): diff --git a/third_party/sunrise/backend/driver.py b/third_party/sunrise/backend/driver.py index 823214d33..6e7871864 100644 --- a/third_party/sunrise/backend/driver.py +++ b/third_party/sunrise/backend/driver.py @@ -487,7 +487,7 @@ def get_current_target(self): capability = "S2" warp_size = 32 return GPUTarget("tang", capability, warp_size) - + def get_active_torch_device(self): import torch return torch.device("cuda", self.get_current_device()) diff --git a/third_party/sunrise/language/tang/__init__.py b/third_party/sunrise/language/tang/__init__.py new file mode 100644 index 000000000..2ce93570d --- /dev/null +++ b/third_party/sunrise/language/tang/__init__.py @@ -0,0 +1,3 @@ +from . import libdevice + +__all__ = ["libdevice"] \ No newline at end of file diff --git a/third_party/sunrise/language/tang/libdevice.py b/third_party/sunrise/language/tang/libdevice.py new file mode 100644 index 000000000..65715d6e1 --- /dev/null +++ b/third_party/sunrise/language/tang/libdevice.py @@ -0,0 +1,154 @@ +from triton.language import core + +@core.extern +def erf(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erf_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erf_f32", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + +@core.extern +def pow(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__ocml_pown_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__ocml_pown_f32", core.dtype("fp64")), + (core.dtype("fp16"), core.dtype("fp16")): ("__ocml_pow_f16", core.dtype("fp16")), + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_pow_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_pow_f32", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + +@core.extern +def tanh(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_tanh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_tanh_f32", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + +@core.extern +def atan2(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_atan2_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_atan2_f32", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + +@core.extern +def asin(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_asin_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_asin_f32", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + +@core.extern +def div_rd(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("llvm.stvm.div.rm.f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("llvm.stvm.div.rm.f32", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + +@core.extern +def div_rz(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("llvm.stvm.div.rz.f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("llvm.stvm.div.rz.f32", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + +@core.extern +def rsqrt(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("llvm.stvm.rsqrt.f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("llvm.stvm.rsqrt.f32", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + +@core.extern +def isinf(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("llvm.stvm.testp.f32.inf", core.dtype("int32")), + (core.dtype("fp64"), ): ("llvm.stvm.testp.f32.inf", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + +@core.extern +def isnan(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("llvm.stvm.testp.f32.not", core.dtype("int32")), + (core.dtype("fp64"), ): ("llvm.stvm.testp.f32.not", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic) + +@core.extern +def erf(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erf_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erf_f32", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + + +@core.extern +def exp(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_exp_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_exp_f32", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + +@core.extern +def exp2(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("llvm.stvm.exp2.f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("llvm.stvm.exp2.f32", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + +@core.extern +def div_rn(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("llvm.stvm.div.rn.f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("llvm.stvm.div.rn.f32", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + +@core.extern +def trunc(arg0, _semantic=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("llvm.stvm.trunc.f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("llvm.stvm.trunc.f", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + +@core.extern +def fmod(arg0, arg1, _semantic=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fmod_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fmod_f32", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) + +@core.extern +def isfinited(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__ocml_isfinite_f64", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic).to(core.int1, _semantic=_semantic) + +@core.extern +def finitef(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_isfinite_f32", core.dtype("int32")), + }, is_pure=True, _semantic=_semantic).to(core.int1, _semantic=_semantic) + +@core.extern +def rint(arg0, _semantic=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("llvm.rint.f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("llvm.rint.f32", core.dtype("fp64")), + }, is_pure=True, _semantic=_semantic) \ No newline at end of file diff --git a/third_party/sunrise/python/test_examples/01-vector-add.py b/third_party/sunrise/python/test_examples/vector-add.py similarity index 93% rename from third_party/sunrise/python/test_examples/01-vector-add.py rename to third_party/sunrise/python/test_examples/vector-add.py index 3619d7ec5..0cc507aa9 100644 --- a/third_party/sunrise/python/test_examples/01-vector-add.py +++ b/third_party/sunrise/python/test_examples/vector-add.py @@ -23,8 +23,6 @@ import triton import triton.language as tl -DEVICE = triton.runtime.driver.active.get_active_torch_device() - @triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. @@ -62,7 +60,6 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. def add(x: torch.Tensor, y: torch.Tensor): # We need to preallocate the output. output = torch.empty_like(x) - assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE n_elements = output.numel() # The SPMD launch grid denotes the number of kernel instances that run in parallel. # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. @@ -81,16 +78,22 @@ def add(x: torch.Tensor, y: torch.Tensor): # %% # We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness: -torch.manual_seed(0) size = 98432 -x = torch.rand(size, device=DEVICE) -y = torch.rand(size, device=DEVICE) +x = torch.rand(size) +y = torch.rand(size) output_torch = x + y -output_triton = add(x, y) +x_dev = x.to(device='ptpu') +y_dev = y.to(device='ptpu') +output_triton_dev = add(x_dev, y_dev) +output_triton_cpu = output_triton_dev.to(device='cpu') +print('output_torch:') print(output_torch) -print(output_triton) +print('output_triton:') +print(output_triton_cpu) print(f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch - output_triton))}') + f'{torch.max(torch.abs(output_torch - output_triton_cpu))}') + +exit(0) # %% # Seems like we're good to go! diff --git a/third_party/sunrise/python/triton/backends b/third_party/sunrise/python/triton/backends deleted file mode 120000 index 13a83a85c..000000000 --- a/third_party/sunrise/python/triton/backends +++ /dev/null @@ -1 +0,0 @@ -../../../../python/triton/backends \ No newline at end of file diff --git a/third_party/sunrise/python/triton/language/extra/cuda b/third_party/sunrise/python/triton/language/extra/cuda deleted file mode 120000 index fc5f8a28a..000000000 --- a/third_party/sunrise/python/triton/language/extra/cuda +++ /dev/null @@ -1 +0,0 @@ -/mnt/data/lisirui/TritonDev/Repos/tb/tt34/triton/third_party/nvidia/language/cuda \ No newline at end of file diff --git a/third_party/sunrise/python/triton/language/extra/hip b/third_party/sunrise/python/triton/language/extra/hip deleted file mode 120000 index dbeb20d81..000000000 --- a/third_party/sunrise/python/triton/language/extra/hip +++ /dev/null @@ -1 +0,0 @@ -/mnt/data/lisirui/TritonDev/Repos/tb/tt34/triton/third_party/amd/language/hip \ No newline at end of file diff --git a/third_party/sunrise/python/triton/language/extra/tang b/third_party/sunrise/python/triton/language/extra/tang deleted file mode 120000 index 16c8cfeaa..000000000 --- a/third_party/sunrise/python/triton/language/extra/tang +++ /dev/null @@ -1 +0,0 @@ -/mnt/data/lisirui/TritonDev/Repos/tb/tt34/triton/third_party/sunrise/language/tang \ No newline at end of file diff --git a/third_party/sunrise/python/triton/runtime/jit.py b/third_party/sunrise/python/triton/runtime/jit.py index 70d8341ac..e951f9aea 100644 --- a/third_party/sunrise/python/triton/runtime/jit.py +++ b/third_party/sunrise/python/triton/runtime/jit.py @@ -549,7 +549,7 @@ def create_binder(self): def run(self, *args, grid, warmup, **kwargs): kwargs["debug"] = kwargs.get("debug", self.debug) or knobs.runtime.debug - + # parse options device = driver.active.get_current_device() stream = driver.active.get_current_stream(device) diff --git a/third_party/sunrise/script/docker_clean_sunrise_flagtree.sh b/third_party/sunrise/script/docker_clean_sunrise_flagtree.sh new file mode 100644 index 000000000..a0b324a63 --- /dev/null +++ b/third_party/sunrise/script/docker_clean_sunrise_flagtree.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +# 需要在triton的根目录执行这个脚本 +if [ ! -d 'python/triton' ] || [ ! -d 'third_party/sunrise' ] ; then + echo "This script must be executed in triton project root directory!" + exit 1 +fi + +if [ $# -eq 1 ] && [ $1 = 'all' ] ; then + rm -f python/triton/FileCheck + rm -f third_party/sunrise/backend/lib/*.bc +fi + +rm -rf python/triton.egg-info +rm -rf python/triton/_C +rm -rf build + +echo "--- OK ---" diff --git a/third_party/sunrise/script/docker_sunrise_env.sh b/third_party/sunrise/script/docker_sunrise_env.sh new file mode 100644 index 000000000..ec4c0c46f --- /dev/null +++ b/third_party/sunrise/script/docker_sunrise_env.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +function print_usage() { + echo "Usage: source build_ttenv.sh " +} + +# 确保这个脚本由 source 命令执行 +if [[ "$0" == "${BASH_SOURCE[0]}" ]]; then + echo "This script must be executed by 'source' or '.' command." + exit 1 +fi + +if [ $# -ne 1 ] || [ ! -d $1 ] ; then + print_usage + return 1 +fi + +# 需要在triton的根目录执行这个脚本 +if [ ! -d 'python/triton' ] || [ ! -d 'third_party/sunrise' ] ; then + echo "This script must be executed in triton project root directory!" + return 1 +fi + +LLVM_INSTALL_DIR=$1 +export PYBIND11_SYSPATH=$CONDA_ENV_DIR/lib/python3.10/site-packages/pybind11/ # you can see by pip show pybind11 +export PYBIND11_INCLUDE_DIR=$PYBIND11_SYSPATH/include +export LLVM_INCLUDE_DIRS=$LLVM_INSTALL_DIR/include +export LLVM_LIBRARY_DIR=$LLVM_INSTALL_DIR/lib +export LLVM_SYSPATH=$LLVM_INSTALL_DIR +export MLIR_DIR=$LLVM_LIBRARY_DIR/cmake/mlir + +export TRITON_OFFLINE_BUILD=1 +export TRITON_BUILD_PROTON=OFF +export TRITON_BUILD_WITH_CLANG_LLD=1 +export MAX_JOBS=50 +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/tangrt/lib/linux-x86_64/stub + +export FLAGTREE_BACKEND=sunrise + +# 拷贝install_dir中缺失的bitcode和Filecheck + +cp $1/stpu/bitcode/*.bc third_party/sunrise/backend/lib +if [ $? -ne 0 ] ; then + echo "copy stpu bitcode failed." + return 1 +fi + +# 必须有libtang.so.0.19.2这个文件 +if [ ! -f /usr/local/tangrt/lib/linux-x86_64/stub/libtang.so.0.19.2 ] ; then + ln -s /usr/local/tangrt/lib/linux-x86_64/stub/libtang.so /usr/local/tangrt/lib/linux-x86_64/stub/libtang.so.0.19.2 +fi + +echo "--- OK ---" From b1291cf1e00fa9a253004a164679e47c248595db Mon Sep 17 00:00:00 2001 From: lisirui1 Date: Mon, 19 Jan 2026 20:36:59 +0800 Subject: [PATCH 4/7] fix bug --- .gitignore | 1 - lib/Dialect/TritonGPU/IR/CMakeLists.txt | 1 + python/triton/backends/driver.py | 46 ++-------- .../python/triton/backends/__init__.py | 47 ++++++++++ .../python/triton/backends/compiler.py | 90 +++++++++++++++++++ .../sunrise/python/triton/backends/driver.py | 85 ++++++++++++++++++ 6 files changed, 230 insertions(+), 40 deletions(-) create mode 100644 third_party/sunrise/python/triton/backends/__init__.py create mode 100644 third_party/sunrise/python/triton/backends/compiler.py create mode 100644 third_party/sunrise/python/triton/backends/driver.py diff --git a/.gitignore b/.gitignore index 29d98a797..bb7ed373f 100644 --- a/.gitignore +++ b/.gitignore @@ -68,7 +68,6 @@ ptxas third_party/nvidia/backend/include third_party/nvidia/backend/lib/cupti third_party/sunrise/backend/lib -third_party/sunrise/python/triton/backends/* # Docs docs/_build/ diff --git a/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/lib/Dialect/TritonGPU/IR/CMakeLists.txt index 855a7162d..dc6104c70 100644 --- a/lib/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -16,4 +16,5 @@ add_triton_library(TritonGPUIR MLIRGPUDialect TritonIR TritonTools + ${_EXTRA_LINK_LIBS} ) diff --git a/python/triton/backends/driver.py b/python/triton/backends/driver.py index b18fe9cc7..8eeb09dc3 100644 --- a/python/triton/backends/driver.py +++ b/python/triton/backends/driver.py @@ -39,47 +39,15 @@ class GPUDriver(DriverBase): def __init__(self): # TODO: support other frameworks than torch import torch - try: - import torch_ptpu - _is_ptpu = True - except ImportError as e: - _is_ptpu = False - if _is_ptpu: - self.get_device_capability = torch.ptpu.get_device_capability - self.get_current_stream = lambda dev_idx: torch.ptpu.current_stream(dev_idx).ptpu_stream - self.get_current_device = torch.ptpu.current_device - self.set_current_device = torch.ptpu.set_device - return - + self.get_device_capability = torch.cuda.get_device_capability try: from torch._C import _cuda_getCurrentRawStream - _is_cuda = True - except ImportError as e: - _cuda_getCurrentRawStream = None - _is_cuda = True if torch.version.cuda else False - if _is_cuda: - self.get_device_capability = torch.cuda.get_device_capability - if _cuda_getCurrentRawStream is not None: - self.get_current_stream = _cuda_getCurrentRawStream - else: - self.get_current_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream - self.get_current_device = torch.cuda.current_device - self.set_current_device = torch.cuda.set_device - return - - try: - import torch_dipu - _is_dipu = True - except ImportError as e: - _is_dipu = False - if _is_dipu: - self.get_device_capability = torch.cuda.get_device_capability - self.get_current_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).dipu_stream - self.get_current_device = torch.cuda.current_device - self.set_current_device = torch.cuda.set_device - return - + self.get_current_stream = _cuda_getCurrentRawStream + except ImportError: + self.get_current_stream = lambda idx: torch.cuda.current_stream(idx).cuda_stream + self.get_current_device = torch.cuda.current_device + self.set_current_device = torch.cuda.set_device # TODO: remove once TMA is cleaned up def assemble_tensormap_to_arg(self, tensormaps_info, args): - return args + return args \ No newline at end of file diff --git a/third_party/sunrise/python/triton/backends/__init__.py b/third_party/sunrise/python/triton/backends/__init__.py new file mode 100644 index 000000000..69a8dab0a --- /dev/null +++ b/third_party/sunrise/python/triton/backends/__init__.py @@ -0,0 +1,47 @@ +import importlib +import inspect +import sys +from dataclasses import dataclass +from typing import Type, TypeVar, Union +from types import ModuleType +from .driver import DriverBase +from .compiler import BaseBackend + +if sys.version_info >= (3, 10): + from importlib.metadata import entry_points +else: + from importlib_metadata import entry_points + +T = TypeVar("T", bound=Union[BaseBackend, DriverBase]) + + +def _find_concrete_subclasses(module: ModuleType, base_class: Type[T]) -> Type[T]: + ret: list[Type[T]] = [] + for attr_name in dir(module): + attr = getattr(module, attr_name) + if isinstance(attr, type) and issubclass(attr, base_class) and not inspect.isabstract(attr): + ret.append(attr) + if len(ret) == 0: + raise RuntimeError(f"Found 0 concrete subclasses of {base_class} in {module}: {ret}") + if len(ret) > 1: + raise RuntimeError(f"Found >1 concrete subclasses of {base_class} in {module}: {ret}") + return ret[0] + + +@dataclass(frozen=True) +class Backend: + compiler: Type[BaseBackend] + driver: Type[DriverBase] + + +def _discover_backends() -> dict[str, Backend]: + backends = dict() + for ep in entry_points().select(group="triton.backends"): + compiler = importlib.import_module(f"{ep.value}.compiler") + driver = importlib.import_module(f"{ep.value}.driver") + backends[ep.name] = Backend(_find_concrete_subclasses(compiler, BaseBackend), # type: ignore + _find_concrete_subclasses(driver, DriverBase)) # type: ignore + return backends + + +backends: dict[str, Backend] = _discover_backends() diff --git a/third_party/sunrise/python/triton/backends/compiler.py b/third_party/sunrise/python/triton/backends/compiler.py new file mode 100644 index 000000000..9bbc5eadb --- /dev/null +++ b/third_party/sunrise/python/triton/backends/compiler.py @@ -0,0 +1,90 @@ +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Dict, Union +from types import ModuleType + + +@dataclass(frozen=True) +class GPUTarget(object): + # Target backend, e.g., cuda, hip + backend: str + # Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip) + arch: Union[int, str] + warp_size: int + + +class Language(Enum): + """The input language being compiled by the backend.""" + TRITON = 0 + GLUON = 1 + + +class BaseBackend(metaclass=ABCMeta): + + def __init__(self, target: GPUTarget) -> None: + self.target = target + assert self.supports_target(target) + + @staticmethod + @abstractmethod + def supports_target(target: GPUTarget): + raise NotImplementedError + + @abstractmethod + def hash(self) -> str: + """Returns a unique identifier for this backend""" + raise NotImplementedError + + @abstractmethod + def parse_options(self, options: dict) -> object: + """ + Converts an `options` dictionary into an arbitrary object and returns it. + This function may contain target-specific heuristics and check the legality of the provided options + """ + raise NotImplementedError + + @abstractmethod + def add_stages(self, stages: dict, options: object) -> None: + """ + Populates `stages` dictionary with entries of the form: + ir_name [str] => Function[(src: str, metadata: dict) -> str|bytes] + The value of each entry may populate a `metadata` dictionary. + Stages will be run sequentially (in inseriton order) and can communicate using `metadata`. + All stages are expected to return a `str` object, except for the last stage which returns + a `bytes` object for execution by the launcher. + """ + raise NotImplementedError + + @abstractmethod + def load_dialects(self, context): + """ + Load additional MLIR dialects into the provided `context` + """ + raise NotImplementedError + + @abstractmethod + def get_module_map(self) -> Dict[str, ModuleType]: + """ + Return a map of interface modules to their device-specific implementations + """ + raise NotImplementedError + + @staticmethod + def parse_attr(desc): + assert isinstance(desc, str) + ret = [] + if "D" in desc: + ret += [["tt.divisibility", 16]] + return ret + + @staticmethod + def get_arg_specialization(arg, ty, **kwargs): + """ + Return a string unique to each possible specialization of the argument + """ + if ty == "int" and arg % 16 == 0 and kwargs.get("align", False): + return "D" + if ty == "tensor" and arg.data_ptr() % 16 == 0 and kwargs.get("align", False): + return "D" + return "" diff --git a/third_party/sunrise/python/triton/backends/driver.py b/third_party/sunrise/python/triton/backends/driver.py new file mode 100644 index 000000000..b18fe9cc7 --- /dev/null +++ b/third_party/sunrise/python/triton/backends/driver.py @@ -0,0 +1,85 @@ +from abc import ABCMeta, abstractmethod +from typing import Callable, List, Protocol, Sequence + + +class Benchmarker(Protocol): + + def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -> Sequence[float]: + pass + + +class DriverBase(metaclass=ABCMeta): + + @classmethod + @abstractmethod + def is_active(self): + pass + + @abstractmethod + def get_current_target(self): + pass + + @abstractmethod + def get_active_torch_device(self): + pass + + @abstractmethod + def get_benchmarker(self) -> Benchmarker: + """ + Return the benchmarking function that this backend should use by default. + """ + raise NotImplementedError + + def __init__(self) -> None: + pass + + +class GPUDriver(DriverBase): + + def __init__(self): + # TODO: support other frameworks than torch + import torch + try: + import torch_ptpu + _is_ptpu = True + except ImportError as e: + _is_ptpu = False + if _is_ptpu: + self.get_device_capability = torch.ptpu.get_device_capability + self.get_current_stream = lambda dev_idx: torch.ptpu.current_stream(dev_idx).ptpu_stream + self.get_current_device = torch.ptpu.current_device + self.set_current_device = torch.ptpu.set_device + return + + try: + from torch._C import _cuda_getCurrentRawStream + _is_cuda = True + except ImportError as e: + _cuda_getCurrentRawStream = None + _is_cuda = True if torch.version.cuda else False + if _is_cuda: + self.get_device_capability = torch.cuda.get_device_capability + if _cuda_getCurrentRawStream is not None: + self.get_current_stream = _cuda_getCurrentRawStream + else: + self.get_current_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream + self.get_current_device = torch.cuda.current_device + self.set_current_device = torch.cuda.set_device + return + + try: + import torch_dipu + _is_dipu = True + except ImportError as e: + _is_dipu = False + if _is_dipu: + self.get_device_capability = torch.cuda.get_device_capability + self.get_current_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).dipu_stream + self.get_current_device = torch.cuda.current_device + self.set_current_device = torch.cuda.set_device + return + + + # TODO: remove once TMA is cleaned up + def assemble_tensormap_to_arg(self, tensormaps_info, args): + return args From 45a9e49d6f1dd447825d37996c42eb2f3000eafb Mon Sep 17 00:00:00 2001 From: lisirui1 Date: Mon, 19 Jan 2026 21:01:56 +0800 Subject: [PATCH 5/7] add _EXTRA_LINK_LIBS for cmake --- lib/Analysis/CMakeLists.txt | 1 + lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 2 ++ lib/Conversion/TritonToTritonGPU/CMakeLists.txt | 2 ++ lib/Dialect/TritonGPU/IR/CMakeLists.txt | 1 + lib/Dialect/TritonGPU/Transforms/CMakeLists.txt | 2 ++ lib/Target/LLVMIR/CMakeLists.txt | 4 +++- 6 files changed, 11 insertions(+), 1 deletion(-) diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index ae1c60067..0aacae2cb 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -25,4 +25,5 @@ add_triton_library(TritonAnalysis TritonIR TritonGPUIR TritonNvidiaGPUIR + ${_EXTRA_LINK_LIBS} ) diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index 081870fd3..182a65b2f 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -46,4 +46,6 @@ add_triton_library(TritonGPUToLLVM TritonGPUIR TritonGPUTransforms TritonNvidiaGPUTransforms + + ${_EXTRA_LINK_LIBS} ) diff --git a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt index b8a2a1297..17a0ebc42 100644 --- a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt +++ b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -15,4 +15,6 @@ add_triton_library(TritonToTritonGPU TritonIR ProtonIR TritonGPUIR + + ${_EXTRA_LINK_LIBS} ) diff --git a/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/lib/Dialect/TritonGPU/IR/CMakeLists.txt index dc6104c70..d697ef366 100644 --- a/lib/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -16,5 +16,6 @@ add_triton_library(TritonGPUIR MLIRGPUDialect TritonIR TritonTools + ${_EXTRA_LINK_LIBS} ) diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 4d82119d6..0978525ef 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -59,4 +59,6 @@ add_triton_library(TritonGPUTransforms TritonNvidiaGPUIR TritonToTritonGPU MLIRTransformUtils + + ${_EXTRA_LINK_LIBS} ) diff --git a/lib/Target/LLVMIR/CMakeLists.txt b/lib/Target/LLVMIR/CMakeLists.txt index c3a0010b8..0f4a3b0a5 100644 --- a/lib/Target/LLVMIR/CMakeLists.txt +++ b/lib/Target/LLVMIR/CMakeLists.txt @@ -22,7 +22,9 @@ add_triton_library(TritonLLVMIR MLIRSupport MLIRTargetLLVMIRExport TritonGPUToLLVM - ) + + ${_EXTRA_LINK_LIBS} +) set_source_files_properties( LLVMIRTranslation.cpp From 48df8bca9ca5136fd88cf26488c6a6efe376345f Mon Sep 17 00:00:00 2001 From: lisirui1 Date: Tue, 20 Jan 2026 15:26:03 +0800 Subject: [PATCH 6/7] add setup_helper for sunrise --- setup_tools/setup_helper.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/setup_tools/setup_helper.py b/setup_tools/setup_helper.py index 7ee3dc84b..194322dcb 100644 --- a/setup_tools/setup_helper.py +++ b/setup_tools/setup_helper.py @@ -410,3 +410,18 @@ def check_env(env_val): pre_hock=lambda: check_env('LLVM_SYSPATH'), post_hock=set_llvm_env, ) + +# sunrise +cache.store( + file="sunrise-llvm21-x86_64", + condition=("sunrise" == flagtree_backend), + url = "https://abc.efg", + pre_hock=lambda: check_env('LLVM_SYSPATH'), + post_hock=set_llvm_env, +) + +cache.store( + file="sunriseTritonPlugin.so", condition=("sunrise" == flagtree_backend) and (not flagtree_plugin), url= + "https://abc.efg", + copy_dst_path=f"third_party/{flagtree_backend}", md5_digest="1f0b7e67" +) From da806835debbae7a20f9412a730a4bf5aa880b1c Mon Sep 17 00:00:00 2001 From: lisirui1 Date: Tue, 20 Jan 2026 16:02:04 +0800 Subject: [PATCH 7/7] fix bug when FLAGTREE_PLUGIN=0 --- setup_tools/setup_helper.py | 38 +++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/setup_tools/setup_helper.py b/setup_tools/setup_helper.py index 194322dcb..4c9554ab1 100644 --- a/setup_tools/setup_helper.py +++ b/setup_tools/setup_helper.py @@ -4,6 +4,7 @@ import functools from pathlib import Path import hashlib +import sysconfig from distutils.sysconfig import get_python_lib from . import utils @@ -31,13 +32,20 @@ def install_extension(*args, **kargs): except Exception: pass - def get_backend_cmake_args(*args, **kargs): + if "editable_wheel" in sys.argv: + editable = True + else: + editable = False + handle_plugin_backend(editable) try: - return activated_module.get_backend_cmake_args(*args, **kargs) + # cmake_args = configs.activated_module.get_backend_cmake_args(*args, **kargs) + cmake_args = activated_module.get_backend_cmake_args(*args, **kargs) except Exception: - return [] - + cmake_args = [] + if editable: + cmake_args += ["-DEDITABLE_MODE=ON"] + return cmake_args def get_device_name(): return device_mapping[flagtree_backend] @@ -282,6 +290,26 @@ def handle_flagtree_backend(): if "editable_wheel" in sys.argv and flagtree_backend not in plugin_backends: ext_sourcedir = os.path.abspath(f"./third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/" +def handle_plugin_backend(editable): + if flagtree_backend in ["iluvatar", "mthreads", "sunrise"]: + if editable is False: + src_build_plugin_path = str( + os.getenv("HOME")) + "/.flagtree/" + flagtree_backend + "/" + flagtree_backend + "TritonPlugin.so" + dst_build_plugin_dir = sysconfig.get_paths()['purelib'] + "/triton/_C" + if not os.path.exists(dst_build_plugin_dir): + os.makedirs(dst_build_plugin_dir) + dst_build_plugin_path = dst_build_plugin_dir + "/" + flagtree_backend + "TritonPlugin.so" + shutil.copy(src_build_plugin_path, dst_build_plugin_path) + src_install_plugin_path = str( + os.getenv("HOME")) + "/.flagtree/" + flagtree_backend + "/" + flagtree_backend + "TritonPlugin.so" + if flagtree_backend in ("mthreads", "sunrise"): + dst_install_plugin_dir = os.path.dirname( + os.path.abspath(__file__)) + "/../../third_party/" + flagtree_backend + "/python/triton/_C" + else: + dst_install_plugin_dir = os.path.dirname(os.path.abspath(__file__)) + "/../triton/_C" + if not os.path.exists(dst_install_plugin_dir): + os.makedirs(dst_install_plugin_dir) + shutil.copy(src_install_plugin_path, dst_install_plugin_dir) def set_env(env_dict: dict): for env_k, env_v in env_dict.items(): @@ -306,6 +334,8 @@ def check_env(env_val): download_flagtree_third_party("flir", condition=(flagtree_backend == "aipu"), hock=utils.aipu.precompile_hock, required=True) +handle_plugin_backend(False) + handle_flagtree_backend() cache = FlagTreeCache()