diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml new file mode 100644 index 00000000..cff80136 --- /dev/null +++ b/.github/workflows/_build.yml @@ -0,0 +1,227 @@ +name: ~Build wheel template + +on: + workflow_call: + inputs: + runs-on: + description: "The runner to use for the build" + required: true + type: string + python-version: + description: "The Python version to use for the build" + required: true + type: string + cuda-version: + description: "The CUDA version to use for the build" + required: true + type: string + torch-version: + description: "The PyTorch version to use for the build" + required: true + type: string + cxx11_abi: + description: "The C++11 ABI to use for the build" + required: true + type: string + upload-to-release: + description: "Upload wheel to this release" + required: false + type: boolean + default: false + release-version: + description: "Upload wheel to this release" + required: false + type: string + use-local-version: + description: "Use local version" + required: false + type: boolean + default: false + +defaults: + run: + shell: bash -x -e -u -o pipefail {0} + +jobs: + build-wheel: + runs-on: ${{ inputs.runs-on }} + name: Build wheel (${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}) + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ inputs.release-version }} + submodules: recursive + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python-version }} + + - name: Set CUDA and PyTorch versions + run: | + echo "MATRIX_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV + echo "MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV + echo "WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV + echo "MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV + + - name: Free up disk space + if: ${{ runner.os == 'Linux' }} + # https://github.com/easimon/maximize-build-space/blob/master/action.yml + # https://github.com/easimon/maximize-build-space/tree/test-report + run: | + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + + - name: Set up swap space + if: runner.os == 'Linux' + uses: pierotofy/set-swap-space@v1.0 + with: + swap-size-gb: 10 + + - name: Install CUDA ${{ inputs.cuda-version }} + if: ${{ inputs.cuda-version != 'cpu' }} + uses: Jimver/cuda-toolkit@v0.2.26 + id: cuda-toolkit + with: + cuda: ${{ inputs.cuda-version }} + linux-local-args: '["--toolkit"]' + # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 + # method: ${{ (inputs.cuda-version == '11.8.0' || inputs.cuda-version == '12.1.0') && 'network' || 'local' }} + method: "network" + + - name: Install additional CUDA libraries + run: | + CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1 "-" $2'}) + sudo apt-get update + sudo apt-get install -y libcusparse-$CUDA_VERSION libcusolver-$CUDA_VERSION + sudo apt-get clean + + - name: Install PyTorch ${{ inputs.torch-version }}+cu${{ inputs.cuda-version }} + run: | + pip install --upgrade pip + # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error + # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable + pip install typing-extensions==4.12.2 + # We want to figure out the CUDA version to download pytorch + # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 + # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix + # This code is ugly, maybe there's a better way to do this. + export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ + minv = {'2.4': 118, '2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126}[env['MATRIX_TORCH_VERSION']]; \ + maxv = {'2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129}[env['MATRIX_TORCH_VERSION']]; \ + print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ + ) + if [[ ${{ inputs.torch-version }} == *"dev"* ]]; then + # pip install --no-cache-dir --pre torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} + # Can't use --no-deps because we need cudnn etc. + # Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001 + pip install jinja2 + pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl + pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl + else + pip install --no-cache-dir torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} + fi + nvcc --version + python --version + python -c "import torch; print('PyTorch:', torch.__version__)" + python -c "import torch; print('CUDA:', torch.version.cuda)" + python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" + + - name: Restore build cache + uses: actions/cache/restore@v4 + with: + path: build.tar + key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }} + restore-keys: | + build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}- + + - name: Unpack build cache + run: | + echo ::group::Adjust timestamps + sudo find / -exec touch -t 197001010000 {} + || true + echo ::endgroup:: + + if [ -f build.tar ]; then + find . -mindepth 1 -maxdepth 1 ! -name 'build.tar' -exec rm -rf {} + + tar -xpvf build.tar -C . + else + echo "No build.tar found, skipping" + fi + + ls -al ./ + ls -al build/ || true + ls -al csrc/ || true + + - name: Build wheel + id: build_wheel + run: | + # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 + # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 + # However this still fails so I'm using a newer version of setuptools + pip install setuptools==75.8.0 + pip install ninja packaging wheel + export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH + export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH + # Limit MAX_JOBS otherwise the github runner goes OOM + # nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM + + export MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] && echo 1 || echo 2) + export NVCC_THREADS=2 + export TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX" + export DG_USE_LOCAL_VERSION=${{ inputs.use-local-version && '1' || '0' }} + + # 5h timeout since GH allows max 6h and we want some buffer + EXIT_CODE=0 + timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$? + + if [ $EXIT_CODE -eq 0 ]; then + tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }} + wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") + ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} + echo "wheel_name=${wheel_name}" >> $GITHUB_ENV + fi + + # Store exit code in GitHub env for later steps + echo "build_exit_code=$EXIT_CODE" | tee -a "$GITHUB_OUTPUT" + + # Do not fail the job if timeout killed the build + exit $EXIT_CODE + + - name: Log build logs after timeout + if: always() && steps.build_wheel.outputs.build_exit_code == 124 + run: | + ls -al ./ + tar -cvf build.tar . --atime-preserve=replace + + - name: Save build cache timeout + if: always() && steps.build_wheel.outputs.build_exit_code == 124 + uses: actions/cache/save@v4 + with: + key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }} + path: build.tar + + - name: Log Built Wheels + run: | + ls dist + + - name: Get Release with tag + id: get_current_release + uses: joutvhu/get-release@v1 + with: + tag_name: ${{ inputs.release-version }} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Upload Release Asset + id: upload_release_asset + if: inputs.upload-to-release + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ steps.get_current_release.outputs.upload_url }} + asset_path: ./dist/${{env.wheel_name}} + asset_name: ${{env.wheel_name}} + asset_content_type: application/* diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 00000000..ee250aa4 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,53 @@ +name: Build wheels + +on: + workflow_dispatch: + inputs: + runs-on: + description: "The runner to use for the build" + required: true + type: string + default: ubuntu-22.04 + python-version: + description: "The Python version to use for the build" + required: true + type: string + cuda-version: + description: "The CUDA version to use for the build" + required: true + type: string + torch-version: + description: "The PyTorch version to use for the build" + required: true + type: string + cxx11_abi: + description: "Enable torch flag C++11 ABI (TRUE/FALSE)" + required: true + type: string + upload-to-release: + description: "Upload wheel to this release" + required: false + type: boolean + default: false + release-version: + description: "Upload wheel to this release" + required: false + type: string + use-local-version: + description: "Use local version" + required: false + type: boolean + default: false + +jobs: + build-wheels: + uses: ./.github/workflows/_build.yml + with: + runs-on: ${{ inputs.runs-on }} + python-version: ${{ inputs.python-version }} + cuda-version: ${{ inputs.cuda-version }} + torch-version: ${{ inputs.torch-version }} + cxx11_abi: ${{ inputs.cxx11_abi }} + upload-to-release: ${{ inputs.upload-to-release }} + release-version: ${{ inputs.release-version }} + use-local-version: ${{ inputs.use-local-version }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 00000000..a7b3e6b8 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,95 @@ +# This workflow will: +# - Create a new Github release +# - Build wheels for supported architectures +# - Deploy the wheels to the Github release +# - Release the static code to PyPi +# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries + +name: Build wheels and deploy + +on: + create: + tags: + - v* + +jobs: + setup_release: + name: Create Release + runs-on: ubuntu-latest + outputs: + release-version: ${{ steps.extract_branch.outputs.branch }} + steps: + - name: Get the tag version + id: extract_branch + run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} + shell: bash + - name: Create Release + id: create_release + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ steps.extract_branch.outputs.branch }} + release_name: ${{ steps.extract_branch.outputs.branch }} + + build_wheels: + name: Build Wheel + needs: setup_release + strategy: + fail-fast: false + matrix: + # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the + # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. + os: [ubuntu-22.04] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] + torch-version: ["2.4.0", "2.5.1", "2.6.0", "2.7.1", "2.8.0"] + cuda-version: ["12.9.1"] + # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. + # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. + # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) + # when building without C++11 ABI and using it on nvcr images. + cxx11_abi: ["FALSE", "TRUE"] + exclude: + # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix + # Pytorch < 2.5 does not support Python 3.13 + - torch-version: "2.4.0" + python-version: "3.13" + uses: ./.github/workflows/_build.yml + with: + runs-on: ${{ matrix.os }} + python-version: ${{ matrix.python-version }} + cuda-version: ${{ matrix.cuda-version }} + torch-version: ${{ matrix.torch-version }} + cxx11_abi: ${{ matrix.cxx11_abi }} + release-version: ${{ needs.setup_release.outputs.release-version }} + upload-to-release: true + use-local-version: false + + publish_package: + name: Publish package + needs: [build_wheels] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: Install dependencies + run: | + pip install ninja packaging wheel twine + # Install latest setuptools with support for pypi metadata 2.2 (improved compat w/ uv) + pip install setuptools==75.8.0 + # We don't want to download anything CUDA-related here + pip install torch --index-url https://download.pytorch.org/whl/cpu + - name: Build core package + env: + DG_USE_LOCAL_VERSION: "0" + DG_SKIP_CUDA_BUILD: "1" + run: | + python setup.py sdist --dist-dir=dist + - name: Deploy + env: + TWINE_USERNAME: "__token__" + TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} + run: | + python -m twine upload dist/* diff --git a/.gitignore b/.gitignore index 3e6e4e5a..d0cdf6ca 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,16 @@ dist # Third-party links created by `setup.py develop` deep_gemm/include/cute deep_gemm/include/cutlass + +# VS Code settings +/.vscode + +# clangd settings +/.clang* +/.cache + +# Generated stub files +stubs/ + +# Symlinks to compiled extensions +deep_gemm/*.so \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 6f12a969..79f1964d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,8 +26,8 @@ include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORC link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs) # The main Python API entrance -pybind11_add_module(deep_gemm_cpp csrc/python_api.cpp) -target_link_libraries(deep_gemm_cpp PRIVATE ${TORCH_LIBRARIES} torch_python cuda) +pybind11_add_module(_C csrc/python_api.cpp) +target_link_libraries(_C PRIVATE ${TORCH_LIBRARIES} torch_python) # Enable kernel code indexing with CMake-based IDEs cuda_add_library(deep_gemm_indexing_cuda STATIC csrc/indexing/main.cu) diff --git a/README.md b/README.md index 0c208695..04a289db 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert - [x] Fix TMA multicast compatibility for indivisible shapes - [x] Skip useless computation on M - [x] NVRTC as a faster compiler -- [ ] Sanitizer for testing +- [x] Sanitizer for testing - [x] Weight gradient kernels for dense models - [x] Weight gradient kernels for MoE models - [ ] Better `get_best_configs` modeling @@ -69,9 +69,7 @@ cat develop.sh # Test all GEMM implements python tests/test_layout.py python tests/test_attention.py -python tests/test_bf16.py -python tests/test_fp8.py -python tests/test_lazy_init.py +python tests/test_core.py ``` ### Installation diff --git a/csrc/apis/attention.hpp b/csrc/apis/attention.hpp index 9a40394e..4fbe9930 100644 --- a/csrc/apis/attention.hpp +++ b/csrc/apis/attention.hpp @@ -1,21 +1,25 @@ #pragma once +#include "../utils/compatibility.hpp" + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE #include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp" #include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" #include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp" -#include "../jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp" #include "../jit_kernels/impls/smxx_fp8_mqa_logits.hpp" #include "../jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp" #include "../jit_kernels/impls/smxx_clean_logits.hpp" +#endif #include "layout.hpp" namespace deep_gemm::attention { +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE static void fp8_gemm_nt_skip_head_mid(const std::pair& a, const std::pair& b, const torch::Tensor& d, - const std::tuple &head_splits, + const std::tuple& head_splits, std::optional> recipe, const std::string& compiled_dims, const bool& disable_ue8m0_cast) { @@ -49,21 +53,21 @@ static void fp8_gemm_nt_skip_head_mid(const std::pairget_arch_major(); const auto& epilogue_type = fmt::format("EpilogueHeadSplits<{}, {}, {}>", left, mid, right); if (arch_major == 9 and sfa.scalar_type() == torch::kFloat and std::get<1>(recipe.value()) != 1) { - sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type); + const auto& major_sfb = get_major_type_ab(sfb); + sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, major_sfb, compiled_dims, epilogue_type); } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { - sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type); - } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { - sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type); + // NOTES: Only granularity 128 and FP8 are exposed in the API + sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, + 128, 128, major_a, major_b, compiled_dims, epilogue_type); } else { DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); } @@ -74,7 +78,8 @@ static torch::Tensor fp8_mqa_logits(const torch::Tensor& q, const torch::Tensor& weights, const torch::Tensor& cu_seq_len_k_start, const torch::Tensor& cu_seq_len_k_end, - const bool& clean_logits) { + const bool& clean_logits, + const int& max_seqlen_k) { const auto& [seq_len, num_heads, head_dim] = get_shape<3>(q); const auto& [seq_len_kv, head_dim_] = get_shape<2>(kv.first); const auto& [seq_len_, num_heads_] = get_shape<2>(weights); @@ -102,27 +107,45 @@ static torch::Tensor fp8_mqa_logits(const torch::Tensor& q, constexpr int seq_len_alignment = 4; constexpr int block_kv = 256; const auto aligned_seq_len = align(seq_len, seq_len_alignment); - const auto aligned_seq_len_kv = align(seq_len_kv + block_kv, 4); - auto logits = torch::empty({aligned_seq_len, aligned_seq_len_kv}, q.options().dtype(torch::kFloat)); - logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, seq_len_kv)}); + + torch::Tensor logits; + int stride_logits; + if (max_seqlen_k == 0) { + stride_logits = align(seq_len_kv + block_kv, 4); + logits = torch::empty({aligned_seq_len, stride_logits}, q.options().dtype(torch::kFloat)); + logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, seq_len_kv)}); + } else { + stride_logits = align(max_seqlen_k, block_kv); + logits = torch::empty({aligned_seq_len, stride_logits}, q.options().dtype(torch::kFloat)); + logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, max_seqlen_k)}); + DG_HOST_ASSERT(not clean_logits); + } // Dispatch implementation const auto& arch_major = device_runtime->get_arch_major(); if (arch_major == 9 or arch_major == 10) { smxx_fp8_mqa_logits(q, kv.first, kv.second, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, - seq_len, seq_len_kv, aligned_seq_len_kv, num_heads, head_dim, seq_len_alignment); + seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, seq_len_alignment); } else { DG_HOST_UNREACHABLE("Unsupported architecture"); } // Clean unfilled logits if (clean_logits) - smxx_clean_logits(logits, cu_seq_len_k_start, cu_seq_len_k_end, 1, seq_len, seq_len_kv, aligned_seq_len_kv); + smxx_clean_logits(logits, cu_seq_len_k_start, cu_seq_len_k_end, 1, seq_len, seq_len_kv, stride_logits); return logits; } static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_lens, int block_kv, int num_sms) { - const auto& [batch_size] = get_shape<1>(context_lens); + const bool is_context_lens_2d = context_lens.dim() == 2; + int batch_size = 0, next_n = 0; + if (is_context_lens_2d) { + batch_size = context_lens.size(0); + next_n = context_lens.size(1); + } else { + DG_HOST_ASSERT(context_lens.dim() == 1); + batch_size = context_lens.size(0); + } DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt); DG_HOST_ASSERT(context_lens.is_contiguous()); @@ -131,7 +154,7 @@ static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_ // Dispatch implementation const auto& arch_major = device_runtime->get_arch_major(); if (arch_major == 9 or arch_major == 10) { - smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, block_kv, num_sms); + smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d); } else { DG_HOST_UNREACHABLE("Unsupported architecture"); } @@ -149,15 +172,24 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q, const bool& clean_logits) { const auto& [batch_size, next_n, num_heads, head_dim] = get_shape<4>(q); const auto& [num_kv_blocks, block_kv, num_heads_kv, head_dim_with_sf] = get_shape<4>(fused_kv_cache); - const auto& [batch_size_] = get_shape<1>(context_lens); const auto& [batch_size_next_n, num_heads_] = get_shape<2>(weights); - const auto& [batch_size__, max_block_len] = get_shape<2>(block_table); + const auto& [batch_size_, max_block_len] = get_shape<2>(block_table); const auto& [schedule_meta_size, meta_info_size] = get_shape<2>(schedule_meta); const auto& num_sms = device_runtime->get_num_sms(); const auto& kv_cache_stride_bytes = fused_kv_cache.stride(0); const auto& block_table_stride = block_table.stride(0); - DG_HOST_ASSERT(batch_size == batch_size_ and batch_size == batch_size__); + const bool is_context_lens_2d = context_lens.dim() == 2; + if (is_context_lens_2d) { + const auto& [batch_size__, next_n_] = get_shape<2>(context_lens); + DG_HOST_ASSERT(batch_size == batch_size__ and next_n == next_n_); + } else { + DG_HOST_ASSERT(context_lens.dim() == 1); + const auto& [batch_size__] = get_shape<1>(context_lens); + DG_HOST_ASSERT(batch_size == batch_size__); + } + + DG_HOST_ASSERT(batch_size == batch_size_); DG_HOST_ASSERT(batch_size_next_n == batch_size * next_n); DG_HOST_ASSERT(num_heads == num_heads_ and num_heads_kv == 1); DG_HOST_ASSERT(head_dim_with_sf == head_dim + static_cast(sizeof(float))); @@ -198,8 +230,8 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q, ); // Allocate output - constexpr int num_math_warp_groups = 4; - const auto& aligned_max_context_len = align(max_context_len, num_math_warp_groups * block_kv); + constexpr int split_kv = 256; + const auto& aligned_max_context_len = align(max_context_len, split_kv); auto logits = torch::empty({batch_size * next_n, aligned_max_context_len}, q.options().dtype(torch::kFloat)); logits = logits.slice(-1, 0, max_context_len); @@ -207,15 +239,17 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q, const auto& arch_major = device_runtime->get_arch_major(); if (arch_major == 9 or arch_major == 10) { smxx_fp8_paged_mqa_logits(q, kv_cache, kv_cache_scales, weights, context_lens, logits, block_table, schedule_meta, - batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, - kv_cache_stride_bytes, aligned_max_context_len, block_table_stride, num_sms, num_math_warp_groups); + batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d, + kv_cache_stride_bytes, aligned_max_context_len, block_table_stride, num_sms, split_kv); } else { DG_HOST_UNREACHABLE("Unsupported architecture"); } // Clean unfilled logits - if (clean_logits) + if (clean_logits) { + DG_HOST_ASSERT(not is_context_lens_2d); smxx_clean_logits(logits, std::nullopt, context_lens, next_n, batch_size * next_n, max_context_len, aligned_max_context_len); + } return logits; } diff --git a/csrc/apis/einsum.hpp b/csrc/apis/einsum.hpp index e53ad7d7..1f8ff674 100644 --- a/csrc/apis/einsum.hpp +++ b/csrc/apis/einsum.hpp @@ -3,13 +3,20 @@ #include "../utils/exception.hpp" #include "../utils/format.hpp" #include "../utils/layout.hpp" +#include "../utils/compatibility.hpp" +#include "gemm.hpp" +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE #include "../jit_kernels/impls/sm90_bmk_bnk_mn.hpp" #include "../jit_kernels/impls/sm100_bmk_bnk_mn.hpp" +#include "../jit_kernels/impls/sm90_bf16_gemm.hpp" +#include "../jit_kernels/impls/sm100_bf16_gemm.hpp" #include "../jit_kernels/impls/smxx_cublaslt.hpp" +#endif namespace deep_gemm::einsum { +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE static void bmk_bnk_mn(const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d, const std::optional& c) { // Currently FP32 only support the accumulated expression @@ -48,7 +55,7 @@ static void bmk_bnk_mn(const torch::Tensor& a, const torch::Tensor& b, const tor } } -static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D) { +static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D, const bool& use_cublaslt) { const auto& [b , h , r ] = get_shape<3>(A); const auto& [h_, d , r_] = get_shape<3>(B); const auto& [b_, h__, d_] = get_shape<3>(D); @@ -58,10 +65,20 @@ static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const to DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1); DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1); - cublaslt_bhr_hdr_bhd(A, B, D, b, h, r, d); + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (use_cublaslt) { + cublaslt_bhr_hdr_bhd(A, B, D, b, h, r, d); + } else if (arch_major == 9) { + sm90_bf16_bhr_hdr_bhd(A, B, D, b, h, r, d); + } else if (arch_major == 10) { + sm100_bf16_bhr_hdr_bhd(A, B, D, b, h, r, d); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } } -static void bhd_hdr_bhr(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D) { +static void bhd_hdr_bhr(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D, const bool& use_cublaslt) { const auto& [b , h , d ] = get_shape<3>(A); const auto& [h_, d_ , r ] = get_shape<3>(B); const auto& [b_, h__, r_] = get_shape<3>(D); @@ -71,14 +88,25 @@ static void bhd_hdr_bhr(const torch::Tensor& A, const torch::Tensor& B, const to DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1); DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1); - cublaslt_bhd_hdr_bhr(A, B, D, b, h, r, d); + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (use_cublaslt) { + cublaslt_bhd_hdr_bhr(A, B, D, b, h, r, d); + } else if (arch_major == 9) { + sm90_bf16_bhd_hdr_bhr(A, B, D, b, h, r, d); + } else if (arch_major == 10) { + sm100_bf16_bhd_hdr_bhr(A, B, D, b, h, r, d); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } } static void einsum(const std::string& expr, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d, - const std::optional& c) { + const std::optional& c, + const bool& use_cublaslt) { DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); @@ -91,16 +119,95 @@ static void einsum(const std::string& expr, // TODO: support any expression // TODO: canonicalize expression if (expr == "bmk,bnk->mn") { + DG_HOST_ASSERT(not use_cublaslt); bmk_bnk_mn(a, b, d, c); } else if (expr == "bhr,hdr->bhd") { DG_HOST_ASSERT(not c.has_value()); - bhr_hdr_bhd(a, b, d); + bhr_hdr_bhd(a, b, d, use_cublaslt); } else if (expr == "bhd,hdr->bhr") { DG_HOST_ASSERT(not c.has_value()); - bhd_hdr_bhr(a, b, d); + bhd_hdr_bhr(a, b, d, use_cublaslt); } else { DG_HOST_UNREACHABLE(fmt::format("Unsupported einsum expression: {}", expr)); } } +static void fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const std::optional& c, + std::optional> recipe, + const std::string& compiled_dims) { + // Shape must be `[B, M, K] @ [B, N, K].T` + const auto& major_a = a.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN; + const auto& major_b = b.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN; + DG_HOST_ASSERT(a.stride(-1) == 1 or a.stride(-2) == 1); + DG_HOST_ASSERT(b.stride(-1) == 1 or b.stride(-2) == 1); + DG_HOST_ASSERT(d.stride(-1) == 1); + + // Type and shape checks + const auto& [batch_size , m , k ] = get_shape<3>(a); + const auto& [batch_size_ , n , k_] = get_shape<3>(b); + const auto& [batch_size__, m_, n_] = get_shape<3>(d); + DG_HOST_ASSERT(batch_size == batch_size_ and batch_size == batch_size_); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(a.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + + // Early return for trivial cases + if (batch_size == 0 or gemm::early_return(m, n, k, d, c)) + return; + + // Transform scaling factors + const auto& [transformed_sfa, transformed_sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + sfa, sfb, m, n, k, recipe, std::nullopt, std::nullopt, batch_size, batch_size, false); + + // Dispatch implementation + const auto arch_major = device_runtime->get_arch_major(); + if (arch_major == 10) { + sm100_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, major_a, major_b, compiled_dims); + } else { + const auto& major_sfb = get_major_type_ab(sfb); + sm90_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, major_a, major_b, major_sfb, compiled_dims); + } +} + +static void fp8_einsum(const std::string& expr, + const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::tuple& recipe) { + // Some hardcoded Einstein sum kernels + const auto arch_major = device_runtime->get_arch_major(); + if (expr == "bhr,hdr->bhd") { + // Permute dims to satisfy the order of (batch_size, m, n, k) + // (batch_size, m, n, k): (h, b, d, r) + const auto& perm_a = a.first.permute({1, 0, 2}); + const auto& perm_sfa = a.second.permute({1, 0, 2}); + const auto& perm_d = d.permute({1, 0, 2}); + const auto& perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt; + fp8_bmm(perm_a, perm_sfa, b.first, b.second, perm_d, perm_c, recipe, "nk"); + } else if (expr == "bhd,hdr->bhr" and arch_major == 10) { + // (batch_size, m, n, k): (h, b, r, d) + const auto& perm_a = a.first.permute({1, 0, 2}); + const auto& perm_sfa = a.second.permute({1, 0, 2}); + const auto& perm_b = b.first.permute({0, 2, 1}); + const auto& perm_sfb = b.second.permute({0, 2, 1}); + const auto& perm_d = d.permute({1, 0, 2}); + const auto& perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt; + fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, perm_d, perm_c, recipe, "nk"); + } else if (expr == "bhd,bhr->hdr" and arch_major == 10) { + // (batch_size, m, n, k): (h, d, r, b) + const auto& perm_a = a.first.permute({1, 2, 0}); + const auto& perm_sfa = a.second.permute({1, 2, 0}); + const auto& perm_b = b.first.permute({1, 2, 0}); + const auto& perm_sfb = b.second.permute({1, 2, 0}); + fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, d, c, recipe, "mn"); + } else { + DG_HOST_UNREACHABLE(fmt::format("Unsupported einsum expression: {}", expr)); + } +} +#endif } // namespace deep_gemm::einsum diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index 7db0dc89..f63f6028 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -1,23 +1,61 @@ #pragma once +#include "../utils/compatibility.hpp" + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE #include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp" #include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" #include "../jit_kernels/impls/sm90_bf16_gemm.hpp" #include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp" -#include "../jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp" #include "../jit_kernels/impls/sm100_bf16_gemm.hpp" +#endif + +#include "../jit_kernels/impls/smxx_cublaslt.hpp" #include "layout.hpp" namespace deep_gemm::gemm { -static void fp8_gemm_nt(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const std::optional& c, - std::optional> recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { +static bool early_return(const int& m, const int &n, const int& k, + const torch::Tensor& d, const std::optional& c) { + // Do nothing if the problem is empty + if (m == 0 or n == 0) + return true; + + // Checks + const bool& is_cd_same = c.has_value() and c->data_ptr() == d.data_ptr(); + if (is_cd_same) + DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides()); + if (c.has_value()) { + check_major_type_cd(c.value()); + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat); + } + + // No accumulation + if (k == 0) { + if (not is_cd_same) + c.has_value() ? d.copy_(c.value()) : d.zero_(); + return true; + } + + // With accumulation, do copy before GEMM (assuming the GEMM kernel does not support different C/D) + if (c.has_value() and not is_cd_same) + d.copy_(c.value()); + return false; +} + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE + +static void fp8_fp4_gemm_nt(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + std::optional> recipe, + std::optional> recipe_a, + std::optional> recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { // Shape must be `[M, K] @ [N, K].T` const auto& major_a = get_major_type_ab(a.first); const auto& major_b = get_major_type_ab(b.first); @@ -30,110 +68,116 @@ static void fp8_gemm_nt(const std::pair& a, check_major_type_cd(d); // Type and shape checks - const auto& [m , k ] = get_shape<2>(a.first); - const auto& [n , k_] = get_shape<2>(b.first); - const auto& [m_, n_] = get_shape<2>(d); + const auto arch_major = device_runtime->get_arch_major(); + const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major); + const auto [n , k_] = check_ab_fp8_fp4(b.first, major_b, arch_major); + const auto [m_, n_] = get_shape<2>(d); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); - DG_HOST_ASSERT(n > 0 and k > 0); - DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); - DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); - // Check C as well - if (c.has_value()) { - check_major_type_cd(c.value()); - DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); - DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat); - } - - // Do nothing if the problem is empty - if (m == 0) + // Early return for trivial cases + if (early_return(m, n, k, d, c)) return; // Transform SFA and SFB into compute-required layout - if (not recipe.has_value()) - recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); - DG_HOST_ASSERT(recipe.value() == std::make_tuple(1, 1, 128) or recipe.value() == std::make_tuple(1, 128, 128)); - const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast); - const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast); + const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, std::nullopt, std::nullopt, disable_ue8m0_cast); // Dispatch into different implements - const auto& arch_major = device_runtime->get_arch_major(); if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { - if (std::get<1>(recipe.value()) == 1) { + const int gran_n = recipe.has_value() ? std::get<1>(recipe.value()) : std::get<0>(recipe_b.value()); + if (gran_n == 1) { sm90_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); } else { - sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + const auto& major_sfb = get_major_type_ab(sfb); + sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, major_sfb, compiled_dims); } } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { - sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); - } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { - sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, gran_k_a, gran_k_b, + major_a, major_b, compiled_dims); } else { DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); } } -static void fp8_gemm_nn(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const std::optional& c, - const std::optional>& recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { - fp8_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)}, - d, c, recipe, compiled_dims, disable_ue8m0_cast); +static void fp8_fp4_gemm_nn(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_fp4_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)}, + d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast); } -static void fp8_gemm_tn(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const std::optional& c, - const std::optional>& recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { - fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, - {b.first.transpose(0, 1), b.second.transpose(0, 1)}, - d, c, recipe, compiled_dims, disable_ue8m0_cast); +static void fp8_fp4_gemm_tn(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_fp4_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, + {b.first.transpose(0, 1), b.second.transpose(0, 1)}, + d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast); } -static void fp8_gemm_tt(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const std::optional& c, - const std::optional>& recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { - fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b, - d, c, recipe, compiled_dims, disable_ue8m0_cast); +static void fp8_fp4_gemm_tt(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_fp4_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b, + d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast); } -static void m_grouped_fp8_gemm_nt_contiguous(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const torch::Tensor& m_indices, - std::optional> recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { +static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + std::optional> recipe, + std::optional> recipe_a, + std::optional> recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { // Shape must be `[M, K] @ [G, N, K].mT` const auto& major_a = get_major_type_ab(a.first); const auto& major_b = get_major_type_ab(b.first); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); if (fp8_requires_k_major()) DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); - DG_HOST_ASSERT(m_indices.is_contiguous()); + DG_HOST_ASSERT(grouped_layout.is_contiguous()); // Type and shape checks - const auto& [m, k] = get_shape<2>(a.first); - const auto& [num_groups, n, k_] = get_shape<3>(b.first); - const auto& [m_, n_] = get_shape<2>(d); - const auto& m__ = static_cast(m_indices.numel()); - DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + const auto arch_major = device_runtime->get_arch_major(); + const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major); + const auto [num_groups, n, k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major); + const auto [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); - DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); - DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); - DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt); + DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt); + + // Layout checks + if (use_psum_layout) { + const auto& [num_groups_] = get_shape<1>(grouped_layout); + DG_HOST_ASSERT(num_groups == num_groups_); + } else { + const auto& [m__] = get_shape<1>(grouped_layout); + DG_HOST_ASSERT(m == m__); + DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value()); + } // D must be N-major check_major_type_cd(d); @@ -143,36 +187,36 @@ static void m_grouped_fp8_gemm_nt_contiguous(const std::pairget_arch_major(); if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { - sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices, - num_groups, m, n, k, major_a, major_b, compiled_dims); + const auto& major_sfb = get_major_type_ab(sfb); + DG_HOST_ASSERT(not use_psum_layout); + sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, grouped_layout, + num_groups, m, n, k, major_a, major_b, major_sfb, compiled_dims); } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { - sm100_m_grouped_fp8_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, m_indices, - num_groups, m, n, k, major_a, major_b, compiled_dims); - } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { - sm100_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices, - num_groups, m, n, k, major_a, major_b, compiled_dims); + sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, grouped_layout, + num_groups, m, n, k, gran_k_a, gran_k_b, major_a, major_b, + compiled_dims, use_psum_layout, expected_m_for_psum_layout); } else { DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); } } -static void m_grouped_fp8_gemm_nn_contiguous(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const torch::Tensor& m_indices, - const std::optional>& recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { - m_grouped_fp8_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)}, - d, m_indices, recipe, compiled_dims, disable_ue8m0_cast); +static void m_grouped_fp8_fp4_gemm_nn_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + const std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast, + const bool& use_psum_layout) { + m_grouped_fp8_fp4_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)}, + d, grouped_layout, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast, use_psum_layout, std::nullopt); } static std::optional> m_grouped_fp8_gemm_nt_masked(const std::pair& a, @@ -193,15 +237,14 @@ static std::optional> m_grouped_fp8_gemm_nt_masked(const std DG_HOST_ASSERT(masked_m.is_contiguous()); // Type and shape checks - const auto& [num_groups, m, k] = get_shape<3>(a.first); - const auto& [num_groups_, n, k_] = get_shape<3>(b.first); - const auto& [num_groups__, m_, n_] = get_shape<3>(d); - const auto& num_groups___ = static_cast(masked_m.numel()); + const auto arch_major = device_runtime->get_arch_major(); + const auto [num_groups , m , k ] = check_grouped_ab_fp8_fp4(a.first, major_a, arch_major); + const auto [num_groups_ , n , k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major); + const auto [num_groups__, m_, n_] = get_shape<3>(d); + const auto num_groups___ = static_cast(masked_m.numel()); DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0); - DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); - DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt); @@ -215,10 +258,8 @@ static std::optional> m_grouped_fp8_gemm_nt_masked(const std check_major_type_cd(d); // Transform scaling factors - if (not recipe.has_value()) - recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); - const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), num_groups, true, disable_ue8m0_cast); - const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast); + const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + a.second, b.second, m, n, k, recipe, std::nullopt, std::nullopt, num_groups, num_groups, disable_ue8m0_cast); // Dispatch implementation const auto& arch_major = device_runtime->get_arch_major(); @@ -228,11 +269,9 @@ static std::optional> m_grouped_fp8_gemm_nt_masked(const std num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims, max_block_n, enable_overlap, signal); } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { - sm100_m_grouped_fp8_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m, - num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); - } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { - sm100_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, - num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, gran_k_a, gran_k_b, + major_a, major_b, compiled_dims); } else { DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); } @@ -250,30 +289,32 @@ static void k_grouped_fp8_gemm_tn_contiguous(const std::pair(d); + const auto& [sum_k_ , m_] = get_shape<2>(a.first); + const auto& [sum_k__, n_] = get_shape<2>(b.first); + const int sum_k = std::accumulate(ks.begin(), ks.end(), 0); + DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__); + // Contiguity checks DG_HOST_ASSERT(a.first.is_contiguous()); DG_HOST_ASSERT(b.first.is_contiguous()); DG_HOST_ASSERT(d.is_contiguous()); - if (c.has_value()) { - DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat); - DG_HOST_ASSERT(c.value().is_contiguous()); - } + DG_HOST_ASSERT(c.has_value() and c.value().is_contiguous()); - // Do nothing if empty - if (std::accumulate(ks.begin(), ks.end(), 0) == 0) + // Early return for trivial cases + if (early_return(m, n, std::accumulate(ks.begin(), ks.end(), 0), d, c)) return; // Transform SF with padding - const auto& [_, m] = get_shape<2>(a.first); - const auto& [__, n] = get_shape<2>(b.first); const auto& sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe); const auto& sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe); // Dispatch implementation const auto& arch_major = device_runtime->get_arch_major(); if (arch_major == 10) { - fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, - cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims); + sm100_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, + cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims); } else { DG_HOST_UNREACHABLE("Unsupported architecture"); } @@ -294,23 +335,18 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pair(d); const auto& sum_mk = a.first.numel(); const auto& sum_nk = b.first.numel(); - int sum_k = 0; - for (const auto& k: ks) - sum_k += k; - DG_HOST_ASSERT(sum_mk == m * sum_k); - DG_HOST_ASSERT(sum_nk == n * sum_k); + const int sum_k = std::accumulate(ks.begin(), ks.end(), 0); + DG_HOST_ASSERT(sum_mk == static_cast(sum_k) * m); + DG_HOST_ASSERT(sum_nk == static_cast(sum_k) * n); // Contiguity checks DG_HOST_ASSERT(a.first.is_contiguous()); DG_HOST_ASSERT(b.first.is_contiguous()); DG_HOST_ASSERT(d.is_contiguous()); - if (c.has_value()) { - DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat); - DG_HOST_ASSERT(c.value().is_contiguous()); - } + DG_HOST_ASSERT(c.has_value() and c.value().is_contiguous()); - // Do nothing if empty - if (std::accumulate(ks.begin(), ks.end(), 0) == 0) + // Early return for trivial cases + if (early_return(m, n, accumulate(ks.begin(), ks.end(), 0), d, c)) return; // Transform SF with padding @@ -326,13 +362,15 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pairget_arch_major(); if (arch_major == 9) { - sm90_fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer, + sm90_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer, cute::UMMA::Major::K, cute::UMMA::Major::K, compiled_dims); } else { DG_HOST_UNREACHABLE("Unsupported architecture"); } } +#endif +#if DG_TENSORMAP_COMPATIBLE static void bf16_gemm_nt(const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d, @@ -350,20 +388,12 @@ static void bf16_gemm_nt(const torch::Tensor& a, const auto& [n , k_] = get_shape<2>(b); const auto& [m_, n_] = get_shape<2>(d); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); - DG_HOST_ASSERT(n > 0 and k > 0); DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); - // Check C as well - if (c.has_value()) { - check_major_type_cd(c.value()); - DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); - DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat); - } - - // Do nothing if the problem is empty - if (m == 0) + // Early return for trivial cases + if (early_return(m, n, k, d, c)) return; // Dispatch into different implements @@ -402,26 +432,36 @@ static void bf16_gemm_tt(const torch::Tensor& a, } static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& d, const torch::Tensor& m_indices, - const std::string& compiled_dims) { + const torch::Tensor& d, const torch::Tensor& grouped_layout, + const std::string& compiled_dims, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { // Shape must be `[M, K] @ [G, N, K].mT` const auto& major_a = get_major_type_ab(a); const auto& major_b = get_major_type_ab(b); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); - DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); - DG_HOST_ASSERT(m_indices.is_contiguous()); + DG_HOST_ASSERT(grouped_layout.is_contiguous()); // Type and shape checks const auto& [m, k] = get_shape<2>(a); const auto& [num_groups, n, k_] = get_shape<3>(b); const auto& [m_, n_] = get_shape<2>(d); - const auto& m__ = static_cast(m_indices.numel()); - DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); - DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt); + DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt); + + // Layout checks + if (use_psum_layout) { + const auto& [num_groups_] = get_shape<1>(grouped_layout); + DG_HOST_ASSERT(num_groups == num_groups_); + } else { + const auto& [m__] = get_shape<1>(grouped_layout); + DG_HOST_ASSERT(m == m__); + DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value()); + } // D must be N-major check_major_type_cd(d); @@ -433,13 +473,26 @@ static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torc // Dispatch implementation const auto& arch_major = device_runtime->get_arch_major(); if (arch_major == 9) { - sm90_m_grouped_bf16_gemm_contiguous(a, b, d, m_indices, + DG_HOST_ASSERT(not use_psum_layout); + sm90_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout, num_groups, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10) { + sm100_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout, + num_groups, m, n, k, major_a, major_b, compiled_dims, + use_psum_layout, expected_m_for_psum_layout); } else { DG_HOST_UNREACHABLE("Unsupported architecture"); } } +static void m_grouped_bf16_gemm_nn_contiguous(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const torch::Tensor& grouped_layout, + const std::string& compiled_dims, + const bool& use_psum_layout) { + m_grouped_bf16_gemm_nt_contiguous(a, b.transpose(1, 2), + d, grouped_layout, compiled_dims, use_psum_layout, std::nullopt); +} + static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d, const torch::Tensor& masked_m, const int& expected_m, const std::string& compiled_dims) { @@ -470,11 +523,52 @@ static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::T if (arch_major == 9) { sm90_bf16_m_grouped_gemm_masked(a, b, d, masked_m, num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + } else if (arch_major == 10) { + sm100_m_grouped_bf16_gemm_masked(a, b, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); } else { DG_HOST_UNREACHABLE("Unsupported architecture"); } } +static void k_grouped_bf16_gemm_tn_contiguous(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::vector& ks, + const torch::Tensor& ks_tensor, + const std::optional& c, + const std::string& compiled_dims) { + // Shape checks + const auto& [num_groups, m, n] = get_shape<3>(d); + const auto& [sum_k_ , m_] = get_shape<2>(a); + const auto& [sum_k__, n_] = get_shape<2>(b); + const int sum_k = std::accumulate(ks.begin(), ks.end(), 0); + DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__); + + // Contiguity checks + DG_HOST_ASSERT(a.is_contiguous()); + DG_HOST_ASSERT(b.is_contiguous()); + DG_HOST_ASSERT(d.is_contiguous()); + DG_HOST_ASSERT(c.has_value() and c.value().is_contiguous()); + + // Early return for trivial cases + if (early_return(m, n, std::accumulate(ks.begin(), ks.end(), 0), d, c)) + return; + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_bf16_k_grouped_gemm(a, b, c, d, m, n, ks, ks_tensor, + cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims); + } else if (arch_major == 10) { + sm100_bf16_k_grouped_gemm(a, b, c, d, m, n, ks, ks_tensor, + cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} +#endif + static void cublaslt_gemm_nt(const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d, const std::optional& c) { // Shape must be `[M, K] @ [N, K].T` @@ -487,11 +581,8 @@ static void cublaslt_gemm_nt(const torch::Tensor& a, const torch::Tensor& b, const auto& [m_, n_] = get_shape<2>(d); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); - if (c.has_value()) - DG_HOST_ASSERT(c.value().scalar_type() == d.scalar_type()); - - // Do nothing if the problem is empty - if (m == 0 or n == 0) + // Early return for trivial cases + if (early_return(m, n, k, d, c)) return; cublaslt_gemm(a, b, c, d, m, n, k, major_a, major_b); diff --git a/csrc/apis/hyperconnection.hpp b/csrc/apis/hyperconnection.hpp new file mode 100644 index 00000000..0a85b10f --- /dev/null +++ b/csrc/apis/hyperconnection.hpp @@ -0,0 +1,70 @@ +#pragma once + +#include "../utils/compatibility.hpp" + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +#include "../jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp" +#include "../jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp" +#endif + +namespace deep_gemm::hyperconnection { + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +static void tf32_hc_prenorm_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& sqr_sum, + const std::optional& num_splits) { + // A and B must be K-major, D must be N-major + DG_HOST_ASSERT(get_major_type_ab(a) == cute::UMMA::Major::K); + DG_HOST_ASSERT(get_major_type_ab(b) == cute::UMMA::Major::K); + check_major_type_cd(d); + + // S must be contiguous + DG_HOST_ASSERT(sqr_sum.is_contiguous()); + + // Type and shape checks + const auto& [m, k ] = get_shape<2>(a); + const auto& [n, k_] = get_shape<2>(b); + if (num_splits.has_value()) { + const auto& [num_splits_, m_, n_] = get_shape<3>(d); + const auto& [num_splits__, m__] = get_shape<2>(sqr_sum); + DG_HOST_ASSERT(num_splits.value() == num_splits_ and num_splits.value() == num_splits__ and num_splits.value() >= 1); + DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + } else { + const auto& [m_, n_] = get_shape<2>(d); + const auto& [m__] = get_shape<1>(sqr_sum); + DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + } + DG_HOST_ASSERT(n > 0 and k > 0); + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(sqr_sum.scalar_type() == torch::kFloat); + + // Do nothing if the problem is empty + if (m == 0) + return; + + // Dispatch into different implements + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_tf32_hc_prenorm_gemm(a, b, d, sqr_sum, m, n, k, num_splits.has_value() ? num_splits.value() : 1); + } else if (arch_major == 10) { + sm100_tf32_hc_prenorm_gemm(a, b, d, sqr_sum, m, n, k, num_splits.has_value() ? num_splits.value() : 1); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +#endif + +static void register_apis(pybind11::module_& m) { +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE + m.def("tf32_hc_prenorm_gemm", &tf32_hc_prenorm_gemm, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("sqr_sum"), + py::arg("num_splits") = std::nullopt); +#endif +} + +} // namespace deep_gemm::hyperconnection diff --git a/csrc/apis/layout.hpp b/csrc/apis/layout.hpp index 852378f3..6d14fbad 100644 --- a/csrc/apis/layout.hpp +++ b/csrc/apis/layout.hpp @@ -1,20 +1,34 @@ #pragma once #include "../utils/layout.hpp" +#include "../utils/compatibility.hpp" + +#if DG_TENSORMAP_COMPATIBLE #include "../jit_kernels/impls/smxx_layout.hpp" +#endif namespace deep_gemm::layout { +#if DG_TENSORMAP_COMPATIBLE static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf, const int& mn, const int& k, - const std::tuple& recipe, + const std::optional>& recipe, + const std::optional>& recipe_ab, const std::optional& num_groups, const bool& is_sfa, const bool& disable_ue8m0_cast) { - const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe); - const auto& gran_k = std::get<2>(recipe); const auto& arch_major = device_runtime->get_arch_major(); + int gran_mn, gran_k; + if (recipe.has_value()) { + DG_HOST_ASSERT(not recipe_ab.has_value()); + gran_mn = is_sfa ? std::get<0>(recipe.value()) : std::get<1>(recipe.value()); + gran_k = std::get<2>(recipe.value()); + } else { + DG_HOST_ASSERT(recipe_ab.has_value()); + std::tie(gran_mn, gran_k) = recipe_ab.value(); + } + // Pre-transform checks check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups); @@ -22,30 +36,44 @@ static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf, if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast)) return get_mn_major_tma_aligned_tensor(sf); - // (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major - if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and arch_major == 10) { - DG_HOST_ASSERT(not disable_ue8m0_cast); - return get_mn_major_tma_aligned_packed_ue8m0_tensor(sf); - } - - // (FP32, 128, 128) on SM90: no need to transform, check shape and contiguous + // (FP32, 128, 128) on SM90: no need to transform, check SFB requirements if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast)) return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat); - // (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major - if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and arch_major == 10) { + // (FP32, x, gran_k) on SM100: transform to (INT, 1, gran_k), TMA-aligned and MN-major + if (sf.scalar_type() == torch::kFloat and (gran_k == 32 or gran_k == 128) and arch_major == 10) { DG_HOST_ASSERT(not disable_ue8m0_cast); - const auto& broadcasted = sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(128)); + const auto& broadcasted = gran_mn == 1 ? sf : + sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(gran_mn)); return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted); } - // (INT, 1, 128) on SM100: transform to TMA-aligned and MN-major - if (sf.scalar_type() == torch::kInt and gran_mn == 1 and gran_k == 128 and arch_major == 10) + // (INT, 1, gran_k) on SM100: transform to TMA-aligned and MN-major + if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 32 or gran_k == 128) and arch_major == 10) return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt); DG_HOST_UNREACHABLE("Unknown SF transformation"); } +static std::tuple transform_sf_pair_into_required_layout( + const torch::Tensor& sfa, const torch::Tensor& sfb, + const int& m, const int& n, const int& k, + std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::optional& num_groups_a, + const std::optional& num_groups_b, + const bool& disable_ue8m0_cast = false) { + DG_HOST_ASSERT(recipe_a.has_value() == recipe_b.has_value()); + if (not recipe_a.has_value() and not recipe.has_value()) + recipe = get_default_recipe(sfa.scalar_type(), sfb.scalar_type()); + const auto transformed_sfa = transform_sf_into_required_layout(sfa, m, k, recipe, recipe_a, num_groups_a, true, disable_ue8m0_cast); + const auto transformed_sfb = transform_sf_into_required_layout(sfb, n, k, recipe, recipe_b, num_groups_b, false, disable_ue8m0_cast); + const int gran_k_a = recipe_a.has_value() ? std::get<1>(recipe_a.value()) : std::get<2>(recipe.value()); + const int gran_k_b = recipe_b.has_value() ? std::get<1>(recipe_b.value()) : std::get<2>(recipe.value()); + return std::make_tuple(transformed_sfa, transformed_sfb, gran_k_a, gran_k_b); +} + static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf, const std::vector& ks, const torch::Tensor& ks_tensor, diff --git a/csrc/apis/runtime.hpp b/csrc/apis/runtime.hpp index bcd08bcd..9da425cb 100644 --- a/csrc/apis/runtime.hpp +++ b/csrc/apis/runtime.hpp @@ -1,6 +1,8 @@ #pragma once +#if DG_TENSORMAP_COMPATIBLE #include "../jit/compiler.hpp" +#endif #include "../jit/device_runtime.hpp" namespace deep_gemm::runtime { diff --git a/csrc/indexing/main.cu b/csrc/indexing/main.cu index 6419a853..1b96da2f 100644 --- a/csrc/indexing/main.cu +++ b/csrc/indexing/main.cu @@ -4,21 +4,24 @@ #include #include #include -#include // Attention kernels #include #include #include #include -#include // Einsum kernels #include #include +// Hyperconnection kernels +#include +#include + // Layout kernels #include +#include using namespace deep_gemm; diff --git a/csrc/jit/compiler.hpp b/csrc/jit/compiler.hpp index 55bcc60e..3dc0cfbf 100644 --- a/csrc/jit/compiler.hpp +++ b/csrc/jit/compiler.hpp @@ -24,14 +24,20 @@ class Compiler { static std::filesystem::path library_include_path; static std::filesystem::path cuda_home; static std::string library_version; + static std::filesystem::path cuobjdump_path; static std::string get_library_version() { - std::stringstream ss; + std::vector buffer; for (const auto& f: collect_files(library_include_path / "deep_gemm")) { std::ifstream in(f, std::ios::binary); - ss << in.rdbuf(); + DG_HOST_ASSERT(in.is_open()); + + // Append into the buffer + buffer.insert(buffer.end(), + std::istreambuf_iterator(in), + std::istreambuf_iterator()); } - return get_hex_digest(ss.str()); + return get_hex_digest(buffer); } static void prepare_init(const std::string& library_root_path, @@ -40,6 +46,7 @@ class Compiler { Compiler::library_include_path = Compiler::library_root_path / "include"; Compiler::cuda_home = cuda_home_path_by_python; Compiler::library_version = get_library_version(); + Compiler::cuobjdump_path = Compiler::cuda_home / "bin" / "cuobjdump"; } std::string signature, flags; @@ -51,6 +58,7 @@ class Compiler { DG_HOST_ASSERT(not library_include_path.empty()); DG_HOST_ASSERT(not cuda_home.empty()); DG_HOST_ASSERT(not library_version.empty()); + DG_HOST_ASSERT(not cuobjdump_path.empty()); // Cache settings cache_dir_path = std::filesystem::path(get_env("HOME")) / ".deep_gemm"; @@ -62,8 +70,8 @@ class Compiler { flags = fmt::format("-std=c++{} --diag-suppress=39,161,174,177,186,940 " "--ptxas-options=--register-usage-level=10", get_env("DG_JIT_CPP_STANDARD", 20)); - if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0)) - flags += " --ptxas-options=--verbose"; + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0) or get_env("DG_JIT_PTXAS_CHECK", 0)) + flags += " --ptxas-options=--verbose,--warn-on-local-memory-usage"; if (get_env("DG_JIT_WITH_LINEINFO", 0)) flags += " -Xcompiler -rdynamic -lineinfo"; } @@ -103,25 +111,57 @@ class Compiler { // Compile into a temporary CUBIN const auto tmp_cubin_path = get_tmp_file_path(); - compile(code, dir_path, tmp_cubin_path); + if (get_env("DG_JIT_DUMP_ASM") or get_env("DG_JIT_DUMP_PTX")) { + // Dump PTX if needed + const auto tmp_ptx_path = get_tmp_file_path(); + compile(code, dir_path, tmp_cubin_path, tmp_ptx_path); + + // Replace into the cache directory + std::filesystem::rename(tmp_ptx_path, dir_path / "kernel.ptx"); + } else { + compile(code, dir_path, tmp_cubin_path); + } // Replace into the cache directory - make_dirs(dir_path); - std::filesystem::rename(tmp_cubin_path, dir_path / "kernel.cubin"); + const auto cubin_path = dir_path / "kernel.cubin"; + std::filesystem::rename(tmp_cubin_path, cubin_path); + + // Disassemble if needed + if (get_env("DG_JIT_DUMP_ASM") or get_env("DG_JIT_DUMP_SASS")) { + // Dump into a temporary SASS + const auto tmp_sass_path = get_tmp_file_path(); + disassemble(cubin_path, tmp_sass_path); + + // Replace into the current directory + std::filesystem::rename(tmp_sass_path, dir_path / "kernel.sass"); + } // Put into the runtime cache - const auto& runtime = kernel_runtime_cache->get(dir_path); + const auto runtime = kernel_runtime_cache->get(dir_path); DG_HOST_ASSERT(runtime != nullptr); return runtime; } - virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const = 0; + static void disassemble(const std::filesystem::path &cubin_path, const std::filesystem::path &sass_path) { + // Disassemble the CUBIN file to SASS + const auto command = fmt::format("{} --dump-sass {} > {}", cuobjdump_path.c_str(), cubin_path.c_str(), sass_path.c_str()); + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) + printf("Running cuobjdump command: %s\n", command.c_str()); + const auto [return_code, output] = call_external_command(command); + if (return_code != 0) { + printf("cuobjdump failed: %s\n", output.c_str()); + DG_HOST_ASSERT(false and "cuobjdump failed"); + } + } + + virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path, const std::optional &ptx_path = std::nullopt) const = 0; }; DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_root_path); DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_include_path); DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuda_home); DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_version); +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuobjdump_path); class NVCCCompiler final: public Compiler { std::filesystem::path nvcc_path; @@ -159,17 +199,19 @@ class NVCCCompiler final: public Compiler { const auto& arch = device_runtime->get_arch(false, nvcc_major > 12 or nvcc_minor >= 9); flags = fmt::format("{} -I{} --gpu-architecture=sm_{} " "--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi " - "-cubin -O3 --expt-relaxed-constexpr --expt-extended-lambda", + "-O3 --expt-relaxed-constexpr --expt-extended-lambda", flags, library_include_path.c_str(), arch); } - void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override { + void compile(const std::string &code, const std::filesystem::path& dir_path, + const std::filesystem::path &cubin_path, + const std::optional &ptx_path) const override { // Write the code into the cache directory const auto& code_path = dir_path / "kernel.cu"; put(code_path, code); // Compile - const auto& command = fmt::format("{} {} -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags); + const auto& command = fmt::format("{} {} -cubin -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags); if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) printf("Running NVCC command: %s\n", command.c_str()); const auto& [return_code, output] = call_external_command(command); @@ -178,6 +220,22 @@ class NVCCCompiler final: public Compiler { DG_HOST_ASSERT(false and "NVCC compilation failed"); } + // Compile to PTX if needed + if (ptx_path.has_value()) { + const auto ptx_command = fmt::format("{} {} -ptx -o {} {}", nvcc_path.c_str(), code_path.c_str(), ptx_path->c_str(), flags); + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) + printf("Running NVCC PTX command: %s\n", ptx_command.c_str()); + const auto [ptx_return_code, ptx_output] = call_external_command(ptx_command); + if (ptx_return_code != 0) { + printf("NVCC PTX compilation failed: %s\n", ptx_output.c_str()); + DG_HOST_ASSERT(false and "NVCC PTX compilation failed"); + } + } + + // Check local memory usage + if (get_env("DG_JIT_PTXAS_CHECK", 0)) + DG_HOST_ASSERT(not std::regex_search(output, std::regex(R"(Local memory used)"))); + // Print PTXAS log if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0)) printf("%s", output.c_str()); @@ -210,11 +268,13 @@ class NVRTCCompiler final: public Compiler { // Override the compiler flags // Only NVRTC >= 12.9 supports arch-specific family suffix const auto& arch = device_runtime->get_arch(false, major > 12 or minor >= 9); - flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {}", + flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {} --device-int128", flags, include_dirs, arch, pch_flags); } - void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override { + void compile(const std::string &code, const std::filesystem::path& dir_path, + const std::filesystem::path &cubin_path, + const std::optional &ptx_path) const override { // Write the code into the cache directory const auto& code_path = dir_path / "kernel.cu"; put(code_path, code); @@ -257,6 +317,17 @@ class NVRTCCompiler final: public Compiler { } } + if (ptx_path.has_value()) { + // Get PTX size and data if needed + size_t ptx_size; + DG_NVRTC_CHECK(nvrtcGetPTXSize(program, &ptx_size)); + std::string ptx_data(ptx_size, '\0'); + DG_NVRTC_CHECK(nvrtcGetPTX(program, ptx_data.data())); + + // Write into the file system + put(ptx_path.value(), ptx_data); + } + // Get CUBIN size and data size_t cubin_size; DG_NVRTC_CHECK(nvrtcGetCUBINSize(program, &cubin_size)); diff --git a/csrc/jit/device_runtime.hpp b/csrc/jit/device_runtime.hpp index 2583f2c1..ae881a03 100644 --- a/csrc/jit/device_runtime.hpp +++ b/csrc/jit/device_runtime.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include "../utils/exception.hpp" @@ -15,10 +16,12 @@ class DeviceRuntime { // cuBLASLt utils static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024; + +public: + // Create the cuBLASLt handle ourselves cublasLtHandle_t cublaslt_handle{}; std::shared_ptr cublaslt_workspace; -public: explicit DeviceRuntime() { cublaslt_workspace = std::make_shared(torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA))); DG_CUBLASLT_CHECK(cublasLtCreate(&cublaslt_handle)); @@ -37,8 +40,13 @@ class DeviceRuntime { } std::shared_ptr get_prop() { - if (cached_prop == nullptr) - cached_prop = std::make_shared(*at::cuda::getCurrentDeviceProperties()); + if (cached_prop == nullptr) { + int device_idx; + cudaDeviceProp prop; + DG_CUDA_RUNTIME_CHECK(cudaGetDevice(&device_idx)); + DG_CUDA_RUNTIME_CHECK(cudaGetDeviceProperties(&prop, device_idx)); + cached_prop = std::make_shared(prop); + } return cached_prop; } @@ -82,6 +90,10 @@ class DeviceRuntime { return compile_mode; } + int get_l2_cache_size() { + return get_prop()->l2CacheSize; + } + void set_tc_util(const int& new_tc_util) { DG_HOST_ASSERT(0 <= new_tc_util and new_tc_util <= 100); tc_util = new_tc_util; diff --git a/csrc/jit/handle.hpp b/csrc/jit/handle.hpp index e05cf92c..34447f91 100644 --- a/csrc/jit/handle.hpp +++ b/csrc/jit/handle.hpp @@ -2,12 +2,46 @@ #include #include +#include #include #include "../utils/exception.hpp" +#include "../utils/compatibility.hpp" namespace deep_gemm { +// Lazy loading all driver symbols +static void* get_driver_handle() { + static void* handle = nullptr; + if (handle == nullptr) { + handle = dlopen("libcuda.so.1", RTLD_LAZY | RTLD_LOCAL); + DG_HOST_ASSERT(handle != nullptr and "Failed to load CUDA driver `libcuda.so.1`"); + } + return handle; +} + +// Macro to define wrapper functions named `lazy_cu{API name}` +#define DECL_LAZY_CUDA_DRIVER_FUNCTION(name) \ +template \ +static auto lazy_##name(Args&&... args) -> decltype(name(args...)) { \ + using FuncType = decltype(&name); \ + static FuncType func = nullptr; \ + if (func == nullptr) { \ + func = reinterpret_cast(dlsym(get_driver_handle(), #name)); \ + DG_HOST_ASSERT(func != nullptr and "Failed to load CUDA driver API"); \ + } \ + return func(std::forward(args)...); \ +} + +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuGetErrorName); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuGetErrorString); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuFuncSetAttribute); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleLoad); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleUnload); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleGetFunction); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLaunchKernelEx); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuTensorMapEncodeTiled); + #if CUDART_VERSION >= 12080 and defined(DG_JIT_USE_RUNTIME_API) // Use CUDA runtime API @@ -80,8 +114,8 @@ static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const s LibraryHandle *library_opt = nullptr) { LibraryHandle library; KernelHandle kernel; - DG_CUDA_DRIVER_CHECK(cuModuleLoad(&library, cubin_path.c_str())); - DG_CUDA_DRIVER_CHECK(cuModuleGetFunction(&kernel, library, func_name.c_str())); + DG_CUDA_DRIVER_CHECK(lazy_cuModuleLoad(&library, cubin_path.c_str())); + DG_CUDA_DRIVER_CHECK(lazy_cuModuleGetFunction(&kernel, library, func_name.c_str())); if (library_opt != nullptr) *library_opt = library; @@ -89,7 +123,7 @@ static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const s } static void unload_library(const LibraryHandle& library) { - const auto& error = cuModuleUnload(library); + const auto& error = lazy_cuModuleUnload(library); DG_HOST_ASSERT(error == CUDA_SUCCESS or error == CUDA_ERROR_DEINITIALIZED); } @@ -97,7 +131,7 @@ static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, const cudaStream_t& stream, const int& smem_size, const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) { if (smem_size > 0) - DG_CUDA_DRIVER_CHECK(cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size)); + DG_CUDA_DRIVER_CHECK(lazy_cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size)); LaunchConfigHandle config; config.gridDimX = grid_dim.x; @@ -127,9 +161,8 @@ static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, template static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle& config, ActTypes&&... args) { void *ptr_args[] = { &args... }; - return cuLaunchKernelEx(&config, kernel, ptr_args, nullptr); + return lazy_cuLaunchKernelEx(&config, kernel, ptr_args, nullptr); } - #endif } // namespace deep_gemm diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index 5427d138..333ab77b 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -4,6 +4,7 @@ #include "../../utils/math.hpp" #include "../../utils/layout.hpp" +#include "../../utils/system.hpp" namespace deep_gemm { @@ -58,7 +59,8 @@ struct GemmConfig { // Templated configs GemmType gemm_type; KernelType kernel_type; - at::ScalarType ab_dtype, cd_dtype; + MmaKind mma_kind; + at::ScalarType a_dtype, b_dtype, cd_dtype; cute::UMMA::Major major_a; cute::UMMA::Major major_b; bool with_accumulation; @@ -101,9 +103,9 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne const int& m, const int& n, const int& k, const int& block_m, const int& block_n, const int& block_k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const MmaKind& mma_kind, const at::ScalarType& cd_dtype, const int& num_stages, const MulticastConfig& multicast_config) { - const int& ab_elem_size = static_cast(c10::elementSize(ab_dtype)); + const int& ab_elem_size = static_cast(get_element_size(mma_kind)); const int& cd_elem_size = static_cast(c10::elementSize(cd_dtype)); const int& load_block_m = ArchSpec::get_ab_load_block_m(multicast_config, block_m); @@ -121,7 +123,7 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne // SF shared memory const auto& [smem_sfa_per_stage, smem_sfb_per_stage] = - ArchSpec::get_sf_smem_size_per_stage(kernel_type, block_m, block_n, block_k, ab_dtype, cd_dtype); + ArchSpec::get_sf_smem_size_per_stage(kernel_type, block_m, block_n, block_k, mma_kind, cd_dtype); const int& smem_extra_sfb = ArchSpec::get_extra_sfb_smem_size(m, n, k, block_m, block_n, block_k); // M-barriers and tensor memory pointers @@ -153,22 +155,41 @@ template static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& kernel_type, const int& m, const int& n, const int& k, const int& num_groups, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const at::ScalarType& a_dtype, const at::ScalarType& b_dtype, + const at::ScalarType& cd_dtype, const bool& with_accumulation, const int& num_sms, const int& max_block_n = 256, const bool& enable_overlap = false) { - DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16); + const auto mma_kind = (a_dtype == torch::kBFloat16 ? MmaKind::BF16 : MmaKind::MXFP8FP4); + if (mma_kind == MmaKind::BF16) { + DG_HOST_ASSERT(a_dtype == torch::kBFloat16 and b_dtype == torch::kBFloat16); + } else { + DG_HOST_ASSERT(a_dtype == torch::kFloat8_e4m3fn or a_dtype == kPackedFP4); + DG_HOST_ASSERT(b_dtype == torch::kFloat8_e4m3fn or b_dtype == kPackedFP4); + } DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat); // Select M/N block sizes - auto block_ms = std::vector{64, 128, 256}; + auto block_ms = ArchSpec::get_block_m_candidates(kernel_type, major_a, m); if (gemm_type == GemmType::MGroupedContiguous) block_ms = std::vector{get_mk_alignment_for_contiguous_layout()}; - if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance - block_ms = std::vector{64, 128}; - const auto block_ns = ArchSpec::get_block_n_candidates(cd_dtype, max_block_n); + if (gemm_type == GemmType::MGroupedMasked or gemm_type == GemmType::MGroupedContiguousWithPsumLayout) + block_ms = std::vector{64, 128}; // Exclude 256 for performance + auto block_ns = ArchSpec::get_block_n_candidates(kernel_type, cd_dtype); + + // Filter block_ns by max_block_n (sgl-release feature) + if (max_block_n < 256) { + std::erase_if(block_ns, [&](int bn) { return bn > max_block_n; }); + } + + // NOTES: TMA copy .b4x16_p64 only supports Swizzle 128B + // TODO: Optimize it + if (a_dtype == kPackedFP4 and major_a == cute::UMMA::Major::MN) + block_ms = std::vector{128}; + if (b_dtype == kPackedFP4 and major_b == cute::UMMA::Major::MN) + block_ns = std::vector{128}; // K block size is selected in a fixed manner - const auto& block_k = 128 / static_cast(c10::elementSize(ab_dtype)); + const auto& block_k = (mma_kind == MmaKind::BF16 ? 64 : 128); // Some util functions const auto& get_num_blocks = [=](const int& block_m, const int& block_n) { @@ -189,7 +210,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k for (const auto& block_n: block_ns) { const int& num_waves = get_num_waves(block_m, block_n); const auto& last_util = get_last_wave_util(block_m, block_n); - if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, block_m, block_n, block_k)) + if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, mma_kind, cd_dtype, m, n, k, block_m, block_n, block_k)) continue; bool success = false; @@ -221,8 +242,16 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k // Decide the number of TMA multicasts and whether broadcast on A MulticastConfig best_multicast_config = {1, false}; - const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality( + auto [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality( gemm_type, num_groups, m, n, best_block_m, best_block_n, num_sms); + + // NOTES: TMA copy .b4x16_p64 only supports Swizzle 128B + // TODO: Optimize it + if (a_dtype == kPackedFP4 and major_a == cute::UMMA::Major::MN) + is_legal_on_a = false; + if (b_dtype == kPackedFP4 and major_b == cute::UMMA::Major::MN) + is_legal_on_b = false; + const bool is_legal[2] = {is_legal_on_b, is_legal_on_a}; bool order[2] = {false, true}; if (best_block_m > best_block_n) @@ -238,15 +267,15 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k constexpr int smem_capacity = ArchSpec::smem_capacity; int best_num_stages = 0; SharedMemoryConfig best_smem_config; - for (int num_stages = 12; num_stages > 0; -- num_stages) { - if (not ArchSpec::is_num_stages_legal(ab_dtype, cd_dtype, num_stages, best_block_m, best_block_n, block_k)) + for (int num_stages = 32; num_stages > 0; -- num_stages) { + if (not ArchSpec::is_num_stages_legal(mma_kind, cd_dtype, num_stages, best_block_m, best_block_n, block_k)) continue; best_smem_config = get_smem_config(gemm_type, kernel_type, m, n, k, best_block_m, best_block_n, block_k, major_a, major_b, - ab_dtype, cd_dtype, + mma_kind, cd_dtype, num_stages, best_multicast_config); if (best_smem_config.smem_size <= smem_capacity) { best_num_stages = num_stages; @@ -258,7 +287,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k // Recompute the minimal number of SMs required // NOTES: less L2 cache usage and less GPU frequency drop int num_min_sms = num_sms; - if (ArchSpec::should_minimize_num_sms()) { + if (get_env("DG_JIT_MINIMIZE_NUM_SMS", 0)) { num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, best_num_waves); num_min_sms = align(num_min_sms, best_multicast_config.num_multicast); DG_HOST_ASSERT(num_min_sms <= num_sms); @@ -267,7 +296,9 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k const auto& config = GemmConfig { .gemm_type = gemm_type, .kernel_type = kernel_type, - .ab_dtype = ab_dtype, + .mma_kind = mma_kind, + .a_dtype = a_dtype, + .b_dtype = b_dtype, .cd_dtype = cd_dtype, .major_a = major_a, .major_b = major_b, @@ -289,21 +320,22 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k // Only SM100 BF16 kernels support tensor core control if (config.tc_util < 100) - DG_HOST_ASSERT(device_runtime->get_arch_major() == 10 and ab_dtype == torch::kBFloat16); + DG_HOST_ASSERT(device_runtime->get_arch_major() == 10 and mma_kind == MmaKind::BF16); // Print configs for the first time if (get_env("DG_JIT_DEBUG") or get_env("DG_PRINT_CONFIGS")) { auto key = std::make_tuple(gemm_type, kernel_type, m, n, k, num_groups, major_a, major_b, - ab_dtype, cd_dtype, with_accumulation, num_sms); + mma_kind, a_dtype, b_dtype, cd_dtype, with_accumulation, num_sms); static std::set printed; if (printed.count(key) == 0) { printf("GEMM type: %d, kernel type: %d, M: %d, N: %d, K: %d, groups: %d, " - "A major: %d, B major: %d, AB dtype: %s, CD dtype: %s, accumulation: %d, " + "A major: %d, B major: %d, MMA kind: %d, A dtype: %s, B dtype: %s, CD dtype: %s, accumulation: %d, " "SM limit: %d -> block M: %d, block N: %d, block K: %d, stages: %d, last stages: %d, " "SMs: %d, multicast: %d, multicast on A: %d, shared memory: %d bytes, swizzle A: %d, " "swizzle B: %d, swizzle CD: %d, SMs: %d, threads: %d, TC util: %d%%\n", static_cast(gemm_type), static_cast(kernel_type), m, n, k, num_groups, - static_cast(major_a), static_cast(major_b), c10::toString(ab_dtype), c10::toString(cd_dtype), + static_cast(major_a), static_cast(major_b), static_cast(mma_kind), + c10::toString(a_dtype), c10::toString(b_dtype), c10::toString(cd_dtype), static_cast(with_accumulation), num_sms, best_block_m, best_block_n, block_k, best_num_stages, config.num_last_stages, num_min_sms, best_multicast_config.num_multicast, static_cast(best_multicast_config.is_multicast_on_a), diff --git a/csrc/jit_kernels/heuristics/sm100.hpp b/csrc/jit_kernels/heuristics/sm100.hpp index d0d16980..dd1e6024 100644 --- a/csrc/jit_kernels/heuristics/sm100.hpp +++ b/csrc/jit_kernels/heuristics/sm100.hpp @@ -12,7 +12,17 @@ namespace deep_gemm { struct SM100ArchSpec { static constexpr int smem_capacity = 232448; - static std::vector get_block_n_candidates(const at::ScalarType& cd_dtype, const int& max_block_n) { + static std::vector get_block_m_candidates(const KernelType& kernel_type, const cute::UMMA::Major& major_a, const int& m) { + std::vector candidates{128, 256}; + if ((kernel_type == KernelType::Kernel1D1D or kernel_type == KernelType::KernelNoSF) and major_a == cute::UMMA::Major::K) { + // NOTES: `block_m = 32/64` is smaller than `LAYOUT_AD_M`, should be careful in handling this + if (m <= 32) candidates.push_back(32); + if (m <= 64) candidates.push_back(64); + } + return candidates; + } + + static std::vector get_block_n_candidates(const KernelType& kernel_type, const at::ScalarType& cd_dtype) { // 16 is for better SM usage // Stride 32 is due to low-performance swizzle-16/32B std::vector candidates = {16}; @@ -43,42 +53,36 @@ struct SM100ArchSpec { } static std::pair get_sf_uttcp_aligned_block_sizes( - const int& block_m, const int& block_n, const at::ScalarType& ab_dtype) { + const int& block_m, const int& block_n, const MmaKind& mma_kind) { constexpr int num_utccp_aligned_elems = 128; - DG_HOST_ASSERT(block_m % num_utccp_aligned_elems == 0); - switch (ab_dtype) { - case torch::kBFloat16: return {0, 0}; - case torch::kFloat8_e4m3fn: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)}; + switch (mma_kind) { + case MmaKind::BF16: return {0, 0}; + case MmaKind::MXFP8FP4: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)}; default: DG_HOST_UNREACHABLE("Unknown dtype"); } } static bool is_block_size_legal(const KernelType& kernel_type, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const MmaKind& mma_kind, const at::ScalarType& cd_dtype, + const int& m, const int& n, const int& k, const int& block_m, const int& block_n, const int& block_k) { - // TODO: consider more carefully for BF16 GEMMs - // 2SM BF16 UMMA does not support `N % 32 != 0` - if (ab_dtype == torch::kBFloat16 and block_n % 32 != 0) - return false; - - // Layout A/D does not support `block_m == 64` and `block_n % 16 != 0` - if (block_m == 64 or block_n % 16 != 0) + // Layout A/D does not support `block_n % 16 != 0` + if (block_n % 16 != 0) return false; // Performance is lower with 1D1D and `block_m == 256` - if (kernel_type == KernelType::Kernel1D1D and major_b == cute::UMMA::Major::K and block_m != 128) + if (kernel_type == KernelType::Kernel1D1D and major_b == cute::UMMA::Major::K and block_m > 128) return false; - // 1D2D kernels' maximum block N is 128 - // 1D2D kernels require more friendly block Ns - if (kernel_type == KernelType::Kernel1D2D and (block_n > 128 or 128 % block_n != 0)) + // For small K, fewer store blocks improve store/compute overlap and reduce epilogue bottleneck + if (k <= 256 and (block_n > 128 or block_m > 128)) return false; // Check tensor memory validity int sf_block_m = 0, sf_block_n = 0; if (kernel_type == KernelType::Kernel1D1D) { - const auto& [sf_block_m_, sf_block_n_] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, ab_dtype); + const auto& [sf_block_m_, sf_block_n_] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, mma_kind); sf_block_m = sf_block_m_, sf_block_n = sf_block_n_; } if (((2 * block_n) + (sf_block_m / 32) + (sf_block_n / 32)) > 512) @@ -86,32 +90,29 @@ struct SM100ArchSpec { // NOTES: when B is MN-major, we restrict `block_n` to multiples of 64, // since TMA performance degrades when `swizzle_b <= 32B` (i.e., when `block_ns % 64 != 0`), even with 3D TMA - return major_b == cute::UMMA::Major::K or (block_n * c10::elementSize(ab_dtype)) % 64 == 0; + return major_b == cute::UMMA::Major::K or (block_n * get_element_size(mma_kind)) % 64 == 0; } - static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + static bool is_num_stages_legal(const MmaKind& mma_kind, const at::ScalarType& cd_dtype, const int& num_stages, const int& block_m, const int& block_n, const int& block_k) { return true; } - static bool should_minimize_num_sms() { - return false; - } - static std::pair get_multicast_legality(const GemmType& gemm_type, const int& num_groups, - const int& m, const int& n, const int& block_m, const int& block_n, - const int& num_sms) { + const int& m, const int& n, const int& block_m, const int& block_n, + const int& num_sms) { // TODO: support other layouts return { false, - is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous), + is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous + or (gemm_type == GemmType::Batched and num_groups <= 32)), }; } static ThreadConfig get_thread_config(const KernelType& kernel_type, const int& block_m, const int& block_n) { - return ThreadConfig::sm100(128, kernel_type == KernelType::Kernel1D2D ? block_m : 128); + return ThreadConfig::sm100(128, 128); } static int get_smem_cd_size(const KernelType& kernel_type, @@ -119,19 +120,19 @@ struct SM100ArchSpec { const int& swizzle_cd_mode, const at::ScalarType& cd_dtype) { constexpr static int layout_ad_m = 128; - return (kernel_type != KernelType::Kernel1D2D ? std::min(block_m, layout_ad_m) : block_m) * swizzle_cd_mode * 2; + return std::min(block_m, layout_ad_m) * swizzle_cd_mode * 2; } static std::pair get_sf_smem_size_per_stage(const KernelType& kernel_type, const int& block_m, const int& block_n, const int& block_k, - const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype) { - if (ab_dtype == torch::kBFloat16) + const MmaKind& mma_kind, const at::ScalarType& cd_dtype) { + if (mma_kind == MmaKind::BF16) return {0, 0}; int smem_sfa_per_stage = 0; int smem_sfb_per_stage = 0; if (kernel_type == KernelType::Kernel1D1D) { - const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, ab_dtype); + const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, mma_kind); smem_sfa_per_stage = sf_block_m * 4; smem_sfb_per_stage = sf_block_n * 4; } else { @@ -149,7 +150,6 @@ struct SM100ArchSpec { static int get_barrier_smem_size(const int& num_stages) { // TODO: remove SF barriers for BF16 GEMMs // TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers - // NOTES: 1D2D kernel will not use the with-SF full barriers // NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages // NOTES: the last barrier is for tensor core utilization control return num_stages * 8 * 3 + 2 * 8 * 2 + 8; diff --git a/csrc/jit_kernels/heuristics/sm90.hpp b/csrc/jit_kernels/heuristics/sm90.hpp index d411206b..2fd2e9ec 100644 --- a/csrc/jit_kernels/heuristics/sm90.hpp +++ b/csrc/jit_kernels/heuristics/sm90.hpp @@ -10,12 +10,29 @@ namespace deep_gemm { struct SM90ArchSpec { static constexpr int smem_capacity = 232448; + + static std::vector get_block_m_candidates(const KernelType& kernel_type, const cute::UMMA::Major& major_a, const int& m) { + std::vector candidates{64, 128, 256}; + if ((kernel_type == KernelType::Kernel1D2D or kernel_type == KernelType::KernelNoSF) and major_a == cute::UMMA::Major::K) { + // NOTES: `block_m = 16/32` is smaller than MMA M size, should be careful in handling this + if (m <= 16) candidates.push_back(16); + if (m <= 32) candidates.push_back(32); + } + return candidates; + } + + static std::vector get_block_n_candidates(const KernelType& kernel_type, const at::ScalarType& cd_dtype) { + int start = 16; - static std::vector get_block_n_candidates(const at::ScalarType& cd_dtype, const int& max_block_n) { - // Avoid bank conflicts for FP32 output - const auto& start = cd_dtype == torch::kFloat ? 8 : 16; + // Avoid bank conflicts for 1D1D kernel FP32 output std::vector candidates; - for (int i = start; i <= max_block_n; i += 16) + if (kernel_type == KernelType::Kernel1D1D and cd_dtype == torch::kFloat) { + candidates.push_back(16); + start = 24; + } + + // Push the strided options + for (int i = start; i <= 256; i += 16) candidates.push_back(i); return candidates; } @@ -43,7 +60,8 @@ struct SM90ArchSpec { static bool is_block_size_legal(const KernelType& kernel_type, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const MmaKind& mma_kind, const at::ScalarType& cd_dtype, + const int& m, const int& n, const int& k, const int& block_m, const int& block_n, const int& block_k) { // SM90 FP32 output does not support `block_m == 256` if (cd_dtype == at::kFloat and block_m == 256) @@ -58,32 +76,28 @@ struct SM90ArchSpec { return false; } + // When B is N Major, use swizzle 128B for better performance; only affects SM90 BF16 GEMM + if (major_b == cute::UMMA::Major::MN and block_n >= 128 and block_n % 64 != 0) + return false; + // Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k` // Or too many register spills if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192)) return false; - // Avoid bank conflicts for FP32 output - if (cd_dtype == torch::kFloat and block_n % 16 == 0) - return false; - // The block sizes cannot be too large (for enough registers), so at least one dim less than 128 return block_m <= 128 or block_n <= 128; } - static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + static bool is_num_stages_legal(const MmaKind& mma_kind, const at::ScalarType& cd_dtype, const int& num_stages, const int& block_m, const int& block_n, const int& block_k) { // Unrolling both stages and `num_former_iters` will cause large code size - if (ab_dtype == torch::kFloat8_e4m3fn and block_k % block_n != 0 and block_k / std::gcd(block_n, block_k) <= 4) + if (mma_kind == MmaKind::MXFP8FP4 and block_k % block_n != 0 and block_k / std::gcd(block_n, block_k) <= 4) return num_stages <= 4; return true; } - static bool should_minimize_num_sms() { - return true; - } - static std::pair get_multicast_legality(const GemmType& gemm_type, const int& num_groups, const int& m, const int& n, const int& block_m, const int& block_n, const int& num_sms) { @@ -91,6 +105,9 @@ struct SM90ArchSpec { if (gemm_type == GemmType::KGroupedContiguous and num_groups > 4) return {false, false}; + if (gemm_type == GemmType::Batched) + return {false, false}; + return { is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked), // For masked GEMM layout, divisibility on N is also required as we must ensure the total number of blocks is even @@ -101,27 +118,27 @@ struct SM90ArchSpec { static ThreadConfig get_thread_config(const KernelType& kernel_type, const int& block_m, const int& block_n) { - return ThreadConfig::sm90(128, (block_m == 64 ? 1 : 2) * 128); + return ThreadConfig::sm90(128, (block_m <= 64 ? 1 : 2) * 128); } static int get_smem_cd_size(const KernelType& kernel_type, const int& block_m, const int& block_n, const int& swizzle_cd_mode, const at::ScalarType& cd_dtype) { - return block_m * block_n * static_cast(c10::elementSize(cd_dtype)); + // NOTES: 1024 is for TMA swizzling alignment requirement + return align(block_m * block_n * static_cast(c10::elementSize(cd_dtype)), 1024); } static std::pair get_sf_smem_size_per_stage(const KernelType& kernel_type, const int& block_m, const int& block_n, const int& block_k, - const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype) { - if (ab_dtype == torch::kBFloat16) + const MmaKind& mma_kind, const at::ScalarType& cd_dtype) { + if (mma_kind == MmaKind::BF16) return {0, 0}; - int smem_sfa_per_stage = block_m * static_cast(sizeof(float)); + // NOTES: 128 is for 2D TMA alignment requirement + int smem_sfa_per_stage = align(block_m * static_cast(sizeof(float)), 128); int smem_sfb_per_stage = 0; - if (kernel_type == KernelType::Kernel1D1D) { - // NOTES: `128` is for 2D TMA alignment requirement + if (kernel_type == KernelType::Kernel1D1D) smem_sfb_per_stage = align(block_n * 4, 128); - } return {smem_sfa_per_stage, smem_sfb_per_stage}; } diff --git a/csrc/jit_kernels/impls/runtime_utils.hpp b/csrc/jit_kernels/impls/runtime_utils.hpp index d3b6b494..b3da4372 100644 --- a/csrc/jit_kernels/impls/runtime_utils.hpp +++ b/csrc/jit_kernels/impls/runtime_utils.hpp @@ -2,8 +2,9 @@ #include -#include "../../utils/math.hpp" #include "../heuristics/sm90.hpp" +#include "../../jit/handle.hpp" +#include "../../utils/math.hpp" #include "../../utils/system.hpp" #include "../../utils/exception.hpp" @@ -35,10 +36,12 @@ static std::string to_string(const cute::UMMA::Major& major) { static std::string to_string(const GemmType& type) { switch (type) { - case GemmType::Normal: return "GemmType::Normal"; - case GemmType::MGroupedContiguous: return "GemmType::MGroupedContiguous"; - case GemmType::MGroupedMasked: return "GemmType::MGroupedMasked"; - case GemmType::KGroupedContiguous: return "GemmType::KGroupedContiguous"; + case GemmType::Normal: return "GemmType::Normal"; + case GemmType::MGroupedContiguous: return "GemmType::MGroupedContiguous"; + case GemmType::MGroupedMasked: return "GemmType::MGroupedMasked"; + case GemmType::MGroupedContiguousWithPsumLayout: return "GemmType::MGroupedContiguousWithPsumLayout"; + case GemmType::KGroupedContiguous: return "GemmType::KGroupedContiguous"; + case GemmType::Batched: return "GemmType::Batched"; } DG_HOST_UNREACHABLE("Unknown GEMM type"); } @@ -48,6 +51,8 @@ static std::string to_string(const at::ScalarType& dtype) { case torch::kInt: return "int"; case torch::kFloat: return "float"; case torch::kBFloat16: return "cutlass::bfloat16_t"; + case torch::kFloat8_e4m3fn: return "cutlass::float_e4m3_t"; + case kPackedFP4: return "cutlass::detail::float_e2m1_unpacksmem_t"; default: DG_HOST_UNREACHABLE("Unsupported dtype"); } } @@ -62,6 +67,7 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8; + case kPackedFP4: return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B; default: DG_HOST_UNREACHABLE("Unsupported dtype"); } } @@ -95,6 +101,10 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t, if (swizzle_mode != 0) smem_inner_dim = swizzle_mode / elem_size; + // Inner dim must be a multiple of 64B for .b4x16_p64 + if (t.scalar_type() == kPackedFP4) + DG_HOST_ASSERT(gmem_inner_dim % 128 == 0); + CUtensorMap tensor_map; const cuuint64_t gmem_dims[2] = {static_cast(gmem_inner_dim), static_cast(gmem_outer_dim)}; const cuuint32_t smem_dims[2] = {static_cast(smem_inner_dim), static_cast(smem_outer_dim)}; @@ -105,7 +115,7 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t, gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim, gmem_outer_stride, swizzle_mode, swizzle_base, elem_size); } - DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( + DG_CUDA_DRIVER_CHECK(lazy_cuTensorMapEncodeTiled( &tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32), 2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides, CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base), @@ -114,14 +124,18 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t, } static CUtensorMap make_tma_3d_desc(const torch::Tensor& t, - const int& gmem_dim_0, const int& gmem_dim_1, const int& gmem_dim_2, - const int& smem_dim_0, const int& smem_dim_1, const int& smem_dim_2, + int gmem_dim_0, int gmem_dim_1, int gmem_dim_2, + int smem_dim_0, int smem_dim_1, int smem_dim_2, const int& gmem_stride_0, const int& gmem_stride_1, const int& swizzle_mode, const int& swizzle_base = 0, const bool& allow_tf32 = false) { const auto& elem_size = static_cast(t.element_size()); if (swizzle_mode != 0) - DG_HOST_ASSERT(smem_dim_0 == swizzle_mode / elem_size); + smem_dim_0 = swizzle_mode / elem_size; + + // Inner dim must be a multiple of 64B for .b4x16_p64 + if (t.scalar_type() == kPackedFP4) + DG_HOST_ASSERT(gmem_dim_0 % 128 == 0); CUtensorMap tensor_map; const cuuint64_t gmem_dims[3] = {static_cast(gmem_dim_0), static_cast(gmem_dim_1), static_cast(gmem_dim_2),}; @@ -133,7 +147,7 @@ static CUtensorMap make_tma_3d_desc(const torch::Tensor& t, gmem_dim_0, gmem_dim_1, gmem_dim_2, smem_dim_0, smem_dim_1, smem_dim_2, gmem_stride_0, gmem_stride_1, swizzle_mode, elem_size); } - DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( + DG_CUDA_DRIVER_CHECK(lazy_cuTensorMapEncodeTiled( &tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32), 3, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides, CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base), @@ -201,7 +215,7 @@ static CUtensorMap make_tma_cd_desc(const torch::Tensor& t, static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major, const torch::Tensor& t, int shape_mn, int shape_k, - const int& block_mn, const int& block_k, + const int& block_mn, const int& gran_k, const int& num_groups, const int& swizzle_mode, const int& swizzle_base = 0, const bool& allow_tf32 = false) { @@ -212,7 +226,7 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major, shape_mn = get_tma_aligned_size(shape_mn, static_cast(t.element_size())); return make_tma_2d_desc(t, - shape_mn, ceil_div(shape_k, block_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups, + shape_mn, ceil_div(shape_k, gran_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups, block_mn, 1, shape_mn, swizzle_mode, swizzle_base, diff --git a/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp index 9229f1e8..6b652695 100644 --- a/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp +++ b/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp @@ -23,8 +23,7 @@ class SM100BF16GemmRuntime final: public LaunchRuntime { void* grouped_layout; CUtensorMap tensor_map_a; CUtensorMap tensor_map_b; - CUtensorMap tensor_map_c; - CUtensorMap tensor_map_d; + CUtensorMap tensor_map_cd; }; static std::string generate_impl(const Args& args) { @@ -67,7 +66,7 @@ static void __instantiate_kernel() {{ DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.grouped_layout, args.m, args.n, args.k, args.tensor_map_a, args.tensor_map_b, - args.tensor_map_c, args.tensor_map_d)); + args.tensor_map_cd)); } }; @@ -78,14 +77,13 @@ static void sm100_bf16_gemm(const torch::Tensor& a, const int& m, const int& n, const int& k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const std::string& compiled_dims) { - const auto& aligned_k = align(k, 64); const auto& config = get_best_config( GemmType::Normal, KernelType::KernelNoSF, m, n, k, 1, major_a, major_b, - torch::kBFloat16, d.scalar_type(), c.has_value(), + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), device_runtime->get_num_sms()); - const auto& cd = c.value_or(d); const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), config.block_k, @@ -96,30 +94,15 @@ static void sm100_bf16_gemm(const torch::Tensor& a, config.block_k, static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, config.smem_config.swizzle_b_mode); - const auto& tensor_map_d = make_tma_cd_desc(d, m, n, - SM100ArchSpec::get_cd_store_block_m(config.block_m), - SM100ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(d.stride(-2)), 1, - config.smem_config.swizzle_cd_mode); - const auto& tensor_map_c = make_tma_cd_desc(cd, m, n, - SM100ArchSpec::get_cd_store_block_m(config.block_m), - SM100ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(cd.stride(-2)), 1, - config.smem_config.swizzle_cd_mode); - - // Duplicate the accumulator if necessary - if (c.has_value()) { - if (c->data_ptr() == d.data_ptr()) { - DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides()); - } else { - // ReSharper disable once CppExpressionWithoutSideEffects - d.copy_(c.value()); - } - } + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); // Launch const SM100BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = aligned_k, + .m = m, .n = n, .k = k, .num_groups = 1, .compiled_dims = compiled_dims, .gemm_config = config, @@ -129,12 +112,278 @@ static void sm100_bf16_gemm(const torch::Tensor& a, .grouped_layout = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, - .tensor_map_c = tensor_map_c, - .tensor_map_d = tensor_map_d + .tensor_map_cd = tensor_map_cd }; const auto& code = SM100BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm100_bf16_gemm", code); MAYBE_LAUNCH(SM100BF16GemmRuntime::launch(runtime, args)); } +static void sm100_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + const int& num_groups, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { + const auto& gemm_type = use_psum_layout ? GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous; + + // NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`. + // Otherwise, treat the contiguous layout as a whole. + const auto& m_for_config = expected_m_for_psum_layout.has_value() ? expected_m_for_psum_layout.value() : m; + const auto& num_groups_for_config = expected_m_for_psum_layout.has_value() ? num_groups : 1; + + const auto& config = get_best_config( + gemm_type, KernelType::KernelNoSF, + // NOTES: `num_groups` is 1, since the contiguous layout is seen as a whole + m_for_config, n, k, num_groups_for_config, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, + device_runtime->get_num_sms()); + + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + + // Launch + const SM100BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = grouped_layout.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm100_bf16_m_grouped_gemm_contiguous", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_bf16_gemm_masked(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& config = get_best_config( + GemmType::MGroupedMasked, KernelType::KernelNoSF, + expected_m, n, k, num_groups, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, + device_runtime->get_num_sms()); + + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); + + // Launch + const SM100BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm100_bf16_m_grouped_gemm_masked", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + +static void sm100_bf16_k_grouped_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, + const std::vector& ks, const torch::Tensor& ks_tensor, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN); + + int sum_k = 0; + for (const auto& k: ks) { + sum_k += k; + DG_HOST_ASSERT(k % 128 == 0); + } + const auto& num_groups = static_cast(ks.size()); + + // Get config using max K for better performance + const auto& max_k = *std::max_element(ks.begin(), ks.end()); + const auto& config = get_best_config( + GemmType::KGroupedContiguous, KernelType::KernelNoSF, + m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Create tensor descriptors + const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(0)), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(0)), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(1)), num_groups, + config.smem_config.swizzle_cd_mode); + + // Launch kernel + const SM100BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = sum_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = ks_tensor.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm100_bf16_k_grouped_gemm", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + +static void sm100_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a, + const torch::Tensor& tensor_b, + const torch::Tensor& tensor_d, + const int& b, const int& h, const int& r, const int& d, + const std::string& compiled_dims = "nk") { + const auto& config = get_best_config( + GemmType::Batched, KernelType::KernelNoSF, + b, d, r, h, cute::UMMA::Major::K, cute::UMMA::Major::K, + tensor_a.scalar_type(), tensor_b.scalar_type(), + tensor_d.scalar_type(), false, + device_runtime->get_num_sms()); + + const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); + const auto& tensor_map_a = make_tma_3d_desc(tensor_a, r, b, h, + config.block_k, load_block_m, 1, + tensor_a.stride(0), tensor_a.stride(1), + config.smem_config.swizzle_a_mode); + const int& load_block_n = SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); + const auto& tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, + config.block_k, load_block_n, 1, + tensor_b.stride(1), tensor_b.stride(0), + config.smem_config.swizzle_b_mode); + const int& store_block_m = SM100ArchSpec::get_cd_store_block_m(config.block_m); + const int& store_block_n = SM100ArchSpec::get_cd_store_block_n(config.block_n); + const auto& tensor_map_cd = make_tma_3d_desc(tensor_d, d, b, h, + store_block_n, store_block_m, 1, + tensor_d.stride(0), tensor_d.stride(1), + config.smem_config.swizzle_cd_mode); + + // Launch + const SM100BF16GemmRuntime::Args& args = { + .m = b, .n = d, .k = r, + .num_groups = h, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm100_bf16_bhr_hdr_bhd", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + +static void sm100_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a, + const torch::Tensor& tensor_b, + const torch::Tensor& tensor_d, + const int& b, const int& h, const int& r, const int& d, + const std::string& compiled_dims = "nk") { + const auto& config = get_best_config( + GemmType::Batched, KernelType::KernelNoSF, + b, r, d, h, cute::UMMA::Major::K, cute::UMMA::Major::MN, + tensor_a.scalar_type(), tensor_b.scalar_type(), + tensor_d.scalar_type(), false, + device_runtime->get_num_sms()); + + const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); + const auto& tensor_map_a = make_tma_3d_desc(tensor_a, d, b, h, + config.block_k, load_block_m, 1, + tensor_a.stride(0), tensor_a.stride(1), + config.smem_config.swizzle_a_mode); + const int& load_block_n = SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); + const auto& tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, + load_block_n, config.block_k, 1, + tensor_b.stride(1), tensor_b.stride(0), + config.smem_config.swizzle_b_mode); + const int& store_block_m = SM100ArchSpec::get_cd_store_block_m(config.block_m); + const int& store_block_n = SM100ArchSpec::get_cd_store_block_n(config.block_n); + const auto& tensor_map_cd = make_tma_3d_desc(tensor_d, r, b, h, + store_block_n, store_block_m, 1, + tensor_d.stride(0), tensor_d.stride(1), + config.smem_config.swizzle_cd_mode); + + // Launch + const SM100BF16GemmRuntime::Args& args = { + .m = b, .n = r, .k = d, + .num_groups = h, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm100_bf16_bhd_hdr_bhr", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp index 30dcb149..188e1dd8 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -13,10 +13,11 @@ namespace deep_gemm { -class SM100FP8Gemm1D1DRuntime final: public LaunchRuntime { +class SM100FP8FP4Gemm1D1DRuntime final: public LaunchRuntime { public: struct Args { int m, n, k, num_groups; + int gran_k_a, gran_k_b; const std::string& compiled_dims; const std::optional& epilogue_type; @@ -28,8 +29,7 @@ class SM100FP8Gemm1D1DRuntime final: public LaunchRuntime(&sm100_fp8_gemm_1d1d_impl< + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, @@ -49,12 +50,14 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, + {}, {}, {}, {}, {}, {} >); }}; )", to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b), + args.gran_k_a, args.gran_k_b, get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, args.num_groups, @@ -63,7 +66,8 @@ static void __instantiate_kernel() {{ args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, args.gemm_config.num_sms, - to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype), + to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, + to_string(args.gemm_config.a_dtype), to_string(args.gemm_config.b_dtype), to_string(args.gemm_config.cd_dtype), get_default_epilogue_type(args.epilogue_type)); } @@ -73,23 +77,24 @@ static void __instantiate_kernel() {{ args.grouped_layout, args.m, args.n, args.k, args.tensor_map_a, args.tensor_map_b, args.tensor_map_sfa, args.tensor_map_sfb, - args.tensor_map_c, args.tensor_map_d)); + args.tensor_map_cd)); } }; -static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, - const torch::Tensor& b, const torch::Tensor& sfb, - const std::optional& c, - const torch::Tensor& d, - const int& m, const int& n, const int& k, - const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims, - const std::optional& epilogue_type = std::nullopt) { - const auto& aligned_k = align(k, 128); +static void sm100_fp8_fp4_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims, + const std::optional& epilogue_type = std::nullopt) { const auto& config = get_best_config( GemmType::Normal, KernelType::Kernel1D1D, m, n, k, 1, major_a, major_b, - torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), device_runtime->get_num_sms()); const auto& cd = c.value_or(d); @@ -103,35 +108,22 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa config.block_k, static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, config.smem_config.swizzle_b_mode); - const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast(d.size(-1)), - SM100ArchSpec::get_cd_store_block_m(config.block_m), - SM100ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(d.stride(-2)), 1, - config.smem_config.swizzle_cd_mode); - const auto& tensor_map_c = make_tma_cd_desc(cd, m, n, - SM100ArchSpec::get_cd_store_block_m(config.block_m), - SM100ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(cd.stride(-2)), 1, - config.smem_config.swizzle_cd_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, static_cast(d.size(-1)), + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, - config.block_m, config.block_k, 1, 0); + config.block_m, gran_k_a, 1, 0); const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, - config.block_n, config.block_k, 1, 0); - - // Duplicate the accumulator if necessary - if (c.has_value()) { - if (c->data_ptr() == d.data_ptr()) { - DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides()); - } else { - // ReSharper disable once CppExpressionWithoutSideEffects - d.copy_(c.value()); - } - } + config.block_n, gran_k_b, 1, 0); // Launch - const SM100FP8Gemm1D1DRuntime::Args& args = { - .m = m, .n = n, .k = aligned_k, + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, .num_groups = 1, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, .compiled_dims = compiled_dims, .epilogue_type = epilogue_type, .gemm_config = config, @@ -143,26 +135,35 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa .tensor_map_b = tensor_map_b, .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, - .tensor_map_c = tensor_map_c, - .tensor_map_d = tensor_map_d + .tensor_map_cd = tensor_map_cd }; - const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); - const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code); - MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args)); + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_fp8_fp4_gemm_1d1d", code); + MAYBE_LAUNCH(SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args)); } -static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, - const torch::Tensor& b, const torch::Tensor& sfb, - const torch::Tensor& d, - const torch::Tensor& m_indices, - const int& num_groups, const int& m, const int& n, const int& k, - const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims) { - const auto& aligned_k = align(k, 128); +static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + const int& num_groups, const int& m, const int& n, const int& k, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { + const auto& gemm_type = use_psum_layout ? GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous; + + // NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`. + // Otherwise, treat the contiguous layout as a whole. + const auto& m_for_config = expected_m_for_psum_layout.has_value() ? expected_m_for_psum_layout.value() : m; + const auto& num_groups_for_config = expected_m_for_psum_layout.has_value() ? num_groups : 1; + const auto& config = get_best_config( - GemmType::MGroupedContiguous, KernelType::Kernel1D1D, - m, n, k, 1, major_a, major_b, - torch::kFloat8_e4m3fn, d.scalar_type(), false, + gemm_type, KernelType::Kernel1D1D, + m_for_config, n, k, num_groups_for_config, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, device_runtime->get_num_sms()); // Create tensor descriptors @@ -176,52 +177,54 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con config.block_k, static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, config.smem_config.swizzle_b_mode); - const auto& tensor_map_d = make_tma_cd_desc(d, m, n, - SM100ArchSpec::get_cd_store_block_m(config.block_m), - SM100ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(d.stride(-2)), 1, - config.smem_config.swizzle_cd_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, - config.block_m, config.block_k, 1, 0); + config.block_m, gran_k_a, 1, 0); const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, - config.block_n, config.block_k, num_groups, 0); + config.block_n, gran_k_b, num_groups, 0); // Launch kernel - const SM100FP8Gemm1D1DRuntime::Args& args = { - .m = m, .n = n, .k = aligned_k, + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, .num_groups = num_groups, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, .compiled_dims = compiled_dims, .epilogue_type = std::nullopt, .gemm_config = config, .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = m_indices.data_ptr(), + .grouped_layout = grouped_layout.data_ptr(), .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, - .tensor_map_c = tensor_map_d, - .tensor_map_d = tensor_map_d + .tensor_map_cd = tensor_map_cd }; - const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); - const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d1d", code); - MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args)); + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d", code); + MAYBE_LAUNCH(SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args)); } -static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, - const torch::Tensor& b, const torch::Tensor& sfb, - const torch::Tensor& d, - const torch::Tensor& masked_m, - const int& num_groups, const int& m, const int& n, const int& k, - const int& expected_m, - const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims) { - const auto& aligned_k = align(k, 128); +static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { const auto& config = get_best_config( GemmType::MGroupedMasked, KernelType::Kernel1D1D, expected_m, n, k, num_groups, major_a, major_b, - torch::kFloat8_e4m3fn, d.scalar_type(), false, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, device_runtime->get_num_sms()); // Create tensor descriptors @@ -235,20 +238,22 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t config.block_k, static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, config.smem_config.swizzle_b_mode); - const auto& tensor_map_d = make_tma_cd_desc(d, m, n, - SM100ArchSpec::get_cd_store_block_m(config.block_m), - SM100ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(d.stride(-2)), num_groups, - config.smem_config.swizzle_cd_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, - config.block_m, config.block_k, num_groups, 0); + config.block_m, gran_k_a, num_groups, 0); const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, - config.block_n, config.block_k, num_groups, 0); + config.block_n, gran_k_b, num_groups, 0); // Launch kernel - const SM100FP8Gemm1D1DRuntime::Args& args = { - .m = m, .n = n, .k = aligned_k, + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, .num_groups = num_groups, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, .compiled_dims = compiled_dims, .epilogue_type = std::nullopt, .gemm_config = config, @@ -260,22 +265,21 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t .tensor_map_b = tensor_map_b, .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, - .tensor_map_c = tensor_map_d, - .tensor_map_d = tensor_map_d + .tensor_map_cd = tensor_map_cd }; - const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); - const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d1d", code); - MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args)); + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_masked_1d1d", code); + MAYBE_LAUNCH(SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args)); } -static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, - const torch::Tensor& b, const torch::Tensor& sfb, - const std::optional& c, - const torch::Tensor& d, - const int& m, const int& n, - const std::vector& ks, const torch::Tensor& ks_tensor, - const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims) { +static void sm100_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, + const std::vector& ks, const torch::Tensor& ks_tensor, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN); int sum_k = 0, sum_sf_k = 0; @@ -290,11 +294,11 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& const auto& config = get_best_config( GemmType::KGroupedContiguous, KernelType::Kernel1D1D, m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN, - torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), device_runtime->get_num_sms()); // Create tensor descriptors - const auto& cd = c.value_or(d); const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), config.block_k, @@ -305,31 +309,22 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& config.block_k, static_cast(b.stride(0)), 1, config.smem_config.swizzle_b_mode); - const auto& tensor_map_d = make_tma_cd_desc(d, m, n, - SM100ArchSpec::get_cd_store_block_m(config.block_m), - SM100ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(d.stride(1)), num_groups, - config.smem_config.swizzle_cd_mode); - const auto& tensor_map_c = make_tma_cd_desc(cd, m, n, - SM100ArchSpec::get_cd_store_block_m(config.block_m), - SM100ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(cd.stride(1)), num_groups, - config.smem_config.swizzle_cd_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(1)), num_groups, + config.smem_config.swizzle_cd_mode); const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 512, config.block_m, config.block_k, 1, 0); const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 512, config.block_n, config.block_k, 1, 0); - // Duplicate the accumulator if necessary - if (c.has_value()) { - DG_HOST_ASSERT(c->data_ptr() == d.data_ptr()); - DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides()); - } - // Launch kernel - const SM100FP8Gemm1D1DRuntime::Args& args = { + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { .m = m, .n = n, .k = sum_k, .num_groups = num_groups, + .gran_k_a = 128, + .gran_k_b = 128, .compiled_dims = compiled_dims, .epilogue_type = std::nullopt, .gemm_config = config, @@ -341,12 +336,79 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& .tensor_map_b = tensor_map_b, .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, - .tensor_map_c = tensor_map_c, - .tensor_map_d = tensor_map_d + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_k_grouped_fp8_gemm_1d1d", code); + MAYBE_LAUNCH(SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args)); +} + +static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& batch_size, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& config = get_best_config( + GemmType::Batched, KernelType::Kernel1D1D, + m, n, k, batch_size, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); + const auto& [inner_dim_a, outer_dim_a] = get_inner_outer_dims(major_a, k, m); + const auto& [inner_block_a, outer_block_a] = get_inner_outer_dims(major_a, config.block_k, load_block_m); + const auto& tensor_map_a = make_tma_3d_desc(a, inner_dim_a, outer_dim_a, batch_size, + inner_block_a, outer_block_a, 1, + a.stride(major_a == cute::UMMA::Major::K ? 1 : 2), + a.stride(0), + config.smem_config.swizzle_a_mode); + + const int& load_block_n = SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); + const auto& [inner_dim_b, outer_dim_b] = get_inner_outer_dims(major_b, k, n); + const auto& [inner_block_b, outer_block_b] = get_inner_outer_dims(major_b, config.block_k, load_block_n); + const auto& tensor_map_b = make_tma_3d_desc(b, inner_dim_b, outer_dim_b, batch_size, + inner_block_b, outer_block_b, 1, + b.stride(major_b == cute::UMMA::Major::K ? 1 : 2), + b.stride(0), + config.smem_config.swizzle_b_mode); + + const int& store_block_m = SM100ArchSpec::get_cd_store_block_m(config.block_m); + const int& store_block_n = SM100ArchSpec::get_cd_store_block_n(config.block_n); + const auto& tensor_map_cd = make_tma_3d_desc(d, n, m, batch_size, + store_block_n, store_block_m, 1, + d.stride(1), d.stride(0), + config.smem_config.swizzle_cd_mode); + + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, batch_size, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, config.block_k, batch_size, 0); + + // Launch + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = batch_size, + .gran_k_a = 128, + .gran_k_b = 128, + .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd }; - const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); - const auto& runtime = compiler->build("sm100_fp8_k_grouped_gemm_1d1d", code); - MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args)); + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code); + MAYBE_LAUNCH(SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args)); } } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp b/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp new file mode 100644 index 00000000..4f3ce5b1 --- /dev/null +++ b/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp @@ -0,0 +1,149 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100BF16HCPrenormGemmRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k; + int block_m, block_n, block_k; + int num_splits; + int swizzle_cd_mode; + int num_stages; + int num_mma_threads, num_cast_and_reduce_threads; + + LaunchArgs launch_args; + + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + float* sqr_sum; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_tf32_hc_prenorm_gemm_impl< + {}, {}, + {}, {}, {}, + {}, + {}, + {}, + {}, {} + >); +}}; +)", + args.n, args.k, + args.block_m, args.block_n, args.block_k, + args.num_splits, + args.swizzle_cd_mode, + args.num_stages, + args.num_mma_threads, args.num_cast_and_reduce_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.m, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.sqr_sum)); + } +}; + +static void sm100_tf32_hc_prenorm_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& sqr_sum, + const int& m, const int& n, const int& k, + const int& num_splits) { + constexpr int block_m = 64; + constexpr int block_k = 64; + constexpr int num_mma_threads = 128; + constexpr int num_cast_and_reduce_threads = 128; + + const int block_n = align(n, 16); + DG_HOST_ASSERT(n <= block_n); + DG_HOST_ASSERT(n <= 128 and n % 8 == 0); + DG_HOST_ASSERT(k % block_k == 0); + + const auto& swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float)); + const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k, + block_m, block_k, + static_cast(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, a.element_size()), 0, + true); + const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k, + block_n, block_k, + static_cast(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, b.element_size()), 0, + true); + const auto& tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n, + block_m, block_n, + static_cast(d.stride(-2)), 1, + swizzle_cd_mode) + : make_tma_3d_desc(d, n, m, num_splits, + block_n, block_m, 1, + static_cast(d.stride(-2)), + static_cast(d.stride(-3)), + swizzle_cd_mode); + + // Calculate stages + int num_stages = 12, smem_size = 0; + while (num_stages > 0) { + const int smem_a_per_stage = block_m * block_k * static_cast(sizeof(nv_bfloat16)); + const int smem_b_per_stage = block_n * block_k * static_cast(sizeof(float)); + const int smem_cd = block_m * swizzle_cd_mode; + const int smem_barriers = (num_stages * 4 + 1) * 8; + const int smem_tmem_ptr = 4; + smem_size = (smem_a_per_stage + smem_b_per_stage) * num_stages + + smem_cd + smem_barriers + smem_tmem_ptr; + + if (smem_size <= SM100ArchSpec::smem_capacity) + break; + -- num_stages; + } + DG_HOST_ASSERT(num_stages > 0); + + // Print configs + if (get_env("DG_JIT_DEBUG", 0)) { + printf("M: %d, N: %d, K: %d -> " + "block M: %d, block N: %d, block K: %d, split K: %d" + "stages: %d, shared memory: %d, swizzle CD: %d\n", + m, n, k, block_m, block_n, block_k, num_splits, + num_stages, smem_size, swizzle_cd_mode); + } + + // Launch + const SM100BF16HCPrenormGemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .block_m = block_m, .block_n = block_n, .block_k = block_k, + .num_splits = num_splits, + .swizzle_cd_mode = swizzle_cd_mode, + .num_stages = num_stages, + .num_mma_threads = num_mma_threads, + .num_cast_and_reduce_threads = num_cast_and_reduce_threads, + .launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_mma_threads + num_cast_and_reduce_threads, smem_size, 1), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .sqr_sum = sqr_sum.data_ptr() + }; + const auto& code = SM100BF16HCPrenormGemmRuntime::generate(args); + const auto& runtime = compiler->build("sm100_tf32_hc_prenorm_gemm", code); + SM100BF16HCPrenormGemmRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp index ca5e8b25..31b5c9a6 100644 --- a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp +++ b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp @@ -21,7 +21,7 @@ class SM90BF16GemmRuntime final: public LaunchRuntime { void *grouped_layout; CUtensorMap tensor_map_a; CUtensorMap tensor_map_b; - CUtensorMap tensor_map_d; + CUtensorMap tensor_map_cd; }; static std::string generate_impl(const Args& args) { @@ -32,26 +32,31 @@ using namespace deep_gemm; static void __instantiate_kernel() {{ auto ptr = reinterpret_cast(&sm90_bf16_gemm_impl< + {}, {}, {}, {}, {}, {}, {}, {}, {}, + {}, {}, {}, {}, {}, {}, {}, {}, + {}, {}, {}, - {}, {}, {} + {} >); }}; )", // TODO: add CD dtype + to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b), get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), args.num_groups, args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, - args.gemm_config.smem_config.swizzle_cd_mode, - args.gemm_config.num_stages, args.gemm_config.num_last_stages, + args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.num_stages, args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, - args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type), + args.gemm_config.num_sms, + to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype)); } @@ -61,7 +66,7 @@ static void __instantiate_kernel() {{ args.grouped_layout, args.m, args.n, args.k, args.tensor_map_a, args.tensor_map_b, - args.tensor_map_d)); + args.tensor_map_cd)); } }; @@ -72,14 +77,11 @@ static void sm90_bf16_gemm(const torch::Tensor& a, const int& m, const int& n, const int& k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const std::string& compiled_dims) { - DG_HOST_ASSERT(not c.has_value()); - DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); - - const auto& aligned_k = align(k, 64); const auto& config = get_best_config( GemmType::Normal, KernelType::KernelNoSF, m, n, k, 1, major_a, major_b, - torch::kBFloat16, d.scalar_type(), c.has_value(), + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), device_runtime->get_num_sms()); // Requires no TMA splits @@ -93,7 +95,7 @@ static void sm90_bf16_gemm(const torch::Tensor& a, config.block_k, static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, config.smem_config.swizzle_b_mode); - const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, SM90ArchSpec::get_cd_store_block_m(config.block_m), SM90ArchSpec::get_cd_store_block_n(config.block_n), static_cast(d.stride(-2)), 1, @@ -101,7 +103,7 @@ static void sm90_bf16_gemm(const torch::Tensor& a, // Launch const SM90BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = aligned_k, + .m = m, .n = n, .k = k, .num_groups = 1, .compiled_dims = compiled_dims, .gemm_config = config, @@ -111,7 +113,7 @@ static void sm90_bf16_gemm(const torch::Tensor& a, .grouped_layout = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, - .tensor_map_d = tensor_map_d, + .tensor_map_cd = tensor_map_cd, }; const auto& code = SM90BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_bf16_gemm", code); @@ -126,13 +128,14 @@ static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const std::string& compiled_dims) { DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); - DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); DG_HOST_ASSERT(k % 64 == 0); const auto& config = get_best_config( GemmType::MGroupedContiguous, KernelType::KernelNoSF, m, n, k, 1, major_a, major_b, - torch::kBFloat16, d.scalar_type(), false, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, device_runtime->get_num_sms()); // Requires no TMA splits @@ -146,7 +149,7 @@ static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, config.block_k, static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, config.smem_config.swizzle_b_mode); - const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, SM90ArchSpec::get_cd_store_block_m(config.block_m), SM90ArchSpec::get_cd_store_block_n(config.block_n), static_cast(d.stride(-2)), 1, @@ -164,7 +167,7 @@ static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, .grouped_layout = m_indices.data_ptr(), .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, - .tensor_map_d = tensor_map_d, + .tensor_map_cd = tensor_map_cd, }; const auto& code = SM90BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code); @@ -186,7 +189,8 @@ static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a, const auto& config = get_best_config( GemmType::MGroupedMasked, KernelType::KernelNoSF, expected_m, n, k, num_groups, major_a, major_b, - torch::kBFloat16, d.scalar_type(), false, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, device_runtime->get_num_sms()); // Requires no TMA splits @@ -200,7 +204,7 @@ static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a, config.block_k, static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, config.smem_config.swizzle_b_mode); - const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, SM90ArchSpec::get_cd_store_block_m(config.block_m), SM90ArchSpec::get_cd_store_block_n(config.block_n), static_cast(d.stride(-2)), num_groups, @@ -218,11 +222,167 @@ static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a, .grouped_layout = masked_m.data_ptr(), .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, - .tensor_map_d = tensor_map_d, + .tensor_map_cd = tensor_map_cd, }; const auto& code = SM90BF16GemmRuntime::generate(args); const auto& runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code); MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args)); } +static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, + const std::vector& ks, const torch::Tensor& ks_tensor, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN); + + int sum_k = 0; + for (const auto& k: ks) { + sum_k += k; + DG_HOST_ASSERT(k % 128 == 0); + } + const auto& num_groups = static_cast(ks.size()); + + // Get config using max K for better performance + const auto& max_k = *std::max_element(ks.begin(), ks.end()); + const auto& config = get_best_config( + GemmType::KGroupedContiguous, KernelType::KernelNoSF, + m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Create tensor descriptors + const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(0)), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(0)), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(1)), num_groups, + config.smem_config.swizzle_cd_mode); + + // Launch kernel + const SM90BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = sum_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = ks_tensor.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd, + }; + const auto& code = SM90BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm90_bf16_k_grouped_gemm", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +static void sm90_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a, + const torch::Tensor& tensor_b, + const torch::Tensor& tensor_d, + const int& b, const int& h, const int& r, const int& d, + const std::string& compiled_dims = "nk") { + const auto& config = get_best_config( + GemmType::Batched, KernelType::KernelNoSF, + b, d, r, h, cute::UMMA::Major::K, cute::UMMA::Major::K, + tensor_a.scalar_type(), tensor_b.scalar_type(), + tensor_d.scalar_type(), false, + device_runtime->get_num_sms()); + + const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); + const auto& tensor_map_a = make_tma_3d_desc(tensor_a, r, b, h, + config.block_k, load_block_m, 1, + tensor_a.stride(0), tensor_a.stride(1), + config.smem_config.swizzle_a_mode); + const int& load_block_n = SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); + const auto& tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, + config.block_k, load_block_n, 1, + tensor_b.stride(1), tensor_b.stride(0), + config.smem_config.swizzle_b_mode); + const int& store_block_m = SM90ArchSpec::get_cd_store_block_m(config.block_m); + const int& store_block_n = SM90ArchSpec::get_cd_store_block_n(config.block_n); + const auto& tensor_map_cd = make_tma_3d_desc(tensor_d, d, b, h, + store_block_n, store_block_m, 1, + tensor_d.stride(0), tensor_d.stride(1), + config.smem_config.swizzle_cd_mode); + // Launch + const SM90BF16GemmRuntime::Args& args = { + .m = b, .n = d, .k = r, + .num_groups = h, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd, + }; + const auto& code = SM90BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm90_bf16_bhr_hdr_bhd", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +static void sm90_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a, + const torch::Tensor& tensor_b, + const torch::Tensor& tensor_d, + const int& b, const int& h, const int& r, const int& d, + const std::string& compiled_dims = "nk") { + const auto& config = get_best_config( + GemmType::Batched, KernelType::KernelNoSF, + b, r, d, h, cute::UMMA::Major::K, cute::UMMA::Major::MN, + tensor_a.scalar_type(), tensor_b.scalar_type(), + tensor_d.scalar_type(), false, + device_runtime->get_num_sms()); + + const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); + const auto& tensor_map_a = make_tma_3d_desc(tensor_a, d, b, h, + config.block_k, load_block_m, 1, + tensor_a.stride(0), tensor_a.stride(1), + config.smem_config.swizzle_a_mode); + const int& load_block_n = SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); + const auto& tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, + load_block_n, config.block_k, 1, + tensor_b.stride(1), tensor_b.stride(0), + config.smem_config.swizzle_b_mode); + const int& store_block_m = SM90ArchSpec::get_cd_store_block_m(config.block_m); + const int& store_block_n = SM90ArchSpec::get_cd_store_block_n(config.block_n); + const auto& tensor_map_cd = make_tma_3d_desc(tensor_d, r, b, h, + store_block_n, store_block_m, 1, + tensor_d.stride(0), tensor_d.stride(1), + config.smem_config.swizzle_cd_mode); + // Launch + const SM90BF16GemmRuntime::Args& args = { + .m = b, .n = r, .k = d, + .num_groups = h, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd, + }; + const auto& code = SM90BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm90_bf16_bhd_hdr_bhr", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp index 4b778acf..e21fed31 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp @@ -27,7 +27,7 @@ class SM90FP8Gemm1D1DRuntime final: public LaunchRuntime CUtensorMap tensor_map_b_base; CUtensorMap tensor_map_sfa; CUtensorMap tensor_map_sfb; - CUtensorMap tensor_map_d; + CUtensorMap tensor_map_cd; }; static std::string generate_impl(const Args& args) { @@ -41,6 +41,7 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, {}, {}, + {}, {}, {}, {}, {}, {}, {}, @@ -52,6 +53,7 @@ static void __instantiate_kernel() {{ get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), args.num_groups, args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.num_stages, args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, @@ -67,7 +69,7 @@ static void __instantiate_kernel() {{ args.m, args.n, args.k, args.tensor_map_a_base, args.tensor_map_b_base, args.tensor_map_sfa, args.tensor_map_sfb, - args.tensor_map_d)); + args.tensor_map_cd)); } }; @@ -84,7 +86,8 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, const auto& config = get_best_config( GemmType::Normal, KernelType::Kernel1D1D, m, n, k, 1, major_a, major_b, - torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), device_runtime->get_num_sms()); // Requires no TMA splits @@ -103,11 +106,11 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, config.block_m, config.block_k, 1, 0); const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, config.block_n, config.block_k, 1, 0); - const auto& tensor_map_d = make_tma_cd_desc(d, m, n, - SM90ArchSpec::get_cd_store_block_m(config.block_m, true), - SM90ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(d.stride(-2)), 1, - 0); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m, true), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + 0); // Launch const SM90FP8Gemm1D1DRuntime::Args& args = { @@ -126,7 +129,7 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, .tensor_map_b_base = tensor_map_b, .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, - .tensor_map_d = tensor_map_d, + .tensor_map_cd = tensor_map_cd, }; const auto& code = SM90FP8Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code); @@ -134,7 +137,7 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, MAYBE_LAUNCH(SM90FP8Gemm1D1DRuntime::launch(runtime, args)); } -static void sm90_fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, +static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, const torch::Tensor& b, const torch::Tensor& sfb, const std::optional& c, const torch::Tensor& d, @@ -152,7 +155,8 @@ static void sm90_fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Te const auto& config = get_best_config( GemmType::KGroupedContiguous, KernelType::Kernel1D1D, m, n, max_k, num_groups, major_a, major_b, - torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), device_runtime->get_num_sms()); // Requires no TMA splits @@ -178,11 +182,11 @@ static void sm90_fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Te config.block_m, config.block_k, 1, 0); const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 128, config.block_n, config.block_k, 1, 0); - const auto& tensor_map_d = make_tma_cd_desc(d, m, n, - SM90ArchSpec::get_cd_store_block_m(config.block_m, true), - SM90ArchSpec::get_cd_store_block_n(config.block_n), - static_cast(d.stride(-2)), num_groups, - config.smem_config.swizzle_cd_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m, true), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); // Launch const SM90FP8Gemm1D1DRuntime::Args& args = { @@ -201,7 +205,7 @@ static void sm90_fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Te .tensor_map_b_base = tensor_map_b_base, .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, - .tensor_map_d = tensor_map_d, + .tensor_map_cd = tensor_map_cd, }; const auto& code = SM90FP8Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code); diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index f08bce8f..4879dca6 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -15,6 +15,7 @@ namespace deep_gemm { class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime { public: struct Args { + cute::UMMA::Major major_sfb; int m, n, k, num_groups; const std::string& compiled_dims; const std::optional& epilogue_type; @@ -37,10 +38,11 @@ using namespace deep_gemm; static void __instantiate_kernel() {{ auto ptr = reinterpret_cast(&sm90_fp8_gemm_1d2d_impl< - {}, {}, {}, {}, {}, {}, {}, {}, + {}, {}, {}, + {}, {}, {}, {}, {}, {}, {}, {}, {}, @@ -50,10 +52,11 @@ static void __instantiate_kernel() {{ }}; )", // TODO: add CD dtype + to_string(args.major_sfb), get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), args.num_groups, args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, - args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode, args.gemm_config.num_stages, args.gemm_config.num_last_stages, args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, @@ -77,17 +80,17 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, const std::optional& c, const torch::Tensor& d, const int& m, const int& n, const int& k, - const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb, const std::string& compiled_dims, const std::optional& epilogue_type = std::nullopt) { DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); - const auto& aligned_k = align(k, 128); const auto& config = get_best_config( GemmType::Normal, KernelType::Kernel1D2D, m, n, k, 1, major_a, major_b, - torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), device_runtime->get_num_sms()); // Requires no TMA splits @@ -113,7 +116,8 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, // Launch const SM90FP8Gemm1D2DRuntime::Args& args = { - .m = m, .n = n, .k = aligned_k, + .major_sfb = major_sfb, + .m = m, .n = n, .k = k, .num_groups = 1, .compiled_dims = compiled_dims, .epilogue_type = epilogue_type, @@ -139,16 +143,16 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons const torch::Tensor& d, const torch::Tensor& m_indices, const int& num_groups, const int& m, const int& n, const int& k, - const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb, const std::string& compiled_dims) { DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); - const auto& aligned_k = align(k, 128); const auto& config = get_best_config( GemmType::MGroupedContiguous, KernelType::Kernel1D2D, m, n, k, 1, major_a, major_b, - torch::kFloat8_e4m3fn, d.scalar_type(), false, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, device_runtime->get_num_sms()); // Requires no TMA splits @@ -174,7 +178,8 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons // Launch const SM90FP8Gemm1D2DRuntime::Args& args = { - .m = m, .n = n, .k = aligned_k, + .major_sfb = major_sfb, + .m = m, .n = n, .k = k, .num_groups = num_groups, .compiled_dims = compiled_dims, .epilogue_type = std::nullopt, @@ -206,14 +211,14 @@ static std::optional> sm90_m_grouped_fp8_gemm_masked_1d2d(co const int& max_block_n, const bool& enable_overlap, const c10::optional& signal) { - const auto& aligned_k = align(k, 128); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); const auto& config = get_best_config( GemmType::MGroupedMasked, KernelType::Kernel1D2D, expected_m, n, k, num_groups, major_a, major_b, - torch::kFloat8_e4m3fn, d.scalar_type(), false, + torch::kFloat8_e4m3fn, torch::kFloat8_e4m3fn, + d.scalar_type(), false, device_runtime->get_num_sms(), max_block_n, enable_overlap); // Requires no TMA splits @@ -239,7 +244,8 @@ static std::optional> sm90_m_grouped_fp8_gemm_masked_1d2d(co // Launch const SM90FP8Gemm1D2DRuntime::Args& args = { - .m = m, .n = n, .k = aligned_k, + .major_sfb = major_sfb, + .m = m, .n = n, .k = k, .num_groups = num_groups, .compiled_dims = compiled_dims, .epilogue_type = std::nullopt, @@ -263,4 +269,71 @@ static std::optional> sm90_m_grouped_fp8_gemm_masked_1d2d(co std::nullopt; } +static void sm90_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& batch_size, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto& config = get_best_config( + GemmType::Batched, KernelType::Kernel1D2D, + m, n, k, batch_size, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); + const auto& tensor_map_a = make_tma_3d_desc(a, k, m, batch_size, + config.block_k, load_block_m, 1, + a.stride(1), + a.stride(0), + config.smem_config.swizzle_a_mode); + + const int& load_block_n = SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); + const auto& tensor_map_b = make_tma_3d_desc(b, k, n, batch_size, + config.block_k, load_block_n, 1, + b.stride(1), + b.stride(0), + config.smem_config.swizzle_b_mode); + + const int& store_block_m = SM90ArchSpec::get_cd_store_block_m(config.block_m); + const int& store_block_n = SM90ArchSpec::get_cd_store_block_n(config.block_n); + const auto& tensor_map_d = make_tma_3d_desc(d, n, m, batch_size, + store_block_n, store_block_m, 1, + d.stride(1), d.stride(0), + config.smem_config.swizzle_cd_mode); + + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, batch_size, 0); + + // Launch + const SM90FP8Gemm1D2DRuntime::Args& args = { + .major_sfb = major_sfb, + .m = m, .n = n, .k = k, + .num_groups = batch_size, + .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code); + SM90FP8Gemm1D2DRuntime::launch(runtime, args); +} + } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp b/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp new file mode 100644 index 00000000..aeea2623 --- /dev/null +++ b/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp @@ -0,0 +1,152 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90BF16HCPrenormGemmRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k; + int block_m, block_n, block_k; + int num_splits; + int swizzle_cd_mode; + int num_stages; + int num_math_threads, num_tma_threads; + + LaunchArgs launch_args; + + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + float* sqr_sum; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_tf32_hc_prenorm_gemm_impl< + {}, {}, + {}, {}, {}, + {}, + {}, + {}, + {}, {} + >); +}}; +)", + args.n, args.k, + args.block_m, args.block_n, args.block_k, + args.num_splits, + args.swizzle_cd_mode, + args.num_stages, + args.num_math_threads, args.num_tma_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.m, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.sqr_sum)); + } +}; + +static void sm90_tf32_hc_prenorm_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& sqr_sum, + const int& m, const int& n, const int& k, + const int& num_splits) { + constexpr int block_m = 64; + constexpr int block_k = 64; + constexpr int num_math_threads = 128; + constexpr int num_tma_threads = 128; + constexpr int num_threads = num_math_threads + num_tma_threads; + + const int block_n = align(n, 16); + DG_HOST_ASSERT(n <= block_n); + // Only support small N for now + DG_HOST_ASSERT(n <= 32 and n % 8 == 0); + DG_HOST_ASSERT(k % block_k == 0); + + const auto& swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float)); + const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k, + block_m, block_k, + static_cast(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, a.element_size()), 0, + true); + const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k, + block_n, block_k, + static_cast(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, b.element_size()), 0, + true); + const auto& tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n, + block_m, block_n, + static_cast(d.stride(-2)), 1, + swizzle_cd_mode) + : make_tma_3d_desc(d, n, m, num_splits, + block_n, block_m, 1, + static_cast(d.stride(-2)), + static_cast(d.stride(-3)), + swizzle_cd_mode); + + // Calculate stages + int num_stages = 12, smem_size = 0; + while (num_stages > 0) { + const int smem_a_per_stage = block_m * block_k * static_cast(sizeof(nv_bfloat16)); + const int smem_b_per_stage = block_n * block_k * static_cast(sizeof(float)); + const int smem_cd = block_m * swizzle_cd_mode; + const int smem_barriers = num_stages * 2 * 8; + smem_size = (smem_a_per_stage + smem_b_per_stage) * num_stages + + smem_cd + smem_barriers; + + if (smem_size <= SM90ArchSpec::smem_capacity) + break; + -- num_stages; + } + DG_HOST_ASSERT(num_stages > 0); + + // Print configs + if (get_env("DG_JIT_DEBUG", 0)) { + printf("M: %d, N: %d, K: %d -> " + "block M: %d, block N: %d, block K: %d, split K: %d" + "stages: %d, shared memory: %d, swizzle CD: %d\n", + m, n, k, block_m, block_n, block_k, num_splits, + num_stages, smem_size, swizzle_cd_mode); + } + + smem_size = SM90ArchSpec::smem_capacity; + + // Launch + const SM90BF16HCPrenormGemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .block_m = block_m, .block_n = block_n, .block_k = block_k, + .num_splits = num_splits, + .swizzle_cd_mode = swizzle_cd_mode, + .num_stages = num_stages, + .num_math_threads = num_math_threads, + .num_tma_threads = num_tma_threads, + .launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_threads, smem_size, 1), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .sqr_sum = sqr_sum.data_ptr() + }; + const auto& code = SM90BF16HCPrenormGemmRuntime::generate(args); + const auto& runtime = compiler->build("sm90_tf32_hc_prenorm_gemm", code); + SM90BF16HCPrenormGemmRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/smxx_clean_logits.hpp b/csrc/jit_kernels/impls/smxx_clean_logits.hpp index cdb472d2..fdb91a03 100644 --- a/csrc/jit_kernels/impls/smxx_clean_logits.hpp +++ b/csrc/jit_kernels/impls/smxx_clean_logits.hpp @@ -13,7 +13,7 @@ class SMXXCleanLogitsRuntime final: public LaunchRuntime int next_n; int seq_len; int seq_len_kv; - uint64_t stride_kv; + uint64_t stride_logits; int* cu_seq_len_k_start; int* cu_seq_len_k_end; @@ -41,7 +41,7 @@ static void __instantiate_kernel() {{ static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, - args.seq_len, args.seq_len_kv, static_cast(args.stride_kv), + args.seq_len, args.seq_len_kv, static_cast(args.stride_logits), args.cu_seq_len_k_start, args.cu_seq_len_k_end, args.logits )); } @@ -52,7 +52,7 @@ static void smxx_clean_logits(const torch::Tensor& logits, const torch::Tensor& cu_seq_len_k_end, const int& next_n, const int& seq_len, const int& seq_len_kv, - const uint64_t &stride_kv) { + const uint64_t &stride_logits) { const int block_kv = 8192; const int num_warps = 8; const int smem_size = block_kv * sizeof(float); @@ -62,7 +62,7 @@ static void smxx_clean_logits(const torch::Tensor& logits, .next_n = next_n, .seq_len = seq_len, .seq_len_kv = seq_len_kv, - .stride_kv = stride_kv, + .stride_logits = stride_logits, .cu_seq_len_k_start = cu_seq_len_k_start.has_value() ? cu_seq_len_k_start.value().data_ptr() : nullptr, .cu_seq_len_k_end = cu_seq_len_k_end.data_ptr(), .logits = logits.data_ptr(), diff --git a/csrc/jit_kernels/impls/smxx_cublaslt.hpp b/csrc/jit_kernels/impls/smxx_cublaslt.hpp index 08816073..dc20e334 100644 --- a/csrc/jit_kernels/impls/smxx_cublaslt.hpp +++ b/csrc/jit_kernels/impls/smxx_cublaslt.hpp @@ -3,6 +3,11 @@ #include #include #include +#include + +#include "../../jit/device_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/compatibility.hpp" namespace deep_gemm { @@ -32,8 +37,6 @@ static void call_cublaslt_api(const cublasOperation_t& trans_a, const bool& accumulate) { cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; cudaDataType_t scale_type = CUDA_R_32F; - const int& math_sms = device_runtime->get_num_sms(); - bool fp8_fast_accumulate = false; // Operation description cublasLtMatmulDesc_t desc; @@ -41,9 +44,17 @@ static void call_cublaslt_api(const cublasOperation_t& trans_a, DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(trans_a))); DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b))); DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type))); + +#if DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE + const int& math_sms = device_runtime->get_num_sms(); DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sms, sizeof(math_sms))); +#endif + +#if DG_FP8_COMPATIBLE and DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE + bool fp8_fast_accumulate = false; if (a.scalar_type() == torch::kFloat8_e4m3fn) DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fp8_fast_accumulate, sizeof(fp8_fast_accumulate))); +#endif // Get cuBLASLt handle, workspace, and stream const auto& handle = device_runtime->get_cublaslt_handle(); diff --git a/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp b/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp index f5856399..f3b82e3d 100644 --- a/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp +++ b/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp @@ -9,16 +9,18 @@ namespace deep_gemm { -class SM90FP8MQALogitsRuntime final: public LaunchRuntime { +class SMXXFP8MQALogitsRuntime final: public LaunchRuntime { public: struct Args { int seq_len; int seq_len_kv; - int stride_kv; + int max_seqlen_k; + int stride_logits; int num_heads, head_dim; + bool is_compressed_logits; + int num_q_stages; int num_kv_stages; - int block_q; int block_kv; @@ -52,6 +54,7 @@ using namespace deep_gemm; static void __instantiate_kernel() {{ auto ptr = reinterpret_cast(&sm{}_fp8_mqa_logits< {}, {}, + {}, {}, {}, {}, {}, {}, {} @@ -59,6 +62,7 @@ static void __instantiate_kernel() {{ }}; )", arch, arch, args.num_heads, args.head_dim, + args.is_compressed_logits, args.block_q, args.block_kv, args.num_q_stages, args.num_kv_stages, args.num_specialized_threads, args.num_math_threads); @@ -66,7 +70,8 @@ static void __instantiate_kernel() {{ static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, - args.seq_len, args.seq_len_kv, static_cast(args.stride_kv), + args.seq_len, args.seq_len_kv, + args.max_seqlen_k, static_cast(args.stride_logits), args.cu_seq_len_k_start, args.cu_seq_len_k_end, args.logits, args.tensor_map_q, args.tensor_map_kv, @@ -81,18 +86,22 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q, const torch::Tensor& cu_seq_len_k_start, const torch::Tensor& cu_seq_len_k_end, const torch::Tensor& logits, - const int& seq_len, const int& seq_len_kv, const int& stride_kv, + const int& seq_len, const int& seq_len_kv, + const int& max_seqlen_k, const int& stride_logits, const int& num_heads, const int& head_dim, const int& seq_len_alignment) { constexpr int block_qh = 128; constexpr int block_kv = 256; constexpr int num_specialized_threads = 128; - constexpr int num_math_threads = 512; constexpr int num_q_stages = 3, num_kv_stages = 3; + const int num_math_threads = (device_runtime->get_arch_major() == 10 ? 256 : 512); const int block_q = block_qh / num_heads; DG_HOST_ASSERT(block_qh % num_heads == 0); DG_HOST_ASSERT(seq_len_alignment % block_q == 0); + // Use compressed logits format when max_seqlen_k is specified + const bool is_compressed_logits = (max_seqlen_k > 0); + // Construct TMAs DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128); const auto& tensor_map_q = make_tma_2d_desc(q, head_dim, seq_len * num_heads, @@ -120,13 +129,16 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q, smem_size += (num_q_stages * 2 + num_kv_stages * 2 + (num_math_threads / 128) * 2) * 8; smem_size += 4; DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); // Launch - const SM90FP8MQALogitsRuntime::Args& args = { + const SMXXFP8MQALogitsRuntime::Args& args = { .seq_len = seq_len, .seq_len_kv = seq_len_kv, - .stride_kv = stride_kv, + .max_seqlen_k = max_seqlen_k, + .stride_logits = stride_logits, .num_heads = num_heads, .head_dim = head_dim, + .is_compressed_logits = is_compressed_logits, .num_q_stages = num_q_stages, .num_kv_stages = num_kv_stages, .block_q = block_q, @@ -144,9 +156,9 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q, num_specialized_threads + num_math_threads, smem_size) }; - const auto& code = SM90FP8MQALogitsRuntime::generate(args); - const auto& runtime = compiler->build("sm90_fp8_mqa_logits", code); - SM90FP8MQALogitsRuntime::launch(runtime, args); + const auto& code = SMXXFP8MQALogitsRuntime::generate(args); + const auto& runtime = compiler->build("smxx_fp8_mqa_logits", code); + SMXXFP8MQALogitsRuntime::launch(runtime, args); } } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp b/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp index 38bbfb9d..1240aad8 100644 --- a/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp +++ b/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp @@ -14,8 +14,10 @@ class SMXXPagedMQALogitsMetadataRuntime final: public LaunchRuntime(sizeof(int)); DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); @@ -66,6 +72,8 @@ static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens, .split_kv = split_kv, .num_sms = num_sms, .batch_size = batch_size, + .next_n = next_n, + .is_context_lens_2d = is_context_lens_2d, .context_lens = context_lens.data_ptr(), .schedule_metadata = schedule_metadata.data_ptr(), .launch_args = LaunchArgs(1, num_threads, smem_size) @@ -83,6 +91,7 @@ class SMXXFP8PagedMQALogitsRuntime final: public LaunchRuntime(&sm{}_fp8_paged_mqa_logits< {}, {}, {}, {}, + {}, {}, {}, {}, {}, {} @@ -129,6 +139,7 @@ static void __instantiate_kernel() {{ )", arch, arch, args.next_n, args.num_heads, args.head_dim, args.block_kv, + args.is_context_lens_2d, args.num_q_stages, args.num_kv_stages, args.split_kv, args.num_specialized_threads, args.num_math_threads); @@ -158,17 +169,18 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, const int& batch_size, const int& next_n, const int& num_heads, const int& head_dim, const int& num_kv_blocks, const int& block_kv, + const bool& is_context_lens_2d, const int& kv_cache_stride_bytes, const int& logits_stride, const int& block_table_stride, const int& num_sms, - const int& num_math_warp_groups) { + const int& split_kv) { const int num_specialized_threads = 128; + const int mma_m = (device_runtime->get_arch_major() == 10 ? 128 : 64); + const int num_math_warp_groups = split_kv / mma_m; const int num_math_threads = num_math_warp_groups * 128; - const int num_extra_threads = device_runtime->get_arch_major() == 10 ? 128 : 0; - const int num_q_stages = 3, num_kv_stages = 3; - const int split_kv = num_math_warp_groups * block_kv; - DG_HOST_ASSERT(logits_stride % (num_math_warp_groups * block_kv) == 0); + const int num_q_stages = 3, num_kv_stages = (device_runtime->get_arch_major() == 10 ? 4 : 3); + DG_HOST_ASSERT(split_kv % mma_m == 0 and logits_stride % split_kv == 0); // Construct TMAs DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128); @@ -184,23 +196,39 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, next_n * num_heads, 1, next_n * num_heads, 0); // Calculate shared memory size - const int swizzle_alignment = head_dim * 8; - - const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast(q.element_size()); - const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast(weights.element_size()), swizzle_alignment); - const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment); - - const int smem_kv_size_per_stage = block_kv * head_dim * static_cast(kv_cache.element_size()); - const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast(kv_cache_scales.element_size()), swizzle_alignment); - const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment); - - // Allocate some shared memory for UMMA barriers and tensor memory pointer, although it is not used in SM90 - const int smem_umma_barriers = num_math_warp_groups * 2 * 8; - const int smem_tmem_ptr = 4; - - const int smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size + smem_umma_barriers + smem_tmem_ptr; - DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); - DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + int smem_size = 0; + if (device_runtime->get_arch_major() == 9) { + const int swizzle_alignment = head_dim * 8; + + const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast(q.element_size()); + const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast(weights.element_size()), swizzle_alignment); + const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment); + + const int smem_kv_size_per_stage = block_kv * head_dim * static_cast(kv_cache.element_size()); + const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast(kv_cache_scales.element_size()), swizzle_alignment); + const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment); + + // Allocate some shared memory for UMMA barriers and tensor memory pointer, although it is not used in SM90 + const int smem_umma_barriers = num_math_warp_groups * 2 * 8; + const int smem_tmem_ptr = 4; + + smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size + smem_umma_barriers + smem_tmem_ptr; + DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + } else { + const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast(q.element_size()); + const int smem_kv_size_per_stage = split_kv * head_dim * static_cast(kv_cache.element_size()); + const int smem_kv_scale_size_per_stage = split_kv * static_cast(kv_cache_scales.element_size()); + const int smem_weight_size_per_stage = next_n * num_heads * static_cast(weights.element_size()); + + const int smem_barriers = (num_q_stages + num_kv_stages) * 2 * 8; + const int smem_umma_barriers = num_math_warp_groups * 2 * 8; + const int smem_tmem_ptr = 4; + + smem_size = num_q_stages * (smem_q_size_per_stage + smem_weight_size_per_stage) + + num_kv_stages * (smem_kv_size_per_stage + smem_kv_scale_size_per_stage) + + smem_barriers + smem_umma_barriers + smem_tmem_ptr; + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + } // Launch const SMXXFP8PagedMQALogitsRuntime::Args& args = { @@ -209,6 +237,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, .num_heads = num_heads, .head_dim = head_dim, .block_kv = block_kv, + .is_context_lens_2d = is_context_lens_2d, .block_table_stride = block_table_stride, .logits_stride = logits_stride, .num_q_stages = num_q_stages, @@ -225,11 +254,11 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, .num_specialized_threads = num_specialized_threads, .num_math_threads = num_math_threads, .launch_args = LaunchArgs(num_sms, - num_specialized_threads + num_math_threads + num_extra_threads, + num_specialized_threads + num_math_threads, smem_size) }; const auto& code = SMXXFP8PagedMQALogitsRuntime::generate(args); - const auto& runtime = compiler->build("sm90_fp8_paged_mqa_logits", code); + const auto& runtime = compiler->build("smxx_fp8_paged_mqa_logits", code); SMXXFP8PagedMQALogitsRuntime::launch(runtime, args); } diff --git a/csrc/jit_kernels/impls/smxx_layout.hpp b/csrc/jit_kernels/impls/smxx_layout.hpp index 9a312ff0..c7d4ff4e 100644 --- a/csrc/jit_kernels/impls/smxx_layout.hpp +++ b/csrc/jit_kernels/impls/smxx_layout.hpp @@ -1,6 +1,7 @@ #pragma once #include "../../jit/kernel_runtime.hpp" +#include "../../jit/compiler.hpp" #include "../../utils/exception.hpp" #include "../../utils/format.hpp" #include "../../utils/math.hpp" diff --git a/csrc/python_api.cpp b/csrc/python_api.cpp index 14f3b15d..6376ea21 100644 --- a/csrc/python_api.cpp +++ b/csrc/python_api.cpp @@ -9,6 +9,7 @@ #include "apis/attention.hpp" #include "apis/einsum.hpp" +#include "apis/hyperconnection.hpp" #include "apis/gemm.hpp" #include "apis/layout.hpp" #include "apis/runtime.hpp" @@ -524,9 +525,20 @@ TORCH_LIBRARY(deep_gemm, m) { /* * einsum */ - m.def(R"(einsum(str expr, Tensor a, Tensor b, Tensor d, Tensor? c=None) -> ())"); - m.impl("einsum", torch::kCUDA, [](const std::string& expr, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d, const c10::optional& c) { - deep_gemm::einsum::einsum(expr, a, b, d, c); + m.def(R"(einsum(str expr, Tensor a, Tensor b, Tensor d, Tensor? c=None, bool use_cublaslt=False) -> ())"); + m.impl("einsum", torch::kCUDA, [](const std::string& expr, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d, const c10::optional& c, bool use_cublaslt) { + deep_gemm::einsum::einsum(expr, a, b, d, c, use_cublaslt); + }); + + m.def(R"(fp8_einsum(str expr, Any a, Any b, Tensor d, Tensor? c=None, int[] recipe=[1, 128, 128]) -> ())"); + m.impl("fp8_einsum", torch::kCUDA, [](const std::string& expr, + const c10::IValue& a_input, const c10::IValue& b_input, + const torch::Tensor& d, + const c10::optional& c, + c10::IntArrayRef recipe) { + auto [a_val, a_scale] = parse_tensor_or_tuple(a_input); + auto [b_val, b_scale] = parse_tensor_or_tuple(b_input); + deep_gemm::einsum::fp8_einsum(expr, {a_val, a_scale}, {b_val, b_scale}, d, c, to_recipe_tuple_default(recipe)); }); } diff --git a/csrc/utils/compatibility.hpp b/csrc/utils/compatibility.hpp new file mode 100644 index 00000000..9e2d6720 --- /dev/null +++ b/csrc/utils/compatibility.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include + +// `torch::kFloat8_e4m3fn` is supported since PyTorch 2.1 +#define DG_FP8_COMPATIBLE (TORCH_VERSION_MAJOR > 2 or (TORCH_VERSION_MAJOR == 2 and TORCH_VERSION_MINOR >= 1)) + +// `cuTensorMapEncodeTiled` is supported since CUDA Driver API 12.1 +#define DG_TENSORMAP_COMPATIBLE (CUDA_VERSION >= 12010) + +// `cublasGetErrorString` is supported since CUDA Runtime API 11.4.2 +#define DG_CUBLAS_GET_ERROR_STRING_COMPATIBLE (CUDART_VERSION >= 11042) + +// `CUBLASLT_MATMUL_DESC_FAST_ACCUM` and `CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET` are supported since CUDA Runtime API 11.8 +#define DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE (CUDART_VERSION >= 11080) \ No newline at end of file diff --git a/csrc/utils/exception.hpp b/csrc/utils/exception.hpp index b61bc09d..2aa27066 100644 --- a/csrc/utils/exception.hpp +++ b/csrc/utils/exception.hpp @@ -5,6 +5,8 @@ #include #include +#include "compatibility.hpp" + namespace deep_gemm { class DGException final : public std::exception { @@ -54,7 +56,7 @@ do { \ if (e != CUDA_SUCCESS) { \ std::stringstream ss; \ const char *name, *info; \ - cuGetErrorName(e, &name), cuGetErrorString(e, &info); \ + lazy_cuGetErrorName(e, &name), lazy_cuGetErrorString(e, &info); \ ss << static_cast(e) << " (" << name << ", " << info << ")"; \ throw DGException("CUDA driver", __FILE__, __LINE__, ss.str()); \ } \ @@ -74,6 +76,25 @@ do { \ #endif #ifndef DG_CUBLASLT_CHECK + +#if !DG_CUBLAS_GET_ERROR_STRING_COMPATIBLE +inline const char* cublasGetStatusString(cublasStatus_t status) { + switch(status) { + case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; + default: return "Unknown cuBLAS error"; + } +} +#endif + #define DG_CUBLASLT_CHECK(cmd) \ do { \ const auto& e = (cmd); \ diff --git a/csrc/utils/hash.hpp b/csrc/utils/hash.hpp index fad1231f..6bdc3f2c 100644 --- a/csrc/utils/hash.hpp +++ b/csrc/utils/hash.hpp @@ -4,7 +4,7 @@ namespace deep_gemm { -static uint64_t fnv1a(const std::string& data, const uint64_t& seed) { +static uint64_t fnv1a(const std::vector& data, const uint64_t& seed) { uint64_t h = seed; const uint64_t& prime = 0x100000001b3ull; for (const char& c: data) { @@ -14,7 +14,7 @@ static uint64_t fnv1a(const std::string& data, const uint64_t& seed) { return h; } -static std::string get_hex_digest(const std::string& data) { +static std::string get_hex_digest(const std::vector& data) { const auto& state_0 = fnv1a(data, 0xc6a4a7935bd1e995ull); const auto& state_1 = fnv1a(data, 0x9e3779b97f4a7c15ull); @@ -32,4 +32,8 @@ static std::string get_hex_digest(const std::string& data) { return oss.str(); } +static std::string get_hex_digest(const std::string& data) { + return get_hex_digest(std::vector{data.begin(), data.end()}); +} + } // namespace deep_gemm diff --git a/csrc/utils/layout.hpp b/csrc/utils/layout.hpp index 53c79dff..bde41711 100644 --- a/csrc/utils/layout.hpp +++ b/csrc/utils/layout.hpp @@ -35,22 +35,41 @@ static bool fp8_requires_k_major() { // Tensor utils template static auto get_shape(const torch::Tensor& t) { + DG_HOST_ASSERT(t.dim() == N); return [&t] (std::index_sequence) { return std::make_tuple(static_cast(t.sizes()[Is])...); }(std::make_index_sequence()); } +static std::tuple check_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) { + auto [mn, k] = get_shape<2>(ab); + if (ab.scalar_type() != torch::kFloat8_e4m3fn) { + DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10); + major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2); + } + return std::make_tuple(mn, k); +} + +static std::tuple check_grouped_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) { + auto [num_groups, mn, k] = get_shape<3>(ab); + if (ab.scalar_type() != torch::kFloat8_e4m3fn) { + DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10); + major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2); + } + return std::make_tuple(num_groups, mn, k); +} + // Recipe static std::tuple get_default_recipe(const torch::ScalarType& sfa_dtype, const torch::ScalarType& sfb_dtype) { - const auto& arch_major = device_runtime->get_arch_major(); + const auto arch_major = device_runtime->get_arch_major(); if (arch_major == 9) { DG_HOST_ASSERT(sfa_dtype == torch::kFloat and sfb_dtype == torch::kFloat); return {1, 128, 128}; } else if (arch_major == 10) { DG_HOST_ASSERT(sfb_dtype == torch::kFloat or sfb_dtype == torch::kInt); return sfb_dtype == torch::kFloat ? - std::make_tuple(1, 128, 128): // Legacy format or 1D2D kernels + std::make_tuple(1, 128, 128): // Legacy format std::make_tuple(1, 1, 128); // 1D1D kernels } DG_HOST_UNREACHABLE("Unknown recipe"); @@ -62,14 +81,14 @@ static torch::Tensor check_sf_layout(const torch::Tensor& sf, const int& gran_mn, const int& gran_k, const std::optional& num_groups, const bool& tma_stride_check = false, - const bool& contiguous_check = false, + const bool& sm90_sfb_check = false, const std::optional& type_check = std::nullopt) { // Type check if (type_check.has_value()) DG_HOST_ASSERT(sf.scalar_type() == type_check.value()); // Always do shape checks - const auto& sf_dtype = sf.scalar_type(); + const auto sf_dtype = sf.scalar_type(); DG_HOST_ASSERT(sf_dtype == torch::kFloat or sf_dtype == torch::kInt); DG_HOST_ASSERT(sf.dim() == static_cast(num_groups.has_value()) + 2); if (num_groups.has_value()) @@ -81,13 +100,18 @@ static torch::Tensor check_sf_layout(const torch::Tensor& sf, if (tma_stride_check) { if (num_groups.has_value()) DG_HOST_ASSERT(sf.stride(-3) == sf.stride(-1) * sf.size(-1)); - DG_HOST_ASSERT(sf.stride(-2) == 1); + // Check contiguity in the MN direction + DG_HOST_ASSERT(sf.stride(-2) == 1 or mn == 1); DG_HOST_ASSERT(sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size())); } - // Hopper SFB must be contiguous - if (contiguous_check) - DG_HOST_ASSERT(sf.is_contiguous()); + // SM90 SFB must be contiguous, or contiguous after transposing the last two dimensions + if (sm90_sfb_check) { + if (num_groups.has_value()) + DG_HOST_ASSERT(sf.stride(-3) == sf.size(-2) * sf.size(-1)); + DG_HOST_ASSERT((sf.stride(-1) == 1 and sf.stride(-2) == sf.size(-1)) or + (sf.stride(-1) == sf.size(-2) and sf.stride(-2) == 1)); + } return sf; } diff --git a/csrc/utils/math.hpp b/csrc/utils/math.hpp index ae75d10c..f70ecf0a 100644 --- a/csrc/utils/math.hpp +++ b/csrc/utils/math.hpp @@ -6,6 +6,9 @@ namespace deep_gemm { +// TODO: Use `torch::kFloat4_e2m1fn_x2` +constexpr auto kPackedFP4 = torch::kUInt8; + template static T ceil_div(const T& a, const T& b) { return (a + b - 1) / b; diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index df01daf1..a311555f 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -15,6 +15,11 @@ from . import deep_gemm_cpp # noqa: F401 # Registers ops into torch.ops without touching CUDA +# Legacy Triton kernels for A100 +try: + from . import legacy +except Exception as e: + print(f'Failed to load legacy DeepGEMM A100 Triton kernels: {e}') def _find_cuda_home() -> str: cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') @@ -76,6 +81,7 @@ def _fn(*args, **kwargs): bf16_gemm_tt = _wrap_op('bf16_gemm_tt') m_grouped_bf16_gemm_nt_contiguous = _wrap_op('m_grouped_bf16_gemm_nt_contiguous') m_grouped_bf16_gemm_nt_masked = _wrap_op('m_grouped_bf16_gemm_nt_masked') +bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked # cuBLASLt GEMMs cublaslt_gemm_nt = _wrap_op('cublaslt_gemm_nt') @@ -89,8 +95,9 @@ def _fn(*args, **kwargs): get_paged_mqa_logits_metadata = _wrap_op('get_paged_mqa_logits_metadata') fp8_paged_mqa_logits = _wrap_op('fp8_paged_mqa_logits') -# Einsum kernel +# Einsum kernels einsum = _wrap_op('einsum') +fp8_einsum = _wrap_op('fp8_einsum') # Layout kernels transform_sf_into_required_layout = _wrap_op('transform_sf_into_required_layout') @@ -120,7 +127,7 @@ def _verify_ops_loaded(): 'get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor', 'fp8_gemm_nt_skip_head_mid', 'fp8_mqa_logits', 'get_paged_mqa_logits_metadata', 'fp8_paged_mqa_logits', - 'einsum', + 'einsum', 'fp8_einsum', 'cublaslt_gemm_nt', 'cublaslt_gemm_nn', 'cublaslt_gemm_tn', 'cublaslt_gemm_tt', ] @@ -137,3 +144,5 @@ def _verify_ops_loaded(): if __debug__: _verify_ops_loaded() + +__version__ = '2.3.0' diff --git a/deep_gemm/include/deep_gemm/common/scheduler.cuh b/deep_gemm/include/deep_gemm/common/scheduler.cuh index 237f688c..f93b96ee 100644 --- a/deep_gemm/include/deep_gemm/common/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/common/scheduler.cuh @@ -5,7 +5,7 @@ namespace deep_gemm { -enum class KGroupedIndexType { +enum class IndexType { MN, K, SF_K, @@ -51,6 +51,8 @@ struct Scheduler { uint32_t current_group_idx = 0; // Only used for masked layout uint32_t current_m_cumsum = 0; + // Only used for countiguous psum layout + uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0; // Only used for k-grouped layout uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0; uint32_t next_group_idx, next_shape_k; @@ -70,14 +72,18 @@ struct Scheduler { num_m_blocks = ceil_div(shape_m, BLOCK_M); num_n_blocks = ceil_div(shape_n, BLOCK_N); current_shape_k = shape_k; - if constexpr (kGemmType == GemmType::Normal) { + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { num_blocks = num_m_blocks * num_n_blocks; - } else if (kGemmType == GemmType::MGroupedContiguous) { + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { num_blocks = num_m_blocks * num_n_blocks; this->grouped_layout = grouped_layout; - } else if (kGemmType == GemmType::MGroupedMasked) { + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { this->grouped_layout = grouped_layout; - } else if (kGemmType == GemmType::KGroupedContiguous) { + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + this->grouped_layout = grouped_layout; + current_psum_m = __ldg(grouped_layout); + num_m_blocks = ceil_div(current_psum_m, BLOCK_M); + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { this->grouped_layout = grouped_layout; get_next_k_group(current_group_idx, current_shape_k); next_group_idx = current_group_idx + 1; @@ -123,7 +129,7 @@ struct Scheduler { } } - template + template __device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, const uint32_t& block_idx, const uint32_t& m_block_idx = 0) { if constexpr (kGemmType == GemmType::Normal) { @@ -131,20 +137,24 @@ struct Scheduler { } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { const auto offset = kWithGroupOffset ? cute::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0; return offset * shape_dim + block_idx * block_size; - } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + } else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { const auto offset = kWithGroupOffset ? current_group_idx : 0; return offset * shape_dim + block_idx * block_size; } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { auto offset = 0; if constexpr (kWithGroupOffset) { - if constexpr (kIndexType == KGroupedIndexType::MN) + if constexpr (kIndexType == IndexType::MN) offset = current_group_idx * shape_dim; - else if constexpr (kIndexType == KGroupedIndexType::K) + else if constexpr (kIndexType == IndexType::K) offset = current_k_cumsum; - else if constexpr (kIndexType == KGroupedIndexType::SF_K) + else if constexpr (kIndexType == IndexType::SF_K) offset = current_sf_k_cumsum; } return offset + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::Batched) { + // Ignore kWithGroupOffset, and apply offset for IndexType::SF_K + const auto offset = kIndexType == IndexType::SF_K ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; } } @@ -168,7 +178,29 @@ struct Scheduler { } get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx); - } else if (kGemmType == GemmType::KGroupedContiguous) { + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + while (true) { + // Within current group + if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks) + break; + + // Move to check the next group + if (++ current_group_idx == kNumGroups) + return false; + + // NOTES: `num_m_blocks` varies with the increase of the group index + last_psum_m = align(current_psum_m, 128u); + current_psum_m = __ldg(grouped_layout + current_group_idx); + current_m_block_cumsum += num_m_blocks; + num_m_blocks = ceil_div(current_psum_m - last_psum_m, BLOCK_M); + } + + get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx); + + // NOTES: `last_psum_m` is aligned with 128 + m_block_idx += last_psum_m / BLOCK_M; + DG_STATIC_ASSERT(128 % BLOCK_M == 0, "Invalid BLOCK_M"); + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { while (true) { // End of the task if (current_group_idx == kNumGroups) @@ -189,6 +221,19 @@ struct Scheduler { } get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_m_blocks * num_n_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::Batched) { + if (next_block_idx >= num_blocks * kNumGroups) + return false; + + current_group_idx = next_block_idx / num_blocks; + const auto& block_idx = next_block_idx - current_group_idx * num_blocks; + if constexpr (kIsMulticastOnA) { + m_block_idx = block_idx / num_n_blocks; + n_block_idx = block_idx % num_n_blocks; + } else { + m_block_idx = block_idx % num_m_blocks; + n_block_idx = block_idx / num_m_blocks; + } } else { if (next_block_idx >= num_blocks) return false; @@ -207,7 +252,8 @@ struct Scheduler { __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { if (num_blocks_in_group == 1) return false; - if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::KGroupedContiguous) { + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or + kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched) { return true; } else { DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type"); @@ -224,12 +270,15 @@ struct Scheduler { // For SM90 only // ReSharper disable once CppNotAllPathsReturnValue __device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const { - if constexpr (kGemmType == GemmType::Normal) { + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { return true; } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0; } else if constexpr (kGemmType == GemmType::MGroupedMasked) { return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + current_group_idx); + } else { + // Unreachable + DG_TRAP_ONLY_DEVICE_ASSERT(false); } } }; diff --git a/deep_gemm/include/deep_gemm/common/sm100_utils.cuh b/deep_gemm/include/deep_gemm/common/sm100_utils.cuh index db1a364a..537cbe08 100644 --- a/deep_gemm/include/deep_gemm/common/sm100_utils.cuh +++ b/deep_gemm/include/deep_gemm/common/sm100_utils.cuh @@ -3,39 +3,13 @@ #include #include #include +#include #include +#include namespace deep_gemm::sm100 { -template -constexpr uint32_t get_inner_block_atom_size() { - return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); -} - -template -__device__ __forceinline__ void -tma_copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr, - dtype_t* smem_ptr, const uint32_t& inner_idx, const int32_t& outer_idx) { - DG_STATIC_ASSERT(1 <= kNumMulticast and kNumMulticast <= 2, "Invalid multicast config"); - DG_STATIC_ASSERT(static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL) == - static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint"); - - // 2-CTA function will send signals to the leader CTA only - const auto copy_func = kNumMulticast == 1 ? cute::SM90_TMA_LOAD_2D::copy : cute::SM100_TMA_2SM_LOAD_2D::copy; - - // Issue multiple TMAs - constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size(); - #pragma unroll - for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { - copy_func(desc_ptr, reinterpret_cast(barrier_ptr), - static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), - smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, inner_idx + i * BLOCK_INNER_ATOM, outer_idx); - } -} - __device__ __forceinline__ cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr, uint32_t stride_byte_offset, uint32_t leading_byte_offset) { @@ -123,7 +97,8 @@ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_id const auto& layout_type = to_umma_layout_type(); const auto& num_non_contiguous = 128 / get_atom_base(layout_type); if constexpr (kMajorMode == cute::UMMA::Major::K) { - // NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128 + // NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)` + // also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); // Atom size: 8 x `kSwizzleMode` (in bytes, on K) @@ -157,8 +132,8 @@ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_id } __device__ __forceinline__ -uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sf_id) { - desc.a_sf_id_ = sf_id, desc.b_sf_id_ = sf_id; +uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) { + desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id; return static_cast(static_cast(desc)) << 32; } @@ -180,6 +155,20 @@ __device__ __forceinline__ void tcgen05_after_thread_sync() { asm volatile("tcgen05.fence::after_thread_sync;"); } +__device__ __forceinline__ +void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier &mbarrier, void* smem_ptr, int col_idx, int4 row_idxs, uint64_t cache_hint) { + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "r"(mbarrier_addr), "l"(cache_hint) + : "memory" + ); +} + // UMMA versions with relaxed assertions struct SM100_MMA_F16BF16_SS { __device__ static void @@ -257,4 +246,21 @@ struct SM100_MMA_MXF8F6F4_2x1SM_SS { } }; +struct SM100_MMA_F16BF16_WS_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + } // namespace `deep_gemm::sm100` diff --git a/deep_gemm/include/deep_gemm/common/sm90_utils.cuh b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh index d910a2df..0874b675 100644 --- a/deep_gemm/include/deep_gemm/common/sm90_utils.cuh +++ b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh @@ -1,11 +1,14 @@ #pragma once -#include #include +#include #include #include +#include #include +#include +#include namespace deep_gemm::sm90 { @@ -93,43 +96,53 @@ struct BF16MMA { static constexpr int kNumAccum = M * N / 128; }; -template +template +constexpr cute::SM90::GMMA::Major to_sm90_major() { + DG_STATIC_ASSERT(kMajor == cute::UMMA::Major::K or kMajor == cute::UMMA::Major::MN, "Invalid major-ness"); + return kMajor == cute::UMMA::Major::K ? cute::SM90::GMMA::Major::K : cute::SM90::GMMA::Major::MN; +} + +template struct BF16MMASelector { static constexpr auto select_mma() { using namespace cute::SM90::GMMA; - if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS(); - if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS(); - if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS(); - if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS(); - if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS(); - if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS(); - if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS(); - if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS(); - if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS(); - if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS(); - if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS(); - if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS(); - if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS(); - if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS(); - if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS(); - if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS(); - if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS(); - if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS(); - if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS(); - if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS(); - if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS(); - if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS(); - if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS(); - if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS(); - if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS(); - if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS(); - if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS(); - if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS(); - if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS(); - if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS(); - if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS(); - if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS(); + constexpr auto kGMMAMajorA = to_sm90_major(); + constexpr auto kGMMAMajorB = to_sm90_major(); + if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS(); + if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS(); + if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS(); + if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS(); + if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS(); + if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS(); + if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS(); + if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS(); + if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS(); + if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS(); + if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS(); + if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS(); + if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS(); + if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS(); + if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS(); + if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS(); + if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS(); + if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS(); + if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS(); + if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS(); + if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS(); + if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS(); + if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS(); + if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS(); + if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS(); + if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS(); + if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS(); + if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS(); + if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS(); + if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS(); + if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS(); + if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS(); } static constexpr auto select_type() { @@ -139,6 +152,51 @@ struct BF16MMASelector { using type = decltype(select_type()); }; +template +struct TF32MMARS { + + template + __forceinline__ __device__ static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(reinterpret_cast(a), desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 8; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct TF32MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (kUseRS) { + if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN(); + if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN(); + if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN(); + if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN(); + if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN(); + if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN(); + DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N"); + } + } + + static constexpr auto select_type() { + if constexpr (kUseRS) { + return TF32MMARS(); + } else { + DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now"); + } + } + + using type = decltype(select_type()); +}; template struct SM90_U32x2_STSM_N { @@ -146,7 +204,7 @@ struct SM90_U32x2_STSM_N { copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" - :: "l"(smem_dst), "r"(src[0]), "r"(src[1])); + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1])); } }; @@ -155,7 +213,7 @@ struct SM90_U32x2_LDSM_N { copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) { asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(dst_0), "=r"(dst_1) - : "l"(smem_src)); + : "l"(__cvta_generic_to_shared(smem_src))); } }; @@ -164,7 +222,7 @@ struct SM90_U32x4_LDSM_N { copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) { asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3) - : "l"(smem_src)); + : "l"(__cvta_generic_to_shared(smem_src))); } }; @@ -186,47 +244,12 @@ __forceinline__ __device__ void warpgroup_wait() { asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); } -// TODO: replace with CUTLASS solution -union GmmaDescriptor { - __host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {} - - __host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {} - - __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {} - - __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {} - - __host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept { - desc_ = t.desc_; - return *this; - } - - __host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept { - desc_ = t.desc_; - return *this; - } - - uint64_t desc_; - uint32_t reg32_[2]; - uint16_t reg16_[4]; - - struct { - uint16_t start_address_: 14, : 2; - uint16_t leading_byte_offset_: 14, : 2; - uint16_t stride_byte_offset_: 14, : 2; - uint8_t : 1, base_offset_: 3, : 4; - uint8_t : 6, layout_type_: 2; - } bitfield; - - // Decay to an `uint64_t` - __host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; } -}; - template -__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, const int& layout_type, - const int& leading_byte_offset = 0, - const int& stride_byte_offset = 1024) { - GmmaDescriptor desc; +__device__ cute::GmmaDescriptor make_smem_desc(PointerType smem_ptr, const int& layout_type, + const int& leading_byte_offset = 0, + const int& stride_byte_offset = 1024) { + // NOTES: the default LBO and SBO are for K-major types + cute::GmmaDescriptor desc; const auto& uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); desc.bitfield.start_address_ = uint_ptr >> 4; desc.bitfield.layout_type_ = layout_type; @@ -236,48 +259,74 @@ __device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, const int& layout return desc; } -__device__ __forceinline__ void -tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, - const uint32_t& crd_0, const uint32_t& crd_1, const uint32_t& num_tma_multicast = 1) { - constexpr auto cache_hint = static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL); - if (num_tma_multicast == 1) { - cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1); - } else if (cute::block_rank_in_cluster() == 0) { - cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << num_tma_multicast) - 1, cache_hint, smem_ptr, crd_0, crd_1); - } +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); } -__device__ __forceinline__ void -tma_3d_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, - const uint32_t& crd_0, const uint32_t& crd_1, const uint32_t& crd_2) { - constexpr auto cache_hint = static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL); - cute::SM90_TMA_LOAD_3D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1, crd_2); +template +__device__ __forceinline__ +constexpr uint32_t get_gmma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size(); } -// Tensormap related -__device__ __forceinline__ void tensor_map_release_cta() { - asm volatile ("fence.proxy.tensormap::generic.release.cta;"); +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::SM90::GMMA::LayoutType to_gmma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 16) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 32) return cute::SM90::GMMA::LayoutType::B32; + if constexpr (kSwizzleMode == 64) return cute::SM90::GMMA::LayoutType::B64; + if constexpr (kSwizzleMode == 128) return cute::SM90::GMMA::LayoutType::B128; } -__device__ __forceinline__ void tensor_map_acquire_cta(const cute::TmaDescriptor* gmem_desc_ptr) { - auto gmem_int_desc = reinterpret_cast(gmem_desc_ptr); - asm volatile ("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" :: "l"(gmem_int_desc) : "memory"); +template +__device__ __forceinline__ +uint32_t advance_gmma_desc_lo(const uint32_t& base, const uint32_t& mn_idx, const uint32_t& k_idx, const uint32_t& offset = 0) { + return base + (((offset + mn_idx * BLOCK_K + k_idx * get_gmma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); } -__device__ __forceinline__ void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) { - auto smem_int_desc = static_cast(__cvta_generic_to_shared(smem_desc)); - const auto new_int64_addr = reinterpret_cast(new_addr); - asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr)); -} - -__device__ __forceinline__ void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) { - auto smem_int_desc = __cvta_generic_to_shared(smem_desc); - asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim)); -#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 3))) - asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride)); -#else - DG_STATIC_ASSERT(false, "Invalid CUDA version"); -#endif +template +__device__ __forceinline__ +cute::GmmaDescriptor make_gmma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_gmma_desc_stride_k(); + const auto& layout_type = to_gmma_layout_type(); + constexpr uint32_t num_non_contiguous = 128 / 16; + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128 + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } } } // namespace `deep_gemm::sm90` diff --git a/deep_gemm/include/deep_gemm/common/tma_utils.cuh b/deep_gemm/include/deep_gemm/common/tma_utils.cuh new file mode 100644 index 00000000..bd54adc2 --- /dev/null +++ b/deep_gemm/include/deep_gemm/common/tma_utils.cuh @@ -0,0 +1,116 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm { + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +__device__ __forceinline__ void +tma_copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr, + dtype_t* smem_ptr, const uint32_t& inner_idx, const uint32_t& outer_idx, + const uint32_t& num_tma_multicast = 1, const uint32_t& batch_idx = 0) { + DG_STATIC_ASSERT(static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL) == + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint"); + constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size(); + + if constexpr (not kIs3DTMA) { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } + #endif + } + } else { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } + #endif + } + } +} + +// Tensormap related +__device__ __forceinline__ void tensor_map_release_cta() { + asm volatile ("fence.proxy.tensormap::generic.release.cta;"); +} + +__device__ __forceinline__ void tensor_map_acquire_cta(const cute::TmaDescriptor* gmem_desc_ptr) { + auto gmem_int_desc = reinterpret_cast(gmem_desc_ptr); + asm volatile ("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" :: "l"(gmem_int_desc) : "memory"); +} + +__device__ __forceinline__ void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) { + auto smem_int_desc = static_cast(__cvta_generic_to_shared(smem_desc)); + const auto new_int64_addr = reinterpret_cast(new_addr); + asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr)); +} + +__device__ __forceinline__ void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) { + auto smem_int_desc = __cvta_generic_to_shared(smem_desc); + asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim)); +#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 3))) + asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride)); +#else + DG_STATIC_ASSERT(false, "Invalid CUDA version"); +#endif +} + +} // namespace `deep_gemm` diff --git a/deep_gemm/include/deep_gemm/common/types.hpp b/deep_gemm/include/deep_gemm/common/types.hpp index 23e73424..410c5469 100644 --- a/deep_gemm/include/deep_gemm/common/types.hpp +++ b/deep_gemm/include/deep_gemm/common/types.hpp @@ -2,13 +2,36 @@ namespace deep_gemm { +enum class MmaKind { + BF16 = 0, + MXFP8FP4 = 1, +}; + +constexpr __host__ __device__ int get_element_size(const MmaKind& mma_kind) { + switch (mma_kind) { + case MmaKind::BF16: return 2; + case MmaKind::MXFP8FP4: return 1; + default: return 0; + } +} + enum class GemmType { - Normal = 0, - MGroupedContiguous = 1, - MGroupedMasked = 2, - KGroupedContiguous = 3, + Normal = 0, + MGroupedContiguous = 1, + MGroupedMasked = 2, + KGroupedContiguous = 3, + Batched = 4, + MGroupedContiguousWithPsumLayout = 5, }; +constexpr __host__ __device__ bool is_m_grouped_contiguous(const GemmType& gemm_type) { + switch (gemm_type) { + case GemmType::MGroupedContiguous: return true; + case GemmType::MGroupedContiguousWithPsumLayout: return true; + default: return false; + } +} + enum class KernelType { Kernel1D1D = 0, Kernel1D2D = 1, diff --git a/deep_gemm/include/deep_gemm/common/utils.cuh b/deep_gemm/include/deep_gemm/common/utils.cuh index d590e614..ef098b31 100644 --- a/deep_gemm/include/deep_gemm/common/utils.cuh +++ b/deep_gemm/include/deep_gemm/common/utils.cuh @@ -100,52 +100,56 @@ __forceinline__ __device__ uint32_t get_lane_idx() { __device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) { uint32_t ret; - asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr))); return ret; } __device__ __forceinline__ float2 ld_shared(const float2* ptr) { float2 ret; - asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr)); + asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr))); return ret; } __device__ __forceinline__ float4 ld_shared(const float4* ptr) { float4 ret; - asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(ptr)); + asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); return ret; } __device__ __forceinline__ uint4 ld_shared(const uint4* ptr) { uint4 ret; - asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr)); + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); return ret; } __device__ __forceinline__ float ld_shared(const float* ptr) { float ret; - asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); + asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr))); return ret; } __device__ __forceinline__ void st_shared(const float* ptr, float val) { - asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val)); + asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val)); } __device__ __forceinline__ void st_shared(const float2* ptr, float2 val) { - asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(ptr), "f"(val.x), "f"(val.y)); + asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y)); } __device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { - asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val)); + asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val)); } __device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y) { - asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(ptr), "r"(x), "r"(y)); + asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y)); } __device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { - asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(ptr), "r"(x), "r"(y), "r"(z), "r"(w)); + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w)); +} + +__device__ __forceinline__ void st_shared(const __int128_t* ptr, __int128_t val) { + asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val)); } template diff --git a/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh index 6b235354..0227b3e8 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh @@ -14,10 +14,10 @@ using namespace deep_gemm::sm100; template = 1000)) or defined(__CLION_IDE__) + // Enlarge `BLOCK_K` for some cases + // NOTES: this is for reducing the `umma_arrive()` overhead + constexpr bool kDoMergeStages = + kNumStages_ >= 8 and kGemmType == GemmType::Normal and + kMajorA == cute::UMMA::Major::K and kMajorB == cute::UMMA::Major::K; + // Ensure there are at least `kNumMinStages` stages after merge + constexpr uint32_t kNumMinStages = 8; + constexpr uint32_t kNumStagesPerMerge = kDoMergeStages ? kNumStages_ / kNumMinStages : 1; + constexpr uint32_t BLOCK_K = BLOCK_K_ * kNumStagesPerMerge; + constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge; + using Barrier = cutlass::arch::ClusterTransactionBarrier; using Allocator = cute::conditional_t; @@ -40,10 +50,12 @@ sm100_bf16_gemm_impl(int* grouped_layout, // Configs constexpr uint32_t LAYOUT_AD_M = 128; - constexpr uint32_t kNumMWaves = BLOCK_M / LAYOUT_AD_M; + constexpr uint32_t WAVE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M; constexpr uint32_t kNumTMAStoreStages = 2; - DG_STATIC_ASSERT(BLOCK_K == 64, "Invalid block K"); - DG_STATIC_ASSERT(BLOCK_M % LAYOUT_AD_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K"); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + DG_STATIC_ASSERT(sizeof(cutlass::bfloat16_t) * LAYOUT_AD_M % kSwizzleAMode == 0, "Invalid swizzle A mode"); // Overwrite shape constants if the compiler gives shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; @@ -63,18 +75,25 @@ sm100_bf16_gemm_impl(int* grouped_layout, constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M; DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); - DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D"); DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); // Share memory sizes constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t); constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t); - DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B must be aligned to 1024 bytes"); DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + // NOTES: Make sure we have enough shared memory for UMMA padding + static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); + // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N) > 512 ? 1 : 2; @@ -87,9 +106,7 @@ sm100_bf16_gemm_impl(int* grouped_layout, if (warp_idx == 0 and cute::elect_one_sync()) { cute::prefetch_tma_descriptor(&tensor_map_a); cute::prefetch_tma_descriptor(&tensor_map_b); - cute::prefetch_tma_descriptor(&tensor_map_d); - if constexpr (kWithAccumulation) - cute::prefetch_tma_descriptor(&tensor_map_c); + cute::prefetch_tma_descriptor(&tensor_map_cd); } // D/A/B shared memory @@ -129,7 +146,7 @@ sm100_bf16_gemm_impl(int* grouped_layout, // Arrive at all CTAs tmem_full_barriers[i]->init(1); // Arrive only at the leader CTA - tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads); + tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); } if constexpr (kTensorCoreUtilControl < 100) tensor_core_full_barrier->init(1); @@ -162,25 +179,25 @@ sm100_bf16_gemm_impl(int* grouped_layout, // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); - #pragma unroll for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { // Wait consumer release empty_barriers[stage_idx]->wait(phase ^ 1); // Compute offsets // NOTES: the group is always concatenated with the outer dimension - uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> ( + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> ( shape_m, BLOCK_M, m_block_idx); - uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> ( + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> ( shape_n, BLOCK_N, n_block_idx, m_block_idx); // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major // And for all m-grouped GEMMs, A must be K-majored - DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kMajorA == cute::UMMA::Major::K, "Invalid major"); uint32_t k_idx = k_block_idx * BLOCK_K; - uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> ( + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); - uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> ( + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); // Add 2 CTA offsets @@ -190,14 +207,20 @@ sm100_bf16_gemm_impl(int* grouped_layout, } // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); if constexpr (kMajorA == cute::UMMA::Major::K) - tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx); + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, kNumMulticast, batch_idx); if constexpr (kMajorA == cute::UMMA::Major::MN) - tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx); + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, kNumMulticast, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::K) - tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx); + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, kNumMulticast, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::MN) - tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx); + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, kNumMulticast, batch_idx); // Arrive at full barriers constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; @@ -219,8 +242,10 @@ sm100_bf16_gemm_impl(int* grouped_layout, auto instr_desc = cute::UMMA::make_instr_desc(); DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); - auto a_desc = make_umma_desc(smem_a[0], 0, 0); - auto b_desc = make_umma_desc(smem_b[0], 0, 0); + // Merged stages only happens in NT normal GEMM cases + constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; @@ -271,10 +296,12 @@ sm100_bf16_gemm_impl(int* grouped_layout, if (cute::elect_one_sync()) { #pragma unroll for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { - b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + uint32_t atom_k_idx = k * UMMA_K / BLOCK_ATOM_K; + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); #pragma unroll for (uint32_t w = 0; w < kNumMWaves; ++ w) { - a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K); + DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset"); + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K + w * WAVE_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); mma_t::fma(a_desc, b_desc, accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, k_block_idx > 0 or k > 0, @@ -298,7 +325,7 @@ sm100_bf16_gemm_impl(int* grouped_layout, tensor_core_phase ^= 1; // Sleep for certain cycles - constexpr static uint64_t kNumUMMACycles = (2ull * BLOCK_M * BLOCK_N * BLOCK_K) / 8192ull; + constexpr static uint64_t kNumUMMACycles = (2ull * LAYOUT_AD_M * kNumMWaves * BLOCK_N * BLOCK_K) / 8192ull; constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl; const auto& start_clock = clock64(); if (cute::elect_one_sync()) @@ -314,7 +341,7 @@ sm100_bf16_gemm_impl(int* grouped_layout, const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); } - } else if (warp_idx >= kNumNonEpilogueThreads / 32) { + } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { // Epilogue warp groups const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); @@ -358,10 +385,10 @@ sm100_bf16_gemm_impl(int* grouped_layout, // Wait shared memory to be released if (epilogue_warp_idx == 0) cute::tma_store_wait(); - cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); // The pipeline stage - const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M; + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; // Store into shared memory @@ -421,25 +448,32 @@ sm100_bf16_gemm_impl(int* grouped_layout, // Synchronize all threads and issue TMA cute::tma_store_fence(); - cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { - using cute_tma_t = cute::conditional_t; - cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx); + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], + n_idx, m_idx, scheduler.current_group_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); + } cute::tma_store_arrive(); } } } } - // Deallocate tensor memory by warp 1 + // Deallocate tensor memory by the last UMMA store warp // NOTES: warp 0 is waiting TMA store - if (epilogue_warp_idx == 1) + if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) Allocator().free(0, kNumTmemCols); } #else if (blockIdx.x == 0 and threadIdx.x == 0) - DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a"); + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); #endif } diff --git a/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh b/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh index 8a0130ba..86303347 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh @@ -115,8 +115,8 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, // Issue TMAs if (cute::elect_one_sync()) { - tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M); - tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N); + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M); + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N); } // Arrive at full barriers @@ -258,8 +258,8 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, #else if (blockIdx.x == 0 and threadIdx.x == 0) - DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a"); + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); #endif } -} \ No newline at end of file +} diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index 322c0fd9..45a603ad 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -14,6 +14,7 @@ namespace deep_gemm { using namespace deep_gemm::sm100; template __global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) sm100_fp8_gemm_1d1d_impl(int* grouped_layout, @@ -31,8 +33,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, const __grid_constant__ cute::TmaDescriptor tensor_map_b, const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, - const __grid_constant__ cute::TmaDescriptor tensor_map_c, - const __grid_constant__ cute::TmaDescriptor tensor_map_d) { + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) using Barrier = cutlass::arch::ClusterTransactionBarrier; using Allocator = cute::conditional_t; @@ -43,18 +44,24 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Configs constexpr uint32_t LAYOUT_AD_M = 128; - constexpr uint32_t kNumMWaves = BLOCK_M / LAYOUT_AD_M; + constexpr uint32_t WAVE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M; constexpr uint32_t kNumTMAStoreStages = 2; - constexpr uint32_t kNumSFStagesPerLoad = sizeof(uint32_t) / sizeof(cutlass::float_ue8m0_t); constexpr uint32_t kNumUTCCPAlignedElems = 128; DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); - DG_STATIC_ASSERT(BLOCK_M % LAYOUT_AD_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + + constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4; + constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4; + DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A"); + DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B"); // Overwrite shape constants if the compiler gives shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; - const uint32_t shape_sf_k = ceil_div(shape_k, BLOCK_K * kNumSFStagesPerLoad); + const uint32_t shape_sfa_k = ceil_div(shape_k, kGranKA * 4); + const uint32_t shape_sfb_k = ceil_div(shape_k, kGranKB * 4); // Utils bool is_leader_cta = cute::block_rank_in_cluster() == 0; @@ -69,22 +76,29 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M; DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); - DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D"); DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); // Share memory sizes constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; - constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); - constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); constexpr uint32_t SF_BLOCK_M = constexpr_align(BLOCK_M, kNumUTCCPAlignedElems); constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems); constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); - DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B must be aligned to 1024 bytes"); DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + // NOTES: Make sure we have enough shared memory for UMMA padding + static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); + // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; @@ -103,9 +117,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, cute::prefetch_tma_descriptor(&tensor_map_b); cute::prefetch_tma_descriptor(&tensor_map_sfa); cute::prefetch_tma_descriptor(&tensor_map_sfb); - cute::prefetch_tma_descriptor(&tensor_map_d); - if constexpr (kWithAccumulation) - cute::prefetch_tma_descriptor(&tensor_map_c); + cute::prefetch_tma_descriptor(&tensor_map_cd); } // D/A/B shared memory @@ -113,10 +125,10 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); }); auto smem_a = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); }); auto smem_b = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); }); // SFA/SFB shared memory @@ -158,7 +170,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Arrive at all CTAs tmem_full_barriers[i]->init(1); // Arrive only at the leader CTA - tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads); + tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); } // Make initialized barrier visible in async proxy @@ -195,18 +207,19 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Compute offsets // NOTES: the group is always concatenated with the outer dimension - uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> ( + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> ( shape_m, BLOCK_M, m_block_idx); - uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> ( + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> ( shape_n, BLOCK_N, n_block_idx, m_block_idx); // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major // And for all m-grouped GEMMs, A must be K-majored - DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kMajorA == cute::UMMA::Major::K, "Invalid major"); uint32_t k_idx = k_block_idx * BLOCK_K; - uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> ( + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); - uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> ( + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( shape_k, BLOCK_K, k_block_idx, m_block_idx); // Add 2 CTA offsets @@ -216,25 +229,34 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, } // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); if constexpr (kMajorA == cute::UMMA::Major::K) - tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx); + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx); if constexpr (kMajorA == cute::UMMA::Major::MN) - tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx); + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::K) - tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx); + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::MN) - tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx); - auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx); + auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2) + + SMEM_B_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2); // Issue SFA and SFB TMAs at certain stages // No swizzling, so one TMA for one SF is enough - const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad; - if (sf_stage_in_group_idx == 0) { - tma_copy(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], m_block_idx * BLOCK_M, - scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad))); - tma_copy(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], n_block_idx * BLOCK_N, - scheduler.template get_global_idx(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad), m_block_idx)); - num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t); + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + tma_copy(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], m_block_idx * BLOCK_M, + scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::SF_K>(shape_sfa_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad))); + num_arrival_bytes += BLOCK_M * sizeof(uint32_t); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + tma_copy(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], n_block_idx * BLOCK_N, + scheduler.template get_global_idx(shape_sfb_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx)); + num_arrival_bytes += BLOCK_N * sizeof(uint32_t); } // Arrive at full barriers @@ -248,15 +270,14 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // TODO: refactor `UMMA_M` calculation constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); - constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); - auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); auto sf_desc = make_sf_desc(nullptr); DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); - auto a_desc = make_umma_desc(smem_a[0], 0, 0); - auto b_desc = make_umma_desc(smem_b[0], 0, 0); + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; @@ -301,19 +322,20 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Do SF copy at certain stages // NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves - const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad; - if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) { - using cute_utccp_t = cute::conditional_t; - - // SFA and SFB copy - // TODO: process shared memory descriptor by addition + // TODO: process shared memory descriptor by addition + using cute_utccp_t = cute::conditional_t; + const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad; + if (sfa_stage_in_group_idx == 0 and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; replace_smem_desc_addr(sf_desc, smem_ptr); cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); } + } + const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad; + if (sfb_stage_in_group_idx == 0 and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; @@ -325,16 +347,20 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Issue UMMA in the leader CTA using mma_t = cute::conditional_t; - const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sf_stage_in_group_idx); const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); if (cute::elect_one_sync()) { #pragma unroll for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { - b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx); + const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx); + const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id); + + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); #pragma unroll for (uint32_t w = 0; w < kNumMWaves; ++ w) { - a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K); + DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset"); + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, w * WAVE_BLOCK_M * BLOCK_K, k * UMMA_K); mma_t::fma(a_desc, b_desc, accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, k_block_idx > 0 or k > 0, @@ -378,11 +404,14 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, full_barriers[stage_idx]->wait(phase); // Transpose for UTCCP at certain stages - const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad; - if (sf_stage_in_group_idx == 0) { + if (k_block_idx % kNumSFAStagesPerLoad == 0) { #pragma unroll for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { #pragma unroll for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems); @@ -394,7 +423,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, with_sf_full_barriers[stage_idx]->arrive(0u); } } - } else if (warp_idx >= kNumNonEpilogueThreads / 32) { + } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { // Epilogue warp groups const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); @@ -438,10 +467,10 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Wait shared memory to be released if (epilogue_warp_idx == 0) cute::tma_store_wait(); - cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); // The pipeline stage - const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M; + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; const auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + s * STORE_BLOCK_N); // Store into shared memory @@ -500,25 +529,32 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Synchronize all threads and issue TMA cute::tma_store_fence(); - cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { - using cute_tma_t = cute::conditional_t; - cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx); + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], + n_idx, m_idx, scheduler.current_group_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); + } cute::tma_store_arrive(); } } } } - // Deallocate tensor memory by warp 1 + // Deallocate tensor memory by the last UMMA store warp // NOTES: warp 0 is waiting TMA store - if (epilogue_warp_idx == 1) + if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) Allocator().free(0, kNumTmemCols); } #else if (blockIdx.x == 0 and threadIdx.x == 0) - DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a"); + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); #endif } diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh deleted file mode 100644 index 658f883c..00000000 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh +++ /dev/null @@ -1,533 +0,0 @@ -#pragma once -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wunknown-attributes" - -#include -#include - -#include -#include -#include -#include - -namespace deep_gemm { - -using namespace deep_gemm::sm100; - -template -__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) -sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, - uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, - const __grid_constant__ cute::TmaDescriptor tensor_map_a, - const __grid_constant__ cute::TmaDescriptor tensor_map_b, - const __grid_constant__ cute::TmaDescriptor tensor_map_d, - const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) { -#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) - using Barrier = cutlass::arch::ClusterTransactionBarrier; - using Allocator = cute::conditional_t; - - // Scaling checks - DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); - - // Configs - constexpr uint32_t LAYOUT_AD_M = 128; - constexpr uint32_t kNumMWaves = BLOCK_M / LAYOUT_AD_M; - constexpr uint32_t kNumTMAStoreStages = 2; - DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); - DG_STATIC_ASSERT(BLOCK_M % LAYOUT_AD_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); - DG_STATIC_ASSERT(BLOCK_M == kNumEpilogueThreads, "Invalid block M"); - - // Overwrite shape constants if the compiler gives - shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; - shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; - shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; - const auto shape_k_scales = ceil_div(shape_k, BLOCK_K); - - // Utils - bool is_leader_cta = cute::block_rank_in_cluster() == 0; - const auto warp_idx = cutlass::canonical_warp_idx_sync(); - const auto lane_idx = get_lane_idx(); - - // Align to 1024 bytes for swizzle-128B - extern __shared__ __align__(1024) uint8_t smem_buffer[]; - - // 2-CTA MMA - constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); - constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); - constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); - constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); - DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); - DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D"); - DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); - - // Share memory sizes - // NOTES: do not use `LOAD_BLOCK_M` for SFA, as we need full SFA for promotion - constexpr bool kMustUseUniformedSFB = (BLOCK_K % BLOCK_N == 0); - constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = BLOCK_M * kSwizzleCDMode; - constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; - constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); - constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); - constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); - DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); - DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); - - // Must have 2 epilogue stages - constexpr uint32_t kNumEpilogueStages = 2; - - // Real tensor memory size and offsets - constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; - constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); - - // Prefetch TMA descriptors at the very beginning - if (warp_idx == 0 and cute::elect_one_sync()) { - cute::prefetch_tma_descriptor(&tensor_map_a); - cute::prefetch_tma_descriptor(&tensor_map_b); - cute::prefetch_tma_descriptor(&tensor_map_d); - cute::prefetch_tma_descriptor(&tensor_map_sfa); - } - - // Data on shared memory (layout as ordered below) - cd_dtype_t* smem_cd[kNumTMAStoreStages]; - cutlass::float_e4m3_t* smem_a[kNumStages]; - cutlass::float_e4m3_t* smem_b[kNumStages]; - float* smem_sfa[kNumStages]; - - // Fill D/A/B pointers - #pragma unroll - for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i) - smem_cd[i] = reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); - #pragma unroll - for (uint32_t i = 0; i < kNumStages; ++ i) { - smem_a[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); - smem_b[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); - } - - // Fill SFA/SFB - auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); - #pragma unroll - for (uint32_t i = 0; i < kNumStages; ++ i) - smem_sfa[i] = reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); - - // Fill barriers - auto barrier_start_ptr = reinterpret_cast(smem_buffer + - SMEM_CD_SIZE + - kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + - kNumStages * SMEM_SFA_SIZE_PER_STAGE); - auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); - auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); - auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); - auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); - - // Fill the tensor memory pointer - auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 2 + kNumEpilogueStages * 2); - DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); - - // Initialize barriers - if (warp_idx == 1 and cute::elect_one_sync()) { - #pragma unroll - for (uint32_t i = 0; i < kNumStages; ++ i) { - // Arrive at all CTAs - full_barriers[i]->init(kNumMulticast); - empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads / 32); - } - #pragma unroll - for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { - // Arrive at all CTAs - tmem_full_barriers[i]->init(1); - // Arrive only at the leader CTA - tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads); - } - - // Make initialized barrier visible in async proxy - cutlass::arch::fence_barrier_init(); - } else if (warp_idx == 2) { - // Allocate tensor memory - Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); - } - kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); - - // For pipeline unrolling - struct DivisibleK {}; - struct NotDivisibleK {}; - const uint32_t num_iterations = ceil_div(shape_k, kNumStages * BLOCK_K); - auto launch_k_iterations = [=](const auto& func) { - if constexpr (kNumLastStages == 0) { - for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) - func(k_iter, DivisibleK{}); - } else { - for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter) - func(k_iter, DivisibleK{}); - func(num_iterations - 1, NotDivisibleK{}); - } - }; - - // Block scheduler - uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); - - // Register configurations - constexpr uint32_t kNumNonEpilogueRegisters = 64; - constexpr uint32_t kNumEpilogueRegisters = 216; - DG_STATIC_ASSERT(kNumNonEpilogueRegisters * kNumNonEpilogueThreads + kNumEpilogueRegisters * kNumEpilogueThreads <= 65535, "Too many registers"); - - // Dispatch warps into different roles - if (warp_idx == 0) { - // Adjust registers - cutlass::arch::warpgroup_reg_dealloc(); - - // TMA load warp - // Persistently schedule over blocks - while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - launch_k_iterations([&](uint32_t k_iter, auto type) { - constexpr bool kHasDivisibleStages = cute::is_same_v; - constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - - #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Wait consumer release - empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); - - // Compute offsets - // NOTES: the group is always concatenated with the outer dimension - uint32_t m_idx = scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>( - shape_m, BLOCK_M, m_block_idx); - uint32_t n_idx = scheduler.get_global_idx<(kMajorB == cute::UMMA::Major::K)>( - shape_n, BLOCK_N, n_block_idx, m_block_idx); - - // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major - // And for all grouped GEMMs, A must be K-majored - DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kMajorA == cute::UMMA::Major::K, "Invalid major"); - uint32_t k_block_idx = k_iter * kNumStages + s; - uint32_t k_idx = k_block_idx * BLOCK_K; - uint32_t k_b_idx = scheduler.get_global_idx<(kMajorB == cute::UMMA::Major::MN)>( - shape_k, BLOCK_K, k_block_idx, m_block_idx); - - // Add 2 CTA offsets - if constexpr (kNumMulticast > 1) { - m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; - n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); - } - - // Issue TMAs - if (cute::elect_one_sync()) { - if constexpr (kMajorA == cute::UMMA::Major::K) - tma_copy(&tensor_map_a, full_barriers[s], smem_a[s], k_idx, m_idx); - if constexpr (kMajorA == cute::UMMA::Major::MN) - tma_copy(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_idx); - if constexpr (kMajorB == cute::UMMA::Major::K) - tma_copy(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx); - if constexpr (kMajorB == cute::UMMA::Major::MN) - tma_copy(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx); - - // Issue SFA TMA - tma_copy( - &tensor_map_sfa, full_barriers[s], - smem_sfa[s], m_block_idx * BLOCK_M, - scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>(shape_k_scales, 1, k_block_idx)); - } - - // Arrive at full barriers - constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE; - if (is_leader_cta and cute::elect_one_sync()) - full_barriers[s]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast); - if (not is_leader_cta and cute::elect_one_sync()) - full_barriers[s]->arrive(0u); - } - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); - if (is_leader_cta and cute::elect_one_sync()) - full_barriers[s]->arrive(); - if (not is_leader_cta and cute::elect_one_sync()) - full_barriers[s]->arrive(0u); - } - }); - } - } else if (warp_idx == 1 and is_leader_cta) { - // Adjust registers - cutlass::arch::warpgroup_reg_dealloc(); - - // MMA issue warp - // NOTES: only the leader CTA will do this - // Make instruction descriptor - // TODO: refactor `UMMA_M` calculation - constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); - constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); - constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); - auto instr_desc = cute::UMMA::make_instr_desc(); - auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); - - // Checks for MMA instructions - // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits - DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or - (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or - (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), - "Invalid MMA instruction shape"); - - // Persistently schedule over blocks - while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - // Launch MMAs - launch_k_iterations([&](uint32_t k_iter, auto type) { - constexpr bool kHasDivisibleStages = cute::is_same_v; - constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - - #pragma unroll - for (uint32_t s = 0; s < kNumStages; ++ s) { - // Wait TMA full - auto iter_idx = scheduler.current_iter * num_iterations + k_iter; - full_barriers[s]->wait(iter_idx & 1); - - // Wait tensor memory empty - auto accum_stage_idx = (iter_idx * kNumStages + s) % kNumEpilogueStages; - auto accum_stage_phase = ((iter_idx * kNumStages + s) / kNumEpilogueStages) & 1; - tmem_empty_barriers[accum_stage_idx]->wait(accum_stage_phase ^ 1); - - // Issue UMMA in the leader CTA - if (s < kNumInnerStages) { - using cute_mma_t = cute::conditional_t; - tcgen05_after_thread_sync(); - #pragma unroll - for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { - auto b_desc = make_umma_desc(smem_b[s], 0, k * UMMA_K); - #pragma unroll - for (uint32_t w = 0; w < kNumMWaves; ++ w) { - auto a_desc = make_umma_desc(smem_a[s], w * LAYOUT_AD_M, k * UMMA_K); - cute_mma_t::fma(a_desc, b_desc, - accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, - k > 0, - runtime_instr_desc); - } - } - tcgen05_before_thread_sync(); - } - - // Commit to the TMA empty and tensor memory full barrier - auto umma_arrive = [](const uint64_t* barrier) { - if constexpr (kNumMulticast == 1) { - cutlass::arch::umma_arrive(barrier); - } else { - constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; - cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); - } - }; - umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); - } - }); - } - } else if (warp_idx < kNumNonEpilogueThreads / 32) { - // Adjust registers - cutlass::arch::warpgroup_reg_dealloc(); - } else if (warp_idx >= kNumNonEpilogueThreads / 32) { - // Adjust registers - cutlass::arch::warpgroup_reg_alloc(); - - // Epilogue warp groups - const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads; - const auto epilogue_thread_idx_in_warpgroup = epilogue_thread_idx % 128; - const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); - const auto epilogue_warpgroup_idx = epilogue_thread_idx / 128; - - // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, - // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. - // NOTES: we also forbid two CTAs to share the same SM and its tensor memory - DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); - - // TMA checks - constexpr uint32_t kNumBankGroupBytes = 16; - constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); - DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); - DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); - - // Persistently schedule over blocks - while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - constexpr uint32_t kNumElemsPerLDTM = 16; - DG_STATIC_ASSERT(kNumElemsPerLDTM == 16 and BLOCK_N % kNumElemsPerLDTM == 0 and BLOCK_K % kNumElemsPerLDTM == 0, "Invalid LDTM width"); - - // SFB stuffs - uint32_t num_former_iters = BLOCK_N, num_full_iters = BLOCK_N; - if constexpr (not kMustUseUniformedSFB) { - num_former_iters = min(BLOCK_N, BLOCK_K - ((n_block_idx * BLOCK_N) % BLOCK_K)); - num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N); - } - num_former_iters /= kNumElemsPerLDTM, num_full_iters /= kNumElemsPerLDTM; - const auto sfb_offset = scheduler.get_global_idx(ceil_div(shape_n, BLOCK_K), 0, 0, m_block_idx); - const auto sfb_ptr = sfb + (sfb_offset + ((n_block_idx * BLOCK_N) / BLOCK_K)) * shape_k_scales; - - // Launch promotion - float accum[BLOCK_N] = {0}; - launch_k_iterations([&](uint32_t k_iter, auto type) { - constexpr bool kHasDivisibleStages = cute::is_same_v; - constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; - DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); - - #pragma unroll - for (uint32_t s = 0; s < kNumStages; ++ s) { - // Load SFB - float sf_0 = 0, sf_1 = 0; - if (s < kNumInnerStages) { - const auto k_block_idx = k_iter * kNumStages + s; - sf_0 = __ldg(sfb_ptr + k_block_idx); - sf_1 = num_former_iters < num_full_iters ? __ldg(sfb_ptr + k_block_idx + shape_k_scales) : 0; - } - - // Wait UMMA arrival - auto iter_idx = scheduler.current_iter * num_iterations + k_iter; - auto accum_stage_idx = (iter_idx * kNumStages + s) % kNumEpilogueStages; - auto accum_stage_phase = ((iter_idx * kNumStages + s) / kNumEpilogueStages) & 1; - tmem_full_barriers[accum_stage_idx]->wait(accum_stage_phase); - tcgen05_after_thread_sync(); - - // Commit to the TMA empty barrier for all CTAs after loading SFA - float sfa = s < kNumInnerStages ? ld_shared(smem_sfa[s] + epilogue_thread_idx) : 0; - sf_0 *= sfa, sf_1 *= sfa; - __syncwarp(); - if (lane_idx < kNumMulticast) - empty_barriers[s]->arrive(lane_idx); - __syncwarp(); - - // Do promotion like the SM90 kernel - if (s < kNumInnerStages) { - uint32_t values[kNumElemsPerLDTM]; - #pragma unroll - for (uint32_t i = 0; i < BLOCK_N / kNumElemsPerLDTM; ++ i) { - // Load from tensor memory - cute::SM100_TMEM_LOAD_32dp32b16x::copy( - accum_stage_idx * kNumMWaves * BLOCK_N + epilogue_warpgroup_idx * BLOCK_N + i * kNumElemsPerLDTM, - values[ 0], values[ 1], values[ 2], values[ 3], - values[ 4], values[ 5], values[ 6], values[ 7], - values[ 8], values[ 9], values[10], values[11], - values[12], values[13], values[14], values[15]); - cutlass::arch::fence_view_async_tmem_load(); - - // Promote - const auto sf = (kMustUseUniformedSFB or i < num_former_iters) ? sf_0 : sf_1; - #pragma unroll - for (uint32_t j = 0; j < kNumElemsPerLDTM; ++ j) - accum[i * kNumElemsPerLDTM + j] += *reinterpret_cast(&values[j]) * sf; - } - } - - // Commit to the tensor memory empty barrier (only at the leader CTA) - tcgen05_before_thread_sync(); - tmem_empty_barriers[accum_stage_idx]->arrive(0u); - } - }); - - // Flush TMA stores - // NOTES: for the first store, we have to flush all previous TMA, - // as we don't share pipeline stages between two blocks - if (epilogue_thread_idx_in_warpgroup == 0) - cute::tma_store_wait<0>(); - cutlass::arch::NamedBarrier::sync(STORE_BLOCK_M, epilogue_warpgroup_idx); - - // Write shared memory - DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); - - // Epilogue store and addition - // Issue every swizzled atom and pipeline: store shared, add C, and TMA store - constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; - #pragma unroll - for (uint32_t s = 0; s < kNumStores; ++ s) { - // Wait shared memory to be released - if (s >= kNumTMAStoreStages) { - if (epilogue_thread_idx_in_warpgroup == 0) - cute::tma_store_wait(); - cutlass::arch::NamedBarrier::sync(STORE_BLOCK_M, epilogue_warpgroup_idx); - } - - // The pipeline stage - const auto tma_stage_idx = s % kNumTMAStoreStages; - const auto m_idx = scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>(shape_m, BLOCK_M, m_block_idx); - const auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + s * STORE_BLOCK_N); - const auto local_smem_cd = smem_cd[tma_stage_idx] + epilogue_warpgroup_idx * STORE_BLOCK_M * STORE_BLOCK_N; - - // Store into shared memory - #pragma unroll - for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { - // Calculate the index of the bank group to be written in the atom - auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); - - // Reshape the atom in another view and swizzle - // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` - // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` - // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern - constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; - auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); - auto col = kHasShortcut ? (i) : (bank_group_index % 8); - col ^= row % (kSwizzleCDMode / 16); - - // Source and destination memory address - auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer - epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset - row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset - - // Load from tensor memory, store into shared memory - // NOTES: if you want to do accumulation, please notice that you need two accumulation barriers - const auto offset = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; - if constexpr (cute::is_same_v) { - // For FP32 output, read and store - DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); - st_shared(smem_ptr, - *reinterpret_cast(&accum[offset + 0]), - *reinterpret_cast(&accum[offset + 1]), - *reinterpret_cast(&accum[offset + 2]), - *reinterpret_cast(&accum[offset + 3])); - } else { - // For BF16 output, read, cast and store - DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); - st_shared(smem_ptr, - cast_into_bf16_and_pack(accum[offset + 0], accum[offset + 1]), - cast_into_bf16_and_pack(accum[offset + 2], accum[offset + 3]), - cast_into_bf16_and_pack(accum[offset + 4], accum[offset + 5]), - cast_into_bf16_and_pack(accum[offset + 6], accum[offset + 7])); - } - } - - // Synchronize all threads and issue TMA - cute::tma_store_fence(); - cutlass::arch::NamedBarrier::sync(STORE_BLOCK_M, epilogue_warpgroup_idx); - if (epilogue_thread_idx_in_warpgroup == 0) { - cute::SM90_TMA_STORE_2D::copy( - &tensor_map_d, local_smem_cd, - n_idx, m_idx + epilogue_warpgroup_idx * STORE_BLOCK_M); - cute::tma_store_arrive(); - } - } - } - - // Deallocate tensor memory by warp 1 - // NOTES: warp 0 is waiting TMA store - if (epilogue_warp_idx == 1) - Allocator().free(0, kNumTmemCols); - } - - // To safely deconstruct all barriers, we need a cluster sync - // TODO: optimize it by another round of barrier waits - if constexpr (kNumMulticast > 1) - cute::cluster_sync(); -#else - if (blockIdx.x == 0 and threadIdx.x == 0) - DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a"); -#endif -} - -}; // namespace deep_gemm - -#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh index 4a53421f..7058c40f 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh @@ -19,10 +19,12 @@ using namespace deep_gemm::sm100; template -__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads + 128, 1) + uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads, + uint32_t kNumMathWarpGroups = kNumMathThreads / 128> +__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, const uint64_t logits_stride, const uint64_t block_table_stride, const uint32_t* context_lens, float* logits, @@ -39,9 +41,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, const auto& lane_idx = get_lane_idx(); // Prefetch TMA descriptors - static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128; DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); - DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { cute::prefetch_tma_descriptor(&tensor_map_q); cute::prefetch_tma_descriptor(&tensor_map_kv); @@ -53,78 +53,58 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Shared memory configs static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float); static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment); - static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) + - constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment); - - static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); - static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment); - static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) + - constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment); - - static constexpr uint32_t SMEM_UMMA_SIZE = kNumMathWarpGroups * 2 * 8 + static_cast(sizeof(uint32_t)); // Align to swizzling alignment bytes extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); - // Q data and barriers on shared memory + // Q and KV data on shared memory auto smem_q = PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); }); - auto smem_weights = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i); - }); - auto q_barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); - auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; }); - auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); }); - - // Separate math warpgroups and tma load warps into KV groups - // Each math warpgroup corresponds to a tma load warp - const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0); - - // Per group KV data and barriers on shared memory - const auto& smem_kv_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx; auto smem_kv = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_kv_offset + SMEM_KV_SIZE_PER_STAGE * i); + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i); }); - auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast(smem_buffer + smem_kv_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); + constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages; + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); }); - auto kv_barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); - auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; }); - auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; }); - // UMMA barriers and TMEM pointer on shared memroy - auto umma_barrier_ptr = reinterpret_cast(smem_buffer + SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kNumMathWarpGroups); + // Barriers and TMEM pointer on shared memory + const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; }); auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; }); - auto tmem_ptr_in_smem = reinterpret_cast(umma_barrier_ptr + kNumMathWarpGroups * 2); + auto tmem_ptr_in_smem = reinterpret_cast(umma_barrier_ptr + kNumMathWarpGroups * 2); constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups; DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); - const bool& is_math_warp = (warp_idx < (kNumMathThreads / 32)); // 0 ~ 16 - const bool& is_tma_load_warp = (warp_idx >= (kNumMathThreads / 32) and warp_idx < (kNumMathThreads / 32 + 4)); // 16 ~ 20 - const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 4)); // 20 + const bool& is_math_warp = (warp_idx < kNumMathWarpGroups * 4); + const bool& is_tma_load_warp = (warp_idx == kNumMathWarpGroups * 4); + const bool& is_umma_warp = (warp_idx == kNumMathWarpGroups * 4 + 1); // Initialize barriers if (is_tma_load_warp and cute::elect_one_sync()) { - if (kv_group_idx == 0) { - #pragma unroll - for (uint32_t i = 0; i < kNumQStages; ++ i) { - full_q_barriers[i]->init(1); - empty_q_barriers[i]->init(kNumMathThreads); - } + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); } - if (kv_group_idx < kNumMathWarpGroups) { - #pragma unroll - for (uint32_t i = 0; i < kNumKVStages; ++ i) { - full_kv_barriers[i]->init(1); - empty_kv_barriers[i]->init(128); - } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); } cutlass::arch::fence_barrier_init(); } @@ -143,12 +123,13 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, __syncthreads(); // Register reconfigurations - constexpr uint32_t kNumSpecializedRegisters = 32; - constexpr uint32_t kNumMathRegisters = 104; + constexpr uint32_t kNumSpecializedRegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; // Scheduler - auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); - DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV"); + constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV; + auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`"); // Q and KV pipeline const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { @@ -160,21 +141,20 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, uint32_t q_iter_idx = 0, kv_iter_idx = 0; // UMMA settings - // Construct instruction with layout F - constexpr uint32_t UMMA_M = 64; + // Construct instruction with layout D + constexpr uint32_t UMMA_M = 128; constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); constexpr uint32_t UMMA_N = kNextN * kNumHeads; + DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); if (is_tma_load_warp) { // TMA warp-group for loading data cutlass::arch::warpgroup_reg_dealloc(); - if (kv_group_idx >= kNumMathWarpGroups) - return; const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { - if (kv_group_idx == 0 and cute::elect_one_sync()) { - tma_copy(&tensor_map_q, reinterpret_cast(full_q_barriers[stage_idx]), smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); - tma_copy(&tensor_map_weights, reinterpret_cast(full_q_barriers[stage_idx]), smem_weights[stage_idx], 0, q_idx); + if (cute::elect_one_sync()) { + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); } }; @@ -198,6 +178,14 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, kv_idx = next_kv_idx; num_kv = next_num_kv; + // Read KV block index + // TODO: deal with `-1`? + if (kv_idx == 0 or kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + kv_block_idx_storage = (kv_idx + lane_idx < num_kv ? __ldg(block_table + q_idx * block_table_stride + (kv_idx + lane_idx)) : 0); + } + DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`"); + // Wait Q consumer release and issue TMA Q if (prefetch_q) { CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); @@ -205,25 +193,26 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, issue_tma_q(q_stage_idx, q_idx + 1); } - // Read KV block index - // TODO: deal with `-1`? - if (kv_idx == 0 or kv_block_idx_ptr == 32) { - kv_block_idx_ptr = 0; - kv_block_idx_storage = (kv_idx + kv_group_idx + + lane_idx * kNumMathWarpGroups < num_kv ? - __ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0); - } - const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++); + int kv_block_idx[kNumBlocksPerSplit]; + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) + kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i); + kv_block_idx_ptr += kNumBlocksPerSplit; // Wait KV consumer release CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); - // Issue TMA KV if (cute::elect_one_sync()) { - tma_3d_copy(&tensor_map_kv, reinterpret_cast(full_kv_barriers[kv_stage_idx]), - smem_kv[kv_stage_idx], 0, 0, kv_block_idx); - tma_copy(&tensor_map_kv_scales, reinterpret_cast(full_kv_barriers[kv_stage_idx]), - smem_kv_scales[kv_stage_idx], 0, kv_block_idx); + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) { + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i, + 0, 0, 1, kv_block_idx[i]); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx] + BLOCK_KV * i, + 0, kv_block_idx[i]); + } full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } @@ -244,32 +233,29 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, uint32_t q_idx = batch_size, kv_idx; uint32_t next_q_idx, next_kv_idx, next_num_kv; uint32_t q_stage_idx, q_phase; - uint32_t umma_phase = 0; - - auto smem_kv = PatternVisitor([&](const uint32_t& stage_idx) { - return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_PIPE_SIZE + SMEM_KV_SIZE_PER_STAGE * stage_idx); - }); + uint32_t umma_phase = 1; while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { if (q_idx != next_q_idx) { CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); } q_idx = next_q_idx; kv_idx = next_kv_idx; CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); - DG_STATIC_ASSERT(BLOCK_KV == 64, "Invalid block size"); DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim"); - #pragma unroll for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { - empty_umma_barriers[i]->wait(umma_phase & 1); + empty_umma_barriers[i]->wait(umma_phase); + tcgen05_after_thread_sync(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { auto a_desc = make_umma_desc( - smem_kv[kv_stage_idx] + i * SMEM_KV_PIPE_SIZE, 0, k * UMMA_K); + smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); auto b_desc = make_umma_desc( smem_q[q_stage_idx], 0, k * UMMA_K); cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); @@ -284,10 +270,12 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Offsets const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); - float weights[kNextN][kNumHeads / 4]; - const auto& sub_warp_offset = (warp_idx % 4) * 16; - const auto& v_0_offset = lane_idx / 4 + 0; - const auto& v_1_offset = lane_idx / 4 + 8; + const uint32_t thread_idx = threadIdx.x; + + // Weights + constexpr uint32_t kNumWeightsInReg = (kNextN == 1 ? kNumHeads : cute::min(48, kNumHeads)); + float weights[kNextN][kNumWeightsInReg]; + DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none uint32_t q_idx = batch_size, kv_idx; @@ -309,9 +297,8 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Read weights #pragma unroll for (uint32_t i = 0; i < kNextN; ++ i) { - #pragma unroll - for (uint32_t j = 0; j < kNumHeads / 4; ++ j) - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); } } @@ -320,75 +307,82 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, kv_idx = next_kv_idx; // Calculate KV offset in advance - auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); + auto kv_offset = q_idx * kNextN * logits_stride + kv_idx * BLOCK_KV; - // Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]` + // Compute `[kNextN * kNumHeads, kHeadDim] @ [SPLIT_KV, kHeadDim] -> [kNextN, SPLIT_KV]` // Wait TMA KV arrival CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); full_kv_barriers[kv_stage_idx]->wait(kv_phase); // Read per-KV scales - auto scale_kv = make_float2(ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset), - ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset)); + float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + thread_idx); - empty_umma_barriers[warpgroup_idx]->arrive(); // Wait UMMA arrival - full_umma_barriers[warpgroup_idx]->wait(umma_phase & 1); + full_umma_barriers[warpgroup_idx]->wait(umma_phase); + tcgen05_after_thread_sync(); umma_phase ^= 1; // Release KV empty empty_kv_barriers[kv_stage_idx]->arrive(); // Reduce over the head dim and store - static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + constexpr uint32_t kNumLDTMElems = kNumHeads * kNextN; + uint32_t shifted_accum[kNumLDTMElems]; + DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid LDTM"); + auto tmem_load = [&](auto... Is) { + if constexpr (kNumLDTMElems == 32) { + cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 64) { + cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 128) { + cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); + } + }; + [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + + tcgen05_before_thread_sync(); + empty_umma_barriers[warpgroup_idx]->arrive(); + #pragma unroll for (uint32_t i = 0; i < kNextN; ++ i) { - // Load from the tensor memory - constexpr uint32_t kNumLDTMElems = UMMA_M * kNumHeads / 128; - uint32_t shifted_accum[kNumLDTMElems]; - DG_STATIC_ASSERT(kNumLDTMElems == 16 or kNumLDTMElems == 32 or kNumLDTMElems == 64, "Invalid LDTM"); - auto tmem_load = [&](auto... Is) { - if constexpr (kNumLDTMElems == 16) { - cute::SM100_TMEM_LOAD_16dp256b4x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 32) { - cute::SM100_TMEM_LOAD_16dp256b8x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 64) { - cute::SM100_TMEM_LOAD_16dp256b16x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...); - } - }; - [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); - cutlass::arch::fence_view_async_tmem_load(); - - // Transform - const auto& transform_2 = [&](const uint32_t& j, const uint32_t& k, const float2& sum) { - auto a = make_float2(fmaxf(*reinterpret_cast(&shifted_accum[j * 4 + k]), 0), - fmaxf(*reinterpret_cast(&shifted_accum[j * 4 + k + 2]), 0)); - auto b = make_float2(weights[i][j * 2 + k], weights[i][j * 2 + k]); - return __ffma2_rn(a, b, sum); - }; + auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); - // Intra-thread reduction auto sum_0 = make_float2(0, 0); auto sum_1 = make_float2(0, 0); + + const auto& transform_reg = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + #pragma unroll - for (uint32_t j = 0; j < kNumHeads / 8; ++ j) { - sum_0 = transform_2(j, 0, sum_0); - sum_1 = transform_2(j, 1, sum_1); + for (int j = 0; j < kNumWeightsInReg; j += 4) { + sum_0 = transform_reg(j, sum_0); + sum_1 = transform_reg(j + 2, sum_1); } - auto v = __fmul2_rn(__fadd2_rn(sum_0, sum_1), scale_kv); - // Inter-thread reduction + const auto& transform_smem = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), + ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); + return __ffma2_rn(a, b, sum); + }; + #pragma unroll - for (uint32_t j = 0; j < 2; ++ j) { - const auto& offset = 1u << j; - v.x += __shfl_xor_sync(0xffffffffu, v.x, offset); - v.y += __shfl_xor_sync(0xffffffffu, v.y, offset); + for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) { + sum_0 = transform_smem(j, sum_0); + sum_1 = transform_smem(j + 2, sum_1); } + + auto sum = __fadd2_rn(sum_0, sum_1); + float result = scale_kv * (sum.x + sum.y); + // Store into the global memory // NOTES: we have redundant writes here, consider more carefully - logits[kv_offset + i * logits_stride + v_0_offset] = v.x; - logits[kv_offset + i * logits_stride + v_1_offset] = v.y; + logits[kv_offset + i * logits_stride + thread_idx] = result; } } } else { diff --git a/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh new file mode 100644 index 00000000..4e4ff21d --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh @@ -0,0 +1,345 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__device__ __forceinline__ +uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) { + // Calculate the index of the bank group to be written in the atom + const auto& bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)` + // - new: `(BLOCK_N * kSwizzleMode / kSwizzleBase / kNumBankGroups, kNumBankGroups)` + constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase; + constexpr bool kHasShortcut = (kSwizzleMode / kSwizzleBase) == kNumBankGroups; + auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups); + auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups); + col ^= row % (kSwizzleMode / kSwizzleBase); + + return row * 128 + col * kSwizzleBase; +} + +template +__global__ void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1) +sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + float* sqr_sum) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Configs + constexpr uint32_t kNumCastStages = 2; + constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128); + constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128); + constexpr auto kMajorA = cute::UMMA::Major::K; + constexpr auto kMajorB = cute::UMMA::Major::K; + DG_STATIC_ASSERT(kNumCastStages <= kNumStages, "Invalid cast stages"); + DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N"); + DG_STATIC_ASSERT(kNumMMAThreads == 128, "Invalid MMA threads"); + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Real tensor memory size and offsets + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Data on shared memory (layout as ordered below) + // Fill D/A/B pointers + auto smem_cd = reinterpret_cast(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto full_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto empty_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4; + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 4 + 1); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + full_cast_barriers[i]->init(kNumCastAndReduceThreads); + empty_barriers[i]->init(1); + empty_cast_barriers[i]->init(1); + } + tmem_full_barrier->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; + constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; + const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); + const uint32_t m_block_idx = block_idx / kNumSplits; + const uint32_t k_split_idx = block_idx % kNumSplits; + const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K; + const uint32_t m_offset = shape_m * k_split_idx; + const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks); + + // Dispatch warps into different roles + if (warp_idx < kNumMMAThreads / 32) { + // TMA load warp + if (warp_idx == 0 and cute::elect_one_sync()) { + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + // Compute offsets + uint32_t m_idx = m_block_idx * BLOCK_M; + uint32_t k_idx = k_offset + s * BLOCK_K; + + // Issue TMAs + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + } + + // MMA issue warp + if (warp_idx == 1) { + // Make instruction descriptor + constexpr uint32_t UMMA_M = BLOCK_M; + constexpr uint32_t UMMA_N = BLOCK_N; + constexpr uint32_t UMMA_K = 32 / sizeof(float); + constexpr uint32_t BLOCK_SWIZZLED_BK = kSwizzleBMode / sizeof(float); + using umma_t = cute::SM100_MMA_TF32_TS; + auto instr_desc = cute::UMMA::make_instr_desc(); + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Launch MMAs + // We can not unroll this part + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + const auto& cast_stage_idx = s % kNumCastStages; + full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1); + tcgen05_after_thread_sync(); + + // Issue UMMA + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK; + const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK; + const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK; + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, offset, in_atom_idx); + umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc); + } + + // Commit + cutlass::arch::umma_arrive(reinterpret_cast(empty_cast_barriers[cast_stage_idx])); + cutlass::arch::umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + } + + // Commit to epilogue threads + cutlass::arch::umma_arrive(reinterpret_cast(tmem_full_barrier)); + } + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Only support layout F (M = 64) and D (M = 128) + DG_STATIC_ASSERT(BLOCK_M == 64 or BLOCK_M == 128, "Invalid block M"); + + // Wait UMMA arrival + tmem_full_barrier->wait(0); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough"); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Source and destination memory address + uint32_t tmem_addr = BLOCK_K * kNumCastStages + i * kNumElemsPerBankGroup; + auto smem_ptr = reinterpret_cast(smem_cd) + // Base pointer + warp_idx * BLOCK_M / 4 * kSwizzleCDMode + // Warp offset + get_swizzled_smem_offset(i, lane_idx); // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16)) + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + if constexpr (BLOCK_M == 64) + __syncwarp(); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumMMAThreads, 0); + if (warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kNumSplits == 1) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M); + } else { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx); + } + cute::tma_store_arrive(); + } + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is waiting TMA store + if (warp_idx == 1) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } else { + DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M"); + DG_STATIC_ASSERT(kNumCastAndReduceThreads == 128, "Invalid cast-and-reduce threads"); + constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4; + const uint32_t sub_warp_idx = warp_idx - kNumMMAThreads / 32; + + // TODO: make even larger block K + DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K"); + + // Launch reductions + float2 sum[2] = {float2{0, 0}, float2{0, 0}}; + #pragma unroll kNumStages + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + // Load from shared memory into tensor memory using movement shape `.16x256b` (shared memory part is 128b) + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16); + constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup; + const auto& smem_base_ptr = reinterpret_cast(smem_a[stage_idx]) + // Base pointer + sub_warp_idx * BLOCK_M_PER_WARP * kSwizzleAMode; // Warp offset + + // 4 lanes shared a bank group + uint32_t uint32_values[2][kNumLoads]; + DG_STATIC_ASSERT(kNumLoads % 2 == 0, "Invalid number of loads"); + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; i += 2) { + auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset(i + lane_idx / 16, lane_idx % 16); + sm90::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0], + uint32_values[0][i + 1], uint32_values[1][i + 1], + smem_ptr); + } + + // Wait tensor memory empty + const auto& cast_stage_idx = s % kNumCastStages; + empty_cast_barriers[cast_stage_idx]->wait(((s / kNumCastStages) & 1) ^ 1); + + // Cast, reduce and store into tensor memory + float2 fp32x2_values[2][kNumLoads]; + const auto& upper_view = reinterpret_cast(&fp32x2_values[0]); + const auto& lower_view = reinterpret_cast(&fp32x2_values[1]); + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; ++ i) { + #pragma unroll + for (uint32_t u = 0; u < 2; ++ u) { + fp32x2_values[u][i] = __bfloat1622float2(*reinterpret_cast(&uint32_values[u][i])); + sum[u] = __ffma2_rn(fp32x2_values[u][i], fp32x2_values[u][i], sum[u]); + } + + // Store upper and lower part at the same time + const auto idx_0 = i * 2, idx_1 = i * 2 + 1; + cute::SM100_TMEM_STORE_16dp256b1x::copy( + upper_view[idx_0], upper_view[idx_1], + lower_view[idx_0], lower_view[idx_1], + cast_stage_idx * BLOCK_K + i * 8); + } + cutlass::arch::fence_view_async_tmem_store(); + + // Arrive for issuing MMAs + tcgen05_before_thread_sync(); + full_cast_barriers[cast_stage_idx]->arrive(); + } + + // Intra-warp reduction and write back + #pragma unroll + for (uint32_t u = 0; u < 2; ++ u) { + const auto& reduced_sum = warp_reduce_sum<4>(sum[u].x + sum[u].y); + const auto& m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8; + if (lane_idx % 4 == 0 and m_idx < shape_m) + sqr_sum[m_offset + m_idx] = reduced_sum; + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +} // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh index 9186e683..7a77e4e8 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -18,26 +19,41 @@ namespace deep_gemm { using namespace deep_gemm::sm90; -template __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_bf16_gemm_impl(int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, - const __grid_constant__ cute::TmaDescriptor tensor_map_d) { + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Enlarge `BLOCK_K` for some cases + // NOTES: this is for reducing the `warpgroup_wait<0>()` overhead + constexpr uint32_t kDoMergeStages = + kNumStages_ >= 10 and + kGemmType == GemmType::Normal and + kMajorA == cute::UMMA::Major::K and kMajorB == cute::UMMA::Major::K and + kNumMathThreads == 128; + // Ensure there are at least `kNumMinStages` stages after merge + constexpr uint32_t kNumMinStages = 5; + constexpr uint32_t kNumStagesPerMerge = kDoMergeStages ? kNumStages_ / kNumMinStages : 1; + constexpr uint32_t BLOCK_K = BLOCK_K_ * kNumStagesPerMerge; + constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge; + // Types - using WGMMA = typename BF16MMASelector::type; + using WGMMA = typename BF16MMASelector::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; - DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); // Overwrite shape constants if the compiler gives shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; @@ -45,13 +61,15 @@ sm90_bf16_gemm_impl(int* grouped_layout, shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; // Shared memory - static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(cd_dtype_t); + static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(cd_dtype_t)), 1024u); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); + // NOTES: Make sure we have enough shared memory for WGMMA padding + static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3); + DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA"); + // Configs - constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; - const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages); const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); const uint32_t lane_idx = get_lane_idx(); @@ -59,37 +77,28 @@ sm90_bf16_gemm_impl(int* grouped_layout, if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { cute::prefetch_tma_descriptor(&tensor_map_a); cute::prefetch_tma_descriptor(&tensor_map_b); - cute::prefetch_tma_descriptor(&tensor_map_d); + cute::prefetch_tma_descriptor(&tensor_map_cd); } __syncwarp(); // Align to 1024 bytes for swizzle-128B extern __shared__ __align__(1024) uint8_t smem_buffer[]; - DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B/D must be aligned to 1024 bytes"); - // Data on shared memory + // D/A/B shared memory auto smem_d = reinterpret_cast(smem_buffer); - __nv_bfloat16* smem_a[kNumStages]; - __nv_bfloat16* smem_b[kNumStages]; - - // TMA Barrier for both divisible and non-divisible cases - Barrier* full_barriers[kNumStages]; - Barrier* empty_barriers[kNumStages]; - - // Fill shared memory pointers - #pragma unroll - for (uint32_t i = 0; i < kNumStages; ++ i) { - smem_a[i] = reinterpret_cast<__nv_bfloat16*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); - smem_b[i] = reinterpret_cast<__nv_bfloat16*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); - } + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); - #pragma unroll - for (uint32_t i = 0; i < kNumStages; ++ i) { - full_barriers[i] = barrier_start_ptr + i; - empty_barriers[i] = barrier_start_ptr + kNumStages + i; - } + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); // Initialize barriers if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { @@ -106,79 +115,83 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Synchronize all threads to make barrier visible in normal memory model (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); - struct DivisibleK {}; - struct NotDivisibleK {}; - auto launch_k_iterations = [=](const auto& func) { - if constexpr (kNumLastStages == 0) { - for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) - func(k_iter, DivisibleK{}); - } else { - for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter) - func(k_iter, DivisibleK{}); - func(num_iterations - 1, NotDivisibleK{}); - } - }; - // Register reconfigurations constexpr uint32_t kNumTMARegisters = 48; - constexpr uint32_t kNumMathRegisters = 224; + constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 224; // Block scheduler uint32_t m_block_idx, n_block_idx; auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + if (warp_idx >= kNumMathThreads / 32) { // TMA warp-group for loading data cutlass::arch::warpgroup_reg_dealloc(); // NOTES: only one thread (or warp) will be used - if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + // We use the third warp, as warp 0/1 may be doing WGMMA with `BLOCK_M == 32` + if (warp_idx == kNumMathThreads / 32 + 2 and cute::elect_one_sync()) { + DG_STATIC_ASSERT(kNumTMAThreads >= 128, "Need at least 128 threads for TMA warp-group"); + // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - launch_k_iterations([&](uint32_t k_iter, auto divisible_type) { - constexpr bool kHasDivisibleStages = cute::is_same_v; - constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; - - // Assign TMA multicast number into A and B - // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. - const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); - const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; - const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; - DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); - - // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all - // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant - #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Wait consumer release - empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); - - constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; - auto& full_barrier = *full_barriers[s]; - uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; - - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), - num_tma_multicast_a); - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), - num_tma_multicast_b); - full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); - } - - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); - full_barriers[s]->arrive(); - } - }); + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[stage_idx]; + + const auto m_idx = scheduler.template get_global_idx(shape_m, BLOCK_M, m_block_idx); + const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx); + + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy( + &tensor_map_a, &full_barrier, smem_a[stage_idx], k_a_idx, m_idx, num_tma_multicast_a, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_a, &full_barrier, smem_a[stage_idx], m_idx, k_a_idx, num_tma_multicast_a, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy( + &tensor_map_b, &full_barrier, smem_b[stage_idx], k_b_idx, n_idx, num_tma_multicast_b, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_b, &full_barrier, smem_b[stage_idx], n_idx, k_b_idx, num_tma_multicast_b, batch_idx); + + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + } } // To safely deconstruct distributed shared barriers, we need another round of empty waits if constexpr (kNumTMAMulticast > 1) { - #pragma unroll - for (uint32_t s = 0; s < kNumStages; ++ s) - empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1); + for (uint32_t i = 0; i < kNumStages; advance_pipeline(i)) + empty_barriers[stage_idx]->wait(phase ^ 1); } } } else { @@ -187,12 +200,24 @@ sm90_bf16_gemm_impl(int* grouped_layout, // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + + // Merged stages only happens in NT normal GEMM cases + constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; + auto a_desc = make_gmma_desc(smem_a[0], math_wg_idx * WGMMA::M, 0); + auto b_desc = make_gmma_desc(smem_b[0], 0, 0); + const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); + const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * (BLOCK_M <= 64 ? 1 : 2); + constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2; DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); float accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + // Pick threads whose WGMMA results are to be stored in shared memory + DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for `BLOCK_M < 64`"); + constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M); + const bool do_wgmma_store = BLOCK_M >= 64 or warp_idx < kNumWGMMAStoreThreads / 32; + // Empty barrier arrival auto empty_barrier_arrive = [&](uint32_t s) { if constexpr (kNumTMAMulticast == 1) { @@ -203,53 +228,42 @@ sm90_bf16_gemm_impl(int* grouped_layout, } }; - cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); + // TODO: remove some useless computation for unaligned Ms + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); - // Launch MMAs - launch_k_iterations([&](uint32_t k_iter, auto divisible_type) { - constexpr bool kHasDivisibleStages = cute::is_same_v; - constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; + // Wait TMA arrivals + full_barriers[stage_idx]->wait(phase); - // TODO: remove some useless computation for unaligned Ms + // Commit WGMMA instructions #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Wait TMA arrivals - full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); - + for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; #pragma unroll - for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { - auto m_offset = local_idx * WAVE_BLOCK_M; - auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; - - // Commit WGMMA instructions - #pragma unroll - for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, shifted_accum, 1); - } - warpgroup_commit_batch(); - #pragma unroll - for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + const uint32_t& atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K; + a_desc.reg32_[0] = advance_gmma_desc_lo( + a_desc_base_lo, local_idx * WAVE_BLOCK_M, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_M * BLOCK_ATOM_K); + b_desc.reg32_[0] = advance_gmma_desc_lo( + b_desc_base_lo, 0, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_N * BLOCK_ATOM_K); + WGMMA::wgmma(a_desc, b_desc, shifted_accum, 1); } - - // Notify barrier arrival - empty_barrier_arrive(s); } - - // Wait unaligned cases + warpgroup_commit_batch(); #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); - empty_barrier_arrive(s); - } - }); + for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(stage_idx); + } // TMA checks constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); @@ -260,12 +274,16 @@ sm90_bf16_gemm_impl(int* grouped_layout, "Unaligned TMA store or too many TMA store instructions"); DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + // Skip WGMMA store for the unfilled parts + if (not do_wgmma_store) + continue; + // Wait last TMA store to be finished if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) cute::tma_store_wait<0>(); - cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0); - if constexpr (std::is_same_v) { + if constexpr (cute::is_same_v) { // Write back to shared memory using STSM and issue TMA stores DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type"); DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); @@ -313,8 +331,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, ); } } - } - else { + } else { // Use `st.shared` if STSM is not available #pragma unroll for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { @@ -323,24 +340,31 @@ sm90_bf16_gemm_impl(int* grouped_layout, auto smem_d_0 = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 0) * BLOCK_N + (lane_idx % 4) * 2); auto smem_d_1 = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2); #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1])); st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3])); } } } cute::tma_store_fence(); - cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0); // Use TMA store to write back to global memory - constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked; - DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, - n_block_idx * BLOCK_N + in_block_n_offset, - scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + if constexpr (kGemmType == GemmType::Batched) { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_cd, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, + m_idx, scheduler.current_group_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, m_idx); + } cute::tma_store_arrive(); } __syncwarp(); diff --git a/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh b/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh index d40308eb..191a4fe2 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh @@ -106,10 +106,11 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, const uint32_t& k_idx = sk_idx % SHAPE_K; const uint32_t& s_idx = sk_idx / SHAPE_K; - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1); - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1); + constexpr uint32_t kSwizzle = BLOCK_K * sizeof(nv_bfloat16); + tma_copy( + &tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1); + tma_copy( + &tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); } } diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh index 4c57cbe0..cdd28fcb 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh @@ -21,6 +21,7 @@ using namespace deep_gemm::sm90; template = 900)) or defined(__CLION_IDE__) // Scaling checks DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid Threads"); @@ -73,7 +74,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, cute::prefetch_tma_descriptor(&tensor_map_b_base); cute::prefetch_tma_descriptor(&tensor_map_sfa); cute::prefetch_tma_descriptor(&tensor_map_sfb); - cute::prefetch_tma_descriptor(&tensor_map_d); + cute::prefetch_tma_descriptor(&tensor_map_cd); } __syncwarp(); @@ -189,8 +190,8 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, // Prepare next tensor map sum_k += scheduler.current_shape_k; if (scheduler.next_group_idx < kNumGroups) { - tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + sum_k * shape_m); - tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + sum_k * shape_n); + tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + static_cast(sum_k) * shape_m); + tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + static_cast(sum_k) * shape_n); tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); *(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]); @@ -217,10 +218,10 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, auto& full_barrier = *full_barriers[stage_idx]; const uint32_t& k_idx = k_block_idx * BLOCK_K; const uint32_t& sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx; - tma_copy(&tensor_map_sfa, reinterpret_cast(&full_barrier), smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a); - tma_copy(&tensor_map_sfb, reinterpret_cast(&full_barrier), smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b); - tma_copy(current_tensor_map_a, reinterpret_cast(&full_barrier), smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a); - tma_copy(current_tensor_map_b, reinterpret_cast(&full_barrier), smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b); + tma_copy(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a); + tma_copy(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b); + tma_copy(current_tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a); + tma_copy(current_tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE); } } @@ -330,7 +331,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, // Use TMA store to write back to global memory if (warp_idx % 4 == 0 and cute::elect_one_sync()) { cute::SM90_TMA_REDUCE_ADD_2D::copy( - &tensor_map_d, smem_d_0, n_block_idx * BLOCK_N, + &tensor_map_cd, smem_d_0, n_block_idx * BLOCK_N, current_group_idx * shape_m + m_block_idx * BLOCK_M + r_0); cute::tma_store_arrive(); } diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index ea4b5057..51c4d26c 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -30,10 +30,11 @@ __device__ void dispatch_num_former_iters(uint32_t num_former_iters, const func_ dispatch_num_former_iters(num_former_iters, func); } -template ::type; using Barrier = cutlass::arch::ClusterTransactionBarrier; - DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); // Overwrite shape constants if the compiler gives shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; @@ -63,13 +64,19 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int *signal, // Shared memory static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); - static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(__nv_bfloat16)), 1024u); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u); const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K); + const uint32_t& shape_n_sfb = ceil_div(shape_n, BLOCK_K); const uint32_t& smem_sfb_size = align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); + // NOTES: Make sure we have enough shared memory for WGMMA padding + static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3); + DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA"); + // Configs const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K); const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); @@ -98,9 +105,9 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int *signal, }); constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); auto smem_sfa = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE); + return reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); }); - auto smem_sfb = reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE); + auto smem_sfb = reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); // Fill barriers auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_sfb) + smem_sfb_size); @@ -127,7 +134,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int *signal, // Register reconfigurations constexpr uint32_t kNumTMARegisters = 40; - constexpr uint32_t kNumMathRegisters = 232; + constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232; // Block scheduler uint32_t m_block_idx, n_block_idx; @@ -148,7 +155,8 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int *signal, cutlass::arch::warpgroup_reg_dealloc(); // NOTES: only one thread (or warp) will be used - if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + // We use the third warp, as warp 0/1 may be doing WGMMA with `BLOCK_M == 32` + if (warp_idx == kNumMathThreads / 32 + 2 and cute::elect_one_sync()) { // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { // Assign TMA multicast number into A and B @@ -163,20 +171,23 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int *signal, empty_barriers[stage_idx]->wait(phase ^ 1); // Issue TMA A + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; auto& full_barrier = *full_barriers[stage_idx]; const uint32_t k_idx = k_block_idx * BLOCK_K; - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + tma_copy(&tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), - num_tma_multicast_a); - tma_copy(&tensor_map_sfa, reinterpret_cast(&full_barrier), - smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.get_global_idx(shape_k_scales, 1, k_block_idx), + num_tma_multicast_a, batch_idx); + tma_copy(&tensor_map_sfa, &full_barrier, + smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx(shape_k_scales, 1, k_block_idx), num_tma_multicast_a); // Issue TMA B - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + tma_copy(&tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), - num_tma_multicast_b); + num_tma_multicast_b, batch_idx); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); } } @@ -214,18 +225,26 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int *signal, // Load B scales with math warp-groups // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks if (threadIdx.x >= 32) { - auto num_previous_lines = scheduler.get_global_idx(ceil_div(shape_n, BLOCK_K), 0, 0, m_block_idx); - auto local_sfb = sfb + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * shape_k_scales; + auto previous_group_offset = scheduler.template get_global_idx(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx); + const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales; + const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1; + auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb; + #pragma unroll for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32) - st_shared(smem_sfb + i, __ldg(local_sfb + i)); + st_shared(smem_sfb + i, __ldg(i < shape_k_scales ? local_sfb + i * stride_k_sfb : local_sfb + (i - shape_k_scales) * stride_k_sfb + stride_n_sfb)); } cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); // Accumulation for WGMMA or CUDA promotion - constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * (BLOCK_M <= 64 ? 1 : 2); + constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2; DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + + // Pick threads whose WGMMA results are to be stored in shared memory + DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for `BLOCK_M < 64`"); + constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M); + const bool do_wgmma_store = BLOCK_M >= WGMMA::M or warp_idx < kNumWGMMAStoreThreads / 32; // Empty barrier arrival auto empty_barrier_arrive = [&]() { @@ -267,8 +286,8 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int *signal, // Read A scales // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0 + m_offset); - auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1 + m_offset); + auto scale_a_0 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0; + auto scale_a_1 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0; // Commit WGMMA instructions #pragma unroll @@ -291,6 +310,10 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int *signal, if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) empty_barrier_arrive(); + // Skip promotion for the unfilled parts + if (not do_wgmma_store) + continue; + // Promote with scales // NOTES: making it as predicates is very important for performance, comparing to two loops float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; @@ -302,7 +325,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int *signal, #pragma unroll for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant - bool predicate = kMustUseUniformedScaleB or i < num_former_iters; + const bool& predicate = kMustUseUniformedScaleB or i < num_former_iters; shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; @@ -328,10 +351,14 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int *signal, "Unaligned TMA store or too many TMA store instructions"); DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + // Skip WGMMA store for the unfilled parts + if (not do_wgmma_store) + continue; + // Wait last TMA store to be finished if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) cute::tma_store_wait<0>(); - cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1); // Write back to shared memory using STSM and issue TMA stores DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); @@ -380,18 +407,23 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int *signal, } } cute::tma_store_fence(); - cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1); // Use TMA store to write back to global memory // TODO: compatible with FP32 output constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked; - DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, - epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + in_block_n_offset), - scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + in_block_n_offset); + auto m_idx = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx); + if constexpr (kGemmType == GemmType::Batched) { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_ptr, + n_idx, m_idx, scheduler.current_group_idx); + } else { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, n_idx, m_idx); + } cute::tma_store_arrive(); } __syncwarp(); diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh index 52f4be6d..d58c7162 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh @@ -27,11 +27,13 @@ static constexpr int to_swizzle_cute_type() { } template __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) -void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_kv, +void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, + const uint32_t max_seqlen_k, const uint64_t stride_logits, uint32_t* cu_seq_len_k_start, uint32_t* cu_seq_len_k_end, float* logits, @@ -125,6 +127,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, cons const auto& get_next_block_q_idx = [&]() -> cute::tuple { return {block_q_idx + gridDim.x, q_iter_idx + 1}; }; + uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { uint32_t start = cute::numeric_limits::max(); uint32_t end = cute::numeric_limits::min(); @@ -132,8 +135,10 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, cons #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); - start = min(start, min(__ldg(cu_seq_len_k_start + q_idx), seq_len_kv)); - end = max(end, min(__ldg(cu_seq_len_k_end + q_idx), seq_len_kv)); + seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); + seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); + start = min(start, min(seq_k_start[i], seq_len_kv)); + end = max(end, min(seq_k_end[i], seq_len_kv)); } start = start / 4 * 4; return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage @@ -160,8 +165,8 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, cons // Prefetch const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { - tma_copy(&tensor_map_q, reinterpret_cast(full_q_barriers[stage_idx]), smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); - tma_copy(&tensor_map_weights, reinterpret_cast(full_q_barriers[stage_idx]), smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); }; if (cute::elect_one_sync() and block_q_idx < num_q_blocks) @@ -187,9 +192,9 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, cons empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); // Issue TMA KV - tma_copy(&tensor_map_kv, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); - tma_copy(&tensor_map_kv_scales, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } @@ -299,8 +304,15 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, cons // Store into the global memory // NOTES: we have redundant writes here, consider more carefully const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; - logits[q_idx * stride_kv + kv_offset + v_0_offset] = v_0; - logits[q_idx * stride_kv + kv_offset + v_1_offset] = v_1; + if constexpr (kIsCompressedLogits) { + if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i]) + logits[q_idx * stride_logits + kv_offset + v_0_offset - seq_k_start[i]] = v_0; + if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i]) + logits[q_idx * stride_logits + kv_offset + v_1_offset - seq_k_start[i]] = v_1; + } else { + logits[q_idx * stride_logits + kv_offset + v_0_offset] = v_0; + logits[q_idx * stride_logits + kv_offset + v_1_offset] = v_1; + } } } num_total_kv_blocks += num_kv_blocks; diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh index edba79d1..482a85a8 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh @@ -14,14 +14,17 @@ namespace deep_gemm { template __global__ __launch_bounds__(32, 1) -void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t* context_lens, uint32_t* schedule_metadata) { +void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d, + const uint32_t* context_lens, uint32_t* schedule_metadata) { DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size"); const uint32_t lane_idx = get_lane_idx(); uint32_t num_segs[kAlignedBatchSize / 32]; #pragma unroll for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { - const uint32_t& context_len = (k * 32 + lane_idx < batch_size ? __ldg(context_lens + k * 32 + lane_idx) : 0); + const uint32_t q_idx = k * 32 + lane_idx; + const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx); + const uint32_t& context_len = (q_idx < batch_size ? __ldg(context_lens + lens_idx) : 0); num_segs[k] = ceil_div(context_len, SPLIT_KV); } @@ -54,7 +57,8 @@ void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t* c } } -template +template struct PagedMQALogitsScheduler { uint32_t batch_size; const uint32_t* context_lens; @@ -63,6 +67,11 @@ struct PagedMQALogitsScheduler { uint32_t end_q_idx, end_kv_idx; uint32_t current_num_kv; + __device__ __forceinline__ uint32_t get_num_kv(const uint32_t& q_idx) { + const auto& lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx); + return q_idx < batch_size ? ceil_div(__ldg(context_lens + lens_idx), BLOCK_KV) : 0; + } + __device__ __forceinline__ explicit PagedMQALogitsScheduler(const uint32_t& batch_size, const uint32_t& sm_idx, const uint32_t* context_lens, const uint32_t* schedule_meta) { this->batch_size = batch_size; @@ -70,10 +79,10 @@ struct PagedMQALogitsScheduler { const auto& current_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx); const auto& end_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx + 1); - current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumMathWarpGroups; - end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumMathWarpGroups; + current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit; + end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit; - current_num_kv = current_q_idx < batch_size ? ceil_div(__ldg(this->context_lens + current_q_idx), BLOCK_KV) : 0; + current_num_kv = get_num_kv(current_q_idx); } __device__ __forceinline__ bool fetch_next_task(uint32_t &q_idx, uint32_t &kv_idx, uint32_t &num_kv) { @@ -84,11 +93,11 @@ struct PagedMQALogitsScheduler { if (q_idx == end_q_idx and kv_idx == end_kv_idx) return false; - current_kv_idx += kNumMathWarpGroups; + current_kv_idx += kNumBlocksPerSplit; if (current_kv_idx >= current_num_kv) { ++ current_q_idx; current_kv_idx = 0; - current_num_kv = current_q_idx < batch_size ? ceil_div(__ldg(this->context_lens + current_q_idx), BLOCK_KV) : 0; + current_num_kv = get_num_kv(current_q_idx); } return true; @@ -103,6 +112,7 @@ using namespace deep_gemm::sm90; template @@ -209,7 +219,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, constexpr uint32_t kNumMathRegisters = 104; // Scheduler - auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV"); // Q and KV pipeline @@ -229,8 +239,8 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { if (kv_group_idx == 0 and cute::elect_one_sync()) { - tma_copy(&tensor_map_q, reinterpret_cast(full_q_barriers[stage_idx]), smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); - tma_copy(&tensor_map_weights, reinterpret_cast(full_q_barriers[stage_idx]), smem_weights[stage_idx], 0, q_idx); + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); } }; @@ -265,7 +275,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, // TODO: deal with `-1`? if (kv_idx == 0 or kv_block_idx_ptr == 32) { kv_block_idx_ptr = 0; - kv_block_idx_storage = (kv_idx + kv_group_idx + + lane_idx * kNumMathWarpGroups < num_kv ? + kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ? __ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0); } const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++); @@ -276,10 +286,10 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, // Issue TMA KV if (cute::elect_one_sync()) { - tma_3d_copy(&tensor_map_kv, reinterpret_cast(full_kv_barriers[kv_stage_idx]), - smem_kv[kv_stage_idx], 0, 0, kv_block_idx); - tma_copy(&tensor_map_kv_scales, reinterpret_cast(full_kv_barriers[kv_stage_idx]), - smem_kv_scales[kv_stage_idx], 0, kv_block_idx); + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], 0, kv_block_idx); full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } diff --git a/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh new file mode 100644 index 00000000..e3bf9847 --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh @@ -0,0 +1,287 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__device__ __forceinline__ +uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) { + constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase; + + const auto& bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange; + + constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase; + constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups; + auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups); + auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups); + col ^= row % kGroupsInSwizzleRange; + + return (row * kNumBankGroups + col) % kGroupsInSwizzleRange; +} + +template +__global__ void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1) +sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + float* sqr_sum) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // kSwizzleAMode and kSwizzleBMode must be 128 for now + constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128); + constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128); + DG_STATIC_ASSERT(BLOCK_K == 64, "Invalid block K"); + DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode"); + DG_STATIC_ASSERT(kSwizzleBMode == 128, "Invalid swizzle B mode"); + + DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N"); + DG_STATIC_ASSERT(kNumMathThreads == 128, "Invalid MMA threads"); + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Data on shared memory (layout as ordered below) + // Fill D/A/B pointers + auto smem_cd = reinterpret_cast(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(128); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; + constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; + const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); + const uint32_t m_block_idx = block_idx / kNumSplits; + const uint32_t k_split_idx = block_idx % kNumSplits; + const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K; + const uint32_t m_offset = shape_m * k_split_idx; + const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks); + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = 256; + + // TMA load warp + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cutlass::arch::warpgroup_reg_dealloc(); + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + // Compute offsets + uint32_t m_idx = m_block_idx * BLOCK_M; + uint32_t k_idx = k_offset + s * BLOCK_K; + + // Issue TMAs + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + + for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) { + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + } + } else if (warp_idx < kNumMathThreads / 32) { + cutlass::arch::warpgroup_reg_alloc(); + + DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K"); + constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4; + constexpr uint32_t WGMMA_M = 64; + constexpr uint32_t WGMMA_N = BLOCK_N; + constexpr uint32_t WGMMA_K = 8; + + using WGMMA = typename TF32MMASelector::type; + float accum[WGMMA::kNumAccum] = {0}; + + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16); + constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup; + float sqr_sum_acc_0 = 0; + float sqr_sum_acc_1 = 0; + + #pragma unroll kNumStages < 8 ? kNumStages : kNumStages / 2 + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + constexpr uint32_t kNumRegPerWgmma = WGMMA::M * WGMMA::K / 128; + constexpr uint32_t kNumWgmmaPerBlockK = BLOCK_K / WGMMA::K; + + float a[kNumRegPerWgmma * kNumWgmmaPerBlockK]; + // Assume swizzle A mode is 128 + DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode"); + + // Load BF16 A fragment from shared memory into registers, and transpose to FP32 + uint32_t row = warp_idx * 16 + lane_idx / 4; + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; ++ i) { + // Refer to the A layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-a + uint32_t bank_group_idx = (row ^ i) % 8; + nv_bfloat16* a_bf16_smem_ptr_upper = smem_a[stage_idx] + row * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup; + nv_bfloat16* a_bf16_smem_ptr_lower = smem_a[stage_idx] + (row + 8) * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup; + + uint32_t elem_offset = lane_idx % 4; + nv_bfloat16 a_bf16[kNumRegPerWgmma]; + a_bf16[0] = a_bf16_smem_ptr_upper[elem_offset]; + a_bf16[2] = a_bf16_smem_ptr_upper[elem_offset + 4]; + a_bf16[1] = a_bf16_smem_ptr_lower[elem_offset]; + a_bf16[3] = a_bf16_smem_ptr_lower[elem_offset + 4]; + + auto a_bf16x2_ptr = reinterpret_cast(a_bf16); + auto a_float2_ptr = reinterpret_cast(a); + float2 a_float2_0 = __bfloat1622float2(a_bf16x2_ptr[0]); + float2 a_float2_1 = __bfloat1622float2(a_bf16x2_ptr[1]); + a_float2_ptr[i * 2 + 0] = a_float2_0; + a_float2_ptr[i * 2 + 1] = a_float2_1; + sqr_sum_acc_0 += a_float2_0.x * a_float2_0.x + a_float2_1.x * a_float2_1.x; + sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y; + } + + warpgroup_wait<0>(); + if (s > 0) + empty_barriers[(s - 1) % kNumStages]->arrive(); + + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + + constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float); + constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K; + DG_STATIC_ASSERT(BLOCK_K % kNumElemsInSwizzleRange == 0, "Invalid block K"); + + #pragma unroll + for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) { + #pragma unroll + for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) { + auto b_desc = make_smem_desc(smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1); + WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1); + } + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + } + + const auto& reduced_sum_0 = warp_reduce_sum<4>(sqr_sum_acc_0); + const auto& reduced_sum_1 = warp_reduce_sum<4>(sqr_sum_acc_1); + + const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4); + if (lane_idx % 4 == 0) { + if (m_idx < shape_m) + sqr_sum[m_offset + m_idx] = reduced_sum_0; + if (m_idx + 8 < shape_m) + sqr_sum[m_offset + m_idx + 8] = reduced_sum_1; + } + warpgroup_wait<0>(); + empty_barriers[(num_total_stages-1) % kNumStages]->arrive(); + + // Write accum to shared memory + // Every 2 threads (one pair) will write to the same bank group (16 bytes). + // Refer to the D layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-d + uint32_t is_odd_pair = lane_idx / 2 % 2; + + // Four threads per group; write the data to the same row. + uint32_t row_idx = lane_idx / 4; + + // Even/odd index pairs write to the same column, we need to reorder idx: + // group even pair indices consecutively, and likewise for odd ones. + uint32_t reordered_pair_idx = is_odd_pair * 8 + row_idx; + + auto shifted_smem_ptr = reinterpret_cast(smem_cd) + + (warp_idx * BLOCK_M_PER_WARP + row_idx) * kSwizzleCDMode + // Row offset, each warp has 16 rows + lane_idx % 2 * 8; // One thread of a pair writes 8 bytes + + #pragma unroll + for (uint32_t i = 0; i < (kSwizzleCDMode / sizeof(float)) / 4; i += 2) { + // Get the swizzled bank group index (16 bytes per group) + uint32_t bank_group_idx = get_swizzled_bank_group_idx(i + is_odd_pair, reordered_pair_idx); + auto smem_ptr = shifted_smem_ptr + bank_group_idx * kNumBankGroupBytes; // Col offset, 16 bytes per group + + // 0/1 write to the same row, 2/3 write to another row + auto values = reinterpret_cast(accum + i * 2); + st_shared(smem_ptr, values[0], values[1]); + st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(128, 1); + + // Issue TMA stores + if (warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kNumSplits == 1) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M); + } else { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx); + } + cute::tma_store_arrive(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +} // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh b/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh index b5d24e7d..cc9e5e6b 100644 --- a/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh @@ -9,7 +9,7 @@ namespace deep_gemm { template __global__ __launch_bounds__(kNumWarps * 32, 1) -void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_kv, +void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_logits, const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, float* logits) { const uint32_t& num_sms = gridDim.x; const uint32_t& sm_idx = blockIdx.x; @@ -40,14 +40,14 @@ void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) { - const auto& right = min(left + BLOCK_KV, static_cast(stride_kv)); + const auto& right = min(left + BLOCK_KV, static_cast(stride_logits)); if (right <= ks or ke <= left) { - cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_kv + left, (right - left) * sizeof(float)); + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(float)); } else { if (left < aligned_ks) - cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_kv + left, (aligned_ks - left) * sizeof(float)); + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(float)); if (aligned_ke < right) - cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_kv + aligned_ke, (right - aligned_ke) * sizeof(float)); + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(float)); } } } @@ -58,9 +58,9 @@ void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; for (uint32_t j = aligned_ks; j < ks; ++ j) - logits[i * stride_kv + j] = neg_inf; + logits[i * stride_logits + j] = neg_inf; for (uint32_t j = ke; j < aligned_ke; ++ j) - logits[i * stride_kv + j] = neg_inf; + logits[i * stride_logits + j] = neg_inf; } } diff --git a/deep_gemm/legacy/__init__.py b/deep_gemm/legacy/__init__.py new file mode 100644 index 00000000..cce39ec7 --- /dev/null +++ b/deep_gemm/legacy/__init__.py @@ -0,0 +1,5 @@ +# All kernels may be deprecated in the future (or rewrite in TileLang) +from .m_grouped_gemm import * +from .a_fused_m_grouped_gemm import * +from .a_fused_k_grouped_gemm import * +from .b_fused_k_grouped_gemm import * diff --git a/deep_gemm/legacy/a_fused_k_grouped_gemm.py b/deep_gemm/legacy/a_fused_k_grouped_gemm.py new file mode 100644 index 00000000..7b42f152 --- /dev/null +++ b/deep_gemm/legacy/a_fused_k_grouped_gemm.py @@ -0,0 +1,88 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_k_grouped_gemm_configs(), key=[], restore_value=['d_ptr']) +@triton.jit +def a_fused_k_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + k_indices_ptr, k_start_ptr, k_end_ptr, + M: tl.constexpr, + N: tl.constexpr, + K, + ACC: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_b = (pid // (num_pid_m * num_pid_n)).to(tl.int64) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M) + n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N) + m_mask = (m_range < M)[:, None] + n_mask = (n_range < N)[None, :] + + k_start = tl.load(k_start_ptr + pid_b) + k_end = tl.load(k_end_ptr + pid_b) + if k_start >= k_end: + if not ACC: + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=m_mask & n_mask) + return + + # Compute + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(k_start, k_end, BLOCK_SIZE_K): + k_range = k + tl.arange(0, BLOCK_SIZE_K) + rows = tl.load(k_indices_ptr + k_range).to(tl.int64) + a_ptrs = a_ptr + m_range[:, None] + rows[None, :] * M + + b_ptrs = b_ptr + k_range[:, None].to(tl.int64) * N + n_range[None, :] + a = tl.load(a_ptrs, mask=(rows >= 0)[None, :] & m_mask, other=0) + b = tl.load(b_ptrs, mask=n_mask, other=0) + acc = tl.dot(a, b, acc) + + # Write back + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + if ACC: + acc += tl.load(d_ptrs, mask=m_mask & n_mask) + acc = acc.to(d_ptr.dtype.element_ty) + tl.store(d_ptrs, acc, mask=m_mask & n_mask) + + +def a_fused_k_grouped_bf16_gemm_tn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + handle: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], acc: bool): + k_indices, k_start, k_end = handle + + assert a.is_contiguous() and b.is_contiguous() and d.is_contiguous() + assert k_indices.is_contiguous() and k_start.is_contiguous() and k_end.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 + assert k_indices.dtype == torch.int32 and k_start.dtype == torch.int32 and k_end.dtype == torch.int32 + assert a.dim() == 2 and b.dim() == 2 and d.dim() == 3 + assert k_start.numel() == k_end.numel() and k_indices.size(0) == b.size(0) + assert d.size(0) == k_start.numel() and d.size(1) == a.size(1) and d.size(2) == b.size(1) + assert b.size(0) % get_mk_alignment_for_contiguous_layout() == 0 + + K_, M = a.shape + K, N = b.shape + B = k_start.numel() + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']) * B,) + a_fused_k_grouped_bf16_gemm_contiguous_tl_impl[grid]( + a, b, d, k_indices, k_start, k_end, M, N, K, ACC=acc) diff --git a/deep_gemm/legacy/a_fused_m_grouped_gemm.py b/deep_gemm/legacy/a_fused_m_grouped_gemm.py new file mode 100644 index 00000000..3f1f5294 --- /dev/null +++ b/deep_gemm/legacy/a_fused_m_grouped_gemm.py @@ -0,0 +1,92 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_m_grouped_gemm_configs(), key=[]) +@triton.jit +def a_fused_m_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + m_indices_ptr, m_row_indices_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + IS_B_K_MAJOR: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M) + n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N) + n_mask = (n_range < N)[None, :] + + batch_id = tl.load(m_indices_ptr + pid_m * BLOCK_SIZE_M).to(tl.int64) + if batch_id < 0: + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=n_mask) + return + + # b block + rows = tl.load(m_row_indices_ptr + m_range).to(tl.int64) + + # Compute + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + k_range = (k + tl.arange(0, BLOCK_SIZE_K)).to(tl.int64) + k_mask = k_range < K + a_ptrs = a_ptr + rows[:, None] * K + k_range[None, :] + b_ptrs = b_ptr + batch_id * K * N + k_range[:, None] * (1 if IS_B_K_MAJOR else N) + n_range[None, :].to(tl.int64) * (K if IS_B_K_MAJOR else 1) + a = tl.load(a_ptrs, mask=(rows >= 0)[:, None] & k_mask[None, :], other=0.0) + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask, other=0.0) + acc = tl.dot(a, b, acc) + d = acc.to(d_ptr.dtype.element_ty) + + # Write back + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, d, mask=n_mask) + + +def a_fused_m_grouped_bf16_gemm_nt_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + mappings: Tuple[torch.Tensor, torch.Tensor]): + m_indices, m_row_indices = mappings + r0, r1, r2 = b.shape + + assert a.is_contiguous() and (b.is_contiguous() or b.mT.is_contiguous()) and d.is_contiguous() + assert m_indices.is_contiguous() and m_row_indices.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 and d.dtype == torch.bfloat16 + assert m_indices.dtype == torch.int32 and m_row_indices.dtype == torch.int32 + assert a.dim() == 2 and b.dim() == 3 and d.dim() == 2 + assert a.size(1) == r2 and d.size(0) == m_indices.numel() and d.size(1) == r1 + assert m_indices.numel() == m_row_indices.numel() + assert m_indices.numel() % get_mk_alignment_for_contiguous_layout() == 0 + + if d.size(0) == 0: + return d + + M_, K = a.shape + B, K, N = r0, r2, r1 + M = m_indices.numel() + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N, meta['BLOCK_SIZE_N']), ) + a_fused_m_grouped_bf16_gemm_contiguous_tl_impl[grid](a, b, d, m_indices, m_row_indices, + M, N, K, IS_B_K_MAJOR=b.is_contiguous()) + + +def a_fused_m_grouped_bf16_gemm_nn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + mappings: Tuple[torch.Tensor, torch.Tensor]): + a_fused_m_grouped_bf16_gemm_nt_contiguous_tl(a, b.mT, d, mappings) diff --git a/deep_gemm/legacy/b_fused_k_grouped_gemm.py b/deep_gemm/legacy/b_fused_k_grouped_gemm.py new file mode 100644 index 00000000..a642204b --- /dev/null +++ b/deep_gemm/legacy/b_fused_k_grouped_gemm.py @@ -0,0 +1,86 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_k_grouped_gemm_configs(), key=[], restore_value=['d_ptr']) +@triton.jit +def b_fused_k_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + k_indices_ptr, k_start_ptr, k_end_ptr, + M: tl.constexpr, + N: tl.constexpr, + K, + ACC: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_b = (pid // (num_pid_m * num_pid_n)).to(tl.int64) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + k_start = tl.load(k_start_ptr + pid_b) + k_end = tl.load(k_end_ptr + pid_b) + + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M) + n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N) + m_mask = (m_range < M)[:, None] + n_mask = (n_range < N)[None, :] + + if k_start >= k_end: + if not ACC: + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=m_mask & n_mask) + return + + # Compute + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(k_start, k_end, BLOCK_SIZE_K): + k_range = (k + tl.arange(0, BLOCK_SIZE_K)).to(tl.int64) + rows = tl.load(k_indices_ptr + k_range).to(tl.int64) + a_ptrs = a_ptr + m_range[:, None] + k_range[None, :] * M + b_ptrs = b_ptr + rows[:, None] * N + n_range[None, :] + a = tl.load(a_ptrs, mask=m_mask, other=0.0) + b = tl.load(b_ptrs, mask=(rows >= 0)[:, None] & n_mask, other=0.0) + acc = tl.dot(a, b, acc) + + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + if ACC: + acc += tl.load(d_ptrs, mask=m_mask & n_mask) + acc = acc.to(d_ptr.dtype.element_ty) + tl.store(d_ptrs, acc, mask=m_mask & n_mask) + + +def b_fused_k_grouped_bf16_gemm_tn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + handle: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], acc: bool): + k_indices, k_start, k_end = handle + + assert a.is_contiguous() and b.is_contiguous() and d.is_contiguous() + assert k_indices.is_contiguous() and k_start.is_contiguous() and k_end.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 + assert k_indices.dtype == torch.int32 and k_start.dtype == torch.int32 and k_end.dtype == torch.int32 + assert a.dim() == 2 and b.dim() == 2 and d.dim() == 3 + assert k_start.numel() == k_end.numel() and k_indices.size(0) == a.size(0) + assert d.size(0) == k_start.numel() and d.size(1) == a.size(1) and d.size(2) == b.size(1) + assert a.size(0) % get_mk_alignment_for_contiguous_layout() == 0 + + K, M = a.shape + K_, N = b.shape + B = k_start.numel() + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']) * B,) + b_fused_k_grouped_bf16_gemm_contiguous_tl_impl[grid](a, b, d, k_indices, k_start, k_end, M, N, K, ACC=acc) diff --git a/deep_gemm/legacy/m_grouped_gemm.py b/deep_gemm/legacy/m_grouped_gemm.py new file mode 100644 index 00000000..e685a9ab --- /dev/null +++ b/deep_gemm/legacy/m_grouped_gemm.py @@ -0,0 +1,84 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_m_grouped_gemm_configs(), key=[]) +@triton.jit +def m_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + m_indices_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + IS_B_K_MAJOR: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + n_mask = (n_range < N)[None, :] + + # Empty tokens + batch_id = tl.load(m_indices_ptr + pid_m * BLOCK_SIZE_M).to(tl.int64) + if batch_id < 0: + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=n_mask) + return + + # Compute + a_ptrs = a_ptr + m_range[:, None].to(tl.int64) * K + tl.arange(0, BLOCK_SIZE_K)[None, :] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + b_ptrs = b_ptr + batch_id * K * N + \ + tl.arange(0, BLOCK_SIZE_K)[:, None].to(tl.int64) * (1 if IS_B_K_MAJOR else N) + \ + n_range[None, :].to(tl.int64) * (K if IS_B_K_MAJOR else 1) + for k in range(0, K, BLOCK_SIZE_K): + k_mask = (k + tl.arange(0, BLOCK_SIZE_K)) < K + a = tl.load(a_ptrs, mask=k_mask[None, :], other=0.0) + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K * (1 if IS_B_K_MAJOR else N) + + # Write back + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, accumulator.to(d_ptr.dtype.element_ty), mask=n_mask) + + +def m_grouped_bf16_gemm_nt_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + m_indices: torch.Tensor): + r0, r1, r2 = b.shape + + assert a.is_contiguous() and (b.is_contiguous or b.mT.is_contiguous()) + assert m_indices.is_contiguous() and d.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 + assert m_indices.dtype == torch.int32 and d.dtype == torch.bfloat16 + assert a.dim() == 2 and b.dim() == 3 and d.dim() == 2 + assert a.size(1) == r2 and a.size(0) == d.size(0) and r1 == d.size(1) + assert m_indices.numel() == a.size(0) + assert a.size(0) % get_mk_alignment_for_contiguous_layout() == 0 + M, K = a.shape + B, N, K_ = r0, r1, r2 + + # For Triton 2.0, persistent kernel will lead to errors + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + m_grouped_bf16_gemm_contiguous_tl_impl[grid]( + a, b, d, m_indices, M, N, K, IS_B_K_MAJOR=b.is_contiguous()) + + +def m_grouped_bf16_gemm_nn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + m_indices: torch.Tensor): + m_grouped_bf16_gemm_nt_contiguous_tl(a, b.mT, d, m_indices) diff --git a/deep_gemm/legacy/tune_options.py b/deep_gemm/legacy/tune_options.py new file mode 100644 index 00000000..ed6a7f77 --- /dev/null +++ b/deep_gemm/legacy/tune_options.py @@ -0,0 +1,28 @@ +from triton import Config +from .._C import get_mk_alignment_for_contiguous_layout + + +def get_config_smem_size(config: Config, elem_bytes: int = 2): + # NOTES: FP8 kernels will not use Triton, so by default we assume BF16 kernels + return (config.kwargs['BLOCK_SIZE_M'] + config.kwargs['BLOCK_SIZE_N']) * config.kwargs['BLOCK_SIZE_K'] * elem_bytes * config.num_stages + + +_gemm_configs = [ + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4), + Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4), +] + +# NOTES: we only consider A100 shared memory sizes here, as legacy kernels are only used for Ampere +_gemm_configs = list(filter(lambda x: get_config_smem_size(x) <= 166912, _gemm_configs)) +_gemm_configs = list(filter(lambda x: x.kwargs['BLOCK_SIZE_M'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) +_gemm_configs = list(filter(lambda x: x.kwargs['BLOCK_SIZE_K'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) + +get_m_grouped_gemm_configs = lambda: list(filter(lambda x: x.kwargs['BLOCK_SIZE_M'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) +get_k_grouped_gemm_configs = lambda: list(filter(lambda x: x.kwargs['BLOCK_SIZE_K'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) diff --git a/deep_gemm/testing/__init__.py b/deep_gemm/testing/__init__.py index 2537dbf1..13a9d78d 100644 --- a/deep_gemm/testing/__init__.py +++ b/deep_gemm/testing/__init__.py @@ -1,3 +1,4 @@ -from . import bench, numeric +from . import bench, numeric, utils from .bench import * from .numeric import * +from .utils import * diff --git a/deep_gemm/testing/bench.py b/deep_gemm/testing/bench.py index 8bba422c..2c752da2 100644 --- a/deep_gemm/testing/bench.py +++ b/deep_gemm/testing/bench.py @@ -79,37 +79,34 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, trace_path: str = None, flush_l2: bool = True, with_multiple_kernels: bool = False): - # Conflict with Nsight Systems - using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0)) + assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) + is_tuple = isinstance(kernel_names, tuple) - # By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle + # Skip profiling + # Conflict with Nsight Systems, Nsight Compute and Compute Sanitizer + if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)): + return (1, ) * len(kernel_names) if is_tuple else 1 + + # By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle flush_l2_size = int(8e9 // 4) # For some auto-tuning kernels with prints fn() # Profile - suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress + suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress with suppress(): - schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) if not using_nsys else None - profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress() + schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) + profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) with profiler: for i in range(2): for _ in range(num_tests): if flush_l2: torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() fn() - - if not using_nsys: - profiler.step() - - # Return 1 if using Nsight Systems - if using_nsys: - return 1 + profiler.step() # Parse the profiling table - assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) - is_tuple = isinstance(kernel_names, tuple) prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names if not with_multiple_kernels: diff --git a/deep_gemm/testing/numeric.py b/deep_gemm/testing/numeric.py index 37a88d43..58b9b92f 100644 --- a/deep_gemm/testing/numeric.py +++ b/deep_gemm/testing/numeric.py @@ -5,6 +5,8 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): x, y = x.double(), y.double() denominator = (x * x + y * y).sum() + if denominator == 0: # Which means that all elements in x and y are 0 + return 0.0 sim = 2 * (x * y).sum() / denominator return 1 - sim diff --git a/deep_gemm/testing/utils.py b/deep_gemm/testing/utils.py new file mode 100644 index 00000000..2d202d41 --- /dev/null +++ b/deep_gemm/testing/utils.py @@ -0,0 +1,38 @@ +import functools +import os +import torch +from typing import Callable + +def get_arch_major() -> int: + major, minor = torch.cuda.get_device_capability() + return major + + +def test_filter(condition: Callable): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if condition(): + func(*args, **kwargs) + else: + print(f'{func.__name__}:') + print(f' > Filtered by {condition}') + print() + return wrapper + return decorator + + +def ignore_env(name: str, condition: Callable): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if condition(): + saved = os.environ.pop(name, None) + func(*args, **kwargs) + if saved is not None: + os.environ[name] = saved + else: + func(*args, **kwargs) + + return wrapper + return decorator diff --git a/deep_gemm/utils/math.py b/deep_gemm/utils/math.py index 1a47e155..c65026e5 100644 --- a/deep_gemm/utils/math.py +++ b/deep_gemm/utils/math.py @@ -15,35 +15,35 @@ def ceil_to_ue8m0(x: torch.Tensor): return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) -def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: +def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape - padded_n = align(n, 128) + padded_n = align(n, gran_k) x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0) x_padded[:, :n] = x - x_view = x_padded.view(m, -1, 128) + x_view = x_padded.view(m, -1, gran_k) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) sf = x_amax / 448.0 sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf -def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 and x.size(0) % 128 == 0 +def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(0) % gran_k == 0 m, n = x.shape - x_view = x.view(-1, 128, n) + x_view = x.view(-1, gran_k, n) x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4) sf = x_amax / 448.0 sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf -def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: +def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros((align(m, 128), align(n, 128)), dtype=x.dtype, device=x.device) + x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device) x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) sf = x_amax / 448.0 sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf @@ -58,3 +58,50 @@ def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) - sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) return x_scaled, sf.squeeze() + + +def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: + ax = x.abs().clamp_max(6.0) + # {0, 0.5, 1, 1.5, 2, 3, 4, 6} + # midpoints: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0 + boundaries = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], + device=x.device, dtype=ax.dtype) + idx = torch.bucketize(ax, boundaries) + code = idx.to(torch.uint8) + sign = (x < 0) & (idx != 0) + code = code | (sign.to(torch.uint8) << 3) + return code # uint8, 0..15 + + +def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + assert n % 2 == 0 + padded_n = align(n, gran_k) + x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device) + x_padded[:, :n] = x + x_view = x_padded.view(m, -1, gran_k) + x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4) + sf = x_amax / 6.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = x_view * (1.0 / sf.unsqueeze(2)) + codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n) + codes2 = codes.view(m, padded_n // 2, 2) + packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8 + return packed[:, :n // 2].contiguous(), sf + + +def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor: + assert a.dtype == torch.uint8 + assert a.dim() == 2 + m, n2 = a.shape + n = n2 * 2 + assert (m % 2) == 0 + lo = a & 0x0F + hi = (a >> 4) & 0x0F + codes = torch.empty((m, n), device=a.device, dtype=torch.uint8) + codes[:, 0::2], codes[:, 1::2] = lo, hi + codes_t = codes.transpose(0, 1).contiguous() + codes2 = codes_t.view(n, m // 2, 2) + out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) + return out.contiguous() \ No newline at end of file diff --git a/develop.sh b/develop.sh index 3a71e249..e784347a 100755 --- a/develop.sh +++ b/develop.sh @@ -15,7 +15,7 @@ python setup.py build # Find the .so file in build directory and create symlink in current directory so_file=$(find build -name "*.so" -type f | head -n 1) if [ -n "$so_file" ]; then - ln -sf "$so_file" . + ln -sf "../$so_file" deep_gemm/ else echo "Error: No SO file found in build directory" >&2 exit 1 diff --git a/install.sh b/install.sh index 6e5e6f22..5c7021c6 100755 --- a/install.sh +++ b/install.sh @@ -7,7 +7,7 @@ cd "$script_dir" rm -rf build dist rm -rf *.egg-info python setup.py bdist_wheel -pip install dist/*.whl +pip install dist/*.whl --force-reinstall # Open users' original directory cd "$original_dir" diff --git a/scripts/generate_pyi.py b/scripts/generate_pyi.py new file mode 100644 index 00000000..df7490d4 --- /dev/null +++ b/scripts/generate_pyi.py @@ -0,0 +1,890 @@ +import re +from pathlib import Path + + +def build_cpp_function_index(root_path): + func_index = {} + extensions = {'.cpp', '.cc', '.cxx', '.c', '.hpp', '.h'} + + pattern = re.compile( + r'([\w:\s*<&>,\[\]\(\)]+?)' + r'\s+' + r'([a-zA-Z_][a-zA-Z0-9_:]*)' + r'\s*\(', + ) + + for file_path in Path(root_path).rglob('*'): + if file_path.suffix.lower() not in extensions: + continue + if not file_path.is_file(): + continue + + try: + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + except Exception as e: + print(f'Failed to read file {file_path}: {e}') + continue + + # Remove the compile directives and comments + lines = content.split('\n') + clean_lines = [line for line in lines if not line.strip().startswith(('#', '//'))] + content = '\n'.join(clean_lines) + + for match in pattern.finditer(content): + return_type_part = match.group(1).strip() + full_func_name = match.group(2).strip() + + if not return_type_part or not re.match(r'^[a-zA-Z_]', return_type_part): + continue + + first_token = return_type_part.split()[0] + if first_token in {'return', 'if', 'else', 'for', 'while', 'switch', 'case', 'throw', 'catch', 'auto'}: + continue + + # Extract base name + if '::' in full_func_name: + base_name = full_func_name.split('::')[-1] + else: + base_name = full_func_name + + if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', base_name): + continue + + # Find matching ')' + paren_start = match.end() - 1 + paren_count = 0 + pos = paren_start + while pos < len(content): + ch = content[pos] + if ch == '(': + paren_count += 1 + elif ch == ')': + paren_count -= 1 + if paren_count == 0: + break + elif paren_count < 0: + pos = -1 + break + pos += 1 + else: + continue + + if pos == -1: + continue + + # Check context before match: should be at statement boundary + match_start = match.start() + context_before = content[max(0, match_start - 50):match_start] + if context_before and re.search(r'[a-zA-Z0-9_]$', context_before.rstrip()): + continue + + # Check for definition or header declaration + is_header = file_path.suffix.lower() in {'.h', '.hpp', '.cuh'} + after_paren = content[pos+1:pos+500] + has_brace = '{' in after_paren + has_semicolon = ';' in after_paren.split('{')[0] + + if has_brace or (is_header and has_semicolon): + sig_start = match.start(1) + full_signature = content[sig_start:pos+1].strip() + if base_name not in func_index: + func_index[base_name] = full_signature + + return func_index + + +class BracketTracker: + """ + Tracks nesting levels of various brackets in C++ code: + - () → paren + - [] → bracket + - {} → brace + - <> → angle (treated as template brackets only at top level) + Provides is_top_level() to check if currently outside all brackets. + """ + def __init__(self): + self.paren = 0 # () + self.bracket = 0 # [] + self.brace = 0 # {} + self.angle = 0 # <> + + def update(self, char: str): + """ + Update internal counters based on the given character. + """ + if char == '(': + self.paren += 1 + elif char == ')': + self.paren -= 1 + elif char == '[': + self.bracket += 1 + elif char == ']': + self.bracket -= 1 + elif char == '{': + self.brace += 1 + elif char == '}': + self.brace -= 1 + # Angle brackets < > are only treated as template delimiters + # when not inside (), [], or {} + elif char == '<' and self._in_top_level_of_other_brackets(): + self.angle += 1 + elif char == '>' and self.angle > 0 and self._in_top_level_of_other_brackets(): + self.angle -= 1 + + def _in_top_level_of_other_brackets(self): + """ + Check if not inside parentheses, square brackets, or braces (for correct template bracket recognition). + """ + return self.paren == 0 and self.bracket == 0 and self.brace == 0 + + def is_top_level(self): + """ + Check if completely at top level (all bracket counters are zero). + """ + return (self.paren == 0 and + self.bracket == 0 and + self.brace == 0 and + self.angle == 0) + + +def extract_m_def_statements(root_path): + """ + Scan all c files under root_path and extract all m.def(...) statements. + """ + results = [] + extensions = {'.hpp', '.cpp', '.h', '.cc'} + + # Regex: match m.def( ... ), supports multi-line + pattern = re.compile(r'm\.def\s*\(') + + for file_path in Path(root_path).rglob('*'): + if file_path.suffix.lower() not in extensions: + continue + if not file_path.is_file(): + continue + + try: + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + except Exception as e: + print(f'Failed to read file {file_path}: {e}') + continue + + m_def_list = [] + lines = content.splitlines(keepends=True) + i = 0 + while i < len(lines): + line = lines[i] + if 'm.def(' in line: + # Found a potential starting line + start_i = i + # Check if it's a comment + stripped = line.lstrip() + if stripped.startswith('//') or stripped.startswith('/*'): + i += 1 + continue + + # Try to match the complete m.def(...) call + paren_count = 0 + j = i + found_start = False + while j < len(lines): + current_line = lines[j] + for k, char in enumerate(current_line): + if char == '(': + if not found_start and re.search(r'm\.def\s*\(', current_line[:k+1]): + found_start = True + if found_start: + paren_count += 1 + elif char == ')': + if found_start: + paren_count -= 1 + if paren_count == 0: + # Found complete statement + full_stmt = ''.join(lines[i:j+1]).rstrip() + m_def_list.append(full_stmt) + i = j + break + if paren_count <= 0 and found_start: + break + j += 1 + else: + pass + i += 1 + + if m_def_list: + results.append({ + 'file': str(file_path), + 'm_def_statements': m_def_list + }) + + return results + + +def parse_m_def_statement(m_def_str): + result = { + 'python_function_name': None, + 'num_args': 0, + 'default_args': {}, + 'is_lambda': False, + } + + # Extract top-level arguments + start = m_def_str.find('m.def(') + if start == -1: + raise ValueError(f'[{m_def_str}] Could not find m.def start position') + + paren_count = 0 + content_start = start + len('m.def(') + content_end = -1 + for i in range(content_start, len(m_def_str)): + ch = m_def_str[i] + if ch == '(': + paren_count += 1 + elif ch == ')': + if paren_count == 0: + content_end = i + break + else: + paren_count -= 1 + if content_end == -1: + raise ValueError(f'[{m_def_str}] m.def parentheses not closed') + + args_content = m_def_str[content_start:content_end] + + # Split arguments using BracketTracker + args_list = [] + current = [] + tracker = BracketTracker() + + for ch in args_content: + if ch in '()[]{}<>': + tracker.update(ch) + if ch == ',' and tracker.is_top_level(): + args_list.append(''.join(current).strip()) + current = [] + else: + current.append(ch) + + if current: + args_list.append(''.join(current).strip()) + + if len(args_list) < 2: + raise ValueError(f'[{m_def_str}] m.def has insufficient arguments') + + # Extract Python function name + first = args_list[0].strip() + str_match = re.match(r'^"([^"\\]*(?:\\.[^"\\]*)*)"', first) + if str_match: + result['python_function_name'] = str_match.group(1) + else: + raise ValueError(f'[{m_def_str}] m.def first argument should be a string literal') + + cpp_func_part = args_list[1].strip() + if cpp_func_part.startswith('&'): + cpp_func_part = cpp_func_part[1:].strip() + + if cpp_func_part.startswith('['): + result['is_lambda'] = True + result['cpp_function_name'] = None + else: + if '::' in cpp_func_part: + cpp_func_name = cpp_func_part.split('::')[-1] + else: + cpp_func_name = cpp_func_part + + match = re.match(r'^([a-zA-Z_][a-zA-Z0-9_]*)', cpp_func_name) + if match: + result['cpp_function_name'] = match.group(1) + else: + result['cpp_function_name'] = cpp_func_name + + # Parse py::arg arguments + py_args = args_list[2:] + result['num_args'] = len(py_args) + + for idx, arg_expr in enumerate(py_args): + expr = arg_expr.strip() + # Find top-level '=' + eq_pos = -1 + p_depth = b_depth = br_depth = angle_depth = 0 + i = 0 + while i < len(expr): + ch = expr[i] + if ch == '(': + p_depth += 1 + elif ch == ')': + p_depth -= 1 + elif ch == '[': + b_depth += 1 + elif ch == ']': + b_depth -= 1 + elif ch == '{': + br_depth += 1 + elif ch == '}': + br_depth -= 1 + elif ch == '<' and p_depth == 0 and b_depth == 0 and br_depth == 0: + angle_depth += 1 + elif ch == '>' and angle_depth > 0 and p_depth == 0 and b_depth == 0 and br_depth == 0: + angle_depth -= 1 + elif ch == '=' and all(d == 0 for d in [p_depth, b_depth, br_depth, angle_depth]): + eq_pos = i + break + i += 1 + + if eq_pos != -1: + default_val = expr[eq_pos + 1:].strip() + if not default_val: + raise ValueError(f'[{expr}] Default value is empty (arg {idx})') + result['default_args'][idx] = default_val + + return result + + +def extract_cpp_signature_from_content(cpp_func_name, content): + """ + Search for the C++ function signature of cpp_func_name in the given file content. + """ + if not cpp_func_name: + return None + + # Build regex: match function starting with cpp_func_name (after word boundary) + # Note: function name may be preceded by return type (with templates, namespaces, etc.), followed by '(' + pattern = re.compile( + r'^\s*' # leading whitespace + r'([\w:\s*<&>,\[\]\(\)]+?)' # return type (non-greedy, allows templates, pointers, etc.) + r'\s+' # at least one space + r'\b' + re.escape(cpp_func_name) + r'\b' # function name (word boundary) + r'\s*\(', # optional whitespace + start of param list + re.MULTILINE + ) + + for match in pattern.finditer(content): + # Find '(' position after function name + paren_start = match.end() - 1 + if content[paren_start] != '(': + paren_start = content.find('(', match.end(0) - 1) + if paren_start == -1: + continue + + # From '(', match to corresponding ')' + paren_count = 0 + pos = paren_start + while pos < len(content): + ch = content[pos] + if ch == '(': + paren_count += 1 + elif ch == ')': + paren_count -= 1 + if paren_count == 0: + start_sig = match.start(1) + full_signature = content[start_sig:pos+1].strip() + return full_signature + pos += 1 + + return None + + +def parse_mdef_and_attach_cpp_signatures(item, func_index): + """ + Enhance item by parsing m.def and extracting C++ function signature from global index + """ + statements_with_parsed_signatures = [] + + for stmt in item['m_def_statements']: + parsed = parse_m_def_statement(stmt,) + cpp_func_name = parsed.get('cpp_function_name') + + cpp_sig = None + if cpp_func_name and cpp_func_name in func_index: + cpp_sig = func_index[cpp_func_name] + else: + if not parsed['is_lambda']: + print(f'Warning: C++ function "{cpp_func_name}" not found in any .cpp file') + + parsed['cpp_signature'] = cpp_sig + statements_with_parsed_signatures.append({ + 'raw': stmt, + 'parsed': parsed + }) + + return { + 'm_def_statements': statements_with_parsed_signatures + } + + +def parse_cpp_signature(cpp_sig): + """ + Parse a C++ function signature and extract return type, parameter types, and names. + """ + if not cpp_sig or not cpp_sig.strip(): + return None + + # Find function name: last identifier before '(' + paren_pos = cpp_sig.find('(') + if paren_pos == -1: + return None + + before_paren = cpp_sig[:paren_pos].strip() + if not before_paren: + return None + + # Function name is the last word in before_paren (may include templates like func) + tokens = before_paren.split() + if len(tokens) < 2: + return None + + # Heuristic: function name is usually the last token (may include <>) + func_name_part = tokens[-1] + return_type = ' '.join(tokens[:-1]).strip() + + # Now extract parameter list content + param_list_str = cpp_sig[paren_pos+1:cpp_sig.rfind(')')].strip() + parameters = [] + + if param_list_str and param_list_str != 'void': # 'void' means no parameters + # Split parameters (handle commas not inside templates/brackets) + param_decls = split_cpp_parameters(param_list_str) + for decl in param_decls: + decl = decl.strip() + if not decl: + continue + # Try to split type and name from right to left + param_info = parse_parameter_declaration(decl) + if param_info: + parameters.append(param_info) + + return { + 'return_type': return_type, + 'parameters': parameters, + 'num_parameters': len(parameters) + } + + +def split_cpp_parameters(param_str: str): + """ + Split a C++ parameter list string by top-level commas, + e.g., 'int a, std::vector b' → ['int a', 'std::vector b'] + """ + if not param_str.strip() or param_str == 'void': + return [] + params = [] + current = [] + tracker = BracketTracker() + + for ch in param_str: + if ch in '()[]{}<>': + tracker.update(ch) + if ch == ',' and tracker.is_top_level(): + param = ''.join(current).strip() + if param: # Only add non-empty parameters + params.append(param) + current = [] + else: + current.append(ch) + + if current: + final_param = ''.join(current).strip() + if final_param: # Only add non-empty parameters + params.append(final_param) + return params + + +def parse_parameter_declaration(decl: str): + """ + Parse a single parameter declaration, e.g., 'const std::string& name' → {'type': 'const std::string&', 'name': 'name'} + Improved version that better handles template types. + """ + decl = decl.strip() + if not decl: + return None + + # Remove possible default value (starting from top-level '=') + tracker = BracketTracker() + eq_pos = -1 + for i, ch in enumerate(decl): + if ch in '()[]{}<>': + tracker.update(ch) + elif ch == '=' and tracker.is_top_level(): + eq_pos = i + break + + if eq_pos != -1: + decl = decl[:eq_pos].strip() + + # Now decl is 'type name' or just 'type' + # Instead of simple splitting, we'll use a more robust approach + # to find the parameter name + + # First, let's handle the case where there's no explicit parameter name + # (this sometimes happens in function declarations) + if not re.search(r'[a-zA-Z_][a-zA-Z0-9_]*$', decl): + # No parameter name found, just return the type + return { + 'type': decl, + 'name': None + } + + # Use bracket tracking to find where the type ends and name begins + tracker = BracketTracker() + name_start = -1 + + # Scan from the end to find the start of the parameter name + # We look for the first identifier that's outside all brackets + i = len(decl) - 1 + while i >= 0: + ch = decl[i] + + if ch in '()[]{}<>': + tracker.update(ch) + + # If we're at top level and find an identifier character + if tracker.is_top_level() and re.match(r'[a-zA-Z0-9_]', ch): + # Track back to find the start of this identifier + name_start = i + while name_start > 0 and re.match(r'[a-zA-Z0-9_]', decl[name_start - 1]): + name_start -= 1 + + # Check if this might be part of a type keyword (like 'int', 'bool', etc.) + potential_name = decl[name_start:i+1] + type_keywords = {'int', 'long', 'short', 'char', 'bool', 'float', 'double', + 'void', 'auto', 'const', 'static', 'volatile', 'mutable', + 'unsigned', 'signed'} + + # If it's not a type keyword and looks like a parameter name, use it + if (potential_name not in type_keywords and + re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', potential_name)): + break + + i -= 1 + + if name_start != -1 and i >= 0: + param_name = decl[name_start:i+1] + param_type = decl[:name_start].strip() + + # Clean up the type - remove trailing &, * and whitespace + param_type = param_type.rstrip('&* \t') + + return { + 'type': param_type, + 'name': param_name + } + + # Fallback: if we can't find a clear parameter name, just return the type + return { + 'type': decl, + 'name': None + } + + +def extract_cpp_signature_details(item): + """ + For each m.def entry in item, parse cpp_signature to extract return type and parameter details. + """ + statements_with_parsed_signatures = [] + for stmt_info in item['m_def_statements']: + parsed = stmt_info['parsed'] + cpp_sig = parsed.get('cpp_signature') + + cpp_params_info = None + if cpp_sig: + try: + cpp_params_info = parse_cpp_signature(cpp_sig) + except Exception as e: + print(f'Failed to parse C++ signature: {e}') + + parsed['cpp_parsed_signature'] = cpp_params_info + statements_with_parsed_signatures.append({ + 'raw': stmt_info['raw'], + 'parsed': parsed + }) + + return { + 'm_def_statements': statements_with_parsed_signatures + } + + +def cpp_type_to_python_type(cpp_type: str) -> str: + if not cpp_type: + return 'Any' + + original = cpp_type.strip() + if not original: + return 'Any' + + # Remove C++ specifiers that don't affect Python type + cleaned = re.sub(r'\b(static|inline|constexpr|thread_local|extern|mutable|const|volatile|endif)\b', '', original) + cleaned = cleaned.replace('&', '').replace('*', '').strip() + cleaned = re.sub(r'\s+', ' ', cleaned).strip() + + # Handle void + if cleaned == 'void': + return 'None' + + # Handle template types — ORDER MATTERS! Must come before internal type checks. + + # std::pair + if cleaned.startswith('std::pair<'): + inner = cleaned[10:-1].strip() # len('std::pair<') == 10 + args = split_template_args(inner) + if len(args) == 2: + t1 = cpp_type_to_python_type(args[0]) + t2 = cpp_type_to_python_type(args[1]) + return f'tuple[{t1}, {t2}]' + else: + print(f'Warning: std::pair with unexpected number of args: {cleaned}') + return 'Any' + + # std::tuple + if cleaned.startswith('std::tuple<'): + inner = cleaned[11:-1].strip() # len('std::tuple<') == 11 + args = split_template_args(inner) + py_types = [cpp_type_to_python_type(arg) for arg in args] + return f"tuple[{', '.join(py_types)}]" + + # std::vector + if cleaned.startswith('std::vector<'): + inner = cleaned[12:-1].strip() # len('std::vector<') == 12 + args = split_template_args(inner) + if len(args) == 1: + inner_py = cpp_type_to_python_type(args[0]) + return f'list[{inner_py}]' + else: + print(f'Warning: std::vector with unexpected args: {cleaned}') + return 'Any' + + # std::optional + if cleaned.startswith('std::optional<'): + inner = cleaned[14:-1].strip() # len('std::optional<') == 14 + args = split_template_args(inner) + if len(args) == 1: + inner_py = cpp_type_to_python_type(args[0]) + return f'Optional[{inner_py}]' + else: + print(f'Warning: std::optional with unexpected args: {cleaned}') + return 'Any' + + # std::string + if re.search(r'\bstd::string\b', original): + return 'str' + + # C-style strings: char*, const char*, char[], etc. + if re.search(r'\b(?:const\s+)?char\s*[\*\[]', original): + return 'str' + + # Boolean + if re.search(r'\bbool\b', cleaned): + return 'bool' + + # Integer types (including fixed-width and common aliases) + if re.search(r'\b(int|long|short|size_t|ssize_t|ptrdiff_t|' + r'int8_t|int16_t|int32_t|int64_t|' + r'uint8_t|uint16_t|uint32_t|uint64_t)\b', cleaned): + return 'int' + + # Floating-point + if re.search(r'\b(float|double|long\s+double)\b', cleaned): + return 'float' + + # torch::Tensor + if re.search(r'\btorch::Tensor\b', original): + return 'torch.Tensor' + + # Unrecognized type + print(f'Warning: Unrecognized C++ type: {original}') + return 'Any' + + +def split_template_args(template_args: str): + """ + Split template arguments, e.g., 'int, std::vector' → ['int', 'std::vector'] + """ + if not template_args.strip(): + return [] + args = [] + current = [] + tracker = BracketTracker() + + for ch in template_args: + if ch in '()[]{}<>': + tracker.update(ch) + if ch == ',' and tracker.is_top_level(): + args.append(''.join(current).strip()) + current = [] + else: + current.append(ch) + + if current: + args.append(''.join(current).strip()) + return args + + +def cpp_default_to_python_default(cpp_default: str): + """ + Convert C++ default value string to valid Python expression string. + """ + if not cpp_default: + return 'None' + + s = cpp_default.strip() + + # Handle string literals: 'bf16' → 'bf16' + # Match: starts and ends with unescaped double quotes + string_match = re.match(r'^"([^"\\]*(?:\\.[^"\\]*)*)"$', s) + if string_match: + return s + + # Handle boolean literals + if s == 'false': + return 'False' + if s == 'true': + return 'True' + + # Handle null-like values: nullptr, nullopt, NULL, etc. + if s in ('nullptr', 'NULL') or 'nullopt' in s: + return 'None' + + # Handle std::tuple({128, 128}) → (128, 128) + tuple_match = re.match(r'std::tuple\s*<[^>]*>\s*\(\s*({.*?})\s*\)', s) + if tuple_match: + inner = tuple_match.group(1) # {128, 128} + inner_py = inner.replace('{', '(').replace('}', ')') + return inner_py + + # Handle std::make_tuple(1, 2, 3) → (1, 2, 3) + make_tuple_match = re.match(r'std::make_tuple\s*\(\s*(.*?)\s*\)', s) + if make_tuple_match: + inner = make_tuple_match.group(1) + # Ensure it's a valid tuple even with one element: add comma if needed? + # But in C++ default args, it's usually multi-element, so we assume valid. + return f'({inner})' + + # Handle std::vector({1,2,3}) → [1, 2, 3] + vector_match = re.match(r'std::vector\s*<[^>]*>\s*\(\s*({.*?})\s*\)', s) + if vector_match: + inner = vector_match.group(1) + inner_py = inner.replace('{', '[').replace('}', ']') + return inner_py + + # Handle numeric literals: integers and floats + if re.match(r'^[+-]?\d+$', s): # integer + return s + if re.match(r'^[+-]?\d*\.\d+([eE][+-]?\d+)?$', s): # float + return s + + # Fallback: unrecognized → warn and return None + print(f'Warning: Unrecognized default value: {s}') + return 'None' + + +def generate_pyi_function(item_entry): + parsed = item_entry['parsed'] + py_name = parsed['python_function_name'] + + if parsed.get('is_lambda'): + return f'def {py_name}(*args, **kwargs) -> Any: ...' + + sig_info = parsed.get('cpp_parsed_signature') + default_args = parsed.get('default_args', {}) + + if not sig_info: + return f'def {py_name}(*args, **kwargs) -> Any: ...' + + return_type = cpp_type_to_python_type(sig_info['return_type']) + params = sig_info['parameters'] + num_params = len(params) + + # Build parameter list + param_lines = [] + for i in range(num_params): + param_info = params[i] if i < len(params) else {'type': 'Any', 'name': f'arg{i}'} + param_type = cpp_type_to_python_type(param_info['type']) + param_name = param_info['name'] or f'arg{i}' + + # Replace invalid Python identifiers (e.g., keywords) + if param_name in {'def', 'class', 'from', 'import', 'None', 'True', 'False'}: + param_name = f'{param_name}_' + + # Check for default value + if i in default_args: + cpp_default = default_args[i] + py_default = cpp_default_to_python_default(cpp_default) + param_str = f' {param_name}: {param_type} = {py_default}' + else: + param_str = f' {param_name}: {param_type}' + + param_lines.append(param_str) + + if param_lines: + params_block = ',\n'.join(param_lines) + func_def = f'def {py_name}(\n{params_block}\n) -> {return_type}: ...' + else: + func_def = f'def {py_name}() -> {return_type}: ...' + + return func_def + + +def generate_pyi_file_content(enhanced_results, module_name: str = 'my_module'): + function_decls = [] + has_optional = False + has_torch = False + has_numpy = False + + for item in enhanced_results: + for stmt in item['m_def_statements']: + try: + decl = generate_pyi_function(stmt) + function_decls.append(decl) + + if 'Optional[' in decl: + has_optional = True + if 'torch.Tensor' in decl: + has_torch = True + if 'numpy.ndarray' in decl or 'py::array' in str(stmt): + has_numpy = True + except Exception as e: + func_name = stmt['parsed'].get('python_function_name', 'unknown') + function_decls.append(f'# ERROR: failed to generate stub for {func_name}: {e}') + + imports = ['from typing import Any'] + if has_optional: + imports[0] += ', Optional' + + if has_torch: + imports.append('import torch') + if has_numpy: + imports.append('import numpy') + + lines = [f'# Stubs for module: {module_name}', ''] + lines.extend(imports) + lines.append('') + lines.append('') + + for decl in function_decls: + lines.append(decl) + lines.append('') + lines.append('') + + return '\n'.join(lines) + + +def generate_pyi_file(name, root, output_dir='.'): + func_index = build_cpp_function_index(root) + results = extract_m_def_statements(root) + + cpp_results = [] + for item in results: + enhanced_item = parse_mdef_and_attach_cpp_signatures(item, func_index) + cpp_item = extract_cpp_signature_details(enhanced_item) + cpp_results.append(cpp_item) + + pyi_content = generate_pyi_file_content(cpp_results, module_name=name) + + output_path = Path(output_dir) / f'{name}.pyi' + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w', encoding='utf-8') as f: + f.write(pyi_content) + + print(f'.pyi file generated: {output_path}') diff --git a/setup.py b/setup.py index e5b96657..6199d7c3 100644 --- a/setup.py +++ b/setup.py @@ -1,15 +1,37 @@ +import ast import os -import setuptools +import re import shutil +import setuptools import subprocess +import sys import torch +import platform +import urllib +import urllib.error +import urllib.request from setuptools import find_packages from setuptools.command.build_py import build_py +from packaging.version import parse +from pathlib import Path from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel +from scripts.generate_pyi import generate_pyi_file -current_dir = os.path.dirname(os.path.realpath(__file__)) + +DG_SKIP_CUDA_BUILD = int(os.getenv('DG_SKIP_CUDA_BUILD', '0')) == 1 +DG_FORCE_BUILD = int(os.getenv('DG_FORCE_BUILD', '0')) == 1 +DG_USE_LOCAL_VERSION = int(os.getenv('DG_USE_LOCAL_VERSION', '1')) == 1 +DG_JIT_USE_RUNTIME_API = int(os.environ.get('DG_JIT_USE_RUNTIME_API', '0')) == 1 + +# Compiler flags cxx_flags = ['-std=c++17', '-O3', '-fPIC', '-Wno-psabi', '-Wno-deprecated-declarations', f'-D_GLIBCXX_USE_CXX11_ABI={int(torch.compiled_with_cxx11_abi())}'] +if DG_JIT_USE_RUNTIME_API: + cxx_flags.append('-DDG_JIT_USE_RUNTIME_API') + +# Sources +current_dir = os.path.dirname(os.path.realpath(__file__)) sources = ['csrc/python_api.cpp'] build_include_dirs = [ f'{CUDA_HOME}/include', @@ -18,19 +40,75 @@ 'third-party/cutlass/include', 'third-party/fmt/include', ] -build_libraries = ['cuda', 'cudart', 'nvrtc'] -build_library_dirs = [ - f'{CUDA_HOME}/lib64', - f'{CUDA_HOME}/lib64/stubs' -] +build_libraries = ['cudart', 'nvrtc'] +build_library_dirs = [f'{CUDA_HOME}/lib64'] third_party_include_dirs = [ 'third-party/cutlass/include/cute', 'third-party/cutlass/include/cutlass', ] -# Use runtime API -if int(os.environ.get('DG_JIT_USE_RUNTIME_API', '0')): - cxx_flags.append('-DDG_JIT_USE_RUNTIME_API') +# Release +base_wheel_url = 'https://github.com/DeepSeek-AI/DeepGEMM/releases/download/{tag_name}/{wheel_name}' + + +def get_package_version(): + with open(Path(current_dir) / 'deep_gemm' / '__init__.py', 'r') as f: + version_match = re.search(r'^__version__\s*=\s*(.*)$', f.read(), re.MULTILINE) + public_version = ast.literal_eval(version_match.group(1)) + + revision = '' + if DG_USE_LOCAL_VERSION: + # noinspection PyBroadException + try: + status_cmd = ['git', 'status', '--porcelain'] + status_output = subprocess.check_output(status_cmd).decode('ascii').strip() + if status_output: + print(f'Warning: Git working directory is not clean. Uncommitted changes:\n{status_output}') + assert False, 'Git working directory is not clean' + + cmd = ['git', 'rev-parse', '--short', 'HEAD'] + revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() + except (subprocess.CalledProcessError, FileNotFoundError, OSError): + revision = '+local' + return f'{public_version}{revision}' + + +def get_platform(): + if sys.platform.startswith('linux'): + return f'linux_{platform.uname().machine}' + else: + raise ValueError('Unsupported platform: {}'.format(sys.platform)) + + +def get_wheel_url(): + torch_version = parse(torch.__version__) + torch_version = f'{torch_version.major}.{torch_version.minor}' + python_version = f'cp{sys.version_info.major}{sys.version_info.minor}' + platform_name = get_platform() + deep_gemm_version = get_package_version() + cxx11_abi = int(torch._C._GLIBCXX_USE_CXX11_ABI) + + # Determine the version numbers that will be used to determine the correct wheel + # We're using the CUDA version used to build torch, not the one currently installed + cuda_version = parse(torch.version.cuda) + cuda_version = f'{cuda_version.major}' + + # Determine wheel URL based on CUDA version, torch version, python version and OS + wheel_filename = f'deep_gemm-{deep_gemm_version}+cu{cuda_version}-torch{torch_version}-cxx11abi{cxx11_abi}-{python_version}-{platform_name}.whl' + wheel_url = base_wheel_url.format(tag_name=f'v{deep_gemm_version}', wheel_name=wheel_filename) + return wheel_url, wheel_filename + + +def get_ext_modules(): + if DG_SKIP_CUDA_BUILD: + return [] + + return [CUDAExtension(name='deep_gemm._C', + sources=sources, + include_dirs=build_include_dirs, + libraries=build_libraries, + library_dirs=build_library_dirs, + extra_compile_args=cxx_flags)] class CustomBuildPy(build_py): @@ -41,9 +119,24 @@ def run(self): # Second, make clusters' cache setting default into `envs.py` self.generate_default_envs() + # Third, generate and copy .pyi file to build root directory + self.generate_pyi_file() + # Finally, run the regular build build_py.run(self) + def generate_pyi_file(self): + generate_pyi_file(name='_C', root='./csrc', output_dir='./stubs') + pyi_source = os.path.join(current_dir, 'stubs', '_C.pyi') + pyi_target = os.path.join(self.build_lib, 'deep_gemm', '_C.pyi') + + if os.path.exists(pyi_source): + print(f"Copying .pyi file from {pyi_source} to {pyi_target}") + os.makedirs(os.path.dirname(pyi_target), exist_ok=True) + shutil.copy2(pyi_source, pyi_target) + else: + print(f"Warning: .pyi file not found at {pyi_source}") + def generate_default_envs(self): code = '# Pre-installed environment variables\n' code += 'persistent_envs = dict()\n' @@ -72,18 +165,37 @@ def prepare_includes(self): shutil.copytree(src_dir, dst_dir) -if __name__ == '__main__': - # noinspection PyBroadException - try: - cmd = ['git', 'rev-parse', '--short', 'HEAD'] - revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() - except: - revision = '' +class CachedWheelsCommand(_bdist_wheel): + def run(self): + if DG_FORCE_BUILD or DG_USE_LOCAL_VERSION: + return super().run() + + wheel_url, wheel_filename = get_wheel_url() + print(f'Try to download wheel from URL: {wheel_url}') + try: + with urllib.request.urlopen(wheel_url, timeout=1) as response: + with open(wheel_filename, 'wb') as out_file: + data = response.read() + out_file.write(data) + + # Make the archive + if not os.path.exists(self.dist_dir): + os.makedirs(self.dist_dir) + impl_tag, abi_tag, plat_tag = self.get_tag() + archive_basename = f'{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}' + wheel_path = os.path.join(self.dist_dir, archive_basename + '.whl') + os.rename(wheel_filename, wheel_path) + except (urllib.error.HTTPError, urllib.error.URLError): + print('Precompiled wheel not found. Building from source...') + # If the wheel could not be downloaded, build from source + super().run() + +if __name__ == '__main__': # noinspection PyTypeChecker setuptools.setup( name='deep_gemm', - version='2.1.0' + revision, + version=get_package_version(), packages=find_packages('.'), package_data={ 'deep_gemm': [ @@ -92,16 +204,10 @@ def prepare_includes(self): 'include/cutlass/**/*', ] }, - ext_modules=[ - CUDAExtension(name='deep_gemm_cpp', - sources=sources, - include_dirs=build_include_dirs, - libraries=build_libraries, - library_dirs=build_library_dirs, - extra_compile_args=cxx_flags) - ], + ext_modules=get_ext_modules(), zip_safe=False, cmdclass={ 'build_py': CustomBuildPy, + 'bdist_wheel': CachedWheelsCommand, }, ) diff --git a/tests/generators.py b/tests/generators.py index 0d06505a..cd192729 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -1,11 +1,13 @@ import enum import random import torch -from typing import Generator, List +from typing import Generator, List, Optional, Tuple +from deep_gemm.testing import get_arch_major from deep_gemm.utils import ( align, ceil_div, per_token_cast_to_fp8, per_channel_cast_to_fp8, per_block_cast_to_fp8, + per_token_cast_to_fp4, transpose_packed_fp4, get_mk_alignment_for_contiguous_layout ) @@ -34,11 +36,51 @@ def is_k_major(self): def is_mn_major(self): return self.value == 1 + +class QuantConfig: + _legacy_quant_config = (128, 128, False, False) -def get_arch_major() -> int: - major, minor = torch.cuda.get_device_capability() - return major + def __init__(self, value: Tuple[int, int, bool, bool] = _legacy_quant_config): + self.gran_k_a, self.gran_k_b, self.is_fp4_a, self.is_fp4_b = value + + def print(self): + print(f' > Testing with gran_k_a={self.gran_k_a}, gran_k_b={self.gran_k_b}, ' + f'is_fp4_a={self.is_fp4_a}, is_fp4_b={self.is_fp4_b}') + + def is_legacy(self) -> bool: + return (self.gran_k_a, self.gran_k_b, self.is_fp4_a, self.is_fp4_b) == self._legacy_quant_config + + def get_recipes(self, is_wgrad: bool = False) -> Tuple[Tuple, Tuple, Tuple]: + recipe, recipe_a, recipe_b = None, None, None + if self.is_legacy(): + recipe = (1, 1, 128) if is_wgrad else None + else: + recipe_a = (1, self.gran_k_a) + recipe_b = (1, self.gran_k_b) if self.is_fp4_b or is_wgrad else (self.gran_k_b, self.gran_k_b) + return recipe, recipe_a, recipe_b + + def max_diff(self) -> float: + if self.is_fp4_a and self.is_fp4_b: + return 0.02 + if self.is_fp4_a or self.is_fp4_b: + return 0.01 + return 0.001 + + @staticmethod + def get_list_from_dtype(dtype: torch.dtype) -> List: + if dtype == torch.bfloat16: + return [None] + quant_config_list = [QuantConfig()] + if get_arch_major() == 10: + quant_config_list.append(QuantConfig((128, 32, False, True))) + return quant_config_list + + +def reset_seed(seed: int = 0): + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) def get_ue8m0_usage(kernel_type: KernelType) -> bool: @@ -51,9 +93,6 @@ def get_kernel_types(dtype: torch.dtype) -> tuple: if dtype == torch.bfloat16: return (KernelType.KernelNoSF, ) - # TODO: SM100 1D2D kernels are going to be deprecated - # But if you want to test it, please use: - # `(KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, KernelType.Kernel1D2D)` return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, ) @@ -67,61 +106,85 @@ def get_major_ab(allow_a_mn_major: bool, allow_b_mn_major: bool) -> Generator: yield major_a, major_b +def get_psum_layout_usage() -> tuple: + return (False, True) if get_arch_major() == 10 else (False, ) + + def enumerate_normal(dtype: torch.dtype) -> Generator: assert dtype in (torch.float8_e4m3fn, torch.bfloat16) + quant_config_list = QuantConfig.get_list_from_dtype(dtype) fp32_output_nk = [(256, 7168), (129280, 7168)] bf16_output_nk = [(2112, 7168), (576, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)] - m_fwd_list, m_bwd_list = [128, 4096], [4096, ] - nk_list = bf16_output_nk + m_fwd_list, m_bwd_list = [1, 128, 4096], [4096, ] + nk_list = list(bf16_output_nk) # Only BF16 GEMM needs FP32 outputs if dtype == torch.bfloat16: nk_list += fp32_output_nk for kernel_type in get_kernel_types(dtype): - # Forward - for m in m_fwd_list: - for n, k in nk_list: - out_dtype = torch.float if (n, k) in fp32_output_nk else torch.bfloat16 - yield kernel_type, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, out_dtype - - # TODO: support BF16 SM90 MN-major kernels - if dtype == torch.bfloat16 and get_arch_major() == 9: - continue - - # Backward - for m in m_bwd_list: - for n, k in nk_list: - override_major = MajorTypeAB.MNMajor - override_kernel_type = kernel_type - if get_arch_major() == 9 and dtype == torch.float8_e4m3fn: - override_major = MajorTypeAB.KMajor - override_kernel_type = KernelType.Kernel1D1D - yield kernel_type, m, k, n, MajorTypeAB.KMajor, override_major, False, torch.bfloat16 # Dgrad - yield override_kernel_type, n, m, k, override_major, override_major, True, torch.float # Wgrad - yield override_kernel_type, n, m, k, override_major, override_major, False, torch.bfloat16 # Wgrad + for quant_config in quant_config_list: + if len(quant_config_list) > 1: + quant_config.print() + reset_seed() + + # Forward + for m in m_fwd_list: + for i in range(len(nk_list)): + n, k = nk_list[i] + out_dtype = torch.bfloat16 if i < len(bf16_output_nk) else torch.float + yield kernel_type, quant_config, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, out_dtype + + # Backward + for m in m_bwd_list: + for n, k in nk_list: + override_major = MajorTypeAB.MNMajor + override_kernel_type = kernel_type + if get_arch_major() == 9 and dtype == torch.float8_e4m3fn: + override_major = MajorTypeAB.KMajor + override_kernel_type = KernelType.Kernel1D1D + yield kernel_type, quant_config, m, k, n, MajorTypeAB.KMajor, override_major, False, torch.bfloat16 # Dgrad + yield override_kernel_type, quant_config, n, m, k, override_major, override_major, True, torch.float # Wgrad + yield override_kernel_type, quant_config, n, m, k, override_major, override_major, False, torch.bfloat16 # Wgrad def enumerate_m_grouped_contiguous(dtype: torch.dtype) -> Generator: + quant_config_list = QuantConfig.get_list_from_dtype(dtype) + m_group_list = [(4, 8192), (8, 4096)] + n_k_list = [(6144, 7168), (7168, 3072), (4096, 4096), (4096, 2048)] for kernel_type in get_kernel_types(dtype): - for num_groups, expected_m_per_group, n, k in ((4, 8192, 4096, 7168), (4, 8192, 7168, 2048), (8, 4096, 4096, 7168), (8, 4096, 7168, 2048)): - for major_a, major_b in get_major_ab(False, get_arch_major() > 9): - yield kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b + for quant_config in quant_config_list: + if len(quant_config_list) > 1: + quant_config.print() + for use_psum_layout in get_psum_layout_usage(): + reset_seed() + for num_groups, expected_m_per_group in m_group_list: + for n, k in n_k_list: + for major_a, major_b in get_major_ab(False, get_arch_major() != 9 or dtype != torch.float8_e4m3fn): + yield kernel_type, quant_config, num_groups, expected_m_per_group, n, k, major_a, major_b, use_psum_layout def enumerate_m_grouped_masked(dtype: torch.dtype) -> Generator: + quant_config_list = QuantConfig.get_list_from_dtype(dtype) max_m = 4096 + m_group_list = [(6, 1024), (32, 192), (32, 50)] + n_k_list = [(6144, 7168), (7168, 3072), (4096, 4096), (4096, 2048)] for kernel_type in get_kernel_types(dtype): - for enable_overlap in (False, True): - for num_groups, m in ((1, 1024), (2, 512), (4, 256), (16, 64), (16, 32)): - for n, k in ((4096, 7168), (7168, 2048), ): - yield kernel_type, enable_overlap, num_groups, max_m, m, n, k - - -def enumerate_k_grouped_contiguous(): - # Only K-major is supported for SM90 - major_a, major_b = (MajorTypeAB.KMajor, MajorTypeAB.KMajor) if get_arch_major() == 9 \ + for quant_config in quant_config_list: + if len(quant_config_list) > 1: + quant_config.print() + for use_psum_layout in get_psum_layout_usage(): + for enable_overlap in (False, True): + reset_seed() + for num_groups, m in m_group_list: + for n, k in n_k_list: + yield kernel_type, quant_config, enable_overlap, num_groups, max_m, m, n, k, use_psum_layout + + +def enumerate_k_grouped_contiguous(dtype: torch.dtype): + # Only K-major is supported for SM90 FP8 + major_a, major_b = (MajorTypeAB.KMajor, MajorTypeAB.KMajor) if get_arch_major() == 9 and dtype == torch.float8_e4m3fn \ else (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor) # Must with FP32 accumulation and 1D1D kernels for num_groups, m, n, expected_k_per_group in (( 4, 4096, 7168, 8192), ( 4, 7168, 2048, 8192), # EP64 @@ -156,11 +219,46 @@ def enumerate_transpose(): yield mn + delta, k +def cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is_fp4: bool, + use_ue8m0: bool, use_block_cast_for_fp8: bool = False): + if is_fp4: + x_fp4 = per_token_cast_to_fp4(x, use_ue8m0=use_ue8m0, gran_k=gran_k) + x = x_fp4 if major.is_k_major() else (transpose_packed_fp4(x_fp4[0]).T, x_fp4[1]) + else: + x_fp8 = per_block_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \ + else per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k) + x = x_fp8 if major.is_k_major() else (x_fp8[0].T.contiguous().T, x_fp8[1]) + return x + + +def grouped_cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is_fp4: bool, + use_ue8m0: bool, use_block_cast_for_fp8: bool = False): + num_groups, mn, k = x.size() + if is_fp4: + x_fp4 = (torch.empty((num_groups, mn, k // 2), device='cuda', dtype=torch.uint8) if major.is_k_major() else \ + torch.empty((num_groups, k, mn // 2), device='cuda', dtype=torch.uint8), + torch.empty((num_groups, mn, ceil_div(k, gran_k)), device='cuda', dtype=torch.float)) + for i in range(num_groups): + x_i_fp4 = per_token_cast_to_fp4(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) + x_fp4[0][i], x_fp4[1][i] = x_i_fp4 if major.is_k_major() else (transpose_packed_fp4(x_i_fp4[0]), x_i_fp4[1]) + x = x_fp4 if major.is_k_major() else (x_fp4[0].mT, x_fp4[1]) + else: + x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), + torch.empty((num_groups, ceil_div(mn, gran_k), ceil_div(k, gran_k)), device='cuda', dtype=torch.float) if use_block_cast_for_fp8 \ + else torch.empty((num_groups, mn, ceil_div(k, gran_k)), device='cuda', dtype=torch.float)) + for i in range(num_groups): + x_fp8[0][i], x_fp8[1][i] = per_block_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \ + else per_token_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) + x = x_fp8 if major.is_k_major() else (x_fp8[0].mT.contiguous().mT, x_fp8[1]) + return x + + def generate_normal(m: int, n: int, k: int, major_a: MajorTypeAB, major_b: MajorTypeAB, accumulate: bool, out_dtype: torch.dtype, kernel_type: KernelType, - use_ue8m0: bool = False, use_bf16: bool = False): + use_ue8m0: bool = False, use_bf16: bool = False, + quant_config: Optional[QuantConfig] = None): a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) d = torch.randn((m, n), device='cuda', dtype=out_dtype) * 32 if accumulate else \ @@ -172,25 +270,28 @@ def generate_normal(m: int, n: int, k: int, a = a if major_a.is_k_major() else a.T.contiguous().T b = b if major_b.is_k_major() else b.T.contiguous().T return a, b, c, d, ref_d + + quant_config = QuantConfig() if quant_config is None else quant_config + a = cast_fp8_fp4_with_major(a, major_a, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0) + b = cast_fp8_fp4_with_major(b, major_b, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, + use_block_cast_for_fp8=not (kernel_type.is_1d1d() and accumulate)) - a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0) - b_fp8 = per_token_cast_to_fp8(b, use_ue8m0=use_ue8m0) if kernel_type.is_1d1d() and accumulate \ - else per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0) - a_fp8 = a_fp8 if major_a.is_k_major() else (a_fp8[0].T.contiguous().T, a_fp8[1]) - b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].T.contiguous().T, b_fp8[1]) - return a_fp8, b_fp8, c, d, ref_d + return a, b, c, d, ref_d def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: int, k: int, major_a: MajorTypeAB, major_b: MajorTypeAB, - use_ue8m0: bool = False, use_bf16: bool = False): + use_ue8m0: bool = False, use_bf16: bool = False, + use_psum_layout: bool = False, + quant_config: Optional[QuantConfig] = None): actual_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] aligned_ms = [align(actual_m, get_mk_alignment_for_contiguous_layout()) for actual_m in actual_ms] m = sum(aligned_ms) a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) - m_indices = torch.empty(m, device='cuda', dtype=torch.int32) + grouped_layout = torch.empty(num_groups, device='cuda', dtype=torch.int32) if use_psum_layout \ + else torch.empty(m, device='cuda', dtype=torch.int32) d = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) ref_d = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) @@ -198,54 +299,69 @@ def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: for i, (actual_m, aligned_m) in enumerate(zip(actual_ms, aligned_ms)): actual_end = start + actual_m aligned_end = start + aligned_m - m_indices[start:actual_end] = i - m_indices[actual_end:aligned_end] = -1 - ref_d[start:aligned_end] = a[start:aligned_end] @ b[i].t() + if use_psum_layout: + grouped_layout[i] = actual_end + else: + grouped_layout[start: actual_end] = i + grouped_layout[actual_end: aligned_end] = -1 + a[actual_end: aligned_end] = 0 + ref_d[start: aligned_end] = a[start: aligned_end] @ b[i].t() start = aligned_end - ref_d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_d), ref_d) if use_bf16: b = b if major_b.is_k_major() else b.mT.contiguous().mT - return m, a, b, m_indices, d, ref_d + return m, a, b, grouped_layout, d, ref_d assert major_a.is_k_major() - a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0) - b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn), - torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float)) + quant_config = QuantConfig() if quant_config is None else quant_config + a = cast_fp8_fp4_with_major(a, major_a, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0) + b = grouped_cast_fp8_fp4_with_major(b, major_b, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, use_block_cast_for_fp8=True) + + return m, a, b, grouped_layout, d, ref_d + + +def layout_masked_to_psum(x: torch.Tensor, psum_m: torch.Tensor): + num_groups, max_m, _ = x.size() + x_psum = torch.empty_like(x).view(num_groups * max_m, -1) + last_psum_m = 0 for i in range(num_groups): - b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0) - b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].mT.contiguous().mT, b_fp8[1]) - return m, a_fp8, b_fp8, m_indices, d, ref_d + x_psum[last_psum_m: psum_m[i]] = x[i, :psum_m[i] - last_psum_m] + last_psum_m = align(psum_m[i], 128) + return x_psum def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, - use_ue8m0: bool = False, use_bf16: bool = False, enable_overlap: bool = False): + use_ue8m0: bool = False, use_bf16: bool = False, + enable_overlap: bool = False, + use_psum_layout: bool = False, + quant_config: Optional[QuantConfig] = None): a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16) b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16) ref_d = torch.einsum('gmk,gnk->gmn', a, b) masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) + psum_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) for j in range(num_groups): masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) + psum_m[j] = (0 if j == 0 else align(psum_m[j - 1], 128)) + masked_m[j] assert masked_m.amax().item() <= max_m if use_bf16: - return a, b, masked_m, d, ref_d + return a, b, masked_m, psum_m, d, ref_d - a_fp8 = (torch.empty_like(a, dtype=torch.float8_e4m3fn), torch.empty((num_groups, max_m, ceil_div(k, 128)), device='cuda', dtype=torch.float)) - b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float)) - for i in range(num_groups): - a_fp8[0][i], a_fp8[1][i] = per_token_cast_to_fp8(a[i], use_ue8m0=use_ue8m0) - b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0) + quant_config = QuantConfig() if quant_config is None else quant_config + a = grouped_cast_fp8_fp4_with_major(a, MajorTypeAB.KMajor, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0) + b = grouped_cast_fp8_fp4_with_major(b, MajorTypeAB.KMajor, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, use_block_cast_for_fp8=True) max_signal_size = num_groups * ceil_div(max_m, 64) signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') if enable_overlap else None - return a_fp8, b_fp8, masked_m, d, ref_d, signal + return a, b, masked_m, psum_m, d, ref_d, signal -def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: MajorTypeAB, major_b: MajorTypeAB, ks: List[int], use_ue8m0: bool): +def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: MajorTypeAB, major_b: MajorTypeAB, ks: List[int], + use_ue8m0: bool = False, use_bf16: bool = False): assert get_mk_alignment_for_contiguous_layout() % 128 == 0 k = sum(ks) @@ -261,6 +377,10 @@ def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: Majo ref_d[i] = c[i] + (a[start:end].T @ b[start:end]) start = end + if use_bf16: + assert (major_a, major_b) == (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor) + return k, a, b, c, d, ref_d + a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0) b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0) diff --git a/tests/test_attention.py b/tests/test_attention.py index 1baa80f1..b26cf673 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -1,12 +1,18 @@ +import dataclasses import random import torch -from typing import Tuple +from typing import Tuple, List import deep_gemm -from deep_gemm.testing import bench_kineto, calc_diff, count_bytes +from deep_gemm.testing import ( + bench_kineto, + calc_diff, count_bytes, + ignore_env, get_arch_major, + test_filter +) from deep_gemm.utils import ceil_div, per_custom_dims_cast_to_fp8 -from generators import get_arch_major, generate_normal, get_ue8m0_usage, get_kernel_types, MajorTypeAB +from generators import generate_normal, get_ue8m0_usage, get_kernel_types, MajorTypeAB def apply_skip_head_mid(d: torch.Tensor, head_splits: Tuple[int, int, int]): @@ -107,137 +113,169 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, return logits, cost +@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 10) def test_mqa_logits(): print('Testing FP8 MQA Logits:') num_heads, head_dim = 64, 128 for seq_len in (2048, 4096): - for seq_len_kv in (4096, 8192, 16384, 32768, 65536, 131072): - for disable_cp in (False, True): - q = torch.randn(seq_len, num_heads, head_dim, device='cuda', dtype=torch.bfloat16) - kv = torch.randn(seq_len_kv, head_dim, device='cuda', dtype=torch.bfloat16) - weights = torch.randn(seq_len, num_heads, device='cuda', dtype=torch.float32) - - if disable_cp: - ks = torch.zeros(seq_len, dtype=torch.int, device='cuda') - ke = torch.arange(seq_len, dtype=torch.int, device='cuda') + (seq_len_kv - seq_len) - else: - ks, ke = generate_cp_test_data(seq_len, seq_len_kv) - - q_fp8 = q.to(torch.float8_e4m3fn) - kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False) - logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke) - - do_check = (seq_len_kv < 32768) - if do_check: - ref_logits, ref_cost = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) - - ref_neginf_mask = (ref_logits == float('-inf')) - neginf_mask = (logits == float('-inf')) - assert torch.equal(neginf_mask, ref_neginf_mask) - - ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) - logits = logits.masked_fill(neginf_mask, 0) - diff = calc_diff(logits, ref_logits) - assert diff < 1e-3, f"{diff=}" - else: - ref_cost = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke, cost_only=True) - - tflops = 2 * ref_cost * num_heads * head_dim / 1e12 - t, clean_t = bench_kineto(lambda: deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke), - ('fp8_mqa_logits', 'clean_logits')) - clean_bytes = (seq_len * seq_len_kv - ref_cost) * 4 + count_bytes(ks, ke) - print(f' > S={seq_len:4}, SKV={seq_len_kv:6}, H={num_heads:3}, D={head_dim:3}, CP={0 if disable_cp else 1}: ' - f'{tflops / t:4.0f} TFLOPS, {t * 1e6:4.0f} us, ' - f'{(count_bytes(q_fp8, kv_fp8, weights, ks, ke) + ref_cost * 4) / t / 1e9:4.0f} GB/s | ' - f'clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s') + for compressed_logits in (False, True): + for seq_len_kv in (4096, 8192): + for disable_cp in (False, True): + q = torch.randn(seq_len, num_heads, head_dim, device='cuda', dtype=torch.bfloat16) + kv = torch.randn(seq_len_kv, head_dim, device='cuda', dtype=torch.bfloat16) + weights = torch.randn(seq_len, num_heads, device='cuda', dtype=torch.float32) + + if disable_cp: + ks = torch.zeros(seq_len, dtype=torch.int, device='cuda') + ke = torch.arange(seq_len, dtype=torch.int, device='cuda') + (seq_len_kv - seq_len) + else: + ks, ke = generate_cp_test_data(seq_len, seq_len_kv) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False) + + if compressed_logits: + max_seqlen_k = (ke - ks).max().item() + logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke, max_seqlen_k=max_seqlen_k, clean_logits=False) + assert logits.size() == (seq_len, max_seqlen_k) + tmp = torch.full((seq_len, seq_len_kv), float('-inf'), device='cuda') + for i in range(seq_len): + tmp[i, ks[i] : ke[i]] = logits[i, : ke[i] - ks[i]] + logits = tmp + else: + logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke) + + do_check = (seq_len_kv < 32768) + if do_check: + ref_logits, ref_cost = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + + ref_neginf_mask = (ref_logits == float('-inf')) + neginf_mask = (logits == float('-inf')) + assert torch.equal(neginf_mask, ref_neginf_mask) + + ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) + logits = logits.masked_fill(neginf_mask, 0) + diff = calc_diff(logits, ref_logits) + assert diff < 1e-3, f'{diff=}' + else: + ref_cost = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke, cost_only=True) + + tflops = 2 * ref_cost * num_heads * head_dim / 1e12 + if compressed_logits: + t = bench_kineto(lambda: deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke, max_seqlen_k=max_seqlen_k, clean_logits=False), 'fp8_mqa_logits') + else: + t, clean_t = bench_kineto(lambda: deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke), ('fp8_mqa_logits', 'clean_logits')) + clean_bytes = (seq_len * seq_len_kv - ref_cost) * 4 + count_bytes(ks, ke) + print(f' > S={seq_len:4}, SKV={seq_len_kv:6}, H={num_heads:3}, D={head_dim:3}, CP={0 if disable_cp else 1}: ' + f'{tflops / t:4.0f} TFLOPS, {t * 1e6:4.0f} us, ' + f'{(count_bytes(q_fp8, kv_fp8, weights, ks, ke) + ref_cost * 4) / t / 1e9:4.0f} GB/s', end='') + # noinspection PyUnboundLocalVariable + print(f' | clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s' if not compressed_logits else '') print() def ref_fp8_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor, weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, - max_model_len: int): + max_model_len: int, is_context_lens_2d: bool): batch_size, next_n, heads, dim = q.size() num_block, block_size, _, dim = kv_cache.size() logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32) context_lens = context_lens.tolist() for i in range(batch_size): context_len = context_lens[i] - q_offsets = torch.arange(context_len - next_n, context_len, device='cuda') + q_offsets = torch.full((next_n, ), context_len, device='cuda', dtype=torch.int32) if is_context_lens_2d \ + else torch.arange(context_len - next_n, context_len, device='cuda') weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous() - for block_rk in range(ceil_div(context_len, block_size)): - block_idx = block_tables[i][block_rk] - qx, kx = q[i], kv_cache[block_idx] - k_offsets = torch.arange(block_rk * block_size, (block_rk + 1) * block_size, device='cuda') - mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None]) - s = torch.where(mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(logits.dtype), float('-inf')) - s = torch.relu(s) * weight_slice[..., None] - s = s.sum(dim=0) - logits[i * next_n:(i + 1) * next_n, block_rk * block_size: (block_rk + 1) * block_size] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf')) + + num_blocks = (context_len + block_size - 1) // block_size + block_idxs = block_tables[i][:num_blocks] + kv_slice = kv_cache[block_idxs] # [num_blocks, block_size, kv_heads, dim] + kx = kv_slice.permute(2, 3, 0, 1).reshape(kv_slice.size(2), dim, -1) # [kv_heads, dim, total_tokens] + qx = q[i].transpose(0, 1) # q[i]: [next_n, heads, dim] -> [heads, next_n, dim] + s = torch.matmul(qx, kx).to(logits.dtype) # [heads, next_n, dim] @ [1, dim, total_tokens] -> [heads, next_n, total_tokens] + + total_len = num_blocks * block_size + k_offsets = torch.arange(0, total_len, device=q.device) + mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None]) + s = torch.where(mask[None, :, :], s, float('-inf')) # mask shape: [1, next_n, total_tokens] + s = torch.relu(s) * weight_slice[..., None] # weight_slice: [heads, next_n] -> [heads, next_n, 1] + s = s.sum(dim=0) # [next_n, total_tokens] + logits[i * next_n:(i + 1) * next_n, :total_len] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf')) + return logits def test_paged_mqa_logits(): print('Testing FP8 Paged MQA Logits:') max_model_len = 111 * 1000 - for batch_size, next_n in [(64, 1), (64, 2), (128, 1)]: - for heads, index_dim in [(64, 128)]: - for avg_kv in (8192, 32768): - num_blocks, blocksize = max_model_len * 3, 64 - - q = torch.randn((batch_size, next_n, heads, index_dim), device='cuda', dtype=torch.bfloat16) - kv_cache = torch.randn((num_blocks, blocksize, 1, index_dim), device='cuda', dtype=torch.bfloat16) - weights = torch.randn((batch_size * next_n, heads), device='cuda', dtype=torch.float32) - - context_lens = torch.randint(int(0.7 * avg_kv), int(1.3 * avg_kv), (batch_size, )).cuda().to(torch.int32) - max_block_len = (context_lens.max().item() + blocksize - 1) // blocksize * blocksize - block_tables = torch.zeros((batch_size, max_block_len), device='cuda', dtype=torch.int32) - - counter = 0 - block_idx_pool = list(range(num_blocks)) - random.shuffle(block_idx_pool) - for i in range(batch_size): - ctx_len = context_lens[i].item() - for j in range(ceil_div(ctx_len, blocksize)): - block_tables[i][j] = block_idx_pool[counter] - counter += 1 - - q_fp8 = q.to(torch.float8_e4m3fn) - kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) - - schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens, blocksize, deep_gemm.get_num_sms()) - logits = deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True) - - ref_logits = ref_fp8_paged_mqa_logits(q, kv_cache, weights, context_lens, block_tables, max_model_len) - positions = torch.arange(max_model_len, device='cuda').unsqueeze(0).expand(batch_size * next_n, -1) - row_indices = torch.arange(batch_size * next_n, device='cuda') // next_n - next_n_offset = torch.arange(batch_size * next_n, device='cuda') % next_n - ref_neginf_mask = ~(positions <= (context_lens[row_indices] - next_n + next_n_offset).unsqueeze(1)) - - neginf_mask = (logits == float('-inf')) - assert torch.equal(neginf_mask, ref_neginf_mask) - - logits = logits.masked_fill(neginf_mask, 0) - ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) - diff = calc_diff(logits, ref_logits) - assert diff < 1e-3, f"{diff=}" - - sum_lens = sum(context_lens.to(torch.int64)) - tflops = 2 * sum_lens * next_n * heads * index_dim / 1e12 - input_bytes = count_bytes(q_fp8, weights, context_lens) + sum_lens * (index_dim + 4) + (sum_lens / blocksize) * 4 - output_bytes = sum_lens * next_n * 4 - t, clean_t = bench_kineto(lambda: deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True), - ('fp8_paged_mqa_logits', 'clean_logits')) - clean_bytes = (batch_size * next_n * max_model_len - neginf_mask.sum().item()) * 4 + count_bytes(context_lens) - print(f' > BSZ={batch_size:3}, NextN={next_n:1}, H={heads:2}, D={index_dim:2}, L={avg_kv:6}: ' - f'{tflops / t:4.0f} TFLOPS, {t * 1e6:3.0f} us, ' - f'{(input_bytes + output_bytes) / t / 1e9:4.0f} GB/s | ' - f'clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s') + for is_context_lens_2d in (False, True): + for batch_size, next_n in [(64, 1), (64, 2), (128, 1)]: + for heads, index_dim in [(64, 128)]: + for avg_kv in (8192, 32768): + num_blocks, blocksize = max_model_len * 3, 64 + + q = torch.randn((batch_size, next_n, heads, index_dim), device='cuda', dtype=torch.bfloat16) + kv_cache = torch.randn((num_blocks, blocksize, 1, index_dim), device='cuda', dtype=torch.bfloat16) + weights = torch.randn((batch_size * next_n, heads), device='cuda', dtype=torch.float32) + q_fp8 = q.to(torch.float8_e4m3fn) + kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) + + context_lens = torch.randint(int(0.7 * avg_kv), int(1.3 * avg_kv), (batch_size, )).cuda().to(torch.int32) + context_lens_list = context_lens.tolist() + max_block_len = (max(context_lens_list) + blocksize - 1) // blocksize * blocksize + block_tables = torch.zeros((batch_size, max_block_len), device='cuda', dtype=torch.int32) + + counter, block_idx_pool = 0, torch.randperm(num_blocks, device='cuda', dtype=torch.int32) + for i in range(batch_size): + num_blocks = ceil_div(context_lens_list[i], blocksize) + block_tables[i][:num_blocks] = block_idx_pool[counter: counter+num_blocks] + counter += num_blocks + + ref_logits = ref_fp8_paged_mqa_logits(q, kv_cache, weights, context_lens, block_tables, max_model_len, is_context_lens_2d) + positions = torch.arange(max_model_len, device='cuda').unsqueeze(0).expand(batch_size * next_n, -1) + + if is_context_lens_2d: + context_lens_2d = ((context_lens.unsqueeze(1) + 1) * torch.rand(batch_size, next_n, device='cuda')).int() + context_lens_2d[:, next_n-1] = context_lens + schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens_2d, blocksize, deep_gemm.get_num_sms()) + logits = deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens_2d, block_tables, schedule_metadata, max_model_len, clean_logits=False) + ref_neginf_mask = ~(positions < context_lens_2d.view(-1).unsqueeze(1)) + else: + schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens, blocksize, deep_gemm.get_num_sms()) + logits = deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True) + row_indices = torch.arange(batch_size * next_n, device='cuda') // next_n + next_n_offset = torch.arange(batch_size * next_n, device='cuda') % next_n + ref_neginf_mask = ~(positions <= (context_lens[row_indices] - next_n + next_n_offset).unsqueeze(1)) + neginf_mask = (logits == float('-inf')) + assert torch.equal(neginf_mask, ref_neginf_mask) + + logits = logits.masked_fill(ref_neginf_mask, 0) + ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) + diff = calc_diff(logits, ref_logits) + assert diff < 1e-3, f"{diff=}" + + sum_lens = sum(context_lens.to(torch.int64)) + tflops = 2 * sum_lens * next_n * heads * index_dim / 1e12 + input_bytes = count_bytes(q_fp8, weights, context_lens) + sum_lens * (index_dim + 4) + (sum_lens / blocksize) * 4 + output_bytes = sum_lens * next_n * 4 + if is_context_lens_2d: + t = bench_kineto(lambda: deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens_2d, block_tables, schedule_metadata, max_model_len, clean_logits=False), + 'fp8_paged_mqa_logits') + else: + t, clean_t = bench_kineto(lambda: deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True), + ('fp8_paged_mqa_logits', 'clean_logits')) + clean_bytes = (batch_size * next_n * max_model_len - neginf_mask.sum().item()) * 4 + count_bytes(context_lens) + print(f' > BSZ={batch_size:3}, NextN={next_n:1}, H={heads:2}, D={index_dim:2}, L={avg_kv:6}: ' + f'{tflops / t:4.0f} TFLOPS, {t * 1e6:3.0f} us, ' + f'{(input_bytes + output_bytes) / t / 1e9:4.0f} GB/s', end='') + # noinspection PyUnboundLocalVariable + print(f' | clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s' if not is_context_lens_2d else '') print() + + if __name__ == '__main__': - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True torch.manual_seed(0) random.seed(0) diff --git a/tests/test_bf16.py b/tests/test_bf16.py index 34e0b42e..1a3b0467 100644 --- a/tests/test_bf16.py +++ b/tests/test_bf16.py @@ -1,5 +1,7 @@ -import torch +import copy +import numpy as np import random +import torch import deep_gemm from deep_gemm.testing import ( @@ -7,19 +9,16 @@ calc_diff, count_bytes ) from generators import ( - get_arch_major, - enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, generate_normal, - generate_m_grouped_contiguous, generate_m_grouped_masked + get_arch_major, layout_masked_to_psum, align, + enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous, + generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous ) def test_gemm() -> None: print('Testing GEMM:') - for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.bfloat16): - # TODO: support accumulation for SM90 BF16 GEMM - if get_arch_major() == 9 and accumulate: - continue - + scores = [] + for kernel_type, _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.bfloat16): major_opt = 'N' if major_a.is_k_major() else 'T' major_opt += 'T' if major_b.is_k_major() else 'N' out_opt = 'FP32' if out_dtype == torch.float else 'BF16' @@ -34,46 +33,49 @@ def test_gemm() -> None: assert a.is_contiguous() and b.is_contiguous() getattr(deep_gemm, func_name)(a, b, d, c=c) diff = calc_diff(d, ref_d) - assert diff < 0.0001, (f'{m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=}, ' + assert diff < 1e-5, (f'{m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=}, ' f'{diff:.5f}, alias={test_alias}') a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True) t = bench_kineto(lambda: deep_gemm.bf16_gemm_nt(a, b, d, c=c), 'bf16_gemm', suppress_kineto_output=True) cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True) print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, layout={major_opt}, {out_opt}, {acc_opt}): ' - f'{t * 1e6:5.0f} us | ' + f'{t * 1e6:7.1f} us | ' f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | ' f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS') - print() + if cublas_t > 0: + scores.append((cublas_t + split_k_t) / t) + print(f"Average speedup over cuBLASLt: {float(np.prod(scores)) ** (1.0 / len(scores)):.3f}x\n") def test_m_grouped_gemm_contiguous() -> None: print('Testing m-grouped contiguous GEMM:') - for _, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(torch.bfloat16): + for _, _, num_groups, expected_m_per_group, n, k, major_a, major_b, use_psum_layout in enumerate_m_grouped_contiguous(torch.bfloat16): major_opt = 'N' if major_a.is_k_major() else 'T' major_opt += 'T' if major_b.is_k_major() else 'N' for test_alias in (False, True): - m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True) + m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, + use_bf16=True, use_psum_layout=use_psum_layout) func_name = f"m_grouped_bf16_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous" if test_alias: assert major_a.is_k_major() b = b if major_b.is_k_major() else b.mT assert a[0].is_contiguous() and b[0].is_contiguous() - getattr(deep_gemm, func_name)(a, b, d, m_indices) - d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(d), d) + getattr(deep_gemm, func_name)(a, b, d, grouped_layout, use_psum_layout=use_psum_layout) diff = calc_diff(d, ref_d) - assert diff < 0.001, f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}' - m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True) + assert diff < 1e-5, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}, alias={test_alias}' + m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, + use_bf16=True, use_psum_layout=use_psum_layout) # noinspection PyShadowingNames def test_func(): - deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a, b, d, m_indices) + deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout, use_psum_layout=use_psum_layout) t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}): ' + print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}, psum={use_psum_layout}): ' f'{t * 1e6:4.0f} us | ' f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') @@ -84,35 +86,89 @@ def test_m_grouped_gemm_masked() -> None: print('Testing m-grouped masked GEMM:') # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. - for _, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(torch.bfloat16): - # Test correctness - for i in range(10): - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True) - deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group) + for _, _, num_groups, max_m, expected_m_per_group, n, k, use_psum_layout in enumerate_m_grouped_masked(torch.bfloat16): + num_tests = 8 + sum_t, max_t = 0, 0 + sum_ops, sum_bytes = 0, 0 + + for i in range(num_tests): + a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, + use_bf16=True, use_psum_layout=use_psum_layout) + if use_psum_layout: + a_psum = layout_masked_to_psum(a, psum_m) + d_psum = layout_masked_to_psum(d, psum_m) + + # noinspection PyShadowingNames + def test_func(): + if use_psum_layout: + deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a_psum, b, d_psum, psum_m, + use_psum_layout=True, expected_m_for_psum_layout=expected_m_per_group) + else: + deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group) + + test_func() for j in range(num_groups): - diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) - assert diff < 0.001, f'{m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' + if masked_m[j].item() == 0: + continue + if use_psum_layout: + d_slice = d_psum[: psum_m[j]] if j == 0 else d_psum[align(psum_m[j - 1], 128): psum_m[j]] + else: + d_slice = d[j, :masked_m[j].item()] + diff = calc_diff(d_slice, ref_d[j, :masked_m[j].item()]) + assert diff < 1e-5, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' + + + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True) + + sum_t += t + max_t = max(max_t, t) + sum_ops += 2 * valid_m * n * k + sum_bytes += count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b) + + print(f' > Perf (num_groups={num_groups:2}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, ' + f'psum={1 if use_psum_layout else 0}): ' + f'{sum_t / num_tests * 1e6:4.0f} us (max: {max_t * 1e6:3.0f} us) | ' + f'{sum_ops / sum_t / 1e12:4.0f} TFLOPS | ' + f'{sum_bytes / sum_t / 1e9:4.0f} GB/s') + print() - # Construct full cases - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True) + +def test_k_grouped_gemm_contiguous() -> None: + print('Testing k-grouped contiguous GEMM:') + + for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous(torch.bfloat16): + for test_empty_groups in (False, True): + new_ks = copy.deepcopy(ks) + if test_empty_groups and len(ks) > 1: + new_ks[random.randint(0, num_groups - 1)] = 0 + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_bf16=True) + new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda') + deep_gemm.k_grouped_bf16_gemm_tn_contiguous(a, b, d, new_ks, new_ks_tensor, c) + + diff = calc_diff(d, ref_d) + assert diff < 1e-5, f'{m=}, {n=}, {k=}, {ks=}, {diff:.7f}' + + # Test performance + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_bf16=True) + ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') # noinspection PyShadowingNames def test_func(): - deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group) + deep_gemm.k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c) - # Test performance with fixed shapes - valid_m = masked_m.sum().item() t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): ' + print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): ' f'{t * 1e6:4.0f} us | ' - f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' - f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s') print() def test_cublaslt_gemm() -> None: print('Testing cuBLASLt GEMM:') - for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(dtype=torch.bfloat16): + for kernel_type, _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(dtype=torch.bfloat16): major_opt = 'N' if major_a.is_k_major() else 'T' major_opt += 'T' if major_b.is_k_major() else 'N' out_opt = 'FP32' if out_dtype == torch.float else 'BF16' @@ -121,9 +177,10 @@ def test_cublaslt_gemm() -> None: a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True) deep_gemm.cublaslt_gemm_nt(a, b, d, c=c) diff = calc_diff(d, ref_d) - assert diff < 5e-7, f'{diff=}, ({m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=})' + assert diff < 6e-7, f'{diff=}, ({m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=})' - t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), 'nvjet', suppress_kineto_output=True,) + t_nvjet, t_gemv, t_gemm = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), ('nvjet', 'gemv', 'gemm'), suppress_kineto_output=True) + t = t_nvjet + t_gemv + t_gemm print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, layout={major_opt}, {out_opt}, {acc_opt}): ' f'{t * 1e6:5.0f} us | ' f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' @@ -132,18 +189,16 @@ def test_cublaslt_gemm() -> None: if __name__ == '__main__': - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True torch.manual_seed(0) random.seed(0) print('Library path:') print(f' > {deep_gemm.__path__}\n') - test_gemm() - # TODO: support SM100 - if get_arch_major() == 9: + if get_arch_major() >= 9: + test_gemm() test_m_grouped_gemm_contiguous() test_m_grouped_gemm_masked() + test_k_grouped_gemm_contiguous() test_cublaslt_gemm() diff --git a/tests/test_einsum.py b/tests/test_einsum.py index a97e9d0c..0fcaae60 100644 --- a/tests/test_einsum.py +++ b/tests/test_einsum.py @@ -3,8 +3,13 @@ import deep_gemm from deep_gemm.testing import ( - bench, bench_kineto, - calc_diff, count_bytes + bench_kineto, + calc_diff, count_bytes, + get_arch_major, test_filter +) +from deep_gemm.utils.math import ( + ceil_div, + per_block_cast_to_fp8, per_channel_cast_to_fp8, per_token_cast_to_fp8 ) @@ -44,8 +49,8 @@ def test_bmk_bnk_mn() -> None: def test_bhr_hdr_bhd(): print('Testing "bhr, hdr -> bhd":') - for b in (128, 4096, 8192): - for h, r, d in [(128, 512, 128)]: + for h, r, d in [(128, 512, 128), (8, 4096, 1024)]: + for b in (4, 32, 128, 4096, 8192): x = torch.randn((b, h, r), device='cuda', dtype=torch.bfloat16) fy = torch.randn((h, d, r + 128), device='cuda', dtype=torch.bfloat16) y = fy[:, :, :r] @@ -54,18 +59,20 @@ def test_bhr_hdr_bhd(): deep_gemm.einsum('bhr,hdr->bhd', x, y, z) assert calc_diff(z, ref_z) < 1e-10 - t = bench_kineto(lambda: deep_gemm.einsum('bhr,hdr->bhd', x, y, z), 'nvjet', suppress_kineto_output=True) + t = bench_kineto(lambda: deep_gemm.einsum('bhr,hdr->bhd', x, y, z), 'gemm', suppress_kineto_output=True) + t_cublaslt = bench_kineto(lambda: deep_gemm.einsum('bhr,hdr->bhd', x, y, z, use_cublaslt=True), 'nvjet', suppress_kineto_output=True) print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', f'{t * 1e6:4.0f} us | ' - f'{2 * b * h * r * d / t / 1e12:.0f} TFLOPS | ' - f'{count_bytes((x, y, z)) / t / 1e9:.0f} GB/s') + f'{2 * b * h * r * d / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes((x, y, z)) / t / 1e9:4.0f} GB/s | ' + f'{t_cublaslt / t:4.2f} x') print() def test_bhd_hdr_bhr(): print('Testing "bhd, hdr -> bhr":') - for b in (128, 4096, 8192): - for h, r, d in [(128, 512, 128)]: + for h, r, d in [(128, 512, 128), (8, 4096, 1024)]: + for b in (4, 32, 128, 4096, 8192): x = torch.randn((b, h, d), device='cuda', dtype=torch.bfloat16) fy = torch.randn((h, d, r + 128), device='cuda', dtype=torch.bfloat16) y = fy[:, :, :r] @@ -74,17 +81,102 @@ def test_bhd_hdr_bhr(): deep_gemm.einsum('bhd,hdr->bhr', x, y, z) assert calc_diff(z, ref_z) < 1e-10 - t = bench_kineto(lambda: deep_gemm.einsum('bhd,hdr->bhr', x, y, z), 'nvjet', suppress_kineto_output=True) + t = bench_kineto(lambda: deep_gemm.einsum('bhd,hdr->bhr', x, y, z), 'gemm', suppress_kineto_output=True) + t_cublaslt = bench_kineto(lambda: deep_gemm.einsum('bhd,hdr->bhr', x, y, z, use_cublaslt=True), 'nvjet', suppress_kineto_output=True) + print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', + f'{t * 1e6:4.0f} us | ' + f'{2 * b * h * r * d / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes((x, y, z)) / t / 1e9:4.0f} GB/s | ' + f'{t_cublaslt / t:4.2f} x') + print() + + +def test_fp8_bhr_hdr_bhd(use_ue8m0: bool = True): + print('Testing FP8 "bhr, hdr -> bhd":') + for h, r, d in [(8, 4096, 1024)]: + for b in (4, 32, 128, 4096, 8192): + x = torch.randn((b, h, r), device='cuda', dtype=torch.bfloat16) + y = torch.randn((h, d, r), device='cuda', dtype=torch.bfloat16) + ref_z = torch.einsum('bhr,hdr->bhd', x, y) + + x_fp8 = per_token_cast_to_fp8(x.view(-1, r), use_ue8m0=use_ue8m0) + x_fp8 = x_fp8[0].view(b, h, r), x_fp8[1].view(b, h, ceil_div(r, 128)) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), + torch.empty((h, ceil_div(d, 128), ceil_div(r, 128)), device='cuda', dtype=torch.float)) + for i in range(h): + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], use_ue8m0=use_ue8m0) + z = torch.empty((b, h, d), device='cuda', dtype=torch.bfloat16) + + deep_gemm.fp8_einsum('bhr,hdr->bhd', x_fp8, y_fp8, z) + assert calc_diff(z, ref_z) < 1e-3 + + t = bench_kineto(lambda: deep_gemm.fp8_einsum('bhr,hdr->bhd', x_fp8, y_fp8, z), 'fp8_gemm', suppress_kineto_output=True) + t_cublaslt = bench_kineto(lambda: deep_gemm.einsum('bhr,hdr->bhd', x, y, z, use_cublaslt=True), 'nvjet', suppress_kineto_output=True) + print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', + f'{t * 1e6:4.0f} us | ' + f'{2 * b * h * r * d / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes((x_fp8, y_fp8, z)) / t / 1e9:4.0f} GB/s | ' + f'{t_cublaslt / t:4.2f} x') + print() + + +@test_filter(lambda: get_arch_major() >= 10) +def test_fp8_bhd_hdr_bhr(use_ue8m0: bool = True): + print('Testing FP8 "bhd, hdr -> bhr":') + for h, r, d in [(8, 4096, 1024)]: + for b in (4, 32, 128, 4096, 8192): + x = torch.randn((b, h, d), device='cuda', dtype=torch.bfloat16) + y = torch.randn((h, d, r), device='cuda', dtype=torch.bfloat16) + ref_z = torch.einsum('bhd,hdr->bhr', x, y) + + x_fp8 = per_token_cast_to_fp8(x.view(-1, d), use_ue8m0=use_ue8m0) + x_fp8 = x_fp8[0].view(b, h, d), x_fp8[1].view(b, h, ceil_div(d, 128)) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), + torch.empty((h, ceil_div(d, 128), ceil_div(r, 128)), device='cuda', dtype=torch.float)) + for i in range(h): + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], use_ue8m0=use_ue8m0) + z = torch.empty((b, h, r), device='cuda', dtype=torch.bfloat16) + + deep_gemm.fp8_einsum('bhd,hdr->bhr', x_fp8, y_fp8, z) + assert calc_diff(z, ref_z) < 1e-3 + + t = bench_kineto(lambda: deep_gemm.fp8_einsum('bhd,hdr->bhr', x_fp8, y_fp8, z), 'fp8_gemm', suppress_kineto_output=True) + t_cublaslt = bench_kineto(lambda: deep_gemm.einsum('bhd,hdr->bhr', x, y, z, use_cublaslt=True), 'nvjet', suppress_kineto_output=True) print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', f'{t * 1e6:4.0f} us | ' - f'{2 * b * h * r * d / t / 1e12:.0f} TFLOPS | ' - f'{count_bytes((x, y, z)) / t / 1e9:.0f} GB/s') + f'{2 * b * h * r * d / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes((x_fp8, y_fp8, z)) / t / 1e9:4.0f} GB/s | ' + f'{t_cublaslt / t:4.2f} x') + print() + + +@test_filter(lambda: get_arch_major() >= 10) +def test_fp8_bhd_bhr_hdr(use_ue8m0: bool = True): + print('Testing FP8 "bhd, bhr -> hdr":') + for h, r, d in [(8, 4096, 1024)]: + for b in (4096, 8192): + x = torch.randn((b, h, d), device='cuda', dtype=torch.bfloat16) + y = torch.randn((b, h, r), device='cuda', dtype=torch.bfloat16) + z_0 = torch.randn((h, d, r), device='cuda', dtype=torch.float32) * 10 + ref_z = z_0 + torch.einsum('bhd,bhr->hdr', x, y) + + x_fp8 = per_channel_cast_to_fp8(x.view(b, -1), use_ue8m0=use_ue8m0) + y_fp8 = per_channel_cast_to_fp8(y.view(b, -1), use_ue8m0=use_ue8m0) + x_fp8 = (x_fp8[0].view(b, h, d), x_fp8[1].view(ceil_div(b, 128), h, d)) + y_fp8 = (y_fp8[0].view(b, h, r), y_fp8[1].view(ceil_div(b, 128), h, r)) + z = z_0.clone() + deep_gemm.fp8_einsum('bhd,bhr->hdr', x_fp8, y_fp8, z, z, recipe=(1, 1, 128)) + assert calc_diff(z, ref_z) < 1e-3 + + t = bench_kineto(lambda: deep_gemm.fp8_einsum('bhd,bhr->hdr', x_fp8, y_fp8, z, z, recipe=(1, 1, 128)), 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', + f'{t * 1e6:4.0f} us | ' + f'{2 * b * h * r * d / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes((x_fp8, y_fp8, z, z)) / t / 1e9:4.0f} GB/s') print() if __name__ == '__main__': - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True torch.manual_seed(0) random.seed(0) @@ -95,3 +187,7 @@ def test_bhd_hdr_bhr(): if nvjet_accessable(): test_bhr_hdr_bhd() test_bhd_hdr_bhr() + + test_fp8_bhr_hdr_bhd() + test_fp8_bhd_hdr_bhr() + test_fp8_bhd_bhr_hdr() diff --git a/tests/test_fp8_fp4.py b/tests/test_fp8_fp4.py new file mode 100644 index 00000000..f7e3e1c4 --- /dev/null +++ b/tests/test_fp8_fp4.py @@ -0,0 +1,207 @@ +import copy +import numpy as np +import random +import torch + +import deep_gemm +from deep_gemm.testing import ( + bench_kineto, + calc_diff, count_bytes, + ignore_env, get_arch_major +) + +from generators import ( + KernelType, get_ue8m0_usage, layout_masked_to_psum, align, + enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous, + generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous +) + + +@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 9) +def test_gemm() -> None: + print('Testing GEMM:') + scores = [] + for kernel_type, quant_config, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.float8_e4m3fn): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + out_opt = 'FP32' if out_dtype == torch.float else 'BF16' + acc_opt = f'acc={int(accumulate)}' + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + recipe, recipe_a, recipe_b = quant_config.get_recipes(is_wgrad=(kernel_type.is_1d1d() and accumulate)) + + for test_alias in (False, True): + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0, quant_config=quant_config) + func_name = f'fp8_fp4_gemm_{major_opt.lower() if test_alias else "nt"}' + if test_alias: + a = a if major_a.is_k_major() else (a[0].T, a[1].T) + b = b if major_b.is_k_major() else (b[0].T, b[1].T) + assert a[0].is_contiguous() and b[0].is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + diff = calc_diff(d, ref_d) + assert diff < quant_config.max_diff(), (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, ' + f'{diff:.5f}, alias={test_alias}') + + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0, quant_config=quant_config) + t = bench_kineto(lambda: deep_gemm.fp8_fp4_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b), + 'fp8_gemm', suppress_kineto_output=True) + cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a[0], b[0], d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True) \ + if not quant_config.is_fp4_a and not quant_config.is_fp4_b else (0, 0) + print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): ' + f'{t * 1e6:6.1f} us | {2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | ' + f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS') + if cublas_t > 0: + scores.append((cublas_t + split_k_t) / t) + print(f"Average FP8xFP8 GEMM speedup over cuBLASLt: {float(np.prod(scores)) ** (1.0 / len(scores)):.3f}x\n") + + +def test_m_grouped_gemm_contiguous() -> None: + print('Testing m-grouped contiguous GEMM:') + + for kernel_type, quant_config, num_groups, expected_m_per_group, n, k, major_a, major_b, use_psum_layout in enumerate_m_grouped_contiguous(dtype=torch.float8_e4m3fn): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + recipe, recipe_a, recipe_b = quant_config.get_recipes() + + for test_alias in (False, True): + m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, + use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout, + quant_config=quant_config) + func_name = f"m_grouped_fp8_fp4_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous" + if test_alias: + assert major_a.is_k_major() + b = b if major_b.is_k_major() else (b[0].mT, b[1].mT) + assert a[0].is_contiguous() and b[0].is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, grouped_layout, disable_ue8m0_cast=disable_ue8m0_cast, use_psum_layout=use_psum_layout, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + diff = calc_diff(d, ref_d) + assert diff < quant_config.max_diff(), f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}' + m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, + use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout, + quant_config=quant_config) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(a, b, d, grouped_layout, disable_ue8m0_cast=disable_ue8m0_cast, use_psum_layout=use_psum_layout, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, m={m:5}, n={n:6}, k={k:5}, {kernel_opt}, layout={major_opt}, psum={use_psum_layout}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') + print() + + +def test_m_grouped_gemm_masked() -> None: + print('Testing m-grouped masked GEMM:') + + # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. + for kernel_type, quant_config, num_groups, max_m, expected_m_per_group, n, k, use_psum_layout in enumerate_m_grouped_masked(torch.float8_e4m3fn): + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + recipe, recipe_a, recipe_b = quant_config.get_recipes() + + num_tests = 8 + sum_t, max_t = 0, 0 + sum_ops, sum_bytes = 0, 0 + + for i in range(num_tests): + a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, + use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout, + quant_config=quant_config) + if use_psum_layout: + a_psum = (layout_masked_to_psum(a[0], psum_m), layout_masked_to_psum(a[1], psum_m)) + d_psum = layout_masked_to_psum(d, psum_m) + + # noinspection PyShadowingNames + def test_func(): + if use_psum_layout: + deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(a_psum, b, d_psum, psum_m, disable_ue8m0_cast=disable_ue8m0_cast, + use_psum_layout=True, expected_m_for_psum_layout=expected_m_per_group, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + else: + deep_gemm.m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + + test_func() + for j in range(num_groups): + if masked_m[j].item() == 0: + continue + if use_psum_layout: + d_slice = d_psum[: psum_m[j]] if j == 0 else d_psum[align(psum_m[j - 1], 128): psum_m[j]] + else: + d_slice = d[j, :masked_m[j].item()] + diff = calc_diff(d_slice, ref_d[j, :masked_m[j].item()]) + assert diff < quant_config.max_diff(), f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' + + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + + sum_t += t + max_t = max(max_t, t) + sum_ops += 2 * valid_m * n * k + sum_bytes += count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b) + + print(f' > Perf (num_groups={num_groups:2}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, ' + f'{kernel_opt}, psum={1 if use_psum_layout else 0}): ' + f'{sum_t / num_tests * 1e6:4.0f} us (max: {max_t * 1e6:3.0f} us) | ' + f'{sum_ops / sum_t / 1e12:4.0f} TFLOPS | ' + f'{sum_bytes / sum_t / 1e9:4.0f} GB/s') + print() + + +@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 9) +def test_k_grouped_gemm_contiguous() -> None: + print('Testing k-grouped contiguous GEMM:') + + k_grouped_fp8_gemm_contiguous = deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9 \ + else deep_gemm.k_grouped_fp8_gemm_tn_contiguous + for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous(torch.float8_e4m3fn): + use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D) + + for test_empty_groups in (False, True): + new_ks = copy.deepcopy(ks) + if test_empty_groups and len(ks) > 1: + new_ks[random.randint(0, num_groups - 1)] = 0 + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_ue8m0=use_ue8m0) + new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda') + k_grouped_fp8_gemm_contiguous(a, b, d, new_ks, new_ks_tensor, c) + + diff = calc_diff(d, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {ks=}, {diff:.5f}' + + # Test performance + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0) + ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') + + # noinspection PyShadowingNames + def test_func(): + k_grouped_fp8_gemm_contiguous(a, b, d, ks, ks_tensor, c) + + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_gemm() + test_m_grouped_gemm_contiguous() + test_m_grouped_gemm_masked() + test_k_grouped_gemm_contiguous() diff --git a/tests/test_hyperconnection.py b/tests/test_hyperconnection.py new file mode 100644 index 00000000..24faf22c --- /dev/null +++ b/tests/test_hyperconnection.py @@ -0,0 +1,57 @@ +import torch +import random + +import deep_gemm +from deep_gemm.testing import ( + test_filter, + bench_kineto, + calc_diff, count_bytes +) +from deep_gemm.utils import align +from generators import get_arch_major + + +@test_filter(lambda: get_arch_major() >= 9) +def test_hc_prenorm_gemm() -> None: + # Needs TF32 precision for PyTorch GEMMs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + print('Testing hyperconnection prenorm GEMM:') + for m in (13, 137, 4096, 8192): + for n, k in [(24, 28672), (24, 7680), (24, 7168)]: + for num_splits in [None, 16]: + a = torch.randn((m, k), dtype=torch.bfloat16, device='cuda') + b = torch.randn((n, k), dtype=torch.float, device='cuda') + d = torch.empty((m, n), dtype=torch.float, device='cuda') if num_splits is None else \ + torch.empty((num_splits, m, n), dtype=torch.float, device='cuda') + s = torch.empty((m, ), dtype=torch.float, device='cuda') if num_splits is None else \ + torch.empty((num_splits, m), dtype=torch.float, device='cuda') + deep_gemm.tf32_hc_prenorm_gemm(a, b, d, s, num_splits=num_splits) + final_d = d if num_splits is None else d.sum(0) + final_s = s if num_splits is None else s.sum(0) + + ref_d = a.float() @ b.T + ref_s = a.float().square().sum(-1) + + diff = max(calc_diff(final_d, ref_d), calc_diff(final_s, ref_s)) + assert diff < 1e-8, f'{m=}, {n=}, {k=}, {diff:.10f}' + + t = bench_kineto(lambda: deep_gemm.tf32_hc_prenorm_gemm(a, b, d, s, num_splits=num_splits), 'tf32_hc_prenorm_gemm', suppress_kineto_output=True) + print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, num_splits={(num_splits or 0):2}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d, s) / 1e9 / t:4.0f} GB/s') + print() + + + + +if __name__ == '__main__': + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_hc_prenorm_gemm() diff --git a/tests/test_layout.py b/tests/test_layout.py index 42d7208b..7875733a 100644 --- a/tests/test_layout.py +++ b/tests/test_layout.py @@ -1,7 +1,6 @@ -import time import torch import random -from deep_gemm.testing import bench_kineto, count_bytes, calc_diff +from deep_gemm.testing import bench_kineto, count_bytes from deep_gemm.utils import ( align, ceil_div, per_token_cast_to_fp8, per_channel_cast_to_fp8, @@ -12,7 +11,6 @@ ) from generators import ( - enumerate_transpose, enumerate_sf_layout, enumerate_k_grouped_sf_layout ) @@ -107,8 +105,6 @@ def test_k_grouped_sf_layout_kernels() -> None: if __name__ == '__main__': - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True torch.manual_seed(1) random.seed(1) diff --git a/tests/test_legacy.py b/tests/test_legacy.py new file mode 100644 index 00000000..4456799f --- /dev/null +++ b/tests/test_legacy.py @@ -0,0 +1,90 @@ +import torch +import random + +import deep_gemm +from deep_gemm.testing import ( + bench_kineto, + calc_diff, count_bytes +) +from generators import ( + enumerate_m_grouped_contiguous, enumerate_k_grouped_contiguous, + generate_m_grouped_contiguous, generate_k_grouped_contiguous, +) + +def test_m_grouped_gemm_contiguous_tl() -> None: + print('Testing m-grouped contiguous Triton GEMM:') + for _, _, num_groups, expected_m_per_group, n, k, major_a, major_b, _ in enumerate_m_grouped_contiguous(torch.bfloat16): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + + for expand in (False, True): + for test_alias in (False, True): + m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True) + func_name = f"{'a_fused_' if expand else ''}m_grouped_bf16_gemm_{major_opt.lower() if test_alias else 'nt'}_contiguous_tl" + if test_alias: + assert major_a.is_k_major() + b = b if major_b.is_k_major() else b.mT + assert a[0].is_contiguous() and b[0].is_contiguous() + if expand: + m_row_indices = torch.arange(0, m, dtype=torch.int32, device='cuda') + getattr(deep_gemm.legacy, func_name)(a, b, d, (m_indices, m_row_indices)) + else: + getattr(deep_gemm.legacy, func_name)(a, b, d, m_indices) + d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(d), d) + diff = calc_diff(d, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}, alias={test_alias}' + m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.legacy.m_grouped_bf16_gemm_nt_contiguous_tl(a, b, d, m_indices) + + t = bench_kineto(test_func, 'm_grouped_bf16_gemm_contiguous_tl_impl', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') + print() + + +def test_k_grouped_gemm_contiguous_tl() -> None: + print('Testing k-grouped contiguous Triton GEMM:') + for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous(torch.bfloat16): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + + for fused_operand in ('a', 'b'): + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=False, use_bf16=True) + func_name = f"{fused_operand}_fused_k_grouped_bf16_gemm_{major_opt.lower()}_contiguous_tl" + k_indices = torch.arange(0, k, dtype=torch.int32, device='cuda') + k_start = torch.empty(len(ks), dtype=torch.int32, device='cuda') + k_end = torch.empty(len(ks), dtype=torch.int32, device='cuda') + for i, group_k in enumerate(ks): + k_start[i] = k_end[i-1] if i > 0 else 0 + k_end[i] = k_start[i] + group_k + getattr(deep_gemm.legacy, func_name)(a, b, c, (k_indices, k_start, k_end), True) + diff = calc_diff(c, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}' + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=False, use_bf16=True) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.legacy.b_fused_k_grouped_bf16_gemm_tn_contiguous_tl(a, b, c, (k_indices, k_start, k_end), True) + + t = bench_kineto(test_func, 'b_fused_k_grouped_bf16_gemm_contiguous_tl_impl', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_m_grouped_gemm_contiguous_tl() + test_k_grouped_gemm_contiguous_tl() diff --git a/tests/test_sanitizer.py b/tests/test_sanitizer.py new file mode 100644 index 00000000..b063e6c4 --- /dev/null +++ b/tests/test_sanitizer.py @@ -0,0 +1,78 @@ +import argparse +import importlib +import inspect +import os +import subprocess +import sys + +import deep_gemm + + +# Single test template +script_dir = os.path.dirname(os.path.abspath(__file__)) +test_template = """ +import random +import sys +import torch + +# Necessary for `generators.py` +sys.path.append('{script_dir}') + +torch.manual_seed(0) +random.seed(0) + +from tests.{module_name} import {func_name} +{func_name}() +""" + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--funcs', type=str, default='all') + parser.add_argument('--tools', type=str, default='memcheck,synccheck') + args = parser.parse_args() + + if args.funcs != 'all': + funcs = [] + for name in [x.strip() for x in args.funcs.split(',')]: + module_name, func_name = name.split('.') + funcs.append((module_name, func_name)) + else: + # Get all test functions except those related to cuBLAS + files = [f for f in os.listdir(script_dir) if f.endswith('.py')] + exclude_files = ['test_sanitizer.py', 'generators.py'] + funcs = [ + (module_name, name) + for module_name in [os.path.splitext(f)[0] for f in files if f not in exclude_files] + for name, obj in inspect.getmembers(importlib.import_module(module_name)) + if inspect.isfunction(obj) and name.startswith('test') and 'test_filter' not in name + ] + tools = [x.strip() for x in args.tools.split(',')] + + env = os.environ.copy() + env['CUDA_LAUNCH_BLOCKING'] = '1' + env['DG_JIT_PTXAS_CHECK'] = '1' + env['DG_USE_NVIDIA_TOOLS'] = '1' + env['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1' + env['TORCH_SHOW_CPP_STACKTRACES'] = '1' + + print(f'Library path: {deep_gemm.__path__}') + for module_name, func_name in funcs: + for tool in tools: + cmd = [ + '/usr/local/cuda/bin/compute-sanitizer', + f'--tool={tool}', + '--target-processes=application-only', + '--destroy-on-device-error=context', + '--force-blocking-launches', + '--check-api-memory-access=no', + '--kernel-name-exclude', 'kns=nvjet', + 'python', + '-c', + test_template.format(module_name=module_name, func_name=func_name, script_dir=script_dir) + ] + print(f'\n{"=" * 60}') + print(f'Running {module_name}.{func_name} with compute-sanitizer {tool}') + result = subprocess.run(cmd, env=env) + if result.returncode != 0: + sys.exit(result.returncode) diff --git a/third-party/cutlass b/third-party/cutlass index a49a78ff..f3fde583 160000 --- a/third-party/cutlass +++ b/third-party/cutlass @@ -1 +1 @@ -Subproject commit a49a78ffefc86a87160dfe0ccc3a3a2d1622c918 +Subproject commit f3fde58372d33e9a5650ba7b80fc48b3b49d40c8