diff --git a/.github/actions/locate-vcvarsall-and-setup-env/action.yml b/.github/actions/locate-vcvarsall-and-setup-env/action.yml index e174f384caa94..c4fdc48a7bd63 100644 --- a/.github/actions/locate-vcvarsall-and-setup-env/action.yml +++ b/.github/actions/locate-vcvarsall-and-setup-env/action.yml @@ -14,7 +14,7 @@ runs: steps: - name: Setup VCPKG - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 + uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' vcpkg-hash: '735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc' diff --git a/.github/actions/macos-ci-setup/action.yml b/.github/actions/macos-ci-setup/action.yml index e170ccf50a0ac..b3b95b855526f 100644 --- a/.github/actions/macos-ci-setup/action.yml +++ b/.github/actions/macos-ci-setup/action.yml @@ -62,15 +62,7 @@ runs: run: | XCODE_DEVELOPER_DIR="/Applications/Xcode_${{ inputs.xcode_version }}.app/Contents/Developer" sudo xcode-select --switch "${XCODE_DEVELOPER_DIR}" - - - name: Export GitHub Actions cache environment variables - if: ${{ inputs.use_cache }} - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - + - name: Install python dependencies shell: bash working-directory: ${{ github.workspace }} diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 8df0064e06a1d..b788bb792b23d 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -27,7 +27,7 @@ jobs: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false @@ -37,7 +37,7 @@ jobs: ndk-version: 28.0.13004108 - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.9 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -66,20 +66,13 @@ jobs: set_var("BuildConfigOs", config["os"]) shell: python working-directory: ${{ github.workspace }} - - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - + - name: 1a. Build onnxruntime run: | set -e -x BINARY_SIZE_THRESHOLD_ARGS="" - echo "Binary size threshold in bytes: 1306224" - BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1306224" + echo "Binary size threshold in bytes: 1436672" + BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1436672" # Ensure ANDROID_NDK_HOME is available and get its real path if [ -z "$ANDROID_NDK_HOME" ]; then @@ -107,8 +100,6 @@ jobs: -e BUILD_ID=${{ github.run_id }} \ -e BUILD_REASON=${{ github.event_name }} \ -e BUILD_BRANCH=${{ github.ref }} \ - -e ACTIONS_CACHE_URL \ - -e ACTIONS_RUNTIME_TOKEN \ -e RUNNER_TEMP=/build \ ${{ steps.build_docker_image_step.outputs.full-image-name }} \ bash -c "python3 -m pip install -r /onnxruntime_src/tools/ci_build/requirements/pybind/requirements.txt && \ @@ -121,7 +112,7 @@ jobs: android_nnapi_ep: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Use jdk 17 uses: actions/setup-java@v4 @@ -131,7 +122,7 @@ jobs: architecture: x64 - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' vcpkg-hash: '735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc' @@ -145,13 +136,6 @@ jobs: with: ndk-version: 28.0.13004108 - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: NNAPI EP, Build, Test on Android Emulator run: >- python3 tools/ci_build/build.py @@ -203,7 +187,7 @@ jobs: name: Android CI Pipeline runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Use jdk 17 uses: actions/setup-java@v4 @@ -217,13 +201,6 @@ jobs: with: ndk-version: 28.0.13004108 - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: CPU EP, Build and Test run: >- python3 tools/ci_build/build.py diff --git a/.github/workflows/cffconvert.yml b/.github/workflows/cffconvert.yml index c875573eb3537..30f832f67c5ee 100644 --- a/.github/workflows/cffconvert.yml +++ b/.github/workflows/cffconvert.yml @@ -12,7 +12,7 @@ jobs: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - name: Check out a copy of the repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Check whether the citation metadata from CITATION.cff is valid uses: citation-file-format/cffconvert-github-action@2.0.0 diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index a0188b864d849..efe580c1b3b0c 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -38,7 +38,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index 91e42583d361f..0e5ea60f61402 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -15,7 +15,7 @@ jobs: name: "Validation" runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - uses: gradle/actions/wrapper-validation@v4 concurrency: group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} diff --git a/.github/workflows/ios.yml b/.github/workflows/ios.yml index 75f5d02fd3720..0d2046b980783 100644 --- a/.github/workflows/ios.yml +++ b/.github/workflows/ios.yml @@ -20,9 +20,17 @@ jobs: runs-on: macos-14 steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 + with: + vcpkg-version: '2025.06.13' + vcpkg-hash: 735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc + cmake-version: '3.31.8' + cmake-hash: 99cc9c63ae49f21253efb5921de2ba84ce136018abf08632c92c060ba91d552e0f6acc214e9ba8123dee0cf6d1cf089ca389e321879fd9d719a60d975bcffcc8 + add-cmake-to-path: 'true' + disable-terrapin: 'true' - name: Use Xcode ${{ env.XCODE_VERSION }} shell: bash run: | @@ -30,13 +38,6 @@ jobs: XCODE_DEVELOPER_DIR="/Applications/Xcode_${{ env.XCODE_VERSION }}.app/Contents/Developer" sudo xcode-select --switch "${XCODE_DEVELOPER_DIR}" - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: (CPU, CoreML, XNNPACK EPs) Build onnxruntime for iOS x86_64 and run tests using simulator shell: bash run: | diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 16c9008f3675f..e4dc0c8bdbe26 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -17,7 +17,7 @@ jobs: name: Optional Lint runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: misspell # Check spellings as well uses: reviewdog/action-misspell@v1 with: @@ -42,9 +42,9 @@ jobs: contents: read security-events: write steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: # Use the version configured in target-version of [tool.black] section in pyproject.toml. python-version: "3.10" @@ -116,7 +116,7 @@ jobs: name: Lint JavaScript runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - uses: actions/setup-node@v4 with: node-version: 20 diff --git a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml index 40195ebdf37f2..c30a8cb023f50 100644 --- a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml +++ b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml @@ -41,7 +41,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: recursive @@ -51,12 +51,12 @@ jobs: node-version: "22" - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.12" architecture: ${{ env.buildArch }} - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' vcpkg-hash: '735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc' @@ -64,12 +64,6 @@ jobs: cmake-hash: '42395e20b10a8e9ef3e33014f9a4eed08d46ab952e02d2c1bbc8f6133eca0d7719fb75680f9bbff6552f20fcd1b73d86860f7f39388d631f98fb6f622b37cf04' add-cmake-to-path: 'true' disable-terrapin: 'true' - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - name: Build (simd + threads) run: | diff --git a/.github/workflows/linux_cuda_ci.yml b/.github/workflows/linux_cuda_ci.yml index f4ee8a7c27cd0..9a9dace777c83 100644 --- a/.github/workflows/linux_cuda_ci.yml +++ b/.github/workflows/linux_cuda_ci.yml @@ -48,9 +48,9 @@ jobs: packages: read steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 - - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 + - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.9 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -93,7 +93,7 @@ jobs: # So build.py --build_dir build/Release inside the container correctly finds the artifacts. - name: Test ONNX Runtime id: test_step - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.9 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: Release diff --git a/.github/workflows/linux_minimal_build.yml b/.github/workflows/linux_minimal_build.yml index 7532d363b19eb..92cdbb70e9858 100644 --- a/.github/workflows/linux_minimal_build.yml +++ b/.github/workflows/linux_minimal_build.yml @@ -29,21 +29,15 @@ jobs: packages: write steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v4 with: node-version: 20 - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' vcpkg-hash: '735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc' @@ -53,7 +47,7 @@ jobs: disable-terrapin: 'true' - name: Build Full ORT and Prepare Test Files - uses: microsoft/onnxruntime-github-actions/build-and-prep-ort-files@v0.0.7 + uses: microsoft/onnxruntime-github-actions/build-and-prep-ort-files@v0.0.9 - name: Upload Test Data Artifact uses: actions/upload-artifact@v4 @@ -72,7 +66,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v4 @@ -80,7 +74,7 @@ jobs: node-version: 20 - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.9 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -90,15 +84,8 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: Run Build 2 (Update) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.9 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -113,7 +100,7 @@ jobs: --enable_training_ops - name: Run Build 2 (Build) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.9 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -138,20 +125,14 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v4 with: node-version: 20 - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' vcpkg-hash: '735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc' @@ -161,7 +142,7 @@ jobs: disable-terrapin: 'true' - name: Build Full ORT and Prepare Test Files - uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.7 + uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.9 with: reduced-ops-config-file: required_ops.ort_models.config enable-custom-ops: 'true' @@ -178,20 +159,14 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v4 with: node-version: 20 - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' vcpkg-hash: '735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc' @@ -200,7 +175,7 @@ jobs: add-cmake-to-path: 'true' disable-terrapin: 'true' - name: Build Full ORT and Prepare Test Files - uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.7 + uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.9 with: reduced-ops-config-file: required_ops_and_types.ort_models.config enable-type-reduction: 'true' @@ -216,20 +191,14 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v4 with: node-version: 20 - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.7 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 with: vcpkg-version: '2025.06.13' vcpkg-hash: '735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc' @@ -239,7 +208,7 @@ jobs: disable-terrapin: 'true' - name: Build Full ORT and Prepare Test Files - uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.7 + uses: microsoft/onnxruntime-github-actions/build-minimal-ort-and-run-tests@v0.0.9 with: globally_allowed_types: 'bool,float,int8_t,uint8_t' enable-type-reduction: 'true' @@ -256,7 +225,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v4 @@ -264,7 +233,7 @@ jobs: node-version: 20 - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.9 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -274,15 +243,9 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - name: Run Build 5 (Update) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.9 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -295,7 +258,7 @@ jobs: --minimal_build extended - name: Run Build 5 (Build) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.9 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -307,7 +270,7 @@ jobs: --use_binskim_compliant_compile_flags --minimal_build extended - name: Run Build 5 (Test) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.9 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -329,12 +292,12 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.9 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -344,13 +307,6 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: gen config shell: bash run: | @@ -358,7 +314,7 @@ jobs: touch ${{ runner.temp }}/.test_data/include_no_operators.config - name: Run Build 6a (Update) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.9 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -374,7 +330,7 @@ jobs: --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF - name: Run Build 6a (Build) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.9 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -391,7 +347,7 @@ jobs: - name: Run Build 6a (Test) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.9 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -416,7 +372,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false @@ -427,7 +383,7 @@ jobs: touch ${{ runner.temp }}/.test_data/include_no_operators.config - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.9 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -437,15 +393,8 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: Run Build 6b (Update) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.9 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -464,7 +413,7 @@ jobs: --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF - name: Run Build 6b (Build) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.9 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -492,7 +441,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false @@ -503,7 +452,7 @@ jobs: touch ${{ runner.temp }}/.test_data/include_no_operators.config - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.9 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -513,12 +462,6 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - name: gen config shell: bash run: | @@ -526,7 +469,7 @@ jobs: touch ${{ runner.temp }}/.test_data/include_no_operators.config - name: Run Build 6c (Update) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.9 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -545,7 +488,7 @@ jobs: --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF - name: Run Build 6c (Build) - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.9 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} @@ -575,7 +518,7 @@ jobs: id-token: write # If using OIDC for ACR login steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false - uses: actions/setup-node@v4 @@ -588,7 +531,7 @@ jobs: path: ${{ runner.temp }}/.test_data/ - name: Get Docker Image using Action - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.9 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -604,13 +547,6 @@ jobs: ndk-version: 28.0.13004108 # Use default android-sdk-root if not specified - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: Run Build 7 (Using docker run) shell: bash run: | @@ -636,7 +572,7 @@ jobs: --volume $ANDROID_HOME:/android_home \ --volume $NDK_HOME_REALPATH:/ndk_home \ -e ALLOW_RELEASED_ONNX_OPSET_ONLY=1 \ - -e NIGHTLY_BUILD=1 -e ACTIONS_CACHE_URL -e ACTIONS_RUNTIME_TOKEN -e RUNNER_TEMP=/build \ + -e NIGHTLY_BUILD=1 -e RUNNER_TEMP=/build \ ${{ steps.build_docker_image_step.outputs.full-image-name }} \ bash -c "python3 -m pip install -r /onnxruntime_src/tools/ci_build/requirements/pybind/requirements.txt \ && python3 /onnxruntime_src/tools/ci_build/build.py \ diff --git a/.github/workflows/linux_tensorrt_ci.yml b/.github/workflows/linux_tensorrt_ci.yml index a7d3f5ec0f5fd..043eb0b218e2f 100644 --- a/.github/workflows/linux_tensorrt_ci.yml +++ b/.github/workflows/linux_tensorrt_ci.yml @@ -48,11 +48,11 @@ jobs: packages: read steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 # --- Build the Docker image needed for testing --- - name: Build Docker Image for Testing - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.9 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -95,7 +95,7 @@ jobs: # So build.py --build_dir build/Release inside the container correctly finds the artifacts. - name: Test ONNX Runtime id: test_step - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.9 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: Release diff --git a/.github/workflows/linux_webgpu.yml b/.github/workflows/linux_webgpu.yml index 08789489b12a3..f7161754895c5 100644 --- a/.github/workflows/linux_webgpu.yml +++ b/.github/workflows/linux_webgpu.yml @@ -51,7 +51,7 @@ jobs: # - name: Checkout code # uses: actions/checkout@v4 - # - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 + # - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.9 # id: build_docker_image_step # with: # dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu @@ -91,7 +91,7 @@ jobs: # - name: Test ONNX Runtime # id: test_step - # uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + # uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.9 # with: # docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} # build_config: Release diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index 9cc1604d71e68..af2b36c870201 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -53,7 +53,8 @@ jobs: runs-on: macos-15 env: - xcode_version: 16 + xcode_version: 16.4 + simulator_runtime_version: 18.5 strategy: matrix: @@ -63,7 +64,15 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 + with: + vcpkg-version: '2025.06.13' + vcpkg-hash: 735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc + cmake-version: '3.31.8' + cmake-hash: 99cc9c63ae49f21253efb5921de2ba84ce136018abf08632c92c060ba91d552e0f6acc214e9ba8123dee0cf6d1cf089ca389e321879fd9d719a60d975bcffcc8 + add-cmake-to-path: 'true' + disable-terrapin: 'true' - name: macOS CI pipeline prepare steps uses: ./.github/actions/macos-ci-setup @@ -90,6 +99,8 @@ jobs: --apple_deploy_target=15.1 \ --apple_sysroot=iphonesimulator \ --osx_arch=${{ matrix.target_arch }} + env: + ORT_GET_SIMULATOR_DEVICE_INFO_REQUESTED_RUNTIME_VERSION: ${{ env.simulator_runtime_version }} Objective-C-StaticAnalysis: runs-on: macos-14 @@ -101,7 +112,15 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 + with: + vcpkg-version: '2025.06.13' + vcpkg-hash: 735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc + cmake-version: '3.31.8' + cmake-hash: 99cc9c63ae49f21253efb5921de2ba84ce136018abf08632c92c060ba91d552e0f6acc214e9ba8123dee0cf6d1cf089ca389e321879fd9d719a60d975bcffcc8 + add-cmake-to-path: 'true' + disable-terrapin: 'true' - name: macOS CI pipeline prepare steps uses: ./.github/actions/macos-ci-setup diff --git a/.github/workflows/macos-ci-build-and-test-workflow.yml b/.github/workflows/macos-ci-build-and-test-workflow.yml index edb1764d44ec1..281538336b0c1 100644 --- a/.github/workflows/macos-ci-build-and-test-workflow.yml +++ b/.github/workflows/macos-ci-build-and-test-workflow.yml @@ -61,7 +61,15 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 + with: + vcpkg-version: '2025.06.13' + vcpkg-hash: 735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc + cmake-version: '3.31.8' + cmake-hash: 99cc9c63ae49f21253efb5921de2ba84ce136018abf08632c92c060ba91d552e0f6acc214e9ba8123dee0cf6d1cf089ca389e321879fd9d719a60d975bcffcc8 + add-cmake-to-path: 'true' + disable-terrapin: 'true' - name: macOS CI pipeline prepare steps uses: ./.github/actions/macos-ci-setup diff --git a/.github/workflows/pr_checks.yml b/.github/workflows/pr_checks.yml index bb8a0638afea2..1d76a9ba413ed 100644 --- a/.github/workflows/pr_checks.yml +++ b/.github/workflows/pr_checks.yml @@ -24,9 +24,9 @@ jobs: contents: read pull-requests: write steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.10" - name: Setup Rust diff --git a/.github/workflows/publish-c-apidocs.yml b/.github/workflows/publish-c-apidocs.yml index 6d3e593d8694e..fb4e92715a723 100644 --- a/.github/workflows/publish-c-apidocs.yml +++ b/.github/workflows/publish-c-apidocs.yml @@ -24,7 +24,7 @@ jobs: name: Generate C/C++ API docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Install doxygen and dependencies run: | sudo apt update diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml index 7cca0969a168b..42d1bdc295785 100644 --- a/.github/workflows/publish-csharp-apidocs.yml +++ b/.github/workflows/publish-csharp-apidocs.yml @@ -24,7 +24,7 @@ jobs: env: DOCFXVERSION: 2.62.2 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Install DocFX run: | dotnet tool update -g docfx diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml index d04669a13aab7..c107535473786 100644 --- a/.github/workflows/publish-java-apidocs.yml +++ b/.github/workflows/publish-java-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate Java docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Set up JDK 11 uses: actions/setup-java@v4 with: diff --git a/.github/workflows/publish-js-apidocs.yml b/.github/workflows/publish-js-apidocs.yml index a6749b42adc35..4bfda98945878 100644 --- a/.github/workflows/publish-js-apidocs.yml +++ b/.github/workflows/publish-js-apidocs.yml @@ -23,7 +23,7 @@ jobs: name: Generate JS API docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Setup Node.js uses: actions/setup-node@v4 with: diff --git a/.github/workflows/publish-objectivec-apidocs.yml b/.github/workflows/publish-objectivec-apidocs.yml index deef64f73f15a..7f1611fdff315 100644 --- a/.github/workflows/publish-objectivec-apidocs.yml +++ b/.github/workflows/publish-objectivec-apidocs.yml @@ -23,7 +23,15 @@ jobs: name: Generate Objective-C API docs runs-on: macos-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 + - uses: microsoft/onnxruntime-github-actions/setup-build-tools@v0.0.9 + with: + vcpkg-version: '2025.06.13' + vcpkg-hash: 735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc + cmake-version: '3.31.8' + cmake-hash: 99cc9c63ae49f21253efb5921de2ba84ce136018abf08632c92c060ba91d552e0f6acc214e9ba8123dee0cf6d1cf089ca389e321879fd9d719a60d975bcffcc8 + add-cmake-to-path: 'true' + disable-terrapin: 'true' - name: Install Jazzy run: | diff --git a/.github/workflows/publish-python-apidocs.yml b/.github/workflows/publish-python-apidocs.yml index d03c9a407d54f..4baa8a0f5c272 100644 --- a/.github/workflows/publish-python-apidocs.yml +++ b/.github/workflows/publish-python-apidocs.yml @@ -24,7 +24,7 @@ jobs: name: Generate Python API docs runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Install tools run: | sudo apt-get update diff --git a/.github/workflows/reusable_linux_build.yml b/.github/workflows/reusable_linux_build.yml index af24e3a3d901a..1a9c0e0a72031 100644 --- a/.github/workflows/reusable_linux_build.yml +++ b/.github/workflows/reusable_linux_build.yml @@ -75,15 +75,15 @@ jobs: id-token: write steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Set up Python ${{ inputs.python_version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ inputs.python_version }} - name: Build Docker Image (${{ inputs.architecture }} / ${{ inputs.build_config }}) - uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.7 + uses: microsoft/onnxruntime-github-actions/build-docker-image@v0.0.9 id: build_docker_image_step with: dockerfile: ${{ github.workspace }}/${{ inputs.dockerfile_path }} @@ -94,16 +94,10 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GH_TOKEN }} - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); # ------------- Update Step (CMake Generation) ------------- - name: Generate Build Files (CMake) (${{ inputs.architecture }} / ${{ inputs.build_config }}) id: update_step - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.9 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: ${{ inputs.build_config }} @@ -115,7 +109,7 @@ jobs: # ------------- Build Step (Compilation) ------------- - name: Build ONNX Runtime (${{ inputs.architecture }} / ${{ inputs.build_config }}) id: build_step - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.9 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: ${{ inputs.build_config }} @@ -128,7 +122,7 @@ jobs: - name: Test ONNX Runtime (${{ inputs.architecture }} / ${{ inputs.build_config }}) id: test_step if: inputs.run_tests == true - uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.7 + uses: microsoft/onnxruntime-github-actions/run-build-script-in-docker@v0.0.9 with: docker_image: ${{ steps.build_docker_image_step.outputs.full-image-name }} build_config: ${{ inputs.build_config }} diff --git a/.github/workflows/web.yml b/.github/workflows/web.yml index 0133e4994e5e9..616c2c6db8a8d 100644 --- a/.github/workflows/web.yml +++ b/.github/workflows/web.yml @@ -22,7 +22,7 @@ jobs: commit_sha: ${{ steps.extract_commit.outputs.commit_sha }} steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: true diff --git a/.github/workflows/windows-web-ci-workflow.yml b/.github/workflows/windows-web-ci-workflow.yml index fcbef760d4626..0ea8b3ee33644 100644 --- a/.github/workflows/windows-web-ci-workflow.yml +++ b/.github/workflows/windows-web-ci-workflow.yml @@ -29,7 +29,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false diff --git a/.github/workflows/windows_build_x64_asan.yml b/.github/workflows/windows_build_x64_asan.yml index 42ecf84369b6f..05fd4acd4de9a 100644 --- a/.github/workflows/windows_build_x64_asan.yml +++ b/.github/workflows/windows_build_x64_asan.yml @@ -19,12 +19,12 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.12' architecture: x64 @@ -34,13 +34,6 @@ jobs: with: architecture: x64 - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: Build and Test (Combined) shell: cmd run: | diff --git a/.github/workflows/windows_cuda.yml b/.github/workflows/windows_cuda.yml index 18ff55506d401..0b1bf59733349 100644 --- a/.github/workflows/windows_cuda.yml +++ b/.github/workflows/windows_cuda.yml @@ -21,12 +21,12 @@ jobs: name: Windows GPU CUDA CI Pipeline runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: fetch-depth: 0 submodules: 'none' - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 with: python-version: '3.12' architecture: x64 @@ -152,7 +152,7 @@ jobs: timeout-minutes: 300 runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: fetch-depth: 0 submodules: 'none' @@ -163,7 +163,7 @@ jobs: name: build-artifacts path: ${{ runner.temp }}\build - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 with: python-version: '3.12' architecture: x64 diff --git a/.github/workflows/windows_dml.yml b/.github/workflows/windows_dml.yml index e9bccab6fae66..639f57d2c6a48 100644 --- a/.github/workflows/windows_dml.yml +++ b/.github/workflows/windows_dml.yml @@ -27,12 +27,12 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: fetch-depth: 0 # Fetch all history for all tags and branches submodules: 'none' - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 with: python-version: '3.12' architecture: x64 @@ -82,13 +82,6 @@ jobs: run: nuget restore ${{ github.workspace }}\packages.config -ConfigFile ${{ github.workspace }}\NuGet.config -PackagesDirectory ${{ github.workspace }}\RelWithDebInfo shell: cmd - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: Set OnnxRuntimeBuildDirectory shell: pwsh run: | diff --git a/.github/workflows/windows_openvino.yml b/.github/workflows/windows_openvino.yml index 96289e65502d9..395ccfbe70244 100644 --- a/.github/workflows/windows_openvino.yml +++ b/.github/workflows/windows_openvino.yml @@ -31,13 +31,13 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: fetch-depth: 0 submodules: none - name: Setup Python 3.12 - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.12' architecture: x64 #Keep x64, because the original pipeline is for x64 @@ -47,13 +47,6 @@ jobs: with: architecture: x64 - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: Download OpenVINO Toolkit v2025.2.0 env: OpenVINOVersion: 2025.2.0 diff --git a/.github/workflows/windows_qnn_x64.yml b/.github/workflows/windows_qnn_x64.yml new file mode 100644 index 0000000000000..4c08d543cefd9 --- /dev/null +++ b/.github/workflows/windows_qnn_x64.yml @@ -0,0 +1,82 @@ +name: Windows x64 QNN CI Pipeline + +on: + push: + branches: + - main + - rel-* + pull_request: + branches: + - main + - rel-* + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} + cancel-in-progress: true + +jobs: + build_test_qnn_ep: + name: Windows x64 QNN CI Pipeline (${{ matrix.QnnLibKind }}) + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] + timeout-minutes: 120 + strategy: + matrix: + QnnLibKind: [shared_lib, static_lib] + env: + AZCOPY_AUTO_LOGIN_TYPE: MSI + AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true + ALLOW_RELEASED_ONNX_OPSET_ONLY: '1' + + steps: + - name: Checkout repository + uses: actions/checkout@v5 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + architecture: x64 + + - name: Locate vcvarsall and Setup Env + uses: ./.github/actions/locate-vcvarsall-and-setup-env + with: + architecture: x64 + + - name: Download QNN SDK + working-directory: ${{ runner.temp }} + run: | + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/qnnsdk/qnn-v2.37.1.250807 . + dir + shell: pwsh + + - name: Set QNN_SDK_ROOT environment variable + shell: pwsh + run: | + $qnn_sdk_path = Join-Path $env:RUNNER_TEMP "qnn-v2.37.1.250807" + echo "QNN_SDK_ROOT=$qnn_sdk_path" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append + echo "QNN SDK Root: $qnn_sdk_path" + dir $qnn_sdk_path + + - name: Build and Test + shell: cmd + run: | + python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --build_dir ${{ runner.temp }}\build --cmake_generator "Visual Studio 17 2022" --build_java --build_shared_lib --use_qnn ${{ matrix.QnnLibKind }} --qnn_home %QNN_SDK_ROOT% --use_binskim_compliant_compile_flags --update --build --test --enable_onnx_tests --parallel + + - name: Run ONNX Tests + shell: cmd + working-directory: ${{ runner.temp }}\build\RelWithDebInfo\RelWithDebInfo + run: | + .\onnx_test_runner -j 1 -e qnn -i "backend_path|%QNN_SDK_ROOT%\lib\x86_64-windows-msvc\QnnCpu.dll" ${{ github.workspace }}\cmake\external\onnx\onnx\backend\test\data\node + + - name: Run float32 model tests + shell: cmd + working-directory: ${{ runner.temp }}\build\RelWithDebInfo\RelWithDebInfo + run: | + rem This step assumes the model data exists at C:\data\float32_models on the runner + if exist C:\data\float32_models ( + .\onnx_test_runner -j 1 -e qnn -i "backend_path|%QNN_SDK_ROOT%\lib\x86_64-windows-msvc\QnnCpu.dll" C:\data\float32_models + ) else ( + echo "Skipping float32 model tests: C:\data\float32_models not found." + ) diff --git a/.github/workflows/windows_tensorrt.yml b/.github/workflows/windows_tensorrt.yml index dbc138e57a3ec..de6fa1529bcb1 100644 --- a/.github/workflows/windows_tensorrt.yml +++ b/.github/workflows/windows_tensorrt.yml @@ -21,12 +21,12 @@ jobs: name: Windows GPU TensorRT CI Pipeline runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: fetch-depth: 0 submodules: 'none' - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 with: python-version: '3.12' architecture: x64 @@ -157,7 +157,7 @@ jobs: timeout-minutes: 300 runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: fetch-depth: 0 submodules: 'none' @@ -168,7 +168,7 @@ jobs: name: build-artifacts path: ${{ runner.temp }}\build - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 with: python-version: '3.12' architecture: x64 diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index 996e0d816d51a..e1a8c28f5a1ad 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -34,13 +34,13 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0" steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: fetch-depth: 0 submodules: none - name: Setup Python 3.12 - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.12" architecture: x64 @@ -101,13 +101,6 @@ jobs: path: ${{ github.workspace }}/js/test/ key: onnxnodetests-${{ hashFiles('js/scripts/prepare-onnx-node-tests.ts') }} - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: Build and Test shell: pwsh run: | @@ -162,13 +155,13 @@ jobs: timeout-minutes: 300 steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: fetch-depth: 0 submodules: none - name: Setup Python 3.12 - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.12" architecture: x64 @@ -183,13 +176,6 @@ jobs: shell: cmd working-directory: ${{ github.workspace }} - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: Generate onnxruntime.sln shell: pwsh run: | @@ -222,13 +208,13 @@ jobs: ONNXRUNTIME_TEST_GPU_DEVICE_ID: "0" steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: fetch-depth: 0 submodules: none - name: Setup Python 3.12 - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.12" architecture: x64 @@ -283,13 +269,6 @@ jobs: shell: cmd working-directory: ${{ github.workspace }} - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: Build shell: pwsh run: | diff --git a/.github/workflows/windows_x64_debug_build_x64_debug.yml b/.github/workflows/windows_x64_debug_build_x64_debug.yml index f4c865efe52f1..187633e5f3548 100644 --- a/.github/workflows/windows_x64_debug_build_x64_debug.yml +++ b/.github/workflows/windows_x64_debug_build_x64_debug.yml @@ -18,12 +18,12 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.12' architecture: x64 @@ -82,13 +82,6 @@ jobs: path: ${{ github.workspace }}/js/test/ key: onnxnodetests-${{ hashFiles('js/scripts/prepare-onnx-node-tests.ts') }} - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: Build and Test shell: pwsh run: | diff --git a/.github/workflows/windows_x64_release_build_x64_release.yml b/.github/workflows/windows_x64_release_build_x64_release.yml index cf4e725d9495e..03ac8386c780b 100644 --- a/.github/workflows/windows_x64_release_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_build_x64_release.yml @@ -18,12 +18,12 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.12' architecture: x64 @@ -82,13 +82,6 @@ jobs: path: ${{ github.workspace }}/js/test/ key: onnxnodetests-${{ hashFiles('js/scripts/prepare-onnx-node-tests.ts') }} - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: Build and Test shell: pwsh run: | diff --git a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml index 76a6203c4dc76..edfb83ac1beaa 100644 --- a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml +++ b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml @@ -18,12 +18,12 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.12' architecture: x64 @@ -76,13 +76,6 @@ jobs: run: | nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: Build and test shell: pwsh run: | diff --git a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml index f95706764d345..30b64006bbdd1 100644 --- a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml @@ -18,12 +18,12 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.12' architecture: x64 @@ -76,13 +76,6 @@ jobs: run: | nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: Build shell: pwsh run: | diff --git a/.github/workflows/windows_x64_release_xnnpack.yml b/.github/workflows/windows_x64_release_xnnpack.yml index e4ee10b691984..49179ff499798 100644 --- a/.github/workflows/windows_x64_release_xnnpack.yml +++ b/.github/workflows/windows_x64_release_xnnpack.yml @@ -18,12 +18,12 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.12' architecture: x64 @@ -74,14 +74,7 @@ jobs: - name: NuGet restore shell: cmd run: | - nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config - - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config - name: Build and Test shell: pwsh diff --git a/.github/workflows/windows_x86.yml b/.github/workflows/windows_x86.yml index 4652757c1d292..0528141d965fa 100644 --- a/.github/workflows/windows_x86.yml +++ b/.github/workflows/windows_x86.yml @@ -18,12 +18,12 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: submodules: false - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.12' architecture: x86 # x86 Python @@ -77,13 +77,6 @@ jobs: run: | nuget restore ${{ github.workspace }}\packages.config -PackagesDirectory ${{ github.workspace }}\build\RelWithDebInfo -ConfigFile ${{ github.workspace }}\NuGet.config - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: Build and Test shell: pwsh run: | diff --git a/.gitignore b/.gitignore index 4d0a1205b7c19..b25763334f227 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ # build, distribute, and bins (+ python proto bindings) +build.*/ build build_*/ .build_debug/* diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index b0941b4d0c922..baf21745e0e40 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -29,6 +29,7 @@ include(CheckLanguage) include(CMakeDependentOption) include(FetchContent) include(CheckFunctionExists) +include(CheckSymbolExists) include(GNUInstallDirs) # onnxruntime_providers_* require CMAKE_INSTALL_* variables # TODO: update this once all system adapt c++20 @@ -97,7 +98,8 @@ option(onnxruntime_USE_VSINPU "Build with VSINPU support" OFF) cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF) option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" OFF) -option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) +cmake_dependent_option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF) +cmake_dependent_option(onnxruntime_USE_FPA_INTB_GEMM "Build FpA IntB gemm cuda kernels" ON "onnxruntime_USE_CUDA" OFF) option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF) option(onnxruntime_USE_AVX "Use AVX instructions" OFF) @@ -693,6 +695,7 @@ if (onnxruntime_USE_CUDA) set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_LEAN_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) + set(onnxruntime_USE_FPA_INTB_GEMM OFF) endif() if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6) @@ -705,6 +708,11 @@ if (onnxruntime_USE_CUDA) set(onnxruntime_USE_FLASH_ATTENTION OFF) endif() + if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12) + message( STATUS "FpA IntB Gemm unsupported for CUDA compiler version < 12.0") + set(onnxruntime_USE_FPA_INTB_GEMM OFF) + endif() + if (WIN32) message( STATUS "Lean Attention unsupported in Windows") set(onnxruntime_USE_LEAN_ATTENTION OFF) @@ -733,6 +741,11 @@ if (onnxruntime_USE_CUDA) message( STATUS "Enable memory efficient attention for CUDA EP") list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1) endif() + + if (onnxruntime_USE_FPA_INTB_GEMM) + message( STATUS "Enable FpA IntB Gemm for CUDA EP") + list(APPEND ORT_PROVIDER_FLAGS -DUSE_FPA_INTB_GEMM=1) + endif() endif() if (onnxruntime_USE_CUDA_INTERFACE AND (NOT onnxruntime_USE_CUDA)) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index f76ad642447ba..3189b64349898 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -336,7 +336,13 @@ if (onnxruntime_ENABLE_CPUINFO) set(CPUINFO_SUPPORTED TRUE) endif() if (WIN32) - set(CPUINFO_SUPPORTED TRUE) + # There's an error when linking with cpuinfo on arm64ec with a vcpkg build (--use_vcpkg). + # TODO Fix it and then re-enable cpuinfo on arm64ec. + if (onnxruntime_target_platform STREQUAL "ARM64EC") + set(CPUINFO_SUPPORTED FALSE) + else() + set(CPUINFO_SUPPORTED TRUE) + endif() elseif (NOT ${onnxruntime_target_platform} MATCHES "^(i[3-6]86|AMD64|x86(_64)?|armv[5-8].*|aarch64|arm64)$") message(WARNING "Target processor architecture \"${onnxruntime_target_platform}\" is not supported in cpuinfo. " @@ -597,10 +603,6 @@ if(NOT (onnx_FOUND OR ONNX_FOUND)) # building ONNX from source endif() endif() -if (onnxruntime_RUN_ONNX_TESTS) - add_definitions(-DORT_RUN_EXTERNAL_ONNX_TESTS) -endif() - if(onnxruntime_ENABLE_DLPACK) message(STATUS "dlpack is enabled.") diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index 5dcc2b2628bf4..d927489372e7c 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -14,7 +14,7 @@ set(onnxruntime_common_src_patterns "${ONNXRUNTIME_ROOT}/core/platform/check_intel.h" "${ONNXRUNTIME_ROOT}/core/platform/check_intel.cc" "${ONNXRUNTIME_ROOT}/core/platform/device_discovery.h" - "${ONNXRUNTIME_ROOT}/core/platform/device_discovery.cc" + "${ONNXRUNTIME_ROOT}/core/platform/device_discovery_common.cc" "${ONNXRUNTIME_ROOT}/core/platform/env.h" "${ONNXRUNTIME_ROOT}/core/platform/env.cc" "${ONNXRUNTIME_ROOT}/core/platform/env_time.h" @@ -32,18 +32,30 @@ set(onnxruntime_common_src_patterns if(WIN32) list(APPEND onnxruntime_common_src_patterns - "${ONNXRUNTIME_ROOT}/core/platform/windows/*.h" - "${ONNXRUNTIME_ROOT}/core/platform/windows/*.cc" + "${ONNXRUNTIME_ROOT}/core/platform/windows/debug_alloc.cc" + "${ONNXRUNTIME_ROOT}/core/platform/windows/debug_alloc.h" + "${ONNXRUNTIME_ROOT}/core/platform/windows/dll_load_error.cc" + "${ONNXRUNTIME_ROOT}/core/platform/windows/dll_load_error.h" + "${ONNXRUNTIME_ROOT}/core/platform/windows/env_time.cc" + "${ONNXRUNTIME_ROOT}/core/platform/windows/env.cc" + "${ONNXRUNTIME_ROOT}/core/platform/windows/env.h" + "${ONNXRUNTIME_ROOT}/core/platform/windows/hardware_core_enumerator.cc" + "${ONNXRUNTIME_ROOT}/core/platform/windows/hardware_core_enumerator.h" + "${ONNXRUNTIME_ROOT}/core/platform/windows/stacktrace.cc" + "${ONNXRUNTIME_ROOT}/core/platform/windows/telemetry.cc" + "${ONNXRUNTIME_ROOT}/core/platform/windows/telemetry.h" "${ONNXRUNTIME_ROOT}/core/platform/windows/logging/*.h" "${ONNXRUNTIME_ROOT}/core/platform/windows/logging/*.cc" ) else() list(APPEND onnxruntime_common_src_patterns - "${ONNXRUNTIME_ROOT}/core/platform/posix/*.h" - "${ONNXRUNTIME_ROOT}/core/platform/posix/*.cc" + "${ONNXRUNTIME_ROOT}/core/platform/posix/env_time.cc" + "${ONNXRUNTIME_ROOT}/core/platform/posix/env.cc" + "${ONNXRUNTIME_ROOT}/core/platform/posix/stacktrace.cc" ) + # logging files if (onnxruntime_USE_SYSLOG) list(APPEND onnxruntime_common_src_patterns "${ONNXRUNTIME_ROOT}/core/platform/posix/logging/*.h" @@ -51,7 +63,7 @@ else() ) endif() - if (CMAKE_SYSTEM_NAME STREQUAL "Android") + if (ANDROID) list(APPEND onnxruntime_common_src_patterns "${ONNXRUNTIME_ROOT}/core/platform/android/logging/*.h" "${ONNXRUNTIME_ROOT}/core/platform/android/logging/*.cc" @@ -66,6 +78,21 @@ else() endif() endif() +# platform-specific device discovery files +if (WIN32) + list(APPEND onnxruntime_common_src_patterns + "${ONNXRUNTIME_ROOT}/core/platform/windows/device_discovery.cc") +elseif (LINUX) + list(APPEND onnxruntime_common_src_patterns + "${ONNXRUNTIME_ROOT}/core/platform/linux/device_discovery.cc") +elseif (APPLE) + list(APPEND onnxruntime_common_src_patterns + "${ONNXRUNTIME_ROOT}/core/platform/apple/device_discovery.cc") +else() + list(APPEND onnxruntime_common_src_patterns + "${ONNXRUNTIME_ROOT}/core/platform/device_discovery_default.cc") +endif() + if(onnxruntime_target_platform STREQUAL "ARM64EC") if (MSVC) link_directories("$ENV{VCINSTALLDIR}/Tools/MSVC/$ENV{VCToolsVersion}/lib/ARM64EC") @@ -216,8 +243,6 @@ endif() if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64) # Link cpuinfo if supported - # Using it mainly in ARM with Android. - # Its functionality in detecting x86 cpu features are lacking, so is support for Windows. if (CPUINFO_SUPPORTED) onnxruntime_add_include_to_target(onnxruntime_common cpuinfo::cpuinfo) list(APPEND onnxruntime_EXTERNAL_LIBRARIES cpuinfo::cpuinfo ${ONNXRUNTIME_CLOG_TARGET_NAME}) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 24cecf07e8e36..3530ab03c822a 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -108,6 +108,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/eltwise_kernel_neon.h ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp ) set(mlas_platform_preprocess_srcs @@ -429,12 +430,16 @@ else() ${MLAS_SRC_DIR}/softmax_kernel_neon.cpp ${MLAS_SRC_DIR}/eltwise_kernel_neon.h ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp ) if (onnxruntime_USE_KLEIDIAI) setup_kleidiai() endif() set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") + set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp + PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + if (NOT APPLE) set(mlas_platform_srcs ${mlas_platform_srcs} @@ -785,12 +790,6 @@ if (WIN32) endif() endif() -if (PLATFORM_NAME STREQUAL "macabi") - # Needed for maccatalyst C compilation - # i.e. the flags below add "--target=x86_64-apple-ios14.0-macabi -ffunction-sections -fdata-sections" - target_compile_options(onnxruntime_mlas PRIVATE ${CMAKE_C_FLAGS}) -endif() - if (NOT onnxruntime_BUILD_SHARED_LIB) install(TARGETS onnxruntime_mlas EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} diff --git a/cmake/onnxruntime_providers_migraphx.cmake b/cmake/onnxruntime_providers_migraphx.cmake index 495ff093326ad..8cb5dcf95155a 100644 --- a/cmake/onnxruntime_providers_migraphx.cmake +++ b/cmake/onnxruntime_providers_migraphx.cmake @@ -2,21 +2,11 @@ # Licensed under the MIT License. add_definitions(-DUSE_MIGRAPHX=1) - set(BUILD_LIBRARY_ONLY 1) - add_definitions("-DONNX_ML=1") - add_definitions("-DONNX_NAMESPACE=onnx") - include_directories(${protobuf_SOURCE_DIR} ${eigen_SOURCE_DIR}) - set(MIGRAPHX_ROOT ${onnxruntime_MIGRAPHX_HOME}) - include_directories(${onnx_SOURCE_DIR}) + include_directories(${protobuf_SOURCE_DIR} ${eigen_SOURCE_DIR} ${onnx_SOURCE_DIR}) set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) - if ( CMAKE_COMPILER_IS_GNUCC ) + if (CMAKE_COMPILER_IS_GNUCC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wno-missing-field-initializers") endif() - set(CXX_VERSION_DEFINED TRUE) - set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS}) - if ( CMAKE_COMPILER_IS_GNUCC ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") - endif() # Add search paths for default rocm installation list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm $ENV{HIP_PATH}) @@ -33,8 +23,6 @@ find_package(hip REQUIRED) find_package(migraphx REQUIRED PATHS ${AMD_MIGRAPHX_HOME}) - set(migraphx_libs migraphx::c hip::host) - file(GLOB_RECURSE onnxruntime_providers_migraphx_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.h" "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.cc" @@ -42,14 +30,14 @@ "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" ) source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_migraphx_cc_srcs}) - onnxruntime_add_shared_library_module(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) + onnxruntime_add_shared_library(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) - add_dependencies(onnxruntime_providers_migraphx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_link_libraries(onnxruntime_providers_migraphx PRIVATE ${migraphx_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) - target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime) + add_dependencies(onnxruntime_providers_migraphx ${onnxruntime_EXTERNAL_DEPENDENCIES}) + target_link_libraries(onnxruntime_providers_migraphx PRIVATE migraphx::c hip::host ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) + target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/migraphx/onnxruntime) set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX) set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime") - target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1) + target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1 ONNX_ML=1 ONNX_NAMESPACE=onnx) if(MSVC) set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS /DEF:${ONNXRUNTIME_ROOT}/core/providers/migraphx/symbols.def) target_link_libraries(onnxruntime_providers_migraphx PRIVATE ws2_32) @@ -62,6 +50,15 @@ target_link_libraries(onnxruntime_providers_migraphx PRIVATE stdc++fs) endif() + set(CMAKE_REQUIRED_LIBRARIES migraphx::c) + + check_symbol_exists(migraphx_onnx_options_set_external_data_path + "migraphx/migraphx.h" HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH) + + if(HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH) + target_compile_definitions(onnxruntime_providers_migraphx PRIVATE HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH=1) + endif() + if (onnxruntime_ENABLE_TRAINING_OPS) onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_training) target_link_libraries(onnxruntime_providers_migraphx PRIVATE onnxruntime_training) @@ -71,15 +68,39 @@ endif() if(CMAKE_SYSTEM_NAME STREQUAL "Windows") - install(TARGETS onnxruntime_providers_migraphx - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - ) - else() - install(TARGETS onnxruntime_providers_migraphx - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - ) + foreach(file migraphx-hiprtc-driver.exe migraphx.dll migraphx_c.dll migraphx_cpu.dll migraphx_device.dll migraphx_gpu.dll migraphx_onnx.dll migraphx_tf.dll) + set(_source "${AMD_MIGRAPHX_HOME}/bin/${file}") + if(EXISTS "${_source}") + add_custom_command(TARGET onnxruntime_providers_migraphx + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${_source} $) + set(_target "$/${file}") + list(APPEND _migraphx_targets ${_target}) + endif() + endforeach() + set(MIGRAPHX_LIB_FILES ${_migraphx_targets} CACHE INTERNAL "" FORCE) + install(FILES ${_migraphx_targets} + DESTINATION ${CMAKE_INSTALL_BINDIR}) + get_property(_amdhip64_location TARGET hip::amdhip64 PROPERTY IMPORTED_LOCATION_RELEASE) + cmake_path(GET _amdhip64_location PARENT_PATH _hipsdk_path) + foreach(file amd_comgr0602.dll amd_comgr0604.dll amd_comgr0700.dll hiprtc0602.dll hiprtc0604.dll hiprtc0700.dll hiprtc-builtins0602.dll hiprtc-builtins0604.dll hiprtc-builtins0700.dll) + set(_source "${_hipsdk_path}/${file}") + if(EXISTS "${_source}") + add_custom_command(TARGET onnxruntime_providers_migraphx + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${_source} $) + set(_target "$/${file}") + list(APPEND _hipsdk_targets ${_target}) + endif() + endforeach() + set(HIPSDK_LIB_FILES ${_hipsdk_targets} CACHE INTERNAL "" FORCE) + install(FILES ${_hipsdk_targets} + DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() + + install(TARGETS onnxruntime_providers_migraphx + EXPORT onnxruntime_providers_migraphxTargets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index c5c85dff96411..ae976abe62fd8 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -740,6 +740,21 @@ if (onnxruntime_USE_OPENVINO) ) endif() +if (onnxruntime_USE_MIGRAPHX) + if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${MIGRAPHX_LIB_FILES} + $/onnxruntime/capi/) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${HIPSDK_LIB_FILES} + $/onnxruntime/capi/) + endif() +endif() + if (onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index 3ec3c6ee1d5ae..f81a7a9726b76 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -5,6 +5,8 @@ file(GLOB onnxruntime_session_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_INCLUDE_DIR}/core/session/*.h" "${ONNXRUNTIME_ROOT}/core/session/*.h" "${ONNXRUNTIME_ROOT}/core/session/*.cc" + "${ONNXRUNTIME_ROOT}/core/session/plugin_ep/*.h" + "${ONNXRUNTIME_ROOT}/core/session/plugin_ep/*.cc" ) if (onnxruntime_ENABLE_TRAINING_APIS) @@ -22,7 +24,7 @@ endif() # which is not enabled for any minimal builds. if (onnxruntime_MINIMAL_BUILD) file(GLOB autoep_srcs - "${ONNXRUNTIME_ROOT}/core/session/ep_*.*" + "${ONNXRUNTIME_ROOT}/core/session/plugin_ep/*.*" ) set(onnxruntime_session_src_exclude diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 96e513c8a7bc9..6aad71e40b2a8 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -607,7 +607,6 @@ endif() if(onnxruntime_USE_MIGRAPHX) list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared) endif() if(onnxruntime_USE_COREML) @@ -688,9 +687,6 @@ endif() if(onnxruntime_USE_MIGRAPHX) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/migraphx/*) - list(APPEND onnxruntime_test_framework_src_patterns "${ONNXRUNTIME_ROOT}/core/providers/migraphx/migraphx_execution_provider_utils.h") - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared) endif() if(onnxruntime_USE_NNAPI_BUILTIN) @@ -1230,6 +1226,12 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) ${onnxruntime_perf_test_src_patterns} ) onnxruntime_add_executable(onnxruntime_perf_test ${onnxruntime_perf_test_src} ${ONNXRUNTIME_ROOT}/core/platform/path_lib.cc) + + # ABSL_FLAGS_STRIP_NAMES is set to 1 by default to disable flag registration when building for Android, iPhone, and "embedded devices". + # See the issue: https://github.com/abseil/abseil-cpp/issues/1875 + # We set it to 0 for all builds to be able to use ABSL flags for onnxruntime_perf_test. + target_compile_definitions(onnxruntime_perf_test PRIVATE ABSL_FLAGS_STRIP_NAMES=0) + if(MSVC) target_compile_options(onnxruntime_perf_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") @@ -1252,7 +1254,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) onnx_test_runner_common onnxruntime_test_utils onnxruntime_common onnxruntime onnxruntime_flatbuffers onnx_test_data_proto ${onnxruntime_EXTERNAL_LIBRARIES} - ${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS}) + absl::flags absl::flags_parse ${SYS_PATH_LIB} ${CMAKE_DL_LIBS}) if(NOT WIN32) if(onnxruntime_USE_SNPE) list(APPEND onnxruntime_perf_test_libs onnxruntime_providers_snpe) @@ -1272,7 +1274,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_link_libraries(onnxruntime_perf_test PRIVATE debug dbghelp advapi32) endif() else() - target_link_libraries(onnxruntime_perf_test PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs}) + target_link_libraries(onnxruntime_perf_test PRIVATE onnx_test_runner_common absl::flags absl::flags_parse ${onnx_test_libs}) endif() set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest") diff --git a/cmake/vcpkg-ports/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch b/cmake/vcpkg-ports/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch new file mode 100644 index 0000000000000..23ceeb8f758cc --- /dev/null +++ b/cmake/vcpkg-ports/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch @@ -0,0 +1,22 @@ +diff --git a/include/cpuinfo.h b/include/cpuinfo.h +index f1d35d4..9e454d2 100644 +--- a/include/cpuinfo.h ++++ b/include/cpuinfo.h +@@ -18,7 +18,7 @@ + #define CPUINFO_ARCH_X86 1 + #endif + +-#if defined(__x86_64__) || defined(__x86_64) || defined(_M_X64) || defined(_M_AMD64) ++#if defined(__x86_64__) || defined(__x86_64) || (defined(_M_X64) && !defined(_M_ARM64EC)) || (defined(_M_AMD64) && !defined(_M_ARM64EC)) + #define CPUINFO_ARCH_X86_64 1 + #endif + +@@ -26,7 +26,7 @@ + #define CPUINFO_ARCH_ARM 1 + #endif + +-#if defined(__aarch64__) || defined(_M_ARM64) ++#if defined(__aarch64__) || defined(_M_ARM64) || defined(_M_ARM64EC) + #define CPUINFO_ARCH_ARM64 1 + #endif + diff --git a/cmake/vcpkg-ports/cpuinfo/portfile.cmake b/cmake/vcpkg-ports/cpuinfo/portfile.cmake index e61308bf643b4..917fd29a8d28b 100644 --- a/cmake/vcpkg-ports/cpuinfo/portfile.cmake +++ b/cmake/vcpkg-ports/cpuinfo/portfile.cmake @@ -9,6 +9,8 @@ vcpkg_from_github( REF 8a1772a0c5c447df2d18edf33ec4603a8c9c04a6 SHA512 b94ccbfa886221d6bb16513d074675af0a72928a9dd9485dcacdc1124a8a60aacbbe91913a1579e766dfb024f0be1d52eeead40342004ff0238a8b94a095ed08 HEAD_REF master + PATCHES + patch_cpuinfo_h_for_arm64ec.patch ) vcpkg_check_features(OUT_FEATURE_OPTIONS FEATURE_OPTIONS diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs index c348184658e7e..bde39d9c6e6cc 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs @@ -4,6 +4,7 @@ namespace Microsoft.ML.OnnxRuntime { using System; + using System.Diagnostics; using System.Runtime.InteropServices; /// @@ -22,18 +23,19 @@ public enum OrtCompileApiFlags : uint /// This class is used to set options for model compilation, and to produce a compiled model using those options. /// See https://onnxruntime.ai/docs/api/c/ for further details of various options. /// - public class OrtModelCompilationOptions : SafeHandle + public class OrtModelCompilationOptions : IDisposable { /// /// Create a new OrtModelCompilationOptions object from SessionOptions. /// + /// By default, the GraphOptimizationLevel is set to ORT_DISABLE_ALL. Use SetGraphOptimizationLevel() + /// to enable graph optimizations. /// SessionOptions instance to read settings from. public OrtModelCompilationOptions(SessionOptions sessionOptions) - : base(IntPtr.Zero, true) { NativeApiStatus.VerifySuccess( NativeMethods.CompileApi.OrtCreateModelCompilationOptionsFromSessionOptions( - OrtEnv.Instance().Handle, sessionOptions.Handle, out handle)); + OrtEnv.Instance().Handle, sessionOptions.Handle, out _handle)); } /// @@ -41,7 +43,7 @@ public OrtModelCompilationOptions(SessionOptions sessionOptions) /// public void CompileModel() { - NativeApiStatus.VerifySuccess(NativeMethods.CompileApi.OrtCompileModel(OrtEnv.Instance().Handle, handle)); + NativeApiStatus.VerifySuccess(NativeMethods.CompileApi.OrtCompileModel(OrtEnv.Instance().Handle, _handle)); } @@ -53,7 +55,7 @@ public void SetInputModelPath(string path) { var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(path); NativeApiStatus.VerifySuccess( - NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelPath(handle, platformPath)); + NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelPath(_handle, platformPath)); } /// @@ -65,7 +67,7 @@ public void SetInputModelFromBuffer(byte[] buffer) { NativeApiStatus.VerifySuccess( NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelFromBuffer( - handle, buffer, (UIntPtr)buffer.Length)); + _handle, buffer, (UIntPtr)buffer.Length)); } /// @@ -76,7 +78,7 @@ public void SetOutputModelPath(string path) { var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(path); NativeApiStatus.VerifySuccess( - NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelPath(handle, platformPath)); + NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelPath(_handle, platformPath)); } @@ -91,7 +93,7 @@ public void SetOutputModelExternalInitializersFile(string filePath, ulong thresh var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(filePath); NativeApiStatus.VerifySuccess( NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelExternalInitializersFile( - handle, platformPath, new UIntPtr(threshold))); + _handle, platformPath, new UIntPtr(threshold))); } // TODO: In order to use this to create an InferenceSession without copying bytes we need more infrastructure. @@ -106,7 +108,7 @@ internal void SetOutputModelBuffer(OrtAllocator allocator, { NativeApiStatus.VerifySuccess( NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelBuffer( - handle, allocator.Pointer, ref outputModelBufferPtr, ref outputModelBufferSizePtr)); + _handle, allocator.Pointer, ref outputModelBufferPtr, ref outputModelBufferSizePtr)); } /// @@ -117,7 +119,7 @@ internal void SetOutputModelBuffer(OrtAllocator allocator, public void SetEpContextEmbedMode(bool embed) { NativeApiStatus.VerifySuccess( - NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextEmbedMode(handle, embed)); + NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextEmbedMode(_handle, embed)); } /// @@ -127,26 +129,379 @@ public void SetEpContextEmbedMode(bool embed) public void SetFlags(OrtCompileApiFlags flags) { NativeApiStatus.VerifySuccess( - NativeMethods.CompileApi.OrtModelCompilationOptions_SetFlags(handle, (uint)flags)); + NativeMethods.CompileApi.OrtModelCompilationOptions_SetFlags(_handle, (uint)flags)); } - internal IntPtr Handle => handle; + /// + /// Sets information related to EP context binary file. The Ep uses this information to decide the + /// location and context binary file name when compiling with both the input and output models + /// stored in buffers. + /// + /// Path to the model directory. + /// The name of the model. + public void SetEpContextBinaryInformation(string outputDirectory, string modelName) + { + var platformOutputDirectory = NativeOnnxValueHelper.GetPlatformSerializedString(outputDirectory); + var platformModelName = NativeOnnxValueHelper.GetPlatformSerializedString(modelName); + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextBinaryInformation( + _handle, platformOutputDirectory, platformModelName)); + } + + /// + /// Sets the graph optimization level. Defaults to ORT_DISABLE_ALL if not specified. + /// + /// The graph optimization level to set. + public void SetGraphOptimizationLevel(GraphOptimizationLevel graphOptimizationLevel) + { + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetGraphOptimizationLevel( + _handle, graphOptimizationLevel)); + } + + /// + /// Delegate to write/save a buffer containing ONNX model bytes to a custom destination. The delegate + /// may be called repeatedly until the entire output model has been written out. Each call to the delegate + /// is expected to consume the entire buffer. + /// + /// The buffer to write out. + /// + public delegate void WriteBufferToDestinationDelegate(ReadOnlySpan buffer); + + /// + /// Sets a delegate that is called by ORT to write out the output model's serialized ONNX bytes. + /// The provided delegate may be called repeatedly until the entire output model has been written out. + /// Each call to the delegate is expected to consume/handle the entire input buffer. + /// + /// The delegate called by ORT to write out the model. + public void SetOutputModelWriteDelegate(WriteBufferToDestinationDelegate writeBufferDelegate) + { + _writeBufferToDestinationDelegateState?.Dispose(); + _writeBufferToDestinationDelegateState = + new DelegateResources( + new WriteBufferToDestinationConnector(writeBufferDelegate), + new NativeMethods.DOrtWriteBufferToDestinationDelegate( + WriteBufferToDestinationConnector.WriteBufferToDestinationDelegateWrapper)); + + IntPtr funcPtr = _writeBufferToDestinationDelegateState.GetFunctionPointerForDelegate(); + + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelWriteFunc( + _handle, + funcPtr, + _writeBufferToDestinationDelegateState.GetConnectorHandleAsPointer())); + } + + /// + /// Delegate called by ORT for every initializer when generating the compiled model. + /// The delegate allows the user to determine whether the initializer should be stored within the compiled + /// model or externally in a file. If the delegate chooses to store an initializer externally, the delegate + /// implementation is responsible for writing the initializer data to a file. + /// + /// The initializer's name. + /// The readonly OrtValue instance containing the data, type, and + /// shape of the initializer. + /// May be null. If the initializer is originally stored externally, + /// this contains the file path, file offset, and data size. Otherwise, this is null. + /// A new OrtExternalInitializerInfo indicating the new location of the initializer. + /// Returns null if the initializer should be stored within the generated compiled model. + /// The return value may be null. + /// + public delegate OrtExternalInitializerInfo GetInitializerLocationDelegate( + string initializerName, + IReadOnlyOrtValue initializerValue, + IReadOnlyExternalInitializerInfo originalInitializerLocation); + + /// + /// Sets a delegate that is called by ORT for every initializer when generating the compiled model. + /// The delegate allows the user to determine whether the initializer should be stored within the compiled + /// model or externally in a file. If the delegate chooses to store an initializer externally, the delegate + /// implementation is responsible for writing the initializer data to a file. + /// + /// The delegate called by ORT for every initializer. + public void SetOutputModelGetInitializerLocationDelegate( + GetInitializerLocationDelegate getInitializerLocationDelegate) + { + _getInitializerLocationDelegateState?.Dispose(); + _getInitializerLocationDelegateState = + new DelegateResources( + new GetInitializerLocationConnector(getInitializerLocationDelegate), + new NativeMethods.DOrtGetInitializerLocationDelegate( + GetInitializerLocationConnector.GetInitializerLocationDelegateWrapper)); + + IntPtr funcPtr = _getInitializerLocationDelegateState.GetFunctionPointerForDelegate(); + + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc( + _handle, + funcPtr, + _getInitializerLocationDelegateState.GetConnectorHandleAsPointer())); + } + + #region Delegate helpers + /// + /// Class to bridge the C# and native worlds for the "write buffer to destination" delegate + /// + private class WriteBufferToDestinationConnector + { + private readonly WriteBufferToDestinationDelegate _userDelegate; + + internal WriteBufferToDestinationConnector(WriteBufferToDestinationDelegate writeBufferDelegate) + { + _userDelegate = writeBufferDelegate; + } + + public static IntPtr WriteBufferToDestinationDelegateWrapper(IntPtr /* void* */ state, + IntPtr /* const void* */ buffer, + UIntPtr /* size_t */ bufferNumBytes) + { + try + { + + WriteBufferToDestinationConnector connector = (WriteBufferToDestinationConnector) + GCHandle.FromIntPtr(state).Target; + ReadOnlySpan bufferSpan; + + unsafe + { + // NOTE: A Span can only view 2GB of data. This is fine because ORT does not write out + // chunks that large. However, if we ever need to, the solution is to just write a loop here + // that repeatedly calls the delegate with smaller chunks of data. + bufferSpan = new ReadOnlySpan(buffer.ToPointer(), checked((int)bufferNumBytes)); + } + + connector._userDelegate(bufferSpan); + } + catch (Exception ex) + { + var error = $"The C# WriteBufferToDestination delegate threw an exception: {ex.Message}"; + IntPtr status = NativeMethods.OrtCreateStatus((uint)ErrorCode.Fail, + NativeOnnxValueHelper.StringToZeroTerminatedUtf8(error)); + return status; + } + + return IntPtr.Zero; + } + } + + /// + /// Class to bridge the C# and native worlds for the "get initializer location" delegate + /// + private class GetInitializerLocationConnector + { + private readonly GetInitializerLocationDelegate _userDelegate; + + internal GetInitializerLocationConnector(GetInitializerLocationDelegate getInitializerLocationDelegate) + { + _userDelegate = getInitializerLocationDelegate; + } + + public static IntPtr GetInitializerLocationDelegateWrapper( + IntPtr /* void* */ state, + IntPtr /* const char* */ initializerName, + IntPtr /* const OrtValue* */ initializerValue, + IntPtr /* const OrtExternalInitializerInfo* */ originalInitializerLocation, + out IntPtr /* OrtExternalInitializerInfo** */ newInitializerLocationOutput) + { + newInitializerLocationOutput = IntPtr.Zero; + + try + { + + GetInitializerLocationConnector connector = (GetInitializerLocationConnector)GCHandle. + FromIntPtr(state).Target; + string utf8InitializerName = NativeOnnxValueHelper.StringFromNativeUtf8(initializerName); + IReadOnlyOrtValue readOnlyInitializerValue = new OrtValue(initializerValue, owned: false); + IReadOnlyExternalInitializerInfo readOnlyOriginalInitializerLocation = null; + + if (originalInitializerLocation != IntPtr.Zero) + { + readOnlyOriginalInitializerLocation = new OrtExternalInitializerInfo( + originalInitializerLocation, ownsHandle: false); + } + // Call user's delegate, which may return the new location of the initializer. + OrtExternalInitializerInfo newInitializerLocation = connector._userDelegate( + utf8InitializerName, readOnlyInitializerValue, readOnlyOriginalInitializerLocation); + + if (newInitializerLocation != null) + { + // Delegate returned info about a new location for the initializer. + // Can't guarantee that the new external info returned by user's delegate is not referenced + // by other C# code. ORT expects to own the new external info, so create a copy here and + // give it to ORT. + string newFilePath = newInitializerLocation.GetFilePath(); + byte[] newFilePathBytes = NativeOnnxValueHelper.GetPlatformSerializedString(newFilePath); + + IntPtr status = NativeMethods.OrtCreateExternalInitializerInfo( + newFilePathBytes, + newInitializerLocation.GetFileOffset(), + (UIntPtr)newInitializerLocation.GetByteSize(), + out newInitializerLocationOutput); + + if (status != IntPtr.Zero) + { + return status; + } + } + else + { + // User's delegate did not return a new location for the initializer. ORT will store initializer + // within the generated compiled model. + newInitializerLocationOutput = IntPtr.Zero; + } + } + catch (Exception ex) + { + var error = $"The C# GetInitializerLocation delegate threw an exception: {ex.Message}"; + IntPtr status = NativeMethods.OrtCreateStatus((uint)ErrorCode.Fail, + NativeOnnxValueHelper.StringToZeroTerminatedUtf8(error)); + return status; + } + + return IntPtr.Zero; + } + } /// - /// Indicates whether the native handle is invalid. + /// Disposable class that stores resources for a delegate provided by the user. /// - public override bool IsInvalid => handle == IntPtr.Zero; + /// The type of the connector class + /// (e.g., WriteBufferToDestinationConnector) + /// The type of the native delegate. + private class DelegateResources : IDisposable + where Connector : class + where Delegate : class + { + public DelegateResources(Connector connector, Delegate @delegate) + { + _connector = connector; + _delegate = @delegate; + _connectorHandle = GCHandle.Alloc(_connector); + _delegateHandle = GCHandle.Alloc(_delegate); + } + internal IntPtr GetFunctionPointerForDelegate() + { + return Marshal.GetFunctionPointerForDelegate(_delegate); + } + + internal IntPtr GetConnectorHandleAsPointer() + { + return GCHandle.ToIntPtr(_connectorHandle); + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + if (disposing) + { + // Dispose other children disposables. We have none. + } + + if (_connectorHandle.IsAllocated) + { + _connectorHandle.Free(); + _connector = null; + } + + if (_delegateHandle.IsAllocated) + { + _delegateHandle.Free(); + _delegate = null; + } + + _disposed = true; + } + + ~DelegateResources() + { + Dispose(false); + } + + private Connector _connector = null; + private Delegate _delegate = null; + private GCHandle _connectorHandle = default; + private GCHandle _delegateHandle = default; + private bool _disposed = false; + } + #endregion + + #region IDispose implementation /// - /// Release the native instance of OrtModelCompilationOptions. + /// IDispose implementation. /// - /// true - protected override bool ReleaseHandle() + public void Dispose() { - NativeMethods.CompileApi.OrtReleaseModelCompilationOptions(handle); - handle = IntPtr.Zero; - return true; + Dispose(true); + GC.SuppressFinalize(this); } + + /// + /// IDispose implementation + /// + /// True if Dispose() has been called by the user-side code. False if + /// called by the runtime from inside the finalizer. + protected virtual void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + if (disposing) + { + _writeBufferToDestinationDelegateState?.Dispose(); + _getInitializerLocationDelegateState?.Dispose(); + } + + Debug.Assert(_handle != IntPtr.Zero); + NativeMethods.CompileApi.OrtReleaseModelCompilationOptions(_handle); + _handle = IntPtr.Zero; + _disposed = true; + } + + /// + /// Finalizer that releases the native handle if not already released by Dispose(). + /// + ~OrtModelCompilationOptions() + { + Dispose(false); + } + #endregion + + /// + /// Handle to the native OrtModelCompilationOptions object. + /// + private IntPtr _handle; + + /// + /// True if this OrtModelCompilationOptions instance has already been disposed. + /// + private bool _disposed = false; + + /// + /// Stores delegate state for the "write buffer to destination" delegate. + /// + private DelegateResources + _writeBufferToDestinationDelegateState = null; + + /// + /// Stores delegate state for the "get initializer location" delegate. + /// + private DelegateResources + _getInitializerLocationDelegateState = null; } -} \ No newline at end of file +} diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs index 098a18b7444cf..2467475b6b189 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs @@ -23,8 +23,8 @@ internal enum ErrorCode ModelLoaded = 8, NotImplemented = 9, InvalidGraph = 10, - ShapeInferenceNotRegistered = 11, - RequirementNotRegistered = 12, + ShapeInferenceNotRegistered = 11, // TODO: should be ORT_EP_FAIL + RequirementNotRegistered = 12, // TODO: should be ORT_MODEL_LOAD_CANCELED } /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs index 3edc25b307a21..84020d84c9e73 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs @@ -21,6 +21,10 @@ public struct OrtCompileApi public IntPtr ModelCompilationOptions_SetEpContextEmbedMode; public IntPtr CompileModel; public IntPtr ModelCompilationOptions_SetFlags; + public IntPtr ModelCompilationOptions_SetEpContextBinaryInformation; + public IntPtr ModelCompilationOptions_SetGraphOptimizationLevel; + public IntPtr ModelCompilationOptions_SetOutputModelWriteFunc; + public IntPtr ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc; } internal class NativeMethods @@ -101,6 +105,37 @@ public DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile uint flags); public DOrtModelCompilationOptions_SetFlags OrtModelCompilationOptions_SetFlags; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetEpContextBinaryInformation( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const ORTCHAR_T* */ outputDirectory, + byte[] /* const ORTCHAR_T* */ modelName); + public DOrtModelCompilationOptions_SetEpContextBinaryInformation + OrtModelCompilationOptions_SetEpContextBinaryInformation; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetGraphOptimizationLevel( + IntPtr /* OrtModelCompilationOptions* */ options, + GraphOptimizationLevel graphOptimizationLevel); + public DOrtModelCompilationOptions_SetGraphOptimizationLevel + OrtModelCompilationOptions_SetGraphOptimizationLevel; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelWriteFunc( + IntPtr /* OrtModelCompilationOptions* */ options, + IntPtr /* DOrtWriteBufferDelegate */ writeFunc, + IntPtr /* void* */ state); + public DOrtModelCompilationOptions_SetOutputModelWriteFunc + OrtModelCompilationOptions_SetOutputModelWriteFunc; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc( + IntPtr /* OrtModelCompilationOptions* */ options, + IntPtr /* DOrtHandleInitializerDataDelegate */ handleInitializerFunc, + IntPtr /* void* */ state); + public DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc + OrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc; + internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi) { @@ -161,6 +196,27 @@ internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi _compileApi.ModelCompilationOptions_SetFlags, typeof(DOrtModelCompilationOptions_SetFlags)); + OrtModelCompilationOptions_SetEpContextBinaryInformation = + (DOrtModelCompilationOptions_SetEpContextBinaryInformation)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetEpContextBinaryInformation, + typeof(DOrtModelCompilationOptions_SetEpContextBinaryInformation)); + + OrtModelCompilationOptions_SetGraphOptimizationLevel = + (DOrtModelCompilationOptions_SetGraphOptimizationLevel)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetGraphOptimizationLevel, + typeof(DOrtModelCompilationOptions_SetGraphOptimizationLevel)); + + OrtModelCompilationOptions_SetOutputModelWriteFunc = + (DOrtModelCompilationOptions_SetOutputModelWriteFunc)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelWriteFunc, + typeof(DOrtModelCompilationOptions_SetOutputModelWriteFunc)); + + OrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc = + (DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc)Marshal. + GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, + typeof(DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc)); + } } } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 8cca2b42e987a..53880308da261 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -368,6 +368,89 @@ public struct OrtApi public IntPtr EpDevice_Device; public IntPtr GetEpApi; public IntPtr GetTensorSizeInBytes; + + public IntPtr AllocatorGetStats; + + public IntPtr CreateMemoryInfo_V2; + public IntPtr MemoryInfoGetDeviceMemType; + public IntPtr MemoryInfoGetVendorId; + + public IntPtr ValueInfo_GetValueProducer; + public IntPtr ValueInfo_GetValueNumConsumers; + public IntPtr ValueInfo_GetValueConsumers; + public IntPtr ValueInfo_GetInitializerValue; + public IntPtr ValueInfo_GetExternalInitializerInfo; + public IntPtr ValueInfo_IsRequiredGraphInput; + public IntPtr ValueInfo_IsOptionalGraphInput; + public IntPtr ValueInfo_IsGraphOutput; + public IntPtr ValueInfo_IsConstantInitializer; + public IntPtr ValueInfo_IsFromOuterScope; + public IntPtr Graph_GetName; + public IntPtr Graph_GetModelPath; + public IntPtr Graph_GetOnnxIRVersion; + public IntPtr Graph_GetNumOperatorSets; + public IntPtr Graph_GetOperatorSets; + public IntPtr Graph_GetNumInputs; + public IntPtr Graph_GetInputs; + public IntPtr Graph_GetNumOutputs; + public IntPtr Graph_GetOutputs; + public IntPtr Graph_GetNumInitializers; + public IntPtr Graph_GetInitializers; + public IntPtr Graph_GetNumNodes; + public IntPtr Graph_GetNodes; + public IntPtr Graph_GetParentNode; + public IntPtr Graph_GetGraphView; + public IntPtr Node_GetId; + public IntPtr Node_GetName; + public IntPtr Node_GetOperatorType; + public IntPtr Node_GetDomain; + public IntPtr Node_GetSinceVersion; + public IntPtr Node_GetNumInputs; + public IntPtr Node_GetInputs; + public IntPtr Node_GetNumOutputs; + public IntPtr Node_GetOutputs; + public IntPtr Node_GetNumImplicitInputs; + public IntPtr Node_GetImplicitInputs; + public IntPtr Node_GetNumAttributes; + public IntPtr Node_GetAttributes; + public IntPtr Node_GetAttributeByName; + public IntPtr Node_GetTensorAttributeAsOrtValue; + public IntPtr OpAttr_GetType; + public IntPtr OpAttr_GetName; + public IntPtr Node_GetNumSubgraphs; + public IntPtr Node_GetSubgraphs; + public IntPtr Node_GetGraph; + public IntPtr Node_GetEpName; + public IntPtr ReleaseExternalInitializerInfo; + public IntPtr ExternalInitializerInfo_GetFilePath; + public IntPtr ExternalInitializerInfo_GetFileOffset; + public IntPtr ExternalInitializerInfo_GetByteSize; + + public IntPtr GetRunConfigEntry; + + public IntPtr EpDevice_MemoryInfo; + + public IntPtr CreateSharedAllocator; + public IntPtr GetSharedAllocator; + public IntPtr ReleaseSharedAllocator; + + public IntPtr GetTensorData; + + public IntPtr GetSessionOptionsConfigEntries; + + public IntPtr SessionGetMemoryInfoForInputs; + public IntPtr SessionGetMemoryInfoForOutputs; + public IntPtr SessionGetEpDeviceForInputs; + + public IntPtr CreateSyncStreamForEpDevice; + public IntPtr SyncStream_GetHandle; + public IntPtr ReleaseSyncStream; + + public IntPtr CopyTensors; + + public IntPtr Graph_GetModelMetadata; + public IntPtr GetModelCompatibilityForEpDevices; + public IntPtr CreateExternalInitializerInfo; } internal static class NativeMethods @@ -704,6 +787,36 @@ static NativeMethods() (DSessionOptionsSetEpSelectionPolicyDelegate)Marshal.GetDelegateForFunctionPointer( api_.SessionOptionsSetEpSelectionPolicyDelegate, typeof(DSessionOptionsSetEpSelectionPolicyDelegate)); + + OrtReleaseExternalInitializerInfo = + (DOrtReleaseExternalInitializerInfo)Marshal.GetDelegateForFunctionPointer( + api_.ReleaseExternalInitializerInfo, + typeof(DOrtReleaseExternalInitializerInfo)); + + OrtExternalInitializerInfo_GetFilePath = + (DOrtExternalInitializerInfo_GetFilePath)Marshal.GetDelegateForFunctionPointer( + api_.ExternalInitializerInfo_GetFilePath, + typeof(DOrtExternalInitializerInfo_GetFilePath)); + + OrtExternalInitializerInfo_GetFileOffset = + (DOrtExternalInitializerInfo_GetFileOffset)Marshal.GetDelegateForFunctionPointer( + api_.ExternalInitializerInfo_GetFileOffset, + typeof(DOrtExternalInitializerInfo_GetFileOffset)); + + OrtExternalInitializerInfo_GetByteSize = + (DOrtExternalInitializerInfo_GetByteSize)Marshal.GetDelegateForFunctionPointer( + api_.ExternalInitializerInfo_GetByteSize, + typeof(DOrtExternalInitializerInfo_GetByteSize)); + + OrtGetModelCompatibilityForEpDevices = (DOrtGetModelCompatibilityForEpDevices)Marshal.GetDelegateForFunctionPointer( + api_.GetModelCompatibilityForEpDevices, + typeof(DOrtGetModelCompatibilityForEpDevices)); + + OrtCreateExternalInitializerInfo = + (DOrtCreateExternalInitializerInfo)Marshal.GetDelegateForFunctionPointer( + api_.CreateExternalInitializerInfo, + typeof(DOrtCreateExternalInitializerInfo)); + } internal class NativeLib @@ -2296,6 +2409,70 @@ out IntPtr lora_adapter public delegate ref CompileApi.OrtCompileApi DOrtGetCompileApi(); #endif public static DOrtGetCompileApi OrtGetCompileApi; + + /// + /// Delegate called by ORT to write a buffer (ONNX model bytes) to a custom destination (e.g., file or stream). + /// + /// State that was provided in when the delegate was registered. + /// The buffer to write. + /// The size of the buffer in bytes. + /// OrtStatus* + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtWriteBufferToDestinationDelegate( + IntPtr /* void* */ state, + IntPtr /* const void* */ buffer, + UIntPtr /* size_t */ bufferNumBytes + ); + + /// + /// Function called by ORT to allow user to specify how an initializer should be saved while compiling + /// a model, that is, either written to an external file or stored within the model. ORT calls this function + /// for every initializer. + /// + /// State that was provided when the delegate was registered. + /// The initializer's name. + /// The OrtValue containing the initializer's data, type, and shape + /// The original initializer's location in an external file, or NULL. + /// Output parameter set to a new OrtExternalInitializerInfo instance + /// indicating the location where the function implementation stored the initializer data. If the function + /// implementation sets `newExternalInfo` to NULL, ORT stores the initializer within the generated model. + /// + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtGetInitializerLocationDelegate( + IntPtr /* void* */ state, + IntPtr /* const char* */ initializerName, + IntPtr /* const OrtValue* */ initializerValue, + IntPtr /* const OrtExternalInitializerInfo* */ externalInfo, + out IntPtr /* OrtExternalInitializerInfo** */ newExternalInfo + ); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtReleaseExternalInitializerInfo(IntPtr /* OrtExternalInitializerInfo* */ info); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCreateExternalInitializerInfo( + byte[] /* const ORTCHAR_T* */ filePath, + long /* int64_t */ fileOffset, + UIntPtr /* size_t */ byteSize, + out IntPtr /* OrtExternalInitializerInfo** */ outInfo); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const ORTCHAR_T* */ DOrtExternalInitializerInfo_GetFilePath( + IntPtr /* const OrtExternalInitializerInfo* */ info); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate long /* int64_t */ DOrtExternalInitializerInfo_GetFileOffset( + IntPtr /* const OrtExternalInitializerInfo* */ info); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate UIntPtr /* size_t */ DOrtExternalInitializerInfo_GetByteSize( + IntPtr /* const OrtExternalInitializerInfo* */ info); + + public static DOrtReleaseExternalInitializerInfo OrtReleaseExternalInitializerInfo; + public static DOrtCreateExternalInitializerInfo OrtCreateExternalInitializerInfo; + public static DOrtExternalInitializerInfo_GetFilePath OrtExternalInitializerInfo_GetFilePath; + public static DOrtExternalInitializerInfo_GetFileOffset OrtExternalInitializerInfo_GetFileOffset; + public static DOrtExternalInitializerInfo_GetByteSize OrtExternalInitializerInfo_GetByteSize; #endregion #region Auto EP API related @@ -2456,6 +2633,18 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, public static DOrtGetEpDevices OrtGetEpDevices; + /// + /// Validate compiled model compatibility for the provided EP devices. + /// + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtGetModelCompatibilityForEpDevices( + IntPtr[] /* const OrtEpDevice* const* */ ep_devices, + UIntPtr /* size_t */ num_ep_devices, + byte[] /* const char* */ compatibility_info, + out int /* OrtCompiledModelCompatibility */ out_status); + + public static DOrtGetModelCompatibilityForEpDevices OrtGetModelCompatibilityForEpDevices; + /// /// Add execution provider devices to the session options. /// Priority is based on the order of the OrtEpDevice instances. Highest priority first. diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs index fc14be00ee47b..4611428ea12ef 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs @@ -150,6 +150,45 @@ internal static byte[] GetPlatformSerializedString(string str) else return StringToZeroTerminatedUtf8(str); } + + /// + /// Converts a null-terminated path string that is pointed to by the given IntPtr handle into + /// a C# UTF-16 string. + /// + /// A path string on Windows is utf-16, but utf-8 on other operating systems. + /// + /// + internal static string StringFromNativePathString(IntPtr strPtr) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + if (strPtr == IntPtr.Zero) + { + return string.Empty; + } + + // Get length of utf16 string by checking for two 0 bytes in a row. + int length = 0; + while (Marshal.ReadInt16(strPtr, length * 2) != 0) + { + length += 1; + } + + if (length == 0) + { + return string.Empty; + } + + unsafe + { + return System.Text.Encoding.Unicode.GetString((byte*)strPtr, length * 2); + } + } + else + { + return StringFromNativeUtf8(strPtr); + } + } } // Guards an array of disposable objects on stack and disposes them in reverse order diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs index 5c70808b82be1..052d5899b52c0 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs @@ -7,6 +7,21 @@ namespace Microsoft.ML.OnnxRuntime { + /// + /// Represents the compatibility status of a pre-compiled model with one or more execution provider devices. + /// + /// + /// This enum is used to determine whether a pre-compiled model can be used with specific execution providers + /// and devices, or if recompilation is needed. + /// + public enum OrtCompiledModelCompatibility + { + EP_NOT_APPLICABLE = 0, + EP_SUPPORTED_OPTIMAL = 1, + EP_SUPPORTED_PREFER_RECOMPILATION = 2, + EP_UNSUPPORTED = 3, + } + /// /// Delegate for logging function callback. /// Supply your function and register it with the environment to receive logging callbacks via @@ -361,6 +376,31 @@ public string[] GetAvailableProviders() } } + /// + /// Validate a compiled model's compatibility information for one or more EP devices. + /// + /// The list of EP devices to validate against. + /// The compatibility string from the precompiled model to validate. + /// OrtCompiledModelCompatibility enum value denoting the compatibility status + public OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices( + IReadOnlyList epDevices, string compatibilityInfo) + { + if (epDevices == null || epDevices.Count == 0) + throw new ArgumentException("epDevices must be non-empty", nameof(epDevices)); + + var devicePtrs = new IntPtr[epDevices.Count]; + for (int i = 0; i < epDevices.Count; ++i) + { + devicePtrs[i] = epDevices[i].Handle; + } + + var infoUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(compatibilityInfo); + NativeApiStatus.VerifySuccess( + NativeMethods.OrtGetModelCompatibilityForEpDevices( + devicePtrs, (UIntPtr)devicePtrs.Length, infoUtf8, out int status)); + return (OrtCompiledModelCompatibility)status; + } + /// /// Get/Set log level property of OrtEnv instance diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtExternalInitializerInfo.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtExternalInitializerInfo.shared.cs new file mode 100644 index 0000000000000..aca16e939ce21 --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtExternalInitializerInfo.shared.cs @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + + +namespace Microsoft.ML.OnnxRuntime +{ + using System; + using System.Diagnostics; + using System.Runtime.InteropServices; + + /// + /// Class to that stores information about the file location where an "external" initializer is stored. + /// + /// + public class OrtExternalInitializerInfo : SafeHandle, IReadOnlyExternalInitializerInfo + { + // Set to false when constructed with an externally managed constant handle owned by ORT. + private readonly bool _ownsHandle = true; + + /// + /// Create a new OrtExternalInitializerInfo instance. + /// + /// The path to the file that stores the initializer data. + /// The byte offset in the file where the data is stored. + /// The size of the data (in bytes) within the file. + public OrtExternalInitializerInfo(string filePath, long fileOffset, long byteSize) + : base(IntPtr.Zero, ownsHandle: true) + { + var platformFilePath = NativeOnnxValueHelper.GetPlatformSerializedString(filePath); + NativeApiStatus.VerifySuccess( + NativeMethods.OrtCreateExternalInitializerInfo(platformFilePath, fileOffset, (UIntPtr)byteSize, out handle)); + _ownsHandle = true; + } + + /// + /// Create a new OrtExternalInitializerInfo instance from an existing native OrtExternalInitializerInfo handle. + /// + /// Native OrtExternalInitializerInfo handle. + /// True if the OrtExternalInitializerInfo instance owns the native handle. + /// Defaults to false. + internal OrtExternalInitializerInfo(IntPtr constHandle, bool ownsHandle = false) + : base(IntPtr.Zero, ownsHandle) + { + Debug.Assert(constHandle != IntPtr.Zero); + SetHandle(constHandle); + _ownsHandle = ownsHandle; + } + + /// + /// Get the file path to the file that store's the initializer's data. + /// + /// + /// The path is relative to the filesystem directory where the ONNX model was stored. + /// + /// The file path. + public string GetFilePath() + { + IntPtr filePathPtr = NativeMethods.OrtExternalInitializerInfo_GetFilePath(handle); + if (filePathPtr == IntPtr.Zero) + { + return string.Empty; + } + + return NativeOnnxValueHelper.StringFromNativePathString(filePathPtr); + } + + /// + /// Get the byte offset within the file where the initializer's data is stored. + /// + /// The file offset location. + public long GetFileOffset() + { + return NativeMethods.OrtExternalInitializerInfo_GetFileOffset(handle); + } + + /// + /// Get the size in bytes of the initializer's data within the file. + /// + /// The size in bytes of the initializer data. + public long GetByteSize() + { + UIntPtr byteSize = NativeMethods.OrtExternalInitializerInfo_GetByteSize(handle); + return checked((long)byteSize); + } + + /// + /// Indicates whether the native handle is invalid. + /// + public override bool IsInvalid { get { return handle == IntPtr.Zero; } } + + /// + /// Release the native instance of OrtExternalInitializerInfo if we own it. + /// + /// true on success and false on error. + protected override bool ReleaseHandle() + { + if (!_ownsHandle) + { + // Return false to indicate an error. + // ReleaseHandle() should not be called on a const handle that this class does not own. + return false; + } + + NativeMethods.OrtReleaseExternalInitializerInfo(handle); + handle = IntPtr.Zero; + return true; + } + } + + /// + /// Interface for all readonly methods implemented by OrtExternalInitializerInfo. + /// + public interface IReadOnlyExternalInitializerInfo + { + /// + /// Get the file path to the file that store's the initializer's data. + /// + /// + /// The path is relative to the filesystem directory where the ONNX model was stored. + /// + /// The file path. + string GetFilePath(); + + /// + /// Get the byte offset within the file where the initializer's data is stored. + /// + /// The file offset location. + long GetFileOffset(); + + /// + /// Get the size in bytes of the initializer's data within the file. + /// + /// The size in bytes of the initializer data. + long GetByteSize(); + } +} \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs index 01ee3aa5ae753..d848c63450ec1 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs @@ -33,6 +33,147 @@ public enum OnnxValueType ONNX_TYPE_OPTIONAL = 6, // It's an optional type that designates anything above (except UNKNOWN) } + /// + /// Interface for all readonly methods implemented by OrtValue. + /// + public interface IReadOnlyOrtValue + { + /// + /// Get the ONNX value type for the OrtValue (e.g., OnnxValueType.ONNX_TYPE_TENSOR). + /// + /// OnnxValueType + OnnxValueType OnnxType { get; } + + /// + /// Returns true if OrtValue contains a tensor + /// + /// true if tensor + bool IsTensor { get; } + + /// + /// Returns true if OrtValue contains a sparse tensor + /// + /// true if sparse tensor + bool IsSparseTensor { get; } + + /// + /// Returns type information about the contained OnnxValue. + /// + /// a disposable instance of OrtTypeInfo + OrtTypeInfo GetTypeInfo(); + + /// + /// Obtains Tensor And Type Information from the OrtValue iff it contains a tensor. + /// Valid only for OrtValues that contain a tensor. + /// + /// A disposable instance of OrtTensorTypeAndShapeInfo + OrtTensorTypeAndShapeInfo GetTensorTypeAndShape(); + + /// + /// Returns the size of the tensor data in bytes. + /// + /// size of the tensor data in bytes + long GetTensorSizeInBytes(); + + /// + /// Returns OrtMemoryInfo iff this OrtValue contains a tensor or a sparse tensor. + /// + /// OrtMemoryInfo that describes the underlying memory allocation + /// + OrtMemoryInfo GetTensorMemoryInfo(); + + /// + /// Returns a ReadOnlySpan over tensor native buffer that + /// provides a read-only view. + /// + /// Note, that the memory may be device allocated and, therefore, not accessible from the CPU. + /// To get memory descriptor use GetTensorMemoryInfo(). + /// + /// OrtValue must contain a non-string tensor. + /// The span is valid as long as the OrtValue instance is alive (not disposed). + /// + /// + /// ReadOnlySpan + /// + ReadOnlySpan GetTensorDataAsSpan() where T : unmanaged; + +#if NET8_0_OR_GREATER + /// + /// Returns a ReadOnlyTensorSpan over tensor native buffer that + /// provides a read-only view. + /// + /// Note, that the memory may be device allocated and, therefore, not accessible from the CPU. + /// To get memory descriptor use GetTensorMemoryInfo(). + /// + /// OrtValue must contain a non-string tensor. + /// The span is valid as long as the OrtValue instance is alive (not disposed). + /// + /// + /// ReadOnlySpan + /// + [Experimental("SYSLIB5001")] + SystemNumericsTensors.ReadOnlyTensorSpan GetTensorDataAsTensorSpan() where T : unmanaged; +#endif + + /// + /// Valid for composite ML types like map, sequence. + /// Returns 2 for map (keys, values) and N for sequence, where N is the number of elements + /// in the sequence. + /// + /// Element count + int GetValueCount(); + + /// + /// For non tensors return OrtValue element at the specified index. + /// For maps only indices 0 and 1 are valid. For sequences, [0..N) are valid. + /// See GetValueCount() to determine the valid range. + /// + /// + /// allocator to use + /// OrtValue disposable instance that points to the corresponding element of the composite type + OrtValue GetValue(int index, OrtAllocator allocator); + + /// + /// Fetch string tensor element buffer pointer at the specified index, + /// convert/copy to UTF-16 char[] and return a ReadOnlyMemory{char} instance. + /// + /// Obtain TensorTypeAndShape to get shape and element count. + /// + /// flat string tensor element index + /// ReadOnlyMemory{char} backed by a managed char[]. Its lifespan is not + /// tied to the native buffer of OrtValue. + ReadOnlyMemory GetStringElementAsMemory(int index); + + /// + /// Fetch string tensor element buffer pointer at the specified index, + /// copy/convert UTF-8 into a UTF-16 string and return it. + /// + /// Obtain TensorTypeAndShape to get shape and element count. + /// + /// flat string tensor element index + /// UTF-16 string instance + string GetStringElement(int index); + + /// + /// Get a span over the native memory of the string tensor element. + /// The span is valid as long as the OrtValue is valid. + /// + /// This is useful if you want to perform your own UTF-8 decoding or + /// you do not care about decoding. + /// Obtain TensorTypeAndShape to get shape and element count. + /// + /// flat element index + /// ReadOnlySpan over UTF-8 bytes of the string tensor element + ReadOnlySpan GetStringElementAsSpan(int index); + + /// + /// Convenience method to obtain all string tensor elements as a string array. + /// + /// string[] + /// + string[] GetStringTensorAsArray(); + } + /// /// Represents a disposable OrtValue. /// This class exposes a native instance of OrtValue. @@ -44,7 +185,7 @@ public enum OnnxValueType /// disposed properly, the pinned memory will continue to be pinned and interfere /// with GC operation. /// - public class OrtValue : IOrtValueOwner, IDisposable + public class OrtValue : IOrtValueOwner, IDisposable, IReadOnlyOrtValue { // OrtValues that are members of Sequences or Maps that map. They potentially map managed memory and we need to keep them around. // this exists only when we deal with compose ML types. @@ -52,11 +193,20 @@ public class OrtValue : IOrtValueOwner, IDisposable private IntPtr _handle; private MemoryHandle? _memHandle; // Present when the OrtValue is created on top of managed memory private bool _disposed; + private bool _owned = true; - internal OrtValue(IntPtr handle) + /// + /// Constructs OrtValue from a native handle. If `owned` is true, the OrtValue instance takes + /// ownership of the native handle and disposes it when the OrtValue instance is disposed. + /// + /// The native OrtValue handle. + /// True if this class instance owns the handle. If false, the handle + /// will not be released. Defaults to true. + internal OrtValue(IntPtr handle, bool owned = true) { _handle = handle; InitOnnxType(); + _owned = owned; } /// @@ -1464,7 +1614,10 @@ protected virtual void Dispose(bool disposing) } Debug.Assert(_handle != IntPtr.Zero); - NativeMethods.OrtReleaseValue(_handle); + if (_owned) + { + NativeMethods.OrtReleaseValue(_handle); + } _handle = IntPtr.Zero; _disposed = true; } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs index bf576b54d8b45..fe2cab57658c8 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs @@ -21,102 +21,249 @@ public class CompileApiTests [Fact] public void BasicUsage() { - var so = new SessionOptions(); - using (var compileOptions = new OrtModelCompilationOptions(so)) + using (var sessionOptions = new SessionOptions()) { - // mainly checking these don't throw which ensures all the plumbing for the binding works. - compileOptions.SetInputModelPath("model.onnx"); - compileOptions.SetOutputModelPath("compiled_model.onnx"); + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) + { + // mainly checking these don't throw which ensures all the plumbing for the binding works. + compileOptions.SetInputModelPath("model.onnx"); + compileOptions.SetOutputModelPath("compiled_model.onnx"); - compileOptions.SetOutputModelExternalInitializersFile("external_data.bin", 512); - compileOptions.SetEpContextEmbedMode(true); + compileOptions.SetOutputModelExternalInitializersFile("external_data.bin", 512); + compileOptions.SetEpContextEmbedMode(true); + compileOptions.SetGraphOptimizationLevel(GraphOptimizationLevel.ORT_ENABLE_BASIC); - } + } - // setup a new instance as SetOutputModelExternalInitializersFile is incompatible with SetOutputModelBuffer - using (var compileOptions = new OrtModelCompilationOptions(so)) - { - var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); - compileOptions.SetInputModelFromBuffer(model); + // setup a new instance as SetOutputModelExternalInitializersFile is incompatible with SetOutputModelBuffer + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + compileOptions.SetInputModelFromBuffer(model); - // SetOutputModelBuffer updates the user provided IntPtr and size when it allocates data post-compile. - // Due to that we need to allocate an IntPtr and UIntPtr here. - IntPtr bytePtr = new IntPtr(); - UIntPtr bytesSize = new UIntPtr(); - var allocator = OrtAllocator.DefaultInstance; - compileOptions.SetOutputModelBuffer(allocator, ref bytePtr, ref bytesSize); + // SetOutputModelBuffer updates the user provided IntPtr and size when it allocates data post-compile. + // Due to that we need to allocate an IntPtr and UIntPtr here. + IntPtr bytePtr = new IntPtr(); + UIntPtr bytesSize = new UIntPtr(); + var allocator = OrtAllocator.DefaultInstance; + compileOptions.SetOutputModelBuffer(allocator, ref bytePtr, ref bytesSize); + compileOptions.SetEpContextBinaryInformation("./", "squeezenet.onnx"); - compileOptions.CompileModel(); + compileOptions.CompileModel(); - Assert.NotEqual(IntPtr.Zero, bytePtr); - Assert.NotEqual(UIntPtr.Zero, bytesSize); + Assert.NotEqual(IntPtr.Zero, bytePtr); + Assert.NotEqual(UIntPtr.Zero, bytesSize); - byte[] compiledBytes = new byte[bytesSize.ToUInt64()]; - Marshal.Copy(bytePtr, compiledBytes, 0, (int)bytesSize.ToUInt32()); + byte[] compiledBytes = new byte[bytesSize.ToUInt64()]; + Marshal.Copy(bytePtr, compiledBytes, 0, (int)bytesSize.ToUInt32()); - // Check the compiled model is valid - using (var session = new InferenceSession(compiledBytes, so)) - { - Assert.NotNull(session); + // Check the compiled model is valid + using (var session = new InferenceSession(compiledBytes, sessionOptions)) + { + Assert.NotNull(session); + } + + allocator.FreeMemory(bytePtr); } - allocator.FreeMemory(bytePtr); - } + // Test using OrtCompileApiFlags.ERROR_NO_NODES_COMPILED. A model compiled with CPU EP will not generate + // any compiled EPContext nodes, so expect an ORT_FAIL error. + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + var output_model_file = "should_not_generate.onnx"; + compileOptions.SetInputModelFromBuffer(model); + compileOptions.SetOutputModelPath(output_model_file); + compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_NO_NODES_COMPILED); - // Test using OrtCompileApiFlags.ERROR_NO_NODES_COMPILED. A model compiled with CPU EP will not generate - // any compiled EPContext nodes, so expect an ORT_FAIL error. - using (var compileOptions = new OrtModelCompilationOptions(so)) - { - var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); - var output_model_file = "should_not_generate.onnx"; - compileOptions.SetInputModelFromBuffer(model); - compileOptions.SetOutputModelPath(output_model_file); - compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_NO_NODES_COMPILED); + // compile should fail + try + { + compileOptions.CompileModel(); + Assert.Fail("CompileModel() should have thrown an exception"); + } + catch (OnnxRuntimeException ex) + { + Assert.Contains("Unable to compile any nodes", ex.Message); + } - // compile should fail + Assert.False(File.Exists(output_model_file)); // Output file should not be generated. + } + + // Test using OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS. + var outputModelFile = "squeezenet_ctx.onnx"; try { - compileOptions.CompileModel(); - Assert.Fail("CompileModel() should have thrown an exception"); + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + + // Compile and generate an output model. + compileOptions.SetInputModelFromBuffer(model); + compileOptions.SetOutputModelPath(outputModelFile); + compileOptions.CompileModel(); + Assert.True(File.Exists(outputModelFile)); + + // Try to compile again with flag that prevents replacing an existing file. + // Expect failure. + compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS); + + // compile should fail + try + { + compileOptions.CompileModel(); + Assert.Fail("CompileModel() should have thrown an exception"); + } + catch (OnnxRuntimeException ex) + { + Assert.Contains("exists already", ex.Message); + } + } } - catch (OnnxRuntimeException ex) + finally { - Assert.Contains("Unable to compile any nodes", ex.Message); + if (File.Exists(outputModelFile)) + { + // This file is created by ORT, so we delete it manually in finally block. + File.Delete(outputModelFile); + } } - - Assert.False(File.Exists(output_model_file)); // Output file should not be generated. } + } + + [Fact] + public void WriteOutModelWithDelegate() + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + var outputModelFilePath = "squeezenet_write_delegate_ctx.onnx"; - // Test using OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS. - using (var compileOptions = new OrtModelCompilationOptions(so)) + using (FileStream fs = new FileStream(outputModelFilePath, FileMode.Create, FileAccess.Write, FileShare.None, + 4096, FileOptions.DeleteOnClose)) + using (var sessionOptions = new SessionOptions()) + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) { - var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); - var output_model_file = "squeezenet_ctx.onnx"; + void BasicWriteBufferDelegate(ReadOnlySpan buffer) + { + Assert.True(buffer.Length > 0); + fs.Write(buffer.ToArray(), 0, buffer.Length); // Write it out to a file + } // Compile and generate an output model. compileOptions.SetInputModelFromBuffer(model); - compileOptions.SetOutputModelPath(output_model_file); + compileOptions.SetOutputModelWriteDelegate(BasicWriteBufferDelegate); compileOptions.CompileModel(); - Assert.True(File.Exists(output_model_file)); + Assert.True(File.Exists(outputModelFilePath)); + } + } - // Try to compile again with flag that prevents replacing an existing file. - // Expect failure. - compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS); + [Fact] + public void BasicGetInitializerLocationDelegate() + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + var outputModelFilePath = "squeezenet_handle_initializer_delegate_ctx.onnx"; + var initializersFilePath = "squeezenet_handle_initializer_delegate_ctx.bin"; - // compile should fail - try + try + { + using (FileStream fs = new FileStream(initializersFilePath, FileMode.Create, FileAccess.Write, + FileShare.None, 4096, FileOptions.DeleteOnClose)) + using (var sessionOptions = new SessionOptions()) + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) { + // Custom delegate that stores large initializers in a new file. + OrtExternalInitializerInfo BasicHandleInitializer( + string initializerName, IReadOnlyOrtValue initializerValue, + IReadOnlyExternalInitializerInfo originalInitializerLocation) + { + Assert.True(initializerName.Length > 0); + + var byteSize = initializerValue.GetTensorSizeInBytes(); + if (byteSize <= 64) + { + // Keep small initializers stored within model. + return null; + } + + long byteOffset = fs.Position; + ReadOnlySpan dataSpan = initializerValue.GetTensorDataAsSpan(); + fs.Write(dataSpan.ToArray(), 0, dataSpan.Length); // Write it out to a file + + // Return the data's new location. + return new OrtExternalInitializerInfo(initializersFilePath, byteOffset, byteSize); + } + + // Compile and generate an output model. + compileOptions.SetInputModelFromBuffer(model); + compileOptions.SetOutputModelPath(outputModelFilePath); + compileOptions.SetOutputModelGetInitializerLocationDelegate(BasicHandleInitializer); compileOptions.CompileModel(); - Assert.Fail("CompileModel() should have thrown an exception"); + Assert.True(File.Exists(outputModelFilePath)); } - catch (OnnxRuntimeException ex) + } + finally + { + if (File.Exists(outputModelFilePath)) { - Assert.Contains("exists already", ex.Message); + // This file is created by ORT, so we delete it manually in finally block. + File.Delete(outputModelFilePath); } + } + } + + [Fact] + public void GetInitializerLocationDelegateThatReusesExternalInitializers() + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("conv_qdq_external_ini.onnx"); + var outputModelFilePath = "conv_qdq_external_ini.reuse.ctx.onnx"; + bool reusedExternalInitializers = false; + + try + { + using (var sessionOptions = new SessionOptions()) + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) + { + // Custom delegate that reuses the original external initializer file. + OrtExternalInitializerInfo ReuseExternalInitializers( + string initializerName, IReadOnlyOrtValue initializerValue, + IReadOnlyExternalInitializerInfo originalInitializerLocation) + { + Assert.True(initializerName.Length > 0); + + if (originalInitializerLocation != null) + { + reusedExternalInitializers = true; // For test assertion only + string originalFilePath = originalInitializerLocation.GetFilePath(); + long originalFileOffset = originalInitializerLocation.GetFileOffset(); + long originalByteSize = originalInitializerLocation.GetByteSize(); + + Assert.True(originalFilePath.Length > 0); + Assert.True(originalFileOffset >= 0); + Assert.True(originalByteSize > 0); - if (File.Exists(output_model_file)) + // This initializer comes from an external file. Reuse it for compiled model. + return new OrtExternalInitializerInfo(originalFilePath, originalFileOffset, originalByteSize); + } + + // Otherwise, embed initializers that were not originally external. + return null; + } + + // Compile and generate an output model. + compileOptions.SetInputModelFromBuffer(model); + compileOptions.SetOutputModelPath(outputModelFilePath); + compileOptions.SetOutputModelGetInitializerLocationDelegate(ReuseExternalInitializers); + compileOptions.CompileModel(); + + Assert.True(File.Exists(outputModelFilePath)); + Assert.True(reusedExternalInitializers); + } + } + finally + { + if (File.Exists(outputModelFilePath)) { - File.Delete(output_model_file); + // This file is created by ORT, so we delete it manually in finally block. + File.Delete(outputModelFilePath); } } } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs new file mode 100644 index 0000000000000..103fe5bc10106 --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// not supported on mobile platforms +#if !(ANDROID || IOS) + +namespace Microsoft.ML.OnnxRuntime.Tests; + +using System; +using System.Linq; +using Xunit; +using System.Collections.Generic; + +public class EpCompatibilityTests +{ + private readonly OrtEnv ortEnvInstance = OrtEnv.Instance(); + + private IReadOnlyList GetDevices() + { + var epDevices = ortEnvInstance.GetEpDevices(); + Assert.NotNull(epDevices); + Assert.NotEmpty(epDevices); + return epDevices; + } + + [Fact] + public void GetEpCompatibility_InvalidArgs() + { + Assert.Throws(() => ortEnvInstance.GetModelCompatibilityForEpDevices(null, "info")); + Assert.Throws(() => ortEnvInstance.GetModelCompatibilityForEpDevices(new List(), "info")); + } + + [Fact] + public void GetEpCompatibility_SingleDeviceCpuProvider() + { + var devices = GetDevices(); + var someInfo = "arbitrary-compat-string"; + + // Use CPU device + var cpu = devices.First(d => d.EpName == "CPUExecutionProvider"); + Assert.NotNull(cpu); + var selected = new List { cpu }; + var status = ortEnvInstance.GetModelCompatibilityForEpDevices(selected, someInfo); + + // CPU defaults to not applicable in this scenario + Assert.Equal(OrtCompiledModelCompatibility.EP_NOT_APPLICABLE, status); + } +} +#endif diff --git a/csharp/testdata/conv_qdq_external_ini.bin b/csharp/testdata/conv_qdq_external_ini.bin new file mode 100644 index 0000000000000..89eea0dba1fa4 Binary files /dev/null and b/csharp/testdata/conv_qdq_external_ini.bin differ diff --git a/csharp/testdata/conv_qdq_external_ini.onnx b/csharp/testdata/conv_qdq_external_ini.onnx new file mode 100644 index 0000000000000..c53e1f3ad4d9b Binary files /dev/null and b/csharp/testdata/conv_qdq_external_ini.onnx differ diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f3dcde1abe37a..cbfc38068ac2a 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -3079,6 +3079,17 @@ This version of the operator has been available since version 1 of the 'com.micr Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1, GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) usually uses top 32 experts and Mixtral(https://huggingface.co/blog/mixtral). + + The SwiGLU (Swish-Gated Linear Unit) activation function is like: + g = xW + b + l = xV + c + G = clamp(g, max=limit) + L = clamp(l, min=-limit, max=limit) + swiglu = G * sigmoid(alpha * G) * (L + beta) + where x is the input, W and V are weight matrices, b and c are bias vectors, and alpha, beta and limit are constant float parameters. + When swiglu_fusion=0, two GEMMs are not fused, and they are FC1 and FC3 in the inputs. + When swiglu_fusion=1, two GEMMs are fused so that g and l are computed in a single GEMM (FC1), and g and l are interleaved on each row of size 2 * inter_size. + When swiglu_fusion=2, two GEMMs are fused, and g and l are concatenated on each row. #### Version @@ -3088,12 +3099,20 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
activation_alpha : float
+
Alpha parameter used in activation function.
+
activation_beta : float
+
Beta parameter used in activation function.
activation_type : string
-
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+
Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu
k : int
Number of top experts to select from expert pool
normalize_routing_weights : int
Whether to normalize routing weights
+
swiglu_fusion : int
+
0: not fused, 1: fused and interleaved. 2: fused and not interleaved.
+
swiglu_limit : float
+
The limit used to clamp in SwiGLU. No clamp when limit is not provided.
use_sparse_mixer : int
Whether to use sparse mixer
@@ -3106,15 +3125,15 @@ This version of the operator has been available since version 1 of the 'com.micr
router_probs : T
2D input tensor with shape (num_rows, num_experts)
fc1_experts_weights : T
-
3D input tensor with shape (num_experts, hidden_size, inter_size)
+
3D input tensor with shape (num_experts, inter_size, hidden_size), or (num_experts, 2 * inter_size, hidden_size) for swiglu
fc1_experts_bias (optional) : T
-
2D optional input tensor with shape (num_experts, inter_size)
+
2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc2_experts_weights : T
-
3D input tensor with shape (num_experts, inter_size, hidden_size)
+
3D input tensor with shape (num_experts, hidden_size, inter_size)
fc2_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, hidden_size)
fc3_experts_weights (optional) : T
-
3D optional input tensor with shape (num_experts, hidden_size, inter_size)
+
3D optional input tensor with shape (num_experts, inter_size, hidden_size)
fc3_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, inter_size)
@@ -3129,8 +3148,8 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float), tensor(float16)
-
Constrain input and output types to float or float16 tensors.
+
T : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float tensors.
@@ -4522,14 +4541,22 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
activation_alpha : float
+
Alpha parameter used in activation function.
+
activation_beta : float
+
Beta parameter used in activation function.
activation_type : string
-
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+
Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu
expert_weight_bits : int
Number of bits used in quantized weights. Default is 4 bits
k : int
Number of top experts to select from expert pool
normalize_routing_weights : int
Whether to normalize routing weights
+
swiglu_fusion : int
+
0: not fused, 1: fused and interleaved. 2: fused and not interleaved.
+
swiglu_limit : float
+
The limit used to clamp inputs in SwiGLU. It is infinite when limit is not provided.
use_sparse_mixer : int
Whether to use sparse mixer
@@ -4542,20 +4569,20 @@ This version of the operator has been available since version 1 of the 'com.micr
router_probs : T
2D input tensor with shape (num_rows, num_experts)
fc1_experts_weights : T1
-
3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)
-
fc1_scales : T
-
2D input tensor with shape (num_experts, inter_size)
+
3D input tensor with shape (num_experts, inter_size, hidden_size), or (num_experts, inter_size, hidden_size / 2) for 4 bits. For swiglu, shape can be (num_experts, 2 * inter_size, hidden_size), or (num_experts, 2 * inter_size, hidden_size / 2) for 4 bits.
+
fc1_scales : T2
+
2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc1_experts_bias (optional) : T
-
2D optional input tensor with shape (num_experts, inter_size)
+
2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc2_experts_weights : T1
-
3D input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)
-
fc2_scales : T
+
3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2) for 4 bits
+
fc2_scales : T2
2D input tensor with shape (num_experts, hidden_size)
fc2_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, hidden_size)
fc3_experts_weights (optional) : T1
-
3D optional input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)
-
fc3_scales (optional) : T
+
3D optional input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)
+
fc3_scales (optional) : T2
2D optional input tensor with shape (num_experts, inter_size)
fc3_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, inter_size)
@@ -4571,10 +4598,12 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float16)
-
Constrain input and output types to float or float16 tensors.
+
T : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float tensors.
T1 : tensor(uint8)
Constrain weights type to uint8 tensors.
+
T2 : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain scales type to float tensors.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 3b70e5da8b3e4..660c63d056335 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -562,6 +562,7 @@ Do not modify directly.* |QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearSoftmax|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearWhere|*in* condition:**B**
*in* X:**T**
*in* x_scale:**TF**
*in* x_zero_point:**T**
*in* Y:**T**
*in* y_scale:**TF**
*in* y_zero_point:**T**
*in* z_scale:**TF**
*in* z_zero_point:**T**
*out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)| +|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T2**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T2**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T2**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)
**T1** = tensor(uint8)
**T2** = tensor(float), tensor(float16)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| @@ -937,6 +938,7 @@ Do not modify directly.* |FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|GatherBlockQuantized|*in* data:**T1**
*in* indices:**Tind**
*in* scales:**T2**
*in* zero_points:**T1**
*out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4), tensor(uint8)
**T2** = tensor(bfloat16), tensor(float), tensor(float16)
**Tind** = tensor(int32), tensor(int64)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |GemmFloat8|*in* A:**TA**
*in* B:**TB**
*in* C:**TC**
*in* scaleA:**TS**
*in* scaleB:**TS**
*in* scaleY:**TS**
*out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TS** = tensor(float)| |GemmaRotaryEmbedding|*in* emb:**U**
*in* q:**T**
*in* q_rot:**T**
*in* k:**T**
*in* k_rot:**T**
*out* output1:**T**
*out* output2:**T**|1+|**T** = tensor(float16)
**U** = tensor(float)| @@ -949,7 +951,7 @@ Do not modify directly.* |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(bfloat16), tensor(float), tensor(float16), tensor(uint8)| -|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**QK** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -957,7 +959,7 @@ Do not modify directly.* |PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |PagedAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* key_cache:**T**
*in* value_cache:**T**
*in* cumulative_sequence_length:**S**
*in* past_seqlens:**S**
*in* block_table:**S**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* key_cache_out:**T**
*out* value_cache_out:**T**|1+|**S** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| -|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float16)
**T1** = tensor(uint8)| +|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T2**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T2**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T2**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float16)
**T1** = tensor(uint8)
**T2** = tensor(bfloat16), tensor(float16)| |QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* attention_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| diff --git a/include/onnxruntime/core/common/common.h b/include/onnxruntime/core/common/common.h index adfd341451aed..820d140ccaabc 100644 --- a/include/onnxruntime/core/common/common.h +++ b/include/onnxruntime/core/common/common.h @@ -294,12 +294,26 @@ inline std::string ToUTF8String(const std::string& s) { return s; } /** * Convert a wide character string to a UTF-8 string */ -std::string ToUTF8String(const std::wstring& s); - -std::wstring ToWideString(const std::string& s); +std::string ToUTF8String(std::wstring_view s); +inline std::string ToUTF8String(const wchar_t* s) { + return ToUTF8String(std::wstring_view{s}); +} +inline std::string ToUTF8String(const std::wstring& s) { + return ToUTF8String(std::wstring_view{s}); +} +std::wstring ToWideString(std::string_view s); +inline std::wstring ToWideString(const char* s) { + return ToWideString(std::string_view{s}); +} +inline std::wstring ToWideString(const std::string& s) { + return ToWideString(std::string_view{s}); +} inline std::wstring ToWideString(const std::wstring& s) { return s; } +inline std::wstring ToWideString(std::wstring_view s) { return std::wstring{s}; } #else inline std::string ToWideString(const std::string& s) { return s; } +inline std::string ToWideString(const char* s) { return s; } +inline std::string ToWideString(std::string_view s) { return std::string{s}; } #endif constexpr size_t kMaxStrLen = 4096; diff --git a/include/onnxruntime/core/common/parse_string.h b/include/onnxruntime/core/common/parse_string.h index 6345b2a55490d..5f88d490b3415 100644 --- a/include/onnxruntime/core/common/parse_string.h +++ b/include/onnxruntime/core/common/parse_string.h @@ -35,13 +35,30 @@ template std::enable_if_t, bool> TryParseStringWithClassicLocale(std::string_view str, T& value) { T parsed_value{}; - const auto [ptr, ec] = std::from_chars(str.data(), str.data() + str.size(), parsed_value); - if (ec != std::errc{}) { + std::from_chars_result conversion_result{}; + if constexpr (std::is_integral_v && std::is_unsigned_v) { + // For unsigned integral types, also handle hex values, i.e., those beginning with "0x". + // std::from_chars() does not accept the "0x" prefix. + const bool has_hex_prefix = str.size() >= 2 && + str[0] == '0' && + (str[1] == 'x' || str[1] == 'X'); + + if (has_hex_prefix) { + str = str.substr(2); + } + + const int base = has_hex_prefix ? 16 : 10; + conversion_result = std::from_chars(str.data(), str.data() + str.size(), parsed_value, base); + } else { + conversion_result = std::from_chars(str.data(), str.data() + str.size(), parsed_value); + } + + if (conversion_result.ec != std::errc{}) { return false; } - if (ptr != str.data() + str.size()) { + if (conversion_result.ptr != str.data() + str.size()) { return false; } diff --git a/include/onnxruntime/core/common/status.h b/include/onnxruntime/core/common/status.h index da9735aa4e418..8cf6420f2d0f7 100644 --- a/include/onnxruntime/core/common/status.h +++ b/include/onnxruntime/core/common/status.h @@ -46,6 +46,7 @@ enum StatusCode { EP_FAIL = 11, MODEL_LOAD_CANCELED = 12, MODEL_REQUIRES_COMPILATION = 13, + NOT_FOUND = 14, }; constexpr const char* StatusCodeToString(StatusCode status) noexcept { @@ -78,6 +79,8 @@ constexpr const char* StatusCodeToString(StatusCode status) noexcept { return "MODEL_LOAD_CANCELED"; case StatusCode::MODEL_REQUIRES_COMPILATION: return "MODEL_REQUIRES_COMPILATION"; + case StatusCode::NOT_FOUND: + return "NOT_FOUND"; default: return "GENERAL ERROR"; } @@ -114,6 +117,8 @@ constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept { return HRESULT_FROM_WIN32(ERROR_CANCELLED); case StatusCode::MODEL_REQUIRES_COMPILATION: return HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED); + case StatusCode::NOT_FOUND: + return HRESULT_FROM_WIN32(ERROR_NOT_FOUND); default: return E_FAIL; } diff --git a/include/onnxruntime/core/common/string_helper.h b/include/onnxruntime/core/common/string_helper.h index 1304303132d5a..c0b331cb8e9a8 100644 --- a/include/onnxruntime/core/common/string_helper.h +++ b/include/onnxruntime/core/common/string_helper.h @@ -7,5 +7,9 @@ // forward declaration struct OrtAllocator; namespace onnxruntime { -char* StrDup(const std::string& str, OrtAllocator* allocator); +char* StrDup(std::string_view str, OrtAllocator* allocator); +inline char* StrDup(const std::string& str, OrtAllocator* allocator) { + return StrDup(std::string_view{str}, allocator); +} +wchar_t* StrDup(std::wstring_view str, OrtAllocator* allocator); } // namespace onnxruntime diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 7df3368ad4e0b..f54f4a5a6f1ef 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -36,6 +36,7 @@ class GraphOptimizerRegistry; #include "core/framework/framework_provider_common.h" #include "core/framework/stream_handles.h" #include "core/framework/tuning_context.h" +#include "core/session/onnxruntime_c_api.h" struct OrtEpDevice; struct OrtRunOptions; @@ -179,7 +180,12 @@ class IExecutionProvider { /** Get the device id of current execution provider */ - virtual int GetDeviceId() const { return default_device_.Id(); }; + virtual int GetDeviceId() const { return default_device_.Id(); } + + /** + * Get the OrtDevice the execution provider was registered with. + */ + const OrtDevice& GetDevice() const { return default_device_; } /** Get execution provider's configuration options. @@ -317,6 +323,29 @@ class IExecutionProvider { virtual common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs); + /** + * Get the compatibility info for a compiled model. + * + * The execution provider determines this value, which denotes the compatibility of the compiled model with the EP. + * This is stored in the model metadata under a key associated with the EP type. + */ + virtual std::string GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const { + // graph_viewer and model_metadata are not used in the default implementation. + ORT_UNUSED_PARAMETER(graph_viewer); + // Default implementation returns empty string + return std::string(); + } + + /** + * Validate the compatibility of a compiled model with this execution provider. + */ + virtual common::Status ValidateCompiledModelCompatibilityInfo(const std::string& /*compatibility_info*/, + OrtCompiledModelCompatibility& model_compatibility) const { + // Default implementation indicates this EP does not support model compatibility validation + model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + return Status::OK(); + } + #endif void SetLogger(const logging::Logger* logger) { diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index 375f0a4dc8dd2..e59a803d97629 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -305,6 +305,24 @@ using BuildKernelCreateInfoFn = KernelCreateInfo (*)(); static_cast([](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \ } +#define ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, type3, name) \ + provider##_##name##_##domain##_ver##ver##_##type1##_##type2##_##type3 + +#define ONNX_OPERATOR_THREE_TYPED_KERNEL_EX(name, domain, ver, type1, type2, type3, provider, builder, ...) \ + class ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, type3, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name) \ + .SetDomain(domain) \ + .SinceVersion(ver) \ + .Provider(provider) \ + .Build(), \ + static_cast([](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \ + } + #define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name) \ provider##_##name##_##domain##_ver##startver##_##endver##_##type diff --git a/include/onnxruntime/core/framework/ortdevice.h b/include/onnxruntime/core/framework/ortdevice.h index 536d641b4eef9..fea970b84fd84 100644 --- a/include/onnxruntime/core/framework/ortdevice.h +++ b/include/onnxruntime/core/framework/ortdevice.h @@ -150,6 +150,13 @@ struct OrtDevice { return alignment < other.alignment; } + bool EqualIgnoringAlignment(const OrtDevice& other) const { + return device_type == other.device_type && + memory_type == other.memory_type && + vendor_id == other.vendor_id && + device_id == other.device_id; + } + private: // Device type. int32_t device_type : 8; diff --git a/include/onnxruntime/core/framework/ortmemoryinfo.h b/include/onnxruntime/core/framework/ortmemoryinfo.h index d930b2289170d..1be81e77064d2 100644 --- a/include/onnxruntime/core/framework/ortmemoryinfo.h +++ b/include/onnxruntime/core/framework/ortmemoryinfo.h @@ -13,18 +13,14 @@ struct OrtMemoryInfo { OrtMemoryInfo() = default; // to allow default construction of Tensor // use string for name, so we could have customized allocator in execution provider. - const char* name = nullptr; + std::string name; OrtMemType mem_type = OrtMemTypeDefault; OrtAllocatorType alloc_type = OrtInvalidAllocator; OrtDevice device; - constexpr OrtMemoryInfo(const char* name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(), - OrtMemType mem_type_ = OrtMemTypeDefault) -#if ((defined(__GNUC__) && __GNUC__ > 4) || defined(__clang__)) - // this causes a spurious error in CentOS gcc 4.8 build so disable if GCC version < 5 - __attribute__((nonnull)) -#endif - : name(name_), + OrtMemoryInfo(std::string name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(), + OrtMemType mem_type_ = OrtMemTypeDefault) + : name(std::move(name_)), mem_type(mem_type_), alloc_type(type_), device(device_) { @@ -39,7 +35,7 @@ struct OrtMemoryInfo { if (device != other.device) return device < other.device; - return strcmp(name, other.name) < 0; + return name < other.name; } // This is to make OrtMemoryInfo a valid key in hash tables @@ -68,7 +64,7 @@ inline bool operator==(const OrtMemoryInfo& left, const OrtMemoryInfo& other) { return left.mem_type == other.mem_type && left.alloc_type == other.alloc_type && left.device == other.device && - strcmp(left.name, other.name) == 0; + left.name == other.name; } inline bool operator!=(const OrtMemoryInfo& lhs, const OrtMemoryInfo& rhs) { return !(lhs == rhs); } diff --git a/include/onnxruntime/core/framework/provider_options_utils.h b/include/onnxruntime/core/framework/provider_options_utils.h index 5967fb91523d0..badb7320ea49e 100644 --- a/include/onnxruntime/core/framework/provider_options_utils.h +++ b/include/onnxruntime/core/framework/provider_options_utils.h @@ -83,12 +83,24 @@ class ProviderOptionsParser { template ProviderOptionsParser& AddValueParser( const std::string& name, ValueParserType value_parser) { + return AddValueParser(std::string_view{name}, value_parser); + } + + template + ProviderOptionsParser& AddValueParser( + std::string_view name, ValueParserType value_parser) { ORT_ENFORCE( value_parsers_.emplace(name, ValueParser{value_parser}).second, "Provider option \"", name, "\" already has a value parser."); return *this; } + template + ProviderOptionsParser& AddValueParser( + const char* name, ValueParserType value_parser) { + return AddValueParser(std::string_view{name}, value_parser); + } + /** * Adds a parser for a particular provider option value which converts a * value to the right type and assigns it to the given reference. @@ -104,13 +116,25 @@ class ProviderOptionsParser { template ProviderOptionsParser& AddAssignmentToReference( const std::string& name, ValueType& dest) { + return AddAssignmentToReference(std::string_view{name}, dest); + } + + template + ProviderOptionsParser& AddAssignmentToReference( + std::string_view name, ValueType& dest) { return AddValueParser( name, - [&dest](const std::string& value_str) -> Status { + [&dest](std::string_view value_str) -> Status { return ParseStringWithClassicLocale(value_str, dest); }); } + template + ProviderOptionsParser& AddAssignmentToReference( + const char* name, ValueType& dest) { + return AddAssignmentToReference(std::string_view{name}, dest); + } + /** * Adds a parser for a particular provider option value which maps an * enumeration name to a value and assigns it to the given reference. @@ -128,6 +152,12 @@ class ProviderOptionsParser { template ProviderOptionsParser& AddAssignmentToEnumReference( const std::string& name, const EnumNameMapping& mapping, EnumType& dest) { + return AddAssignmentToEnumReference(std::string_view{name}, mapping, dest); + } + + template + ProviderOptionsParser& AddAssignmentToEnumReference( + std::string_view name, const EnumNameMapping& mapping, EnumType& dest) { return AddValueParser( name, [&mapping, &dest](const std::string& value_str) -> Status { @@ -135,6 +165,12 @@ class ProviderOptionsParser { }); } + template + ProviderOptionsParser& AddAssignmentToEnumReference( + const char* name, const EnumNameMapping& mapping, EnumType& dest) { + return AddAssignmentToEnumReference(std::string_view{name}, mapping, dest); + } + /** * Parses the given provider options. */ diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index e164f23b8fc35..9a0708d72b4f8 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -740,7 +740,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi common::Status InjectExternalInitializedTensors(const InlinedHashMap& external_initializers); /** This function takes externally provided files in memory for initializers with external - * data and replaces graph initializers with its content. + * data and replaces main graph initializers with its content. */ common::Status InjectExternalInitializersFromFilesInMemory( const InlinedHashMap>& external_initializer_files); @@ -1220,7 +1220,10 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi #endif #if !defined(ORT_MINIMAL_BUILD) - /** Gets the GraphProto representation of this Graph only. */ + /** Gets the GraphProto representation of this Graph only. + * This does not remove in-memory tags for graph initializers. + * Use ToGraphProto() const to get a GraphProto that can be serialized externally. + */ const ONNX_NAMESPACE::GraphProto& ToGraphProto(); /// @@ -1244,6 +1247,18 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi const std::filesystem::path& model_file_path, const ModelSavingOptions& model_saving_options) const; + /// + /// Serialize the Graph to a onnx::GraphProto. Caller provides a function that determines where each initializer + /// is stored (i.e., either in an external file or within the model). + /// + /// Function called for every initializer. + /// Opaque user state passed to the handle_initializer_func. + /// Output parameter set to the serialized onnx::GraphProto. + /// A status indicating success or an error. + common::Status ToGraphProtoWithCustomInitializerHandling(OrtGetInitializerLocationFunc handle_initializer_func, + void* state, + /*out*/ ONNX_NAMESPACE::GraphProto& graph_proto) const; + /** Gets the ISchemaRegistry instances being used with this Graph. */ IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const; @@ -1439,6 +1454,27 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi return Resolve(default_options); } + /// + /// This function converts all the graph TensorProto initializers into OrtValues + /// and creates a in-memory external data reference for each OrtValue. + /// + /// + Status ConvertInitializersIntoOrtValues(); + + /** + * @brief Converts a subset of graph TensorProto initializers into OrtValues and updates the graph proto. + * + * This function converts specified TensorProto initializers in the graph into OrtValues and + * creates in-memory external data references for each OrtValue. It then updates the provided + * GraphProto with the modified initializers. + * + * @param iterators Span of iterators pointing to the initializers and the order that should be processed + * @param output_graph_proto The GraphProto to be updated with the modified initializers + * @return Status Returns a Status object indicating success or any errors that occurred during conversion + */ + Status RegenerateInitializersAndReplaceInMemory(gsl::span iterators, + ONNX_NAMESPACE::GraphProto& output_graph_proto) const; + const std::unordered_set& GetOuterScopeNodeArgNames() const noexcept { return outer_scope_node_arg_names_; } @@ -1595,20 +1631,25 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi /// This function is used by ToGraphProto() to ensure in-memory external data references /// don't leak externally since they are non-standard. /// - /// It handles two scenarios: - /// - When GraphSynchronizationNeeded() is false: GraphProto is simply copied + /// It is used when GraphSynchronizationNeeded() is false: GraphProto is simply copied /// from graph_proto_ by ToGraphProto(). This copy includes both main graph /// and subgraph initializers. This function examines all initializers /// and inlines any in-memory data references. - /// - When GraphSynchronizationNeeded() is true: ToGraphProto() generates a new GraphProto - /// using ToGraphProtoInternal(). This doesn't transfer main graph initializers, which are - /// copied and inlined by ToGraphProto() itself. This function processes only the subgraph initializers - /// as needed. /// /// The GraphProto to process - /// Whether to process the main graph initializers - /// Status indicating success or failure /// - Status ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto, bool process_main) const; + /// Status indicating success or failure + Status ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto) const; + + /// + /// This function replaces all of the initializers within output_graph_proto + /// from this Graph instance. All in memory initializers are regenerated and inlined. + /// This is necessary even if the graph_proto_ is already up to date because initializers() may + /// contain obsolete initializers that are no longer in use due to optimizations and contain obsolete + /// references to OrtValues that may no longer be around (since we like appending rather than replacing). + /// + /// Destination GraphProto to receive the updated initializers. + /// Status indicating success or failure. + Status RegenerateInitializersAndReplaceInMemory(ONNX_NAMESPACE::GraphProto& output_graph_proto) const; /// /// This function traverses the graph bottom up and externalizes @@ -1635,6 +1676,9 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi std::ostream& external_stream, int64_t& external_offset) const; + Status ToGraphProtoWithCustomInitializerHandlingImpl(OrtGetInitializerLocationFunc handle_initializer_func, + void* state, + /*out*/ ONNX_NAMESPACE::GraphProto& output_graph_proto) const; #endif Version IrVersion() const noexcept { diff --git a/include/onnxruntime/core/graph/indexed_sub_graph.h b/include/onnxruntime/core/graph/indexed_sub_graph.h index 088db79a7e005..8ef4fdb66e1e6 100644 --- a/include/onnxruntime/core/graph/indexed_sub_graph.h +++ b/include/onnxruntime/core/graph/indexed_sub_graph.h @@ -31,7 +31,7 @@ struct IndexedSubGraph { std::string domain; ///< Domain of customized SubGraph/FunctionProto int since_version; ///< Since version of customized SubGraph/FunctionProto. - ONNX_NAMESPACE::OperatorStatus status; ///< Status of customized SubGraph/FunctionProto. + ONNX_NAMESPACE::OperatorStatus status{ONNX_NAMESPACE::OperatorStatus::STABLE}; ///< Status of customized SubGraph/FunctionProto. std::vector inputs; ///< Inputs of customized SubGraph/FunctionProto. std::vector outputs; ///< Outputs of customized SubGraph/FunctionProto. diff --git a/include/onnxruntime/core/graph/model_saving_options.h b/include/onnxruntime/core/graph/model_saving_options.h index 6c041ec96a035..06c1b1ac6475f 100644 --- a/include/onnxruntime/core/graph/model_saving_options.h +++ b/include/onnxruntime/core/graph/model_saving_options.h @@ -9,36 +9,30 @@ class PrepackedWeightsForGraph; // These options affect how the model initializers are written to the external file. // This includes options to align external initializer offset. -// For models running on CPU, ORT will try to use mmap to load external -// initializers. To use mmap, external initializer need to be offset aligned. +// ORT will try to use mmap to load external initializers. +// // ORT saves external initializers into single data file, each initializer is // accessed with offset(start position of initializer) and length(byte length of -// initializer) of the data file. To use mmap, each offset need to be aligned -// which means offset need to divisible by allocation granularity(64KB for -// windows and 4K for other OSes). With align_offset to true, ORT will align -// offset for large initializer when save ONNX model with external data file. +// initializer) of the data file. With align_offset to true, ORT will align +// offset for large initializer (larger than align_threshold) +// when save ONNX model with external data file. It will align then to +// on_disk_alignment value. struct ModelSavingOptions { explicit ModelSavingOptions(size_t size_threshold) : initializer_size_threshold(size_threshold) {} // Minimal initializer size in bytes to be externalized on disk size_t initializer_size_threshold; - // Offset will always be page aligned and allocation granularity aligned for - // mmap support. This is done by padding previous tensor data with zeros - // keeping same length. + // Offset will always be aligned for mmap support. + // This is done by padding previous tensor data with zeros keeping same length. bool align_offset = false; // Alignment threshold for size of data. // Having a low threshold will waste file space for small initializers. // Only when tensor's data size is > the page_align_threshold it will be force // aligned. Default to 1MB. int64_t align_threshold = 1048576; - // The allocation Granularity for mmap() support. - // Typically 64KB for Windows & 4KB for other OSes. Default to 64KB. -#ifdef _WIN32 - int64_t allocation_granularity = 65536; -#else - int64_t allocation_granularity = 4096; -#endif + // Alignment factor for big tensors (bigger than align_threshold). Defaults to 4K. + int64_t on_disk_alignment = 4096; // Force embed all external initializer into the Onnx file // Used for EPContext model generation while some nodes fallback on CPU which has external data dependency bool force_embed_external_ini = false; diff --git a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h index 11cc6f131dab3..026fc3b2dc0a0 100644 --- a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h +++ b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h @@ -31,10 +31,10 @@ constexpr const char* kDetailedBuildLog = "nv_detailed_build_log"; constexpr const char* kProfilesMinShapes = "nv_profile_min_shapes"; constexpr const char* kProfilesMaxShapes = "nv_profile_max_shapes"; constexpr const char* kProfilesOptShapes = "nv_profile_opt_shapes"; -constexpr const char* kCudaGraphEnable = "nv_cuda_graph_enable"; -constexpr const char* kONNXBytestream = "nv_onnx_bytestream"; -constexpr const char* kONNXBytestreamSize = "nv_onnx_bytestream_size"; +constexpr const char* kCudaGraphEnable = "enable_cuda_graph"; constexpr const char* kMultiProfileEnable = "nv_multi_profile_enable"; +constexpr const char* kUseExternalDataInitializer = "nv_use_external_data_initializer"; +constexpr const char* kRuntimeCacheFile = "nv_runtime_cache_path"; } // namespace provider_option_names namespace run_option_names { diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 0d920ab7dac89..e2b2aff2011fe 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -203,415 +203,331 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, #define ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) \ do { \ - OrtStatus* _status = (fn); \ - if (_status != nullptr) { \ - return Ort::Status{_status}; \ + Ort::Status _status{(fn)}; \ + if (!_status.IsOK()) { \ + return _status; \ } \ } while (0) #define ORT_EP_UTILS_CXX_RETURN_IF_ERROR(fn) \ - do { \ - Ort::Status _status = (fn); \ - if (!_status.IsOK()) { \ - return _status; \ - } \ - } while (0) + ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) -#define ORT_EP_UTILS_C_RETURN_IF(cond, ort_api, msg) \ - do { \ - if ((cond)) { \ - return Ort::Status{(ort_api).CreateStatus(ORT_FAIL, (msg))}; \ - } \ +#define ORT_EP_UTILS_C_RETURN_IF(cond, msg) \ + do { \ + if ((cond)) { \ + return Ort::Status{msg, ORT_FAIL}; \ + } \ } while (0) namespace OrtEpUtils { -static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, +static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi, bool get_symbolic_dims, /*out*/ ONNXTensorElementDataType& elem_type, /*out*/ std::vector& dims, /*out*/ std::vector& symbolic_dims); -static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); +static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info, onnx::ValueInfoProto& value_info_proto); +static Ort::Status OrtOpAttrToProto(Ort::ConstOpAttr ort_attr, onnx::AttributeProto& attr_proto); -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, +Ort::Status OrtGraphToProto(const OrtGraph& graph, onnx::GraphProto& graph_proto, HandleInitializerDataFunc handle_initializer_data_func) { - const OrtApi& ort_api = Ort::GetApi(); - - // - // Set GraphProto metadata - // - const char* graph_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetName(&ort_graph, &graph_name)); - graph_proto.set_name(graph_name); - graph_proto.set_doc_string("Serialized from OrtGraph"); - - // - // Set GraphProto inputs and outputs - // - size_t num_graph_inputs = 0; - size_t num_graph_outputs = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumInputs(&ort_graph, &num_graph_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOutputs(&ort_graph, &num_graph_outputs)); - - std::vector graph_inputs(num_graph_inputs); - std::vector graph_outputs(num_graph_outputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetInputs(&ort_graph, graph_inputs.data(), graph_inputs.size())); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOutputs(&ort_graph, graph_outputs.data(), graph_outputs.size())); - - for (const OrtValueInfo* ort_value_info : graph_inputs) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); - } - - for (const OrtValueInfo* ort_value_info : graph_outputs) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); - } - - // - // Set GraphProto nodes, value_infos, and initializers. - // - - // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer. - // A std::map maintains its elements in a stable ordering. - std::map value_infos; // For GraphProto.value_info - std::map initializer_value_infos; // For GraphProto.initializer - - // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`. - // Optionally returns the OrtValueInfo name to the caller. - auto collect_value_info = [&ort_api, &value_infos, - &initializer_value_infos](const OrtValueInfo& ort_value_info, - /*out*/ const char** value_name_out = nullptr) -> Ort::Status { - const char* value_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); - - if (value_name_out != nullptr) { - *value_name_out = value_name; + try { + Ort::ConstGraph ort_graph{&graph}; + // + // Set GraphProto metadata + // + auto graph_name = ort_graph.GetName(); + graph_proto.set_name(graph_name); + graph_proto.set_doc_string("Serialized from OrtGraph"); + + // + // Set GraphProto inputs and outputs + // + std::vector graph_inputs = ort_graph.GetInputs(); + std::vector graph_outputs = ort_graph.GetOutputs(); + + for (const auto& ort_value_info : graph_inputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(ort_value_info, *value_info_proto)); } - if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) { - return Ort::Status{nullptr}; // Already processed this OrtValueInfo. + for (const auto& ort_value_info : graph_outputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(ort_value_info, *value_info_proto)); } - bool is_required_graph_input = false; - bool is_optional_graph_input = false; - bool is_graph_output = false; - bool is_constant_initializer = false; - bool is_from_outer_scope = false; - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsRequiredGraphInput(&ort_value_info, &is_required_graph_input)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsOptionalGraphInput(&ort_value_info, &is_optional_graph_input)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsGraphOutput(&ort_value_info, &is_graph_output)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(&ort_value_info, &is_constant_initializer)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsFromOuterScope(&ort_value_info, &is_from_outer_scope)); - - // Don't add graph inputs or graph outputs to GraphProto's list of value_infos. - // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors. - // For values defined in an outer scope, just add the value info but not the initializer. - if (is_from_outer_scope) { - value_infos.emplace(value_name, &ort_value_info); - } else if (is_optional_graph_input) { - initializer_value_infos.emplace(value_name, &ort_value_info); - } else if (is_constant_initializer) { - value_infos.emplace(value_name, &ort_value_info); - initializer_value_infos.emplace(value_name, &ort_value_info); - } else if (!is_required_graph_input && !is_graph_output) { - value_infos.emplace(value_name, &ort_value_info); // This is an internal OrtValueInfo. - } + // + // Set GraphProto nodes, value_infos, and initializers. + // + + // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer. + // A std::map maintains its elements in a stable ordering. + std::map value_infos; // For GraphProto.value_info + std::map initializer_value_infos; // For GraphProto.initializer - return Ort::Status{nullptr}; - }; - - size_t num_nodes = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes)); - - std::vector nodes(num_nodes); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size())); - - // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos - // that will be stored in GraphProto.value_info and GraphProto.initializer. - for (size_t i = 0; i < num_nodes; i++) { - const OrtNode* ort_node = nodes[i]; - onnx::NodeProto* node_proto = graph_proto.add_node(); - - const char* node_name = nullptr; - const char* node_domain = nullptr; - const char* node_op_type = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetName(ort_node, &node_name)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetDomain(ort_node, &node_domain)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOperatorType(ort_node, &node_op_type)); - - node_proto->set_name(node_name); - node_proto->set_domain(node_domain); - node_proto->set_op_type(node_op_type); - - size_t num_inputs = 0; - size_t num_implicit_inputs = 0; - size_t num_outputs = 0; - size_t num_attrs = 0; - size_t num_subgraphs = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumInputs(ort_node, &num_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumImplicitInputs(ort_node, &num_implicit_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(ort_node, &num_outputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumAttributes(ort_node, &num_attrs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumSubgraphs(ort_node, &num_subgraphs)); - - // Handle node attributes - if (num_attrs > 0) { - std::vector ort_attrs(num_attrs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetAttributes(ort_node, ort_attrs.data(), ort_attrs.size())); - - for (const OrtOpAttr* ort_attr : ort_attrs) { - OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; - - Ort::Status attr_type_status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; + // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`. + // Optionally returns the OrtValueInfo name to the caller. + auto collect_value_info = [&value_infos, + &initializer_value_infos](Ort::ConstValueInfo ort_value_info, + /*out*/ std::optional& value_name_out) { + auto value_name = ort_value_info.GetName(); + + if (value_name_out) { + *value_name_out = value_name; + } + + if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) { + return; // Already processed this OrtValueInfo. + } + + bool is_required_graph_input = ort_value_info.IsRequiredGraphInput(); + bool is_optional_graph_input = ort_value_info.IsOptionalGraphInput(); + bool is_graph_output = ort_value_info.IsGraphOutput(); + bool is_constant_initializer = ort_value_info.IsConstantInitializer(); + bool is_from_outer_scope = ort_value_info.IsFromOuterScope(); + + // Don't add graph inputs or graph outputs to GraphProto's list of value_infos. + // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors. + // For values defined in an outer scope, just add the value info but not the initializer. + if (is_from_outer_scope) { + value_infos.emplace(value_name, ort_value_info); + } else if (is_optional_graph_input) { + initializer_value_infos.emplace(value_name, ort_value_info); + } else if (is_constant_initializer) { + value_infos.emplace(value_name, ort_value_info); + initializer_value_infos.emplace(value_name, ort_value_info); + } else if (!is_required_graph_input && !is_graph_output) { + value_infos.emplace(value_name, ort_value_info); // This is an internal OrtValueInfo. + } + }; + + std::vector nodes = ort_graph.GetNodes(); + // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos + // that will be stored in GraphProto.value_info and GraphProto.initializer. + for (const auto& ort_node : nodes) { + onnx::NodeProto* node_proto = graph_proto.add_node(); + + std::string node_name = ort_node.GetName(); + std::string node_domain = ort_node.GetDomain(); + std::string node_op_type = ort_node.GetOperatorType(); + + node_proto->set_name(node_name); + node_proto->set_domain(node_domain); + node_proto->set_op_type(node_op_type); + + // Handle node attributes + std::vector ort_attrs = ort_node.GetAttributes(); + for (const auto& attr : ort_attrs) { + OrtOpAttrType attr_type = attr.GetType(); if (attr_type == OrtOpAttrType::ORT_OP_ATTR_GRAPH) { // ORT does not support reading subgraphs via ReadOpAttr(), so skip it. // Can use Node_GetSubgraphs to get subgraphs. continue; } - if (!attr_type_status.IsOK()) { - // Unsupported attribute type. - return attr_type_status; - } - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(attr, *attr_proto)); } - } - - // Handle node subgraphs - if (num_subgraphs > 0) { - std::vector ort_subgraphs(num_subgraphs); - std::vector subgraph_attr_names(num_subgraphs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetSubgraphs(ort_node, ort_subgraphs.data(), ort_subgraphs.size(), - subgraph_attr_names.data())); - - for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) { - const OrtGraph* ort_subgraph = ort_subgraphs[subgraph_idx]; - const char* subgraph_attr_name = subgraph_attr_names[subgraph_idx]; + // Handle node subgraphs + std::vector ort_subgraphs = ort_node.GetSubgraphs(); + for (const auto& [subgraph_attr_name, ort_subgraph] : ort_subgraphs) { onnx::AttributeProto* attr_proto = node_proto->add_attribute(); onnx::GraphProto* subgraph_proto = attr_proto->mutable_g(); - attr_proto->set_name(subgraph_attr_name); attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH); ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_subgraph, *subgraph_proto)); } - } - - // Handle node inputs - if (num_inputs > 0) { - std::vector ort_inputs(num_inputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetInputs(ort_node, ort_inputs.data(), ort_inputs.size())); - for (const OrtValueInfo* ort_value_info : ort_inputs) { - if (ort_value_info == nullptr) { + // Handle node inputs + std::vector ort_inputs = ort_node.GetInputs(); + for (const auto& vi : ort_inputs) { + if (vi == nullptr) { // missing optional input. node_proto->add_input(""); continue; } - const char* value_name = nullptr; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); - - node_proto->add_input(value_name); + std::optional value_name; + value_name.emplace(); + collect_value_info(vi, value_name); + node_proto->add_input(*value_name); } - } - // Handle implicit inputs to this node. - if (num_implicit_inputs > 0) { - std::vector ort_implicit_inputs(num_implicit_inputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetImplicitInputs(ort_node, ort_implicit_inputs.data(), - ort_implicit_inputs.size())); - - for (const OrtValueInfo* ort_value_info : ort_implicit_inputs) { - assert(ort_value_info != nullptr); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, /*value_name_out*/ nullptr)); + // Handle implicit inputs to this node. + std::vector ort_implicit_inputs = ort_node.GetImplicitInputs(); + for (const auto& vi : ort_implicit_inputs) { + assert(vi != nullptr); + std::optional value_name; + collect_value_info(vi, value_name); } - } - - // Handle node outputs - if (num_outputs > 0) { - std::vector ort_outputs(num_outputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOutputs(ort_node, ort_outputs.data(), ort_outputs.size())); - for (const OrtValueInfo* ort_value_info : ort_outputs) { - if (ort_value_info == nullptr) { + // Handle node outputs + std::vector ort_outputs = ort_node.GetOutputs(); + for (const auto& vi : ort_outputs) { + if (vi == nullptr) { // missing optional output. node_proto->add_output(""); continue; } - const char* value_name = nullptr; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); - - node_proto->add_output(value_name); + std::optional value_name; + value_name.emplace(); + collect_value_info(vi, value_name); + node_proto->add_output(*value_name); } } - } - // Add value_infos to GraphProto as ValueInfoProto objects. - for (const std::pair& entry : value_infos) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*entry.second, *value_info_proto)); - } - - // Add initializers to GraphProto as TensorProto objects. - for (const std::pair& entry : initializer_value_infos) { - const OrtValueInfo* initializer_value_info = entry.second; - std::string initializer_name = std::string{entry.first}; // Need a null-terminated string. - std::vector initializer_dims; - std::vector initializer_sym_dims; - ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(*initializer_value_info, /*get_sym_dims*/ false, - initializer_elem_type, initializer_dims, - initializer_sym_dims)); - - onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); - tensor_proto->set_name(initializer_name); - tensor_proto->set_data_type(initializer_elem_type); - - auto* tensor_proto_dims = tensor_proto->mutable_dims(); - for (int64_t dim : initializer_dims) { - tensor_proto_dims->Add(dim); + // Add value_infos to GraphProto as ValueInfoProto objects. + for (const auto& [value_name, value_info] : value_infos) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(value_info, *value_info_proto)); } - const OrtValue* ort_value = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_GetInitializerValue(initializer_value_info, &ort_value)); + // Add initializers to GraphProto as TensorProto objects. + for (const auto& [initializer_name, initializer_value_info] : initializer_value_infos) { + std::vector initializer_dims; + std::vector initializer_sym_dims; + ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(initializer_value_info, /*get_sym_dims*/ false, + initializer_elem_type, initializer_dims, + initializer_sym_dims)); + + onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); + tensor_proto->set_name(initializer_name); + tensor_proto->set_data_type(initializer_elem_type); + + auto* tensor_proto_dims = tensor_proto->mutable_dims(); + for (int64_t dim : initializer_dims) { + tensor_proto_dims->Add(dim); + } - const void* data = nullptr; - size_t data_bytes = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorData(ort_value, &data)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(ort_value, &data_bytes)); + Ort::ConstValue ort_value{nullptr}; + ORT_EP_UTILS_C_RETURN_IF_ERROR(initializer_value_info.GetInitializer(ort_value)); - std::string ext_location; - int64_t ext_offset = 0; - bool is_external = false; + assert(ort_value.IsTensor()); + const void* data = ort_value.GetTensorRawData(); + const size_t data_bytes = ort_value.GetTensorSizeInBytes(); - if (handle_initializer_data_func != nullptr) { - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes, - is_external, ext_location, ext_offset)); - } + std::string ext_location; + int64_t ext_offset = 0; + bool is_external = false; - if (is_external) { - tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); - auto* ext_data_entries = tensor_proto->mutable_external_data(); - onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); - onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); - onnx::StringStringEntryProto* length_entry = ext_data_entries->Add(); - - location_entry->set_key("location"); - location_entry->set_value(ext_location); - offset_entry->set_key("offset"); - offset_entry->set_value(std::to_string(ext_offset)); - length_entry->set_key("length"); - length_entry->set_value(std::to_string(data_bytes)); - } else { - // User wants to store data inline the TensorProto's raw_data - tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); - tensor_proto->set_raw_data(data, data_bytes); + if (handle_initializer_data_func != nullptr) { + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes, + is_external, ext_location, ext_offset)); + } + + if (is_external) { + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); + auto* ext_data_entries = tensor_proto->mutable_external_data(); + onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* length_entry = ext_data_entries->Add(); + + location_entry->set_key("location"); + location_entry->set_value(ext_location); + offset_entry->set_key("offset"); + offset_entry->set_value(std::to_string(ext_offset)); + length_entry->set_key("length"); + length_entry->set_value(std::to_string(data_bytes)); + } else { + // User wants to store data inline the TensorProto's raw_data + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); + tensor_proto->set_raw_data(data, data_bytes); + } } + } catch (const Ort::Exception& ex) { + return Ort::Status{ex}; + } catch (const std::exception& ex) { + return Ort::Status{ex.what(), ORT_FAIL}; } return Ort::Status{nullptr}; } -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, +Ort::Status OrtGraphToProto(const OrtGraph& graph, onnx::ModelProto& model_proto, HandleInitializerDataFunc handle_initializer_data_func) { - const OrtApi& ort_api = Ort::GetApi(); - - // Check that OrtGraph is a top-level graph (no parent node). - const OrtNode* parent_node = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetParentNode(&ort_graph, &parent_node)); - ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, ort_api, "Cannot serialize nested OrtGraph into a ModelProto"); - - // Set model description. - model_proto.set_doc_string("Serialized from OrtGraph"); - model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto"); - - // Set ir version. - int64_t ir_version = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOnnxIRVersion(&ort_graph, &ir_version)); - model_proto.set_ir_version(ir_version); - - // Set operator sets. - size_t num_operator_sets = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOperatorSets(&ort_graph, &num_operator_sets)); - ORT_EP_UTILS_C_RETURN_IF(num_operator_sets == 0, ort_api, "OrtGraph should have at least one operator set."); - - std::vector domains(num_operator_sets, nullptr); - std::vector opset_versions(num_operator_sets); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOperatorSets(&ort_graph, domains.data(), opset_versions.data(), - num_operator_sets)); - - auto* operator_sets = model_proto.mutable_opset_import(); - - for (size_t i = 0; i < num_operator_sets; ++i) { - onnx::OperatorSetIdProto* operator_set = operator_sets->Add(); - operator_set->set_domain(domains[i]); - operator_set->set_version(opset_versions[i]); - } + try { + // Check that OrtGraph is a top-level graph (no parent node). + Ort::ConstGraph ort_graph{&graph}; + Ort::ConstNode parent_node = ort_graph.GetParentNode(); + ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, "Cannot serialize nested OrtGraph into a ModelProto"); + + // Set model description. + model_proto.set_doc_string("Serialized from OrtGraph"); + model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto"); + + // Set ir version. + int64_t ir_version = ort_graph.GetOnnxIRVersion(); + model_proto.set_ir_version(ir_version); + + // Set operator sets. + std::vector op_sets = ort_graph.GetOperatorSets(); + ORT_EP_UTILS_C_RETURN_IF(op_sets.empty(), "OrtGraph should have at least one operator set."); + + auto* operator_sets = model_proto.mutable_opset_import(); + + for (const auto& op_set : op_sets) { + onnx::OperatorSetIdProto* operator_set = operator_sets->Add(); + operator_set->set_domain(op_set.domain); + operator_set->set_version(op_set.version); + } - model_proto.clear_graph(); - onnx::GraphProto* graph_proto = model_proto.mutable_graph(); + model_proto.clear_graph(); + onnx::GraphProto* graph_proto = model_proto.mutable_graph(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_graph, *graph_proto, handle_initializer_data_func)); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(ort_graph, *graph_proto, handle_initializer_data_func)); + } catch (const Ort::Exception& ex) { + return Ort::Status(ex); + } catch (const std::exception& ex) { + return Ort::Status(ex.what(), ORT_EP_FAIL); + } return Ort::Status{nullptr}; } -static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, +static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi, bool get_symbolic_dims, /*out*/ ONNXTensorElementDataType& elem_type, /*out*/ std::vector& dims, /*out*/ std::vector& symbolic_dims) { - const OrtApi& ort_api = Ort::GetApi(); - - const OrtTypeInfo* ort_type_info = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(&ort_value_info, &ort_type_info)); - - ONNXType ort_onnx_type = ONNX_TYPE_UNKNOWN; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetOnnxTypeFromTypeInfo(ort_type_info, &ort_onnx_type)); - ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, ort_api, "Expected OrtValueInfo to represent a Tensor"); + try { + Ort::ConstTypeInfo ort_type_info = vi.TypeInfo(); + ONNXType ort_onnx_type = ort_type_info.GetONNXType(); + ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, "Expected OrtValueInfo to represent a Tensor"); - const OrtTensorTypeAndShapeInfo* ort_type_shape = nullptr; - ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(ort_type_info, &ort_type_shape)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorElementType(ort_type_shape, &ort_elem_type)); - - size_t num_dims = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensionsCount(ort_type_shape, &num_dims)); + Ort::ConstTensorTypeAndShapeInfo ort_type_shape = ort_type_info.GetTensorTypeAndShapeInfo(); + ONNXTensorElementDataType ort_elem_type = ort_type_shape.GetElementType(); - std::vector ort_dims(num_dims, 0); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensions(ort_type_shape, ort_dims.data(), ort_dims.size())); + size_t num_dims = ort_type_shape.GetDimensionsCount(); + std::vector ort_dims = ort_type_shape.GetShape(); - elem_type = ort_elem_type; - dims = std::move(ort_dims); + elem_type = ort_elem_type; + dims = std::move(ort_dims); - if (get_symbolic_dims) { - std::vector ort_dim_syms(num_dims, nullptr); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetSymbolicDimensions(ort_type_shape, ort_dim_syms.data(), - ort_dim_syms.size())); + if (get_symbolic_dims) { + std::vector ort_dim_syms(num_dims, nullptr); + ort_type_shape.GetSymbolicDimensions(ort_dim_syms.data(), ort_dim_syms.size()); - symbolic_dims.reserve(num_dims); - for (const char* sym_dim : ort_dim_syms) { - symbolic_dims.push_back(sym_dim); + symbolic_dims.reserve(num_dims); + for (const char* sym_dim : ort_dim_syms) { + symbolic_dims.push_back(sym_dim); + } } + } catch (const Ort::Exception& ex) { + return Ort::Status{ex}; + } catch (const std::exception& ex) { + return Ort::Status{ex.what(), ORT_EP_FAIL}; } - return Ort::Status{nullptr}; } // Create an onnx::ValueInfoProto from an OrtValueInfo (name, type, shape). -static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, +static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info, onnx::ValueInfoProto& value_info_proto) { - const OrtApi& ort_api = Ort::GetApi(); - std::vector ort_dims; std::vector ort_dim_syms; ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; @@ -620,9 +536,7 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, /*get_sym_dims*/ true, ort_elem_type, ort_dims, ort_dim_syms)); - const char* value_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); - value_info_proto.set_name(value_name); + value_info_proto.set_name(ort_value_info.GetName()); onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type(); type_proto_tensor->set_elem_type(ort_elem_type); @@ -652,116 +566,149 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, return Ort::Status{nullptr}; } -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { - const OrtApi& ort_api = Ort::GetApi(); +static Ort::Status OrtOpAttrToProto(Ort::ConstOpAttr attr, onnx::AttributeProto& attr_proto) { + try { + std::string attr_name = attr.GetName(); + attr_proto.set_name(attr_name); - const char* attr_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetName(&ort_attr, &attr_name)); - attr_proto.set_name(attr_name); + OrtOpAttrType attr_type = attr.GetType(); - size_t total_attr_bytes = 0; - OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetType(&ort_attr, &attr_type)); - - switch (attr_type) { - case OrtOpAttrType::ORT_OP_ATTR_INT: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_INT); - - int64_t i_val = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &i_val, sizeof(i_val), &total_attr_bytes)); - attr_proto.set_i(i_val); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_INTS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector i_vals(total_attr_bytes / sizeof(int64_t)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, i_vals.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* ints = attr_proto.mutable_ints(); - for (int64_t val : i_vals) { - ints->Add(val); + switch (attr_type) { + case OrtOpAttrType::ORT_OP_ATTR_INT: { + int64_t i_val = 0; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValue(i_val)); + attr_proto.set_type(onnx::AttributeProto_AttributeType_INT); + attr_proto.set_i(i_val); + break; } - break; - } - case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT); - - float f_val = 0.0f; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &f_val, sizeof(f_val), &total_attr_bytes)); - attr_proto.set_f(f_val); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector f_vals(total_attr_bytes / sizeof(float)); - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, f_vals.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* floats = attr_proto.mutable_floats(); - for (float val : f_vals) { - floats->Add(val); + case OrtOpAttrType::ORT_OP_ATTR_INTS: { + std::vector i_vals; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValueArray(i_vals)); + auto* ints = attr_proto.mutable_ints(); + ints->Assign(i_vals.begin(), i_vals.end()); + attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS); + break; } - break; - } - case OrtOpAttrType::ORT_OP_ATTR_STRING: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::string* str = attr_proto.mutable_s(); - - str->resize(total_attr_bytes); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, str->data(), total_attr_bytes, - &total_attr_bytes)); - - str->resize(total_attr_bytes); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector chars(total_attr_bytes, '\0'); - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, chars.data(), total_attr_bytes, - &total_attr_bytes)); + case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { + float f_val = 0.0f; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValue(f_val)); + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT); + attr_proto.set_f(f_val); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { + std::vector f_vals; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValueArray(f_vals)); + auto* floats = attr_proto.mutable_floats(); + floats->Assign(f_vals.begin(), f_vals.end()); + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRING: { + std::string* str = attr_proto.mutable_s(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValue(*str)); + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { + std::vector result; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValueArray(result)); + auto* strs = attr_proto.mutable_strings(); + strs->Assign(result.begin(), result.end()); + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_TENSOR: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR); + + onnx::TensorProto tensor_proto; + + // TensorProto as an attribute value doesn't require a name. + + Ort::Value tensor; + ORT_EP_UTILS_C_RETURN_IF_ERROR(attr.GetTensorAttributeAsOrtValue(tensor)); + + // Get tensor type and shape info + Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo(); + + // Get tensor type + ONNXTensorElementDataType element_type = type_shape_info.GetElementType(); + + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64); + break; + } + default: { + std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + } - auto* strs = attr_proto.mutable_strings(); + auto shape = type_shape_info.GetShape(); - // Strings are all in a single buffer, each separated with a '\0'. - // Extract each string and add it to the STRINGS attribute array. - char* at = chars.data(); - char* end = at + chars.size(); + for (auto& dim : shape) { + tensor_proto.add_dims(dim); + } - while (at < end) { - char* str_begin = at; + const void* data = tensor.GetTensorRawData(); + const size_t data_bytes = tensor.GetTensorSizeInBytes(); - while (*at && at < end) { - at++; - } + // Copy the Ortvalue to TensorProto as raw data + tensor_proto.set_raw_data(data, data_bytes); - strs->Add()->assign(str_begin, at - str_begin); - if (at < end) { - assert(*at == '\0'); - at++; // Skip '\0' to get to the beginning of the next string. - } + *(attr_proto.mutable_t()) = std::move(tensor_proto); + break; + } + default: { + std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); } - - break; - } - default: { - std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); - return Ort::Status(err_msg.c_str(), ORT_FAIL); } + } catch (const Ort::Exception& ex) { + return Ort::Status{ex}; + } catch (const std::exception& ex) { + return Ort::Status{ex.what(), ORT_FAIL}; } return Ort::Status{nullptr}; diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 7e49275e59b8b..59ca1a1df762e 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -20,7 +20,7 @@ #include "core/platform/threadpool.h" #include "core/session/abi_devices.h" -#include "core/session/ep_library.h" +#include "core/session/plugin_ep/ep_library.h" #include "core/session/onnxruntime_c_api.h" struct OrtThreadingOptions; @@ -106,6 +106,15 @@ class Environment { return shared_allocators_; } + /** + * Returns an AllocatorPtr for a shared IAllocator based allocator if it matches the memory info. + * The OrtMemoryInfo name and whether it's an arena or device allocator is ignored in the lookup, as is the + * alignment. + * The user calling this function is not expected to know the alignment, and we expect the allocator instance to be + * created with a valid alignment for the device. + */ + AllocatorPtr GetRegisteredSharedAllocator(const OrtMemoryInfo& mem_info) const; + /** * Removes registered allocator that was previously registered for sharing between multiple sessions. */ @@ -171,7 +180,7 @@ class Environment { std::unique_ptr inter_op_thread_pool_; bool create_global_thread_pools_{false}; - std::mutex mutex_; + mutable std::mutex mutex_; // shared allocators from various sources. // CreateAndRegisterAllocator[V2]: IAllocator allocators created by ORT @@ -190,23 +199,6 @@ class Environment { using OrtAllocatorUniquePtr = std::unique_ptr>; - // if the user calls CreateSharedAllocator and wraps the plugin EP's allocator with an arena we end up with - // OrtAllocator from EP -> wrapped in IAllocatorImplWrappingOrtAllocator -> inside a BFCArena IAllocator. - // we can put that in shared_allocators_ for sessions to use, but to have an OrtAllocator available in - // shared_ort_allocators_ that can be used outside of a session we need to additionally wrap that in an - // OrtAllocatorImplWrappingIAllocator. way too many levels of indirection but that is what it is currently. - // we need something to own that final OrtAllocator, so we add it to arena_ort_allocators_. - // - // TODO: we could split out the BFCArena implementation so it can be plugged into either an IAllocator - // or an OrtAllocator instance to reduce the indirection a little. - // with that we get an OrtAllocator from the EP, wrap it with an OrtAllocator based BFCArena, and wrap that with the - // IAllocatorImplWrappingOrtAllocator which takes ownership of the OrtAllocator and is in shared_allocators_. - // - // Alternatively we can disable wrapping an EP's allocator with a BFCArena and say the EP should provide the arena - // implementation directly. They're free to copy BFCArena as it came from TF originally. Or we could provide a - // cut-and-paste BFCArena implementation that works using the EP API that can be included in the EP source. - std::unordered_map> arena_ort_allocators_; - #if !defined(ORT_MINIMAL_BUILD) // register EPs that are built into the ORT binary so they can take part in AutoEP selection // added to ep_libraries diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 2f0e4aa7ce108..8561de9c8c3b9 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -264,6 +264,7 @@ typedef enum OrtErrorCode { ORT_EP_FAIL, ORT_MODEL_LOAD_CANCELED, ORT_MODEL_REQUIRES_COMPILATION, + ORT_NOT_FOUND, } OrtErrorCode; typedef enum OrtOpAttrType { @@ -275,6 +276,7 @@ typedef enum OrtOpAttrType { ORT_OP_ATTR_STRING, ORT_OP_ATTR_STRINGS, ORT_OP_ATTR_GRAPH, + ORT_OP_ATTR_TENSOR, } OrtOpAttrType; //! @} @@ -531,6 +533,57 @@ typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** e _Out_ size_t* num_selected, _In_ void* state); +/** \brief Function called by ORT to write a buffer to a custom destination (e.g., file, stream, etc.). + * + * \param state Opaque pointer holding the user's state. + * \param buffer The buffer to write. + * \param buffer_num_bytes The size of the buffer in bytes. + * + * \return OrtStatus* Write status. Return nullptr on success. + * Use CreateStatus to provide error info. Use ORT_FAIL as the error code. + * ORT will release the OrtStatus* if not null. + */ +typedef OrtStatus*(ORT_API_CALL* OrtWriteBufferFunc)(_In_ void* state, + _In_ const void* buffer, + _In_ size_t buffer_num_bytes); + +/** \brief Function called by ORT to allow user to specify how an initializer should be saved, that is, either + * written to an external file or stored within the model. ORT calls this function for every initializer when + * generating a model. + * + * If the function implementation sets the `new_external_info` output parameter to NULL, ORT stores the initializer data + * within the generated model. + * + * Otherwise, if the function implementation sets `new_external_info` to a valid OrtExternalInitializerInfo instance, + * ORT assumes that this function stores the initializer data in a file. In this case, ORT configures the model's + * initializer to point to the location specified by the `new_external_info` output parameter. + * + * \param[in] state Opaque pointer holding the user's state. + * \param[in] initializer_name The initializer's name as a null-terminated string. + * \param[in] initializer_value OrtValue containing the initializer's data, type, and shape. + * \param[in] external_info If the initializer is originally stored in an external file, `external_info` contains + * the file path, file offset, and the data's byte size within the file. Otherwise, + * `external_info` is NULL if the initializer is not originally stored in a file. + * \param[out] new_external_info Output parameter set to a new OrtExternalInitializerInfo instance indicating the + * location where the function implementation stored the initializer data. + * The function implementation must use `OrtApi::CreateExternalInitializerInfo()` to + * create the instance. + * If the function implementation sets `new_external_info` to NULL, + * ORT stores the initializers within the model. + * + * \note ORT takes ownership of the `new_external_info` output parameter. + * + * \return OrtStatus* Write status. Return nullptr on success. + * Use CreateStatus to provide error info. Use ORT_FAIL as the error code. + * ORT will release the OrtStatus* if not null. + */ +typedef OrtStatus*(ORT_API_CALL* OrtGetInitializerLocationFunc)( + _In_ void* state, + _In_ const char* initializer_name, + _In_ const OrtValue* initializer_value, + _In_opt_ const OrtExternalInitializerInfo* external_info, + _Outptr_result_maybenull_ OrtExternalInitializerInfo** new_external_info); + /** \brief Algorithm to use for cuDNN Convolution Op */ typedef enum OrtCudnnConvAlgoSearch { @@ -752,13 +805,13 @@ typedef struct OrtMIGraphXProviderOptions { int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true - int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true + int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, nonzero = true const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name - int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true + int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, nonzero = true const char* migraphx_save_model_path; // migraphx model path name - int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true + int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, nonzero = true const char* migraphx_load_model_path; // migraphx model path name - bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false + bool migraphx_exhaustive_tune; // MIGraphX tuned compile. Default = false, nonzero = true /** \brief MIGraphX memory limit (To use all possible memory pass in maximum size_t) * Defaults to SIZE_MAX. @@ -774,6 +827,7 @@ typedef struct OrtMIGraphXProviderOptions { */ int migraphx_arena_extend_strategy; + // This is the legacy struct and don't add new fields here. } OrtMIGraphXProviderOptions; /** \brief OpenVINO Provider Options @@ -899,6 +953,16 @@ typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t n * * \nosubgrouping */ +/* + * Public enum for compiled model compatibility across EPs. + */ +typedef enum OrtCompiledModelCompatibility { + OrtCompiledModelCompatibility_EP_NOT_APPLICABLE = 0, + OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL, + OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION, + OrtCompiledModelCompatibility_EP_UNSUPPORTED, +} OrtCompiledModelCompatibility; + struct OrtApi { /// \name OrtStatus /// @{ @@ -5826,7 +5890,7 @@ struct OrtApi { * * \since Version 1.23. */ - ORT_API2_STATUS(Graph_GetNodes, const OrtGraph* graph, + ORT_API2_STATUS(Graph_GetNodes, _In_ const OrtGraph* graph, _Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes); /** \brief Get the parent node for the given graph, if any exists. @@ -5846,14 +5910,13 @@ struct OrtApi { /** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph. * - * Note: - * The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference + * \note The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference * the same underlying graph. * * \param[in] src_graph The source OrtGraph instance. * \param[in] nodes A subset of the nodes/OrtNodes in 'graph'. * \param[in] num_nodes Number of nodes. - * \param[out] dst_sub_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph. + * \param[out] dst_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -6032,6 +6095,11 @@ struct OrtApi { * Typical usage sets this to the result of Node_GetNumAttributes(). An error status is * returned if `num_attributes` is less than the number of node attributes. * + * \note ONNX Runtime automatically sets optional (unset) attributes to their default values if the default value + * is a constant expression that does not depend on other tensor/model characteristics. Conv's 'kernel_shape' + * attribute is an example of an optional attribute that does not have a constant default value. This function + * does not provide any unset optional attributes without a constant default value. + * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. @@ -6043,14 +6111,35 @@ struct OrtApi { * * \param[in] node The OrtNode instance. * \param[in] attribute_name The name of the attribute - * \param[out] attribute Output the attribute if its name matches 'attribute_name', otherwise output nullptr. + * \param[out] attribute Output parameter set to the OrtOpAttr instance if an attribute by the given name exists. + * For an unset optional attribute, `attribute` is set to NULL and a non-error status is + * returned. For an invalid attribute name, `attribute` is set to NULL and an error status with + * code ORT_NOT_FOUND is returned. + * + * \note ONNX Runtime automatically sets optional (unset) attributes to their default values if the default value + * is a constant expression that does not depend on other tensor/model characteristics. Conv's 'kernel_shape' + * attribute is an example of an optional attribute that does not have a constant default value. This function + * does not provide any unset optional attributes without a constant default value. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ ORT_API2_STATUS(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, - _Outptr_ const OrtOpAttr** attribute); + _Outptr_result_maybenull_ const OrtOpAttr** attribute); + + /** \brief Get the OrtNode's 'TENSOR' attribute as an OrtValue. + * + * \param[in] attribute The OrtOpAttr instance. + * \param[out] attr_tensor If successful, contains the 'TENSOR' attribute as a newly created OrtValue. + Must be freed with OrtApi::ReleaseValue. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute, + _Outptr_result_maybenull_ OrtValue** attr_tensor); /** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr. * @@ -6440,6 +6529,55 @@ struct OrtApi { _In_reads_(num_tensors) OrtValue* const* dst_tensors, _In_opt_ OrtSyncStream* stream, _In_ size_t num_tensors); + + /** \brief Get ::OrtModelMetadata from an ::OrtGraph + * + * \param[in] graph The OrtGraph instance. + * \param[out] out Newly created ::OrtModelMetadata. Must be freed using OrtApi::ReleaseModelMetadata. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out); + + /** \brief Validate a compiled model's compatibility information for one or more EP devices. + * + * \param[in] ep_devices The EP devices to validate against (e.g., from GetEpDevices). + * All devices must belong to the same execution provider. + * \param[in] num_ep_devices The number of EP devices provided. + * \param[in] compatibility_info The compatibility info string produced when the model was compiled. + * \param[out] out_status The resulting compatibility status for the EP devices. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(GetModelCompatibilityForEpDevices, + _In_reads_(num_ep_devices) const OrtEpDevice* const* ep_devices, + _In_ size_t num_ep_devices, + _In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* out_status); + + /// \name OrtExternalInitializerInfo + /// @{ + + /** \brief Creates an OrtExternalInitializerInfo instance. + * + * \param[in] filepath The relative path to the file that stores the initializer's data. ORT copies this path string. + * \param[in] file_offset The byte offset where the initializer's data is stored within the file. + * \param[in] byte_size The size in bytes of the initializer's data within the file. + * \param[out] out Output parameter set to the new OrtExternalInitializerInfo instance. + * Must be released by calling ReleaseExternalInitializerInfo(). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(CreateExternalInitializerInfo, _In_ const ORTCHAR_T* filepath, _In_ int64_t file_offset, + _In_ size_t byte_size, _Outptr_ OrtExternalInitializerInfo** out); + + /// @} }; /* @@ -7007,6 +7145,9 @@ struct OrtCompileApi { * ReleaseOrtModelCompilationsOptions must be called to free the OrtModelCompilationOptions after calling * CompileModel. * + * \note By default, the GraphOptimizationLevel is set to ORT_DISABLE_ALL. Use + * ModelCompilationOptions_SetGraphOptimizationLevel to enable graph optimizations. + * * \param[in] env OrtEnv object. * \param[in] session_options The OrtSessionOptions instance from which to create the OrtModelCompilationOptions. * \param[out] out The created OrtModelCompilationOptions instance. @@ -7163,7 +7304,7 @@ struct OrtCompileApi { * \since Version 1.23. */ ORT_API2_STATUS(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_compile_options, - size_t flags); + uint32_t flags); /** Sets information related to EP context binary file. * @@ -7182,6 +7323,56 @@ struct OrtCompileApi { _In_ OrtModelCompilationOptions* model_compile_options, _In_ const ORTCHAR_T* output_directory, _In_ const ORTCHAR_T* model_name); + + /** Set the graph optimization level. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] graph_optimization_level The graph optimization level. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetGraphOptimizationLevel, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ GraphOptimizationLevel graph_optimization_level); + + /** \brief Sets a OrtWriteBufferFunc function that is called by ORT to write out the output model's serialized + * ONNX bytes. + * + * The provided write function may be called repeatedly until then entire output model has been written out. Each call + * to the write function is expected to consume the entire input buffer. + * + * The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions + * that begin with ModelCompilationOptions_SetOutputModel____. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] write_func The OrtWriteBufferFunc function called by ORT when writing out the model. + * \param[in] state Opaque state passed as the first argument to OrtWriteBufferFunc. Can be NULL. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelWriteFunc, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ OrtWriteBufferFunc write_func, _In_ void* state); + + /** \brief Sets a OrtGetInitializerLocationFunc function that is called by ORT for every initializer in the generated + * model. Allows implementer to specify whether initializers should be stored within the model or externally. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] get_initializer_location_func The OrtGetInitializerLocationFunc function called by ORT when + * to determine the location of the initializer. + * \param[in] state Opaque state passed as the first argument to OrtGetInitializerLocationFunc. Can be NULL. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index d1b08f127fa2a..9fa7915679f62 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -52,6 +52,7 @@ namespace Ort { * If ORT_NO_EXCEPTIONS is defined, then any error will result in a call to abort() */ struct Exception : std::exception { + Exception(const std::string& string, OrtErrorCode code) : message_{string}, code_{code} {} Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {} OrtErrorCode GetOrtErrorCode() const { return code_; } @@ -549,34 +550,43 @@ namespace detail { inline void OrtRelease(Ort##NAME* ptr) { API_GETTER().Release##NAME(ptr); } ORT_DEFINE_RELEASE(Allocator); -ORT_DEFINE_RELEASE(MemoryInfo); +ORT_DEFINE_RELEASE(ArenaCfg); ORT_DEFINE_RELEASE(CustomOpDomain); -ORT_DEFINE_RELEASE(ThreadingOptions); ORT_DEFINE_RELEASE(Env); -ORT_DEFINE_RELEASE(RunOptions); +ORT_DEFINE_RELEASE(ExternalInitializerInfo); +ORT_DEFINE_RELEASE(Graph); +ORT_DEFINE_RELEASE(IoBinding); +ORT_DEFINE_RELEASE(KernelInfo); +ORT_DEFINE_RELEASE(KeyValuePairs); ORT_DEFINE_RELEASE(LoraAdapter); +ORT_DEFINE_RELEASE(MemoryInfo); +ORT_DEFINE_RELEASE(MapTypeInfo); +ORT_DEFINE_RELEASE(Model); +ORT_DEFINE_RELEASE(ModelMetadata); +ORT_DEFINE_RELEASE(Node); +ORT_DEFINE_RELEASE(Op); +ORT_DEFINE_RELEASE(OpAttr); +ORT_DEFINE_RELEASE(PrepackedWeightsContainer); +ORT_DEFINE_RELEASE(RunOptions); ORT_DEFINE_RELEASE(Session); ORT_DEFINE_RELEASE(SessionOptions); -ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); ORT_DEFINE_RELEASE(SequenceTypeInfo); -ORT_DEFINE_RELEASE(MapTypeInfo); +ORT_DEFINE_RELEASE(Status); +ORT_DEFINE_RELEASE(SyncStream); +ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); +ORT_DEFINE_RELEASE(ThreadingOptions); ORT_DEFINE_RELEASE(TypeInfo); ORT_DEFINE_RELEASE(Value); -ORT_DEFINE_RELEASE(ModelMetadata); -ORT_DEFINE_RELEASE(IoBinding); -ORT_DEFINE_RELEASE(ArenaCfg); -ORT_DEFINE_RELEASE(Status); -ORT_DEFINE_RELEASE(OpAttr); -ORT_DEFINE_RELEASE(Op); -ORT_DEFINE_RELEASE(KernelInfo); ORT_DEFINE_RELEASE(ValueInfo); -ORT_DEFINE_RELEASE(Node); -ORT_DEFINE_RELEASE(Graph); -ORT_DEFINE_RELEASE(Model); -ORT_DEFINE_RELEASE(KeyValuePairs) + ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi); +// This is defined explicitly since OrtTensorRTProviderOptionsV2 is not a C API type, +// but the struct has V2 in its name to indicate that it is the second version of the options. +inline void OrtRelease(OrtTensorRTProviderOptionsV2* ptr) { GetApi().ReleaseTensorRTProviderOptions(ptr); } +inline void OrtRelease(OrtCUDAProviderOptionsV2* ptr) { GetApi().ReleaseCUDAProviderOptions(ptr); } + #undef ORT_DEFINE_RELEASE #undef ORT_DEFINE_RELEASE_FROM_API_STRUCT @@ -628,6 +638,7 @@ struct Base { } constexpr operator contained_type*() const noexcept { return p_; } + constexpr contained_type& operator*() const noexcept { return *p_; } /// \brief Relinquishes ownership of the contained C object pointer /// The underlying object is not destroyed @@ -672,6 +683,7 @@ struct Base> { } constexpr operator contained_type*() const noexcept { return p_; } + constexpr contained_type& operator*() const noexcept { return *p_; } protected: contained_type* p_{}; @@ -692,11 +704,17 @@ struct AllocatedFree { struct AllocatorWithDefaultOptions; struct Env; struct EpDevice; +struct ExternalInitializerInfo; struct Graph; struct Model; struct Node; struct ModelMetadata; struct TypeInfo; +struct PrepackedWeightsContainer; +struct Session; +struct SessionOptions; +struct SyncStream; +struct TensorRTProviderOptions; struct Value; struct ValueInfo; @@ -711,14 +729,12 @@ using AllocatedStringPtr = std::unique_ptr; * constructors to construct an instance of a Status object from exceptions. */ struct Status : detail::Base { - using Base = detail::Base; - using Base::Base; - - explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used - explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. - explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception - explicit Status(const std::exception&) noexcept; ///< Creates status instance out of exception - Status(const char* message, OrtErrorCode code) noexcept; ///< Creates status instance out of null-terminated string message. + Status() = default; // Same as with std::nullptr_t. But can be used in re-sizable containers and represent success. + explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used + explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. + explicit Status(const Exception&); ///< Creates status instance out of exception + explicit Status(const std::exception&); ///< Creates status instance out of exception + Status(const char* message, OrtErrorCode code); ///< Creates status instance out of null-terminated string message. std::string GetErrorMessage() const; OrtErrorCode GetErrorCode() const; bool IsOK() const noexcept; ///< Returns true if instance represents an OK (non-error) status. @@ -754,6 +770,98 @@ struct ThreadingOptions : detail::Base { ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); }; +/** \brief The TensorRTOptions (V2) + * + * Used to pass options to TRT EP + */ +struct TensorRTProviderOptions : detail::Base { + TensorRTProviderOptions(std::nullptr_t) {} + /// \brief Wraps OrtApi::CreateTensorRTProviderOptionsV2 + TensorRTProviderOptions(); + ///< Wrapper around OrtApi::UpdateTensorRTProviderOptions + void Update(const std::unordered_map& options); + ///< Wrapper around OrtApi::UpdateTensorRTProviderOptions + void UpdateWithValue(const char* key, void* value); + + ///< Wrapper around OrtApi::GetTensorRTProviderOptionsByName + void* GetOptionByName(const char* name) const; + ///< Wrapper around OrtApi::GetTensorRTProviderOptionsAsString + std::string GetTensorRTProviderOptionsAsString() const; +}; + +/** \brief The CUDAProviderOptions (V2) + * + * Used to pass options to CUDA EP + */ +struct CUDAProviderOptions : detail::Base { + CUDAProviderOptions(std::nullptr_t) {} + /// \brief Wraps OrtApi::CreateCUDAProviderOptions + CUDAProviderOptions(); + ///< Wrapper around OrtApi::UpdateCUDAProviderOptions + void Update(const std::unordered_map& options); + ///< Wrapper around OrtApi::GetCUDAProviderOptionsAsString + std::string GetCUDAProviderOptionsAsString() const; + ///< Wrapper around OrtApi::UpdateCUDAProviderOptionsWithValue + void UpdateWithValue(const char* key, void* value); + ///< Wrapper around OrtApi::GetCUDAProviderOptionsByName + void* GetOptionByName(const char* name) const; +}; + +/** \brief The PrepackedWeightsContainer + * + * Create only and pass to Ort::Session constructor for multiple sessions + * to share pre-packed weights. + */ +struct PrepackedWeightsContainer : detail::Base { + using Base = detail::Base; + ///< No instance is created + explicit PrepackedWeightsContainer(std::nullptr_t) {} + ///< Take ownership of a pointer created by C API + explicit PrepackedWeightsContainer(OrtPrepackedWeightsContainer* p) : Base{p} {} + /// \brief Wraps OrtApi::CreatePrepackedWeightsContainer + PrepackedWeightsContainer(); +}; + +namespace detail { +template +struct ConstExternalInitializerInfoImpl : Base { + using B = Base; + using B::B; + + // Wraps OrtApi::ExternalInitializerInfo_GetFilePath + const std::basic_string GetFilePath() const; + // Wraps OrtApi::ExternalInitializerInfo_GetFileOffset + int64_t GetFileOffset() const; + // Wraps OrtApi::ExternalInitializerInfo_GetByteSize + size_t GetByteSize() const; +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstExternalInitializerInfo = + detail::ConstExternalInitializerInfoImpl>; + +/** \brief Wrapper around ::OrtExternalInitializerInfo + * + */ +struct ExternalInitializerInfo : detail::ConstExternalInitializerInfoImpl { + using Base = detail::ConstExternalInitializerInfoImpl; + using Base::Base; + + explicit ExternalInitializerInfo(std::nullptr_t) {} + explicit ExternalInitializerInfo(OrtExternalInitializerInfo* p) + : detail::ConstExternalInitializerInfoImpl{p} {} + + ConstExternalInitializerInfo GetConst() const { return ConstExternalInitializerInfo{this->p_}; } + + ///< Wraps OrtApi::CreateExternalInitializerInfo + ExternalInitializerInfo(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size); + + ///< Wrapper around CreateExternalInitializerInfo that does not throw an exception. + static Status Create(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size, + /*out*/ ExternalInitializerInfo& out); +}; + namespace detail { template struct KeyValuePairsImpl : Ort::detail::Base { @@ -793,6 +901,111 @@ struct KeyValuePairs : detail::KeyValuePairsImpl { ConstKeyValuePairs GetConst() const { return ConstKeyValuePairs{this->p_}; } }; +namespace detail { +template +struct MemoryInfoImpl : Base { + using B = Base; + using B::B; + + std::string GetAllocatorName() const; ///< Wrapper MemoryInfoGetName + OrtAllocatorType GetAllocatorType() const; ///< Wrapper MemoryInfoGetType + int GetDeviceId() const; ///< Wrapper MemoryInfoGetId + OrtMemoryInfoDeviceType GetDeviceType() const; ///< Wrapper MemoryInfoGetDeviceType + OrtMemType GetMemoryType() const; ///< Wrapper MemoryInfoGetMemType + OrtDeviceMemoryType GetDeviceMemoryType() const; ///< Wrapper MemoryInfoGetDeviceMemType + uint32_t GetVendorId() const; ///< Wrapper MemoryInfoGetVendorId + + template + bool operator==(const MemoryInfoImpl& o) const; +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstMemoryInfo = detail::MemoryInfoImpl>; + +/** \brief Wrapper around ::OrtMemoryInfo + * + */ +struct MemoryInfo : detail::MemoryInfoImpl { + static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); + explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created + explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C API + MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); + MemoryInfo(const char* name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id, uint32_t device_id, + OrtDeviceMemoryType mem_type, size_t alignment, OrtAllocatorType allocator_type); ///< Wrapper around CreateMemoryInfo_V2 + ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; } +}; + +/// +/// Represents native memory allocation coming from one of the +/// OrtAllocators registered with OnnxRuntime. +/// Use it to wrap an allocation made by an allocator +/// so it can be automatically released when no longer needed. +/// +struct MemoryAllocation { + MemoryAllocation(OrtAllocator* allocator, void* p, size_t size); + ~MemoryAllocation(); + MemoryAllocation(const MemoryAllocation&) = delete; + MemoryAllocation& operator=(const MemoryAllocation&) = delete; + MemoryAllocation(MemoryAllocation&&) noexcept; + MemoryAllocation& operator=(MemoryAllocation&&) noexcept; + + void* get() { return p_; } + size_t size() const { return size_; } + + private: + OrtAllocator* allocator_; + void* p_; + size_t size_; +}; + +namespace detail { +template +struct AllocatorImpl : Base { + using B = Base; + using B::B; + + void* Alloc(size_t size); + MemoryAllocation GetAllocation(size_t size); + void Free(void* p); + ConstMemoryInfo GetInfo() const; + + /** \brief Function that returns the statistics of the allocator. + * + * \return A pointer to a KeyValuePairs object that will be filled with the allocator statistics. + */ + KeyValuePairs GetStats() const; +}; +} // namespace detail + +/** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime + * + */ +struct AllocatorWithDefaultOptions : detail::AllocatorImpl> { + explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance + AllocatorWithDefaultOptions(); +}; + +/** \brief Wrapper around ::OrtAllocator + * + */ + +struct Allocator : detail::AllocatorImpl { + explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance + Allocator(const Session& session, const OrtMemoryInfo*); +}; + +using UnownedAllocator = detail::AllocatorImpl>; + +/** \brief Wrapper around ::OrtSyncStream + * + */ +struct SyncStream : detail::Base { + explicit SyncStream(std::nullptr_t) {} ///< Create an empty SyncStream object, must be assigned a valid one to be used + explicit SyncStream(OrtSyncStream* p) : Base{p} {} ///< Take ownership of a pointer created by C API + void* GetHandle() const; ///< Wraps SyncStream_GetHandle +}; + namespace detail { template struct HardwareDeviceImpl : Ort::detail::Base { @@ -823,6 +1036,8 @@ struct EpDeviceImpl : Ort::detail::Base { ConstKeyValuePairs EpMetadata() const; ConstKeyValuePairs EpOptions() const; ConstHardwareDevice Device() const; + ConstMemoryInfo GetMemoryInfo(OrtDeviceMemoryType memory_type) const; ///< Wraps EpDevice_MemoryInfo + SyncStream CreateSyncStream(ConstKeyValuePairs stream_options = {}) const; /// Wraps EpDevice_CreateSyncStream }; } // namespace detail @@ -842,6 +1057,16 @@ struct EpDevice : detail::EpDeviceImpl { ConstKeyValuePairs ep_metadata = {}, ConstKeyValuePairs ep_options = {}); }; +/** \brief Validate a compiled model's compatibility for one or more EP devices. + * + * Throws on error. Returns the resulting compatibility status. + * /// \param ep_devices The EP devices to check compatibility against. + * /// \param compatibility_info The compatibility string from the precompiled model to validate. + */ +OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices( + const std::vector& ep_devices, + const char* compatibility_info); + /** \brief The Env (Environment) * * The Env holds the logging state used by all other objects. @@ -877,10 +1102,28 @@ struct Env : detail::Base { const std::unordered_map& options, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2 + Env& RegisterAllocator(OrtAllocator* allocator); ///< Wraps OrtApi::RegisterAllocator + + Env& UnregisterAllocator(const OrtMemoryInfo* mem_info); ///< Wraps OrtApi::UnregisterAllocator + + UnownedAllocator CreateSharedAllocator(const OrtEpDevice* ep_device, OrtDeviceMemoryType mem_type, + OrtAllocatorType allocator_type, + const OrtKeyValuePairs* allocator_options); ///< Wraps OrtApi::CreateSharedAllocator + + // Result may be nullptr + UnownedAllocator GetSharedAllocator(const OrtMemoryInfo* mem_info); ///< Wraps OrtApi::GetSharedAllocator + + void ReleaseSharedAllocator(const OrtEpDevice* ep_device, + OrtDeviceMemoryType mem_type); ///< Wraps OrtApi::ReleaseSharedAllocator + Env& RegisterExecutionProviderLibrary(const char* registration_name, const std::basic_string& path); ///< Wraps OrtApi::RegisterExecutionProviderLibrary Env& UnregisterExecutionProviderLibrary(const char* registration_name); ///< Wraps OrtApi::UnregisterExecutionProviderLibrary std::vector GetEpDevices() const; + + Status CopyTensors(const std::vector& src_tensors, + const std::vector& dst_tensors, + OrtSyncStream* stream) const; ///< Wraps OrtApi::CopyTensors }; /** \brief Custom Op Domain @@ -1018,8 +1261,6 @@ struct CustomOpConfigs { * Wraps ::OrtSessionOptions object and methods */ -struct SessionOptions; - namespace detail { // we separate const-only methods because passing const ptr to non-const methods // is only discovered when inline methods are compiled which is counter-intuitive @@ -1077,6 +1318,7 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { const std::vector& external_initializer_file_buffer_array, const std::vector& external_initializer_file_lengths); ///< Wraps OrtApi::AddExternalInitializersFromFilesInMemory + SessionOptionsImpl& AppendExecutionProvider_CPU(int use_arena); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CPU SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2 SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM @@ -1159,11 +1401,23 @@ struct ModelCompilationOptions : detail::Base { ModelCompilationOptions& SetOutputModelPath(const ORTCHAR_T* output_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelPath ModelCompilationOptions& SetOutputModelExternalInitializersFile(const ORTCHAR_T* file_path, size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile + + ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc + ModelCompilationOptions& SetOutputModelGetInitializerLocationFunc( + OrtGetInitializerLocationFunc get_initializer_location_func, + void* state); + ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer + + ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelWriteFunc + ModelCompilationOptions& SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void* state); + ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory, const ORTCHAR_T* model_name); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation - ModelCompilationOptions& SetFlags(size_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags + ModelCompilationOptions& SetFlags(uint32_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags + + ModelCompilationOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::ModelCompilationOptions_SetGraphOptimizationLevel }; /** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels. @@ -1264,6 +1518,10 @@ struct ConstSessionImpl : Base { std::vector GetOutputNames() const; std::vector GetOverridableInitializerNames() const; + std::vector GetMemoryInfoForInputs() const; ///< Wrapper for OrtApi::SessionGetMemoryInfoForInputs + std::vector GetMemoryInfoForOutputs() const; ///< Wrapper for OrtApi::SessionGetMemoryInfoForOutputs + std::vector GetEpDeviceForInputs() const; ///< Wrapper for OrtApi::SessionGetEpDeviceForInputs + /** \brief Returns a copy of input name at the specified index. * * \param index must less than the value returned by GetInputCount() @@ -1427,37 +1685,6 @@ struct Session : detail::SessionImpl { UnownedSession GetUnowned() const { return UnownedSession{this->p_}; } }; -namespace detail { -template -struct MemoryInfoImpl : Base { - using B = Base; - using B::B; - - std::string GetAllocatorName() const; - OrtAllocatorType GetAllocatorType() const; - int GetDeviceId() const; - OrtMemoryInfoDeviceType GetDeviceType() const; - OrtMemType GetMemoryType() const; - - template - bool operator==(const MemoryInfoImpl& o) const; -}; -} // namespace detail - -// Const object holder that does not own the underlying object -using ConstMemoryInfo = detail::MemoryInfoImpl>; - -/** \brief Wrapper around ::OrtMemoryInfo - * - */ -struct MemoryInfo : detail::MemoryInfoImpl { - static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); - explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created - explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C API - MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); - ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; } -}; - namespace detail { template struct TensorTypeAndShapeInfoImpl : Base { @@ -1686,7 +1913,7 @@ struct ConstValueImpl : Base { /// /// const pointer to data, no copies made template - const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// + const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorData /// /// /// Returns a non-typed pointer to a tensor contained data. @@ -1956,7 +2183,7 @@ struct Value : detail::ValueImpl { using OrtSparseValuesParam = detail::OrtSparseValuesParam; using Shape = detail::Shape; - explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used + Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used Value(Value&&) = default; Value& operator=(Value&&) = default; @@ -2121,67 +2348,6 @@ struct Value : detail::ValueImpl { #endif // !defined(DISABLE_SPARSE_TENSORS) }; -/// -/// Represents native memory allocation coming from one of the -/// OrtAllocators registered with OnnxRuntime. -/// Use it to wrap an allocation made by an allocator -/// so it can be automatically released when no longer needed. -/// -struct MemoryAllocation { - MemoryAllocation(OrtAllocator* allocator, void* p, size_t size); - ~MemoryAllocation(); - MemoryAllocation(const MemoryAllocation&) = delete; - MemoryAllocation& operator=(const MemoryAllocation&) = delete; - MemoryAllocation(MemoryAllocation&&) noexcept; - MemoryAllocation& operator=(MemoryAllocation&&) noexcept; - - void* get() { return p_; } - size_t size() const { return size_; } - - private: - OrtAllocator* allocator_; - void* p_; - size_t size_; -}; - -namespace detail { -template -struct AllocatorImpl : Base { - using B = Base; - using B::B; - - void* Alloc(size_t size); - MemoryAllocation GetAllocation(size_t size); - void Free(void* p); - ConstMemoryInfo GetInfo() const; - - /** \brief Function that returns the statistics of the allocator. - * - * \return A pointer to a KeyValuePairs object that will be filled with the allocator statistics. - */ - KeyValuePairs GetStats() const; -}; - -} // namespace detail - -/** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime - * - */ -struct AllocatorWithDefaultOptions : detail::AllocatorImpl> { - explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance - AllocatorWithDefaultOptions(); -}; - -/** \brief Wrapper around ::OrtAllocator - * - */ -struct Allocator : detail::AllocatorImpl { - explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance - Allocator(const Session& session, const OrtMemoryInfo*); -}; - -using UnownedAllocator = detail::AllocatorImpl>; - namespace detail { namespace binding_utils { // Bring these out of template @@ -2244,21 +2410,58 @@ struct ArenaCfg : detail::Base { * See docs/C_API.md for details on what the following parameters mean and how to choose these values */ ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk); + + /** + * Wraps Ort::CreateArenaCfgV2 + * See C API for details on what the following parameters mean and how to choose these values + */ + explicit ArenaCfg(const std::unordered_map& arena_config); }; // // Custom OPs (only needed to implement custom OPs) // +namespace detail { +// Need to define a templated ConstOpAttr with const members +template +struct ConstOpAttrImpl : Base { + using B = detail::Base; + using B::B; + + // Wraps OrtApi::OpAttr_GetName + std::string GetName() const; + // Wraps OrtApi::OpAttr_GetType + OrtOpAttrType GetType() const; + + // Wraps OrtApi::ReadAttr for a single value + // This does not support Tensor Attribute + // Call GetTensorAttributeAsOrtValue() instead. + template + Status GetValue(R& out) const; + + // Wraps OrtApi::ReadAttr for an array of values + template + Status GetValueArray(std::vector& out) const; + // Wraps OrtApi::OpAttr_GetTensorAttributeAsOrtValue + Status GetTensorAttributeAsOrtValue(Value&) const; +}; +} // namespace detail + +using ConstOpAttr = detail::ConstOpAttrImpl>; + /// /// This struct provides life time management for custom op attribute /// -struct OpAttr : detail::Base { - using Base = detail::Base; +struct OpAttr : detail::ConstOpAttrImpl { + using Base = detail::ConstOpAttrImpl; using Base::Base; + OpAttr() = default; // Enable storing it in the container for resize() explicit OpAttr(std::nullptr_t) {} OpAttr(const char* name, const void* data, int len, OrtOpAttrType type); + + ConstOpAttr GetConst() const { return ConstOpAttr{this->p_}; } }; /** @@ -2604,7 +2807,7 @@ struct ShapeInferContext { Strings GetAttrStrings(const char* attr_name); private: - const OrtOpAttr* GetAttrHdl(const char* attr_name) const; + ConstOpAttr GetAttrHdl(const char* attr_name) const; const OrtApi* ort_api_; OrtShapeInferContext* ctx_; std::vector input_shapes_; @@ -2755,48 +2958,114 @@ struct CustomOpBase : OrtCustomOp { int end_ver_ = MAX_CUSTOM_OP_END_VER; }; +// Forward declaration to resolve circular dependency +// on ConstNode +struct ValueInfoConsumerProducerInfo; + namespace detail { template -struct ValueInfoImpl : Ort::detail::Base { - using B = Ort::detail::Base; +struct ConstValueInfoImpl : Base { + using B = Base; using B::B; - std::string Name() const; + /// < A wrapper around OrtApi::GetValueInfoName + std::string GetName() const; + /// < A wrapper around OrtApi::GetValueInfoTypeInfo ConstTypeInfo TypeInfo() const; + ///< Wraps OrtApi::ValueInfo_GetProducerNode + ValueInfoConsumerProducerInfo GetProducerNode() const; + /// < A wrapper around OrtApi::ValueInfo_GetValueConsumers + std::vector GetConsumers() const; + /// < A wrapper around OrtApi::ValueInfo_GetInitializerValue + Status GetInitializer(ConstValue& value) const; + /// < A wrapper around OrtApi::ValueInfo_GetExternalInitializerInfo + Status GetExternalInitializerInfo(ExternalInitializerInfo& info) const; + /// < A wrapper around OrtApi::ValueInfo_IsRequiredGraphInput + bool IsRequiredGraphInput() const; + /// < A wrapper around OrtApi::ValueInfo_IsOptionalGraphInput + bool IsOptionalGraphInput() const; + /// < A wrapper around OrtApi::ValueInfo_IsGraphOutput + bool IsGraphOutput() const; + /// < A wrapper around OrtApi::ValueInfo_IsConstantInitializer + bool IsConstantInitializer() const; + /// < A wrapper around OrtApi::ValueInfo_IsFromOuterScope + bool IsFromOuterScope() const; }; } // namespace detail // Const object holder that does not own the underlying object -using ConstValueInfo = detail::ValueInfoImpl>; +using ConstValueInfo = detail::ConstValueInfoImpl>; /** \brief Wrapper around ::OrtValueInfo * */ -struct ValueInfo : detail::ValueInfoImpl { +struct ValueInfo : detail::ConstValueInfoImpl { + ValueInfo() = default; // Same thing as with nullptr explicit ValueInfo(std::nullptr_t) {} ///< No instance is created /// Take ownership of a pointer created by C API - explicit ValueInfo(OrtValueInfo* p) : ValueInfoImpl{p} {} + explicit ValueInfo(OrtValueInfo* p) : ConstValueInfoImpl{p} {} +#if !defined(ORT_MINIMAL_BUILD) // Create ValueInfo for a tensor explicit ValueInfo(const std::string& name, const ConstTypeInfo& type_info); - +#endif ConstValueInfo GetConst() const { return ConstValueInfo{this->p_}; } }; +// Forward declaration +struct AttrNameSubgraph; + namespace detail { +// Forward decl template -struct NodeImpl : Ort::detail::Base { - using B = Ort::detail::Base; +struct ConstGraphImpl; + +template +struct ConstNodeImpl : Base { + using B = Base; using B::B; + + // GetInputs() const; + // GetOutputs() const; + // GetImplicitInputs() const; + // GetAttributes() const; + // GetSubgraphs() const; + // > GetGraph() const; + // >; + /** \brief Wrapper around ::OrtNode * */ -struct Node : detail::NodeImpl { - explicit Node(std::nullptr_t) {} ///< No instance is created - explicit Node(OrtNode* p) : NodeImpl{p} {} ///< Take ownership of a pointer created by C API +struct Node : detail::ConstNodeImpl { + Node() = default; // Same thing as with nullptr + explicit Node(std::nullptr_t) {} ///< No instance is created + explicit Node(OrtNode* p) : ConstNodeImpl{p} {} ///< Take ownership of a pointer created by C API #if !defined(ORT_MINIMAL_BUILD) Node(const std::string& operator_name, const std::string& operator_domain, @@ -2823,21 +3092,78 @@ struct Node : detail::NodeImpl { #endif // !defined(ORT_MINIMAL_BUILD) }; +// Return struct for some of ValueInfo APIs. +// Must be declared after ConstNode is available. +struct ValueInfoConsumerProducerInfo { + ConstNode node; + // either producer output or consumer output index + // producer is unsigned only, output can be -1 + int64_t index; +}; + +// Represents a return value for Graph::GetOperatorSets() +struct OperatorSet { + std::string domain; + int64_t version; +}; + namespace detail { template -struct GraphImpl : Ort::detail::Base { - using B = Ort::detail::Base; +struct ConstGraphImpl : Base { + using B = Base; + using B::B; + + // GetModelPath() const; + // GetOperatorSets() const; + // GetInputs() const; + // GetOutputs() const; + // GetInitializers() const; + // GetNodes() const; + // & nodes) const; + // +struct GraphImpl : ConstGraphImpl { + using B = ConstGraphImpl; using B::B; #if !defined(ORT_MINIMAL_BUILD) + // & inputs); + // & outputs); + // >; + +// Return value for Node API +// Must be declared after ConstGraph +struct AttrNameSubgraph { + std::string attr_name; + ConstGraph sub_graph; +}; + /** \brief Wrapper around ::OrtGraph * */ @@ -2845,24 +3171,26 @@ struct Graph : detail::GraphImpl { explicit Graph(std::nullptr_t) {} ///< No instance is created explicit Graph(OrtGraph* p) : GraphImpl{p} {} ///< Take ownership of a pointer created by C API #if !defined(ORT_MINIMAL_BUILD) + // -struct ModelImpl : Ort::detail::Base { +struct ModelImpl : detail::Base { using B = Ort::detail::Base; using B::B; #if !defined(ORT_MINIMAL_BUILD) + // >; +using UnownedModel = detail::ModelImpl>; /** \brief Wrapper around ::OrtModel * @@ -2874,10 +3202,9 @@ struct Model : detail::ModelImpl { explicit Model(OrtModel* p) : ModelImpl{p} {} ///< Take ownership of a pointer created by C API #if !defined(ORT_MINIMAL_BUILD) + //< Wraps GetModelEditorApi().CreateModel() explicit Model(const std::vector& opsets); #endif - - ConstModel GetConst() const { return ConstModel{this->p_}; } }; } // namespace Ort #include "onnxruntime_cxx_inline.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 705f17c5d6f43..59979189eed0f 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -56,15 +56,15 @@ inline void ThrowOnError(const Status& st) { inline Status::Status(OrtStatus* status) noexcept : detail::Base{status} { } -inline Status::Status(const std::exception& e) noexcept { +inline Status::Status(const std::exception& e) { p_ = GetApi().CreateStatus(ORT_FAIL, e.what()); } -inline Status::Status(const Exception& e) noexcept { +inline Status::Status(const Exception& e) { p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what()); } -inline Status::Status(const char* message, OrtErrorCode code) noexcept { +inline Status::Status(const char* message, OrtErrorCode code) { p_ = GetApi().CreateStatus(code, message); } @@ -296,6 +296,16 @@ inline OrtMemType MemoryInfoImpl::GetMemoryType() const { return type; } +template +inline OrtDeviceMemoryType MemoryInfoImpl::GetDeviceMemoryType() const { + return GetApi().MemoryInfoGetDeviceMemType(this->p_); +} + +template +inline uint32_t MemoryInfoImpl::GetVendorId() const { + return GetApi().MemoryInfoGetVendorId(this->p_); +} + template template inline bool MemoryInfoImpl::operator==(const MemoryInfoImpl& o) const { @@ -316,6 +326,12 @@ inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, O ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_)); } +inline MemoryInfo::MemoryInfo(const char* name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id, uint32_t device_id, + OrtDeviceMemoryType mem_type, size_t alignment, OrtAllocatorType allocator_type) { + ThrowOnError(GetApi().CreateMemoryInfo_V2(name, device_type, vendor_id, device_id, mem_type, alignment, + allocator_type, &this->p_)); +} + namespace detail { template inline std::vector ConstIoBindingImpl::GetOutputNames() const { @@ -404,20 +420,7 @@ inline std::vector GetOutputNamesHelper(const OrtIoBinding* binding inline std::vector GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) { std::vector result; - size_t owned = 0; size_t output_count = 0; - // Lambda to release the buffer when no longer needed and - // make sure that we destroy all instances on exception - auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) { - if (buffer) { - while (owned < output_count) { - auto* p = buffer + owned++; - GetApi().ReleaseValue(*p); - } - allocator->Free(allocator, buffer); - } - }; - using Ptr = std::unique_ptr; OrtValue** output_buffer = nullptr; ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count)); @@ -425,12 +428,11 @@ inline std::vector GetOutputValuesHelper(const OrtIoBinding* binding, Ort return result; } - Ptr buffer_g(output_buffer, free_fn); + std::unique_ptr buffer_g(output_buffer, AllocatedFree(allocator)); result.reserve(output_count); for (size_t i = 0; i < output_count; ++i) { result.emplace_back(output_buffer[i]); - ++owned; } return result; } @@ -446,6 +448,18 @@ inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_)); } +inline ArenaCfg::ArenaCfg(const std::unordered_map& arena_config) { + std::vector keys; + std::vector values; + keys.reserve(arena_config.size()); + values.reserve(arena_config.size()); + for (const auto& kv : arena_config) { + keys.push_back(kv.first.c_str()); + values.push_back(kv.second); + } + ThrowOnError(GetApi().CreateArenaCfgV2(keys.data(), values.data(), arena_config.size(), &p_)); +} + inline ThreadingOptions::ThreadingOptions() { ThrowOnError(GetApi().CreateThreadingOptions(&p_)); } @@ -485,6 +499,114 @@ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustom return *this; } +inline TensorRTProviderOptions::TensorRTProviderOptions() { + ThrowOnError(GetApi().CreateTensorRTProviderOptions(&this->p_)); +} + +inline void TensorRTProviderOptions::Update(const std::unordered_map& options) { + std::vector keys; + std::vector values; + keys.reserve(options.size()); + values.reserve(options.size()); + for (const auto& kv : options) { + keys.push_back(kv.first.c_str()); + values.push_back(kv.second.c_str()); + } + ThrowOnError(GetApi().UpdateTensorRTProviderOptions(p_, keys.data(), values.data(), options.size())); +} + +inline void TensorRTProviderOptions::UpdateWithValue(const char* key, void* value) { + ThrowOnError(GetApi().UpdateTensorRTProviderOptionsWithValue(p_, key, value)); +} + +inline void* TensorRTProviderOptions::GetOptionByName(const char* name) const { + void* value = nullptr; + ThrowOnError(GetApi().GetTensorRTProviderOptionsByName(p_, name, &value)); + return value; +} + +inline std::string TensorRTProviderOptions::GetTensorRTProviderOptionsAsString() const { + AllocatorWithDefaultOptions allocator; + char* options_str = nullptr; + ThrowOnError(GetApi().GetTensorRTProviderOptionsAsString(p_, allocator, &options_str)); + std::unique_ptr options_str_g(options_str, detail::AllocatedFree(allocator)); + return std::string(options_str); +} + +inline CUDAProviderOptions::CUDAProviderOptions() { + ThrowOnError(GetApi().CreateCUDAProviderOptions(&this->p_)); +} + +inline void CUDAProviderOptions::Update(const std::unordered_map& options) { + std::vector keys; + std::vector values; + keys.reserve(options.size()); + values.reserve(options.size()); + for (const auto& kv : options) { + keys.push_back(kv.first.c_str()); + values.push_back(kv.second.c_str()); + } + ThrowOnError(GetApi().UpdateCUDAProviderOptions(p_, keys.data(), values.data(), options.size())); +} + +inline std::string CUDAProviderOptions::GetCUDAProviderOptionsAsString() const { + AllocatorWithDefaultOptions allocator; + char* options_str = nullptr; + ThrowOnError(GetApi().GetCUDAProviderOptionsAsString(p_, allocator, &options_str)); + std::unique_ptr options_str_g(options_str, detail::AllocatedFree(allocator)); + return std::string(options_str); +} + +inline void CUDAProviderOptions::UpdateWithValue(const char* key, void* value) { + ThrowOnError(GetApi().UpdateCUDAProviderOptionsWithValue(p_, key, value)); +} + +inline void* CUDAProviderOptions::GetOptionByName(const char* name) const { + void* value = nullptr; + ThrowOnError(GetApi().GetCUDAProviderOptionsByName(p_, name, &value)); + return value; +} + +inline PrepackedWeightsContainer::PrepackedWeightsContainer() { + ThrowOnError(GetApi().CreatePrepackedWeightsContainer(&this->p_)); +} + +namespace detail { + +template +inline const std::basic_string ConstExternalInitializerInfoImpl::GetFilePath() const { + return GetApi().ExternalInitializerInfo_GetFilePath(this->p_); +} + +template +inline int64_t ConstExternalInitializerInfoImpl::GetFileOffset() const { + return GetApi().ExternalInitializerInfo_GetFileOffset(this->p_); +} + +template +inline size_t ConstExternalInitializerInfoImpl::GetByteSize() const { + return GetApi().ExternalInitializerInfo_GetByteSize(this->p_); +} +} // namespace detail + +inline ExternalInitializerInfo::ExternalInitializerInfo(const ORTCHAR_T* filepath, int64_t file_offset, + size_t byte_size) { + ThrowOnError(GetApi().CreateExternalInitializerInfo(filepath, file_offset, byte_size, &this->p_)); +} + +inline Status ExternalInitializerInfo::Create(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size, + /*out*/ ExternalInitializerInfo& out) { + OrtExternalInitializerInfo* info = nullptr; + OrtStatus* status = GetApi().CreateExternalInitializerInfo(filepath, file_offset, byte_size, &info); + if (status != nullptr) { + return Status{status}; + } + + out = ExternalInitializerInfo(info); + + return Status{nullptr}; +} + namespace detail { template inline const char* KeyValuePairsImpl::GetValue(const char* key) const { @@ -547,6 +669,10 @@ inline void KeyValuePairs::Remove(const char* key) { GetApi().RemoveKeyValuePair(this->p_, key); } +inline void* SyncStream::GetHandle() const { + return GetApi().SyncStream_GetHandle(this->p_); +} + namespace detail { template inline OrtHardwareDeviceType HardwareDeviceImpl::Type() const { @@ -597,6 +723,19 @@ template inline ConstHardwareDevice EpDeviceImpl::Device() const { return ConstHardwareDevice(GetApi().EpDevice_Device(this->p_)); } + +template +inline ConstMemoryInfo EpDeviceImpl::GetMemoryInfo(OrtDeviceMemoryType memory_type) const { + const auto* mem_info = GetApi().EpDevice_MemoryInfo(this->p_, memory_type); + return ConstMemoryInfo{mem_info}; +} + +template +inline SyncStream EpDeviceImpl::CreateSyncStream(ConstKeyValuePairs stream_options) const { + OrtSyncStream* stream = nullptr; + ThrowOnError(GetApi().CreateSyncStreamForEpDevice(this->p_, stream_options, &stream)); + return SyncStream{stream}; +} } // namespace detail inline EpDevice::EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardware_device, @@ -676,6 +815,16 @@ inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type, return *this; } +inline Env& Env::RegisterAllocator(OrtAllocator* allocator) { + ThrowOnError(GetApi().RegisterAllocator(p_, allocator)); + return *this; +} + +inline Env& Env::UnregisterAllocator(const OrtMemoryInfo* mem_info) { + ThrowOnError(GetApi().UnregisterAllocator(p_, mem_info)); + return *this; +} + inline Env& Env::RegisterExecutionProviderLibrary(const char* registration_name, const std::basic_string& path) { ThrowOnError(GetApi().RegisterExecutionProviderLibrary(p_, registration_name, path.c_str())); @@ -703,6 +852,41 @@ inline std::vector Env::GetEpDevices() const { return devices; } +inline Status Env::CopyTensors(const std::vector& src_tensors, + const std::vector& dst_tensors, + OrtSyncStream* stream) const { + if (src_tensors.size() != dst_tensors.size()) { + return Status("Source and destination tensor vectors must have the same size", ORT_INVALID_ARGUMENT); + } + if (src_tensors.empty()) { + return Status(nullptr); + } + + const OrtValue* const* src_tensors_ptr = reinterpret_cast(src_tensors.data()); + OrtValue* const* dst_tensors_ptr = reinterpret_cast(dst_tensors.data()); + OrtStatus* status = GetApi().CopyTensors(p_, src_tensors_ptr, dst_tensors_ptr, stream, src_tensors.size()); + return Status(status); +} + +inline UnownedAllocator Env::CreateSharedAllocator(const OrtEpDevice* ep_device, OrtDeviceMemoryType mem_type, + OrtAllocatorType allocator_type, + const OrtKeyValuePairs* allocator_options) { + OrtAllocator* p; + ThrowOnError(GetApi().CreateSharedAllocator(p_, ep_device, mem_type, allocator_type, allocator_options, &p)); + return UnownedAllocator{p}; +} + +inline UnownedAllocator Env::GetSharedAllocator(const OrtMemoryInfo* mem_info) { + OrtAllocator* p; + ThrowOnError(GetApi().GetSharedAllocator(p_, mem_info, &p)); + return UnownedAllocator{p}; +} + +inline void Env::ReleaseSharedAllocator(const OrtEpDevice* ep_device, + OrtDeviceMemoryType mem_type) { + ThrowOnError(GetApi().ReleaseSharedAllocator(p_, ep_device, mem_type)); +} + inline CustomOpDomain::CustomOpDomain(const char* domain) { ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_)); } @@ -711,6 +895,26 @@ inline void CustomOpDomain::Add(const OrtCustomOp* op) { ThrowOnError(GetApi().CustomOpDomain_Add(p_, op)); } +inline OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices( + const std::vector& ep_devices, + const char* compatibility_info) { + if (ep_devices.empty()) { + ORT_CXX_API_THROW("ep_devices is empty", ORT_INVALID_ARGUMENT); + } + + std::vector ptrs; + ptrs.reserve(ep_devices.size()); + for (const auto& d : ep_devices) ptrs.push_back(d); + + OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + ThrowOnError(GetApi().GetModelCompatibilityForEpDevices( + reinterpret_cast(ptrs.data()), + ptrs.size(), + compatibility_info, + &status)); + return status; +} + inline LoraAdapter LoraAdapter::CreateLoraAdapter(const std::basic_string& adapter_path, OrtAllocator* allocator) { OrtLoraAdapter* p; @@ -835,6 +1039,16 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelExternalI return *this; } +inline ModelCompilationOptions& +ModelCompilationOptions::SetOutputModelGetInitializerLocationFunc( + OrtGetInitializerLocationFunc get_initializer_location_func, void* state) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc( + this->p_, + get_initializer_location_func, + state)); + return *this; +} + inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelBuffer( OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelBuffer(this->p_, allocator, @@ -843,6 +1057,12 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelBuffer( return *this; } +inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, + void* state) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelWriteFunc(this->p_, write_func, state)); + return *this; +} + inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode( bool embed_ep_context_in_model) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextEmbedMode( @@ -851,11 +1071,18 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode( return *this; } -inline ModelCompilationOptions& ModelCompilationOptions::SetFlags(size_t flags) { +inline ModelCompilationOptions& ModelCompilationOptions::SetFlags(uint32_t flags) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetFlags(this->p_, flags)); return *this; } +inline ModelCompilationOptions& ModelCompilationOptions::SetGraphOptimizationLevel( + GraphOptimizationLevel graph_optimization_level) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetGraphOptimizationLevel(this->p_, + graph_optimization_level)); + return *this; +} + namespace detail { template @@ -1056,6 +1283,12 @@ inline SessionOptionsImpl& SessionOptionsImpl::AddExternalInitializersFrom return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CPU(int use_arena) { + ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(this->p_, use_arena)); + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) { ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options)); @@ -1298,9 +1531,9 @@ inline std::vector ConstSessionImpl::GetInputNames() const { input_names.reserve(num_inputs); for (size_t i = 0; i < num_inputs; ++i) { - char* name = nullptr; + char* name; ThrowOnError(GetApi().SessionGetInputName(this->p_, i, allocator, &name)); - input_names.push_back(name); + input_names.emplace_back(name); allocator.Free(name); } @@ -1316,9 +1549,9 @@ inline std::vector ConstSessionImpl::GetOutputNames() const { output_names.reserve(num_inputs); for (size_t i = 0; i < num_inputs; ++i) { - char* name = nullptr; + char* name; ThrowOnError(GetApi().SessionGetOutputName(this->p_, i, allocator, &name)); - output_names.push_back(name); + output_names.emplace_back(name); allocator.Free(name); } @@ -1334,14 +1567,45 @@ inline std::vector ConstSessionImpl::GetOverridableInitializerNa initializer_names.reserve(num_initializers); for (size_t i = 0; i < num_initializers; ++i) { - char* name = nullptr; + char* name; ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, i, allocator, &name)); - initializer_names.push_back(name); + initializer_names.emplace_back(name); } return initializer_names; } +template +inline std::vector ConstSessionImpl::GetMemoryInfoForInputs() const { + static_assert(sizeof(ConstMemoryInfo) == sizeof(OrtMemoryInfo*), + "ConstMemoryInfo must be compatible with OrtMemoryInfo*"); + + auto num_inputs = GetInputCount(); + std::vector mem_infos; + mem_infos.resize(num_inputs); + + ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_, + reinterpret_cast(mem_infos.data()), + num_inputs)); + + return mem_infos; +} + +template +inline std::vector ConstSessionImpl::GetMemoryInfoForOutputs() const { + static_assert(sizeof(ConstMemoryInfo) == sizeof(OrtMemoryInfo*), + "ConstMemoryInfo must be compatible with OrtMemoryInfo*"); + + auto num_outputs = GetOutputCount(); + std::vector mem_infos; + mem_infos.resize(num_outputs); + + ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, + reinterpret_cast(mem_infos.data()), + num_outputs)); + return mem_infos; +} + template inline AllocatedStringPtr ConstSessionImpl::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const { char* out; @@ -1363,6 +1627,19 @@ inline AllocatedStringPtr ConstSessionImpl::GetOverridableInitializerNameAllo return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); } +template +inline std::vector ConstSessionImpl::GetEpDeviceForInputs() const { + auto num_inputs = GetInputCount(); + std::vector input_devices; + input_devices.resize(num_inputs); + + ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_, + reinterpret_cast(input_devices.data()), + num_inputs)); + + return input_devices; +} + template inline uint64_t ConstSessionImpl::GetProfilingStartTimeNs() const { uint64_t out; @@ -1541,7 +1818,7 @@ inline Session::Session(const Env& env, const void* model_data, size_t model_dat #if !defined(ORT_MINIMAL_BUILD) inline Session::Session(const Env& env, const Model& model, const SessionOptions& options) { - ThrowOnError(GetModelEditorApi().CreateSessionFromModel(env, model.GetConst(), options, &this->p_)); + ThrowOnError(GetModelEditorApi().CreateSessionFromModel(env, model, options, &this->p_)); } // static @@ -1857,15 +2134,15 @@ inline size_t ConstValueImpl::GetTensorSizeInBytes() const { template template inline const R* ConstValueImpl::GetTensorData() const { - R* out; - ThrowOnError(GetApi().GetTensorMutableData(const_cast(this->p_), (void**)&out)); + const R* out; + ThrowOnError(GetApi().GetTensorData(this->p_, reinterpret_cast(&out))); return out; } template inline const void* ConstValueImpl::GetTensorRawData() const { - void* out; - ThrowOnError(GetApi().GetTensorMutableData(const_cast(this->p_), &out)); + const void* out; + ThrowOnError(GetApi().GetTensorData(this->p_, &out)); return out; } @@ -2257,6 +2534,172 @@ inline void KernelContext::ParallelFor(void (*fn)(void*, size_t), size_t total, ThrowOnError(GetApi().KernelContext_ParallelFor(ctx_, fn, total, num_batch, usr_data)); } +namespace detail { + +template +constexpr OrtOpAttrType TypeToAttrType(); + +template <> +inline constexpr OrtOpAttrType TypeToAttrType() { + return OrtOpAttrType::ORT_OP_ATTR_INT; +} + +template <> +inline constexpr OrtOpAttrType TypeToAttrType() { + return OrtOpAttrType::ORT_OP_ATTR_FLOAT; +} + +template +inline constexpr OrtOpAttrType TypeToAttrsType(); + +template <> +inline constexpr OrtOpAttrType TypeToAttrsType() { + return OrtOpAttrType::ORT_OP_ATTR_INTS; +} + +template <> +inline constexpr OrtOpAttrType TypeToAttrsType() { + return OrtOpAttrType::ORT_OP_ATTR_FLOATS; +} + +inline Status CheckAttrType(const OrtOpAttr* attr, OrtOpAttrType requested_type) { + OrtOpAttrType type; + Ort::Status status(GetApi().OpAttr_GetType(attr, &type)); + if (!status.IsOK()) return status; + if (requested_type != type) { + std::string msg = "Attribute type mismatch: expected " + std::to_string(requested_type) + + ", but got " + std::to_string(type); + return Ort::Status(msg.c_str(), OrtErrorCode::ORT_INVALID_ARGUMENT); + } + return Ort::Status{}; +} + +inline size_t GetDataSize(const OrtOpAttr* attr, OrtOpAttrType attr_type) { + size_t result{}; + // Ignore the status here because we check the data type so the error should only be about + // the size + [[maybe_unused]] Status status{GetApi().ReadOpAttr(attr, attr_type, nullptr, 0, &result)}; + return result; +} + +template +Ort::Status GetNumericValue(const OrtOpAttr* attr, T& out) { + static_assert(std::is_arithmetic::value, "T must be an arithmetic type"); + size_t size{}; + return Ort::Status{GetApi().ReadOpAttr(attr, TypeToAttrType(), &out, sizeof(out), &size)}; +} + +template +struct GetValueImpl { + static Status GetValue(const OrtOpAttr* attr, T& out) { + return GetNumericValue(attr, out); + } + static Status GetValues(const OrtOpAttr* attr, std::vector& out) { + // Api deficiency when it comes to value arrays. It is not possible + // to tell if the error is due to the type mismatch or the size + // so we check the type first, and then ignore the status of the size check + constexpr auto deduced_type = TypeToAttrsType(); + auto status = CheckAttrType(attr, deduced_type); + if (!status.IsOK()) return status; + auto size = GetDataSize(attr, deduced_type); + std::vector result; + if (size > 0) { + result.resize(size / sizeof(T)); + status = Status{GetApi().ReadOpAttr( + attr, deduced_type, result.data(), size, &size)}; + if (!status.IsOK()) return status; + } + out.swap(result); + return status; + } +}; + +// Create GetValueImpl specializations for std::string +template <> +struct GetValueImpl { + static Status GetValue(const OrtOpAttr* attr, std::string& out) { + // Api deficiency when it comes to value arrays. It is not possible + // to tell if the error is due to the type mismatch or the size + // so we check the type first, and then ignore the status of the size check + auto status = CheckAttrType(attr, OrtOpAttrType::ORT_OP_ATTR_STRING); + if (!status.IsOK()) return status; + auto size = GetDataSize(attr, OrtOpAttrType::ORT_OP_ATTR_STRING); + std::string result; + if (size > 0) { + result.resize(size); + // some compilers in use do not support std::string::data() non-const + auto* buffer = &result[0]; + status = Status{GetApi().ReadOpAttr( + attr, OrtOpAttrType::ORT_OP_ATTR_STRING, buffer, size, &size)}; + if (!status.IsOK()) return status; + } + out.swap(result); + return status; + } + static Status GetValues(const OrtOpAttr* attr, std::vector& out) { + auto status = CheckAttrType(attr, OrtOpAttrType::ORT_OP_ATTR_STRINGS); + if (!status.IsOK()) return status; + + std::vector result; + size_t total_buffer_size = GetDataSize(attr, OrtOpAttrType::ORT_OP_ATTR_STRINGS); + if (total_buffer_size > 0) { + // Create a temporary buffer to hold the string data + std::vector buffer(total_buffer_size); + status = Status{GetApi().ReadOpAttr(attr, OrtOpAttrType::ORT_OP_ATTR_STRINGS, buffer.data(), + total_buffer_size, &total_buffer_size)}; + if (!status.IsOK()) return status; + + const char* data = buffer.data(); + const char* end = data + total_buffer_size; + while (data < end) { + result.emplace_back(data); + data += result.back().size() + 1; // Move past the null terminator + } + } + out.swap(result); + return status; + } +}; + +template +template +inline Status ConstOpAttrImpl::GetValue(R& out) const { + return GetValueImpl::GetValue(this->p_, out); +} + +template +template +inline Status ConstOpAttrImpl::GetValueArray(std::vector& out) const { + return GetValueImpl::GetValues(this->p_, out); +} + +template +inline Status ConstOpAttrImpl::GetTensorAttributeAsOrtValue(Value& out) const { + OrtValue* tensor_value = nullptr; + auto status = Status(GetApi().OpAttr_GetTensorAttributeAsOrtValue(this->p_, &tensor_value)); + if (!status.IsOK()) return status; + out = Value{tensor_value}; + return status; +} + +template +inline std::string ConstOpAttrImpl::GetName() const { + const char* name = nullptr; + ThrowOnError(GetApi().OpAttr_GetName(this->p_, &name)); + if (name != nullptr) { + return name; + } + return {}; +} + +template +inline OrtOpAttrType ConstOpAttrImpl::GetType() const { + OrtOpAttrType type; + ThrowOnError(GetApi().OpAttr_GetType(this->p_, &type)); + return type; +} +} // namespace detail + inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) { Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_)); } @@ -2557,115 +3000,69 @@ inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shap } inline int64_t ShapeInferContext::GetAttrInt(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - int64_t i = {}; - size_t out = {}; - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INT, &i, sizeof(i), &out)); - return i; + auto attr = GetAttrHdl(attr_name); + int64_t value; + Status status = attr.GetValue(value); + if (!status.IsOK()) { + ORT_CXX_API_THROW("Getting int attribute failed: " + status.GetErrorMessage(), status.GetErrorCode()); + } + return value; } inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - int64_t i = {}; - size_t out = {}; - // first call to get the bytes needed - // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. - // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). - // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. - auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, &i, sizeof(i), &out); - if (status) { - size_t num_i = out / sizeof(int64_t); - ShapeInferContext::Ints ints(num_i, 0); - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, ints.data(), out, &out)); - return ints; - } else { - if (out == 0u) { - return {}; - } - return {i}; + auto attr = GetAttrHdl(attr_name); + ShapeInferContext::Ints result; + auto status = attr.GetValueArray(result); + if (!status.IsOK()) { + ORT_CXX_API_THROW("Getting ints attribute failed: " + status.GetErrorMessage(), status.GetErrorCode()); } + return result; } inline float ShapeInferContext::GetAttrFloat(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - float f = {}; - size_t out = {}; - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOAT, &f, sizeof(f), &out)); - return f; + auto attr = GetAttrHdl(attr_name); + float value; + Status status = attr.GetValue(value); + if (!status.IsOK()) { + ORT_CXX_API_THROW("Getting float attribute failed: " + status.GetErrorMessage(), status.GetErrorCode()); + } + return value; } inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - float f = {}; - size_t out = {}; - // first call to get the bytes needed - // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. - // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). - // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. - auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, &f, sizeof(f), &out); - if (status) { - size_t num_f = out / sizeof(float); - ShapeInferContext::Floats floats(num_f, 0); - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, floats.data(), out, &out)); - return floats; - } else { - if (out == 0u) { - return {}; - } - return {f}; + auto attr = GetAttrHdl(attr_name); + ShapeInferContext::Floats result; + auto status = attr.GetValueArray(result); + if (!status.IsOK()) { + ORT_CXX_API_THROW("Getting floats attribute failed: " + status.GetErrorMessage(), status.GetErrorCode()); } + return result; } inline std::string ShapeInferContext::GetAttrString(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - char c = {}; - size_t out = {}; - // first call to get the bytes needed - auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, &c, sizeof(char), &out); - if (status) { - std::vector chars(out, '\0'); - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, chars.data(), out, &out)); - return std::string{chars.data(), out}; - } else { - return {c}; + auto attr = GetAttrHdl(attr_name); + std::string value; + Status status = attr.GetValue(value); + if (!status.IsOK()) { + ORT_CXX_API_THROW("Getting string attribute failed: " + status.GetErrorMessage(), status.GetErrorCode()); } + return value; } inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - char c = {}; - size_t out = {}; - // first call to get the bytes needed - // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. - // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). - // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. - auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, &c, sizeof(char), &out); - if (status) { - std::vector chars(out, '\0'); - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, chars.data(), out, &out)); - ShapeInferContext::Strings strings; - char* char_st = chars.data(); - char* char_ed = char_st + out; - while (char_st < char_ed) { - strings.emplace_back(char_st); - while (*char_st != '\0') { - char_st++; - } - char_st++; - } - return strings; - } else { - if (out == 0u) { - return {}; - } - return {std::string{c}}; + auto attr = GetAttrHdl(attr_name); + ShapeInferContext::Strings result; + auto status = attr.GetValueArray(result); + if (!status.IsOK()) { + ORT_CXX_API_THROW("Getting strings attribute failed: " + status.GetErrorMessage(), status.GetErrorCode()); } + return result; } -inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) const { +inline ConstOpAttr ShapeInferContext::GetAttrHdl(const char* attr_name) const { const OrtOpAttr* attr_hdl = {}; Ort::ThrowOnError(ort_api_->ShapeInferContext_GetAttribute(ctx_, attr_name, &attr_hdl)); - return attr_hdl; + return ConstOpAttr{attr_hdl}; } namespace detail { @@ -2679,6 +3076,136 @@ inline std::vector StringsToCharPtrs(const std::vector } } // namespace detail +namespace detail { +template +inline size_t ConstNodeImpl::GetId() const { + size_t id; + ThrowOnError(GetApi().Node_GetId(this->p_, &id)); + return id; +} + +template +inline std::string ConstNodeImpl::GetName() const { + const char* name; + ThrowOnError(GetApi().Node_GetName(this->p_, &name)); + return std::string(name); +} + +template +inline std::string ConstNodeImpl::GetOperatorType() const { + const char* type; + ThrowOnError(GetApi().Node_GetOperatorType(this->p_, &type)); + return std::string(type); +} + +template +inline std::string ConstNodeImpl::GetDomain() const { + const char* domain; + ThrowOnError(GetApi().Node_GetDomain(this->p_, &domain)); + return std::string(domain); +} + +template +inline int ConstNodeImpl::GetSinceVersion() const { + int since_version; + ThrowOnError(GetApi().Node_GetSinceVersion(this->p_, &since_version)); + return since_version; +} + +template +inline std::vector ConstNodeImpl::GetInputs() const { + static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo)); + size_t num_vi; + ThrowOnError(GetApi().Node_GetNumInputs(this->p_, &num_vi)); + std::vector result; + if (num_vi > 0) { + result.resize(num_vi); + ThrowOnError(GetApi().Node_GetInputs(this->p_, reinterpret_cast(result.data()), num_vi)); + } + return result; +} + +template +inline std::vector ConstNodeImpl::GetOutputs() const { + static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo)); + size_t num_vi; + ThrowOnError(GetApi().Node_GetNumOutputs(this->p_, &num_vi)); + std::vector result; + if (num_vi > 0) { + result.resize(num_vi); + ThrowOnError(GetApi().Node_GetOutputs(this->p_, reinterpret_cast(result.data()), num_vi)); + } + return result; +} + +template +inline std::vector ConstNodeImpl::GetImplicitInputs() const { + static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo)); + size_t num_vi; + ThrowOnError(GetApi().Node_GetNumImplicitInputs(this->p_, &num_vi)); + std::vector result; + if (num_vi > 0) { + result.resize(num_vi); + ThrowOnError(GetApi().Node_GetImplicitInputs(this->p_, reinterpret_cast(result.data()), + num_vi)); + } + return result; +} + +template +inline std::vector ConstNodeImpl::GetAttributes() const { + static_assert(sizeof(const OrtOpAttr*) == sizeof(ConstOpAttr), "Must be the same size"); + size_t num_attrs; + ThrowOnError(GetApi().Node_GetNumAttributes(this->p_, &num_attrs)); + std::vector attrs; + if (num_attrs > 0) { + attrs.resize(num_attrs); + ThrowOnError(GetApi().Node_GetAttributes(this->p_, reinterpret_cast(attrs.data()), num_attrs)); + } + return attrs; +} + +template +inline Status ConstNodeImpl::GetAttributeByName(const std::string& name, ConstOpAttr& out) const { + const OrtOpAttr* attr = nullptr; + auto status = Status(GetApi().Node_GetAttributeByName(this->p_, name.c_str(), &attr)); + out = ConstOpAttr{attr}; + return status; +} + +template +inline std::vector ConstNodeImpl::GetSubgraphs() const { + size_t num_graphs; + ThrowOnError(GetApi().Node_GetNumSubgraphs(this->p_, &num_graphs)); + std::vector result; + if (num_graphs > 0) { + std::vector sub_graphs(num_graphs); + std::vector attr_names(num_graphs); + ThrowOnError(GetApi().Node_GetSubgraphs(this->p_, sub_graphs.data(), num_graphs, attr_names.data())); + result.reserve(num_graphs); + for (size_t i = 0; i < num_graphs; ++i) { + result.push_back({std::string(attr_names[i]), ConstGraph{sub_graphs[i]}}); + } + } + return result; +} + +template +inline ConstGraph ConstNodeImpl::GetGraph() const { + const OrtGraph* graph; + ThrowOnError(GetApi().Node_GetGraph(this->p_, &graph)); + return ConstGraph{graph}; +} + +template +inline std::string ConstNodeImpl::GetEpName() const { + const char* name; + ThrowOnError(GetApi().Node_GetEpName(this->p_, &name)); + return std::string(name); +} + +} // namespace detail + #if !defined(ORT_MINIMAL_BUILD) // static inline void Node::Init(const std::string& operator_name, const std::string& operator_domain, @@ -2720,90 +3247,294 @@ inline Node::Node(const std::string& operator_name, const std::string& operator_ std::vector empty_attributes; Init(operator_name, operator_domain, node_name, input_names, output_names, empty_attributes, p_); } +inline ValueInfo::ValueInfo(const std::string& name, const ConstTypeInfo& type_info) { + ThrowOnError(GetModelEditorApi().CreateValueInfo(name.c_str(), type_info, &p_)); +} +#endif // !defined(ORT_MINIMAL_BUILD) -inline Graph::Graph() { - ThrowOnError(GetModelEditorApi().CreateGraph(&p_)); +namespace detail { +template +inline std::string ConstValueInfoImpl::GetName() const { + const char* p = nullptr; + ThrowOnError(GetApi().GetValueInfoName(this->p_, &p)); + return std::string(p); } -inline Model::Model(const std::vector& opsets) { - std::vector domains; - std::vector versions; - domains.reserve(opsets.size()); - versions.reserve(opsets.size()); +template +inline ConstTypeInfo ConstValueInfoImpl::TypeInfo() const { + const OrtTypeInfo* type_info = nullptr; + ThrowOnError(GetApi().GetValueInfoTypeInfo(this->p_, &type_info)); + return ConstTypeInfo{type_info}; +} - for (const auto& pair : opsets) { - domains.push_back(pair.first.c_str()); - versions.push_back(pair.second); +template +inline ValueInfoConsumerProducerInfo ConstValueInfoImpl::GetProducerNode() const { + ValueInfoConsumerProducerInfo info; + const OrtNode* producer; + size_t index; + ThrowOnError(GetApi().ValueInfo_GetValueProducer(this->p_, &producer, &index)); + info.node = ConstNode(producer); + info.index = static_cast(index); + return info; +} + +template +inline std::vector ConstValueInfoImpl::GetConsumers() const { + size_t num = 0; + ThrowOnError(GetApi().ValueInfo_GetValueNumConsumers(this->p_, &num)); + std::vector out; + if (num > 0) { + std::vector nodes(num); + std::vector indices(num); + ThrowOnError(GetApi().ValueInfo_GetValueConsumers(this->p_, nodes.data(), indices.data(), num)); + out.reserve(num); + for (size_t i = 0; i < num; ++i) { + out.push_back({ConstNode{nodes[i]}, indices[i]}); + } } + return out; +} - ThrowOnError(GetModelEditorApi().CreateModel(domains.data(), versions.data(), opsets.size(), &p_)); +template +inline Status ConstValueInfoImpl::GetInitializer(ConstValue& value) const { + const OrtValue* out = nullptr; + auto status = Status(GetApi().ValueInfo_GetInitializerValue(this->p_, &out)); + if (!status.IsOK()) return status; + value = ConstValue{out}; + return status; } -inline ValueInfo::ValueInfo(const std::string& name, const ConstTypeInfo& type_info) { - ThrowOnError(GetModelEditorApi().CreateValueInfo(name.c_str(), type_info, &p_)); +template +inline Status ConstValueInfoImpl::GetExternalInitializerInfo(ExternalInitializerInfo& info) const { + OrtExternalInitializerInfo* out = nullptr; + auto status = Status(GetApi().ValueInfo_GetExternalInitializerInfo(this->p_, &out)); + if (!status.IsOK()) return status; + info = ExternalInitializerInfo{out}; + return status; } -#endif // !defined(ORT_MINIMAL_BUILD) -namespace detail { -template <> -inline std::string ValueInfoImpl::Name() const { - const char* name = nullptr; - ThrowOnError(GetApi().GetValueInfoName(this->p_, &name)); - return name; +template +inline bool ConstValueInfoImpl::IsRequiredGraphInput() const { + bool out = false; + ThrowOnError(GetApi().ValueInfo_IsRequiredGraphInput(this->p_, &out)); + return out; } -template <> -inline ConstTypeInfo ValueInfoImpl::TypeInfo() const { - const OrtTypeInfo* type_info = nullptr; - ThrowOnError(GetApi().GetValueInfoTypeInfo(this->p_, &type_info)); - return ConstTypeInfo{type_info}; +template +inline bool ConstValueInfoImpl::IsOptionalGraphInput() const { + bool out = false; + ThrowOnError(GetApi().ValueInfo_IsOptionalGraphInput(this->p_, &out)); + return out; +} + +template +inline bool ConstValueInfoImpl::IsGraphOutput() const { + bool out = false; + ThrowOnError(GetApi().ValueInfo_IsGraphOutput(this->p_, &out)); + return out; +} + +template +inline bool ConstValueInfoImpl::IsConstantInitializer() const { + bool out = false; + ThrowOnError(GetApi().ValueInfo_IsConstantInitializer(this->p_, &out)); + return out; +} + +template +inline bool ConstValueInfoImpl::IsFromOuterScope() const { + bool out = false; + ThrowOnError(GetApi().ValueInfo_IsFromOuterScope(this->p_, &out)); + return out; +} + +template +inline ModelMetadata ConstGraphImpl::GetModelMetadata() const { + OrtModelMetadata* out; + ThrowOnError(GetApi().Graph_GetModelMetadata(this->p_, &out)); + return ModelMetadata{out}; +} + +template +inline std::string ConstGraphImpl::GetName() const { + const char* name; + ThrowOnError(GetApi().Graph_GetName(this->p_, &name)); + return std::string(name); +} + +template +inline std::basic_string ConstGraphImpl::GetModelPath() const { + const ORTCHAR_T* path; + ThrowOnError(GetApi().Graph_GetModelPath(this->p_, &path)); + return std::basic_string(path); +} + +template +inline int64_t ConstGraphImpl::GetOnnxIRVersion() const { + int64_t version; + ThrowOnError(GetApi().Graph_GetOnnxIRVersion(this->p_, &version)); + return version; +} + +template +inline std::vector ConstGraphImpl::GetOperatorSets() const { + size_t num_opsets; + ThrowOnError(GetApi().Graph_GetNumOperatorSets(this->p_, &num_opsets)); + std::vector result; + if (num_opsets > 0) { + std::vector domains; + std::vector versions; + domains.resize(num_opsets); + versions.resize(num_opsets); + ThrowOnError(GetApi().Graph_GetOperatorSets(this->p_, domains.data(), versions.data(), num_opsets)); + result.reserve(num_opsets); + for (size_t i = 0; i < num_opsets; ++i) { + result.push_back({domains[i], versions[i]}); + } + } + return result; +} + +template +inline std::vector ConstGraphImpl::GetInputs() const { + static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo)); + size_t num_vi; + ThrowOnError(GetApi().Graph_GetNumInputs(this->p_, &num_vi)); + std::vector result; + if (num_vi > 0) { + result.resize(num_vi); + ThrowOnError(GetApi().Graph_GetInputs(this->p_, reinterpret_cast(result.data()), num_vi)); + } + return result; +} + +template +inline std::vector ConstGraphImpl::GetOutputs() const { + static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo)); + size_t num_vi; + ThrowOnError(GetApi().Graph_GetNumOutputs(this->p_, &num_vi)); + std::vector result; + if (num_vi > 0) { + result.resize(num_vi); + ThrowOnError(GetApi().Graph_GetOutputs(this->p_, reinterpret_cast(result.data()), num_vi)); + } + return result; +} + +template +inline std::vector ConstGraphImpl::GetInitializers() const { + static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo)); + size_t num_vi; + ThrowOnError(GetApi().Graph_GetNumInitializers(this->p_, &num_vi)); + std::vector result; + if (num_vi > 0) { + result.resize(num_vi); + ThrowOnError(GetApi().Graph_GetInitializers(this->p_, reinterpret_cast(result.data()), + num_vi)); + } + return result; +} + +template +inline std::vector ConstGraphImpl::GetNodes() const { + static_assert(sizeof(const OrtNode*) == sizeof(ConstNode)); + size_t num_nodes; + ThrowOnError(GetApi().Graph_GetNumNodes(this->p_, &num_nodes)); + std::vector result; + if (num_nodes > 0) { + result.resize(num_nodes); + ThrowOnError(GetApi().Graph_GetNodes(this->p_, reinterpret_cast(result.data()), num_nodes)); + } + return result; +} + +template +inline ConstNode ConstGraphImpl::GetParentNode() const { + const OrtNode* parent; + ThrowOnError(GetApi().Graph_GetParentNode(this->p_, &parent)); + return ConstNode{parent}; +} + +template +inline Graph ConstGraphImpl::GetGraphView(const std::vector& nodes) const { + OrtGraph* graph_viewer; + std::vector inputs_ptrs; + inputs_ptrs.reserve(nodes.size()); + std::transform(nodes.begin(), nodes.end(), std::back_inserter(inputs_ptrs), + [](ConstNode n) -> const OrtNode* { return n; }); + ThrowOnError(GetApi().Graph_GetGraphView(this->p_, inputs_ptrs.data(), + nodes.size(), &graph_viewer)); + return Graph{graph_viewer}; } #if !defined(ORT_MINIMAL_BUILD) -template <> -inline void GraphImpl::SetInputs(std::vector& inputs) { +template +inline void GraphImpl::SetInputs(std::vector& inputs) { std::vector inputs_ptrs; inputs_ptrs.reserve(inputs.size()); std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_ptrs), [](ValueInfo& vi) -> OrtValueInfo* { return vi; }); - ThrowOnError(GetModelEditorApi().SetGraphInputs(p_, inputs_ptrs.data(), inputs_ptrs.size())); + ThrowOnError(GetModelEditorApi().SetGraphInputs(this->p_, inputs_ptrs.data(), inputs_ptrs.size())); // Graph now owns the inputs std::for_each(inputs.begin(), inputs.end(), [](ValueInfo& vi) { vi.release(); }); } -template <> -inline void GraphImpl::SetOutputs(std::vector& outputs) { +template +inline void GraphImpl::SetOutputs(std::vector& outputs) { std::vector outputs_ptrs; outputs_ptrs.reserve(outputs.size()); std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_ptrs), [](ValueInfo& vi) -> OrtValueInfo* { return vi; }); - ThrowOnError(GetModelEditorApi().SetGraphOutputs(p_, outputs_ptrs.data(), outputs_ptrs.size())); + ThrowOnError(GetModelEditorApi().SetGraphOutputs(this->p_, outputs_ptrs.data(), outputs_ptrs.size())); // Graph now owns the outputs std::for_each(outputs.begin(), outputs.end(), [](ValueInfo& vi) { vi.release(); }); } -template <> -inline void GraphImpl::AddInitializer(const std::string& name, Value& initializer, bool data_is_external) { +template +inline void GraphImpl::AddInitializer(const std::string& name, Value& initializer, bool data_is_external) { // Graph takes ownership of `initializer` - ThrowOnError(GetModelEditorApi().AddInitializerToGraph(p_, name.c_str(), initializer.release(), data_is_external)); + // On error the ownership is not transferred. + ThrowOnError(GetModelEditorApi().AddInitializerToGraph(this->p_, name.c_str(), initializer, data_is_external)); + initializer.release(); } -template <> -inline void GraphImpl::AddNode(Node& node) { +template +inline void GraphImpl::AddNode(Node& node) { // Graph takes ownership of `node` - ThrowOnError(GetModelEditorApi().AddNodeToGraph(p_, node.release())); + ThrowOnError(GetModelEditorApi().AddNodeToGraph(this->p_, node.release())); } -template <> -inline void ModelImpl::AddGraph(Graph& graph) { +template +inline void ModelImpl::AddGraph(Graph& graph) { // Model takes ownership of `graph` - ThrowOnError(GetModelEditorApi().AddGraphToModel(p_, graph.release())); + ThrowOnError(GetModelEditorApi().AddGraphToModel(this->p_, graph.release())); } #endif // !defined(ORT_MINIMAL_BUILD) } // namespace detail + +#if !defined(ORT_MINIMAL_BUILD) +inline Graph::Graph() { + ThrowOnError(GetModelEditorApi().CreateGraph(&p_)); +} + +inline Model::Model(const std::vector& opsets) { + std::vector domains; + std::vector versions; + domains.reserve(opsets.size()); + versions.reserve(opsets.size()); + + for (const auto& pair : opsets) { + domains.push_back(pair.first.c_str()); + versions.push_back(pair.second); + } + + ThrowOnError(GetModelEditorApi().CreateModel(domains.data(), versions.data(), opsets.size(), &p_)); +} +#endif + } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 620cb5fcf13cc..975f6b453a88d 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -482,18 +482,6 @@ typedef enum OrtEpDataLayout { OrtEpDataLayout_Default = OrtEpDataLayout_NCHW, } OrtEpDataLayout; -/** - * \brief Enumeration describing the compatibility state of a compiled model relative to an execution provider. - * - * \since Version 1.23. - */ -typedef enum OrtCompiledModelCompatibility { - OrtCompiledModelCompatibility_EP_NOT_APPLICABLE = 0, - OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL, - OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION, - OrtCompiledModelCompatibility_EP_UNSUPPORTED, -} OrtCompiledModelCompatibility; - /** * \brief The OrtEp struct provides functions to implement for an execution provider. * \since Version 1.22. @@ -901,20 +889,28 @@ struct OrtEpFactory { */ ORT_API_T(const char*, GetVersion, _In_ const OrtEpFactory* this_ptr); - /** \brief Validate the compatibility of a compiled model with the execution provider. + /** \brief Validate the compatibility of a compiled model with the execution provider factory for one or more devices. + * + * Given a compatibility info string produced during model compilation, the EP factory should determine whether the + * compiled model is compatible with the EP factory when targeting the provided hardware devices. All devices provided + * must belong to the same execution provider instance that this factory creates. * - * This function validates if a model produced with the supplied compatibility info string is supported by the underlying EP. - * The EP should check if a compiled model is compatible with the EP and set the model_compatibility parameter accordingly. + * The EP factory implementation should consider the set of devices (e.g., multi-adapter or multi-GPU scenarios) when + * evaluating compatibility and set `model_compatibility` accordingly. * * \param[in] this_ptr The OrtEpFactory instance. - * \param[in] compatibility_info The compatibility information string that will be used - * \param[out] model_compatibility OrtCompiledModelCompatibility enum value describing the compatibility of the model with the EP. + * \param[in] devices Array of OrtHardwareDevice pointers that the EP would run on. All must map to this EP. + * \param[in] num_devices Number of entries in `devices`. + * \param[in] compatibility_info The compatibility information string produced when the model was compiled. + * \param[out] model_compatibility OrtCompiledModelCompatibility value describing the compatibility of the model with the EP. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ ORT_API2_STATUS(ValidateCompiledModelCompatibilityInfo, _In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, _In_ const char* compatibility_info, _Out_ OrtCompiledModelCompatibility* model_compatibility); diff --git a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h index f0992f05f31e5..bbd6a43bb7a41 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h @@ -8,3 +8,11 @@ // Key for the execution provider version string. This should be available for all plugin EPs. static const char* const kOrtEpDevice_EpMetadataKey_Version = "version"; + +// Prefix for execution provider compatibility information stored in model metadata. +// Used when generating EP context models to store compatibility strings for each EP. +// Full key format: "ep_compatibility_info." +static const char* const kOrtModelMetadata_EpCompatibilityInfoPrefix = "ep_compatibility_info."; + +// Key for the execution provider library path (for dynamically loaded EPs) +static const char* const kOrtEpDevice_EpMetadataKey_LibraryPath = "library_path"; diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 314cf76cc8044..7eb5f7659a365 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -382,8 +382,8 @@ static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "sessio // THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME // Meant to be used with SetEpDynamicOptions // Specify the type of workload for this session. -// “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default] -// “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance. +// "Default": OS determines the scheduling priority and processor performance to service this workload. [Default] +// "Efficient": OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance. static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type"; // Disables model compilation during session initialization. @@ -401,3 +401,10 @@ static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload // - "0": EP compile is not disabled. [DEFAULT] // - "1": EP compile is disabled. static const char* const kOrtSessionOptionsDisableModelCompile = "session.disable_model_compile"; + +// Controls behavior when compiled model compatibility is SUPPORTED_PREFER_RECOMPILATION. +// "0": Allow execution with suboptimal performance. [DEFAULT] +// "1": Fail session creation to require recompilation for optimal performance. +// Note: UNSUPPORTED models always fail regardless of this setting. +static const char* const kOrtSessionOptionsFailOnSuboptimalCompiledModel = + "session.fail_on_suboptimal_compiled_model"; diff --git a/java/build.gradle b/java/build.gradle index 2d43d1ead13f0..64a31c89ad322 100644 --- a/java/build.gradle +++ b/java/build.gradle @@ -3,8 +3,7 @@ plugins { id 'maven-publish' id 'signing' id 'jacoco' - id "com.diffplug.spotless" version "6.25.0" - id "net.linguica.maven-settings" version "0.5" + id "com.diffplug.spotless" version "7.2.1" } allprojects { @@ -14,17 +13,9 @@ allprojects { } project.group = "com.microsoft.onnxruntime" -version = rootProject.file('../VERSION_NUMBER').text.trim() - // cmake runs will inform us of the build directory of the current run def cmakeBuildDir = System.properties['cmakeBuildDir'] def useCUDA = System.properties['USE_CUDA'] -def useROCM = System.properties['USE_ROCM'] - -def adoArtifact = project.findProperty('adoArtifact') -def adoAccessToken = project.findProperty('adoAccessToken') -// Only publish to ADO feed if all two properties are set -def publishToAdo = adoArtifact != null && adoAccessToken != null boolean enableTrainingApis = (System.properties['ENABLE_TRAINING_APIS'] ?: "0") == "1" def cmakeJavaDir = "${cmakeBuildDir}/java" @@ -33,21 +24,14 @@ def cmakeNativeJniDir = "${cmakeJavaDir}/native-jni" def cmakeNativeTestDir = "${cmakeJavaDir}/native-test" def cmakeBuildOutputDir = "${cmakeJavaDir}/build" -def mavenUser = System.properties['mavenUser'] -def mavenPwd = System.properties['mavenPwd'] - def tmpArtifactId = enableTrainingApis ? project.name + "-training" : project.name -def mavenArtifactId = (useCUDA == null && useROCM == null) ? tmpArtifactId : tmpArtifactId + "_gpu" +def mavenArtifactId = (useCUDA == null) ? tmpArtifactId : tmpArtifactId + "_gpu" def defaultDescription = 'ONNX Runtime is a performance-focused inference engine for ONNX (Open Neural Network Exchange) models.' def trainingDescription = 'ONNX Runtime Training is a training and inference package for ONNX ' + '(Open Neural Network Exchange) models. This package is targeted for Learning on The Edge aka On-Device Training ' + 'See https://github.com/microsoft/onnxruntime-training-examples/tree/master/on_device_training for more details.' -// We need to have a custom settings.xml so codeql can bypass the need for settings.security.xml -mavenSettings { - userSettingsFileName = "${projectDir}/settings.xml" -} java { sourceCompatibility = JavaVersion.VERSION_17 @@ -202,16 +186,27 @@ test { systemProperties System.getProperties().subMap([ 'ENABLE_TRAINING_APIS', 'JAVA_FULL_TEST', + 'USE_ACL', + 'USE_ARMNN', + 'USE_AZURE', + 'USE_CANN', 'USE_COREML', 'USE_CUDA', 'USE_DML', 'USE_DNNL', + 'USE_MIGRAPHX', + 'USE_NNAPI', + 'USE_NV', 'USE_OPENVINO', - 'USE_ROCM', - 'USE_TENSORRT', 'USE_QNN', - 'USE_XNNPACK', + 'USE_RKNPU', + 'USE_SNPE', + 'USE_TENSORRT', + 'USE_VITISAI', + 'USE_VSINPU', 'USE_WEBGPU', + 'USE_WEBNN', + 'USE_XNNPACK', ]) testLogging { events "passed", "skipped", "failed" @@ -233,13 +228,9 @@ publishing { publications { maven(MavenPublication) { groupId = project.group - if(publishToAdo) { - artifactId = 'onnxruntime_gpu' - artifact (adoArtifact) - } else { - artifactId = mavenArtifactId - from components.java - } + artifactId = mavenArtifactId + from components.java + version = project.version pom { name = enableTrainingApis ? 'onnxruntime-training' : 'onnx-runtime' @@ -270,29 +261,6 @@ publishing { } } } - repositories { - if (publishToAdo) { - maven { - url "https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/${System.getenv('ADOFeedName')}/maven/v1" - name System.getenv('ADOFeedName') - authentication { - basic(BasicAuthentication) - } - credentials { - username 'aiinfra' - password "${project.findProperty('adoAccessToken')}" - } - } - } else { - maven { - url 'https://oss.sonatype.org/service/local/staging/deploy/maven2/' - credentials { - username mavenUser - password mavenPwd - } - } - } - } } // Generates a task signMavenPublication that will // build all artifacts. @@ -300,12 +268,17 @@ signing { // Queries env vars: // ORG_GRADLE_PROJECT_signingKey // ORG_GRADLE_PROJECT_signingPassword but can be changed to properties - def signingKey = findProperty("signingKey") - def signingPassword = findProperty("signingPassword") - // Skip signing if no key is provided - if (signingKey != null && signingPassword != null) { - useInMemoryPgpKeys(signingKey, signingPassword) - sign publishing.publications.maven - sign publishing.publications.mavenAdo - } + def signingKey = findProperty("signingKey") + def signingPassword = findProperty("signingPassword") + // Skip signing if no key is provided + if (signingKey != null && signingPassword != null) { + useInMemoryPgpKeys(signingKey, signingPassword) + sign publishing.publications.maven + } +} + +tasks.named('generatePomFileForMavenPublication') { + doFirst { + println "AGENT_LOG: Generating POM for version: ${project.version}" + } } diff --git a/java/src/main/java/ai/onnxruntime/OrtException.java b/java/src/main/java/ai/onnxruntime/OrtException.java index 5ec58ea137124..06c3d3cbc770c 100644 --- a/java/src/main/java/ai/onnxruntime/OrtException.java +++ b/java/src/main/java/ai/onnxruntime/OrtException.java @@ -81,11 +81,17 @@ public enum OrtErrorCode { /** The ONNX graph is invalid. */ ORT_INVALID_GRAPH(10), /** The ORT execution provider failed. */ - ORT_EP_FAIL(11); + ORT_EP_FAIL(11), + /** Model load was canceled. */ + ORT_MODEL_LOAD_CANCELED(12), + /** Model requires compilation. */ + ORT_MODEL_REQUIRES_COMPILATION(13), + /** Item was not found. */ + ORT_NOT_FOUND(14); private final int value; - private static final OrtErrorCode[] values = new OrtErrorCode[12]; + private static final OrtErrorCode[] values = new OrtErrorCode[15]; static { for (OrtErrorCode ot : OrtErrorCode.values()) { diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index fe19015d642f0..5d8efd7b476cb 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -1051,6 +1051,12 @@ jint convertErrorCode(OrtErrorCode code) { return 10; case ORT_EP_FAIL: return 11; + case ORT_MODEL_LOAD_CANCELED: + return 12; + case ORT_MODEL_REQUIRES_COMPILATION: + return 13; + case ORT_NOT_FOUND: + return 14; default: return -1; // Unknown error code } diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index c3f9d345078fe..c202b2a9f80e0 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -693,12 +693,6 @@ public void testCUDA() throws OrtException { runProvider(OrtProvider.CUDA); } - @Test - @EnabledIfSystemProperty(named = "USE_ROCM", matches = "1") - public void testROCM() throws OrtException { - runProvider(OrtProvider.ROCM); - } - @Test @EnabledIfSystemProperty(named = "USE_TENSORRT", matches = "1") public void testTensorRT() throws OrtException { @@ -725,6 +719,18 @@ public void testDNNL() throws OrtException { runProvider(OrtProvider.DNNL); } + @Test + @EnabledIfSystemProperty(named = "USE_MIGRAPHX", matches = "1") + public void testMIGRAPHX() throws OrtException { + runProvider(OrtProvider.MI_GRAPH_X); + } + + @Test + @EnabledIfSystemProperty(named = "USE_NNAPI", matches = "1") + public void testNNAPI() throws OrtException { + runProvider(OrtProvider.NNAPI); + } + @Test @EnabledIfSystemProperty(named = "USE_XNNPACK", matches = "1") public void testXNNPACK() throws OrtException { diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 7de9cfa14927d..550502cf3bc48 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -33,6 +33,7 @@ OrtCompileApiFlags, # noqa: F401 OrtEpDevice, # noqa: F401 OrtExecutionProviderDevicePolicy, # noqa: F401 + OrtExternalInitializerInfo, # noqa: F401 OrtHardwareDevice, # noqa: F401 OrtHardwareDeviceType, # noqa: F401 OrtMemoryInfo, # noqa: F401 diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 0d5117709c18a..bfa450f4287f8 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -280,6 +280,18 @@ class GQAAttentionBase { output, static_cast(present_buffer_sequence_length), nullptr); } + // Pre-allocate buffer for attention mask to avoid allocating it for every processed token + float* attention_bias_thread_fp32 = nullptr; + if (attention_bias_thread != nullptr) { + if constexpr (!std::is_same_v) { + static_assert(std::is_same_v && std::is_same_v); + + size_t bytes = attention_total_seqlen * sizeof(float); + attention_bias_thread_fp32 = static_cast(allocator->Alloc(bytes)); + } + } + BufferUniquePtr scratch_buffer(attention_bias_thread_fp32, BufferDeleter(allocator)); + // compute Softmax U* output_softmax = output; for (size_t seq = 0; seq < sequence_length; seq++) { @@ -316,9 +328,6 @@ class GQAAttentionBase { static_cast(window_size)); } else { static_assert(std::is_same_v && std::is_same_v); - size_t bytes = window_size * sizeof(float); - auto attention_bias_thread_fp32 = static_cast(allocator->Alloc(bytes)); - BufferUniquePtr scratch_buffer(attention_bias_thread_fp32, BufferDeleter(allocator)); MlasConvertHalfToFloatBuffer(attention_bias_thread + start_offset, attention_bias_thread_fp32, window_size); ApplyAttentionBias(output_softmax + start_offset, attention_bias_thread_fp32, static_cast(window_size)); diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 1a737f3a9d251..34410a5f42630 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -106,6 +106,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QMoE); // ******** End: Quantization ******************* // #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED @@ -271,6 +273,8 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h new file mode 100644 index 0000000000000..eae96c186d471 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/tensor_shape.h" +#include "core/framework/op_kernel.h" +#include "contrib_ops/cpu/moe/moe_helper.h" +#include + +namespace onnxruntime { +namespace contrib { + +enum class ActivationType { + Relu = 0, + Gelu = 1, + Silu = 2, + Identity = 3, + SwiGLU = 4, +}; + +class MoEBaseCPU { + protected: + MoEBaseCPU(const OpKernelInfo& op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("k", &k_).IsOK()); + + std::string activation_type_str; + ORT_ENFORCE(op_kernel_info.GetAttr("activation_type", &activation_type_str).IsOK()); + if (activation_type_str == "relu") { + activation_type_ = ActivationType::Relu; + } else if (activation_type_str == "gelu") { + activation_type_ = ActivationType::Gelu; + } else if (activation_type_str == "silu") { + activation_type_ = ActivationType::Silu; + } else if (activation_type_str == "identity") { + activation_type_ = ActivationType::Identity; + } else if (activation_type_str == "swiglu") { + activation_type_ = ActivationType::SwiGLU; + } else { + ORT_THROW("Unsupported MoE activation type: ", activation_type_str); + } + + normalize_routing_weights_ = op_kernel_info.GetAttrOrDefault("normalize_routing_weights", 0) == 1; + + use_sparse_mixer_ = op_kernel_info.GetAttrOrDefault("use_sparse_mixer", 0) == 1; + if (use_sparse_mixer_) { + ORT_ENFORCE(k_ == 2, "Sparse mixer only supports k=2"); + } + + swiglu_fusion_ = op_kernel_info.GetAttrOrDefault("swiglu_fusion", 0); + swiglu_limit_ = op_kernel_info.GetAttrOrDefault("swiglu_limit", std::numeric_limits::infinity()); + activation_alpha_ = op_kernel_info.GetAttrOrDefault("activation_alpha", 1.0f); + activation_beta_ = op_kernel_info.GetAttrOrDefault("activation_beta", 0.0f); + } + + bool normalize_routing_weights_; + bool use_sparse_mixer_; + int64_t k_; + ActivationType activation_type_; + float activation_alpha_; + float activation_beta_; + float swiglu_limit_; + int64_t swiglu_fusion_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h new file mode 100644 index 0000000000000..e494719464d20 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h @@ -0,0 +1,131 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/common.h" +#include "core/framework/tensor_shape.h" +#include "core/util/shape_checker.h" + +namespace onnxruntime { +namespace contrib { + +enum class MoEParallelType { + None = 0, + EP = 1, + TP = 2, + EPAndTP = 3, +}; + +struct MoEParameters { + MoEParameters() = default; + + explicit MoEParameters(int64_t tensor_shards) + : tensor_shards(tensor_shards) {} + + int64_t num_rows{0}; + int64_t num_experts{0}; + int64_t local_num_experts{0}; + int64_t hidden_size{0}; + int64_t inter_size{0}; + + MoEParallelType parallel_type{MoEParallelType::None}; + int64_t tensor_shards{1}; +}; +namespace moe_helper { + +template +Status CheckInputs(MoEParameters& parameters, + const Tensor* input, // required + const Tensor* router_probs, // required + const Tensor* fc1_experts_weights, // required + const Tensor* fc1_experts_bias, // optional + const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc2_experts_weights, // required + const Tensor* fc2_experts_bias, // optional + const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc3_experts_weights, // optional + const Tensor* fc3_experts_bias, // optional + const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE + const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) + const bool is_fused_swiglu) { + // Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later. + ASSERT_TENSOR_2D_OR_3D(input); + ASSERT_TENSOR_3D(fc1_experts_weights); + ASSERT_TENSOR_3D(fc2_experts_weights); + ASSERT_TENSOR_2D(router_probs); + + const auto& input_dims = input->Shape().GetDims(); + const auto& router_probs_dims = router_probs->Shape().GetDims(); + const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); + const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims(); + + int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1]; + int64_t hidden_size = input_dims[input_dims.size() - 1]; + int64_t local_num_experts = fc1_experts_weights_dims[0]; + int64_t num_experts = router_probs_dims[1]; + int64_t inter_size = (fc2_experts_weights_dims[1] * fc2_experts_weights_dims[2] * pack_size) / hidden_size; + + const bool legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) || + (hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size); + + // Fused swiglu doubles the output dimension of FC1 since it fused two GEMMs into one. + const int64_t fc1_inter_size = is_fused_swiglu ? (inter_size + inter_size) : inter_size; + + if (legacy_shape) { + // legacy shape does not match column major memory layout. This is for backward compatibility. + CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, hidden_size, fc1_inter_size / pack_size); + CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, inter_size, hidden_size / pack_size); + CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, hidden_size, inter_size / pack_size); + } else { + CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, fc1_inter_size, hidden_size / pack_size); + CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, hidden_size, inter_size / pack_size); + CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, inter_size, hidden_size / pack_size); + } + + CHECK_TENSOR_SHAPE(router_probs, num_rows, num_experts); + + CHECK_TENSOR_SHAPE(fc1_experts_bias, num_experts, fc1_inter_size); + CHECK_TENSOR_SHAPE(fc2_experts_bias, num_experts, hidden_size); + CHECK_TENSOR_SHAPE(fc3_experts_bias, num_experts, inter_size); + + CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size); + CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size); + CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size); + + if (fc3_experts_weights == nullptr) { + ORT_ENFORCE(fc3_experts_bias == nullptr && fc3_experts_scales == nullptr); + } else { + ORT_ENFORCE(fc1_experts_scales == nullptr || fc3_experts_scales != nullptr); // MOE no scale, or qMOE need scales + } + + parameters.num_rows = num_rows; + parameters.num_experts = num_experts; + parameters.local_num_experts = local_num_experts; + parameters.hidden_size = hidden_size; + parameters.inter_size = inter_size; + if (num_experts == local_num_experts) { + if (parameters.tensor_shards == 1) { + parameters.parallel_type = MoEParallelType::None; + } else { + parameters.parallel_type = MoEParallelType::TP; + } + } else if (num_experts > local_num_experts) { + if (parameters.tensor_shards == 1) { + parameters.parallel_type = MoEParallelType::EP; + } else { + parameters.parallel_type = MoEParallelType::EPAndTP; + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_experts must be greater than or equal to local_num_experts, got ", num_experts, + " and ", local_num_experts); + } + + return Status::OK(); +} + +} // namespace moe_helper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc new file mode 100644 index 0000000000000..5c6c3b919b572 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -0,0 +1,400 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/moe/moe_quantization_cpu.h" + +#include "core/framework/allocator.h" +#include "core/framework/float16.h" +#include "core/mlas/inc/mlas.h" +#include "core/platform/threadpool.h" +#include "core/providers/cpu/math/gemm_helper.h" +#include "contrib_ops/cpu/moe/moe_utils.h" +#include "contrib_ops/cpu/moe/moe_helper.h" + +#include +#include +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +// Helper function to dequantize weights. Supports 4-bit and 8-bit symmetric quantization. +// The source quantized weights are stored as a row-major representation of the transposed +// logical weight matrix (W^T). This function dequantizes it into a float row-major W^T matrix. +template +void DequantizeBlock(const uint8_t* quantized_data, + const TScale* scales, + int64_t /*block_size*/, + int64_t num_bits, + int64_t rows, + int64_t cols, + float* dequantized_data) { + const float zero_point = num_bits == 8 ? 128.0f : 8.0f; + if (num_bits == 8) { + for (int64_t r = 0; r < rows; ++r) { + const float scale = static_cast(scales[r]); + for (int64_t c = 0; c < cols; ++c) { + // Symmetric quantization: dequantized_value = scale * (quantized_value - zero_point) + dequantized_data[r * cols + c] = scale * (static_cast(quantized_data[r * cols + c]) - zero_point); + } + } + } else if (num_bits == 4) { + const int64_t packed_cols = (cols + 1) / 2; + for (int64_t r = 0; r < rows; ++r) { + const float scale = static_cast(scales[r]); + for (int64_t c = 0; c < cols; ++c) { + const uint8_t packed_val = quantized_data[r * packed_cols + c / 2]; + // Unpack the 4-bit value. Low nibble for even columns, high nibble for odd columns. + const uint8_t quantized_val = (c % 2 == 0) ? (packed_val & 0x0F) : (packed_val >> 4); + // Symmetric quantization: dequantized_value = scale * (quantized_value - zero_point) + dequantized_data[r * cols + c] = scale * (static_cast(quantized_val) - zero_point); + } + } + } +} + +template +QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) + : OpKernel(op_kernel_info), + MoEBaseCPU(op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("expert_weight_bits", &expert_weight_bits_).IsOK()); + ORT_ENFORCE(expert_weight_bits_ == 4 || expert_weight_bits_ == 8, + "Attribute 'expert_weight_bits' must be 4 or 8."); + block_size_ = op_kernel_info.GetAttrOrDefault("block_size", 0); +} + +template +Status QMoECPU::Compute(OpKernelContext* context) const { + // --- 1. Get Inputs and Attributes --- + const auto* input = context->Input(0); + const auto* router_probs = context->Input(1); + const auto* fc1_experts_weights = context->Input(2); + const auto* fc1_scales = context->Input(3); + const auto* fc1_experts_bias = context->Input(4); + const auto* fc2_experts_weights = context->Input(5); + const auto* fc2_scales = context->Input(6); + const auto* fc2_experts_bias = context->Input(7); + const auto* fc3_experts_weights = context->Input(8); + const auto* fc3_scales = context->Input(9); + const auto* fc3_experts_bias = context->Input(10); + + MoEParameters moe_params; + ORT_RETURN_IF_ERROR(moe_helper::CheckInputs( + moe_params, input, router_probs, + fc1_experts_weights, fc1_experts_bias, fc1_scales, + fc2_experts_weights, fc2_experts_bias, fc2_scales, + fc3_experts_weights, fc3_experts_bias, fc3_scales, + expert_weight_bits_ == 4 ? 2 : 1, + true)); + + if (fc3_experts_weights || fc3_experts_bias || fc3_scales) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "FC3 gating is not yet implemented on CPU for QMoE"); + } + + const auto& input_shape = input->Shape(); + const int64_t num_tokens = moe_params.num_rows; + const int64_t hidden_size = moe_params.hidden_size; + const int64_t inter_size = moe_params.inter_size; + const int64_t num_experts = moe_params.num_experts; + const int64_t fc1_out_features = inter_size * (swiglu_fusion_ > 0 ? 2 : 1); + + auto* output = context->Output(0, input_shape); + auto* tp = context->GetOperatorThreadPool(); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + const size_t output_buffer_size = static_cast(output->Shape().Size()); + + const T* input_data = input->Data(); + const T* router_probs_data = router_probs->Data(); + + // --- 2. Routing Logic: Assign tokens to experts --- + IAllocatorUniquePtr router_logits_float_buffer; + const float* router_logits_float; + if constexpr (std::is_same_v) { + router_logits_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * num_experts)); + router_logits_float = router_logits_float_buffer.get(); + MlasConvertHalfToFloatBuffer(reinterpret_cast(router_probs_data), + const_cast(router_logits_float), + static_cast(num_tokens * num_experts)); + } else { + router_logits_float = reinterpret_cast(router_probs_data); + } + + auto route_expert_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); + int* route_expert = route_expert_ptr.get(); + auto route_scale_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); + float* route_scale = route_scale_ptr.get(); + + // Parallelize the routing logic to improve performance for large token batches. + // Minor performance regression for single-token decoding is an acceptable trade-off + int num_routing_threads = (tp == nullptr || num_tokens < 4096) ? 1 : std::min(static_cast(num_tokens), concurrency::ThreadPool::DegreeOfParallelism(tp)); + + std::vector>> thread_local_expert_token_maps(num_routing_threads); + for (auto& map : thread_local_expert_token_maps) { + map.resize(static_cast(num_experts)); + } + + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_routing_threads, [&](std::ptrdiff_t thread_id) { + auto work = concurrency::ThreadPool::PartitionWork(static_cast(thread_id), num_routing_threads, static_cast(num_tokens)); + auto& local_expert_token_map = thread_local_expert_token_maps[thread_id]; + + // Pre-allocate buffers for this thread to reuse, avoiding allocations inside the loop. + std::vector> sorted_logits(static_cast(num_experts)); + std::vector top_k_exp(static_cast(k_)); + + for (int64_t i = work.start; i < work.end; ++i) { + const float* logits = router_logits_float + i * num_experts; + for (int64_t j = 0; j < num_experts; ++j) { + sorted_logits[static_cast(j)] = {logits[j], j}; + } + std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), sorted_logits.end(), std::greater<>()); + + float max_logit = -std::numeric_limits::infinity(); + for (int64_t j = 0; j < k_; ++j) { + if (sorted_logits[static_cast(j)].first > max_logit) { + max_logit = sorted_logits[static_cast(j)].first; + } + } + + float sum_exp = 0.0f; + for (int64_t j = 0; j < k_; ++j) { + top_k_exp[static_cast(j)] = std::exp(sorted_logits[static_cast(j)].first - max_logit); + sum_exp += top_k_exp[static_cast(j)]; + } + + float scale = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp); + for (int64_t j = 0; j < k_; ++j) { + int64_t expert_idx = sorted_logits[static_cast(j)].second; + int64_t route_idx = i * k_ + j; + route_expert[route_idx] = static_cast(expert_idx); + route_scale[route_idx] = top_k_exp[static_cast(j)] * scale; + if (route_scale[route_idx] > 0.0f) { + local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); + } + } + } + }); + + // Merge the maps from each thread into a single global map. + std::vector> expert_token_map(static_cast(num_experts)); + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + size_t total_tokens_for_expert = 0; + for (int t = 0; t < num_routing_threads; ++t) { + total_tokens_for_expert += thread_local_expert_token_maps[t][static_cast(expert_idx)].size(); + } + expert_token_map[static_cast(expert_idx)].reserve(total_tokens_for_expert); + } + + for (int t = 0; t < num_routing_threads; ++t) { + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + auto& local_tokens = thread_local_expert_token_maps[t][static_cast(expert_idx)]; + if (!local_tokens.empty()) { + expert_token_map[static_cast(expert_idx)].insert(expert_token_map[static_cast(expert_idx)].end(), local_tokens.begin(), local_tokens.end()); + } + } + } + + // --- 3. Parallel Expert Computation --- + IAllocatorUniquePtr input_float_buffer; + const float* input_float; + if constexpr (std::is_same_v) { + input_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * hidden_size)); + input_float = input_float_buffer.get(); + MlasConvertHalfToFloatBuffer(reinterpret_cast(input_data), + const_cast(input_float), + static_cast(num_tokens * hidden_size)); + } else { + input_float = reinterpret_cast(input_data); + } + + int num_expert_threads = (tp == nullptr) ? 1 : std::min(static_cast(num_experts), concurrency::ThreadPool::DegreeOfParallelism(tp)); + if (num_expert_threads == 0) num_expert_threads = 1; + auto thread_local_outputs_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * output_buffer_size); + float* thread_local_outputs = thread_local_outputs_ptr.get(); + memset(thread_local_outputs, 0, static_cast(num_expert_threads) * output_buffer_size * sizeof(float)); + + // Pre-calculate workspace size per thread to avoid allocations inside the loop + size_t max_tokens_per_expert = 0; + for (const auto& tokens : expert_token_map) { + if (tokens.size() > max_tokens_per_expert) { + max_tokens_per_expert = tokens.size(); + } + } + + const size_t A1_size = static_cast(max_tokens_per_expert * hidden_size); + const size_t C1_size = static_cast(max_tokens_per_expert * fc1_out_features); + const size_t A2_size = static_cast(max_tokens_per_expert * inter_size); + const size_t C2_size = static_cast(max_tokens_per_expert * hidden_size); + const size_t B1_dequant_size = static_cast(fc1_out_features * hidden_size); + const size_t B2_dequant_size = static_cast(hidden_size * inter_size); + const size_t bias1_size = static_cast(fc1_out_features); + const size_t bias2_size = static_cast(hidden_size); + + const size_t workspace_elements_per_thread = A1_size + C1_size + A2_size + C2_size + B1_dequant_size + B2_dequant_size + bias1_size + bias2_size; + auto workspace_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * workspace_elements_per_thread); + float* workspace = workspace_ptr.get(); + + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t thread_id_pd) { + int thread_id = static_cast(thread_id_pd); + auto work = concurrency::ThreadPool::PartitionWork(thread_id, num_expert_threads, static_cast(num_experts)); + + float* thread_workspace = workspace + static_cast(thread_id) * workspace_elements_per_thread; + + for (int64_t expert_idx = work.start; expert_idx < work.end; ++expert_idx) { + const auto& routes = expert_token_map[static_cast(expert_idx)]; + if (routes.empty()) { + continue; + } + + const int64_t num_expert_tokens = routes.size(); + + // Partition the workspace for the current expert + float* A1 = thread_workspace; + float* C1 = A1 + num_expert_tokens * hidden_size; + float* A2 = C1 + num_expert_tokens * fc1_out_features; + float* C2 = A2 + num_expert_tokens * inter_size; + float* B1_dequant = C2 + num_expert_tokens * hidden_size; + float* B2_dequant = B1_dequant + fc1_out_features * hidden_size; + float* bias1_float = B2_dequant + hidden_size * inter_size; + float* bias2_float = bias1_float + fc1_out_features; + + // --- Gather input tokens for the current expert --- + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const int64_t token_idx = routes[static_cast(i)] / k_; + memcpy(A1 + i * hidden_size, + input_float + token_idx * hidden_size, + static_cast(hidden_size) * sizeof(float)); + } + + // --- FC1 GEMM (X * W1^T) --- + DequantizeBlock(fc1_experts_weights->Data() + expert_idx * fc1_out_features * (hidden_size / (8 / expert_weight_bits_)), + fc1_scales->Data() + expert_idx * fc1_out_features * (block_size_ > 0 ? hidden_size / block_size_ : 1), + block_size_, expert_weight_bits_, + fc1_out_features, hidden_size, B1_dequant); + + MlasGemm(CblasNoTrans, CblasTrans, + static_cast(num_expert_tokens), static_cast(fc1_out_features), static_cast(hidden_size), + 1.0f, A1, static_cast(hidden_size), + B1_dequant, static_cast(hidden_size), + 0.0f, C1, static_cast(fc1_out_features), + nullptr); + + const T* B1_bias = (fc1_experts_bias) ? fc1_experts_bias->Data() + expert_idx * fc1_out_features : nullptr; + if (B1_bias) { + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), bias1_float, static_cast(fc1_out_features)); + } else { + memcpy(bias1_float, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + } + for (int64_t i = 0; i < num_expert_tokens; ++i) { + for (int64_t j = 0; j < fc1_out_features; ++j) { + C1[i * fc1_out_features + j] += bias1_float[j]; + } + } + } + + // --- Activation --- + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + + // --- FC2 GEMM (A2 * W2^T) --- + DequantizeBlock(fc2_experts_weights->Data() + expert_idx * hidden_size * (inter_size / (8 / expert_weight_bits_)), + fc2_scales->Data() + expert_idx * hidden_size * (block_size_ > 0 ? inter_size / block_size_ : 1), + block_size_, expert_weight_bits_, + hidden_size, inter_size, B2_dequant); + + MlasGemm(CblasNoTrans, CblasTrans, + static_cast(num_expert_tokens), static_cast(hidden_size), static_cast(inter_size), + 1.0f, A2, static_cast(inter_size), + B2_dequant, static_cast(inter_size), + 0.0f, C2, static_cast(hidden_size), + nullptr); + + const T* B2_bias = (fc2_experts_bias) ? fc2_experts_bias->Data() + expert_idx * hidden_size : nullptr; + if (B2_bias) { + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), bias2_float, static_cast(hidden_size)); + } else { + memcpy(bias2_float, B2_bias, static_cast(hidden_size) * sizeof(float)); + } + } + + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const int64_t route_idx = routes[static_cast(i)]; + const int64_t token_idx = route_idx / k_; + const float weight = route_scale[route_idx]; + + const size_t buffer_offset = static_cast(token_idx) * static_cast(hidden_size); + if (buffer_offset + static_cast(hidden_size) > output_buffer_size) { + // Skip this token to prevent buffer overflow + continue; + } + + float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + buffer_offset; + const float* src = C2 + i * hidden_size; + for (int64_t j = 0; j < hidden_size; ++j) { + dest[j] += weight * (src[j] + (B2_bias ? bias2_float[j] : 0.0f)); + } + } + } + }); + + // --- 4. Final Reduction (accumulate expert outputs to a float buffer) --- + auto accumulate = [&](float* buffer) { + memset(buffer, 0, output_buffer_size * sizeof(float)); + for (int i = 0; i < num_expert_threads; ++i) { + const size_t thread_offset = static_cast(i) * output_buffer_size; + for (size_t j = 0; j < output_buffer_size; ++j) { + buffer[j] += thread_local_outputs[thread_offset + j]; + } + } + }; + + if constexpr (std::is_same_v) { + auto final_output_float_ptr = IAllocator::MakeUniquePtr(allocator, output_buffer_size); + float* final_output_float = final_output_float_ptr.get(); + accumulate(final_output_float); + + // --- 5. Convert final float buffer to output type T --- + MlasConvertFloatToHalfBuffer(final_output_float, + reinterpret_cast(output->MutableData()), + static_cast(output_buffer_size)); + } else { // T is float + accumulate(output->MutableData()); + } + + return Status::OK(); +} + +// Explicit template instantiation +template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); +template Status QMoECPU::Compute(OpKernelContext* context) const; +template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); +template Status QMoECPU::Compute(OpKernelContext* context) const; + +// Kernel Registration +ONNX_OPERATOR_TYPED_KERNEL_EX( + QMoE, kMSDomain, 1, float, kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QMoECPU); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QMoE, kMSDomain, 1, MLFloat16, kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QMoECPU); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h new file mode 100644 index 0000000000000..890580e051a8e --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "contrib_ops/cpu/moe/moe_base_cpu.h" + +namespace onnxruntime { +namespace contrib { + +/** + * @brief QMoE is the templated CPU implementation of the Quantized Mixture of Experts operator. + * + * This kernel supports both float and MLFloat16 data types for activations, scales, and outputs. + * It parallelizes expert computation using the ONNX Runtime thread pool and minimizes memory + * usage through on-the-fly block dequantization of weights. + * + * @tparam T The data type for the kernel (float or MLFloat16). + */ +template +class QMoECPU final : public OpKernel, public MoEBaseCPU { + public: + explicit QMoECPU(const OpKernelInfo& op_kernel_info); + Status Compute(OpKernelContext* context) const override; + + private: + int64_t expert_weight_bits_; + int64_t block_size_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc b/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc new file mode 100644 index 0000000000000..2c59210bfabd4 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/moe/moe_utils.h" +#include +#include +#include "core/common/common.h" + +namespace onnxruntime { +namespace contrib { + +float ApplyActivation(float x, ActivationType activation_type) { + switch (activation_type) { + case ActivationType::Relu: + return std::max(0.0f, x); + case ActivationType::Gelu: + return 0.5f * x * (1.0f + std::tanh(0.7978845608f * (x + 0.044715f * x * x * x))); + case ActivationType::Silu: + return x * (1.0f / (1.0f + std::exp(-x))); + case ActivationType::Identity: + return x; + case ActivationType::SwiGLU: + // SwiGLU is a special case handled by ApplySwiGLUActivation, this is just a placeholder + return x; + default: + return x; + } +} + +void ApplySwiGLUActivation(const float* input_data, float* output_data, int64_t inter_size, bool is_interleaved_format, + float activation_alpha, float activation_beta, float clamp_limit) { + if (is_interleaved_format) { + for (int64_t i = 0; i < inter_size; ++i) { + float gate_val = input_data[2 * i]; + float linear_val = input_data[2 * i + 1]; + + gate_val = std::min(gate_val, clamp_limit); + linear_val = std::clamp(linear_val, -clamp_limit, clamp_limit); + + float sigmoid_arg = activation_alpha * gate_val; + float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg)); + float swish_out = gate_val * sigmoid_out; + + output_data[i] = swish_out * (linear_val + activation_beta); + } + } else { + ORT_NOT_IMPLEMENTED("Non-interleaved format not supported for SwiGLU activation"); + } +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_utils.h b/onnxruntime/contrib_ops/cpu/moe/moe_utils.h new file mode 100644 index 0000000000000..de238e8d7ae66 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_utils.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "contrib_ops/cpu/moe/moe_base_cpu.h" + +namespace onnxruntime { +namespace contrib { + +float ApplyActivation(float x, ActivationType activation_type); + +void ApplySwiGLUActivation(const float* input_data, float* output_data, int64_t inter_size, bool is_interleaved_format, + float activation_alpha, float activation_beta, float clamp_limit); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index e2bb3b508ca7c..36a6f70cc69d9 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/cpuid_info.h" // for CPUIDInfo::GetCPUIDInfo().HasArm_SME() #include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/mlas/inc/mlas.h" @@ -10,6 +11,7 @@ #include "core/util/math_cpuonly.h" #include "core/util/qmath.h" +#include #include #include @@ -169,43 +171,53 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { // only pack Matrix B if (input_idx == GetBIdx()) { const Tensor* b_zp_constant_tensor{nullptr}; - bool b_quantization_is_asymmetric = false; + bool b_quantization_might_be_asymmetric = false; - // zero point tensor could be provided as a direct input to the kernel and not as a constant so this - // test is not sufficient const OrtValue* b_zp; if (Info().TryGetConstantInput(IN_B_ZERO_POINT, &b_zp)) { b_zp_constant_tensor = &b_zp->Get(); } - // MlasDynamicQgemm requires symmetric quantization for B, so no zero point should exist or it should - // have a zero value - if (b_zp_constant_tensor != nullptr) { // Covers the case where tensor is not a constant - const auto& shape = b_zp_constant_tensor->Shape(); - const auto* zp_data = static_cast(b_zp_constant_tensor->DataRaw()); - size_t zp_size = static_cast(shape.Size()); - // MlasDynamicQgemm requires symmetric quantization: zp must be scalar 0 or 1D all-zero - if ((shape.NumDimensions() == 0) && (zp_data[0] == 0)) { - b_quantization_is_asymmetric = false; - } else if (shape.NumDimensions() == 1) { - b_quantization_is_asymmetric = false; - for (size_t i = 0; i < zp_size; ++i) { - if (zp_data[i] != 0) { - b_quantization_is_asymmetric = true; - break; - } - } - } else { - // Unsupported higher-rank zp tensor - b_quantization_is_asymmetric = true; - } + // MlasDynamicQgemm requires symmetric quantization for B, so the B zero point value should either be all zeros + // or not provided. + if (b_zp_constant_tensor != nullptr) { + // B zero point is constant. Check if it is all zeros. + assert(b_zp_constant_tensor->IsDataType() || b_zp_constant_tensor->IsDataType()); + const auto* zp_bytes = static_cast(b_zp_constant_tensor->DataRaw()); + const size_t zp_size_in_bytes = b_zp_constant_tensor->SizeInBytes(); + b_quantization_might_be_asymmetric = std::any_of(zp_bytes, zp_bytes + zp_size_in_bytes, + [](std::byte v) { return v != std::byte{0}; }); + } else { + // B zero point input is not constant. If it exists, we can't assume symmetric quantization. + const auto input_defs = Info().node().InputDefs(); + const bool b_zp_input_exists = input_defs.size() > IN_B_ZERO_POINT && input_defs[IN_B_ZERO_POINT]->Exists(); + b_quantization_might_be_asymmetric = b_zp_input_exists; } // MlasDynamicQgemm requires scale data to be available at packing stage const Tensor* b_scale_tensor = nullptr; const bool b_scale_available = Info().TryGetConstantInput(IN_B_SCALE, &b_scale_tensor); - can_use_dynamic_quant_mlas_ = (!b_quantization_is_asymmetric && b_scale_available); + can_use_dynamic_quant_mlas_ = (!b_quantization_might_be_asymmetric && b_scale_available); + + // Kleidi dynamic path requires strictly positive, finite scales. + // Disable if any invalid scale is detected. + if (can_use_dynamic_quant_mlas_) { + const auto bs = b_scale_tensor->DataAsSpan(); + const bool has_invalid = + std::any_of(bs.begin(), bs.end(), + [](float s) { return !std::isfinite(s) || s <= 0.0f; }); + + if (has_invalid) { + can_use_dynamic_quant_mlas_ = false; + } + } + + // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. + // We check that here too before attempting to use them. + if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) { + can_use_dynamic_quant_mlas_ = false; + } // Only handle the common case of a 2D weight matrix. Additional matrices // could be handled by stacking the packed buffers. @@ -380,7 +392,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { if (y->Shape().Size() == 0) return Status::OK(); - auto a_data = static_cast(ctx->Input(IN_A)->DataRaw()); + const float* a_data = ctx->Input(IN_A)->Data(); auto* y_data = y->MutableData(); // batch gemm @@ -394,7 +406,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { for (size_t gemm_idx = 0; gemm_idx < num_gemms; gemm_idx++) { auto& params = gemm_data_vec[gemm_idx]; - params.A = reinterpret_cast(a_data + helper.LeftOffsets()[gemm_idx]); + params.A = a_data + helper.LeftOffsets()[gemm_idx]; params.lda = gemm_shape.K; params.PackedB = packed_b_.get(); params.C = y_data + helper.OutputOffsets()[gemm_idx]; diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 1a4a38282fcc1..51252dc2b0467 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -197,19 +197,33 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All has_zp_input_, nullptr, nullptr); is_packed = true; } else if (compute_type_ == SQNBIT_CompInt8) { -#ifdef MLAS_TARGET_AMD64_IX86 - if (input_idx == InputIndex::scales && packed_b_ != nullptr) { - auto sptr = tensor.Data(); - MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, - has_zp_input_, nullptr, nullptr); - is_packed = false; - } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { - auto zptr = tensor.Data(); - MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, - has_zp_input_, zptr, nullptr); - is_packed = false; + // Packing scales and zero points + bool should_pack_scale_and_zp_inputs = [&]() { +#if defined(MLAS_TARGET_AMD64_IX86) + return true; +#else + return (nbits_ == 8); +#endif + }(); + + if (should_pack_scale_and_zp_inputs) { + if (input_idx == InputIndex::scales && packed_b_ != nullptr) { + auto sptr = tensor.Data(); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, + has_zp_input_, nullptr, nullptr); + is_packed = false; + } + + // Packing zero_point + if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { + auto zptr = tensor.Data(); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, + has_zp_input_, zptr, nullptr); + is_packed = false; + } } -#elif defined(MLAS_TARGET_ARM64) + +#if defined(MLAS_TARGET_ARM64) if (input_idx == InputIndex::scales && packed_b_ != nullptr && MlasQNBitGemmScalesPacked(K_, nbits_, block_size_, compute_type_, has_zp_input_)) { scales_are_packed_ = true; diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 1a4a63de38790..93d802ca05b42 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -71,15 +71,21 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { const Tensor* fc3_experts_bias_optional = context->Input(7); MoEParameters moe_params(tensor_shards_); - MoEQuantType quant_type = MoEQuantType::None; - ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights, - fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, - fc3_experts_weights_optional, fc3_experts_bias_optional)); + ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs( + moe_params, input, router_probs, + fc1_experts_weights, fc1_experts_bias_optional, nullptr, + fc2_experts_weights, fc2_experts_bias_optional, nullptr, + fc3_experts_weights_optional, fc3_experts_bias_optional, nullptr, + 1, // no quantization so pack size is 1 + activation_type_ == ort_fastertransformer::ActivationType::SwiGLU)); ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size"); - ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, - normalize_routing_weights_, use_sparse_mixer_); + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, + activation_type_, + fc3_experts_weights_optional != nullptr, + normalize_routing_weights_, + use_sparse_mixer_); size_t ws_size = moe_runner.getWorkspaceSize( static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index f3346d4513261..36d6fc378d45e 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -15,6 +15,10 @@ using namespace onnxruntime::common; ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, start_ver, end_ver, type, name) #define CUDA_MS_OP_VERSIONED_CLASS_NAME(start_ver, end_ver, name) \ ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, start_ver, end_ver, name) +#define CUDA_MS_OP_TWO_TYPED_CLASS_NAME(ver, type1, type2, name) \ + ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, ver, type1, type2, name) +#define CUDA_MS_OP_THREE_TYPED_CLASS_NAME(ver, type1, type2, type3, name) \ + ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, ver, type1, type2, type3, name) #define CUDA_ONNX_OP_TYPED_CLASS_NAME(ver, type, name) \ ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, ver, type, name) @@ -92,7 +96,9 @@ class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Crop); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Crop); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MoE); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MoE); -class CUDA_MS_OP_CLASS_NAME(1, QMoE); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MoE); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, QMoE); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, QMoE); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_float, MultiHeadAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_MLFloat16, MultiHeadAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_float, MultiHeadAttention); @@ -184,6 +190,25 @@ class CUDA_MS_OP_CLASS_NAME(1, GemmFloat8); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SparseAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, SparseAttention); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, float, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, MLFloat16, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, BFloat16, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, float, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, MLFloat16, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, BFloat16, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, float, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, MLFloat16, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, BFloat16, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, float, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, MLFloat16, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, BFloat16, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, float, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, MLFloat16, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, BFloat16, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, float, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, MLFloat16, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, BFloat16, int64_t, GatherBlockQuantized); + #ifdef ENABLE_ATEN class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen); #endif @@ -307,7 +332,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -404,6 +431,24 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef ENABLE_ATEN BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scale_zeros.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scale_zeros.cu index 47e662b9a88ba..07c1fda362a07 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scale_zeros.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scale_zeros.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#if USE_FPA_INTB_GEMM #include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h" namespace onnxruntime::llm { @@ -24,3 +24,4 @@ template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t, } // namespace cutlass_kernels } // namespace kernels } // namespace onnxruntime::llm +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scaleonly.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scaleonly.cu index 026852623513b..5473eeaf64678 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scaleonly.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scaleonly.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#if USE_FPA_INTB_GEMM #include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h" namespace onnxruntime::llm { @@ -24,3 +24,4 @@ template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t, } // namespace cutlass_kernels } // namespace kernels } // namespace onnxruntime::llm +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scale_zeros.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scale_zeros.cu index 9452aa0e1fbe6..f27879604cf2e 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scale_zeros.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scale_zeros.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#if USE_FPA_INTB_GEMM #include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h" namespace onnxruntime::llm { @@ -24,3 +24,4 @@ template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t, } // namespace cutlass_kernels } // namespace kernels } // namespace onnxruntime::llm +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scaleonly.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scaleonly.cu index 0849f6d9da042..fe84b8e03251e 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scaleonly.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scaleonly.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#if USE_FPA_INTB_GEMM #include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h" namespace onnxruntime::llm { @@ -23,3 +23,4 @@ template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t, cutlass::WeightO } // namespace cutlass_kernels } // namespace kernels } // namespace onnxruntime::llm +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scale_zeros.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scale_zeros.cu index 4a22e0f1b2aac..785d76ec4339e 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scale_zeros.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scale_zeros.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#if USE_FPA_INTB_GEMM #include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h" namespace onnxruntime::llm { @@ -24,3 +24,4 @@ template class CutlassFpAIntBGemmRunner #include "core/providers/cuda/cuda_common.h" @@ -283,3 +283,4 @@ void transpose_uint8_matrix_and_convert_to_int8( } // namespace fpA_intB_gemv } // namespace kernels } // namespace onnxruntime::llm +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu index 4990f676cb5c4..2ecb7c11a6710 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - +#if USE_FPA_INTB_GEMM #include "contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/common/safeint.h" @@ -581,3 +581,4 @@ void preprocess_weights_for_mixed_gemm_cuda(cudaStream_t stream, } // namespace weight_only } // namespace kernels } // namespace onnxruntime::llm +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4.cu index e2c008884c998..2423350a532b9 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#if USE_FPA_INTB_GEMM #include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" namespace onnxruntime::llm { @@ -30,3 +30,4 @@ INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( } // namespace fpA_intB_gemv } // namespace kernels } // namespace onnxruntime::llm +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4_hopper.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4_hopper.cu index 8cd96c44421e5..5272fd3c908bd 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4_hopper.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4_hopper.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#if USE_FPA_INTB_GEMM #include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" namespace onnxruntime::llm { @@ -29,3 +29,4 @@ INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( } // namespace fpA_intB_gemv } // namespace kernels } // namespace onnxruntime::llm +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8.cu index 1eb5f51bdffdc..a2bf24b444388 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#if USE_FPA_INTB_GEMM #include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" namespace onnxruntime::llm { @@ -26,3 +26,4 @@ INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( } // namespace fpA_intB_gemv } // namespace kernels } // namespace onnxruntime::llm +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8_hopper.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8_hopper.cu index f5872841e1acb..cb91d12bfa9b5 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8_hopper.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8_hopper.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#if USE_FPA_INTB_GEMM #include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" namespace onnxruntime::llm { @@ -26,3 +26,4 @@ INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( } // namespace fpA_intB_gemv } // namespace kernels } // namespace onnxruntime::llm +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4.cu index f6b76e67b20ba..2dd84acacd5e2 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#if USE_FPA_INTB_GEMM #include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" namespace onnxruntime::llm { @@ -30,3 +30,4 @@ INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( } // namespace fpA_intB_gemv } // namespace kernels } // namespace onnxruntime::llm +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4_hopper.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4_hopper.cu index 2ca88285d4cfe..b717d1e928fc4 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4_hopper.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4_hopper.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#if USE_FPA_INTB_GEMM #include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" namespace onnxruntime::llm { @@ -29,3 +29,4 @@ INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( } // namespace fpA_intB_gemv } // namespace kernels } // namespace onnxruntime::llm +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8.cu index 7a00e1ba35f80..b59e7ab0ba677 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#if USE_FPA_INTB_GEMM #include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" namespace onnxruntime::llm { @@ -26,3 +26,4 @@ INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( } // namespace fpA_intB_gemv } // namespace kernels } // namespace onnxruntime::llm +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8_hopper.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8_hopper.cu index 4a8506ca6bbde..277e1f7ad85cb 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8_hopper.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8_hopper.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#if USE_FPA_INTB_GEMM #include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" namespace onnxruntime::llm { @@ -26,3 +26,4 @@ INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( } // namespace fpA_intB_gemv } // namespace kernels } // namespace onnxruntime::llm +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu index 54ed44c0d68d5..78aa458a56d47 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#if USE_FPA_INTB_GEMM #include #include #include @@ -102,3 +102,4 @@ bool is_supported(int arch, KernelType kernel_type) { } // namespace fpA_intB_gemv } // namespace kernels } // namespace onnxruntime::llm +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc b/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc index 73902a0636fcb..a89f1aa5c1150 100644 --- a/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc +++ b/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#if USE_FPA_INTB_GEMM #include "contrib_ops/cuda/llm/gemm_profiler.h" #include "contrib_ops/cuda/llm/common/logger.h" #include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h" @@ -311,3 +311,4 @@ template class GemmPluginProfiler; } // namespace onnxruntime::llm::kernels::weight_only +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/generate_kernels.py b/onnxruntime/contrib_ops/cuda/llm/generate_kernels.py index 50d2fe07d4a38..eeef49bfaa351 100644 --- a/onnxruntime/contrib_ops/cuda/llm/generate_kernels.py +++ b/onnxruntime/contrib_ops/cuda/llm/generate_kernels.py @@ -224,6 +224,7 @@ def get_file_content(launcher_inl_files, operations): instantiations = "\n".join(insts_list) file_content = f""" +#if USE_FPA_INTB_GEMM #ifndef EXCLUDE_SM_90 {includes} @@ -237,6 +238,7 @@ def get_file_content(launcher_inl_files, operations): }} // namespace kernels }} // namespace onnxruntime::llm #endif // EXCLUDE_SM_90 +#endif // USE_FPA_INTB_GEMM """ return file_content diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h index 36127054cfd5e..d5ad8161e100e 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h @@ -52,6 +52,7 @@ enum class ActivationType { Gelu, GeGLU, ReGLU, SiGLU, + SwiGLU, Identity, InvalidType }; diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_bf16.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_bf16.cu new file mode 100644 index 0000000000000..5f0a71147b366 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_bf16.cu @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif + +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +namespace ort_fastertransformer { +template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16>; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint4.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint4.cu new file mode 100644 index 0000000000000..4a84581127156 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint4.cu @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif + +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +namespace ort_fastertransformer { +template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t>; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint8.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint8.cu new file mode 100644 index 0000000000000..6c23127955ac2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint8.cu @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif + +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +namespace ort_fastertransformer { +template class MoeGemmRunner<__nv_bfloat16, uint8_t>; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index ef1f97b9e57a2..f855092670bc3 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -53,6 +53,8 @@ #include "cutlass_heuristic.h" #include "moe_gemm_kernels.h" +#include + #include #include #include @@ -66,8 +68,8 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, const int multi_processor_count, cudaStream_t stream, int* kernel_occupancy = nullptr) { - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, - "Specialized for half, float"); + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for half, float, bfloat16"); static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || @@ -76,12 +78,11 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. using ElementType_ = - typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; + typename cutlass::platform::conditional::value, cutlass::half_t, typename cutlass::platform::conditional::value, cutlass::bfloat16_t, T>::type>::type; using ElementType = ElementType_; using CutlassWeightType_ = - typename cutlass::platform::conditional::value, cutlass::half_t, - WeightType>::type; + typename cutlass::platform::conditional::value, cutlass::half_t, typename cutlass::platform::conditional::value, cutlass::bfloat16_t, WeightType>::type>::type; using CutlassWeightType = CutlassWeightType_; @@ -391,12 +392,10 @@ void MoeGemmRunner::dispatch_to_arch(const T* A, con dispatch_moe_gemm_to_cutlass( A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, stream, occupancy); - } else if (sm_ >= 80 && sm_ < 90) { + } else if (sm_ >= 80) { // Hopper and Blackwell will fallback to use Ampere kernels. dispatch_moe_gemm_to_cutlass( A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, stream, occupancy); - } else { - ORT_THROW("[MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); } } @@ -478,6 +477,7 @@ void MoeGemmRunner::moe_gemm_bias_act(const T* A, const WeightTyp int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, cudaStream_t stream) { + // Swiglu will use Identity to call this function so we not need to handle it here. switch (activation_type) { case ActivationType::Relu: run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index bfbe1d81b1c15..ce8c0270f5c32 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -38,12 +38,96 @@ #include "moe_kernel.h" +#include #include #include #include +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" + namespace ort_fastertransformer { static constexpr int WARP_SIZE = 32; + +// SwiGLU with interleaved is like the following python code using PyTorch: +// dim = x.shape[-1] +// x = x.view(-1, dim // 2, 2) +// x_glu, x_linear = x[..., 0], x[..., 1] +// y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) +template +__global__ void swiglu_kernel_interleaved(T* output, T const* input, int intermediate_size, int num_rows, float alpha, float limit) { + int const row = blockIdx.x; + if (row >= num_rows) { + return; + } + + T const* row_input = input + row * 2 * intermediate_size; + T* row_output = output + row * intermediate_size; + + for (int i = threadIdx.x; i < intermediate_size; i += blockDim.x) { + float glu = static_cast(row_input[2 * i]); + float linear = static_cast(row_input[2 * i + 1]); + + if constexpr (HasLimit) { + glu = fminf(glu, limit); + linear = fminf(fmaxf(linear, -limit), limit); + } + + float sigmoid_arg = alpha * glu; + float sigmoid_out = 1.f / (1.f + expf(-sigmoid_arg)); + + float swish_out = glu * sigmoid_out; + row_output[i] = static_cast(swish_out * (linear + 1.f)); + } +} + +// Non interleaved version of SwiGLU kernel, which splits each row into two chunks of same size. +template +__global__ void swiglu_kernel_chunked(T* output, T const* input, int intermediate_size, int num_rows, float alpha, float limit) { + int const row = blockIdx.x; + if (row >= num_rows) { + return; + } + + T const* row_input = input + row * 2 * intermediate_size; + T* row_output = output + row * intermediate_size; + + for (int i = threadIdx.x; i < intermediate_size; i += blockDim.x) { + float glu = static_cast(row_input[i]); + float linear = static_cast(row_input[i + intermediate_size]); + + if constexpr (HasLimit) { + glu = fminf(glu, limit); + linear = fminf(fmaxf(linear, -limit), limit); + } + + float sigmoid_arg = alpha * glu; + float sigmoid_out = 1.f / (1.f + expf(-sigmoid_arg)); + + float swish_out = glu * sigmoid_out; + row_output[i] = static_cast(swish_out * (linear + 1.f)); + } +} + +template +void invokeSwiGLU(T* output, T const* input, int intermediate_size, int num_rows, float alpha, float limit, cudaStream_t stream) { + if (num_rows == 0) { + return; + } + dim3 block(std::min(intermediate_size, 1024)); + dim3 grid(num_rows); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("swiglu input", input, num_rows, 2 * intermediate_size); + + if constexpr (IsInterLeaved) { + swiglu_kernel_interleaved<<>>(output, input, intermediate_size, num_rows, alpha, limit); + } else { + swiglu_kernel_chunked<<>>(output, input, intermediate_size, num_rows, alpha, limit); + } + + DUMP_TENSOR("swiglu output", output, num_rows, intermediate_size); +} + // ====================== Softmax things =============================== // We have our own implementation of softmax here so we can support transposing the output // in the softmax kernel when we extend this module to support expert-choice routing. @@ -456,7 +540,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ if (normalize_routing_weights && k_idx == k - 1) { #pragma unroll for (int ki = 0; ki < k; ++ki) { - output[idx - ki] = T(static_cast(output[idx - ki]) / output_row_sum); + float old_val = static_cast(output[idx - ki]); + output[idx - ki] = T(old_val / output_row_sum); } } } @@ -666,9 +751,14 @@ __global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, i } template -CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, bool has_fc3, +CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer) - : has_fc3_(has_fc3), total_past_rows_(0), total_covered_rows_(0), normalize_routing_weights_(normalize_routing_weights), use_sparse_mixer_(use_sparse_mixer) { + : activation_type_(activation_type), + has_fc3_(has_fc3), + total_past_rows_(0), + total_covered_rows_(0), + normalize_routing_weights_(normalize_routing_weights), + use_sparse_mixer_(use_sparse_mixer) { moe_gemm_runner_.initialize(sm_version); } @@ -695,8 +785,16 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(size_t num_ro total_ws_bytes += buf_size * sizeof(T); // permuted_data total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ total_ws_bytes += num_softmax_outs * sizeof(T); - const size_t bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T); - const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows)); + + size_t bytes_for_fc1_result; + if (activation_type_ == ActivationType::SwiGLU) { + // Space for both fc1_result_ and act_result_. + bytes_for_fc1_result = (2 * interbuf_size + interbuf_size) * sizeof(T); + } else { + bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T); + } + + const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows)); sorter_.update_num_experts(static_cast(num_experts)); size_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result; @@ -705,7 +803,7 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(size_t num_ro bytes_for_intermediate_and_sorting += remaining_bytes; } - total_ws_bytes += bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub sorting workspace + total_ws_bytes += bytes_for_intermediate_and_sorting; return total_ws_bytes; } @@ -725,27 +823,49 @@ void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, total_rows_before_expert_ = reinterpret_cast(permuted_data_ + buf_size); + char* current_ptr = reinterpret_cast(total_rows_before_expert_ + padded_experts); + + if (activation_type_ == ActivationType::SwiGLU) { + // fc1_result_ is used for GEMM1 output (2 * inter_size) + fc1_result_ = reinterpret_cast(current_ptr); + current_ptr += 2 * interbuf_size * sizeof(T); + + // act_result_ is used for SwiGLU output (inter_size) + act_result_ = reinterpret_cast(current_ptr); + current_ptr += interbuf_size * sizeof(T); + + ORT_ENFORCE(!has_fc3_, "SwiGLU activation is not supported with fc3"); + } else { + fc1_result_ = reinterpret_cast(current_ptr); + act_result_ = nullptr; // No extra buffer for activation since it is done inplace. + current_ptr += interbuf_size * sizeof(T); + } + if (has_fc3_) { - fc3_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); - fc1_result_ = reinterpret_cast(fc3_result_ + interbuf_size); + fc3_result_ = reinterpret_cast(current_ptr); + current_ptr += interbuf_size * sizeof(T); } else { - fc1_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); + fc3_result_ = nullptr; } const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); if (!is_pow_2 || num_experts > 256) { - softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size); + softmax_out_ = reinterpret_cast(current_ptr); } else { softmax_out_ = nullptr; } } namespace { - -struct __align__(8) Half4 { +typedef struct __CUDA_ALIGN__(8) { half2 x; half2 y; -}; +} half2_2; + +typedef struct __CUDA_ALIGN__(8) { + __nv_bfloat162 x; + __nv_bfloat162 y; +} __nv_bfloat162_2; // TODO(wy): move to common header template @@ -756,7 +876,11 @@ struct T4 { }; template <> struct T4 { - using Type = Half4; + using Type = half2_2; +}; +template <> +struct T4<__nv_bfloat16> { + using Type = __nv_bfloat162_2; }; template @@ -769,6 +893,10 @@ template <> struct T2 { using Type = half2; }; +template <> +struct T2<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; inline __device__ float2 operator*(const float2 a, const float2 b) { return make_float2(a.x * b.x, a.y * b.y); } @@ -785,15 +913,27 @@ inline __device__ half2 operator*(const half2 a, const half2 b) { return make_ha #endif // TODO(wy): use cuda common header and investigate pipeline build issue. -inline __device__ Half4 operator*(const Half4 a, const Half4 b) { +inline __device__ half2_2 operator*(const half2_2 a, const half2_2 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \ ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) - Half4 result; + half2_2 result; + result.x = a.x * b.x; + result.y = a.y * b.y; + return result; +#else + return half2_2{__hmul2(a.x, b.x), __hmul2(a.y, b.y)}; +#endif +} + +inline __device__ __nv_bfloat162_2 operator*(const __nv_bfloat162_2 a, const __nv_bfloat162_2 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 && \ + ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + __nv_bfloat162_2 result; result.x = a.x * b.x; result.y = a.y * b.y; return result; #else - return Half4{__hmul2(a.x, b.x), __hmul2(a.y, b.y)}; + return __nv_bfloat162_2{__hmul2(a.x, b.x), __hmul2(a.y, b.y)}; #endif } @@ -880,8 +1020,54 @@ void CutlassMoeFCRunner::run_moe_fc( stream); } - // moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size, - // expanded_active_expert_rows); + if (fc1_activation_type == ActivationType::SwiGLU) { + T* gemm1_output_buffer = fc1_result_; + T* swiglu_output_buffer = act_result_; + + moe_gemm_runner_.moe_gemm_bias_act( + permuted_data_ + total_past_rows_ * hidden_size, + fc1_expert_weights, + fc1_scales, + fc1_expert_biases, + gemm1_output_buffer + total_past_rows_ * 2 * inter_size, + total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, + 2 * inter_size, + hidden_size, + local_num_experts, + ActivationType::Identity, + stream); + + constexpr bool swiglu_interleaved = true; + constexpr bool swiglu_has_limit = true; + constexpr float swiglu_alpha = 1.702f; + constexpr float swiglu_limit = 7.0f; + invokeSwiGLU( + swiglu_output_buffer + total_past_rows_ * inter_size, + gemm1_output_buffer + total_past_rows_ * 2 * inter_size, + inter_size, + static_cast(total_covered_rows_), + swiglu_alpha, + swiglu_limit, + stream); + + moe_gemm_runner_.moe_gemm( + swiglu_output_buffer + total_past_rows_ * inter_size, + fc2_expert_weights, + fc2_scales, + nullptr, + fc2_result + total_past_rows_ * hidden_size, + total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, + hidden_size, + inter_size, + local_num_experts, + stream); + + // No fc3 for SwiGLU + return; + } + moe_gemm_runner_.moe_gemm_bias_act( permuted_data_ + total_past_rows_ * hidden_size, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_result_ + total_past_rows_ * inter_size, total_rows_before_expert_ + local_experts_start_index, @@ -1151,18 +1337,26 @@ template void topk_gating_softmax_kernelLauncher(const float*, const bool*, floa int, bool, bool, cudaStream_t); template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int, int, int, bool, bool, cudaStream_t); +template void topk_gating_softmax_kernelLauncher(const __nv_bfloat16*, const bool*, __nv_bfloat16*, __nv_bfloat16*, int*, int*, int, int, + int, bool, bool, cudaStream_t); // ==================== Variable batched GEMM specializations ================================== template class CutlassMoeFCRunner; template class CutlassMoeFCRunner; +template class CutlassMoeFCRunner<__nv_bfloat16, __nv_bfloat16>; +// For qMoE: template class CutlassMoeFCRunner; template class CutlassMoeFCRunner; +template class CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>; +template class CutlassMoeFCRunner<__nv_bfloat16, uint8_t>; // ===================== Specializations for init routing ========================= template void initialize_moe_routing_kernelLauncher(const float*, float*, const int*, int*, int, int, int, int, cudaStream_t); template void initialize_moe_routing_kernelLauncher(const half*, half*, const int*, int*, int, int, int, int, cudaStream_t); +template void initialize_moe_routing_kernelLauncher(const __nv_bfloat16*, __nv_bfloat16*, const int*, int*, int, int, int, int, + cudaStream_t); // ==================== Specializations for final routing =================================== template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const int*, @@ -1177,5 +1371,10 @@ template void finalize_moe_routing_kernelLauncher(const float*, float*, const fl const float*, const int*, const int*, int, int, int, cudaStream_t); template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, const half*, const int*, const int*, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const __nv_bfloat16*, __nv_bfloat16*, const __nv_bfloat16*, + const __nv_bfloat16*, const int*, const int*, int, int, int, cudaStream_t); + +template void invokeSwiGLU(float*, float const*, int, int, float, float, cudaStream_t); +template void invokeSwiGLU(half*, half const*, int, int, float, float, cudaStream_t); } // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index c457b608decbf..de11d357a8c07 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -54,7 +54,10 @@ static inline size_t pad_to_multiple_of_16(size_t input) { template void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_out, int* indices, int* source_row, int num_rows, int num_experts, int k, - cudaStream_t stream); + bool normalize_routing_weights, bool use_sparse_mixer, cudaStream_t stream); + +template +void invokeSwiGLU(T* output, T const* input, int intermediate_size, int num_rows, float swiglu_alpha, cudaStream_t stream); class CubKeyValueSorter { public: @@ -109,7 +112,7 @@ template class CutlassMoeFCRunner { public: - CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); + CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k); @@ -157,8 +160,10 @@ class CutlassMoeFCRunner { int64_t* total_rows_before_expert_; T* fc1_result_; + T* act_result_; T* fc3_result_; + ActivationType activation_type_; bool has_fc3_; bool normalize_routing_weights_; bool use_sparse_mixer_; @@ -173,14 +178,4 @@ class CutlassMoeFCRunner { std::vector total_rows_before_expert_host_; }; -template -class CutlassMoeFCRunner::value>> { - public: - CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); - - size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) { - return 0; - } -}; - } // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index c5352d931ce2c..a5b9d483d5ad1 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -3,6 +3,7 @@ #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cuda_type_conversion.h" #include "moe.h" using namespace onnxruntime::cuda; @@ -20,6 +21,7 @@ namespace cuda { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) template MoE::MoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) { @@ -37,19 +39,25 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { const Tensor* fc3_experts_bias_optional = context->Input(7); MoEParameters moe_params; - MoEQuantType quant_type = MoEQuantType::None; - ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights, - fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, - fc3_experts_weights_optional, fc3_experts_bias_optional)); - - typedef typename ToCudaType::MappedType CudaT; + ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs( + moe_params, input, router_probs, + fc1_experts_weights, fc1_experts_bias_optional, nullptr, + fc2_experts_weights, fc2_experts_bias_optional, nullptr, + fc3_experts_weights_optional, fc3_experts_bias_optional, nullptr, + 1, // no quantization so pack size is 1 + activation_type_ == ort_fastertransformer::ActivationType::SwiGLU)); + + using CudaT = typename OrtToCudaType::type; auto stream = context->GetComputeStream(); auto& device_prop = GetDeviceProp(); const int sm = device_prop.major * 10 + device_prop.minor; - ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, - normalize_routing_weights_, use_sparse_mixer_); + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, + activation_type_, + fc3_experts_weights_optional != nullptr, + normalize_routing_weights_, + use_sparse_mixer_); size_t ws_size = moe_runner.getWorkspaceSize( static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index 6b65557444a66..5f0c30b16a8f4 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -7,206 +7,13 @@ #include "core/framework/tensor_shape.h" #include "core/framework/op_kernel.h" #include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h" +#include "contrib_ops/cpu/moe/moe_helper.h" namespace onnxruntime { namespace contrib { namespace cuda { -enum class MoEParallelType { - None = 0, - EP = 1, - TP = 2, - EPAndTP = 3, -}; - -enum class MoEQuantType { - None = 0, - UINT4 = 1, - UINT8 = 2, -}; - -struct MoEParameters { - MoEParameters() {} - explicit MoEParameters(int64_t tensor_shards) : tensor_shards(tensor_shards) {} - int64_t num_rows; - int64_t num_experts; - int64_t local_num_experts; - int64_t hidden_size; - int64_t inter_size; - - MoEParallelType parallel_type; - int64_t tensor_shards{1}; -}; - class MoEBase { - public: - Status CheckInputs(MoEParameters& parameters, MoEQuantType& quant_type, const Tensor* input, - const Tensor* router_probs, const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional) const { - const auto& input_dims = input->Shape().GetDims(); - const auto& router_probs_dims = router_probs->Shape().GetDims(); - const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); - const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims(); - - int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1]; - int64_t hidden_size = input_dims[input_dims.size() - 1]; - int64_t local_num_experts = fc1_experts_weights_dims[0]; - int64_t num_experts = router_probs_dims[1]; - int64_t inter_size = fc2_experts_weights_dims[1]; - - if (fc1_experts_weights_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims must be 3D, got ", - fc1_experts_weights_dims.size()); - } - if (fc2_experts_weights_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_weights_dims must be 3D, got ", - fc2_experts_weights_dims.size()); - } - if (fc1_experts_weights_dims[1] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_weights_dims[1] must be equal to hidden_size, got ", - fc1_experts_weights_dims[1], " and ", hidden_size); - } - if (fc2_experts_weights_dims[1] != inter_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_weights_dims[1] must be equal to inter_size, got ", - fc2_experts_weights_dims[1], " and ", inter_size); - } - - const int64_t coe = quant_type == MoEQuantType::UINT4 ? 2 : 1; - if (fc1_experts_weights_dims[2] != inter_size / coe) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_weights_dims[2] must be equal to inter_size, got ", - fc1_experts_weights_dims[2], " and ", inter_size); - } - if (fc2_experts_weights_dims[2] != hidden_size / coe) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_weights_dims[2] must be equal to hidden_size, got ", - fc2_experts_weights_dims[2], " and ", hidden_size); - } - - if (router_probs_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims must be 2D, got ", - router_probs_dims.size()); - } - if (router_probs_dims[0] != num_rows) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[0] must be equal to num_rows, got ", - router_probs_dims[0], " and ", num_rows); - } - if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional != nullptr) { - const auto& fc1_experts_bias_dims = fc1_experts_bias_optional->Shape().GetDims(); - const auto& fc2_experts_bias_dims = fc2_experts_bias_optional->Shape().GetDims(); - if (fc1_experts_bias_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims must be 2D, got ", - fc1_experts_bias_dims.size()); - } - if (fc2_experts_bias_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims must be 2D, got ", - fc2_experts_bias_dims.size()); - } - if (fc1_experts_bias_dims[0] != local_num_experts) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_bias_dims[0] must be equal to local_num_experts, got ", - fc1_experts_bias_dims[0], " and ", local_num_experts); - } - if (fc2_experts_bias_dims[0] != num_experts) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_bias_dims[0] must be equal to num_experts, got ", fc2_experts_bias_dims[0], - " and ", num_experts); - } - if (fc1_experts_bias_dims[1] != inter_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_bias_dims[1] must be equal to inter_size, got ", fc1_experts_bias_dims[1], - " and ", inter_size); - } - if (fc2_experts_bias_dims[1] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_bias_dims[1] must be equal to hidden_size, got ", fc2_experts_bias_dims[1], - " and ", hidden_size); - } - } - - if (fc3_experts_weights_optional != nullptr && - fc3_experts_weights_optional->Shape().GetDims() != fc1_experts_weights_dims) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc3_experts_weights_dims must be equal to fc1_experts_weights_dims, got ", - fc3_experts_weights_optional->Shape(), " and ", TensorShape(fc1_experts_weights_dims)); - } - - if (fc3_experts_bias_optional != nullptr && fc1_experts_bias_optional != nullptr && - fc3_experts_bias_optional->Shape().GetDims() != fc1_experts_bias_optional->Shape().GetDims()) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, INVALID_ARGUMENT, "fc3_experts_bias_dims must be equal to fc1_experts_bias_dims, got ", - fc3_experts_bias_optional->Shape(), " and ", fc1_experts_bias_optional->Shape()); - } - - parameters.num_rows = num_rows; - parameters.num_experts = num_experts; - parameters.local_num_experts = local_num_experts; - parameters.hidden_size = hidden_size; - parameters.inter_size = inter_size; - if (num_experts == local_num_experts) { - if (parameters.tensor_shards == 1) { - parameters.parallel_type = MoEParallelType::None; - } else { - parameters.parallel_type = MoEParallelType::TP; - } - } else if (num_experts > local_num_experts) { - if (parameters.tensor_shards == 1) { - parameters.parallel_type = MoEParallelType::EP; - } else { - parameters.parallel_type = MoEParallelType::EPAndTP; - } - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_experts must be greater than or equal to local_num_experts, got ", num_experts, - " and ", local_num_experts); - } - - return Status::OK(); - } - - Status CheckInputScales(const Tensor* fc1_experts_scales, const Tensor* fc2_experts_scales, - const Tensor* fc3_experts_scales, int64_t num_experts, int64_t hidden_size, - int64_t inter_size) const { - const auto& fc1_experts_scales_dims = fc1_experts_scales->Shape().GetDims(); - const auto& fc2_experts_scales_dims = fc2_experts_scales->Shape().GetDims(); - - if (fc1_experts_scales_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales must be 2D, got ", - fc1_experts_scales->Shape().GetDims().size()); - } - if (fc1_experts_scales_dims[0] != num_experts) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[0] must be equal to num_experts, got ", - fc1_experts_scales_dims[0], " and ", num_experts); - } - if (fc1_experts_scales_dims[1] != inter_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[1] must be equal to inter_size, got ", - fc1_experts_scales_dims[1], " and ", inter_size); - } - if (fc2_experts_scales_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales must be 2D, got ", - fc2_experts_scales->Shape().GetDims().size()); - } - if (fc2_experts_scales_dims[0] != num_experts) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales[0] must be equal to num_experts, got ", - fc2_experts_scales_dims[0], " and ", num_experts); - } - if (fc2_experts_scales_dims[1] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales[1] must be equal to hidden_size, got ", - fc2_experts_scales_dims[1], " and ", hidden_size); - } - if (fc3_experts_scales != nullptr && fc1_experts_scales_dims != fc3_experts_scales->Shape().GetDims()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc3_experts_scales must be equal to fc1_experts_scales, got ", - fc3_experts_scales->Shape(), " and ", TensorShape(fc1_experts_scales_dims)); - } - - return Status::OK(); - } - protected: MoEBase(const OpKernelInfo& op_kernel_info) { ORT_ENFORCE(op_kernel_info.GetAttr("k", &k_).IsOK()); @@ -219,6 +26,8 @@ class MoEBase { activation_type_ = ort_fastertransformer::ActivationType::Gelu; } else if (activation_type_str == "silu") { activation_type_ = ort_fastertransformer::ActivationType::Silu; + } else if (activation_type_str == "swiglu") { + activation_type_ = ort_fastertransformer::ActivationType::SwiGLU; } else if (activation_type_str == "identity") { activation_type_ = ort_fastertransformer::ActivationType::Identity; } else { diff --git a/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc new file mode 100644 index 0000000000000..bad44b260b7b2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/quantization/gather_block_quantized.h" +#include "contrib_ops/cuda/quantization/gather_block_quantized.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +#define REGISTER_GATHERBLOCKQUANTIZED(T1, T2, Tind) \ + ONNX_OPERATOR_THREE_TYPED_KERNEL_EX( \ + GatherBlockQuantized, \ + kMSDomain, 1, \ + T1, T2, Tind, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \ + GatherBlockQuantized); + +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, float, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, float, int64_t); +REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, float, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, float, int64_t); +REGISTER_GATHERBLOCKQUANTIZED(Int4x2, float, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(Int4x2, float, int64_t); + +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, MLFloat16, int64_t); +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, MLFloat16, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, MLFloat16, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, MLFloat16, int64_t); +REGISTER_GATHERBLOCKQUANTIZED(Int4x2, MLFloat16, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(Int4x2, MLFloat16, int64_t); + +REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, BFloat16, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, BFloat16, int64_t); +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, BFloat16, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, BFloat16, int64_t); +REGISTER_GATHERBLOCKQUANTIZED(Int4x2, BFloat16, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(Int4x2, BFloat16, int64_t); + +template +GatherBlockQuantized::GatherBlockQuantized(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(info.GetAttr("bits", &bits_).IsOK()); + + block_size_ = info.GetAttrOrDefault("block_size", 0); + gather_axis_ = info.GetAttrOrDefault("gather_axis", 0); + quantize_axis_ = info.GetAttrOrDefault("quantize_axis", 0); + + // If block size is set, it has to be no smaller than 16 and must be power of 2 + // block_size_ & (block_size_ - 1) == 0 checks if block_size_ only has 1 bit set + ORT_ENFORCE(block_size_ == 0 || (block_size_ >= 16 && ((block_size_ & (block_size_ - 1)) == 0))); +} + +template +Status GatherBlockQuantized::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* data = ctx->Input(0); + const Tensor* indices = ctx->Input(1); + const Tensor* scales = ctx->Input(2); + const Tensor* zero_points = ctx->Input(3); + + auto data_shape = data->Shape().GetDims(); + auto data_rank = data->Shape().NumDimensions(); + + auto indices_shape = indices->Shape().GetDims(); + auto indices_rank = indices->Shape().NumDimensions(); + + ORT_ENFORCE(quantize_axis_ == data_rank - 1); + + TensorShapeVector output_shape; + output_shape.reserve(data_rank - 1 + indices_rank); + + // Dimension after gather axis + int64_t after_gather_dim = 1; + + // Dimension of indices + int64_t ind_dim = 1; + + // 1) dims before gather_axis + for (int64_t i = 0; i < gather_axis_; ++i) { + output_shape.push_back(data_shape[i]); + } + + // 2) all of indices.shape + for (auto dim : indices_shape) { + output_shape.push_back(dim); + ind_dim *= dim; + } + + // 3) dims after gather_axis + for (int64_t i = gather_axis_ + 1; i < data_rank; ++i) { + output_shape.push_back(data_shape[i]); + after_gather_dim *= data_shape[i]; + } + + // Special int4‐in‐uint8 packing tweak: expand the last dim by components + if constexpr (std::is_same_v) { + uint32_t components = 8 / static_cast(bits_); + if (components > 1) { + output_shape.back() *= components; + } + } + + Tensor* output = ctx->Output(0, TensorShape(output_shape)); + + int64_t N = 1; + for (auto dim : output_shape) { + N *= dim; + } + + const auto* data_ptr = data->Data(); + const auto* indices_ptr = indices->Data(); + const T1* zero_points_ptr = nullptr; + if (zero_points != nullptr) { + zero_points_ptr = zero_points->Data(); + } + + GatherBlockQuantizedParam param; + param.stream = Stream(ctx); + param.after_gather_dim = after_gather_dim; + param.gather_axis_dim = data_shape[gather_axis_]; + param.ind_dim = ind_dim; + param.bits = bits_; + param.block_size = block_size_; + param.gather_axis = gather_axis_; + param.N = N; + + const auto dequantized_type = scales->GetElementType(); + if (dequantized_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + const auto* scales_ptr = static_cast(scales->DataRaw()); + auto* output_ptr = static_cast(output->MutableDataRaw()); + LaunchGatherBlockQuantizedKernel(data_ptr, indices_ptr, scales_ptr, zero_points_ptr, output_ptr, param); + } else if (dequantized_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + const auto* scales_ptr = static_cast(scales->DataRaw()); + auto* output_ptr = static_cast(output->MutableDataRaw()); + LaunchGatherBlockQuantizedKernel(data_ptr, indices_ptr, scales_ptr, zero_points_ptr, output_ptr, param); + } else if (dequantized_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) { + const auto* scales_ptr = static_cast(scales->DataRaw()); + auto* output_ptr = static_cast(output->MutableDataRaw()); + LaunchGatherBlockQuantizedKernel(data_ptr, indices_ptr, scales_ptr, zero_points_ptr, output_ptr, param); + } + + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cu b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cu new file mode 100644 index 0000000000000..39286c63e9a08 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cu @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "gather_block_quantized.cuh" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__device__ inline int64_t get_val(const T1* data, int64_t idx, int64_t bits, bool sign) { + const uint32_t mask = (1U << bits) - 1; + const int64_t elems_per_byte = 8 / bits; + const int64_t byte_idx = idx / elems_per_byte; + const int64_t bit_offset = (idx % elems_per_byte) * bits; + const uint8_t byte = reinterpret_cast(data)[byte_idx]; + int64_t val = (byte >> bit_offset) & mask; + + // Sign-extend based on bit width + if (sign) { + if (val & (1 << (bits - 1))) { + val |= -1LL << bits; + } + } + + return val; +} + +template +__global__ void GatherBlockQuantizedKernel( + const T1* data, // packed 4-bit codes, one code per element + const Tind* indices, + const T2* scales, // one float scale per block + const T1* zero_points, // packed 4-bit zero-points, one per block + T2* output, + int64_t after_gather_dim, + int64_t gather_axis_dim, + int64_t ind_dim, + int64_t bits, + int64_t block_size, + int64_t gather_axis, + int64_t N, + bool sign) { + int64_t out_idx = blockDim.x * blockIdx.x + threadIdx.x; + if (out_idx >= N) return; + + // compute which input element this thread corresponds to: + int64_t idx_before = out_idx / (after_gather_dim * ind_dim); + int64_t idx_after = out_idx % after_gather_dim; + int64_t idx = (out_idx % (after_gather_dim * ind_dim)) / after_gather_dim; + int64_t idx_at_g = indices[idx]; + int64_t in_idx = idx_before * gather_axis_dim * after_gather_dim + idx_at_g * after_gather_dim + idx_after; + + int64_t block_id = in_idx / block_size; + + // unpack zero_point for this block: + int64_t offset = 0; + if (zero_points) { + offset = get_val(zero_points, block_id, bits, sign); + } + + // unpack the raw quantized code for this element: + int64_t weight = get_val(data, in_idx, bits, sign); + + // apply dequantization: + output[out_idx] = static_cast(weight - offset) * scales[block_id]; +} + +template +void LaunchGatherBlockQuantizedKernel(const T1* data, + const Tind* indices, + const T2* scales, + const T1* zero_points, + T2* output, + GatherBlockQuantizedParam param) { + // Require quant_axis is last dim + int blocksPerGrid = (int)(ceil(static_cast(param.N) / GridDim::maxThreadsPerBlock)); + bool sign = std::is_same::value; + + GatherBlockQuantizedKernel<<>>(data, indices, scales, zero_points, output, + param.after_gather_dim, param.gather_axis_dim, param.ind_dim, param.bits, param.block_size, param.gather_axis, param.N, sign); +} + +template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int32_t*, const float*, const uint8_t*, float*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int64_t*, const float*, const uint8_t*, float*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int32_t*, const float*, const UInt4x2*, float*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int64_t*, const float*, const UInt4x2*, float*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int32_t*, const float*, const Int4x2*, float*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int64_t*, const float*, const Int4x2*, float*, GatherBlockQuantizedParam); + +template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int32_t*, const half*, const uint8_t*, half*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int64_t*, const half*, const uint8_t*, half*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int32_t*, const half*, const UInt4x2*, half*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int64_t*, const half*, const UInt4x2*, half*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int32_t*, const half*, const Int4x2*, half*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int64_t*, const half*, const Int4x2*, half*, GatherBlockQuantizedParam); + +template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int32_t*, const BFloat16*, const uint8_t*, BFloat16*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int64_t*, const BFloat16*, const uint8_t*, BFloat16*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int32_t*, const BFloat16*, const UInt4x2*, BFloat16*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int64_t*, const BFloat16*, const UInt4x2*, BFloat16*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int32_t*, const BFloat16*, const Int4x2*, BFloat16*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int64_t*, const BFloat16*, const Int4x2*, BFloat16*, GatherBlockQuantizedParam); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cuh b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cuh new file mode 100644 index 0000000000000..f5dea3b1f2d9d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cuh @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +struct GatherBlockQuantizedParam { + cudaStream_t stream; + int64_t after_gather_dim; + int64_t gather_axis_dim; + int64_t ind_dim; + int64_t bits; + int64_t block_size; + int64_t gather_axis; + int64_t N; +}; + +template +void LaunchGatherBlockQuantizedKernel(const T1* data, + const Tind* indices, + const T2* scales, + const T1* zero_points, + T2* output, + GatherBlockQuantizedParam param); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.h b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.h new file mode 100644 index 0000000000000..7718b6dd06765 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cuda_kernel.h" + +#include + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class GatherBlockQuantized final : public CudaKernel { + public: + GatherBlockQuantized(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t bits_; + int64_t block_size_; + int64_t gather_axis_; + int64_t quantize_axis_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 2e862ff816bef..39821ce0afae8 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -13,9 +13,11 @@ #include "contrib_ops/cpu/utils/dump_tensor.h" #include "contrib_ops/cuda/quantization/matmul_nbits.cuh" #include "contrib_ops/cuda/quantization/dequantize_blockwise.cuh" +#if USE_FPA_INTB_GEMM #include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h" #include "contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h" #include "contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h" +#endif #include "contrib_ops/cuda/llm/common/logger.h" #include "contrib_ops/cpu/quantization/matmul_nbits_helper.h" @@ -27,6 +29,8 @@ namespace onnxruntime { namespace contrib { namespace cuda { using namespace onnxruntime::cuda; + +#if USE_FPA_INTB_GEMM using onnxruntime::llm::kernels::weight_only::GemmPluginProfilerManager; using onnxruntime::llm::kernels::weight_only::WeightOnlyGroupwiseQuantGemmPluginProfiler; using onnxruntime::llm::kernels::weight_only::WeightTypeId; @@ -246,6 +250,7 @@ Status MatMulNBits::PrePack_ZeroPoint([[maybe_unused]] const Tensor& tensor, } return Status::OK(); } +#endif template Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { @@ -257,11 +262,25 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { } const Tensor* a = ctx->Input(0); + const Tensor* reorder_idx = ctx->Input(4); + const Tensor* bias = ctx->Input(5); + +#if USE_FPA_INTB_GEMM const Tensor* b = is_prepacked_weight_ ? nullptr : ctx->Input(1); const Tensor* scales = is_prepacked_scale_ ? nullptr : ctx->Input(2); const Tensor* zero_points = is_prepacked_zero_point_ ? nullptr : ctx->Input(3); - const Tensor* reorder_idx = ctx->Input(4); - const Tensor* bias = ctx->Input(5); + const uint8_t* blob_data = is_prepacked_weight_ ? nullptr : b->Data(); + const auto* scales_data = is_prepacked_scale_ ? nullptr : scales->Data(); + const auto* zero_points_data = (is_prepacked_zero_point_ || zero_points == nullptr) ? nullptr : zero_points->DataRaw(); + const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); +#else + const Tensor* b = ctx->Input(1); + const Tensor* scales = ctx->Input(2); + const Tensor* zero_points = ctx->Input(3); + const uint8_t* blob_data = b->Data(); + const auto* scales_data = scales->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); +#endif if (bias != nullptr) { ORT_THROW("MatMulNBits does not support bias in CUDA kernel"); @@ -271,11 +290,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { a, b, scales, zero_points, reorder_idx, bias, N_, K_, block_size_, nbits_)); const auto* a_data = a->Data(); - const uint8_t* blob_data = is_prepacked_weight_ ? nullptr : b->Data(); - const auto* scales_data = is_prepacked_scale_ ? nullptr : scales->Data(); - const auto* zero_points_data = (is_prepacked_zero_point_ || zero_points == nullptr) ? nullptr : zero_points->DataRaw(); const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); - const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); constexpr bool transa = false; constexpr bool transb = true; @@ -292,7 +307,6 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { cudaStream_t stream = static_cast(ctx->GetComputeStream()->GetHandle()); typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; - CudaT* out_data = reinterpret_cast(Y->MutableData()); int m = SafeInt(helper.M()); int n = SafeInt(helper.N()); @@ -300,6 +314,9 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { DUMP_TENSOR_INIT(); +#if USE_FPA_INTB_GEMM + CudaT* out_data = reinterpret_cast(Y->MutableData()); + if constexpr (std::is_same::value || std::is_same::value) { if (has_fpA_intB_gemm_) { // We expect weight/scale/zero_point(optional) inputs are initializers and have been prepacked. @@ -356,6 +373,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { return Status::OK(); } } +#endif if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { if (TryMatMulNBits( diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h index 0d3558f91f03e..a7f0a9516584c 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -17,6 +17,8 @@ namespace onnxruntime { namespace contrib { namespace cuda { using namespace onnxruntime::cuda; + +#if USE_FPA_INTB_GEMM using onnxruntime::llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner; using onnxruntime::llm::kernels::weight_only::GemmDims; using onnxruntime::llm::kernels::weight_only::GemmIdCore; @@ -31,6 +33,7 @@ constexpr int kFpAIntBGemmOption_All = 0x01; constexpr int kFpAIntBGemmOption_Gemv = 0x02; constexpr int kFpAIntBGemmOption_Int4 = 0x04; constexpr int kFpAIntBGemmOption_Int8 = 0x08; +#endif template class MatMulNBits final : public CudaKernel { @@ -57,6 +60,7 @@ class MatMulNBits final : public CudaKernel { is_zero_points_scale_same_type_ = (zero_point_type == scale_type); } +#if USE_FPA_INTB_GEMM if constexpr (std::is_same::value || std::is_same::value) { int option = ParseEnvironmentVariableWithDefault(kFpAIntBGemmOption, 0); if ((option & (static_cast(nbits_) | kFpAIntBGemmOption_All)) != 0 && @@ -94,21 +98,25 @@ class MatMulNBits final : public CudaKernel { has_zero_points_ ? (is_zero_points_scale_same_type_ ? int(sizeof(T)) * 8 : int(nbits_)) : int(0), int(has_g_idx_ ? 1 : 0), int(has_bias_ ? 1 : 0), int(has_fpA_intB_gemv_), int(has_fpA_intB_gemm_)); +#endif #endif } Status ComputeInternal(OpKernelContext* context) const override; - +#if USE_FPA_INTB_GEMM Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights) override; +#endif private: +#if USE_FPA_INTB_GEMM void InitGemmProfiler(int sm); void RunGemmProfile(bool hasWeightOnlyCudaKernel, int min_m, int max_m); Status PrePack_B(const Tensor& tensor, AllocatorPtr alloc, cudaStream_t stream); Status PrePack_Scale(const Tensor& tensor, AllocatorPtr alloc, cudaStream_t stream); Status PrePack_ZeroPoint(const Tensor& tensor, AllocatorPtr alloc, cudaStream_t stream); +#endif int64_t K_; int64_t N_; @@ -121,6 +129,8 @@ class MatMulNBits final : public CudaKernel { bool has_bias_{false}; bool has_zero_points_{false}; bool is_zero_points_scale_same_type_{false}; + +#if USE_FPA_INTB_GEMM bool has_fpA_intB_gemv_{false}; bool has_fpA_intB_gemm_{false}; @@ -135,6 +145,7 @@ class MatMulNBits final : public CudaKernel { IAllocatorUniquePtr fpA_intB_weight_buffer_; IAllocatorUniquePtr fpA_intB_scale_buffer_; IAllocatorUniquePtr fpA_intB_zero_buffer_; +#endif }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc index 4dd5a079d1a29..dcf32bb3c5ae4 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -5,6 +5,7 @@ #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/quantization/moe_quantization.h" +#include "core/providers/cuda/cuda_type_conversion.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -14,16 +15,6 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_KERNEL() \ - ONNX_OPERATOR_KERNEL_EX(QMoE, kMSDomain, 1, kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(0, 0) \ - .TypeConstraint("T", BuildKernelDefConstraints()) \ - .TypeConstraint("T1", BuildKernelDefConstraints()), \ - QMoE); - -REGISTER_KERNEL() - namespace { template struct ToCudaTypeWrapper : public ToCudaType {}; @@ -40,27 +31,29 @@ struct ToCudaTypeWrapper { } // anonymous namespace -QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) { +template +QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) { ORT_ENFORCE(op_kernel_info.GetAttr("expert_weight_bits", &expert_weight_bits_).IsOK()); ORT_ENFORCE(expert_weight_bits_ == 8 || expert_weight_bits_ == 4, "expert_weight_bits must be 4 or 8, but got ", expert_weight_bits_); } +template template -Status QMoE::QuantizedMoEImpl(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, - const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - const Tensor* fc3_scales_optional, - const cudaDeviceProp& device_prop) const { +Status QMoE::QuantizedMoEImpl(OpKernelContext* context, + MoEParameters& moe_params, + const Tensor* input, + const Tensor* router_probs, + const Tensor* fc1_experts_weights, + const Tensor* fc1_experts_bias_optional, + const Tensor* fc2_experts_weights, + const Tensor* fc2_experts_bias_optional, + const Tensor* fc3_experts_weights_optional, + const Tensor* fc3_experts_bias_optional, + const Tensor* fc1_scales, + const Tensor* fc2_scales, + const Tensor* fc3_scales_optional, + const cudaDeviceProp& device_prop) const { auto stream = context->GetComputeStream(); const int sm = device_prop.major * 10 + device_prop.minor; @@ -68,10 +61,10 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context, AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - using T = MLFloat16; - using CudaT = typename ToCudaType::MappedType; + using CudaT = typename OrtToCudaType::type; ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, + activation_type_, fc3_experts_weights_optional != nullptr, normalize_routing_weights_, use_sparse_mixer_); @@ -136,7 +129,8 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context, return Status::OK(); } -Status QMoE::ComputeInternal(OpKernelContext* context) const { +template +Status QMoE::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); const Tensor* router_probs = context->Input(1); const Tensor* fc1_experts_weights = context->Input(2); @@ -149,20 +143,21 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { const Tensor* fc3_scales_optional = context->Input(9); const Tensor* fc3_experts_bias_optional = context->Input(10); - MoEQuantType quant_type = expert_weight_bits_ == 4 ? MoEQuantType::UINT4 : MoEQuantType::UINT8; MoEParameters moe_params; - ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights, - fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, - fc3_experts_weights_optional, fc3_experts_bias_optional)); - ORT_RETURN_IF_ERROR(CheckInputScales(fc1_scales, fc2_scales, fc3_scales_optional, moe_params.num_experts, - moe_params.hidden_size, moe_params.inter_size)); + ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs( + moe_params, input, router_probs, + fc1_experts_weights, fc1_experts_bias_optional, fc1_scales, + fc2_experts_weights, fc2_experts_bias_optional, fc2_scales, + fc3_experts_weights_optional, fc3_experts_bias_optional, fc3_scales_optional, + expert_weight_bits_ == 4 ? 2 : 1, + activation_type_ == ort_fastertransformer::ActivationType::SwiGLU)); #if defined(__GNUC__) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // Mute "maybe used uninitialized" warning for MoEParameters. #endif - if (quant_type == MoEQuantType::UINT4) { + if (expert_weight_bits_ == 4) { using CudaWeightT = typename ToCudaTypeWrapper::MappedType; return QuantizedMoEImpl(context, moe_params, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights, @@ -183,6 +178,32 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { #endif } +ONNX_OPERATOR_TYPED_KERNEL_EX( + QMoE, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .MayInplace(0, 0) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QMoE); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QMoE, + kMSDomain, + 1, + BFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .MayInplace(0, 0) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QMoE); + } // namespace cuda } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h index c0164576d7c7f..c4698a1f277ef 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h @@ -14,6 +14,7 @@ namespace cuda { using namespace onnxruntime::cuda; +template class QMoE final : public CudaKernel, public MoEBase { public: explicit QMoE(const OpKernelInfo& op_kernel_info); diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 40d46cc3fba59..8b7b257dd2852 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -198,9 +198,12 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* present_value = context.Output(2, present_kv_shape); parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw(); + ORT_ENFORCE(parameters.total_sequence_length_ <= parameters.seqlen_present_kv_cache_, "Total sequence length cannot be greater than the existing KV cache length."); + // Use a sliding window if the total sequence exceeds the window's length. + bool use_sliding_window = (local_window_size_ != -1 && local_window_size_ < parameters.total_sequence_length_); if (!do_rotary_ && head_sink == nullptr && !use_smooth_softmax_ && - local_window_size_ == -1 && + !use_sliding_window && CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, present_value, parameters, context); diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index dccfdbda8971b..6c66047b4b36a 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -1,6 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "core/common/cpuid_info.h" + +#include +#include + #include "core/common/logging/logging.h" #include "core/common/logging/severity.h" #include "core/platform/check_intel.h" @@ -51,6 +55,14 @@ #endif // _WIN32 +#if defined(__APPLE__) +#if defined(CPUIDINFO_ARCH_ARM) + +#include + +#endif // defined(CPUIDINFO_ARCH_ARM) +#endif // defined(__APPLE__) + #if defined(CPUINFO_SUPPORTED) #include #if defined(CPUIDINFO_ARCH_ARM) @@ -74,6 +86,14 @@ void decodeMIDR(uint32_t midr, uint32_t uarch[1]); namespace onnxruntime { +void CPUIDInfo::LogEarlyWarning(std::string_view message) { + if (logging::LoggingManager::HasDefaultLogger()) { + LOGS_DEFAULT(WARNING) << message; + } else { + std::cerr << "onnxruntime cpuid_info warning: " << message << "\n"; + } +} + #if defined(CPUIDINFO_ARCH_X86) static inline void GetCPUID(int function_id, int data[4]) { // NOLINT @@ -108,9 +128,6 @@ void CPUIDInfo::X86Init() { int data[4] = {-1}; GetCPUID(0, data); - vendor_ = GetX86Vendor(data); - vendor_id_ = GetVendorId(vendor_); - int num_IDs = data[0]; if (num_IDs >= 1) { GetCPUID(1, data); @@ -158,24 +175,8 @@ void CPUIDInfo::X86Init() { } } -std::string CPUIDInfo::GetX86Vendor(int32_t* data) { - char vendor[sizeof(int32_t) * 3 + 1]{}; - *reinterpret_cast(vendor + 0) = data[1]; - *reinterpret_cast(vendor + 4) = data[3]; - *reinterpret_cast(vendor + 8) = data[2]; - return vendor; -} - #endif // defined(CPUIDINFO_ARCH_X86) -uint32_t CPUIDInfo::GetVendorId(const std::string& vendor) { - if (vendor == "GenuineIntel") return 0x8086; - if (vendor == "AuthenticAMD") return 0x1022; - if (vendor.find("Qualcomm") == 0) return 'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24); - if (vendor.find("NV") == 0) return 0x10DE; - return 0; -} - #if defined(CPUIDINFO_ARCH_ARM) #if defined(__linux__) @@ -228,10 +229,6 @@ void CPUIDInfo::ArmLinuxInit() { #elif defined(_WIN32) // ^ defined(__linux__) void CPUIDInfo::ArmWindowsInit() { - // Get the ARM vendor string from the registry - vendor_ = GetArmWindowsVendor(); - vendor_id_ = GetVendorId(vendor_); - // Read MIDR and ID_AA64ISAR1_EL1 register values from Windows registry // There should be one per CPU std::vector midr_values{}, id_aa64isar1_el1_values{}; @@ -323,15 +320,6 @@ void CPUIDInfo::ArmWindowsInit() { #endif // defined(CPUINFO_SUPPORTED) } -std::string CPUIDInfo::GetArmWindowsVendor() { - const int MAX_VALUE_NAME = 256; - const CHAR vendorKey[] = "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"; - CHAR vendorVal[MAX_VALUE_NAME] = ""; - unsigned long vendorSize = sizeof(char) * MAX_VALUE_NAME; - ::RegGetValueA(HKEY_LOCAL_MACHINE, vendorKey, "Vendor Identifier", RRF_RT_REG_SZ | RRF_ZEROONFAILURE, nullptr, &vendorVal, &vendorSize); - return vendorVal; -} - #elif defined(__APPLE__) // ^ defined(_WIN32) void CPUIDInfo::ArmAppleInit() { @@ -376,16 +364,21 @@ uint32_t CPUIDInfo::GetCurrentCoreIdx() const { } CPUIDInfo::CPUIDInfo() { -#ifdef CPUIDINFO_ARCH_X86 - X86Init(); -#elif defined(CPUIDINFO_ARCH_ARM) #if defined(CPUINFO_SUPPORTED) pytorch_cpuinfo_init_ = cpuinfo_initialize(); if (!pytorch_cpuinfo_init_) { - LOGS_DEFAULT(WARNING) << "Failed to initialize PyTorch cpuinfo library. May cause CPU EP performance degradation " - "due to undetected CPU features."; + LogEarlyWarning( + "Failed to initialize PyTorch cpuinfo library. May cause CPU EP performance degradation due to undetected CPU " + "features."); } #endif // defined(CPUINFO_SUPPORTED) + + // Note: This should be run after cpuinfo initialization if cpuinfo is enabled. + VendorInfoInit(); + +#ifdef CPUIDINFO_ARCH_X86 + X86Init(); +#elif defined(CPUIDINFO_ARCH_ARM) #if defined(__linux__) ArmLinuxInit(); #elif defined(_WIN32) diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 84571fa12e6ea..d49eca7e1d60c 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -103,7 +103,40 @@ class CPUIDInfo { } private: + // Log function that uses ORT logging if available or writes to stderr. + // This enables us to log even before ORT logging has been initialized. + static void LogEarlyWarning(std::string_view message); + CPUIDInfo(); + + void VendorInfoInit(); + +#if defined(CPUIDINFO_ARCH_X86) + + void X86Init(); + +#elif defined(CPUIDINFO_ARCH_ARM) + +#if defined(__linux__) + + void ArmLinuxInit(); + +#elif defined(_WIN32) + + void ArmWindowsInit(); + +#elif defined(__APPLE__) + + void ArmAppleInit(); + +#endif + +#endif // defined(CPUIDINFO_ARCH_ARM) + +#if defined(CPUINFO_SUPPORTED) + bool pytorch_cpuinfo_init_{false}; +#endif // defined(CPUINFO_SUPPORTED) + bool has_amx_bf16_{false}; bool has_avx_{false}; bool has_avx2_{false}; @@ -132,37 +165,6 @@ class CPUIDInfo { std::string vendor_; uint32_t vendor_id_; - - uint32_t GetVendorId(const std::string& vendor); - -#if defined(CPUIDINFO_ARCH_X86) - - void X86Init(); - std::string GetX86Vendor(int32_t* data); - -#elif defined(CPUIDINFO_ARCH_ARM) - -#if defined(CPUINFO_SUPPORTED) - // Now the following var is only used in ARM build, but later on we may expand the usage. - bool pytorch_cpuinfo_init_{false}; -#endif // defined(CPUINFO_SUPPORTED) - -#if defined(__linux__) - - void ArmLinuxInit(); - -#elif defined(_WIN32) - - void ArmWindowsInit(); - std::string GetArmWindowsVendor(); - -#elif defined(__APPLE__) - - void ArmAppleInit(); - -#endif - -#endif // defined(CPUIDINFO_ARCH_ARM) }; } // namespace onnxruntime diff --git a/onnxruntime/core/common/cpuid_info_vendor.cc b/onnxruntime/core/common/cpuid_info_vendor.cc new file mode 100644 index 0000000000000..d4d940eedfe28 --- /dev/null +++ b/onnxruntime/core/common/cpuid_info_vendor.cc @@ -0,0 +1,244 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/cpuid_info.h" + +#include +#include +#include + +#if defined(CPUINFO_SUPPORTED) +#include "cpuinfo.h" +#endif + +namespace { + +#if !defined(CPUINFO_SUPPORTED) + +// The `cpuinfo_vendor` enum is defined by the cpuinfo library. +// In case we don't build with cpuinfo, we define our own copy. +// The enum was copied from here: +// https://github.com/pytorch/cpuinfo/blob/8a1772a0c5c447df2d18edf33ec4603a8c9c04a6/include/cpuinfo.h#L154-L307 + +/** Vendor of processor core design */ +enum cpuinfo_vendor { + /** Processor vendor is not known to the library, or the library failed + to get vendor information from the OS. */ + cpuinfo_vendor_unknown = 0, + + /* Active vendors of modern CPUs */ + + /** + * Intel Corporation. Vendor of x86, x86-64, IA64, and ARM processor + * microarchitectures. + * + * Sold its ARM design subsidiary in 2006. The last ARM processor design + * was released in 2004. + */ + cpuinfo_vendor_intel = 1, + /** Advanced Micro Devices, Inc. Vendor of x86 and x86-64 processor + microarchitectures. */ + cpuinfo_vendor_amd = 2, + /** ARM Holdings plc. Vendor of ARM and ARM64 processor + microarchitectures. */ + cpuinfo_vendor_arm = 3, + /** Qualcomm Incorporated. Vendor of ARM and ARM64 processor + microarchitectures. */ + cpuinfo_vendor_qualcomm = 4, + /** Apple Inc. Vendor of ARM and ARM64 processor microarchitectures. */ + cpuinfo_vendor_apple = 5, + /** Samsung Electronics Co., Ltd. Vendir if ARM64 processor + microarchitectures. */ + cpuinfo_vendor_samsung = 6, + /** Nvidia Corporation. Vendor of ARM64-compatible processor + microarchitectures. */ + cpuinfo_vendor_nvidia = 7, + /** MIPS Technologies, Inc. Vendor of MIPS processor microarchitectures. + */ + cpuinfo_vendor_mips = 8, + /** International Business Machines Corporation. Vendor of PowerPC + processor microarchitectures. */ + cpuinfo_vendor_ibm = 9, + /** Ingenic Semiconductor. Vendor of MIPS processor microarchitectures. + */ + cpuinfo_vendor_ingenic = 10, + /** + * VIA Technologies, Inc. Vendor of x86 and x86-64 processor + * microarchitectures. + * + * Processors are designed by Centaur Technology, a subsidiary of VIA + * Technologies. + */ + cpuinfo_vendor_via = 11, + /** Cavium, Inc. Vendor of ARM64 processor microarchitectures. */ + cpuinfo_vendor_cavium = 12, + /** Broadcom, Inc. Vendor of ARM processor microarchitectures. */ + cpuinfo_vendor_broadcom = 13, + /** Applied Micro Circuits Corporation (APM). Vendor of ARM64 processor + microarchitectures. */ + cpuinfo_vendor_apm = 14, + /** + * Huawei Technologies Co., Ltd. Vendor of ARM64 processor + * microarchitectures. + * + * Processors are designed by HiSilicon, a subsidiary of Huawei. + */ + cpuinfo_vendor_huawei = 15, + /** + * Hygon (Chengdu Haiguang Integrated Circuit Design Co., Ltd), Vendor + * of x86-64 processor microarchitectures. + * + * Processors are variants of AMD cores. + */ + cpuinfo_vendor_hygon = 16, + /** SiFive, Inc. Vendor of RISC-V processor microarchitectures. */ + cpuinfo_vendor_sifive = 17, + + /* Active vendors of embedded CPUs */ + + /** Texas Instruments Inc. Vendor of ARM processor microarchitectures. + */ + cpuinfo_vendor_texas_instruments = 30, + /** Marvell Technology Group Ltd. Vendor of ARM processor + * microarchitectures. + */ + cpuinfo_vendor_marvell = 31, + /** RDC Semiconductor Co., Ltd. Vendor of x86 processor + microarchitectures. */ + cpuinfo_vendor_rdc = 32, + /** DM&P Electronics Inc. Vendor of x86 processor microarchitectures. */ + cpuinfo_vendor_dmp = 33, + /** Motorola, Inc. Vendor of PowerPC and ARM processor + microarchitectures. */ + cpuinfo_vendor_motorola = 34, + + /* Defunct CPU vendors */ + + /** + * Transmeta Corporation. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 2004. + * Transmeta processors implemented VLIW ISA and used binary translation + * to execute x86 code. + */ + cpuinfo_vendor_transmeta = 50, + /** + * Cyrix Corporation. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 1996. + */ + cpuinfo_vendor_cyrix = 51, + /** + * Rise Technology. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 1999. + */ + cpuinfo_vendor_rise = 52, + /** + * National Semiconductor. Vendor of x86 processor microarchitectures. + * + * Sold its x86 design subsidiary in 1999. The last processor design was + * released in 1998. + */ + cpuinfo_vendor_nsc = 53, + /** + * Silicon Integrated Systems. Vendor of x86 processor + * microarchitectures. + * + * Sold its x86 design subsidiary in 2001. The last processor design was + * released in 2001. + */ + cpuinfo_vendor_sis = 54, + /** + * NexGen. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 1994. + * NexGen designed the first x86 microarchitecture which decomposed x86 + * instructions into simple microoperations. + */ + cpuinfo_vendor_nexgen = 55, + /** + * United Microelectronics Corporation. Vendor of x86 processor + * microarchitectures. + * + * Ceased x86 in the early 1990s. The last processor design was released + * in 1991. Designed U5C and U5D processors. Both are 486 level. + */ + cpuinfo_vendor_umc = 56, + /** + * Digital Equipment Corporation. Vendor of ARM processor + * microarchitecture. + * + * Sold its ARM designs in 1997. The last processor design was released + * in 1997. + */ + cpuinfo_vendor_dec = 57, +}; + +#endif // !defined(CPUINFO_SUPPORTED) + +} // namespace + +namespace onnxruntime { + +namespace { + +struct CpuVendorInfo { + cpuinfo_vendor vendor; + std::string_view name; + uint32_t id; +}; + +constexpr auto kUnknownCpuVendorInfo = CpuVendorInfo{cpuinfo_vendor_unknown, "unknown", 0x0000}; + +constexpr std::array kCpuVendorInfos{ + CpuVendorInfo{cpuinfo_vendor_amd, "AMD", 0x1022}, + CpuVendorInfo{cpuinfo_vendor_intel, "Intel", 0x8086}, + CpuVendorInfo{cpuinfo_vendor_qualcomm, "Qualcomm", uint32_t{'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24)}}, + CpuVendorInfo{cpuinfo_vendor_nvidia, "Nvidia", 0x10DE}, + CpuVendorInfo{cpuinfo_vendor_apple, "Apple", 0x106B}, + CpuVendorInfo{cpuinfo_vendor_arm, "ARM", 0x13B5}, + + // TODO add more as needed +}; + +const CpuVendorInfo* FindCpuVendorInfo(cpuinfo_vendor vendor) { + const auto vendor_mapping_it = std::find_if(kCpuVendorInfos.begin(), kCpuVendorInfos.end(), + [vendor](const CpuVendorInfo& entry) { + return entry.vendor == vendor; + }); + + if (vendor_mapping_it != kCpuVendorInfos.end()) { + return &*vendor_mapping_it; + } + + return nullptr; +} + +} // namespace + +void CPUIDInfo::VendorInfoInit() { + const cpuinfo_vendor vendor = [&]() { + cpuinfo_vendor result = cpuinfo_vendor_unknown; +#if defined(CPUINFO_SUPPORTED) + if (pytorch_cpuinfo_init_) { + const auto* processor = cpuinfo_get_processor(0); + if (processor && processor->core) { + result = processor->core->vendor; + } + } +#endif // defined(CPUINFO_SUPPORTED) + return result; + }(); + + const auto* vendor_info = FindCpuVendorInfo(vendor); + if (vendor_info == nullptr) { + LogEarlyWarning(MakeString("Unknown CPU vendor. cpuinfo_vendor value: ", static_cast(vendor))); + vendor_info = &kUnknownCpuVendorInfo; + } + + vendor_ = vendor_info->name; + vendor_id_ = vendor_info->id; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/common/helper.cc b/onnxruntime/core/common/helper.cc index 6a52db73df106..07cd1672b27c1 100644 --- a/onnxruntime/core/common/helper.cc +++ b/onnxruntime/core/common/helper.cc @@ -18,7 +18,7 @@ namespace onnxruntime { #ifdef _WIN32 -std::string ToUTF8String(const std::wstring& s) { +std::string ToUTF8String(std::wstring_view s) { if (s.size() >= static_cast(std::numeric_limits::max())) ORT_THROW("length overflow"); @@ -33,7 +33,7 @@ std::string ToUTF8String(const std::wstring& s) { return ret; } -std::wstring ToWideString(const std::string& s) { +std::wstring ToWideString(std::string_view s) { if (s.size() >= static_cast(std::numeric_limits::max())) ORT_THROW("length overflow"); diff --git a/onnxruntime/core/common/path_string.h b/onnxruntime/core/common/path_string.h index 6cfb327cce08a..4ca326d76a37d 100644 --- a/onnxruntime/core/common/path_string.h +++ b/onnxruntime/core/common/path_string.h @@ -40,6 +40,12 @@ inline PathString ToPathString(const PathString& s) { static_assert(std::is_same::value, "PathString is not std::wstring!"); +inline PathString ToPathString(std::string_view s) { + return ToWideString(s); +} +inline PathString ToPathString(const char* s) { + return ToWideString(s); +} inline PathString ToPathString(const std::string& s) { return ToWideString(s); } @@ -56,6 +62,14 @@ inline std::string PathToUTF8String(const PathString& s) { static_assert(std::is_same::value, "PathString is not std::string!"); +inline PathString ToPathString(const char* s) { + return s; +} + +inline PathString ToPathString(std::string_view s) { + return PathString{s}; +} + inline PathChar ToLowerPathChar(PathChar c) { return std::tolower(c); } diff --git a/onnxruntime/core/common/string_utils.h b/onnxruntime/core/common/string_utils.h index c2e26f629330f..d8d943d6e9a41 100644 --- a/onnxruntime/core/common/string_utils.h +++ b/onnxruntime/core/common/string_utils.h @@ -61,10 +61,11 @@ inline void TrimStringFromRight(std::string& s) { * @param s The string to trim. * @return The trimmed string. */ -inline std::string TrimString(std::string s) { - TrimStringFromRight(s); - TrimStringFromLeft(s); - return s; +inline std::string TrimString(std::string_view s) { + std::string s_trimmed{s}; + TrimStringFromRight(s_trimmed); + TrimStringFromLeft(s_trimmed); + return s_trimmed; } /** diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index e1b9d1294fb9e..91b5b811a3529 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -6,6 +6,7 @@ #include "core/common/safeint.h" #include "core/common/status.h" #include "core/framework/allocator.h" +#include "core/framework/error_code_helper.h" #include "core/mlas/inc/mlas.h" #include "core/framework/utils.h" #include "core/session/ort_apis.h" @@ -185,22 +186,32 @@ std::ostream& operator<<(std::ostream& out, const OrtMemoryInfo& info) { return #endif ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtAllocatorType type, int id1, enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out) { + API_IMPL_BEGIN + + if (name1 == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "MemoryInfo name cannot be null."); + } + + if (out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Output memory info cannot be null."); + } + auto device_id = static_cast(id1); if (strcmp(name1, onnxruntime::CPU) == 0) { *out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), mem_type1); } else if (strcmp(name1, onnxruntime::CUDA) == 0) { *out = new OrtMemoryInfo( - name1, type, + onnxruntime::CUDA, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::OpenVINO_GPU) == 0) { *out = new OrtMemoryInfo( - name1, type, + onnxruntime::OpenVINO_GPU, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::INTEL, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::HIP) == 0) { *out = new OrtMemoryInfo( - name1, type, + onnxruntime::HIP, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 || @@ -212,38 +223,39 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA } else if (strcmp(name1, onnxruntime::DML) == 0) { *out = new OrtMemoryInfo( - name1, type, + onnxruntime::DML, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) { *out = new OrtMemoryInfo( - name1, type, + onnxruntime::OpenVINO_RT_NPU, type, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::INTEL, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::CUDA_PINNED) == 0) { *out = new OrtMemoryInfo( - name1, type, + onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::HIP_PINNED) == 0) { *out = new OrtMemoryInfo( - name1, type, + onnxruntime::HIP_PINNED, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::AMD, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::QNN_HTP_SHARED) == 0) { *out = new OrtMemoryInfo( - name1, type, + onnxruntime::QNN_HTP_SHARED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::QUALCOMM, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::CPU_ALIGNED_4K) == 0) { *out = new OrtMemoryInfo( - name1, type, + onnxruntime::CPU_ALIGNED_4K, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, device_id, onnxruntime::kAlloc4KAlignment), mem_type1); } else { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported. Try CreateMemoryInfo_V2."); } + API_IMPL_END return nullptr; } @@ -251,6 +263,16 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo_V2, _In_ const char* name, _In_ en _In_ uint32_t vendor_id, _In_ int32_t device_id, _In_ enum OrtDeviceMemoryType mem_type, _In_ size_t alignment, enum OrtAllocatorType type, _Outptr_ OrtMemoryInfo** out) { + API_IMPL_BEGIN + + if (name == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "MemoryInfo name cannot be null."); + } + + if (out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Output memory info cannot be null."); + } + // map the public enum values to internal OrtDevice values OrtDevice::MemoryType mt = mem_type == OrtDeviceMemoryType_DEFAULT ? OrtDevice::MemType::DEFAULT : OrtDevice::MemType::HOST_ACCESSIBLE; @@ -275,6 +297,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo_V2, _In_ const char* name, _In_ en *out = new OrtMemoryInfo(name, type, OrtDevice{dt, mt, vendor_id, narrow(device_id), alignment}, mem_type == OrtDeviceMemoryType_DEFAULT ? OrtMemTypeDefault : OrtMemTypeCPU); + API_IMPL_END return nullptr; } @@ -283,7 +306,7 @@ ORT_API(void, OrtApis::ReleaseMemoryInfo, _Frees_ptr_opt_ OrtMemoryInfo* p) { de #pragma warning(pop) #endif ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetName, _In_ const OrtMemoryInfo* ptr, _Out_ const char** out) { - *out = ptr->name; + *out = ptr->name.c_str(); return nullptr; } diff --git a/onnxruntime/core/framework/bfc_arena.cc b/onnxruntime/core/framework/bfc_arena.cc index e0b50cd04173e..3a5af42d03cdd 100644 --- a/onnxruntime/core/framework/bfc_arena.cc +++ b/onnxruntime/core/framework/bfc_arena.cc @@ -13,7 +13,7 @@ BFCArena::BFCArena(std::unique_ptr resource_allocator, int max_dead_bytes_per_chunk, int initial_growth_chunk_size_bytes, int64_t max_power_of_two_extend_bytes) - : IAllocator(OrtMemoryInfo(resource_allocator->Info().name, + : IAllocator(OrtMemoryInfo(resource_allocator->Info().name.c_str(), OrtAllocatorType::OrtArenaAllocator, resource_allocator->Info().device, resource_allocator->Info().mem_type)), diff --git a/onnxruntime/core/framework/ep_context_options.cc b/onnxruntime/core/framework/ep_context_options.cc new file mode 100644 index 0000000000000..abfd3cf89cecf --- /dev/null +++ b/onnxruntime/core/framework/ep_context_options.cc @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include "core/common/common.h" +#include "core/framework/ep_context_options.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +namespace onnxruntime { +namespace epctx { +// class ModelGenOptions + +ModelGenOptions::ModelGenOptions() = default; + +ModelGenOptions::ModelGenOptions(const ConfigOptions& config_options) { + enable = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; + + std::string output_model_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + if (!output_model_path.empty()) { + output_model_location = std::filesystem::path(output_model_path); + } else { + output_model_location = std::monostate{}; + } + + std::string external_initializers_file_path = config_options.GetConfigOrDefault( + kOrtSessionOptionsEpContextModelExternalInitializersFileName, ""); + if (!external_initializers_file_path.empty()) { + ExternalInitializerFileInfo ext_info = {}; + ext_info.file_path = external_initializers_file_path; + ext_info.size_threshold = 0; + initializers_location = std::move(ext_info); + } + + embed_ep_context_in_model = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1"; +} + +bool ModelGenOptions::HasOutputModelLocation() const { + return !std::holds_alternative(output_model_location); +} + +const std::filesystem::path* ModelGenOptions::TryGetOutputModelPath() const { + return std::get_if(&output_model_location); +} + +const BufferHolder* ModelGenOptions::TryGetOutputModelBuffer() const { + return std::get_if(&output_model_location); +} + +const BufferWriteFuncHolder* ModelGenOptions::TryGetOutputModelWriteFunc() const { + return std::get_if(&output_model_location); +} + +bool ModelGenOptions::AreInitializersEmbeddedInOutputModel() const { + return std::holds_alternative(initializers_location); +} + +const ExternalInitializerFileInfo* ModelGenOptions::TryGetExternalInitializerFileInfo() const { + return std::get_if(&initializers_location); +} + +const InitializerHandler* ModelGenOptions::TryGetInitializerHandler() const { + return std::get_if(&initializers_location); +} + +} // namespace epctx +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/ep_context_options.h b/onnxruntime/core/framework/ep_context_options.h new file mode 100644 index 0000000000000..6643516bfb4c3 --- /dev/null +++ b/onnxruntime/core/framework/ep_context_options.h @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/framework/allocator.h" +#include "core/framework/config_options.h" + +namespace onnxruntime { +namespace epctx { +/// +/// Holds the buffer that will store the output model and the allocator used to allocate the memory. +/// +struct BufferHolder { + void** buffer_ptr = nullptr; + size_t* buffer_size_ptr = nullptr; + AllocatorPtr buffer_allocator = nullptr; +}; + +/// +/// Holds the opaque stream state and the write function that ORT calls to write out the output model. +/// +struct BufferWriteFuncHolder { + OrtWriteBufferFunc write_func = nullptr; + void* stream_state = nullptr; // Opaque pointer to user's stream state. Passed as first argument to write_func. +}; + +/// +/// Holds path and size threshold used to write out initializers to an external file. +/// +struct ExternalInitializerFileInfo { + std::filesystem::path file_path; + size_t size_threshold = 0; +}; + +/// +/// Holds function and state provided by user to handle initializer data (i.e., write to stream or embed in model). +/// +struct InitializerHandler { + OrtGetInitializerLocationFunc handle_initializer_func = nullptr; + void* state = nullptr; +}; + +/// +/// Stores EPContext model generation options. Used in SessionOptions. +/// +struct ModelGenOptions { + // Action to take if the output model does not have compiled (EPContext) nodes. + enum class ActionIfNoCompiledNodes { + // Return OK() but don't generate an output model. Compiling via SessionOptions defaults to this behavior + // to maintain compatibility. The explicit compile API does *not* use this action. + kDontGenerateModel = 0, + + // Generate an output model even if it doesn't have compiled nodes. + // The explicit Compile API defaults to this value. + kGenerateModel, + + // Return an error if the model does not have compiled nodes. + // The explicit Compile API can be configured to this value. + kReturnError, + }; + + ModelGenOptions(); + + // Initializes from string key/value pairs in session config options. + explicit ModelGenOptions(const ConfigOptions& config_options); + + bool enable = false; + bool error_if_output_file_exists = true; + bool error_if_no_compiled_nodes = false; + bool embed_ep_context_in_model = false; + ActionIfNoCompiledNodes action_if_no_compiled_nodes = ActionIfNoCompiledNodes::kDontGenerateModel; + + std::variant // Function to write the output model to a user's stream. + output_model_location = std::monostate{}; + + std::variant // Custom function called for every initializer to determine location. + initializers_location = std::monostate{}; + + bool HasOutputModelLocation() const; + const std::filesystem::path* TryGetOutputModelPath() const; + const BufferHolder* TryGetOutputModelBuffer() const; + const BufferWriteFuncHolder* TryGetOutputModelWriteFunc() const; + + bool AreInitializersEmbeddedInOutputModel() const; + const ExternalInitializerFileInfo* TryGetExternalInitializerFileInfo() const; + const InitializerHandler* TryGetInitializerHandler() const; +}; + +} // namespace epctx +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/ep_context_utils.cc b/onnxruntime/core/framework/ep_context_utils.cc new file mode 100644 index 0000000000000..3f02c54538526 --- /dev/null +++ b/onnxruntime/core/framework/ep_context_utils.cc @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) +#include +#include +#include "core/framework/ep_context_utils.h" +#include "core/framework/error_code_helper.h" +#include "core/graph/model_saving_options.h" + +namespace onnxruntime { +namespace epctx { + +// Serialize an EPContext model into a onnx::ModelProto. +Status EpContextModelToProto(const onnxruntime::Model& ep_context_model, + const std::filesystem::path& validated_model_path, + const epctx::ModelGenOptions& ep_context_gen_options, + /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) { + // Handle case where initializers are stored inline within the ONNX model. + if (ep_context_gen_options.AreInitializersEmbeddedInOutputModel()) { + // if no external ini file specified, set force_embed_external_ini to true to avoid intermediate file creation + // and force all initializers embed into the ONNX file. + ModelSavingOptions model_saving_options{/*size_threshold*/ SIZE_MAX}; + model_saving_options.force_embed_external_ini = true; + + model_proto = ep_context_model.ToGraphProtoWithExternalInitializers(std::filesystem::path{}, + validated_model_path, + model_saving_options); + return Status::OK(); + } + + // Handle case where initializers (with size > threshold) are stored in an external file. + if (const epctx::ExternalInitializerFileInfo* ext_info = ep_context_gen_options.TryGetExternalInitializerFileInfo(); + ext_info != nullptr) { + ModelSavingOptions model_saving_options{ext_info->size_threshold}; + + model_proto = ep_context_model.ToGraphProtoWithExternalInitializers(ext_info->file_path, + validated_model_path, + model_saving_options); + return Status::OK(); + } + + // Handle case where user specified a custom handler function that determines how each initializer is saved. + if (const epctx::InitializerHandler* custom_handler = ep_context_gen_options.TryGetInitializerHandler(); + custom_handler != nullptr) { + ORT_RETURN_IF_ERROR(ep_context_model.ToGraphProtoWithCustomInitializerHandling( + custom_handler->handle_initializer_func, + custom_handler->state, + model_proto)); + return Status::OK(); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected location for initializers while generating ", + validated_model_path); +} + +// +// OutStreamBuf class: +// + +OutStreamBuf::OutStreamBuf(BufferWriteFuncHolder write_func_holder) + : write_func_holder_(write_func_holder), buffer_(65536) { + setp(buffer_.data(), buffer_.data() + buffer_.size()); +} + +OutStreamBuf::~OutStreamBuf() { + sync(); +} + +// Called when the buffer_ is full. Flushes the buffer_ (via sync()) and then writes the overflow character to buffer_. +std::streambuf::int_type OutStreamBuf::overflow(std::streambuf::int_type ch) { + if (sync() == -1) { + return traits_type::eof(); + } + + if (ch != traits_type::eof()) { + *pptr() = static_cast(ch); + pbump(1); + } + + return ch; +} + +// Flushes the entire buffer_ to the user's write function. +int OutStreamBuf::sync() { + if (!last_status_.IsOK()) { + return -1; + } + + std::ptrdiff_t num_bytes = pptr() - pbase(); + if (num_bytes == 0) { + return 0; + } + + // Can only call pbump() with an int, so can only write at most (2^31 - 1) bytes. + if (num_bytes > std::numeric_limits::max()) { + num_bytes = std::numeric_limits::max(); + } + + char* ptr = pbase(); + + Status status = Status::OK(); + + ORT_TRY { + status = ToStatusAndRelease(write_func_holder_.write_func(write_func_holder_.stream_state, + ptr, num_bytes)); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Caught exception while calling user's OrtOutStreamWriteFunc callback: ", e.what()); + }); + } + + if (!status.IsOK()) { + last_status_ = std::move(status); + return -1; + } + + pbump(-static_cast(num_bytes)); // Reset internal pointer to point to the beginning of the buffer_ + return 0; +} + +} // namespace epctx +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/ep_context_utils.h b/onnxruntime/core/framework/ep_context_utils.h new file mode 100644 index 0000000000000..b3c76565982ff --- /dev/null +++ b/onnxruntime/core/framework/ep_context_utils.h @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include +#include + +#include "core/common/status.h" +#include "core/framework/ep_context_options.h" +#include "core/graph/model.h" + +namespace onnxruntime { +namespace epctx { + +/// +/// Serialize an EPContext model into a onnx::ModelProto based on the provided options. +/// +/// The EP Context model to serialize. +/// The path into which to save the model. May be empty if serialized into a +/// buffer or output stream. +/// The model generation options. +/// Output parameter set to the serialized onnx::ModelProto. +/// A status indicating success or an error. +Status EpContextModelToProto(const onnxruntime::Model& ep_context_model, + const std::filesystem::path& validated_model_path, + const epctx::ModelGenOptions& ep_context_gen_options, + /*out*/ ONNX_NAMESPACE::ModelProto& model_proto); + +// Class that wraps the user's OrtBufferWriteFunc function to enable use with +// C++'s std::ostream. +// Example: +// BufferWriteFuncHolder write_func_holder{write_func, stream_state}; +// std::unique_ptr out_stream_buf = std::make_unique(write_func_holder); +// std::ostream out_stream(out_stream_buf.get()); +class OutStreamBuf : public std::streambuf { + public: + explicit OutStreamBuf(BufferWriteFuncHolder write_func_holder); + ~OutStreamBuf(); + + const Status& GetStatus() const { + return last_status_; + } + + protected: + int_type overflow(int_type ch) override; + int sync() override; + + private: + BufferWriteFuncHolder write_func_holder_{}; + std::vector buffer_; + Status last_status_{}; +}; + +} // namespace epctx +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index efc12ef8dd0e8..43caf4766d5c0 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -5,10 +5,12 @@ #include #include +#include #include "core/common/inlined_containers.h" #include "core/common/string_utils.h" #include "core/framework/compute_capability.h" +#include "core/framework/ep_context_utils.h" #include "core/framework/execution_providers.h" #include "core/framework/func_kernel.h" #include "core/framework/kernel_lookup.h" @@ -20,8 +22,9 @@ #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" -#include "core/graph/model_saving_options.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/util/protobuf_parsing_utils.h" // uncomment this line to count non-CUDA ops in ONNX domain // #define COUNT_NON_CUDA_OPS @@ -765,6 +768,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide } // Validate the ep_context_path to make sure it is file path and check whether the file exist already +// TODO: Move function to ep_context_utils.h/cc static Status GetValidatedEpContextPath(const std::filesystem::path& ep_context_path, const std::filesystem::path& model_path, std::filesystem::path& context_cache_path, @@ -793,9 +797,10 @@ static Status GetValidatedEpContextPath(const std::filesystem::path& ep_context_ return Status::OK(); } +// TODO: Move function to ep_context_utils.h/cc static Status CreateEpContextModel(const ExecutionProviders& execution_providers, const Graph& graph, - const EpContextModelGenerationOptions& ep_context_gen_options, + const epctx::ModelGenOptions& ep_context_gen_options, const logging::Logger& logger) { InlinedVector all_ep_context_nodes; for (const auto& ep : execution_providers) { @@ -806,11 +811,11 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers if (all_ep_context_nodes.size() < 1) { auto action_if_no_compiled_nodes = ep_context_gen_options.action_if_no_compiled_nodes; - ORT_RETURN_IF(action_if_no_compiled_nodes == EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kReturnError, + ORT_RETURN_IF(action_if_no_compiled_nodes == epctx::ModelGenOptions::ActionIfNoCompiledNodes::kReturnError, "Unable to compile any nodes. Check that the session EPs support compilation and can execute " "at least one subgraph in the model."); - if (action_if_no_compiled_nodes == EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kDontGenerateModel) { + if (action_if_no_compiled_nodes == epctx::ModelGenOptions::ActionIfNoCompiledNodes::kDontGenerateModel) { LOGS(logger, WARNING) << "Unable to compile any nodes. ONNX Runtime will not generate a compiled model. " "Either the session EPs do not support compilation or the model is already compiled."; // Note: this path is only taken if a model is compiled with the original compilation approach that uses @@ -820,7 +825,7 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers } // Assert so that this is caught in a test in DEBUG builds (in case a new enum value is added) - assert(action_if_no_compiled_nodes == EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kGenerateModel); + assert(action_if_no_compiled_nodes == epctx::ModelGenOptions::ActionIfNoCompiledNodes::kGenerateModel); LOGS(logger, INFO) << "Unable to compile any nodes but will still generate an output model. " "Either the session EPs do not support compilation or the model is already compiled."; } @@ -834,15 +839,17 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers return std::make_pair(false, static_cast(nullptr)); }; - bool saving_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr && - ep_context_gen_options.output_model_buffer_size_ptr != nullptr && - ep_context_gen_options.output_model_buffer_allocator != nullptr; + const epctx::BufferHolder* output_buffer_holder = ep_context_gen_options.TryGetOutputModelBuffer(); + const epctx::BufferWriteFuncHolder* output_write_func_holder = ep_context_gen_options.TryGetOutputModelWriteFunc(); + const std::filesystem::path* output_model_path_ptr = ep_context_gen_options.TryGetOutputModelPath(); - std::filesystem::path context_cache_path; - if (!saving_to_buffer || !graph.ModelPath().empty()) { - ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path, + std::filesystem::path valid_output_model_path; + if (output_model_path_ptr != nullptr || !graph.ModelPath().empty()) { + std::filesystem::path output_model_path = (output_model_path_ptr != nullptr) ? *output_model_path_ptr + : std::filesystem::path(""); + ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(output_model_path, graph.ModelPath(), - context_cache_path, + valid_output_model_path, ep_context_gen_options.error_if_output_file_exists)); } @@ -909,39 +916,89 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers } } - size_t ini_size_threshold = ep_context_gen_options.output_external_initializer_size_threshold; - std::filesystem::path external_ini_path = ep_context_gen_options.output_external_initializers_file_path; - bool force_embed_external_ini = false; - if (external_ini_path.empty()) { - // if no external ini file specified, set force_embed_external_ini to true to avoid intermedia file creation - // and force all initializers embed into the Onnx file - ini_size_threshold = SIZE_MAX; - force_embed_external_ini = true; + ORT_RETURN_IF_ERROR(ep_graph.Resolve()); + + // Generate EP compatibility strings for OrtEp types and add to model metadata + // At this point, the graph has been populated with all the EPContext nodes + { + const GraphViewer graph_viewer(ep_graph); + for (const auto& ep : execution_providers) { + try { + // Generate the compatibility string for this EP + std::string compatibility_string = ep->GetCompiledModelCompatibilityInfo(graph_viewer); + if (!compatibility_string.empty()) { + // Create a unique key for this EP's compatibility info + // Use format: "ep_compatibility_info." + // All EPs in a session must have a unique Type() value, so this will be unique for the generated model + std::string metadata_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep->Type(); + auto& model_metadata = ep_context_model.MetaData(); + auto [it, was_inserted] = + model_metadata.insert_or_assign(metadata_key, compatibility_string); + if (!was_inserted) { + LOGS(logger, WARNING) << "Overwriting existing EP compatibility info for key: " << metadata_key << " (EP: " << ep->Type() << ")"; + } + LOGS(logger, VERBOSE) << "Added EP compatibility info for " << ep->Type() << " with key: " << metadata_key; + } + } catch (const std::exception& ex) { + LOGS(logger, WARNING) << "Failed to generate compatibility string for EP " << ep->Type() << ": " << ex.what(); + } + } } - ModelSavingOptions model_saving_options{ini_size_threshold}; - model_saving_options.force_embed_external_ini = force_embed_external_ini; + ONNX_NAMESPACE::ModelProto model_proto; + ORT_RETURN_IF_ERROR(EpContextModelToProto(ep_context_model, valid_output_model_path, ep_context_gen_options, + /*out*/ model_proto)); - if (saving_to_buffer) { - ORT_RETURN_IF_ERROR(ep_context_model.MainGraph().Resolve()); - // TODO(adrianlizarraga): Investigate if we can make this more memory efficient. - // May be able to use allocator to directly allocate the ModelProto to avoid a copy. - ONNX_NAMESPACE::ModelProto model_proto = ep_context_model.ToGraphProtoWithExternalInitializers(external_ini_path, - context_cache_path, - model_saving_options); + if (output_buffer_holder != nullptr) { + // Write output model into a buffer ORT allocates for the user. size_t buffer_size = model_proto.ByteSizeLong(); ORT_RETURN_IF(buffer_size > static_cast(std::numeric_limits::max()), "Cannot serialize ONNX ModelProto larger than 2GB"); - AllocatorPtr allocator = ep_context_gen_options.output_model_buffer_allocator; + AllocatorPtr allocator = output_buffer_holder->buffer_allocator; IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr(allocator, buffer_size); model_proto.SerializeToArray(buffer.get(), static_cast(buffer_size)); - *ep_context_gen_options.output_model_buffer_size_ptr = buffer_size; - *ep_context_gen_options.output_model_buffer_ptr = buffer.release(); + *output_buffer_holder->buffer_size_ptr = buffer_size; + *output_buffer_holder->buffer_ptr = buffer.release(); + } else if (output_write_func_holder != nullptr) { + // Write output model to user's output stream. + size_t buffer_size = model_proto.ByteSizeLong(); + ORT_RETURN_IF(buffer_size > static_cast(std::numeric_limits::max()), + "Cannot serialize ONNX ModelProto larger than 2GB"); + + auto out_stream_buf = std::make_unique(*output_write_func_holder); + std::ostream out_stream(out_stream_buf.get()); + + model_proto.SerializeToOstream(&out_stream); + out_stream.flush(); + ORT_RETURN_IF_ERROR(out_stream_buf->GetStatus()); } else { - ORT_RETURN_IF_ERROR(Model::SaveWithExternalInitializers(ep_context_model, context_cache_path, - external_ini_path, model_saving_options)); + // Write output model to a file. + int fd = 0; + Status status = Env::Default().FileOpenWr(valid_output_model_path, fd); + ORT_RETURN_IF_ERROR(status); + + ORT_TRY { + google::protobuf::io::FileOutputStream output(fd); + bool serialize_result = model_proto.SerializeToZeroCopyStream(&output) && output.Flush(); + if (!serialize_result) { + status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_PROTOBUF, + "Protobuf serialization failed when generating EPContext model ", + valid_output_model_path); + } + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ex.what()); + }); + } + if (!status.IsOK()) { + GSL_SUPPRESS(es .84) + ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd)); + return status; + } + ORT_RETURN_IF_ERROR(Env::Default().FileClose(fd)); } return Status::OK(); @@ -1192,7 +1249,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, const ConfigOptions& config_options, const logging::Logger& logger, Mode mode, - const EpContextModelGenerationOptions& ep_context_gen_options, + const epctx::ModelGenOptions& ep_context_gen_options, const layout_transformation::DebugGraphFn& debug_graph_fn) const { // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now. // 1. Execution providers' capabilities are checked one by one. @@ -1239,12 +1296,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, if (mode == Mode::kNormal || mode == Mode::kAssignOnly) { #if !defined(ORT_MINIMAL_BUILD) - if (ep_context_gen_options.enable && ep_context_gen_options.output_model_buffer_ptr == nullptr) { - // Check before EP compile graphs - std::filesystem::path context_cache_path; - ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path, graph.ModelPath(), - context_cache_path, - ep_context_gen_options.error_if_output_file_exists)); + if (ep_context_gen_options.enable) { + if (const std::filesystem::path* output_model_path_ptr = ep_context_gen_options.TryGetOutputModelPath(); + output_model_path_ptr != nullptr) { + // Check before EP compile graphs + std::filesystem::path context_cache_path; + ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(*output_model_path_ptr, graph.ModelPath(), + context_cache_path, + ep_context_gen_options.error_if_output_file_exists)); + } } // We use this only if Resource Aware Partitioning is enabled for any of the EPs diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 6e36d79701fd7..abe46cea58ab2 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -15,7 +15,10 @@ class ExecutionProviders; class KernelRegistryManager; class Model; struct ConfigOptions; -struct EpContextModelGenerationOptions; + +namespace epctx { +struct ModelGenOptions; +} class GraphPartitioner { public: @@ -50,7 +53,7 @@ class GraphPartitioner { const ConfigOptions& config_options, const logging::Logger& logger, Mode mode = Mode::kNormal, - const EpContextModelGenerationOptions& ep_context_gen_options = {}, + const epctx::ModelGenOptions& ep_context_gen_options = {}, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; bool IsLoadCancellationFlagSet() const { diff --git a/onnxruntime/core/framework/plugin_data_transfer.cc b/onnxruntime/core/framework/plugin_data_transfer.cc index f753f00206c5d..d6b1680176815 100644 --- a/onnxruntime/core/framework/plugin_data_transfer.cc +++ b/onnxruntime/core/framework/plugin_data_transfer.cc @@ -41,7 +41,7 @@ Status DataTransfer::CopyTensors(const std::vector& src_dst_pairs) c for (size_t i = 0; i < src_dst_pairs.size(); ++i) { src_values.push_back(&values[i * 2]); dst_values.push_back(&values[i * 2 + 1]); - streams.push_back(nullptr); // static_cast(src_dst_pairs[i].src_stream)); + streams.push_back(reinterpret_cast(src_dst_pairs[i].src_stream)); } auto* status = impl_.CopyTensors(&impl_, src_values.data(), dst_values.data(), streams.data(), diff --git a/onnxruntime/core/framework/session_options.cc b/onnxruntime/core/framework/session_options.cc index 231eb47603838..63f928d52d788 100644 --- a/onnxruntime/core/framework/session_options.cc +++ b/onnxruntime/core/framework/session_options.cc @@ -99,20 +99,11 @@ void SessionOptions::AddCustomOpLibraryHandle(PathString library_name, void* lib } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) -EpContextModelGenerationOptions::EpContextModelGenerationOptions(const ConfigOptions& config_options) { - enable = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; - output_model_file_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); - output_external_initializers_file_path = config_options.GetConfigOrDefault( - kOrtSessionOptionsEpContextModelExternalInitializersFileName, ""); - output_external_initializer_size_threshold = 0; - embed_ep_context_in_model = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1"; -} - -EpContextModelGenerationOptions SessionOptions::GetEpContextGenerationOptions() const { +epctx::ModelGenOptions SessionOptions::GetEpContextGenerationOptions() const { if (this->has_explicit_ep_context_gen_options) { return this->ep_context_gen_options; } - return EpContextModelGenerationOptions(this->config_options); + return epctx::ModelGenOptions(this->config_options); } } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index b75eeb217e7f0..b328fc916f885 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -13,6 +13,7 @@ #include "core/common/inlined_containers.h" #include "core/framework/allocator.h" #include "core/framework/config_options.h" +#include "core/framework/ep_context_options.h" #include "core/framework/ort_value.h" #include "core/session/onnxruntime_c_api.h" #include "core/optimizer/graph_transformer_level.h" @@ -70,53 +71,6 @@ struct FreeDimensionOverride { using CheckLoadCancellationFn = std::function; -/// -/// Options that configure the generation of a compiled model (i.e., a model with EPContext nodes). -/// There are two ways to compile a model: -/// 1. By specifying the correct session option configurations and creating an inference session. -/// The compiled model is generated as a side-effect of session creation. -/// 2. Using an explicit compile API (see OrtCompileApi struct in onnxruntime_c_api.h). -/// -/// The default values in this struct are set to match the current/default behavior of approach 1 to maintain -/// compatibility with the older way of compiling. The explicit compile API overrides some of these values to -/// provide its own defaults (see core/session/model_compilation_options.h/cc). -/// -struct EpContextModelGenerationOptions { - // Action to take if the output model does not have compiled (EPContext) nodes. - enum class ActionIfNoCompiledNodes { - // Return OK() but don't generate an output model. Compiling via SessionOptions defaults to this behavior - // to maintain compatibility. The explicit compile API does *not* use this action. - kDontGenerateModel = 0, - - // Generate an output model even if it doesn't have compiled nodes. - // The explicit Compile API defaults to this value. - kGenerateModel, - - // Return an error if the model does not have compiled nodes. - // The explicit Compile API can be configured to this value. - kReturnError, - }; - - EpContextModelGenerationOptions() = default; - - // Initializes from string key/value pairs in session config options. - // This initializes this struct from options set via the older compiling approach #1 above. - explicit EpContextModelGenerationOptions(const ConfigOptions& config_options); - - bool enable = false; - bool error_if_output_file_exists = true; - ActionIfNoCompiledNodes action_if_no_compiled_nodes = ActionIfNoCompiledNodes::kDontGenerateModel; - bool embed_ep_context_in_model = false; - - std::string output_model_file_path; - void** output_model_buffer_ptr = nullptr; - size_t* output_model_buffer_size_ptr = nullptr; - AllocatorPtr output_model_buffer_allocator = nullptr; - - std::string output_external_initializers_file_path; - size_t output_external_initializer_size_threshold = 0; -}; - struct EpSelectionPolicy { // flag to detect that a policy was set by the user. // need to preserve current behavior of defaulting to CPU EP if no EPs are explicitly registered @@ -270,8 +224,8 @@ struct SessionOptions { // The function GetEpContextGenerationOptions() handles conversion of string key/value pairs to the new // struct type. bool has_explicit_ep_context_gen_options = false; - EpContextModelGenerationOptions ep_context_gen_options = {}; - EpContextModelGenerationOptions GetEpContextGenerationOptions() const; + epctx::ModelGenOptions ep_context_gen_options = {}; + epctx::ModelGenOptions GetEpContextGenerationOptions() const; }; inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_options) { diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 17e337838b091..4f8c5607afce9 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -375,7 +375,7 @@ common::Status SaveInitializedTensors( // TODO: if the tensor need be copied, does it have enough room? ORT_RETURN_IF_ERROR(planner.GetPreallocatedBuffer(ort_value_index, name, memory_buffer, alloc)); - // ??? Should we ignore this session option if the EP is explictly providing the read only allocator? + // ??? Should we ignore this session option if the EP is explicitly providing the read only allocator? // bool have_readonly_initializer_allocator = alloc->Info().alloc_type == OrtReadOnlyAllocator; const bool use_device_allocator_for_initializers = session_options.config_options.GetConfigOrDefault( @@ -402,9 +402,11 @@ common::Status SaveInitializedTensors( std::move(tensor), ort_value)); } } else { + // if in memory we were expecting to find it above. + ORT_ENFORCE(!utils::HasExternalDataInMemory(tensor_proto)); + // We need to deserialize the tensor proto into an OrtValue // using the preallocated buffer or allocator. - Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (memory_buffer.has_value()) ? &*memory_buffer : nullptr, alloc, default_cpu_alloc, ort_value, data_transfer_mgr, diff --git a/onnxruntime/core/framework/tensor_external_data_info.cc b/onnxruntime/core/framework/tensor_external_data_info.cc index 971851db62437..dfdb3ba962609 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.cc +++ b/onnxruntime/core/framework/tensor_external_data_info.cc @@ -18,6 +18,13 @@ using ::google::protobuf::RepeatedPtrField; using ::ONNX_NAMESPACE::StringStringEntryProto; namespace onnxruntime { +ExternalDataInfo::ExternalDataInfo() = default; + +#if !defined(ORT_MINIMAL_BUILD) +ExternalDataInfo::ExternalDataInfo(const PathString& rel_path, OFFSET_TYPE offset, size_t length) + : rel_path_(rel_path), offset_(offset), length_(length) {} +#endif + Status ExternalDataInfo::Create(const RepeatedPtrField& input, std::unique_ptr& external_data_info_result) { auto external_data_info = std::make_unique(); @@ -107,7 +114,7 @@ void ExternalDataInfo::SetExternalLocationToProto(const std::filesystem::path& e std::ostream& ExternalDataInfo::WritePrepackedToFileAndAddToProto( const PrepackedWeightsForGraph& prepacked_for_graph, const InlinedHashSet& blob_keys, bool align, - int64_t align_threshold, int64_t allocation_granularity, + int64_t align_threshold, int64_t on_disk_alignment, std::ostream& os, int64_t& external_offset, ::ONNX_NAMESPACE::TensorProto& proto) { size_t key_count = 0; for (const auto& key : blob_keys) { @@ -120,7 +127,7 @@ std::ostream& ExternalDataInfo::WritePrepackedToFileAndAddToProto( const auto size_in_bytes = prepacked_weights->buffer_sizes_[i]; if (align && static_cast(size_in_bytes) > align_threshold) { // return early on error - if (!AlignAndPad(os, allocation_granularity, external_offset)) { + if (!AlignAndPad(os, on_disk_alignment, external_offset)) { return os; } } diff --git a/onnxruntime/core/framework/tensor_external_data_info.h b/onnxruntime/core/framework/tensor_external_data_info.h index 2de1e01f381ec..aa9bb32922bd7 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.h +++ b/onnxruntime/core/framework/tensor_external_data_info.h @@ -25,6 +25,12 @@ class ExternalDataInfo { using OFFSET_TYPE = off_t; #endif + ExternalDataInfo(); + +#if !defined(ORT_MINIMAL_BUILD) + ExternalDataInfo(const PathString& rel_path, OFFSET_TYPE offset, size_t length); +#endif + const PathString& GetRelPath() const { return rel_path_; } OFFSET_TYPE GetOffset() const { return offset_; } @@ -41,15 +47,13 @@ class ExternalDataInfo { size_t tensor_bytes_size, ::ONNX_NAMESPACE::TensorProto& proto); - // Pads the output with zeros according to the specified allocation_granularity + // Pads the output with zeros according to the specified alignment_factor // It updates external_offset for alignment. // need to do padding before write actual tensor data as we do offset alignment at the begin of - // large tensors (offset need to be page aligned and allocation granularity aligned) like below: + // large tensors (offset need to be page aligned) like below: // \242\2557\256\023.\031&0000000000000000\332)k+\253\246\342\246(&\006!\347\232\374\236\325\026\032+\36XXXX // |<---smaller tensor---->|<---padding--->|<------------------large tensor----------------------------->| - static std::ostream& AlignAndPad(std::ostream& stream, int64_t allocation_granularity, int64_t& external_offset) { - // Align to the larger of the page size or the allocation granularity - int64_t alignment_factor = std::max(static_cast(4096), allocation_granularity); + static std::ostream& AlignAndPad(std::ostream& stream, int64_t alignment_factor, int64_t& external_offset) { // Align to the next page or alloc granularity boundary SafeInt safe_external_offset = external_offset; int64_t new_external_offset = ((safe_external_offset + alignment_factor - 1) / alignment_factor) * @@ -66,7 +70,7 @@ class ExternalDataInfo { static std::ostream& WritePrepackedToFileAndAddToProto( const PrepackedWeightsForGraph& prepacked_for_graph, const InlinedHashSet& blob_keys, - bool align, int64_t align_threshold, int64_t allocation_granularity, + bool align, int64_t align_threshold, int64_t on_disk_alignment, std::ostream& os, int64_t& external_offset, ::ONNX_NAMESPACE::TensorProto& proto); diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index 6383d29d7a2bc..c5d7d4cc4e68c 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -10,6 +10,7 @@ #include "core/framework/tensor_external_data_info.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" #define DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(external_type, internal_type, internal_api) \ external_type* ToExternal() { return static_cast(this); } \ @@ -30,8 +31,10 @@ enum class OrtGraphIrApi { kEpApi, }; -// Alias OrtExternalInitializerInfo to the internal type. -struct OrtExternalInitializerInfo : onnxruntime::ExternalDataInfo {}; +// Alias OrtExternalInitializerInfo to the internal onnxruntime::ExternalDataInfo type. +struct OrtExternalInitializerInfo : onnxruntime::ExternalDataInfo { + using onnxruntime::ExternalDataInfo::ExternalDataInfo; // inherit constructors +}; /// /// Public type that represents an ONNX value info. @@ -291,6 +294,11 @@ struct OrtGraph { /// The graph's name. virtual const std::string& GetName() const = 0; + /// + /// Returns the model's metadata. + /// + /// The model metadata. + virtual std::unique_ptr GetModelMetadata() const = 0; /// /// Returns the model's path, which is empty if unknown. /// diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 5511275239e45..b48fe8c1e1839 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1387,33 +1387,49 @@ constexpr const char* MoE_ver1_doc = R"DOC( Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1, GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) usually uses top 32 experts and Mixtral(https://huggingface.co/blog/mixtral). + + The SwiGLU (Swish-Gated Linear Unit) activation function is like: + g = xW + b + l = xV + c + G = clamp(g, max=limit) + L = clamp(l, min=-limit, max=limit) + swiglu = G * sigmoid(alpha * G) * (L + beta) + where x is the input, W and V are weight matrices, b and c are bias vectors, and alpha, beta and limit are constant float parameters. + When swiglu_fusion=0, two GEMMs are not fused, and they are FC1 and FC3 in the inputs. + When swiglu_fusion=1, two GEMMs are fused so that g and l are computed in a single GEMM (FC1), and g and l are interleaved on each row of size 2 * inter_size. + When swiglu_fusion=2, two GEMMs are fused, and g and l are concatenated on each row. )DOC"; -ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1, - OpSchema() - .SetDoc(MoE_ver1_doc) - .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) - .Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1)) - .Attr("normalize_routing_weights", "Whether to normalize routing weights", AttributeProto::INT, static_cast(0)) - .Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast(0)) - .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") - .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") - .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T") - .Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) - .Input(4, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size)", "T") - .Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional) - .Input(6, "fc3_experts_weights", "3D optional input tensor with shape (num_experts, hidden_size, inter_size)", "T", OpSchema::Optional) - .Input(7, "fc3_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) - .Output(0, "output", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") - .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or float16 tensors.") - .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); +ONNX_MS_OPERATOR_SET_SCHEMA( + MoE, 1, + OpSchema() + .SetDoc(MoE_ver1_doc) + .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) + .Attr("swiglu_fusion", "0: not fused, 1: fused and interleaved. 2: fused and not interleaved.", AttributeProto::INT, static_cast(0)) + .Attr("swiglu_limit", "The limit used to clamp in SwiGLU. No clamp when limit is not provided.", AttributeProto::FLOAT, OPTIONAL_VALUE) + .Attr("activation_alpha", "Alpha parameter used in activation function.", AttributeProto::FLOAT, 1.0f) + .Attr("activation_beta", "Beta parameter used in activation function.", AttributeProto::FLOAT, 0.0f) + .Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1)) + .Attr("normalize_routing_weights", "Whether to normalize routing weights", AttributeProto::INT, static_cast(0)) + .Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast(0)) + .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") + .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") + .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size), or (num_experts, 2 * inter_size, hidden_size) for swiglu", "T") + .Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T", OpSchema::Optional) + .Input(4, "fc2_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T") + .Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional) + .Input(6, "fc3_experts_weights", "3D optional input tensor with shape (num_experts, inter_size, hidden_size)", "T", OpSchema::Optional) + .Input(7, "fc3_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + .Output(0, "output", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); ONNX_MS_OPERATOR_SET_SCHEMA( QMoE, 1, OpSchema() .SetDoc("Quantized MoE") .Attr("activation_type", - "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", + "Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) .Attr("k", @@ -1429,6 +1445,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Number of bits used in quantized weights. Default is 4 bits", AttributeProto::INT, static_cast(4)) + .Attr("swiglu_fusion", "0: not fused, 1: fused and interleaved. 2: fused and not interleaved.", AttributeProto::INT, static_cast(0)) + .Attr("swiglu_limit", "The limit used to clamp inputs in SwiGLU. It is infinite when limit is not provided.", AttributeProto::FLOAT, OPTIONAL_VALUE) + .Attr("activation_alpha", "Alpha parameter used in activation function.", AttributeProto::FLOAT, 1.0f) + .Attr("activation_beta", "Beta parameter used in activation function.", AttributeProto::FLOAT, 0.0f) .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape " @@ -1437,19 +1457,21 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") .Input(2, "fc1_experts_weights", - "3D input tensor with shape (num_experts, hidden_size, inter_size) " - "or (num_experts, hidden_size, inter_size / 2)", + "3D input tensor with shape (num_experts, inter_size, hidden_size), " + "or (num_experts, inter_size, hidden_size / 2) for 4 bits. " + "For swiglu, shape can be (num_experts, 2 * inter_size, hidden_size), " + "or (num_experts, 2 * inter_size, hidden_size / 2) for 4 bits.", "T1") - .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size)", "T") + .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T2") .Input(4, "fc1_experts_bias", - "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + "2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T", OpSchema::Optional) .Input(5, "fc2_experts_weights", - "3D input tensor with shape (num_experts, inter_size, hidden_size) " - "or (num_experts, inter_size, hidden_size / 2)", + "3D input tensor with shape (num_experts, hidden_size, inter_size) " + "or (num_experts, hidden_size, inter_size / 2) for 4 bits", "T1") - .Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T") + .Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T2") .Input(7, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", @@ -1457,14 +1479,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(8, "fc3_experts_weights", - "3D optional input tensor with shape (num_experts, hidden_size, inter_size) " - "or (num_experts, hidden_size, inter_size / 2)", + "3D optional input tensor with shape (num_experts, inter_size, hidden_size) " + "or (num_experts, inter_size, hidden_size / 2)", "T1", OpSchema::Optional) .Input(9, "fc3_scales", "2D optional input tensor with shape (num_experts, inter_size)", - "T", + "T2", OpSchema::Optional) .Input(10, "fc3_experts_bias", @@ -1476,8 +1498,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape " "(batch_size, sequence_length, hidden_size)", "T") - .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output types to float or float16 tensors.") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("T1", {"tensor(uint8)"}, "Constrain weights type to uint8 tensors.") + .TypeConstraint("T2", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain scales type to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1, @@ -3664,10 +3687,10 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h uint32_t components = (ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) ? (8 / bits) : 1; for (int i = 0; i < r; ++i) { - if (!data_shape.dim(i).has_dim_value() || - !scales_shape.dim(i).has_dim_value() || - (i == quantize_axis && (data_shape.dim(i).dim_value() * components + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) || - (i != quantize_axis && data_shape.dim(i).dim_value() != scales_shape.dim(i).dim_value())) { + if (data_shape.dim(i).has_dim_value() && + scales_shape.dim(i).has_dim_value() && + ((i == quantize_axis && (data_shape.dim(i).dim_value() * components + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) || + (i != quantize_axis && data_shape.dim(i).dim_value() != scales_shape.dim(i).dim_value()))) { fail_shape_inference("data shape and scales shape do not match"); } } diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 4ceadb6191a9b..92eb31f0ad385 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -20,6 +20,7 @@ #include "core/framework/onnxruntime_typeinfo.h" #include "core/graph/graph_viewer.h" #include "core/graph/graph.h" +#include "core/graph/model.h" namespace onnxruntime { @@ -87,6 +88,24 @@ static void ConvertNodeArgsToValueInfos(const EpGraph* ep_graph, } } +#if !defined(ORT_MINIMAL_BUILD) +static bool IsOptionalAttribute(const Node& node, const std::string& attr_name) { + const ONNX_NAMESPACE::OpSchema* op_schema = node.Op(); + if (op_schema == nullptr) { + return false; + } + + auto attr_schema_iter = op_schema->attributes().find(attr_name); + if (attr_schema_iter == op_schema->attributes().end()) { + return false; // Not an attribute for this operator type. + } + + const ONNX_NAMESPACE::OpSchema::Attribute& attr_schema = attr_schema_iter->second; + + return !attr_schema.required; +} +#endif // !defined(ORT_MINIMAL_BUILD) + // // EpNode // @@ -268,13 +287,20 @@ gsl::span EpNode::GetOutputsSpan() const { return outputs_; } -const OrtOpAttr* EpNode::GetAttribute(const std::string& name) const { +const OrtOpAttr* EpNode::GetAttribute(const std::string& name, bool& is_unset_optional_attr) const { auto iter = attributes_map_.find(name); - if (iter == attributes_map_.end()) { - return nullptr; - } else { + if (iter != attributes_map_.end()) { + is_unset_optional_attr = false; return reinterpret_cast(iter->second.get()); } + +#if !defined(ORT_MINIMAL_BUILD) + is_unset_optional_attr = IsOptionalAttribute(node_, name); +#else + // This is not properly set in a minimal build because it does not have access to the operator schema. + is_unset_optional_attr = false; +#endif // !defined(ORT_MINIMAL_BUILD) + return nullptr; } const std::string& EpNode::GetEpName() const { @@ -301,6 +327,9 @@ static Status GetInputIndices(const EpNode& consumer_node, [&found, &value_info_name, &indices](gsl::span input_value_infos, bool is_implicit) -> void { for (size_t i = 0; i < input_value_infos.size(); i++) { + if (input_value_infos[i] == nullptr) { // input_value_info == nullptr means the input is optional + continue; + } if (input_value_infos[i]->GetName() == value_info_name) { indices.push_back(is_implicit ? -1 : static_cast(i)); found = true; @@ -718,6 +747,25 @@ Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& const std::string& EpGraph::GetName() const { return graph_viewer_.Name(); } +std::unique_ptr EpGraph::GetModelMetadata() const { +#if !defined(ORT_MINIMAL_BUILD) + const auto& model = graph_viewer_.GetGraph().GetModel(); + auto model_metadata = std::make_unique(); + + model_metadata->producer_name = model.ProducerName(); + model_metadata->producer_version = model.ProducerVersion(); + model_metadata->description = model.DocString(); + model_metadata->graph_description = model.GraphDocString(); + model_metadata->domain = model.Domain(); + model_metadata->version = model.ModelVersion(); + model_metadata->custom_metadata_map = model.MetaData(); + model_metadata->graph_name = model.MainGraph().Name(); + return model_metadata; +#else + return nullptr; +#endif +} + const ORTCHAR_T* EpGraph::GetModelPath() const { return graph_viewer_.ModelPath().c_str(); } diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 243bdc2944ffb..e003f02a79a2d 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -209,8 +209,9 @@ struct EpNode : public OrtNode { // Helper that returns this node's outputs as a span of EpValueInfo pointers. gsl::span GetOutputsSpan() const; - // Helper that gets the node's attributes by name. - const OrtOpAttr* GetAttribute(const std::string& name) const; + // Helper that gets the node's attributes by name. If the attribute is not set, returns NULL and sets the + // output parameter `is_unset_optional_attr` to true if this is an unset optional attribute. + const OrtOpAttr* GetAttribute(const std::string& name, bool& is_unset_optional_attr) const; // Helper that gets the execution provider name that this node is assigned to run on. const std::string& GetEpName() const; @@ -294,6 +295,9 @@ struct EpGraph : public OrtGraph { // Returns the graph's name. const std::string& GetName() const override; + // Returns the graph's metadata + std::unique_ptr GetModelMetadata() const override; + // Returns the model path. const ORTCHAR_T* GetModelPath() const override; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index de6776b0e0df1..9a97711996343 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -17,6 +17,7 @@ #include "core/common/logging/logging.h" #include "core/common/narrow.h" #include "core/flatbuffers/flatbuffers_utils.h" +#include "core/framework/error_code_helper.h" #include "core/framework/tensor_type_and_shape.h" #include "core/flatbuffers/schema/ort.fbs.h" #include "core/framework/tensor_external_data_info.h" @@ -666,12 +667,16 @@ void Node::ToProto(NodeProto& proto, bool update_subgraphs) const { // Set attributes. proto.clear_attribute(); - for (const auto& attribute : attributes_) { + for (const auto& [name, attribute] : attributes_) { const gsl::not_null attr{proto.add_attribute()}; - *attr = attribute.second; // copy - if (update_subgraphs && attr->has_g()) { + *attr = attribute; // copy + if (update_subgraphs && utils::HasGraph(*attr)) { + auto find_hit = attr_to_subgraph_map_.find(name); + // Force ToGraphProto() const to be called so + // that any in-memory TensorProto initializers go back to being inlined + const Graph& subgraph = *find_hit->second; attr->clear_g(); - *attr->mutable_g() = attr_to_subgraph_map_.find(attribute.first)->second->ToGraphProto(); + *attr->mutable_g() = subgraph.ToGraphProto(); } } @@ -1226,26 +1231,6 @@ Graph::Graph(const Model& owning_model, ArgNameToTypeMap name_to_type_map; const auto& model_path = ModelPath(); - // If the tensor proto data is large enough, externalize it and replace with a tensor_proto - // with external data reference pointing to an OrtValue, otherwise do nothing. - auto put_data_maybe_in_memory = [this, &model_path](ONNX_NAMESPACE::TensorProto& tensor_proto) { - size_t size_in_bytes = 0; - ORT_THROW_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &size_in_bytes)); - if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) { - OrtValue ort_value; - ORT_THROW_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto, - CPUAllocator::DefaultInstance(), ort_value)); - constexpr const bool use_tensor_buffer_true = true; - auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get(), tensor_proto.name(), - use_tensor_buffer_true); - assert(ort_value.IsAllocated()); - auto ins_result = ortvalue_initializers_.insert_or_assign(tensor_proto_to_add.name(), std::move(ort_value)); - ORT_ENFORCE(ins_result.second, "Unexpected duplicate insert or assign OrtValue for tensor: ", tensor_proto_to_add.name(), - " in the initializer list."); - tensor_proto = std::move(tensor_proto_to_add); - } - }; - // Process 'Constant' nodes // Put the 'TensorProto' stored in the 'Constant' nodes attribute into the graphs initializer list for (auto& node : graph_proto_->node()) { @@ -1265,8 +1250,6 @@ Graph::Graph(const Model& owning_model, } } - put_data_maybe_in_memory(*tensor); - // Ensure initializers are also graph inputs. if (ir_version_ < 4) { TypeProto t{utils::TypeProtoFromTensorProto(*tensor)}; @@ -1343,22 +1326,7 @@ Graph::Graph(const Model& owning_model, } // Copy initial tensors to a map. - for (int i = 0, lim = graph_proto_->initializer_size(); i < lim; ++i) { - auto& tensor = *graph_proto_->mutable_initializer(i); - // If data is on disk, it will be loaded either by optimizers - // or during session state finalization. - // If data is already in memory, do nothing. - if (!utils::HasExternalData(tensor)) { - const bool is_sparse = sparse_tensor_names_.count(tensor.name()); - if (is_sparse) { - sparse_tensor_names_.erase(tensor.name()); - } - put_data_maybe_in_memory(tensor); - if (is_sparse) { - sparse_tensor_names_.emplace(tensor.name()); - } - } - + for (auto& tensor : graph_proto_->initializer()) { auto p = name_to_initial_tensor_.emplace(tensor.name(), &tensor); if (!p.second) { LOGS(logger_, WARNING) << "Duplicate initializer (dense, sparse or ConstantNode): '" << tensor.name() @@ -3418,9 +3386,39 @@ Status Graph::Resolve(const ResolveOptions& options) { return Status::OK(); }; - ORT_RETURN_IF_ERROR(ForThisAndAllSubgraphs(all_subgraphs, finalize_func)); + return ForThisAndAllSubgraphs(all_subgraphs, finalize_func); +} + +Status Graph::ConvertInitializersIntoOrtValues() { + std::vector all_subgraphs; + FindAllSubgraphs(all_subgraphs); - return Status::OK(); + auto put_weights_maybe_in_memory_func = [&](Graph& graph) -> Status { + // if we have any initializers that are not in memory, put them there. + const auto& model_path = graph.ModelPath(); + auto& graph_proto = *graph.graph_proto_; + for (int i = 0, lim = graph_proto.initializer_size(); i < lim; ++i) { + auto& tensor_proto = *graph_proto.mutable_initializer(i); + if (utils::HasExternalData(tensor_proto)) { + continue; // ignore data on disk, that will be loaded either by EP or at session_state finalize + } + + size_t size_in_bytes = 0; + ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &size_in_bytes)); + if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) { + OrtValue ort_value; + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto, + CPUAllocator::DefaultInstance(), ort_value)); + constexpr const bool use_tensor_buffer_true = true; + auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get(), tensor_proto.name(), + use_tensor_buffer_true); + ORT_RETURN_IF_ERROR(graph.ReplaceInitializedTensor(tensor_proto_to_add, ort_value)); + } + } + return Status::OK(); + }; + + return ForThisAndAllSubgraphs(all_subgraphs, put_weights_maybe_in_memory_func); } void Graph::SetName(const std::string& name) { @@ -3659,6 +3657,15 @@ Status Graph::ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initi ORT_RETURN_IF_NOT(old_initializer.data_type() == new_initializer.data_type(), "Replacement tensor's data type does not match."); + bool is_sparse = false; + { + auto sparse_tensor_it = sparse_tensor_names_.find(initializer_name); + if (sparse_tensor_it != sparse_tensor_names_.end()) { + sparse_tensor_names_.erase(sparse_tensor_it); + is_sparse = true; + } + } + auto& mutable_initializers = *(graph_proto_->mutable_initializer()); // use cheaper pointer comparison to find old entry auto existing_entry = std::find(mutable_initializers.pointer_begin(), mutable_initializers.pointer_end(), @@ -3675,6 +3682,9 @@ Status Graph::ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initi } **existing_entry = std::move(new_initializer); + if (is_sparse) { + sparse_tensor_names_.insert((**existing_entry).name()); + } return Status::OK(); } @@ -3720,7 +3730,7 @@ Status Graph::InjectExternalInitializedTensors(const InlinedHashMap>& external_initializer_files) { for (const auto& [tensor_name, tensor_proto] : name_to_initial_tensor_) { - if (tensor_proto->data_location() == TensorProto_DataLocation_EXTERNAL) { + if (utils::HasExternalDataInFile(*tensor_proto)) { std::unique_ptr external_data_info; ORT_RETURN_IF_ERROR(ExternalDataInfo::Create(tensor_proto->external_data(), external_data_info)); @@ -3729,6 +3739,7 @@ Status Graph::InjectExternalInitializersFromFilesInMemory( const size_t external_data_length = external_data_info->GetLength(); SafeInt tensor_byte_size; ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(*tensor_proto, &tensor_byte_size)); + ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size, "TensorProto: ", tensor_name, " external data size mismatch. Computed size: ", *&tensor_byte_size, ", external_data.length: ", external_data_length); @@ -3736,18 +3747,19 @@ Status Graph::InjectExternalInitializersFromFilesInMemory( SafeInt end_of_read(file_offset); end_of_read += tensor_byte_size; - auto external_file_pos = external_initializer_files.find(external_file); - ORT_RETURN_IF(external_file_pos == external_initializer_files.end(), + auto user_provided_entry = external_initializer_files.find(external_file); + ORT_RETURN_IF(user_provided_entry == external_initializer_files.end(), "External file: ", ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(external_file), " not found from the table user provided."); - auto external_file_length = external_file_pos->second.second; - ORT_RETURN_IF(file_offset < 0 || end_of_read > narrow(external_file_length), + auto user_provided_length = user_provided_entry->second.second; + + ORT_RETURN_IF(file_offset < 0 || end_of_read > narrow(user_provided_length), "External initializer: ", tensor_name, " offset: ", file_offset, " size to read: ", external_data_length, - " given file_length: ", external_file_length, " are out of bounds or can not be read in full."); - char* external_file_buffer = static_cast(external_file_pos->second.first); - char* tensor_buffer = external_file_buffer + file_offset; + " given file_length: ", user_provided_length, " are out of bounds or can not be read in full."); + char* user_provided_file_buffer = static_cast(user_provided_entry->second.first); + char* user_provided_tensor_buffer = user_provided_file_buffer + file_offset; const auto& old_initializer = *(tensor_proto); auto& mutable_initializers = *(graph_proto_->mutable_initializer()); @@ -3762,19 +3774,11 @@ Status Graph::InjectExternalInitializersFromFilesInMemory( const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(old_initializer.data_type())->GetElementType(); TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(old_initializer); - auto tensor = Tensor(type, tensor_shape, tensor_buffer, + auto tensor = Tensor(type, tensor_shape, user_provided_tensor_buffer, OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); - constexpr const bool use_tensor_buffer_true = true; - auto new_tensor_proto = utils::TensorToTensorProto(tensor, tensor_name, use_tensor_buffer_true); - // Implied that external data is in memory - const bool has_external_data_in_memory = utils::HasExternalData(new_tensor_proto); - - OrtValue ort_value; - if (has_external_data_in_memory) { - Tensor::InitOrtValue(std::move(tensor), ort_value); - } - ortvalue_initializers_.insert_or_assign(tensor_name, std::move(ort_value)); + constexpr const bool use_tensor_buffer_false = false; + auto new_tensor_proto = utils::TensorToTensorProto(tensor, tensor_name, use_tensor_buffer_false); **existing_entry = std::move(new_tensor_proto); } } @@ -4314,18 +4318,63 @@ Status InlineOrCopyInitializer(const Graph& src_graph, const ONNX_NAMESPACE::Ten } return Status::OK(); } - } // namespace -Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto, - bool process_main) const { - for (const auto& node : Nodes()) { +Status Graph::RegenerateInitializersAndReplaceInMemory(gsl::span iterators, + ONNX_NAMESPACE::GraphProto& output_graph_proto) const { + auto& mutable_initializers = *output_graph_proto.mutable_initializer(); + +#if !defined(DISABLE_SPARSE_TENSORS) + output_graph_proto.clear_sparse_initializer(); + + const auto& model_path = ModelPath(); + const bool has_sparse_initializers = !sparse_tensor_names_.empty(); + const auto sparse_end = sparse_tensor_names_.end(); + + for (const auto& iter : iterators) { + const auto& [name, tensor_proto] = *iter; + const auto& initializer = *tensor_proto; + if (!has_sparse_initializers || sparse_end == sparse_tensor_names_.find(name)) { + ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, initializer, + *mutable_initializers.Add())); + } else { + auto& sparse_initializer = *output_graph_proto.add_sparse_initializer(); + if (utils::HasExternalDataInMemory(initializer)) { + ONNX_NAMESPACE::TensorProto tensor_proto_inlined; + ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, initializer, + tensor_proto_inlined)); + ORT_RETURN_IF_ERROR(utils::DenseTensorToSparseTensorProto(tensor_proto_inlined, model_path, sparse_initializer)); + } else { + ORT_RETURN_IF_ERROR(utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer)); + } + } + } +#else + for (const auto& iter : iterators) { + const auto& [name, tensor_proto] = *iter; + ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, *tensor_proto, *mutable_initializers.Add())); + } +#endif + return Status::OK(); +} + +// Struct that holds the return value from GetSubgraphsWithMatchingGraphProtos; +// Pairs a onnxruntime::Graph with its matching onnx::GraphProto. +struct SubgraphWithMutableProto { + ONNX_NAMESPACE::GraphProto* subgraph_proto = nullptr; + const Graph* subgraph = nullptr; +}; + +static Status GetSubgraphsWithMatchingGraphProtos(const GraphNodes& nodes, + ONNX_NAMESPACE::GraphProto& graph_proto, + std::vector& subgraphs) { + for (const auto& node : nodes) { if (node.ContainsSubgraph()) { // Let's find this node in the output_graph_proto // The node name is optional, so we may need to check by the output value name // given that they can only assigned once. - auto hit = std::find_if(output_graph_proto.mutable_node()->begin(), - output_graph_proto.mutable_node()->end(), + auto hit = std::find_if(graph_proto.mutable_node()->begin(), + graph_proto.mutable_node()->end(), [&node](const ONNX_NAMESPACE::NodeProto& proto) { const auto& node_name = node.Name(); if (!node_name.empty()) @@ -4333,7 +4382,7 @@ Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_gr return (proto.output_size() > 0 && proto.output(0) == node.OutputDefs()[0]->Name()); }); - ORT_RETURN_IF_NOT(hit != output_graph_proto.mutable_node()->end(), "Node ", node.Name(), + ORT_RETURN_IF_NOT(hit != graph_proto.mutable_node()->end(), "Node ", node.Name(), " not found in output_graph_proto"); auto& result_node = *hit; for (const auto& e : node.GetAttributeNameToSubgraphMap()) { @@ -4348,104 +4397,65 @@ Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_gr ORT_RETURN_IF_NOT(sub_hit != result_node.mutable_attribute()->end() && utils::HasGraph(*sub_hit), "Subgraph ", name, " is referred to in GetAttributeNameToSubgraphMap, but not found in node ", node.Name(), " while attempting to recurse into it."); - auto& result_subgraph = *sub_hit->mutable_g(); - ORT_RETURN_IF_ERROR(subgraph->ProcessSubgraphsInMemoryData(result_subgraph, process_main)); + SubgraphWithMutableProto subgraph_result{sub_hit->mutable_g(), subgraph}; + subgraphs.emplace_back(subgraph_result); } } } - // When graph_proto is copied from graph_proto, initializers already present in the main graph - if (parent_graph_ != nullptr || process_main) { -#if !defined(DISABLE_SPARSE_TENSORS) - auto* mutable_initializers = output_graph_proto.mutable_initializer(); - const auto& model_path = ModelPath(); - const bool has_sparse_initializers = !sparse_tensor_names_.empty(); - const auto sparse_end = sparse_tensor_names_.end(); + return Status::OK(); +} - // We want to make sure that sparse initializers do not appear - // as dense duplicates within the initializers list. - std::optional> initializer_to_remove; - if (has_sparse_initializers) { - // We need to remove the dense initializers that are sparse tensors - initializer_to_remove.emplace(); - } +Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto) const { + // Process subgraphs recursively (bottom-up). + { + std::vector subgraphs; + ORT_RETURN_IF_ERROR(GetSubgraphsWithMatchingGraphProtos(Nodes(), output_graph_proto, subgraphs)); - for (auto first = mutable_initializers->begin(), end = mutable_initializers->end(); first != end; ++first) { - auto& initializer = *first; - if (utils::HasExternalDataInMemory(initializer)) { - // If the initializer has external data in memory, we need to inline it. - ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, initializer, initializer)); - } - if (has_sparse_initializers && sparse_end != sparse_tensor_names_.find(initializer.name())) { - auto& sparse_initializer = *output_graph_proto.add_sparse_initializer(); - ORT_RETURN_IF_ERROR(utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer)); - initializer_to_remove->insert(initializer.name()); - } + for (SubgraphWithMutableProto& subgraph_and_proto : subgraphs) { + gsl::not_null subgraph = subgraph_and_proto.subgraph; + gsl::not_null subgraph_proto = subgraph_and_proto.subgraph_proto; + ORT_RETURN_IF_ERROR(subgraph->ProcessSubgraphsInMemoryData(*subgraph_proto)); } + } - // erase/remove dense initializers that are sparse tensors so no duplicates are present - if (initializer_to_remove && !initializer_to_remove->empty()) { - mutable_initializers->erase(std::remove_if( - mutable_initializers->begin(), mutable_initializers->end(), - [&initializer_to_remove](const ONNX_NAMESPACE::TensorProto& initializer) { - return initializer_to_remove->count(initializer.name()) > 0; - }), - mutable_initializers->end()); - } -#else - for (auto& initializer : *output_graph_proto.mutable_initializer()) { - if (utils::HasExternalDataInMemory(initializer)) { - // If the initializer has external data in memory, we need to inline it. - ORT_RETURN_IF_ERROR(InlineOrCopyInitializer(*this, initializer, initializer)); - } + // Filter in iterators for weights that are present in the name_to_initial_tensor_ map + // and preserve the order. This is needed for tests. + InlinedVector initializers_to_process; + initializers_to_process.reserve(name_to_initial_tensor_.size()); + for (const auto& tensor_proto : output_graph_proto.initializer()) { + auto hit = name_to_initial_tensor_.find(tensor_proto.name()); + if (hit != name_to_initial_tensor_.end()) { + initializers_to_process.push_back(hit); } -#endif } - return Status::OK(); + + output_graph_proto.clear_initializer(); + return RegenerateInitializersAndReplaceInMemory(initializers_to_process, output_graph_proto); } ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const { GraphProto result; if (!GraphProtoSyncNeeded()) { result = *graph_proto_; - ORT_THROW_IF_ERROR(ProcessSubgraphsInMemoryData(result, /*process_main*/ true)); + ORT_THROW_IF_ERROR(ProcessSubgraphsInMemoryData(result)); } else { + // Recursion is handled via Node::ToProto() const -> Graph::ToGraphProto() const (this method) + // so below we handle this graph only. ToGraphProtoInternal(result); - ORT_THROW_IF_ERROR(ProcessSubgraphsInMemoryData(result, /*process_main*/ false)); - - // Add initializers to parent graph by copy converting them from graph_proto_ - // ToGraphProtoInternal() does not copy initializers for the main graph - auto* mutable_initializers = result.mutable_initializer(); - -#if !defined(DISABLE_SPARSE_TENSORS) - const auto& model_path = ModelPath(); - const bool has_sparse_initializers = !sparse_tensor_names_.empty(); - const auto sparse_end = sparse_tensor_names_.end(); - - for (const auto& initializer : graph_proto_->initializer()) { - if (!has_sparse_initializers || sparse_end == sparse_tensor_names_.find(initializer.name())) { - ORT_THROW_IF_ERROR(InlineOrCopyInitializer(*this, initializer, - *mutable_initializers->Add())); - } else { - auto& sparse_initializer = *result.add_sparse_initializer(); - if (utils::HasExternalDataInMemory(initializer)) { - ONNX_NAMESPACE::TensorProto tensor_proto; - ORT_THROW_IF_ERROR(InlineOrCopyInitializer(*this, initializer, - tensor_proto)); - ORT_THROW_IF_ERROR(utils::DenseTensorToSparseTensorProto(tensor_proto, model_path, sparse_initializer)); - } else { - ORT_THROW_IF_ERROR(utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer)); - } + InlinedVector initializers_to_process; + initializers_to_process.reserve(name_to_initial_tensor_.size()); + for (const auto& tensor_proto : graph_proto_->initializer()) { + auto hit = name_to_initial_tensor_.find(tensor_proto.name()); + if (hit != name_to_initial_tensor_.end()) { + initializers_to_process.push_back(hit); } } -#else - for (const auto& initializer : graph_proto_->initializer()) { - ORT_THROW_IF_ERROR(InlineOrCopyInitializer(*this, initializer, *mutable_initializers->Add())); - } -#endif - } + ORT_THROW_IF_ERROR(RegenerateInitializersAndReplaceInMemory(initializers_to_process, + result)); + } return result; } @@ -4460,44 +4470,19 @@ Status Graph::AddExternalInitializersToGraphProtoImpl( // Process initializers in a subgraph, check their size and // write to an external file. This function also saves pre-packed // blobs for the initializer being saved to disk, if the initializer has any pre-packs. - // This function is invoked by ToGraphProtoWithExternalInitiallizers() and processes subgraphs + // This function is invoked by ToGraphProtoWithExternalInitializers() and processes subgraphs // bottom up. - for (const auto& node : Nodes()) { - if (node.ContainsSubgraph()) { - // Let's find this node in the output_graph_proto - // The node name is optional, so we may need to check by the output value name - // given that they can only assigned once. - auto hit = std::find_if(output_graph_proto.mutable_node()->begin(), - output_graph_proto.mutable_node()->end(), - [&node](const ONNX_NAMESPACE::NodeProto& proto) { - const auto& node_name = node.Name(); - if (!node_name.empty()) - return proto.name() == node_name; - return (proto.output_size() > 0 && - proto.output(0) == node.OutputDefs()[0]->Name()); - }); - ORT_RETURN_IF_NOT(hit != output_graph_proto.mutable_node()->end(), "Node ", node.Name(), - " not found in output_graph_proto"); - auto& result_node = *hit; - for (const auto& e : node.GetAttributeNameToSubgraphMap()) { - const auto& name = e.first; - const auto& subgraph = e.second; - // Lets find this subgraph in the result_node - auto sub_hit = std::find_if(result_node.mutable_attribute()->begin(), - result_node.mutable_attribute()->end(), - [&name](const ONNX_NAMESPACE::AttributeProto& proto) { - return proto.name() == name; - }); - ORT_RETURN_IF_NOT(sub_hit != result_node.mutable_attribute()->end() && utils::HasGraph(*sub_hit), - "Subgraph ", name, " is referred to in GetAttributeNameToSubgraphMap, but not found in node ", - node.Name(), " while attempting to recurse into it."); - auto& result_subgraph = *sub_hit->mutable_g(); - ORT_RETURN_IF_ERROR(subgraph->AddExternalInitializersToGraphProtoImpl( - model_path, external_file_path, - model_external_file_path, model_saving_options, - result_subgraph, - external_stream, external_offset)); - } + { + std::vector subgraphs; + ORT_RETURN_IF_ERROR(GetSubgraphsWithMatchingGraphProtos(Nodes(), output_graph_proto, subgraphs)); + + for (SubgraphWithMutableProto& subgraph_and_proto : subgraphs) { + gsl::not_null subgraph = subgraph_and_proto.subgraph; + gsl::not_null subgraph_proto = subgraph_and_proto.subgraph_proto; + ORT_RETURN_IF_ERROR(subgraph->AddExternalInitializersToGraphProtoImpl( + model_path, external_file_path, + model_external_file_path, model_saving_options, + *subgraph_proto, external_stream, external_offset)); } } @@ -4552,14 +4537,14 @@ Status Graph::AddExternalInitializersToGraphProtoImpl( continue; } - // update external_offset for alignment + // update external_offset for alignment (if enabled) // need to do padding before write actual tensor data as we do offset alignment at the begin of - // large tensors (offset need to be page aligned and allocation granularity aligned) like below: + // large tensors (offset need to be page aligned) like below: // \242\2557\256\023.\031&0000000000000000\332)k+\253\246\342\246(&\006!\347\232\374\236\325\026\032+\36XXXX // |<---smaller tensor---->|<---padding--->|<------------------large tensor----------------------------->| if (model_saving_options.align_offset && static_cast(tensor_bytes_size) > model_saving_options.align_threshold) { - ORT_RETURN_IF_NOT(ExternalDataInfo::AlignAndPad(external_stream, model_saving_options.allocation_granularity, + ORT_RETURN_IF_NOT(ExternalDataInfo::AlignAndPad(external_stream, model_saving_options.on_disk_alignment, external_offset), "Failed writing external data to: ", model_external_file_path); } @@ -4592,7 +4577,7 @@ Status Graph::AddExternalInitializersToGraphProtoImpl( auto& os = ExternalDataInfo::WritePrepackedToFileAndAddToProto( *prepacked_weights_for_graph_, blob_keys_to_external_data, model_saving_options.align_offset, model_saving_options.align_threshold, - model_saving_options.allocation_granularity, + model_saving_options.on_disk_alignment, external_stream, external_offset, *output_proto); ORT_RETURN_IF_NOT(os.good(), "Failed to write pre-packed blobs to external file"); } @@ -4659,6 +4644,113 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers( return result; } +Status Graph::ToGraphProtoWithCustomInitializerHandlingImpl( + OrtGetInitializerLocationFunc handle_initializer_func, + void* state, + /*out*/ ONNX_NAMESPACE::GraphProto& output_graph_proto) const { + // This loop processes subgraphs bottom up. + { + std::vector subgraphs; + ORT_RETURN_IF_ERROR(GetSubgraphsWithMatchingGraphProtos(Nodes(), output_graph_proto, subgraphs)); + + for (SubgraphWithMutableProto& subgraph_and_proto : subgraphs) { + gsl::not_null subgraph = subgraph_and_proto.subgraph; + gsl::not_null subgraph_proto = subgraph_and_proto.subgraph_proto; + ORT_RETURN_IF_ERROR(subgraph->ToGraphProtoWithCustomInitializerHandlingImpl(handle_initializer_func, + state, *subgraph_proto)); + } + } + + // Create a sorted std::vector of initializers so that we always process them in a deterministic order. + InlinedVector initializers; + initializers.reserve(GetAllInitializedTensors().size()); + + for (const auto& [name, initializer_tp] : GetAllInitializedTensors()) { + initializers.push_back(initializer_tp); + } + + std::sort(initializers.begin(), initializers.end(), + [](const ONNX_NAMESPACE::TensorProto* a, const ONNX_NAMESPACE::TensorProto* b) { + return a->name() < b->name(); + }); + + // Call user's handler function for each initializer. We store the initializer externally + // or within the model depending on the result returned by the handler function. + for (gsl::not_null initializer : initializers) { +#if !defined(DISABLE_SPARSE_TENSORS) + if (IsSparseInitializer(initializer->name())) { + // Sparse tensors are added to the ONNX file directly. + auto& sparse_initializer = *output_graph_proto.add_sparse_initializer(); + ORT_RETURN_IF_ERROR(utils::DenseTensorToSparseTensorProto(*initializer, ModelPath(), sparse_initializer)); + } else { +#endif + TensorProto* output_proto = output_graph_proto.add_initializer(); + + output_proto->set_name(initializer->name()); + output_proto->set_data_type(initializer->data_type()); + for (int i = 0; i != initializer->dims_size(); ++i) { + output_proto->add_dims(initializer->dims(i)); + } + output_proto->set_doc_string(initializer->doc_string()); + + OrtValue ort_value; + std::unique_ptr original_ext_data_info = nullptr; + + if (utils::HasExternalDataInFile(*initializer)) { + // Initializer has data in an external file. Load it into OrtValue (potentially via memory mapping). + ORT_RETURN_IF_ERROR(ExternalDataInfo::Create(initializer->external_data(), original_ext_data_info)); + ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(Env::Default(), ModelPath(), *initializer, ort_value)); + } else { + // Initializer is either stored inline within the TensorProto or it is "external data in memory". + // Get an OrtValue (if already loaded by Graph) or copy into an OrtValue otherwise. + bool graph_has_ort_value = GetOrtValueInitializer(initializer->name(), ort_value, /*check_outer_scope*/ false); + if (!graph_has_ort_value) { + assert(!utils::HasExternalData(*initializer)); + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), ModelPath(), *initializer, + CPUAllocator::DefaultInstance(), ort_value)); + } + } + + // Call the user's initializer handler function. If the user wants to store the initializer externally, + // the handler function will use OrtApi::CreateExternalInitializerInfo() to create a new + // OrtExternalInitializerInfo instance that indicates the location of the data. + OrtExternalInitializerInfo* new_external_info = nullptr; + Status status = ToStatusAndRelease(handle_initializer_func(state, initializer->name().c_str(), + &ort_value, + static_cast(original_ext_data_info.get()), + &new_external_info)); + + ORT_RETURN_IF(new_external_info != nullptr && + new_external_info == static_cast(original_ext_data_info.get()), + "User's OrtGetInitializerLocationFunc must not return the external_info parameter.", + "Return a copy instead."); + std::unique_ptr new_external_info_holder(new_external_info); // Take ownership + ORT_RETURN_IF_ERROR(status); + + if (new_external_info != nullptr) { + ExternalDataInfo::SetExternalLocationToProto(new_external_info->GetRelPath(), new_external_info->GetOffset(), + new_external_info->GetLength(), *output_proto); + } else { + const Tensor& tensor = ort_value.Get(); + output_proto->clear_data_location(); + utils::SetRawDataInTensorProto(*output_proto, tensor.DataRaw(), tensor.SizeInBytes()); + } +#if !defined(DISABLE_SPARSE_TENSORS) + } +#endif + } + + return Status::OK(); +} + +Status Graph::ToGraphProtoWithCustomInitializerHandling(OrtGetInitializerLocationFunc handle_initializer_func, + void* state, + /*out*/ ONNX_NAMESPACE::GraphProto& graph_proto) const { + ToGraphProtoInternal(graph_proto); + ORT_RETURN_IF_ERROR(ToGraphProtoWithCustomInitializerHandlingImpl(handle_initializer_func, state, graph_proto)); + return Status::OK(); +} + void Graph::ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const { graph_proto_->clear_node(); graph_proto_->clear_input(); @@ -5241,23 +5333,7 @@ Status Graph::AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& nod tensor_proto.set_name(std::string(new_name.value())); } - // In the constant node, we won't have symbolic dims. - const auto tensor_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); - auto ml_data = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); - const size_t size_in_bytes = Tensor::CalculateTensorStorageSize(ml_data, tensor_shape); - - if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) { - OrtValue ort_value; - ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), ModelPath(), tensor_proto, - CPUAllocator::DefaultInstance(), ort_value)); - - constexpr const bool use_tensor_buffer_true = true; - auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get(), tensor_proto.name(), - use_tensor_buffer_true); - ORT_RETURN_IF_ERROR(AddInitializedOrtValue(tensor_proto_to_add, ort_value)); - } else { - AddInitializedTensor(tensor_proto); - } + AddInitializedTensor(tensor_proto); if (GetNodeArg(tensor_proto.name()) == nullptr) { TypeProto t{utils::TypeProtoFromTensorProto(tensor_proto)}; diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 436af7115eb1a..0ffbced51ee35 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -361,6 +361,10 @@ const ModelMetaData& Model::MetaData() const noexcept { return model_metadata_; } +ModelMetaData& Model::MetaData() noexcept { + return model_metadata_; +} + Graph& Model::MainGraph() noexcept { return *graph_; } @@ -377,6 +381,15 @@ ModelProto Model::ToProto() const { // out dense duplicates of sparse initializers and leave the original // proto intact. ModelProto result(model_proto_); + + // Sync current model_metadata_ back to protobuf metadata_props + result.clear_metadata_props(); + for (const auto& metadata : model_metadata_) { + const gsl::not_null prop{result.add_metadata_props()}; + prop->set_key(metadata.first); + prop->set_value(metadata.second); + } + const auto& graph = *graph_; *(result.mutable_graph()) = graph.ToGraphProto(); return result; @@ -386,6 +399,15 @@ ModelProto Model::ToGraphProtoWithExternalInitializers(const std::filesystem::pa const std::filesystem::path& file_path, const ModelSavingOptions& model_saving_options) const { ModelProto result(model_proto_); + + // Sync current model_metadata_ back to protobuf metadata_props + result.clear_metadata_props(); + for (const auto& metadata : model_metadata_) { + const gsl::not_null prop{result.add_metadata_props()}; + prop->set_key(metadata.first); + prop->set_value(metadata.second); + } + const auto& graph = *graph_; *(result.mutable_graph()) = graph.ToGraphProtoWithExternalInitializers(external_file_name, file_path, @@ -393,6 +415,25 @@ ModelProto Model::ToGraphProtoWithExternalInitializers(const std::filesystem::pa return result; } +common::Status Model::ToGraphProtoWithCustomInitializerHandling(OrtGetInitializerLocationFunc handle_initializer_func, + void* state, + /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) const { + model_proto = model_proto_; + + // Sync current model_metadata_ back to protobuf metadata_props + model_proto.clear_metadata_props(); + for (const auto& metadata : model_metadata_) { + const gsl::not_null prop{model_proto.add_metadata_props()}; + prop->set_key(metadata.first); + prop->set_value(metadata.second); + } + + const auto& graph = *graph_; + ORT_RETURN_IF_ERROR(graph.ToGraphProtoWithCustomInitializerHandling(handle_initializer_func, + state, *model_proto.mutable_graph())); + return Status::OK(); +} + Status Model::Load(std::istream& model_istream, ModelProto* p_model_proto) { if (!model_istream.good()) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid istream object."); diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 70f82bcfb160b..c86aac44806bd 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -189,6 +189,8 @@ class Model { const ModelMetaData& MetaData() const noexcept; + ModelMetaData& MetaData() noexcept; + // Gets the path from which the model was loaded, if any. const std::filesystem::path& ModelPath() const noexcept { return model_path_; } @@ -208,6 +210,18 @@ class Model { const std::filesystem::path& file_path, const ModelSavingOptions& model_saving_options) const; + /// + /// Serialize the Model to a onnx::ModelProto. Caller provides a function that determines where each initializer + /// is stored (i.e., either in an external file or within the model). + /// + /// Function called for every initializer. + /// Opaque user state passed to the handle_initializer_func. + /// Output parameter set to the serialized onnx::ModelProto. + /// A status indicating success or an error. + common::Status ToGraphProtoWithCustomInitializerHandling(OrtGetInitializerLocationFunc handle_initializer_func, + void* state, + /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) const; + static common::Status Save(Model& model, const PathString& file_path); static common::Status Save(Model& model, int fd); diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 5d84e48182bfe..2c0f6d6174303 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -13,6 +13,7 @@ #include "core/framework/ort_value.h" #include "core/graph/abi_graph_types.h" #include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" namespace onnxruntime { @@ -179,6 +180,9 @@ struct ModelEditorGraph : public OrtGraph { const std::string& GetName() const override { return name; } + std::unique_ptr GetModelMetadata() const override { + return std::make_unique(model_metadata); + } const ORTCHAR_T* GetModelPath() const override { return model_path.c_str(); } int64_t GetOnnxIRVersion() const override { @@ -236,6 +240,7 @@ struct ModelEditorGraph : public OrtGraph { std::vector> nodes; std::string name = "ModelEditorGraph"; std::filesystem::path model_path; + ModelMetadata model_metadata; }; } // namespace onnxruntime diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 3627989609737..fc3c0b6016ced 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -48,6 +48,17 @@ struct MLAS_QNBIT_GEMM_DATA_PARAMS { const T* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block const T* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block + + /// + /// Address of scale * accumulate(quant - zp), one per block, where `scale`, `quant`, `zp` are respectively + /// an individual block's scale, quantized values, and zero point for the input `B`. + /// When converting the activation input (A) to uint8, we first convert the values to int8 and then + /// add a "bias" of +128 to convert the range of values from [-128, +127] to [0, +255]. + /// This input helps to "de-bias" the output of the +128 bias added to the activation input. + /// This input is to be used only when A is quantized to uint8. + /// + const T* BlkUnsignedQuantAZeroPointCorrection = nullptr; + const T* Bias = nullptr; ///< optional address of Bias, vector size N T* C = nullptr; ///< address of result matrix size_t ldc = 0; ///< leading dimension of C diff --git a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp index caa445b71e2a5..c579ff1542eb9 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp @@ -153,28 +153,23 @@ ArmKleidiAI::MlasGemmBatch( MLAS_THREADPOOL* ThreadPool ) { - if(TransA == CblasTrans) - { - return false; + if (M == 0 || N == 0) { + return true; } - if (TransA == CblasNoTrans && K == 0) { - if (Data->beta != 1.0f) { + + if (Data->alpha == 0.0f || K == 0) { + if (Data->beta == 0.0f) { + for (size_t i = 0; i < M; ++i) { + std::fill_n(Data->C + i * Data->ldc, N, 0.0f); + } + } else if (Data->beta != 1.0f) { for (size_t i = 0; i < M; ++i) { for (size_t j = 0; j < N; ++j) { Data->C[i * Data->ldc + j] *= Data->beta; } } } - } - if (Data->beta == 0.0f){ - std::fill_n(Data->C, M * Data->ldc, 0.0f); - } - //Fallback in the case of unsupported cases - if (M == 0 || N == 0 || K == 0 || - TransA != CblasNoTrans || - (TransB != CblasNoTrans && !Data[0].BIsPacked)) - { - return false; + return true; } if (TransA == CblasNoTrans) { @@ -185,11 +180,9 @@ ArmKleidiAI::MlasGemmBatch( auto m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); auto n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); - if (M < m_step || N < n_step) { - if (GetMlasPlatform().MlasGemmBatchOverride != ArmKleidiAI::MlasGemmBatch){ - //Fallback to MLAS - return false; - } + if (M < m_step && N < n_step && !Data->BIsPacked) { + // Fallback to MLAS + return false; } std::vector KaiPackedData; @@ -316,7 +309,7 @@ ArmKleidiAI::MlasGemmBatch( float* dst_tile = reinterpret_cast(CTile); // quick copy of data in cases where we are not scaling or accumulating anything - // with bounds checking on tile sizing to ensure the data fits in the memory block + // with bounds checking on tile sizing to ensure the data fits in the memory block bool can_memcpy = ( Data[BIdx].alpha == 1.0f && Data[BIdx].beta == 0.0f && @@ -328,21 +321,37 @@ ArmKleidiAI::MlasGemmBatch( if (can_memcpy) { std::memcpy(dst_tile, temp_tile, TileSizeM * TileSizeN * sizeof(float)); - }else { - // apply alpha scaling and beta to output files - for (size_t i = 0; i < TileSizeM; ++i) { - for (size_t j = 0; j < TileSizeN; ++j) { - const size_t idx = i * TileSizeN + j; - const size_t dst_idx = i * Data[BIdx].ldc + j; - - float ab = temp_tile[idx]; - float c_orig = dst_tile[dst_idx]; + return; + } - dst_tile[dst_idx] = Data[BIdx].alpha * ab + Data[BIdx].beta * c_orig; + float alpha = Data[BIdx].alpha; + float beta = Data[BIdx].beta; + size_t ldc = Data[BIdx].ldc; + + for (size_t i = 0; i < TileSizeM; ++i) { + for (size_t j = 0; j < TileSizeN; ++j) { + const size_t temp_idx = i * TileSizeN + j; + const size_t dst_idx = i * ldc + j; + + float ab = temp_tile[temp_idx]; + float c_orig = dst_tile[dst_idx]; + + if (alpha == 1.0f && beta == 0.0f) { + dst_tile[dst_idx] = ab; + } else if (alpha == 1.0f) { + dst_tile[dst_idx] = ab + beta * c_orig; + } else if (beta == 0.0f) { + dst_tile[dst_idx] = alpha * ab; + } else { + dst_tile[dst_idx] = alpha * ab + beta * c_orig; } } } + return; }); + return true; + } + else { + return false; } - return true; } diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index a099bcf8438fe..2bbcfd51fe4ba 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1200,7 +1200,8 @@ struct MLAS_QNBIT_GEMM_DISPATCH; const MLAS_QNBIT_GEMM_DISPATCH& GetMlasQNBitGemmDispatchNeon( - bool InitializeWithDotSupport + bool InitializeWithDotSupport, + bool InitializeWithI8MMSupport ); extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; @@ -1297,6 +1298,8 @@ struct MLAS_PLATFORM { // TODO: move to cpuinfo bool Avx2Supported_ = false; bool Avx512Supported_ = false; + bool ArmNeonIsQuantActivationsUnsigned = false; + // Mlas overrides initialisation MLAS_GEMM_BATCH_OVERRIDE* MlasGemmBatchOverride = nullptr; MLAS_GEMM_PACK_B_SIZE_OVERRIDE* MlasGemmPackBSizeOverride = nullptr; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 3256dadb856d3..c4b8d5e78a491 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -582,7 +582,6 @@ Return Value: this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; } - this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions); #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ this->MlasGemmBatchOverride = ArmKleidiAI::MlasGemmBatch; @@ -593,16 +592,22 @@ Return Value: } #endif -#if defined(__linux__) // // Check if the processor supports ASIMD I8MM instructions. // - if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { + + const bool HasI8MMInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM(); + if (HasI8MMInstructions) { +#if defined(__linux__) + this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchUmmla; this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchUmmla; this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchSmmla; - } #endif + } + + this->ArmNeonIsQuantActivationsUnsigned = HasI8MMInstructions ? false : true; + this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions, HasI8MMInstructions); #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelNeon; diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index 19d11a60b7376..d806f4b08bfff 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -132,7 +132,7 @@ QNBitGemmPerGemmWorkspaceSize( } if (BlkBitWidth == 4 || BlkBitWidth == 8) { - return Dispatch->QNBitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, HasZeroPoint, ComputeType); + return Dispatch->QNBitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, HasZeroPoint, ComputeType, BlkBitWidth); } return 0; @@ -266,7 +266,7 @@ MlasQNBitGemmPackQuantBData( if (BlkBitWidth == 4) { if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen, false); Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( N, K, @@ -307,7 +307,8 @@ MlasQNBitGemmPackQuantBData( } else if (BlkBitWidth == 8) { if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, + BlkLen, GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned); Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum( N, K, @@ -742,6 +743,8 @@ SQ8BitGemm_CompInt8( : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; + const float* BlkUnsignedQuantAZeroPointCorrection = + DataParams->BlkUnsignedQuantAZeroPointCorrection ? DataParams->BlkUnsignedQuantAZeroPointCorrection + RangeStartN * k_blks : nullptr; float* C = DataParams->C + RangeStartM * ldc + RangeStartN; const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; @@ -759,6 +762,8 @@ SQ8BitGemm_CompInt8( if (GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) { const float* b_blk_sum = QuantBBlkSum + n * k_blks; + const float* blk_unsigned_quant_A_zp_correction = BlkUnsignedQuantAZeroPointCorrection ? + BlkUnsignedQuantAZeroPointCorrection + n * k_blks : nullptr; GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8( BlkLen, QuantA, @@ -774,7 +779,8 @@ SQ8BitGemm_CompInt8( bias, ldc, ABlockSum, - b_blk_sum + b_blk_sum, + blk_unsigned_quant_A_zp_correction ); if (DataParams->PostProcessor != nullptr) { @@ -798,7 +804,8 @@ InitializeWorkspace_CompInt8( const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + size_t BlkBitWidth ); template <> @@ -812,7 +819,8 @@ InitializeWorkspace_CompInt8( const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + size_t BlkBitWidth ) { MLAS_UNREFERENCED_PARAMETER(N); @@ -826,7 +834,7 @@ InitializeWorkspace_CompInt8( const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); // TODO: try parallel on BatchN * M threads because BatchN is usually 1. - if (UsePacked && QuantizeA_Packed && UsePacked(K, BlkLen, DataParams->QuantBZeroPoint)) { + if (BlkBitWidth == 4 && UsePacked && QuantizeA_Packed && UsePacked(K, BlkLen, DataParams->QuantBZeroPoint)) { MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { const auto& data = DataParams[gemm_idx]; @@ -834,38 +842,63 @@ InitializeWorkspace_CompInt8( std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; QuantizeA_Packed(BlkLen, ARowPtr, M, K, QuantARowPtr); }); - } else if (QuantizeARow) { - MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { - const auto& data = DataParams[gemm_idx]; - - const float* ARowPtr = data.A; - std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; - for (size_t m = 0; m < M; ++m) { - QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); - - ARowPtr += data.lda; - QuantARowPtr += QuantAStride; - } - }); } else { - MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { - const auto& data = DataParams[gemm_idx]; - const float* ARowPtr = data.A; - - void* PerGemmWorkspace = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; - PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen); - std::byte* QuantARowPtr = quant_a_data.QuantData; - float* QuantARowScalePtr = quant_a_data.QuantScale; - float* QuantARowBlkSum = quant_a_data.BlockSum; - for (size_t m = 0; m < M; ++m) { - QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr, QuantARowBlkSum); - ARowPtr += data.lda; - QuantARowPtr += BlockCountK * BlkLen; - QuantARowScalePtr += BlockCountK; - QuantARowBlkSum += BlockCountK; + // TODO(hasesh): Clean-up the following logic so that it is clean AND it works as expected on all platforms + if (BlkBitWidth == 4) { + if (QuantizeARow) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + + const float* ARowPtr = data.A; + std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + for (size_t m = 0; m < M; ++m) { + QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); + + ARowPtr += data.lda; + QuantARowPtr += QuantAStride; + } + }); + } else if (QuantizeARow2) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + const float* ARowPtr = data.A; + + void* PerGemmWorkspace = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen); + std::byte* QuantARowPtr = quant_a_data.QuantData; + float* QuantARowScalePtr = quant_a_data.QuantScale; + float* QuantARowBlkSum = quant_a_data.BlockSum; + for (size_t m = 0; m < M; ++m) { + QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr, QuantARowBlkSum); + ARowPtr += data.lda; + QuantARowPtr += BlockCountK * BlkLen; + QuantARowScalePtr += BlockCountK; + QuantARowBlkSum += BlockCountK; + } + }); } - }); - } + } else if (BlkBitWidth == 8) { + if (QuantizeARow2) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + const float* ARowPtr = data.A; + + void* PerGemmWorkspace = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen); + std::byte* QuantARowPtr = quant_a_data.QuantData; + float* QuantARowScalePtr = quant_a_data.QuantScale; + float* QuantARowBlkSum = quant_a_data.BlockSum; + for (size_t m = 0; m < M; ++m) { + QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr, QuantARowBlkSum); + ARowPtr += data.lda; + QuantARowPtr += BlockCountK * BlkLen; + QuantARowScalePtr += BlockCountK; + QuantARowBlkSum += BlockCountK; + } + }); + } + } + } } template <> @@ -879,7 +912,8 @@ InitializeWorkspace_CompInt8( const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + size_t BlkBitWidth ) { MLAS_UNREFERENCED_PARAMETER(M); MLAS_UNREFERENCED_PARAMETER(N); @@ -890,6 +924,7 @@ InitializeWorkspace_CompInt8( MLAS_UNREFERENCED_PARAMETER(Workspace); MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspaceStride); MLAS_UNREFERENCED_PARAMETER(ThreadPool); + MLAS_UNREFERENCED_PARAMETER(BlkBitWidth); } template @@ -902,7 +937,8 @@ using InitializeWorkspaceFn = std::function* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + size_t BlkBitWidth )>; template @@ -1015,7 +1051,7 @@ MlasQNBitGemmBatch( if (const auto InitializeWorkspaceOperation = GetInitializeWorkspace(Variant); InitializeWorkspaceOperation != nullptr) { InitializeWorkspaceOperation( - M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool + M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool, BlkBitWidth ); } @@ -1029,17 +1065,19 @@ MlasQNBitGemmBatch( void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; if (Variant == SQ4BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, false); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); } else if (Variant == SQ8BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + const_cast*>(Data)->BlkUnsignedQuantAZeroPointCorrection = packed_quant_b.BlkUnsignedQuantAZeroPointCorrection; + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); } else { @@ -1107,7 +1145,7 @@ MlasQNBitGemmBatch( void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; if (Variant == SQ4BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, false); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; @@ -1115,10 +1153,11 @@ MlasQNBitGemmBatch( PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); } else if (Variant == SQ8BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + const_cast*>(Data)->BlkUnsignedQuantAZeroPointCorrection = packed_quant_b.BlkUnsignedQuantAZeroPointCorrection; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.h b/onnxruntime/core/mlas/lib/qnbitgemm.h index 4c133103bee04..7ec80c6d67f15 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm.h @@ -48,7 +48,7 @@ MlasAlignAddress(void* addr, const size_t alignment) template struct PackedQuantBDataStruct { - PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen) + PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen, bool QuantAUnsigned) : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) { const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); @@ -56,16 +56,40 @@ struct PackedQuantBDataStruct { #if defined(MLAS_TARGET_AMD64_IX86) // avx512 requires alignment on a 64-byte boundary PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 64); +#elif defined (MLAS_TARGET_ARM64) + // Only for 8-bit Gemms is the `PackedQuantBData` is to be 32-byte aligned and + // there is enough memory allocated to support this alignment. + // See QNBitGemmPackQuantBDataSize(). + // When bit width is 4, there is no alignment guarantee. + // TODO(hasesh): Can we unify the alignment for 4-bit and 8-bit ARM64 Gemms so as to + // simpify this logic and make code here cleaner ? + if constexpr (BlkBitWidth == 8) { + PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); + } + else { + PackedQuantBData = (std::byte*)PackedQuantBWorkspace; + } #else PackedQuantBData = (std::byte*)PackedQuantBWorkspace; #endif + QuantBBlkSum = (T*)(PackedQuantBData + PackedQuantBDataSize); QuantBBlkSum = (T*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); - PackedQuantBScale = (T*)((std::byte*)QuantBBlkSum + BlkSumSize); + + if (QuantAUnsigned) { + BlkUnsignedQuantAZeroPointCorrection = (T*)((std::byte*)QuantBBlkSum + BlkSumSize); + BlkUnsignedQuantAZeroPointCorrection = (T*)MlasAlignAddress(BlkUnsignedQuantAZeroPointCorrection, MlasQNBitQuantBBlkSumAlignment()); + PackedQuantBScale = (T*)((std::byte*)BlkUnsignedQuantAZeroPointCorrection + BlkSumSize); + } else { + BlkUnsignedQuantAZeroPointCorrection = nullptr; + PackedQuantBScale = (T*)((std::byte*)QuantBBlkSum + BlkSumSize); + } } + std::byte* PackedQuantBData; T* PackedQuantBScale; T* QuantBBlkSum; + T* BlkUnsignedQuantAZeroPointCorrection; void* QuantBWorkspace_; size_t N_, BlockCountK_, BlkLen_; @@ -178,7 +202,8 @@ struct MLAS_QNBIT_GEMM_DISPATCH { size_t K, size_t BlkLen, bool HasZeroPoint, - MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + size_t BlkBitWidth ); QNBitGemmPerGemmWorkspaceSize_Fn* QNBitGemmPerGemmWorkspaceSize = nullptr; @@ -373,20 +398,22 @@ struct MLAS_QNBIT_GEMM_DISPATCH { * @brief Multiply quantized 8-bit integer matrix A with quantized 8-bit integer matrix B. * A and B are block quantized and B is column major. * - * @param BlkLen Number of values in a block. - * @param QuantA Supplies the quantized A matrix. - Binary data containing block quantized int8 data and scale values. - * @param QuantBData Supplies the quantized B matrix block data. - * @param QuantBScale Supplies the quantized B matrix block scale values. - * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. - * @param[out] C Supplies the output C matrix. - * @param CountN Number of columns of B and C. - * @param CountK Number of columns of A and rows of B. - * @param BlockCountK Number of blocks between adjacent columns of the quantized B matrix. - * @param Bias Bias vector of length N. - * @param ldc Number of elements between adjacent rows of C.. - * @param ABlockSum Supplies the blksum of A. - * @param QuantBBlkSum Supplies the blksum of B. + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockCountK Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + * @param ldc Number of elements between adjacent rows of C.. + * @param ABlockSum Supplies the blksum of A. + * @param QuantBBlkSum Supplies the blksum of B. + * @param BlkUnsignedQuantAZeroPointCorrection Supplies the optional input to de-bias the Gemm output to account for the +128 bias + addition when the activation input A is quantized to uint8. */ typedef size_t(SQ8BitGemmKernel_BlkSum_CompInt8_Fn)( size_t BlkLen, @@ -403,7 +430,8 @@ struct MLAS_QNBIT_GEMM_DISPATCH { const float* Bias, size_t ldc, const float* ABlockSum, - const float* QuantBBlkSum + const float* QuantBBlkSum, + const float* BlkUnsignedQuantAZeroPointCorrection ); SQ8BitGemmKernel_BlkSum_CompInt8_Fn* SQ8BitGemmKernel_BlkSum_CompInt8 = nullptr; diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index 0d06eb04e5245..ba2b68e4fbb07 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -21,6 +21,7 @@ Module Name: #include #include +#include #include "qnbitgemm.h" #include "sqnbitgemm_q8_block.h" @@ -42,8 +43,9 @@ namespace // Quantized B data packing function implementation. // +template size_t -Q4BitGemmPackQuantBDataSize( +QNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkLen, @@ -51,26 +53,49 @@ Q4BitGemmPackQuantBDataSize( MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { + if constexpr (BlkBitWidth == 4) { #ifndef USE_KLEIDIAI - MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); - MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType + MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); + MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType #endif #ifdef USE_KLEIDIAI - if (ComputeType == SQNBIT_CompInt8 && UseKleidiAI(K, BlkLen, HasZeroPoint)) { - const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel = GetKleidiAIGemmUKernel(); - const size_t nr = ukernel.get_nr(); - const size_t kr = ukernel.get_kr(); - const size_t sr = ukernel.get_sr(); - return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, BlkLen, kai_dt_bf16); - } else + if (ComputeType == SQNBIT_CompInt8 && UseKleidiAI(K, BlkLen, HasZeroPoint)) { + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel = GetKleidiAIGemmUKernel(); + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, BlkLen, kai_dt_bf16); + } else #endif - { - constexpr size_t BlkBitWidth = 4; - + { + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; + } + } else { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - return PackedQuantBDataSize; + size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + + if (ComputeType == SQNBIT_CompInt8) { + const size_t ScaleSize = N * BlockCountK * sizeof(float); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + + // align on a 32-byte boundary + constexpr size_t PackedQuantBDataAlignment = 32; + PackedQuantBDataSize += PackedQuantBDataAlignment - 1; + constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment(); + BlkSumSize += BlkSumAlignment - 1; + + if constexpr (QuantAUnsigned) { + // 2 block sum + return PackedQuantBDataSize + ScaleSize + BlkSumSize + BlkSumSize; + } else { + return PackedQuantBDataSize + ScaleSize + BlkSumSize; + } + } else { + return PackedQuantBDataSize; + } } } @@ -199,6 +224,167 @@ SQ4BitGemmPackQuantBDataAndBlkSum( } } +void +Q8PackQuantB( + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + float* BlkUnsignedQuantAZeroPointCorrectionBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t N, + const size_t K, + const size_t BlkLen) +{ + constexpr size_t SubBlkLen = 4; + const size_t BlkCountK = MlasDivRoundup(K, BlkLen); + const size_t SubBlkPerBlk = BlkLen / SubBlkLen; + const size_t StrideN = BlkCountK * BlkLen; + const size_t Iterations = N * BlkCountK; + + // 4 rows x 8 columns pack together, then 4 rows x 4 columns, then per column. + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t c = tid / BlkCountK; + const size_t c8 = c & (~7), c8_res = c & 7; + const size_t c4 = c & (~3), c4_res = c & 3; + const size_t r_blk = tid % BlkCountK; + size_t r_subblk = r_blk * SubBlkPerBlk; + + const std::byte* src = QuantBDataBegin + c * StrideN + r_blk * BlkLen; + const uint8_t* src8 = reinterpret_cast(src); + + for (size_t i = 0; i < SubBlkPerBlk; ++i, src += SubBlkLen, ++r_subblk) { + if (c8 + 8 <= N) { // full 8 cols + std::byte* dest = + PackedQuantBDataBegin + c8 * StrideN + r_subblk * SubBlkLen * 8 + c8_res * SubBlkLen; + std::copy(src, src + SubBlkLen, dest); + } else if (c4 + 4 <= N) { // full 4 cols + std::byte* dest = + PackedQuantBDataBegin + c4 * StrideN + r_subblk * SubBlkLen * 4 + c4_res * SubBlkLen; + std::copy(src, src + SubBlkLen, dest); + } else { // remainder cols + std::byte* dest = + PackedQuantBDataBegin + c * StrideN + r_subblk * SubBlkLen; + std::copy(src, src + SubBlkLen, dest); + } + } + + if (BlkUnsignedQuantAZeroPointCorrectionBegin) { + const int accu = std::accumulate(src8, src8 + std::min(BlkLen, K - r_blk * BlkLen), 0); + + // for sgemmc + const size_t dst_offset = ((c / 16) * BlkCountK + r_blk) * 16 + c % 16; + BlkUnsignedQuantAZeroPointCorrectionBegin[dst_offset] = static_cast(accu); + } + } + ); +} + +void +Q8ComputePackBlkSum( + const size_t BlkLen, + const size_t N, + const size_t K, + float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + float* BlockSum2Begin, + MLAS_THREADPOOL* ThreadPool) +{ + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + std::vector QuantBScaleBeginCopy(N * BlockCountK); + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, QuantBScaleBeginCopy.begin()); + + MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t n8 = n & (~7), n8_res = n & 7; + const size_t n4 = n & (~3), n4_res = n & 3; + const size_t k_blk = tid % BlockCountK; + + const size_t src_blk_offset = n * BlockCountK + k_blk; + const float QuantBScale = QuantBScaleBeginCopy[src_blk_offset]; + uint8_t zp = 128; + if (QuantBZPBegin) { + const std::byte* QuantBZP = QuantBZPBegin + src_blk_offset; + zp = (uint8_t)(*QuantBZP); + } + + // BlockSum is a width 16 row major matrix + const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; + *(BlockSumBegin + dst_offset) = -QuantBScale * zp; + if (BlockSum2Begin) { + BlockSum2Begin[dst_offset] = QuantBScale * (static_cast(zp) * std::min(BlkLen, K - k_blk * BlkLen) - BlockSum2Begin[dst_offset]); + } + + // re-arrange scale to the same order as packed data + if (n4 + 4 > N) { // remainder cols + *(QuantBScaleBegin + n * BlockCountK + k_blk) = QuantBScale; + } else if (n8 + 8 > N) { // full 4 cols + *(QuantBScaleBegin + n4 * BlockCountK + k_blk * 4 + n4_res) = QuantBScale; + } else { // full 8 cols + *(QuantBScaleBegin + n8 * BlockCountK + k_blk * 8 + n8_res) = QuantBScale; + } + }); +} + +/** + * 4 rows x 8 cols pack together, along all K. Then 4 rows x 4 cols, along all K. + * When rows < 4, keep original layout. + * + * dotprod: vdotq_laneq_u32. + * convert quant a from int8 to uint8. zp is 128. + * + * i8mm: vusdotq_laneq_s32. + */ +void +SQ8BitGemmPackQuantBDataAndBlkSum( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType */, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + // Pack the quantized weights + if (QuantBDataBegin) { + Q8PackQuantB(QuantBDataBegin, PackedQuantB.PackedQuantBData, PackedQuantB.BlkUnsignedQuantAZeroPointCorrection, ThreadPool, N, K, BlkLen); + } else { + // We ignore the scales and zero points if they are provided when pre-packing the weights as there is + // some "state" associated with 'BlkUnsignedQuantAZeroPointCorrection'. + + // We accumulate the block sum into 'BlkUnsignedQuantAZeroPointCorrection' while packing the weights + // in the previous step. If we were to use 'scales' while pre-packing the weights and if there were no + // zero points, then we would enter 'Q8ComputePackBlkSum' twice - once while pre-packing the weights + // and once while pre-packing the scales which would lead to erroneous 'BlkUnsignedQuantAZeroPointCorrection' + // computation as the buffer is "used" in-place for the "block sum" temporary values (obtained while pre-packing + // the weights) and the actual 'BlkUnsignedQuantAZeroPointCorrection' which will use the scales. + // Hence, to ensure that the piece of logic to calculate 'BlkUnsignedQuantAZeroPointCorrection' is only invoked + // once, we do it while we are pre-packing the scales and ignore any provided 'scales' and 'zero points' while + // pre-packing the weights. + // The flip side is that the user has to ensure that this function is called once each for 'weights', + // 'scales', and 'zero points'. This is a reasonable expectation and hence we go with that design. + + // Pack the block scales + if (QuantBScaleBegin) { + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, PackedQuantB.PackedQuantBScale); + } + + // Pack the blksum (and BlkUnsignedQuantAZeroPointCorrection if applicable) + if ((QuantBScaleBegin && !HasZeroPoint) || QuantBZPBegin) { + Q8ComputePackBlkSum(BlkLen, N, K, PackedQuantB.PackedQuantBScale, QuantBZPBegin, PackedQuantB.QuantBBlkSum, PackedQuantB.BlkUnsignedQuantAZeroPointCorrection, ThreadPool); + } + } +} + // // Workspace size calculation function implementation. // @@ -210,19 +396,21 @@ QNBitGemmPerGemmWorkspaceSize( size_t K, size_t BlkLen, bool HasZeroPoint, - MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + size_t BlkBitWidth ) { MLAS_UNREFERENCED_PARAMETER(N); #ifndef USE_KLEIDIAI MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); + MLAS_UNREFERENCED_PARAMETER(BlkBitWidth); #endif switch (ComputeType) { case SQNBIT_CompInt8: { // workspace buffer is used for block quantization of A to int8 #ifdef USE_KLEIDIAI - if (UseKleidiAI(K, BlkLen, HasZeroPoint)) { + if (BlkBitWidth == 4 && UseKleidiAI(K, BlkLen, HasZeroPoint)) { const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel = M == 1? GetKleidiAIGemvUKernel() : GetKleidiAIGemmUKernel(); @@ -233,8 +421,10 @@ QNBitGemmPerGemmWorkspaceSize( } else #endif { + // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + // QuantData + Scale + BlkSum + const size_t PerGemmWorkspaceSize = M * BlockCountK * (Q8BlkSize(BlkLen) + sizeof(float)); return PerGemmWorkspaceSize; } } @@ -278,6 +468,77 @@ UseKleidiAI(size_t K, size_t BlkLen, bool HasZp) #endif } +template +size_t +SQ8BitGemmKernel_BlkSum_CompInt8( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* /*QuantBZeroPoint*/, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum, + const float* BlkUnsignedQuantAZeroPointCorrection +) +{ + MlasQ8Int8GemmKernelNeon( + BlkLen, + reinterpret_cast*>(QuantA), + QuantAScale, + reinterpret_cast(QuantBData), + QuantBScale, + C, + CountM, + CountN, + CountK, + Bias, + ldc + ); + + { + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = MlasSgemmKernelAdd(a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + } + + if constexpr (QuantAUnsigned) { + { + assert(BlkUnsignedQuantAZeroPointCorrection != nullptr); + float* c_blk = C; + const float* b_blk_sum2 = BlkUnsignedQuantAZeroPointCorrection; + + size_t RowsRemaining = CountM; + const float* a_scale_row = QuantAScale; + while (RowsRemaining > 0) { + auto RowsHandled = MlasSgemmKernelAdd(a_scale_row, b_blk_sum2, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 128.f); + + c_blk += ldc * RowsHandled; + a_scale_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + } + } + + return CountM; +} + } // namespace sqnbitgemm_neon // @@ -286,7 +547,8 @@ UseKleidiAI(size_t K, size_t BlkLen, bool HasZp) const MLAS_QNBIT_GEMM_DISPATCH& GetMlasQNBitGemmDispatchNeon( - bool InitializeWithDotSupport + bool InitializeWithDotSupport, + bool InitializeWithI8MMSupport ) { // Note: The InitializeWithX parameters are only used in the invocation of this method that initializes the static @@ -295,9 +557,11 @@ GetMlasQNBitGemmDispatchNeon( static const MLAS_QNBIT_GEMM_DISPATCH MlasQNBitGemmDispatchNeon = [&]() { MLAS_QNBIT_GEMM_DISPATCH d; - d.Q4BitGemmPackQuantBDataSize = sqnbitgemm_neon::Q4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = sqnbitgemm_neon::QNBitGemmPackQuantBDataSize<4, false>; + d.Q8BitGemmPackQuantBDataSize = sqnbitgemm_neon::QNBitGemmPackQuantBDataSize<8, true>; d.SQ4BitGemmPackQuantBData = sqnbitgemm_neon::SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = sqnbitgemm_neon::SQ4BitGemmPackQuantBDataAndBlkSum; + d.SQ8BitGemmPackQuantBDataAndBlkSum = sqnbitgemm_neon::SQ8BitGemmPackQuantBDataAndBlkSum; d.QNBitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::QNBitGemmPerGemmWorkspaceSize; d.QNBitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::QNBitGemmPerGemmWorkspaceAlignment; @@ -310,12 +574,21 @@ GetMlasQNBitGemmDispatchNeon( d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8; d.UsePacked_CompInt8 = sqnbitgemm_neon::UsePacked_CompInt8; + d.QuantizeARowComputeBlkSum_CompInt8 = sqnbitgemm_neon::QuantizeARowComputeBlkSum_CompInt8; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = sqnbitgemm_neon::SQ8BitGemmKernel_BlkSum_CompInt8; + #ifdef USE_KLEIDIAI d.SQ4BitGemmKernel_Packed_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_Packed_CompInt8; d.QuantizeA_Packed_CompInt8 = sqnbitgemm_neon::QuantizeA_Packed_CompInt8; #endif } + if (InitializeWithI8MMSupport) { + d.Q8BitGemmPackQuantBDataSize = sqnbitgemm_neon::QNBitGemmPackQuantBDataSize<8, false>; + d.QuantizeARowComputeBlkSum_CompInt8 = sqnbitgemm_neon::QuantizeARowComputeBlkSum_CompInt8; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = sqnbitgemm_neon::SQ8BitGemmKernel_BlkSum_CompInt8; + } + #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) d.HQ4BitGemmPackQuantBData = sqnbitgemm_neon::HQ4BitGemmPackQuantBData_CompFp16; d.HQ4BitBlkDequantBForHgemm_CompFp16 = sqnbitgemm_neon::HQ4BitBlkDequantBForHgemm_CompFp16; diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h index a254ec9f92596..c8be42b01fbe2 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h @@ -123,6 +123,36 @@ QuantizeARow_CompInt8( std::byte* QuantA ); +template +void MLASCALL +QuantizeARowComputeBlkSum_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) +); + +template +using QuantAType = typename std::conditional::type; + +template +size_t +MlasQ8Int8GemmKernelNeon( + const size_t BlkLen, + const QuantAType* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float * QuantBScale, + float* C, + const size_t CountM, + const size_t CountN, + const size_t CountK, + const float* Bias, + const size_t ldc +); + size_t SQ4BitGemmKernel_CompInt8( size_t BlkLen, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 384b04c807195..f160c9f541238 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -602,7 +602,8 @@ SQ8BitGemmKernel_BlkSum_CompInt8_avx2( const float* Bias, size_t ldc, const float* ABlockSum, - const float* QuantBBlkSum + const float* QuantBBlkSum, + const float* /*QuantBBlkSum2*/ ) { if (BlkLen == 16) { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index c1bc00fbffa3e..122086d8ef05b 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -264,7 +264,8 @@ SQ8BitGemmKernel_BlkSum_CompInt8_avx512( const float* Bias, size_t ldc, const float* ABlockSum, - const float* QuantBBlkSum + const float* QuantBBlkSum, + const float* /*QuantBBlkSum2*/ ) { if (BlkLen == 16) { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index ea5eebd854655..e172308637af1 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -316,7 +316,8 @@ SQ8BitGemmKernel_BlkSum_CompInt8_avx512vnni( const float* Bias, size_t ldc, const float* ABlockSum, - const float* QuantBBlkSum + const float* QuantBBlkSum, + const float* /*QuantBBlkSum2*/ ) { if (BlkLen == 16) { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index bb38f37fb0eb8..36c15cd5ac57f 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -469,7 +469,8 @@ QNBitGemmPerGemmWorkspaceSize( size_t K, size_t BlkLen, bool /* HasZeroPoint */, - MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + size_t /* BlkBitWidth */ ) { MLAS_UNREFERENCED_PARAMETER(N); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp index 8dbd339468930..b03b8121059f3 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -187,6 +187,230 @@ QuantizeARow_CompInt8( } } +MLAS_FORCEINLINE +float32x4_t LoadFloat32x4(const float* src, size_t count) +{ + if (count == 4) { + return vld1q_f32(src); + } else if (count == 3) { + float32x4_t v = vdupq_n_f32(0.0f); + v = vld1q_lane_f32(src, v, 0); + v = vld1q_lane_f32(src + 1, v, 1); + v = vld1q_lane_f32(src + 2, v, 2); + return v; + } else if (count == 2) { + float32x4_t v = vdupq_n_f32(0.0f); + v = vld1q_lane_f32(src, v, 0); + v = vld1q_lane_f32(src + 1, v, 1); + return v; + } else { + assert(count == 1); + float32x4_t v = vdupq_n_f32(0.0f); + v = vld1q_lane_f32(src, v, 0); + return v; + } +} + +template +using I16VecType = typename std::conditional::type; + +template +I16VecType MLAS_FORCEINLINE +PrepareZeroI16() +{ + if constexpr (IsQuantAUnsigned) { + return vdupq_n_u16(0); + } else { + return vdupq_n_s16(0); + } +} + +template +void MLASCALL +QuantizeARowComputeBlkSum_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) +) +{ + // First use i8 to quantize A. range [-128, 127] + // If convert to u8, +128. Range [0, 255] + assert(BlkLen % 16 == 0); + assert(BlkLen <= 256); + MLAS_DECLSPEC_ALIGN(static const uint8_t MASK[16], 16) = { + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + const int16x8_t v128 = vdupq_n_s16(128); + QuantAType* blob = reinterpret_cast*>(QuantA); + float* scale_ptr = QuantAScale; + size_t k = 0; + for (; k + BlkLen <= CountK; k += BlkLen) { + float32x4_t absMax0 = vdupq_n_f32(0.0f); + float32x4_t absMax1 = vdupq_n_f32(0.0f); + float32x4_t absMax2 = vdupq_n_f32(0.0f); + float32x4_t absMax3 = vdupq_n_f32(0.0f); + + for (size_t kk = 0; kk < BlkLen; kk += 16) { + const float32x4x4_t v0 = vld4q_f32(A + k + kk); + absMax0 = vmaxq_f32(absMax0, vabsq_f32(v0.val[0])); + absMax1 = vmaxq_f32(absMax1, vabsq_f32(v0.val[1])); + absMax2 = vmaxq_f32(absMax2, vabsq_f32(v0.val[2])); + absMax3 = vmaxq_f32(absMax3, vabsq_f32(v0.val[3])); + } + + const float32x4_t max01 = vmaxq_f32(absMax0, absMax1); + const float32x4_t max23 = vmaxq_f32(absMax2, absMax3); + const float32x4_t max0123 = vmaxq_f32(max01, max23); + const float maxScalar = vmaxvq_f32(max0123); + + // Quantize these floats + const float scale = maxScalar / 127.f; + *scale_ptr = scale; + scale_ptr++; + + const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; + const float32x4_t mul = vdupq_n_f32(inverse_scale); + + I16VecType sum_8_i16_0 = PrepareZeroI16(); + I16VecType sum_8_i16_1 = PrepareZeroI16(); + + for (size_t kk = 0; kk < BlkLen; kk += 16) { + const float32x4_t vfp32_0 = LoadFloat32x4(A + k + kk, 4); + const float32x4_t vfp32_1 = LoadFloat32x4(A + k + kk + 4, 4); + const float32x4_t vfp32_2 = LoadFloat32x4(A + k + kk + 8, 4); + const float32x4_t vfp32_3 = LoadFloat32x4(A + k + kk + 12, 4); + + const float32x4_t v0 = vmulq_f32(vfp32_0, mul); + const float32x4_t v1 = vmulq_f32(vfp32_1, mul); + const float32x4_t v2 = vmulq_f32(vfp32_2, mul); + const float32x4_t v3 = vmulq_f32(vfp32_3, mul); + + const int32x4_t i0 = vcvtnq_s32_f32(v0); + const int32x4_t i1 = vcvtnq_s32_f32(v1); + const int32x4_t i2 = vcvtnq_s32_f32(v2); + const int32x4_t i3 = vcvtnq_s32_f32(v3); + + const int16x8_t v_8_i16_0 = vcombine_s16(vqmovn_s32(i0), vqmovn_s32(i1)); + const int16x8_t v_8_i16_1 = vcombine_s16(vqmovn_s32(i2), vqmovn_s32(i3)); + + if constexpr (IsQuantAUnsigned) { + const uint16x8_t v_8_u16_0 = vreinterpretq_u16_s16(vaddq_s16(v_8_i16_0, v128)); + const uint16x8_t v_8_u16_1 = vreinterpretq_u16_s16(vaddq_s16(v_8_i16_1, v128)); + const uint8x16_t v_16_u8 = vcombine_u8(vqmovn_u16(v_8_u16_0), vqmovn_u16(v_8_u16_1)); + vst1q_u8(blob + k + kk, v_16_u8); + + // accumulate Sum(a_i) + const uint16x8_t i_8_u16_0 = vmovl_u8(vget_low_u8(v_16_u8)); + const uint16x8_t i_8_u16_1 = vmovl_high_u8(v_16_u8); + sum_8_i16_0 = vaddq_u16(sum_8_i16_0, i_8_u16_0); + sum_8_i16_1 = vaddq_u16(sum_8_i16_1, i_8_u16_1); + } else { + const int8x16_t v_16_i8 = vcombine_s8(vqmovn_s16(v_8_i16_0), vqmovn_s16(v_8_i16_1)); + vst1q_s8(blob + k + kk, v_16_i8); + + // accumulate Sum(a_i) + sum_8_i16_0 = vaddq_s16(sum_8_i16_0, v_8_i16_0); + sum_8_i16_1 = vaddq_s16(sum_8_i16_1, v_8_i16_1); + } + } + + float qsum; + + if constexpr (IsQuantAUnsigned) { + const uint16x8_t sum_8_u16 = vaddq_u16(sum_8_i16_0, sum_8_i16_1); + qsum = static_cast(vaddvq_u16(sum_8_u16)); + } else { + const int16x8_t sum_8_i16 = vaddq_s16(sum_8_i16_0, sum_8_i16_1); + qsum = static_cast(vaddvq_s16(sum_8_i16)); + } + + *AScaledBlkSum = scale * qsum; + AScaledBlkSum++; + } + + if (k < CountK) { + float32x4_t absMax = vdupq_n_f32(0.0f); + + for (size_t kk = k; kk < CountK; kk += 4) { + size_t step = std::min(static_cast(4), CountK - kk); + const float32x4_t v0 = LoadFloat32x4(A + kk, step); + absMax = vmaxq_f32(absMax, vabsq_f32(v0)); + } + + const float maxScalar = vmaxvq_f32(absMax); + const float scale = maxScalar / 127.f; + *scale_ptr = scale; + + const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; + const float32x4_t mul = vdupq_n_f32(inverse_scale); + + I16VecType sum_8_i16 = PrepareZeroI16(); + + for (size_t kk = k; kk < CountK; kk += 4) { + size_t step = std::min(static_cast(4), CountK - kk); + const float32x4_t vfp32 = LoadFloat32x4(A + kk, step); + const float32x4_t v_f32 = vmulq_f32(vfp32, mul); + const int32x4_t v_i32 = vcvtnq_s32_f32(v_f32); + const int16x8_t v_8_i16 = vcombine_s16(vqmovn_s32(v_i32), vdup_n_s16(0)); + + if constexpr (IsQuantAUnsigned) { + const uint16x8_t v_8_u16 = vreinterpretq_u16_s16(vaddq_s16(v_8_i16, v128)); + uint8x8_t v_8_u8 = vqmovn_u16(v_8_u16); + vst1_lane_s32(reinterpret_cast(blob + kk), vreinterpret_s32_u8(v_8_u8), 0); + + // accumulate Sum(a_i) + v_8_u8 = vand_u8(v_8_u8, vld1_u8(MASK + 8 - step)); + const uint16x8_t i_8_u16 = vmovl_u8(v_8_u8); + sum_8_i16 = vaddq_u16(sum_8_i16, i_8_u16); + } else { + const int8x8_t v_8_i8 = vqmovn_s16(v_8_i16); + vst1_lane_s32(reinterpret_cast(blob + kk), vreinterpret_s32_s8(v_8_i8), 0); + + // accumulate Sum(a_i) + sum_8_i16 = vaddq_s16(sum_8_i16, v_8_i16); + } + } + + float qsum; + + if constexpr (IsQuantAUnsigned) { + qsum = static_cast(vaddvq_u16(sum_8_i16)); + } else { + qsum = static_cast(vaddvq_s16(sum_8_i16)); + } + + *AScaledBlkSum = scale * qsum; + + memset(blob + CountK, 0, BlkLen - (CountK % BlkLen)); + } +} + +template +void MLASCALL +QuantizeARowComputeBlkSum_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) +); + +template +void MLASCALL +QuantizeARowComputeBlkSum_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) +); + namespace { @@ -1439,6 +1663,723 @@ SQ4BitGemmKernel_CompInt8( return CountM; } +MLAS_FORCEINLINE void +Q8Int8GemmR2xC8DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t MRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol8 = BlockCountK * BlkLen * NCols8; + + assert(CountM % MRows2 == 0); + assert(CountN % NCols8 == 0); + + for (size_t m = 0; m < CountM; m += MRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols8) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf0_47 = vdupq_n_f32(0.0f); + float32x4_t accf1_03 = vdupq_n_f32(0.0f); + float32x4_t accf1_47 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float32x4_t scaleB03 = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleB47 = vld1q_f32(QuantBScalePtr + NCols4); + + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB03, scaleA0); + const float32x4_t scaleA0B47 = vmulq_n_f32(scaleB47, scaleA0); + const float32x4_t scaleA1B03 = vmulq_n_f32(scaleB03, scaleA1); + const float32x4_t scaleA1B47 = vmulq_n_f32(scaleB47, scaleA1); + + uint32x4_t acc0_03 = vdupq_n_u32(0U); + uint32x4_t acc0_47 = vdupq_n_u32(0U); + uint32x4_t acc1_03 = vdupq_n_u32(0U); + uint32x4_t acc1_47 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + const uint8x16_t av1_16_i8 = vld1q_u8(QuantAPtr + lda); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_0_47 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_1_47 = vld1q_u8(QuantBDataPtr + 48); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 64); + uint8x16_t bv_packed_2_47 = vld1q_u8(QuantBDataPtr + 80); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 96); + uint8x16_t bv_packed_3_47 = vld1q_u8(QuantBDataPtr + 112); + + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_0_47, av0_16_i8, 0); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_1_47, av0_16_i8, 1); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_2_47, av0_16_i8, 2); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_3_47, av0_16_i8, 3); + + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_0_03, av1_16_i8, 0); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_1_03, av1_16_i8, 1); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_2_03, av1_16_i8, 2); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_3_03, av1_16_i8, 3); + + acc1_47 = vdotq_laneq_u32(acc1_47, bv_packed_0_47, av1_16_i8, 0); + acc1_47 = vdotq_laneq_u32(acc1_47, bv_packed_1_47, av1_16_i8, 1); + acc1_47 = vdotq_laneq_u32(acc1_47, bv_packed_2_47, av1_16_i8, 2); + acc1_47 = vdotq_laneq_u32(acc1_47, bv_packed_3_47, av1_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols8 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_u32(acc0_03)); + accf0_47 = vfmaq_f32(accf0_47, scaleA0B47, vcvtq_f32_u32(acc0_47)); + accf1_03 = vfmaq_f32(accf1_03, scaleA1B03, vcvtq_f32_u32(acc1_03)); + accf1_47 = vfmaq_f32(accf1_47, scaleA1B47, vcvtq_f32_u32(acc1_47)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols8; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32_03 = vld1q_f32(BiasPtr); + const float32x4_t bias_4_f32_47 = vld1q_f32(BiasPtr + 4); + + accf0_03 = vaddq_f32(accf0_03, bias_4_f32_03); + accf0_47 = vaddq_f32(accf0_47, bias_4_f32_47); + accf1_03 = vaddq_f32(accf1_03, bias_4_f32_03); + accf1_47 = vaddq_f32(accf1_47, bias_4_f32_47); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + 4, accf0_47); + vst1q_f32(SumPtr + ldc, accf1_03); + vst1q_f32(SumPtr + ldc + 4, accf1_47); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol8; + QuantBScaleColPtr += NCols8 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols8 : 0; + SumPtr += NCols8; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC8DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol8 = BlockCountK * BlkLen * NCols8; + + assert(CountN % NCols8 == 0); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols8) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf0_47 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float32x4_t scaleB03 = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleB47 = vld1q_f32(QuantBScalePtr + NCols4); + + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB03, scaleA0); + const float32x4_t scaleA0B47 = vmulq_n_f32(scaleB47, scaleA0); + + uint32x4_t acc0_03 = vdupq_n_u32(0U); + uint32x4_t acc0_47 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_0_47 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_1_47 = vld1q_u8(QuantBDataPtr + 48); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 64); + uint8x16_t bv_packed_2_47 = vld1q_u8(QuantBDataPtr + 80); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 96); + uint8x16_t bv_packed_3_47 = vld1q_u8(QuantBDataPtr + 112); + + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_0_47, av0_16_i8, 0); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_1_47, av0_16_i8, 1); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_2_47, av0_16_i8, 2); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_3_47, av0_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols8 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_u32(acc0_03)); + accf0_47 = vfmaq_f32(accf0_47, scaleA0B47, vcvtq_f32_u32(acc0_47)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols8; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32_03 = vld1q_f32(BiasPtr); + const float32x4_t bias_4_f32_47 = vld1q_f32(BiasPtr + 4); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32_03); + accf0_47 = vaddq_f32(accf0_47, bias_4_f32_47); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + 4, accf0_47); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol8; + QuantBScaleColPtr += NCols8 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols8 : 0; + SumPtr += NCols8; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t MRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol4 = BlockCountK * BlkLen * NCols4; + + assert(CountM % MRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += MRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf1_03 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float32x4_t scaleB = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB, scaleA0); + const float32x4_t scaleA1B03 = vmulq_n_f32(scaleB, scaleA1); + + uint32x4_t acc0_03 = vdupq_n_u32(0U); + uint32x4_t acc1_03 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + const uint8x16_t av1_16_i8 = vld1q_u8(QuantAPtr + lda); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 48); + + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_0_03, av1_16_i8, 0); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_1_03, av1_16_i8, 1); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_2_03, av1_16_i8, 2); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_3_03, av1_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols4 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_u32(acc0_03)); + accf1_03 = vfmaq_f32(accf1_03, scaleA1B03, vcvtq_f32_u32(acc1_03)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols4; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32 = vld1q_f32(BiasPtr); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32); + accf1_03 = vaddq_f32(accf1_03, bias_4_f32); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + ldc, accf1_03); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol4; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol4 = BlockCountK * BlkLen * NCols4; + + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float32x4_t scaleB = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB, scaleA0); + + uint32x4_t acc0_03 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 48); + + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols4 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_u32(acc0_03)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols4; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32 = vld1q_f32(BiasPtr); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32); + } + + vst1q_f32(SumPtr, accf0_03); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol4; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC1DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t MRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol = BlockCountK * BlkLen; + + assert(CountM % MRows2 == 0); + + for (size_t m = 0; m < CountM; m += MRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; ++n) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0 = vdupq_n_f32(0.0f); + float32x4_t accf1 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float scaleB = *QuantBScalePtr; + const float scaleA0B = scaleB * scaleA0; + const float scaleA1B = scaleB * scaleA1; + + uint32x4_t acc0 = vdupq_n_u32(0U); + uint32x4_t acc1 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + const uint8x16_t av1_16_i8 = vld1q_u8(QuantAPtr + lda); + + uint8x16_t bv_packed = vld1q_u8(QuantBDataPtr); + + acc0 = vdotq_u32(acc0, bv_packed, av0_16_i8); + acc1 = vdotq_u32(acc1, bv_packed, av1_16_i8); + + QuantAPtr += KStep16; + QuantBDataPtr += KStep16; + } + + accf0 = vfmaq_n_f32(accf0, vcvtq_f32_u32(acc0), scaleA0B); + accf1 = vfmaq_n_f32(accf1, vcvtq_f32_u32(acc1), scaleA1B); + + ++QuantAScalePtr; + ++QuantBScalePtr; + } + + float32_t accf0v = vaddvq_f32(accf0); + float32_t accf1v = vaddvq_f32(accf1); + + if (BiasPtr != nullptr) { + const float bias = *BiasPtr; + accf0v += bias; + accf1v += bias; + } + + *SumPtr = accf0v; + *(SumPtr + ldc) = accf1v; + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr ? 1 : 0; + ++SumPtr; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol = BlockCountK * BlkLen; + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; ++n) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleB = *QuantBScalePtr; + const float scaleA0B = scaleB * scaleA0; + + uint32x4_t acc0 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + + uint8x16_t bv_packed = vld1q_u8(QuantBDataPtr); + + acc0 = vdotq_u32(acc0, bv_packed, av0_16_i8); + + QuantAPtr += KStep16; + QuantBDataPtr += KStep16; + } + + accf0 = vfmaq_n_f32(accf0, vcvtq_f32_u32(acc0), scaleA0B); + + ++QuantAScalePtr; + ++QuantBScalePtr; + } + + float32_t accf0v = vaddvq_f32(accf0); + + if (BiasPtr != nullptr) { + const float bias = *BiasPtr; + accf0v += bias; + } + + *SumPtr = accf0v; + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr ? 1 : 0; + ++SumPtr; + } + } +} + +template <> +size_t +MlasQ8Int8GemmKernelNeon( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float * QuantBScale, + float* C, + const size_t CountM, + const size_t CountN, + const size_t CountK, + const float* Bias, + const size_t ldc +) { + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols8 = 8; + constexpr size_t NCols4 = 4; + constexpr size_t MRows2 = 2; + const size_t BlockCountK = MlasDivRoundup(CountK, BlkLen); + + const size_t lda = BlockCountK * BlkLen; + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % MRows2; + size_t multipleRows = CountM - remainingRows; + size_t multipleCols8 = CountN & (~(NCols8 - 1)); + size_t multipleCols4 = CountN & (~(NCols4 - 1)); + size_t remainingCols4 = CountN % NCols4; + + if (multipleRows > 0 && multipleCols8 > 0) { + Q8Int8GemmR2xC8DotProd( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols8, + BlockCountK, + Bias, + ldc + ); + } + + if (multipleRows > 0 && multipleCols4 > multipleCols8) { + Q8Int8GemmR2xC4DotProd( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols8 * StrideQuantBData, + QuantBScale + multipleCols8 * StrideQuantBScale, + C + multipleCols8, + multipleRows, + multipleCols4 - multipleCols8, + BlockCountK, + Bias ? Bias + multipleCols8 : nullptr, + ldc + ); + } + + if (multipleRows > 0 && remainingCols4 > 0) { + Q8Int8GemmR2xC1DotProd( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols4 * StrideQuantBData, + QuantBScale + multipleCols4 * StrideQuantBScale, + C + multipleCols4, + multipleRows, + remainingCols4, + BlockCountK, + Bias ? Bias + multipleCols4 : nullptr, + ldc + ); + } + + if (remainingRows > 0 && multipleCols8 > 0) { + Q8Int8GemmR1xC8DotProd( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols8, + BlockCountK, + Bias, + ldc); + } + + if (remainingRows > 0 && multipleCols4 > multipleCols8) { + Q8Int8GemmR1xC4DotProd( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols8 * StrideQuantBData, + QuantBScale + multipleCols8 * StrideQuantBScale, + C + multipleRows * ldc + multipleCols8, + remainingRows, + multipleCols4 - multipleCols8, + BlockCountK, + Bias ? Bias + multipleCols8 : nullptr, + ldc); + } + + if (remainingRows > 0 && remainingCols4 > 0) { + Q8Int8GemmR1xC1DotProd( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols4 * StrideQuantBData, + QuantBScale + multipleCols4 * StrideQuantBScale, + C + multipleRows * ldc + multipleCols4, + remainingRows, + remainingCols4, + BlockCountK, + Bias ? Bias + multipleCols4 : nullptr, + ldc); + } + + return CountM; +} + #ifdef USE_KLEIDIAI void SQ4BitGemmKernel_Packed_CompInt8( diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8_i8mm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8_i8mm.cpp new file mode 100644 index 0000000000000..db040dbb9a08c --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8_i8mm.cpp @@ -0,0 +1,743 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon_int8_i8mm.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON specific to + input type T1 as float32 and + MLAS_QNBIT_GEMM_COMPUTE_TYPE SQNBIT_CompInt8 + using i8mm instructions. + +--*/ + +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" + +namespace sqnbitgemm_neon +{ + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC8I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t NRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol8 = BlockCountK * BlkLen * NCols8; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols8 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols8) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf0_47 = vdupq_n_f32(0.0f); + float32x4_t accf1_03 = vdupq_n_f32(0.0f); + float32x4_t accf1_47 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float32x4_t scaleB03 = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleB47 = vld1q_f32(QuantBScalePtr + NCols4); + + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB03, scaleA0); + const float32x4_t scaleA0B47 = vmulq_n_f32(scaleB47, scaleA0); + const float32x4_t scaleA1B03 = vmulq_n_f32(scaleB03, scaleA1); + const float32x4_t scaleA1B47 = vmulq_n_f32(scaleB47, scaleA1); + + int32x4_t acc0_03 = vdupq_n_s32(0); + int32x4_t acc0_47 = vdupq_n_s32(0); + int32x4_t acc1_03 = vdupq_n_s32(0); + int32x4_t acc1_47 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + const int8x16_t av1_16_i8 = vld1q_s8(QuantAPtr + lda); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_0_47 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_1_47 = vld1q_u8(QuantBDataPtr + 48); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 64); + uint8x16_t bv_packed_2_47 = vld1q_u8(QuantBDataPtr + 80); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 96); + uint8x16_t bv_packed_3_47 = vld1q_u8(QuantBDataPtr + 112); + + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_0_47, av0_16_i8, 0); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_1_47, av0_16_i8, 1); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_2_47, av0_16_i8, 2); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_3_47, av0_16_i8, 3); + + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_0_03, av1_16_i8, 0); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_1_03, av1_16_i8, 1); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_2_03, av1_16_i8, 2); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_3_03, av1_16_i8, 3); + + acc1_47 = vusdotq_laneq_s32(acc1_47, bv_packed_0_47, av1_16_i8, 0); + acc1_47 = vusdotq_laneq_s32(acc1_47, bv_packed_1_47, av1_16_i8, 1); + acc1_47 = vusdotq_laneq_s32(acc1_47, bv_packed_2_47, av1_16_i8, 2); + acc1_47 = vusdotq_laneq_s32(acc1_47, bv_packed_3_47, av1_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols8 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_s32(acc0_03)); + accf0_47 = vfmaq_f32(accf0_47, scaleA0B47, vcvtq_f32_s32(acc0_47)); + accf1_03 = vfmaq_f32(accf1_03, scaleA1B03, vcvtq_f32_s32(acc1_03)); + accf1_47 = vfmaq_f32(accf1_47, scaleA1B47, vcvtq_f32_s32(acc1_47)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols8; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32_03 = vld1q_f32(BiasPtr); + const float32x4_t bias_4_f32_47 = vld1q_f32(BiasPtr + 4); + + accf0_03 = vaddq_f32(accf0_03, bias_4_f32_03); + accf0_47 = vaddq_f32(accf0_47, bias_4_f32_47); + accf1_03 = vaddq_f32(accf1_03, bias_4_f32_03); + accf1_47 = vaddq_f32(accf1_47, bias_4_f32_47); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + 4, accf0_47); + vst1q_f32(SumPtr + ldc, accf1_03); + vst1q_f32(SumPtr + ldc + 4, accf1_47); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol8; + QuantBScaleColPtr += NCols8 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols8 : 0; + SumPtr += NCols8; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC8I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol8 = BlockCountK * BlkLen * NCols8; + + assert(CountN % NCols8 == 0); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols8) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf0_47 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float32x4_t scaleB03 = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleB47 = vld1q_f32(QuantBScalePtr + NCols4); + + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB03, scaleA0); + const float32x4_t scaleA0B47 = vmulq_n_f32(scaleB47, scaleA0); + + int32x4_t acc0_03 = vdupq_n_s32(0); + int32x4_t acc0_47 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_0_47 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_1_47 = vld1q_u8(QuantBDataPtr + 48); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 64); + uint8x16_t bv_packed_2_47 = vld1q_u8(QuantBDataPtr + 80); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 96); + uint8x16_t bv_packed_3_47 = vld1q_u8(QuantBDataPtr + 112); + + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_0_47, av0_16_i8, 0); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_1_47, av0_16_i8, 1); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_2_47, av0_16_i8, 2); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_3_47, av0_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols8 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_s32(acc0_03)); + accf0_47 = vfmaq_f32(accf0_47, scaleA0B47, vcvtq_f32_s32(acc0_47)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols8; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32_03 = vld1q_f32(BiasPtr); + const float32x4_t bias_4_f32_47 = vld1q_f32(BiasPtr + 4); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32_03); + accf0_47 = vaddq_f32(accf0_47, bias_4_f32_47); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + 4, accf0_47); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol8; + QuantBScaleColPtr += NCols8 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols8 : 0; + SumPtr += NCols8; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol4 = BlockCountK * BlkLen * NCols4; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf1_03 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float32x4_t scaleB = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB, scaleA0); + const float32x4_t scaleA1B03 = vmulq_n_f32(scaleB, scaleA1); + + int32x4_t acc0_03 = vdupq_n_s32(0); + int32x4_t acc1_03 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + const int8x16_t av1_16_i8 = vld1q_s8(QuantAPtr + lda); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 48); + + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_0_03, av1_16_i8, 0); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_1_03, av1_16_i8, 1); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_2_03, av1_16_i8, 2); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_3_03, av1_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols4 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_s32(acc0_03)); + accf1_03 = vfmaq_f32(accf1_03, scaleA1B03, vcvtq_f32_s32(acc1_03)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols4; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32 = vld1q_f32(BiasPtr); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32); + accf1_03 = vaddq_f32(accf1_03, bias_4_f32); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + ldc, accf1_03); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol4; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol4 = BlockCountK * BlkLen * NCols4; + + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float32x4_t scaleB = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB, scaleA0); + + int32x4_t acc0_03 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 48); + + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols4 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_s32(acc0_03)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols4; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32 = vld1q_f32(BiasPtr); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32); + } + + vst1q_f32(SumPtr, accf0_03); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol4; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC1I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol = BlockCountK * BlkLen; + + assert(CountM % NRows2 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; ++n) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0 = vdupq_n_f32(0.0f); + float32x4_t accf1 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float scaleB = *QuantBScalePtr; + const float scaleA0B = scaleB * scaleA0; + const float scaleA1B = scaleB * scaleA1; + + int32x4_t acc0 = vdupq_n_s32(0); + int32x4_t acc1 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + const int8x16_t av1_16_i8 = vld1q_s8(QuantAPtr + lda); + + uint8x16_t bv_packed = vld1q_u8(QuantBDataPtr); + + acc0 = vusdotq_s32(acc0, bv_packed, av0_16_i8); + acc1 = vusdotq_s32(acc1, bv_packed, av1_16_i8); + + QuantAPtr += KStep16; + QuantBDataPtr += KStep16; + } + + accf0 = vfmaq_n_f32(accf0, vcvtq_f32_s32(acc0), scaleA0B); + accf1 = vfmaq_n_f32(accf1, vcvtq_f32_s32(acc1), scaleA1B); + + ++QuantAScalePtr; + ++QuantBScalePtr; + } + + float32_t accf0v = vaddvq_f32(accf0); + float32_t accf1v = vaddvq_f32(accf1); + + if (BiasPtr != nullptr) { + const float bias = *BiasPtr; + accf0v += bias; + accf1v += bias; + } + + *SumPtr = accf0v; + *(SumPtr + ldc) = accf1v; + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr ? 1 : 0; + ++SumPtr; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol = BlockCountK * BlkLen; + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; ++n) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleB = *QuantBScalePtr; + const float scaleA0B = scaleB * scaleA0; + + int32x4_t acc0 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + + uint8x16_t bv_packed = vld1q_u8(QuantBDataPtr); + + acc0 = vusdotq_s32(acc0, bv_packed, av0_16_i8); + + QuantAPtr += KStep16; + QuantBDataPtr += KStep16; + } + + accf0 = vfmaq_n_f32(accf0, vcvtq_f32_s32(acc0), scaleA0B); + + ++QuantAScalePtr; + ++QuantBScalePtr; + } + + float32_t accf0v = vaddvq_f32(accf0); + + if (BiasPtr != nullptr) { + const float bias = *BiasPtr; + accf0v += bias; + } + + *SumPtr = accf0v; + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr ? 1 : 0; + ++SumPtr; + } + } +} + +template <> +size_t +MlasQ8Int8GemmKernelNeon( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float * QuantBScale, + float* C, + const size_t CountM, + const size_t CountN, + const size_t CountK, + const float* Bias, + const size_t ldc +) { + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols8 = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + const size_t BlockCountK = MlasDivRoundup(CountK, BlkLen); + + const size_t lda = BlockCountK * BlkLen; + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t multipleCols8 = CountN & (~(NCols8 - 1)); + size_t multipleCols4 = CountN & (~(NCols4 - 1)); + size_t remainingCols4 = CountN % NCols4; + + if (multipleRows > 0 && multipleCols8 > 0) { + Q8Int8GemmR2xC8I8MM( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols8, + BlockCountK, + Bias, + ldc + ); + } + + if (multipleRows > 0 && multipleCols4 > multipleCols8) { + Q8Int8GemmR2xC4I8MM( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols8 * StrideQuantBData, + QuantBScale + multipleCols8 * StrideQuantBScale, + C + multipleCols8, + multipleRows, + multipleCols4 - multipleCols8, + BlockCountK, + Bias ? Bias + multipleCols8 : nullptr, + ldc + ); + } + + if (multipleRows > 0 && remainingCols4 > 0) { + Q8Int8GemmR2xC1I8MM( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols4 * StrideQuantBData, + QuantBScale + multipleCols4 * StrideQuantBScale, + C + multipleCols4, + multipleRows, + remainingCols4, + BlockCountK, + Bias ? Bias + multipleCols4 : nullptr, + ldc + ); + } + + if (remainingRows > 0 && multipleCols8 > 0) { + Q8Int8GemmR1xC8I8MM( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols8, + BlockCountK, + Bias, + ldc); + } + + if (remainingRows > 0 && multipleCols4 > multipleCols8) { + Q8Int8GemmR1xC4I8MM( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols8 * StrideQuantBData, + QuantBScale + multipleCols8 * StrideQuantBScale, + C + multipleRows * ldc + multipleCols8, + remainingRows, + multipleCols4 - multipleCols8, + BlockCountK, + Bias ? Bias + multipleCols8 : nullptr, + ldc); + } + + if (remainingRows > 0 && remainingCols4 > 0) { + Q8Int8GemmR1xC1I8MM( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols4 * StrideQuantBData, + QuantBScale + multipleCols4 * StrideQuantBScale, + C + multipleRows * ldc + multipleCols4, + remainingRows, + remainingCols4, + BlockCountK, + Bias ? Bias + multipleCols4 : nullptr, + ldc); + } + + return CountM; +} + +} // namespace sqnbitgemm_neon diff --git a/onnxruntime/core/optimizer/attention_fusion.cc b/onnxruntime/core/optimizer/attention_fusion.cc index 616bc1257676f..3f9b58f71bd23 100644 --- a/onnxruntime/core/optimizer/attention_fusion.cc +++ b/onnxruntime/core/optimizer/attention_fusion.cc @@ -111,7 +111,7 @@ static NodeArg& MergeQkvWeights(Graph& graph, int64_t hidden_size, utils::SetRawDataInTensorProto(initializer, result.data(), gsl::narrow(element_count) * sizeof(MLFloat16)); } - return graph_utils::AddInitializerWithExternalData(graph, initializer); + return graph_utils::AddInitializer(graph, initializer); } static NodeArg* ConvertMaskToInt32(Graph& graph, NodeArg* mask_input, ProviderType provider_type, diff --git a/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc b/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc index a98d0ea6f978b..86a7a4d6afbf8 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc @@ -189,7 +189,7 @@ NodeArg* CreateInitializerFromVector(Graph& graph, "total_count: ", total_count, " values.size(): ", values.size()); utils::SetRawDataInTensorProto(const_tensor, values.data(), values.size() * sizeof(int64_t)); - return &graph_utils::AddInitializerWithExternalData(graph, const_tensor); + return &graph_utils::AddInitializer(graph, const_tensor); } NodeArg* InsertNodesForValidIndices(Graph& graph, diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index 3d838d8aacfbb..16e8955cb4486 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -95,7 +95,7 @@ static bool ConstantFoldShapeNode(Graph& graph, Node& node) { ONNX_NAMESPACE::TensorShapeProto result_shape; result_shape.add_dim()->set_dim_value(clamped_slice_length); constant_arg_out->SetShape(result_shape); - graph_utils::AddInitializerWithExternalData(graph, shape_constant); + graph_utils::AddInitializer(graph, shape_constant); } return is_concrete_shape; // convert to constant if this is true @@ -317,11 +317,11 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, // Build the TensorProto that corresponds to the computed OrtValue and add it as initializer to the graph. auto* constant_arg_out = node->MutableOutputDefs()[fetch_idx]; const Tensor& out_tensor = ort_value.Get(); - constexpr const bool use_tensor_buffer_true = true; + constexpr const bool use_tensor_buffer_false = false; ONNX_NAMESPACE::TensorProto out_tensorproto = utils::TensorToTensorProto( out_tensor, constant_arg_out->Name(), - use_tensor_buffer_true); + use_tensor_buffer_false); ONNX_NAMESPACE::TensorShapeProto result_shape; for (auto& dim : out_tensor.Shape().GetDims()) { @@ -329,12 +329,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, } constant_arg_out->SetShape(result_shape); - // The data is too small and has been inlined. - if (!utils::HasExternalData(out_tensorproto)) { - ORT_THROW_IF_ERROR(graph.AddInitializedOrtValue(out_tensorproto, OrtValue())); - } else { - ORT_THROW_IF_ERROR(graph.AddInitializedOrtValue(out_tensorproto, ort_value)); - } + graph.AddInitializedTensor(out_tensorproto); } } } diff --git a/onnxruntime/core/optimizer/conv_add_fusion.cc b/onnxruntime/core/optimizer/conv_add_fusion.cc index c349adfccce53..6478fa7d29d5b 100644 --- a/onnxruntime/core/optimizer/conv_add_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_fusion.cc @@ -79,7 +79,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie auto new_name = graph.GenerateNodeArgName("ConvAddFusion_B_" + B_input_name); new_conv_B_tensor_proto.set_name(new_name); - NodeArg& new_conv_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto); + NodeArg& new_conv_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto); graph_utils::ReplaceNodeInput(node, 2, new_conv_B_node_arg); } else { @@ -94,7 +94,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie auto new_name = graph.GenerateNodeArgName("ConvAddFusion_Add_B_" + add_B_tensor_proto->name()); new_conv_B_tensor_proto.set_name(new_name); - NodeArg& new_add_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto); + NodeArg& new_add_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto); graph_utils::AddNodeInput(node, 2, new_add_B_node_arg); } diff --git a/onnxruntime/core/optimizer/conv_bn_fusion.cc b/onnxruntime/core/optimizer/conv_bn_fusion.cc index 8bf5420baddde..a14639631d7a1 100644 --- a/onnxruntime/core/optimizer/conv_bn_fusion.cc +++ b/onnxruntime/core/optimizer/conv_bn_fusion.cc @@ -120,10 +120,10 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff new_conv_W_tensor_proto.set_name(new_W_name); new_conv_B_tensor_proto.set_name(new_B_name); - NodeArg& new_conv_W_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_W_tensor_proto); + NodeArg& new_conv_W_node_arg = graph_utils::AddInitializer(graph, new_conv_W_tensor_proto); graph_utils::ReplaceNodeInput(node, 1, new_conv_W_node_arg); - auto& new_conv_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto); + auto& new_conv_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto); if (conv_inputs.size() == 3) { graph_utils::ReplaceNodeInput(node, 2, new_conv_B_node_arg); diff --git a/onnxruntime/core/optimizer/conv_mul_fusion.cc b/onnxruntime/core/optimizer/conv_mul_fusion.cc index dc50a150537f7..e91a00729e9db 100644 --- a/onnxruntime/core/optimizer/conv_mul_fusion.cc +++ b/onnxruntime/core/optimizer/conv_mul_fusion.cc @@ -90,7 +90,7 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef new_conv_W_tensor_proto.set_name(new_W_name); // Replace initializers of conv node - NodeArg& new_conv_W_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_W_tensor_proto); + NodeArg& new_conv_W_node_arg = graph_utils::AddInitializer(graph, new_conv_W_tensor_proto); graph_utils::ReplaceNodeInput(conv_node, 1, new_conv_W_node_arg); if (is_3d) { @@ -100,7 +100,7 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef auto new_B_name = graph.GenerateNodeArgName("ConvMulFusion_Mul_B_" + mul_B_tensor_proto->name()); new_conv_B_tensor_proto.set_name(new_B_name); - NodeArg& new_conv_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto); + NodeArg& new_conv_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto); graph_utils::ReplaceNodeInput(conv_node, 2, new_conv_B_node_arg); } diff --git a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc index 7f214e656e0ab..96f75f07e32e1 100644 --- a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc +++ b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc @@ -53,7 +53,7 @@ static void ApplyNewInputValue(Graph& graph, Node& node, QDQ::InputIndex index, auto new_name = graph.GenerateNodeArgName("DoubleQDQRemoved_" + node.InputDefs()[index]->Name()); new_input_tensor.set_name(new_name); new_input_tensor.add_dims(1); - NodeArg& new_input = graph_utils::AddInitializerWithExternalData(graph, new_input_tensor); + NodeArg& new_input = graph_utils::AddInitializer(graph, new_input_tensor); graph_utils::ReplaceNodeInput(node, index, new_input); } diff --git a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc index ad25f95ac1186..f8fd807084d38 100644 --- a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc @@ -474,7 +474,7 @@ static NodeArg* ExtractEmbedding(Graph& graph, utils::SetRawDataInTensorProto(initializer, data, gsl::narrow(element_count) * sizeof(MLFloat16)); } - NodeArg& node_arg = graph_utils::AddInitializerWithExternalData(graph, initializer); + NodeArg& node_arg = graph_utils::AddInitializer(graph, initializer); modified = true; return &node_arg; } diff --git a/onnxruntime/core/optimizer/fuse_initializers_transformer.cc b/onnxruntime/core/optimizer/fuse_initializers_transformer.cc index 388ab14dd51fe..e604c688ee033 100644 --- a/onnxruntime/core/optimizer/fuse_initializers_transformer.cc +++ b/onnxruntime/core/optimizer/fuse_initializers_transformer.cc @@ -137,12 +137,8 @@ static void FuseInitializerWithNode(Graph& graph, graph.RemoveEdge(node.Index(), next_node.Index(), 0, static_cast(next_node_arg_index)); // Add the new converted Tensor in next node as initializer potentially with external data - ONNX_NAMESPACE::TensorProto dst_tensor = utils::TensorToTensorProto(new_data.Get(), new_arg_name, true); - if (!utils::HasExternalData(dst_tensor)) { - new_data = OrtValue(); // Data is inline - } - - auto& new_arg = graph_utils::AddInitializerWithExternalData(graph, dst_tensor, std::move(new_data)); + ONNX_NAMESPACE::TensorProto dst_tensor = utils::TensorToTensorProto(new_data.Get(), new_arg_name, false); + auto& new_arg = graph_utils::AddInitializer(graph, dst_tensor); graph_utils::ReplaceNodeInput(next_node, static_cast(next_node_arg_index), new_arg); } diff --git a/onnxruntime/core/optimizer/gather_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc index 3cd06350df95d..bd730683a4c91 100644 --- a/onnxruntime/core/optimizer/gather_fusion.cc +++ b/onnxruntime/core/optimizer/gather_fusion.cc @@ -256,7 +256,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra axes_initializer_proto.add_dims(static_cast(1)); axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); axes_initializer_proto.add_int64_data(axis); - NodeArg* axes_arg = &graph_utils::AddInitializerWithExternalData(graph, axes_initializer_proto); + NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto); Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze"), "Squeeze", "Squeeze for Fused Gather nodes", {split_output_arg, axes_arg}, {original_output_arg}); @@ -272,7 +272,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); split_initializer_proto.add_dims(static_cast(split_values.size())); split_initializer_proto.mutable_int64_data()->Add(split_values.begin(), split_values.end()); - NodeArg* split_initializer_arg = &graph_utils::AddInitializerWithExternalData(graph, split_initializer_proto); + NodeArg* split_initializer_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); const auto split_node_name = graph.GenerateNodeName(nodes_to_fuse[0].get().Name() + "/GatherSliceToSplitFusion"); Node& split_node = graph.AddNode(split_node_name, "Split", "Split for Fused Gather nodes", {graph.GetNodeArg(node_arg->Name()), split_initializer_arg}, split_outputs); @@ -359,7 +359,7 @@ Status GatherToSliceFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le unsqueeze_axes_initializer_proto.add_dims(static_cast(1)); unsqueeze_axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); unsqueeze_axes_initializer_proto.add_int64_data(static_cast(0)); - NodeArg* unsqueeze_axes_arg = &graph_utils::AddInitializerWithExternalData(graph, unsqueeze_axes_initializer_proto); + NodeArg* unsqueeze_axes_arg = &graph_utils::AddInitializer(graph, unsqueeze_axes_initializer_proto); for (size_t i = 0; i < range_input_defs.size(); ++i) { Node& unsqueeze_node = graph.AddNode(graph.GenerateNodeName("Unsqueeze_" + std::to_string(i)), "Unsqueeze", @@ -386,7 +386,7 @@ Status GatherToSliceFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le } else { slice_axes_initializer_proto.add_int32_data(static_cast(axis)); } - NodeArg* slice_axes_arg = &graph_utils::AddInitializerWithExternalData(graph, slice_axes_initializer_proto); + NodeArg* slice_axes_arg = &graph_utils::AddInitializer(graph, slice_axes_initializer_proto); Node& slice_node = graph.AddNode(graph.GenerateNodeName("Slice"), "Slice", "Slice for Fused Gather nodes", {gather_node.MutableInputDefs()[0], unsqueeze_outputs[0], unsqueeze_outputs[1], slice_axes_arg, unsqueeze_outputs[2]}, diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.cc b/onnxruntime/core/optimizer/matmul_add_fusion.cc index 761fe1854274e..fed72db71332a 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_add_fusion.cc @@ -194,7 +194,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, shape_initializer_proto.add_dims(static_cast(shape.size())); shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); utils::SetRawDataInTensorProto(shape_initializer_proto, shape.data(), shape.size() * sizeof(int64_t)); - NodeArg* shape_arg = &graph_utils::AddInitializerWithExternalData(graph, shape_initializer_proto); + NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto); ONNX_NAMESPACE::TypeProto new_arg_type; const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( gemm_input_defs[0]->TypeAsProto()->tensor_type().elem_type()); diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index 725cb3fc33f04..367fb42d7928d 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -212,14 +212,14 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& matmul_b.ToProto(new_gemm_b_tensor); const std::string new_gemm_b_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmB_" + matmul_b_tensor->name()); new_gemm_b_tensor.set_name(new_gemm_b_name); - NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_gemm_b_tensor); + NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializer(graph, new_gemm_b_tensor); // create bias tensorProto for new Gemm node from initializer. ONNX_NAMESPACE::TensorProto new_gemm_bias_tensor; bias.ToProto(new_gemm_bias_tensor); const std::string new_gemm_bias_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmBias"); new_gemm_bias_tensor.set_name(new_gemm_bias_name); - NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_gemm_bias_tensor); + NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializer(graph, new_gemm_bias_tensor); Node& gemm_node = graph.AddNode( graph.GenerateNodeArgName("MatMulBnFusion_Gemm"), diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index 335209dbfadaf..f094a48e10c33 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -437,7 +437,7 @@ void NchwcTransformerImpl::TransformConv(Node& node) { nchwc_conv_W_tensor_proto.add_dims(conv_W_dims[i]); } - nchwc_conv_W_arg = &graph_utils::AddInitializerWithExternalData(graph_, nchwc_conv_W_tensor_proto); + nchwc_conv_W_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_W_tensor_proto); filters_map->emplace(input_defs[1], nchwc_conv_W_arg); } @@ -464,7 +464,7 @@ void NchwcTransformerImpl::TransformConv(Node& node) { nchwc_conv_B_tensor_proto.add_dims(nchwc_output_channels); - nchwc_conv_B_arg = &graph_utils::AddInitializerWithExternalData(graph_, nchwc_conv_B_tensor_proto); + nchwc_conv_B_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_B_tensor_proto); aligned_biases_.emplace(input_defs[2], nchwc_conv_B_arg); } } @@ -580,7 +580,7 @@ Node& NchwcTransformerImpl::InsertReshape(NodeArg* input_arg, } shape_tensor_proto.add_dims(split_channels ? kNchwcDims + 1 : kNchwcDims); - shape_arg = &graph_utils::AddInitializerWithExternalData(graph_, shape_tensor_proto); + shape_arg = &graph_utils::AddInitializer(graph_, shape_tensor_proto); } Node& reshape_node = graph_.AddNode(graph_.GenerateNodeName("Reshape"), @@ -892,7 +892,7 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) { nchwc_conv_W_tensor_proto.add_dims(1); nchwc_conv_W_tensor_proto.add_dims(1); - auto* nchwc_conv_W_arg = &graph_utils::AddInitializerWithExternalData(graph_, nchwc_conv_W_tensor_proto); + auto* nchwc_conv_W_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_W_tensor_proto); std::copy_n(bn_B.data(), channels, padded_buffer.data()); @@ -903,7 +903,7 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) { gsl::narrow(nchwc_channels) * sizeof(float)); nchwc_conv_B_tensor_proto.add_dims(nchwc_channels); - auto* nchwc_conv_B_arg = &graph_utils::AddInitializerWithExternalData(graph_, nchwc_conv_B_tensor_proto); + auto* nchwc_conv_B_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_B_tensor_proto); // Create the replacement node. std::string nchwc_node_name = graph_.GenerateNodeName(output_defs[0]->Name() + "_bn_nchwc"); diff --git a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc index 42cd31b5bd7b4..42d27de632b91 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc @@ -130,22 +130,22 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph, const log weights_proto_u8.set_name(weight_tensor_proto->name() + "_s8_2_u8"); weights_proto_u8.mutable_dims()->CopyFrom(weight_tensor_proto->dims()); utils::SetRawDataInTensorProto(weights_proto_u8, w_temp.data(), static_cast(w_temp.size())); - input_defs[w_idx] = &graph_utils::AddInitializerWithExternalData(graph, weights_proto_u8); + input_defs[w_idx] = &graph_utils::AddInitializer(graph, weights_proto_u8); ONNX_NAMESPACE::TensorProto weight_zp_proto_u8; QDQ::Int8TensorProto2Uint8(weight_zp_tensor_proto, weight_zp_proto_u8, graph, true); - input_defs[w_zp_idx] = &graph_utils::AddInitializerWithExternalData(graph, weight_zp_proto_u8); + input_defs[w_zp_idx] = &graph_utils::AddInitializer(graph, weight_zp_proto_u8); ONNX_NAMESPACE::TensorProto r_proto_u8; r_proto_u8.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8); r_proto_u8.set_name(r_tensor_proto->name() + "_s8_2_u8"); r_proto_u8.mutable_dims()->CopyFrom(r_tensor_proto->dims()); utils::SetRawDataInTensorProto(r_proto_u8, r_temp.data(), static_cast(r_temp.size())); - input_defs[r_idx] = &graph_utils::AddInitializerWithExternalData(graph, r_proto_u8); + input_defs[r_idx] = &graph_utils::AddInitializer(graph, r_proto_u8); ONNX_NAMESPACE::TensorProto r_zp_proto_u8; QDQ::Int8TensorProto2Uint8(r_zp_tensor_proto, r_zp_proto_u8, graph, true); - input_defs[r_zp_idx] = &graph_utils::AddInitializerWithExternalData(graph, r_zp_proto_u8); + input_defs[r_zp_idx] = &graph_utils::AddInitializer(graph, r_zp_proto_u8); return true; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc index 98c818b0c761b..828165e99d840 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc @@ -61,7 +61,7 @@ static bool QDQ_S8_to_U8(Graph& graph, Node& q_node, Node& dq_node) { zp_tensor_proto_u8.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8); zp_tensor_proto_u8.set_name(graph.GenerateNodeArgName("qdq_s8_to_u8_zp_conversion")); utils::SetRawDataInTensorProto(zp_tensor_proto_u8, &q_zp_value, sizeof(uint8_t)); - NodeArg* zp_u8_arg = &graph_utils::AddInitializerWithExternalData(graph, zp_tensor_proto_u8); + NodeArg* zp_u8_arg = &graph_utils::AddInitializer(graph, zp_tensor_proto_u8); auto q_output_node_arg_name = graph.GenerateNodeArgName("qdq_s8_to_u8_quant"); NodeArg* q_output_arg = &graph.GetOrCreateNodeArg(q_output_node_arg_name, nullptr); diff --git a/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.cc index 616144c0ccde0..f094f3c199f2a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.cc @@ -43,12 +43,12 @@ bool ConvertS8WeightToU8(Graph& graph, Node& op_node, // The weights fits into S7, overflow is not a problem, no need to convert to U8 return false; } - input_defs[weights_idx] = &graph_utils::AddInitializerWithExternalData(graph, weights_proto_u8); + input_defs[weights_idx] = &graph_utils::AddInitializer(graph, weights_proto_u8); // Convert weight zero point to uint8 ONNX_NAMESPACE::TensorProto weight_zp_proto_u8; Int8TensorProto2Uint8(weight_zp_tensor_proto, weight_zp_proto_u8, graph, true); - input_defs[weight_zp_idx] = &graph_utils::AddInitializerWithExternalData(graph, weight_zp_proto_u8); + input_defs[weight_zp_idx] = &graph_utils::AddInitializer(graph, weight_zp_proto_u8); return true; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index dce69e2913582..34d7ba3c79775 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -439,23 +439,23 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, } } - auto weight_T_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, true); - auto scale_T_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, true); + auto weight_T_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, false); + auto scale_T_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, false); std::optional zp_T_tp; if (zp_dst) { - zp_T_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true)); + zp_T_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, false)); } auto& input_defs = replacement_node.MutableInputDefs(); - input_defs.push_back(&graph_utils::AddInitializerWithExternalData(graph, weight_T_tp, std::move(weight_dst))); + input_defs.push_back(&graph_utils::AddInitializer(graph, weight_T_tp)); replacement_node.MutableInputArgsCount().push_back(1); - input_defs.push_back(&graph_utils::AddInitializerWithExternalData(graph, scale_T_tp, std::move(scale_dst))); + input_defs.push_back(&graph_utils::AddInitializer(graph, scale_T_tp)); replacement_node.MutableInputArgsCount().push_back(1); if (zp_T_tp) { - input_defs.push_back(&graph_utils::AddInitializerWithExternalData(graph, zp_T_tp.value(), std::move(*zp_dst))); + input_defs.push_back(&graph_utils::AddInitializer(graph, zp_T_tp.value())); replacement_node.MutableInputArgsCount().push_back(1); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc index aa6f9c5409de7..4efaec325292a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc @@ -13,6 +13,39 @@ namespace onnxruntime { +/** + * Checks whether or not the output path from a given node leads to a QuantizeLinear op, optionally, with no + * branching ReLU or Clip op in between. See also: NodeGroupSelector::GetQDQSelection() in qdq_selectors.cc. + * + * @param node The starting node to check the output path from. + * @param graph The graph containing the nodes. + * + * @return true if the path exist, false otherwise. + */ +static bool IsNoBranchPathToQuantizeLinear(const Node& node, const Graph& graph) { + const Node* current = &node; + while (true) { + // Conv / ConvTranspose / Gemm produces single output + if (current->OutputDefs().size() != 1) { + return false; + } + const std::vector& consumers = graph.GetConsumerNodes(current->OutputDefs()[0]->Name()); + // Branching or no consumer: not eligible + if (consumers.size() != 1) { + return false; + } + const Node* consumer = consumers[0]; + if (consumer->OpType() == QDQ::QOpName) { + return true; + } + // Allow ReLU or Clip, see also: NodeGroupSelector::GetQDQSelection() in qdq_selectors.cc. + if (consumer->OpType() != "Relu" && consumer->OpType() != "Clip") { + return false; + } + current = consumer; + } +} + Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { const GraphViewer graph_viewer{graph}; @@ -43,11 +76,8 @@ Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph continue; } - // Require that the node's output is consumed by a single QuantizeLinear node. - // Otherwise, if only the inputs are quantized, but not the output, then this node group would not - // be considered a QDQ node unit anyway. - std::vector children_nodes = graph.GetConsumerNodes(node.OutputDefs()[0]->Name()); - if (children_nodes.size() != 1 || children_nodes[0]->OpType() != QDQ::QOpName) { + // Check if the output path leads to QuantizeLinear with optionally ReLU or Clip op in between. + if (!IsNoBranchPathToQuantizeLinear(node, graph)) { continue; } @@ -131,14 +161,14 @@ Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph weight_scale_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_weight_scale")); weight_scale_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); weight_scale_proto.mutable_float_data()->Add(scale); - weight_scale_arg = &graph_utils::AddInitializerWithExternalData(graph, weight_scale_proto); + weight_scale_arg = &graph_utils::AddInitializer(graph, weight_scale_proto); // Weight zero point initializer. ONNX_NAMESPACE::TensorProto weight_zp_proto; weight_zp_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_weight_zp")); weight_zp_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT8); weight_zp_proto.mutable_int32_data()->Add(static_cast(zp)); - NodeArg& weight_zp_arg = graph_utils::AddInitializerWithExternalData(graph, weight_zp_proto); + NodeArg& weight_zp_arg = graph_utils::AddInitializer(graph, weight_zp_proto); // Q from float32 to int8. ONNX_NAMESPACE::TypeProto weight_q_type_proto; diff --git a/onnxruntime/core/optimizer/relu_clip_fusion.cc b/onnxruntime/core/optimizer/relu_clip_fusion.cc index efd7022ab764b..07902fde04930 100644 --- a/onnxruntime/core/optimizer/relu_clip_fusion.cc +++ b/onnxruntime/core/optimizer/relu_clip_fusion.cc @@ -97,7 +97,7 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff mutable_next_node->AddAttribute("min", 0.f); } else { // Add the initialized tensor to the graph - auto* replacement_min_nodearg = &graph_utils::AddInitializerWithExternalData(graph, replacement_min); + auto* replacement_min_nodearg = &graph_utils::AddInitializer(graph, replacement_min); // Replace the input def at the appropriate index of the Clip node auto& mutable_input_defs = mutable_next_node->MutableInputDefs(); diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index 36213609f6b61..324905f953eec 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -438,7 +438,7 @@ bool ReshapeFusion::Fuse_Subgraph(Node& reshape, Graph& graph, const logging::Lo shape_initializer_proto.add_dims(static_cast(shape_value.size())); shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); utils::SetRawDataInTensorProto(shape_initializer_proto, shape_value.data(), shape_value.size() * sizeof(int64_t)); - auto& new_node_arg = graph_utils::AddInitializerWithExternalData(graph, shape_initializer_proto); + auto& new_node_arg = graph_utils::AddInitializer(graph, shape_initializer_proto); // Safely remove concat parent nodes which have only one output for (int i = 0; i < concat_input_count; ++i) { @@ -492,7 +492,7 @@ bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph) { shape_initializer_proto.add_dims(static_cast(shape_value.size())); shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); utils::SetRawDataInTensorProto(shape_initializer_proto, shape_value.data(), shape_value.size() * sizeof(int64_t)); - NodeArg* shape_arg = &graph_utils::AddInitializerWithExternalData(graph, shape_initializer_proto); + NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto); Node& reshape_node = graph.AddNode(graph.GenerateNodeName(name + "_new_reshape"), "Reshape", "Reshape for " + name, {contiguous_reshapes[0].get().MutableInputDefs()[0], shape_arg}, {contiguous_reshapes.back().get().MutableOutputDefs()[0]}); diff --git a/onnxruntime/core/optimizer/stft_decomposition.cc b/onnxruntime/core/optimizer/stft_decomposition.cc index 74121508132dc..5c09e5225ab9c 100644 --- a/onnxruntime/core/optimizer/stft_decomposition.cc +++ b/onnxruntime/core/optimizer/stft_decomposition.cc @@ -46,7 +46,7 @@ NodeArg* AddInitializer(Graph& graph, const char* name, const int64_t (&shape)[T proto.add_dims(shape[i]); } utils::SetRawDataInTensorProto(proto, begin, element_count * sizeof(TDataType)); - return &graph_utils::AddInitializerWithExternalData(graph, proto); + return &graph_utils::AddInitializer(graph, proto); } template diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index a320de2ee7a13..cc7682b2b418d 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -383,21 +383,7 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker TensorProto new_tensor_proto = *tensor_proto; *(new_tensor_proto.mutable_name()) = new_def_name; - // Query any OrtValue existing for the original initializer - // We are checking outer scope because GetInitializer is called with true, therefore, we potentially - // have references to parent graphs. - // We are doing this so the same OrtValue is re-used in subgraphs and no copies made for big items. - constexpr const bool check_outer_scope_true = true; - OrtValue ort_value; - // The initializer can be in memory with OrtValue or it can be a flatbuffer mapped. - if (utils::HasExternalDataInMemory(new_tensor_proto) && - graph_.GetOrtValueInitializer(name, ort_value, check_outer_scope_true)) { - // Re-use the same ort_value and proto that points to the same buffer - ORT_IGNORE_RETURN_VALUE(graph_utils::AddInitializerWithExternalData(graph_, new_tensor_proto, - std::move(ort_value))); - } else { - ORT_IGNORE_RETURN_VALUE(graph_utils::AddInitializer(graph_, new_tensor_proto)); - } + ORT_IGNORE_RETURN_VALUE(graph_utils::AddInitializer(graph_, new_tensor_proto)); replacements.insert(std::make_pair(provider_def, &new_def)); } diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index 48ea54434b805..3a95d2a53e8f5 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -586,10 +586,10 @@ void ApiGraph::TransposeInitializer(std::string_view name, const std::vector& shape) { @@ -622,7 +622,7 @@ void ApiGraph::ReshapeInitializer(std::string_view name, const std::vector()->Reshape(new_shape); - } - - auto& new_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_tensor_proto, ort_value); + auto& new_node_arg = graph_utils::AddInitializer(graph, new_tensor_proto); graph_utils::ReplaceNodeWithInitializer(graph, node, new_node_arg); // Remove the Unsqueeze node and replace it with the initializer. diff --git a/onnxruntime/core/platform/apple/device_discovery.cc b/onnxruntime/core/platform/apple/device_discovery.cc new file mode 100644 index 0000000000000..767b834e38756 --- /dev/null +++ b/onnxruntime/core/platform/apple/device_discovery.cc @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/platform/device_discovery.h" + +#include +#include + +#include "core/common/logging/logging.h" + +namespace onnxruntime { + +namespace { + +constexpr auto kApplePciVendorId = 0x106B; +constexpr auto kAppleVendorName = "Apple"; + +std::vector GetGpuDevices() { + std::vector result{}; + + // For now, we assume the existence of one GPU if it is a Mac with Apple Silicon. + // TODO support iOS + // TODO support Intel Macs which may have more than one GPU +#if TARGET_OS_OSX && TARGET_CPU_ARM64 + { + OrtHardwareDevice gpu_device{}; + gpu_device.type = OrtHardwareDeviceType_GPU; + gpu_device.vendor_id = kApplePciVendorId; + gpu_device.vendor = kAppleVendorName; + + result.emplace_back(std::move(gpu_device)); + } +#endif // TARGET_OS_OSX && TARGET_CPU_ARM64 + + return result; +} + +bool HasAppleNeuralEngine() { + // Copied from onnxruntime/core/providers/coreml/builders/helper.cc:HasNeuralEngine(). + bool has_apple_neural_engine = false; + + struct utsname system_info; + uname(&system_info); + LOGS_DEFAULT(VERBOSE) << "Current Apple hardware info: " << system_info.machine; + +#if TARGET_OS_IPHONE + // utsname.machine has device identifier. For example, identifier for iPhone Xs is "iPhone11,2". + // Since Neural Engine is only available for use on A12 and later, major device version in the + // identifier is checked for these models: + // A12: iPhone XS (11,2), iPad Mini - 5th Gen (11,1) + // A12X: iPad Pro - 3rd Gen (8,1) + // For more information, see https://www.theiphonewiki.com/wiki/Models + size_t str_len = strnlen(system_info.machine, onnxruntime::kMaxStrLen); + if (str_len > 4 && strncmp("iPad", system_info.machine, 4) == 0) { + const int major_version = atoi(system_info.machine + 4); + has_apple_neural_engine = major_version >= 8; // There are no device between iPad 8 and 11. + } else if (str_len > 6 && strncmp("iPhone", system_info.machine, 6) == 0) { + const int major_version = atoi(system_info.machine + 6); + has_apple_neural_engine = major_version >= 11; + } +#elif TARGET_OS_OSX && TARGET_CPU_ARM64 + // Only Mac with arm64 CPU (Apple Silicon) has ANE. + has_apple_neural_engine = true; +#endif // #if TARGET_OS_IPHONE + + return has_apple_neural_engine; +} + +std::vector GetNpuDevices() { + std::vector result{}; + + if (HasAppleNeuralEngine()) { + OrtHardwareDevice npu_device{}; + npu_device.type = OrtHardwareDeviceType_NPU; + npu_device.vendor_id = kApplePciVendorId; + npu_device.vendor = kAppleVendorName; + + result.emplace_back(std::move(npu_device)); + } + + return result; +} + +} // namespace + +std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatform() { + std::unordered_set devices; + + // get CPU devices + devices.insert(GetCpuDeviceFromCPUIDInfo()); + + // get GPU devices + { + auto gpu_devices = GetGpuDevices(); + devices.insert(gpu_devices.begin(), gpu_devices.end()); + } + + // get NPU devices + { + auto npu_devices = GetNpuDevices(); + devices.insert(npu_devices.begin(), npu_devices.end()); + } + + return devices; +} +} // namespace onnxruntime diff --git a/onnxruntime/core/platform/device_discovery.h b/onnxruntime/core/platform/device_discovery.h index 70be10bf09e4e..b49e63b90236a 100644 --- a/onnxruntime/core/platform/device_discovery.h +++ b/onnxruntime/core/platform/device_discovery.h @@ -3,25 +3,24 @@ #pragma once -#include #include #include "core/session/abi_devices.h" + namespace onnxruntime { class DeviceDiscovery { public: - static std::unordered_set& GetDevices() { - // assumption: devices don't change. we assume the machine must be shutdown to change cpu/gpu/npu devices. - // technically someone could disable/enable a device in a running OS. we choose not to add complexity to support - // that scenario. - static std::unordered_set devices(DiscoverDevicesForPlatform()); - return devices; - } + static const std::unordered_set& GetDevices(); private: DeviceDiscovery() = default; + // platform specific code implements this method static std::unordered_set DiscoverDevicesForPlatform(); + + // Gets a CPU device by querying `CPUIDInfo`. + static OrtHardwareDevice GetCpuDeviceFromCPUIDInfo(); }; + } // namespace onnxruntime diff --git a/onnxruntime/core/platform/device_discovery_common.cc b/onnxruntime/core/platform/device_discovery_common.cc new file mode 100644 index 0000000000000..dcba31aed6fec --- /dev/null +++ b/onnxruntime/core/platform/device_discovery_common.cc @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file contains platform-agnostic device discovery implementation. + +#include "core/platform/device_discovery.h" + +#include + +#include "core/common/cpuid_info.h" +#include "core/common/logging/logging.h" + +namespace onnxruntime { + +const std::unordered_set& DeviceDiscovery::GetDevices() { + // assumption: devices don't change. we assume the machine must be shutdown to change cpu/gpu/npu devices. + // technically someone could disable/enable a device in a running OS. we choose not to add complexity to support + // that scenario. + static std::unordered_set devices = []() { + auto discovered_devices = DiscoverDevicesForPlatform(); + + // log discovered devices + for (const auto& ortdevice : discovered_devices) { + std::ostringstream oss; + oss << "Discovered OrtHardwareDevice {vendor_id:0x" << std::hex << ortdevice.vendor_id + << ", device_id:0x" << ortdevice.device_id + << ", vendor:" << ortdevice.vendor + << ", type:" << std::dec << static_cast(ortdevice.type) + << ", metadata: ["; + for (auto& [key, value] : ortdevice.metadata.Entries()) { + oss << key << "=" << value << ", "; + } + oss << "]}"; + LOGS_DEFAULT(INFO) << oss.str(); + } + + return discovered_devices; + }(); + + return devices; +} + +OrtHardwareDevice DeviceDiscovery::GetCpuDeviceFromCPUIDInfo() { + const auto& cpuid_info = CPUIDInfo::GetCPUIDInfo(); + + OrtHardwareDevice cpu_device{}; + cpu_device.vendor = cpuid_info.GetCPUVendor(); + cpu_device.vendor_id = cpuid_info.GetCPUVendorId(); + cpu_device.device_id = 0; + cpu_device.type = OrtHardwareDeviceType_CPU; + + return cpu_device; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/platform/posix/device_discovery.cc b/onnxruntime/core/platform/device_discovery_default.cc similarity index 57% rename from onnxruntime/core/platform/posix/device_discovery.cc rename to onnxruntime/core/platform/device_discovery_default.cc index 82564539ab5d4..73ddf516034ab 100644 --- a/onnxruntime/core/platform/posix/device_discovery.cc +++ b/onnxruntime/core/platform/device_discovery_default.cc @@ -4,14 +4,16 @@ #include "core/platform/device_discovery.h" namespace onnxruntime { + std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatform() { - std::unordered_set devices; - // get CPU devices + // This is a default implementation. + // We assume that there is a CPU device and do not attempt to discover anything else. - // get GPU devices + std::unordered_set devices{}; - // get NPU devices + devices.emplace(GetCpuDeviceFromCPUIDInfo()); return devices; } + } // namespace onnxruntime diff --git a/onnxruntime/core/platform/linux/device_discovery.cc b/onnxruntime/core/platform/linux/device_discovery.cc new file mode 100644 index 0000000000000..6a02a1b46028f --- /dev/null +++ b/onnxruntime/core/platform/linux/device_discovery.cc @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/platform/device_discovery.h" + +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/common/logging/logging.h" +#include "core/common/parse_string.h" +#include "core/common/string_utils.h" + +namespace fs = std::filesystem; + +namespace onnxruntime { + +namespace { + +Status ErrorCodeToStatus(const std::error_code& ec) { + if (!ec) { + return Status::OK(); + } + + return Status{common::StatusCategory::ONNXRUNTIME, common::StatusCode::FAIL, + MakeString("Error: std::error_code with category name: ", ec.category().name(), + ", value: ", ec.value(), ", message: ", ec.message())}; +} + +struct GpuSysfsPathInfo { + size_t card_idx; + fs::path path; +}; + +Status DetectGpuSysfsPaths(std::vector& gpu_sysfs_paths_out) { + std::error_code error_code{}; + const fs::path sysfs_class_drm_path = "/sys/class/drm"; + const bool sysfs_class_drm_path_exists = fs::exists(sysfs_class_drm_path, error_code); + ORT_RETURN_IF_ERROR(ErrorCodeToStatus(error_code)); + + if (!sysfs_class_drm_path_exists) { + gpu_sysfs_paths_out = std::vector{}; + return Status::OK(); + } + + const auto detect_card_path = [](const fs::path& sysfs_path, size_t& card_idx) -> bool { + const auto filename = sysfs_path.filename(); + const auto filename_str = std::string_view{filename.native()}; + + // Look for a filename matching "cardN". N is a number. + constexpr std::string_view prefix = "card"; + if (filename_str.find(prefix) != 0) { + return false; + } + + size_t parsed_card_idx{}; + if (!TryParseStringWithClassicLocale(filename_str.substr(prefix.size()), parsed_card_idx)) { + return false; + } + + card_idx = parsed_card_idx; + return true; + }; + + std::vector gpu_sysfs_paths{}; + + auto dir_iterator = fs::directory_iterator{sysfs_class_drm_path, error_code}; + ORT_RETURN_IF_ERROR(ErrorCodeToStatus(error_code)); + + for (const auto& dir_item : dir_iterator) { + const auto& dir_item_path = dir_item.path(); + + if (size_t card_idx{}; detect_card_path(dir_item_path, card_idx)) { + GpuSysfsPathInfo path_info{}; + path_info.card_idx = card_idx; + path_info.path = dir_item_path; + gpu_sysfs_paths.emplace_back(std::move(path_info)); + } + } + + gpu_sysfs_paths_out = std::move(gpu_sysfs_paths); + return Status::OK(); +} + +Status ReadFileContents(const fs::path& file_path, std::string& contents) { + std::ifstream file{file_path}; + ORT_RETURN_IF_NOT(file, "Failed to open file: ", file_path); + std::istreambuf_iterator file_begin{file}, file_end{}; + contents.assign(file_begin, file_end); + return Status::OK(); +} + +template +Status ReadValueFromFile(const fs::path& file_path, ValueType& value) { + std::string file_text{}; + ORT_RETURN_IF_ERROR(ReadFileContents(file_path, file_text)); + file_text = utils::TrimString(file_text); + return ParseStringWithClassicLocale(file_text, value); +} + +Status GetGpuDeviceFromSysfs(const GpuSysfsPathInfo& path_info, OrtHardwareDevice& gpu_device_out) { + OrtHardwareDevice gpu_device{}; + const auto& sysfs_path = path_info.path; + + // vendor id + { + const auto vendor_id_path = sysfs_path / "device" / "vendor"; + ORT_RETURN_IF_ERROR(ReadValueFromFile(vendor_id_path, gpu_device.vendor_id)); + } + + // TODO vendor name + + // device id + { + const auto device_id_path = sysfs_path / "device" / "device"; + ORT_RETURN_IF_ERROR(ReadValueFromFile(device_id_path, gpu_device.device_id)); + } + + // metadata + gpu_device.metadata.Add("card_idx", MakeString(path_info.card_idx)); + // TODO is card discrete? + + gpu_device.type = OrtHardwareDeviceType_GPU; + + gpu_device_out = std::move(gpu_device); + return Status::OK(); +} + +Status GetGpuDevices(std::vector& gpu_devices_out) { + std::vector gpu_sysfs_path_infos{}; + ORT_RETURN_IF_ERROR(DetectGpuSysfsPaths(gpu_sysfs_path_infos)); + + std::vector gpu_devices{}; + gpu_devices.reserve(gpu_sysfs_path_infos.size()); + + for (const auto& gpu_sysfs_path_info : gpu_sysfs_path_infos) { + OrtHardwareDevice gpu_device{}; + ORT_RETURN_IF_ERROR(GetGpuDeviceFromSysfs(gpu_sysfs_path_info, gpu_device)); + gpu_devices.emplace_back(std::move(gpu_device)); + } + + gpu_devices_out = std::move(gpu_devices); + return Status::OK(); +} + +} // namespace + +std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatform() { + std::unordered_set devices; + + // get CPU devices + devices.emplace(GetCpuDeviceFromCPUIDInfo()); + + // get GPU devices + { + std::vector gpu_devices{}; + Status gpu_device_discovery_status = GetGpuDevices(gpu_devices); + if (gpu_device_discovery_status.IsOK()) { + devices.insert(std::make_move_iterator(gpu_devices.begin()), + std::make_move_iterator(gpu_devices.end())); + } else { + LOGS_DEFAULT(WARNING) << "GPU device discovery failed: " << gpu_device_discovery_status.ErrorMessage(); + } + } + + // get NPU devices + // TODO figure out how to discover these + + return devices; +} +} // namespace onnxruntime diff --git a/onnxruntime/core/platform/windows/device_discovery.cc b/onnxruntime/core/platform/windows/device_discovery.cc index ff904ddb3e7e0..cf761f587ad0b 100644 --- a/onnxruntime/core/platform/windows/device_discovery.cc +++ b/onnxruntime/core/platform/windows/device_discovery.cc @@ -635,19 +635,6 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor } } - std::ostringstream oss; - oss << "Adding OrtHardwareDevice {vendor_id:0x" << std::hex << ortdevice.vendor_id - << ", device_id:0x" << ortdevice.device_id - << ", vendor:" << ortdevice.vendor - << ", type:" << std::dec << static_cast(ortdevice.type) - << ", metadata: ["; - for (auto& [key, value] : ortdevice.metadata.Entries()) { - oss << key << "=" << value << ", "; - } - - oss << "]}" << std::endl; - LOGS_DEFAULT(INFO) << oss.str(); - return ortdevice; }; diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 36c6b54a1fce0..aa237fc6441b2 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -29,6 +29,7 @@ limitations under the License. #include #include "core/common/logging/logging.h" #include "core/common/narrow.h" +#include "core/common/safeint.h" #include "core/common/span_utils.h" #include "core/platform/env.h" #include "core/platform/scoped_resource.h" @@ -439,30 +440,28 @@ Status WindowsEnv::MapFileIntoMemory(_In_z_ const ORTCHAR_T* file_path, SYSTEM_INFO sysinfo; GetSystemInfo(&sysinfo); - static const DWORD page_size = sysinfo.dwPageSize; static const DWORD allocation_granularity = sysinfo.dwAllocationGranularity; - const FileOffsetType offset_to_page = offset % static_cast(page_size); - const size_t mapped_length = length + static_cast(offset_to_page); - const FileOffsetType mapped_offset = offset - offset_to_page; - if (mapped_offset % allocation_granularity != 0) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "mapped offset must be a multiple of the allocation granularity", - " , mapped_offset = ", mapped_offset, - " , allocation_granularity = ", allocation_granularity, - " , errcode = ", error_code, - " - ", std::system_category().message(error_code)); - } + const FileOffsetType offset_to_granularity = offset % static_cast(allocation_granularity); + const SIZE_T mapped_length = SafeInt(offset_to_granularity) + length; + const FileOffsetType mapped_offset = offset - offset_to_granularity; + assert((mapped_offset % allocation_granularity) == 0); void* const mapped_base = MapViewOfFile(file_mapping_handle.get(), FILE_MAP_READ, static_cast((mapped_offset >> 32) & 0xFFFFFFFF), static_cast(mapped_offset & 0xFFFFFFFF), mapped_length); - GSL_SUPPRESS(r.11) + + if (mapped_base == nullptr) { + const auto error_code = GetLastError(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "MapViewOfFile ", ToUTF8String(Basename(file_path)), + " fail, errcode = ", error_code, + " - ", std::system_category().message(error_code)); + } mapped_memory = - MappedMemoryPtr{reinterpret_cast(mapped_base) + offset_to_page, + MappedMemoryPtr{reinterpret_cast(mapped_base) + offset_to_granularity, [mapped_base](void*) { UnmapFile(mapped_base); }}; diff --git a/onnxruntime/core/providers/cuda/cuda_graph.cc b/onnxruntime/core/providers/cuda/cuda_graph.cc index 8353c654681fc..88e58aec70550 100644 --- a/onnxruntime/core/providers/cuda/cuda_graph.cc +++ b/onnxruntime/core/providers/cuda/cuda_graph.cc @@ -72,7 +72,7 @@ void CUDAGraphManager::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id cuda_graph_set_.Put(cuda_graph_annotation_id, graph_exec); } -Status CUDAGraphManager::Replay(CudaGraphAnnotation_t cuda_graph_annotation_id) { +Status CUDAGraphManager::Replay(CudaGraphAnnotation_t cuda_graph_annotation_id, bool sync_status_flag) { // Although this function is not thread safe, the lock is not needed here because // CUDA EP maintains a separate cuda graph per thread LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_ << " with cuda_graph_annotation_id " @@ -81,7 +81,9 @@ Status CUDAGraphManager::Replay(CudaGraphAnnotation_t cuda_graph_annotation_id) cudaGraphExec_t graph_exec = cuda_graph_set_.Get(cuda_graph_annotation_id); CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec, stream_)); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); + if (sync_status_flag) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); + } return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/cuda_graph.h b/onnxruntime/core/providers/cuda/cuda_graph.h index 064b526e604bc..6b61a66671de4 100644 --- a/onnxruntime/core/providers/cuda/cuda_graph.h +++ b/onnxruntime/core/providers/cuda/cuda_graph.h @@ -38,7 +38,7 @@ struct CUDAGraphManager { void SetStream(cudaStream_t stream); void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id); void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id); - Status Replay(CudaGraphAnnotation_t cuda_graph_annotation_id); + Status Replay(CudaGraphAnnotation_t cuda_graph_annotation_id, bool sync_status_flag = true); void Reset(); diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index e8d133779f33c..3b361f155831b 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -586,7 +586,7 @@ struct CudaSyncNotificationImpl : OrtSyncNotificationImpl { Release = ReleaseImpl; } - cudaStream_t& stream_; + cudaStream_t stream_; cudaEvent_t event_; const OrtApi& ort_api; @@ -632,9 +632,9 @@ struct CudaSyncStreamImpl : OrtSyncStreamImpl { *notification_impl = nullptr; std::unique_ptr notification; - cudaStream_t* cuda_stream = static_cast(impl.stream_.GetHandle()); + cudaStream_t cuda_stream = static_cast(impl.stream_.GetHandle()); - RETURN_IF_ERROR(CudaSyncNotificationImpl::Create(*cuda_stream, impl.ort_api, notification)); + RETURN_IF_ERROR(CudaSyncNotificationImpl::Create(cuda_stream, impl.ort_api, notification)); *notification_impl = notification.release(); return nullptr; @@ -734,6 +734,10 @@ struct CudaEpFactory : OrtEpFactory { } */ + // guard against bad device discovery. max devices we expect to add is num_cuda_devices. if we're attempting + // to add more than that we have duplicates in the `devices` array. + max_ep_devices = std::min(max_ep_devices, static_cast(num_cuda_devices)); + int16_t device_id = 0; for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *devices[i]; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 9611cb82d5a62..6d8d5453b9fc0 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -927,7 +927,7 @@ namespace Dml bool IsGpuTensor(const onnxruntime::Tensor& tensor) { - return strcmp(tensor.Location().name, onnxruntime::CPU) && + return strcmp(tensor.Location().name.c_str(), onnxruntime::CPU) && !(tensor.Location().mem_type == ::OrtMemType::OrtMemTypeCPUOutput || tensor.Location().mem_type == ::OrtMemType::OrtMemTypeCPUInput); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index c601ee3c1d5e6..fe52f27b35bb8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -98,7 +98,7 @@ namespace Windows::AI::MachineLearning::Adapter bool IsAllocationInterface(const ::OrtMemoryInfo& info) { - return strcmp(info.name, onnxruntime::CPU) && !(info.mem_type == ::OrtMemType::OrtMemTypeCPUOutput || info.mem_type == ::OrtMemType::OrtMemTypeCPUInput); + return strcmp(info.name.c_str(), onnxruntime::CPU) && !(info.mem_type == ::OrtMemType::OrtMemTypeCPUOutput || info.mem_type == ::OrtMemType::OrtMemTypeCPUInput); } // Translate the data object stored in a tensor to the type which will be returned through @@ -1774,7 +1774,9 @@ namespace Windows::AI::MachineLearning::Adapter } // tells caller whether this tensor is in CPU memory - return !strcmp(m_impl->Location().name, onnxruntime::CPU) || m_impl->Location().mem_type == ::OrtMemType::OrtMemTypeCPUOutput || m_impl->Location().mem_type == ::OrtMemType::OrtMemTypeCPUInput; + return !strcmp(m_impl->Location().name.c_str(), onnxruntime::CPU) + || m_impl->Location().mem_type == ::OrtMemType::OrtMemTypeCPUOutput + || m_impl->Location().mem_type == ::OrtMemType::OrtMemTypeCPUInput; } bool STDMETHODCALLTYPE TensorWrapper::IsDataInterface() const noexcept diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.h b/onnxruntime/core/providers/migraphx/gpu_data_transfer.h index 5918716b3e77f..a4eb8efd2afea 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.h +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.h @@ -3,7 +3,7 @@ #pragma once -#include "migraphx_inc.h" +#include "core/providers/migraphx/migraphx_inc.h" #include "core/framework/data_transfer.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc index cf9f44f4cd8f0..911a1a7fd18b9 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc @@ -2,12 +2,11 @@ // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" -#include "migraphx_call.h" -#include "migraphx_allocator.h" +#include "core/providers/migraphx/migraphx_call.h" +#include "core/providers/migraphx/migraphx_allocator.h" #include "core/common/status.h" #include "core/framework/float16.h" -#include "core/common/status.h" -#include "gpu_data_transfer.h" +#include "core/providers/migraphx/gpu_data_transfer.h" namespace onnxruntime { @@ -18,7 +17,7 @@ void MIGraphXAllocator::CheckDevice() const { int current_device; auto hip_err = hipGetDevice(¤t_device); if (hip_err == hipSuccess) { - ORT_ENFORCE(current_device == Info().id); + ORT_ENFORCE(current_device == Info().device.Id()); } #endif } @@ -55,7 +54,9 @@ void MIGraphXExternalAllocator::Free(void* p) { auto it = reserved_.find(p); if (it != reserved_.end()) { reserved_.erase(it); - if (empty_cache_) empty_cache_(); + if (empty_cache_ != nullptr) { + empty_cache_(); + } } } diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.h b/onnxruntime/core/providers/migraphx/migraphx_allocator.h index f6b7788e0604c..10e06ab2f35ad 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.h +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.h @@ -3,9 +3,9 @@ #pragma once +#include #include #include "core/framework/allocator.h" -#include namespace onnxruntime { diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.cc b/onnxruntime/core/providers/migraphx/migraphx_call.cc index 9807cd646e51c..79dfb5512d3b5 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_call.cc @@ -1,13 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include + #ifdef _WIN32 #include #else #include #endif -#include #include "core/common/common.h" #include "core/common/status.h" #include "core/providers/shared_library/provider_api.h" @@ -15,10 +17,9 @@ namespace onnxruntime { -using namespace common; - +namespace { template -const char* RocmErrString(ERRTYPE x) { +std::string_view RocmErrString(ERRTYPE x) { ORT_NOT_IMPLEMENTED(); } @@ -27,14 +28,16 @@ const char* RocmErrString(ERRTYPE x) { return #x template <> -const char* RocmErrString(hipError_t x) { +std::string_view RocmErrString(hipError_t x) { (void)hipDeviceSynchronize(); - return hipGetErrorString(x); + return std::string_view{hipGetErrorString(x)}; } +} // namespace + template std::conditional_t RocmCall( - ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line) { + ERRTYPE retCode, std::string_view exprString, std::string_view libName, ERRTYPE successCode, std::string_view msg, std::string_view file, const int line) { if (retCode != successCode) { try { #ifdef _WIN32 @@ -47,17 +50,16 @@ std::conditional_t RocmCall( int currentHipDevice; (void)hipGetDevice(¤tHipDevice); (void)hipGetLastError(); // clear last HIP error - static char str[1024]; - snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=%s ; file=%s ; line=%d ; expr=%s; %s", - libName, (int)retCode, RocmErrString(retCode), currentHipDevice, - hostname.c_str(), - file, line, exprString, msg); + std::stringstream ss; + ss << libName << " failure " << static_cast(retCode) << ": " << RocmErrString(retCode) + << "; GPU=" << currentHipDevice << "; hostname=" << hostname << "; file=" << file << "; line=" << line + << "; expr=" << exprString << "; " << msg; if constexpr (THRW) { // throw an exception with the error info - ORT_THROW(str); + ORT_THROW(ss.str()); } else { - LOGS_DEFAULT(ERROR) << str; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, str); + LOGS_DEFAULT(ERROR) << ss.str(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ss.str()); } } catch (const std::exception& e) { // catch, log, and rethrow since HIP code sometimes hangs in destruction, so we'd never get to see the error if constexpr (THRW) { @@ -73,7 +75,7 @@ std::conditional_t RocmCall( } } -template Status RocmCall(hipError_t retCode, const char* exprString, const char* libName, hipError_t successCode, const char* msg, const char* file, const int line); -template void RocmCall(hipError_t retCode, const char* exprString, const char* libName, hipError_t successCode, const char* msg, const char* file, const int line); +template Status RocmCall(hipError_t retCode, std::string_view exprString, std::string_view libName, hipError_t successCode, std::string_view msg, std::string_view file, int line); +template void RocmCall(hipError_t retCode, std::string_view exprString, std::string_view libName, hipError_t successCode, std::string_view msg, std::string_view file, int line); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.h b/onnxruntime/core/providers/migraphx/migraphx_call.h index 6d514e01aea96..9c3b5c79a947b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.h +++ b/onnxruntime/core/providers/migraphx/migraphx_call.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once -#include "migraphx_inc.h" +#include "core/providers/migraphx/migraphx_inc.h" #include "core/common/common.h" namespace onnxruntime { @@ -13,7 +13,7 @@ namespace onnxruntime { template std::conditional_t RocmCall( - ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line); + ERRTYPE retCode, std::string_view exprString, std::string_view libName, ERRTYPE successCode, std::string_view msg, std::string_view file, int line); #define HIP_CALL(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define HIP_CALL_THROW(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 41b55e3baf508..a59347841be95 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1,26 +1,34 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License -#include + +#include + #include +#include +#include +#include #include -#include +#include +#include #include -#include +#include +#include +#include +#include +#include #include "core/providers/shared_library/provider_api.h" #define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/common/safeint.h" #include "core/common/logging/severity.h" -#include "migraphx_execution_provider.h" -#include "migraphx_execution_provider_info.h" -#include "migraphx_execution_provider_utils.h" -#include "migraphx_allocator.h" -#include "gpu_data_transfer.h" -#include -#include "migraphx_call.h" - -#include "migraphx_stream_handle.h" +#include "core/providers/migraphx/migraphx_execution_provider.h" +#include "core/providers/migraphx/migraphx_execution_provider_info.h" +#include "core/providers/migraphx/migraphx_execution_provider_utils.h" +#include "core/providers/migraphx/migraphx_allocator.h" +#include "core/providers/migraphx/gpu_data_transfer.h" +#include "core/providers/migraphx/migraphx_call.h" +#include "core/providers/migraphx/migraphx_stream_handle.h" #if defined(_MSC_VER) #pragma warning(disable : 4244 4245) @@ -105,240 +113,144 @@ std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() c return s_kernel_registry; } -MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, - info.device_id)}, - info_(info) { - InitProviderOrtApi(); - get_flags_from_session_info(info); - metadef_id_generator_ = ModelMetadefIdGenerator::Create(); - get_flags_from_env(); +static std::string_view GetArenaExtendStrategyName(ArenaExtendStrategy strategy) { + switch (strategy) { + case ArenaExtendStrategy::kNextPowerOfTwo: + return "kNextPowerOfTwo"; + case ArenaExtendStrategy::kSameAsRequested: + return "kSameAsRequested"; + default: + return "Unknown"; + } } -MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { -} +#define GET_ENV(variable, value, ...) \ + const auto value##env{GetEnvironmentVar(variable)}; \ + if (!value##env.empty()) { \ + __VA_ARGS__; \ + LOGS_DEFAULT(INFO) << "\n " << variable << ": " << value##env; \ + } -void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info) { - // Set GPU device to be used - HIP_CALL_THROW(hipSetDevice(info_.device_id)); - t_ = migraphx::target(info.target_device.c_str()); +#define GET_ENV_BOOL(variable, value) \ + GET_ENV(variable, value, value = std::stoi(value##env) != 0) - // Quantization - fp16_enable_ = info.fp16_enable; +#define GET_ENV_STRING(variable, value) \ + GET_ENV(variable, value, value = value##env) +MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) + : IExecutionProvider{kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, info.device_id)}, + device_id_{info.device_id}, + fp16_enable_{info.fp16_enable}, +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && (HIP_VERSION_MINOR > 4 || (HIP_VERSION_MINOR == 4 && HIP_VERSION_PATCH >= 2))) + bf16_enable_{info.bf16_enable}, +#endif #if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4) - fp8_enable_ = info.fp8_enable; -#else - LOGS_DEFAULT(WARNING) << "MIGraphX: FP8 Quantization requires ROCm 6.4 or greater"; - fp8_enable_ = false; + fp8_enable_{info.fp8_enable}, #endif - int8_enable_ = info.int8_enable; - - if (int8_enable_ and fp8_enable_) { - LOGS_DEFAULT(FATAL) << "MIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; - } - - if (int8_enable_ xor fp8_enable_) { - int8_calibration_cache_name_ = info.int8_calibration_table_name; - int8_use_native_migraphx_calibration_table_ = info.int8_use_native_calibration_table; - } - - if (int8_enable_ or fp8_enable_) { - int8_calibration_cache_available_ = !info.int8_calibration_table_name.empty(); - } + int8_enable_{info.int8_enable}, + model_cache_path_{info.model_cache_dir}, + t_{info.target_device.c_str()}, + exhaustive_tune_{info.exhaustive_tune}, + metadef_id_generator_{ModelMetadefIdGenerator::Create()}, + external_alloc_{info.external_alloc}, + external_free_{info.external_free}, + external_empty_cache_{info.external_empty_cache} { + InitProviderOrtApi(); - // Load INT8 calibration table - std::unordered_map dynamic_range_map; - if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { - const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); - if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { - throw std::runtime_error("Session Failed to read INT8 calibration table " + calibration_cache_path); - } - } + // Set GPU device to be used and read device properties for feature usage. - // Save/load migraphx compiled models - save_compiled_model_ = info.save_compiled_model; - save_compiled_path_ = info.save_model_file; - load_compiled_model_ = info.load_compiled_model; - load_compiled_path_ = info.load_model_file; + HIP_CALL_THROW(hipSetDevice(device_id_)); + HIP_CALL_THROW(hipGetDeviceProperties(&device_prop_, device_id_)); - exhaustive_tune_ = info.exhaustive_tune; + // Overwrite initialized values with values from environment variables. - LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX provider Session Options:"; - print_migraphx_ep_flags(); -} + LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX ENV Override Variables Set:"; + GET_ENV_BOOL(migraphx_env_vars::kFP16Enable, fp16_enable_); +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && (HIP_VERSION_MINOR > 4 || (HIP_VERSION_MINOR == 4 && HIP_VERSION_PATCH >= 2))) + GET_ENV_BOOL(migraphx_env_vars::kBF16Enable, bf16_enable_); +#endif +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4) + GET_ENV_BOOL(migraphx_env_vars::kFP8Enable, fp8_enable_); +#endif + GET_ENV_BOOL(migraphx_env_vars::kINT8Enable, int8_enable_); + GET_ENV(migraphx_env_vars::kINT8CalibrationTableName, int8_calibration_cache_name_); + GET_ENV(migraphx_env_vars::kINT8UseNativeMIGraphXCalibrationTable, int8_use_native_migraphx_calibration_table_); + GET_ENV_STRING(migraphx_env_vars::kCachePath, calibration_cache_path_); + GET_ENV_STRING(migraphx_env_vars::kModelCachePath, model_cache_path_); + GET_ENV_BOOL(migraphx_env_vars::kDumpModelOps, dump_model_ops_); + GET_ENV_BOOL(migraphx_env_vars::kExhaustiveTune, exhaustive_tune_); + + // Verify configuration correctness and adjust accordingly. + +#if HIP_VERSION_MAJOR < 6 || (HIP_VERSION_MAJOR == 6 && (HIP_VERSION_MINOR < 4 || (HIP_VERSION_MINOR == 4 && HIP_VERSION_PATCH < 2))) + LOGS_DEFAULT(WARNING) << "MIGraphX: BF16 Quantization requires ROCm 6.4.2 or greater"; + bf16_enable_ = false; +#endif -void MIGraphXExecutionProvider::get_flags_from_env() { - LOGS_DEFAULT(WARNING) << "\n[MIGraphX EP] MIGraphX ENV Override Variables Set:"; - // whether fp16 is enable - const std::string fp16_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP16Enable); - if (!fp16_enable_env.empty()) { - fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP16_ENABLE: " << fp16_enable_; + if (bf16_enable_ && fp16_enable_) { + bf16_enable_ = false; + fp16_enable_ = false; + LOGS_DEFAULT(FATAL) << "MIGraphX: BF16 and FP16 Quantization Mutually exclusive. Ignoring both Quantization flags"; } - // whether fp8 quantization is enabled - const std::string fp8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP8Enable); - if (!fp8_enable_env.empty()) { -#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4) - fp8_enable_ = (std::stoi(fp8_enable_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP8_ENABLE: " << fp8_enable_; -#else - LOGS_DEFAULT(WARNING) << "MIGraphX: FP8 Quantization requires ROCm 6.4 or greater"; - fp8_enable = false; +#if HIP_VERSION_MAJOR < 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR < 4) + LOGS_DEFAULT(WARNING) << "MIGraphX: FP8 Quantization requires ROCm 6.4 or greater"; + fp8_enable_ = false; #endif - } - // whether int8 is enabled - const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable); - if (!int8_enable_env.empty()) { - int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_ENABLE: " << int8_enable_; + if (int8_enable_ && fp8_enable_) { + LOGS_DEFAULT(FATAL) << "MIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; } - if (int8_enable_ and fp8_enable_) { - LOGS_DEFAULT(FATAL) << "\nMIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; + if (int8_enable_ ^ fp8_enable_) { + int8_calibration_table_name_ = + int8_calibration_cache_name_env.empty() ? info.int8_calibration_table_name : int8_calibration_cache_name_env; + int8_use_native_calibration_table_ = + int8_use_native_migraphx_calibration_table_env.empty() ? info.int8_use_native_calibration_table : std::stoi(int8_use_native_migraphx_calibration_table_env) != 0; } if (int8_enable_ || fp8_enable_) { - const std::string int8_calibration_cache_name_env = - onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName); - if (!int8_calibration_cache_name_env.empty()) { - int8_calibration_cache_name_ = int8_calibration_cache_name_env; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CALIBRATION_TABLE_NAME: " << int8_calibration_cache_name_; - } - - const std::string cache_path = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kCachePath); - if (!cache_path.empty()) { - calibration_cache_path_ = cache_path; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CACHE_PATH: " << calibration_cache_path_; - } - - const std::string int8_use_native_migraphx_calibration_table_env = - onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8UseNativeMIGraphXCalibrationTable); - if (!int8_use_native_migraphx_calibration_table_env.empty()) { - int8_use_native_migraphx_calibration_table_ = - (std::stoi(int8_use_native_migraphx_calibration_table_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE: " - << int8_use_native_migraphx_calibration_table_; - } - } - - if (int8_enable_ or fp8_enable_) { - int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); + int8_calibration_cache_available_ = !info.int8_calibration_table_name.empty(); } // Load INT8 calibration table - std::unordered_map dynamic_range_map; if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { - const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); - if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { - throw std::runtime_error("ENV Failed to read calibration table " + calibration_cache_path); + std::unordered_map dynamic_range_map; + auto calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_table_name_); + if (!ReadDynamicRange(calibration_cache_path, int8_use_native_calibration_table_, dynamic_range_map)) { + throw std::runtime_error("Session Failed to read INT8 calibration table " + calibration_cache_path.string()); } } - // Save/load migraphx compiled models - const std::string save_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSaveCompiledModel); - if (!save_comp_model_env.empty()) { - save_compiled_model_ = (std::stoi(save_comp_model_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_SAVE_COMPILED_MODEL: " << save_compiled_model_; - } - - const std::string save_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSavedModelPath); - if (save_compiled_model_ && !save_model_path_env.empty()) { - save_compiled_path_ = save_model_path_env; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_SAVE_COMPILED_PATH: " << save_compiled_path_; - } - - const std::string load_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadCompiledModel); - if (!load_comp_model_env.empty()) { - load_compiled_model_ = (std::stoi(load_comp_model_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_LOAD_COMPILED_MODEL: " << load_compiled_model_; - } - - const std::string load_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadModelPath); - if (load_compiled_model_ && !load_model_path_env.empty()) { - load_compiled_path_ = load_model_path_env; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_LOAD_COMPILED_PATH: " << load_compiled_path_; - } - - // dump unsupported ops - const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::dumpModelOps); - if (!dump_model_ops_env.empty()) { - dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_DUMP_MODEL_OPS: " << dump_model_ops_; - } + // Print configured options for the session. - // Allow for exhaustive tune during compile - const std::string exhaustive_tune_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune); - if (!exhaustive_tune_env.empty()) { - exhaustive_tune_ = (std::stoi(exhaustive_tune_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_EXHAUSTIVE_TUNE_OPS: " << exhaustive_tune_; - } -} - -void MIGraphXExecutionProvider::print_migraphx_ep_flags() { - LOGS_DEFAULT(WARNING) << "\n device_id: " << info_.device_id - << "\n migraphx_fp16_enable: " << fp16_enable_ - << "\n migraphx_fp8_enable: " << fp8_enable_ - << "\n migraphx_int8_enable: " << int8_enable_ + LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider Session Options:" + << "\n " << migraphx_provider_option::kDeviceId << ": " << device_id_ + << "\n " << migraphx_provider_option::kFp16Enable << ": " << fp16_enable_ + << "\n " << migraphx_provider_option::kBf16Enable << ": " << bf16_enable_ + << "\n " << migraphx_provider_option::kFp8Enable << ": " << fp8_enable_ + << "\n " << migraphx_provider_option::kInt8Enable << ": " << int8_enable_ + << "\n " << migraphx_provider_option::kMemLimit << ": " << mem_limit_ + << "\n " << migraphx_provider_option::kArenaExtendStrategy << ": " << GetArenaExtendStrategyName(arena_extend_strategy_) << "\n dump_model_ops: " << dump_model_ops_ - << "\n exhaustive_tune: " << exhaustive_tune_ - << "\n migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ + << "\n " << migraphx_provider_option::kExhaustiveTune << ": " << exhaustive_tune_ + << "\n " << migraphx_provider_option::kInt8CalibTable << ": " << int8_calibration_table_name_ << "\n int8_calibration_cache_available: " << int8_calibration_cache_available_ - << "\n use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_ - << "\n migraphx_save_compiled_model: " << save_compiled_model_ - << "\n migraphx_save_compiled_model_path: " << save_compiled_path_ - << "\n migraphx_load_compiled_model: " << load_compiled_model_ - << "\n migraphx_load_compiled_model_path: " << load_compiled_path_; -} - -AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, - size_t migx_mem_limit, - ArenaExtendStrategy arena_extend_strategy, - MIGraphXExecutionProviderExternalAllocatorInfo - external_allocator_info, - const OrtArenaCfg* default_memory_arena_cfg) { - if (external_allocator_info.UseExternalAllocator()) { - AllocatorCreationInfo default_memory_info( - [external_allocator_info](OrtDevice::DeviceId id) { - return std::make_unique(id, HIP, - external_allocator_info.alloc, - external_allocator_info.free, - external_allocator_info.empty_cache); - }, - device_id, - false); - - return CreateAllocator(default_memory_info); - } else { - AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId id) { - return std::make_unique(id, HIP); - }, - device_id, - true, - {default_memory_arena_cfg ? *default_memory_arena_cfg - : OrtArenaCfg(migx_mem_limit, static_cast(arena_extend_strategy), - -1, -1, -1, -1L)}, - // make it stream aware - true); - - // ROCM malloc/free is expensive so always use an arena - return CreateAllocator(default_memory_info); - } + << "\n " << migraphx_provider_option::kInt8UseNativeCalibTable << ": " << int8_use_native_calibration_table_ + << "\n " << migraphx_provider_option::kModelCacheDir << ": " << model_cache_path_; } std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, onnxruntime::CUDA); }, - info_.device_id); + [](OrtDevice::DeviceId device_id) { + return std::make_unique(device_id, onnxruntime::CUDA); + }, + device_id_); AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { - return std::make_unique(device_id, onnxruntime::CUDA_PINNED); + return std::make_unique(device_id, CUDA_PINNED); }, - info_.device_id); + device_id_); return std::vector{CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)}; } @@ -354,6 +266,7 @@ static bool IsTypeSupported(const NodeArg* node_arg) { switch (type_proto->tensor_type().elem_type()) { case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ: @@ -384,6 +297,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: mgx_type = migraphx_shape_half_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + mgx_type = migraphx_shape_bf16_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: mgx_type = migraphx_shape_float_type; break; @@ -457,7 +373,7 @@ std::vector toVector(const ONNX_NAMESPACE::int64s& nums) { static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, const Node* node) { std::vector input_nodes; const auto& optype = node->OpType(); - if (optype == "ArgMax" or optype == "ArgMin") { + if (optype == "ArgMax" || optype == "ArgMin") { const auto& attributes = node->GetAttributes(); // we do not support select_last_index = 1 for now auto sli_attr = attributes.find("select_last_index"); @@ -475,7 +391,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return true; } - if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) and + if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) && (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)) { return true; } @@ -503,7 +419,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co // storage order 1 (column major format) is not supported auto storage_order_attr = attributes.find("storage_order"); - if (storage_order_attr != attributes.end() and (*storage_order_attr).second.i() != 0) { + if (storage_order_attr != attributes.end() && (*storage_order_attr).second.i() != 0) { return true; } @@ -513,7 +429,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return true; } auto data_type = input_type->tensor_type().elem_type(); - if (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8 or + if (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8 || data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8) { return true; } @@ -524,7 +440,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return true; } - if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) and + if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) && (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)) { return true; } @@ -580,7 +496,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co } return true; } - } else if (optype == "Resize" or optype == "Upsample") { + } else if (optype == "Resize" || optype == "Upsample") { const auto& attributes = node->GetAttributes(); auto ct_attr = attributes.find("coordinate_transformation_mode"); if (ct_attr != attributes.end()) { @@ -618,7 +534,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co } const auto& attributes = node->GetAttributes(); - if (attributes.count("starts") > 0 and attributes.count("ends") > 0) { + if (attributes.count("starts") > 0 && attributes.count("ends") > 0) { auto starts = toVector((*attributes.find("starts")).second.ints()); auto ends = toVector((*attributes.find("ends")).second.ints()); for (std::size_t i = 0; i < starts.size(); ++i) { @@ -656,7 +572,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { return true; } - } else if (optype == "Unsqueeze" or optype == "Squeeze") { + } else if (optype == "Unsqueeze" || optype == "Squeeze") { const auto& args = node->InputDefs(); if (args.size() == 2) { if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { @@ -685,9 +601,9 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v if (args.size() == 2) { std::vector node_inputs; if (canEvalNodeArgument(graph_viewer, node, {1}, node_inputs)) { - return (not std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto index) { - return std::find(git.begin(), git.end(), index) != git.end(); - })); + return !std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto i) { + return std::find(git.begin(), git.end(), i) != git.end(); + }); } else { return true; } @@ -857,12 +773,14 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st erased.insert(output); } // Only when output is neither in input list nor erased list, add the output to output list - else if (erased.find(output) == erased.end()) { - if (std::find(graph_output_names.begin(), - graph_output_names.end(), output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; + else { + if (erased.find(output) == erased.end()) { + if (std::find(graph_output_names.begin(), + graph_output_names.end(), output->Name()) != graph_output_names.end()) { + graph_outputs_to_add[output] = output_order; + } + fused_outputs[output] = output_order++; } - fused_outputs[output] = output_order++; } } } @@ -944,6 +862,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "Atan", "Atanh", "ATen", + "Attention", "AveragePool", "BatchNormalization", "BiasGelu", @@ -986,6 +905,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "Greater", "GreaterOrEqual", "GroupNormalization", + "GroupNorm", "GroupQueryAttention", "HardSigmoid", "HardSwish", @@ -1017,6 +937,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "MultiHeadAttention", "Neg", "NegativeLogLikelihoodLoss", + "NhwcConv", "NonMaxSuppression", "NonZero", "Not", @@ -1053,6 +974,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "ReverseSequence", "RNN", "Roialign", + "RotaryEmbedding", "Round", "Scatter", "ScatterElements", @@ -1243,29 +1165,25 @@ bool get_input_output_names(const GraphViewer& graph, // Attempt to load a model and catch any exceptions on load fail. // Useful to default to EP to trigger the compile if file doesn't exist or loading fails. -bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::string path) { - try { - if (load_enable) { - LOGS_DEFAULT(WARNING) << "Attempting to load model at:" << path; - prog = migraphx::load(path.c_str()); - LOGS_DEFAULT(WARNING) << "load model : Success"; - return true; - } else { - return false; - } - } catch (...) { - return false; +bool load_precompiled_model(migraphx::program& prog, const std::filesystem::path& path) try { + if (!path.empty() && exists(path)) { + LOGS_DEFAULT(VERBOSE) << "Attempting to load model at:" << path.string(); + prog = migraphx::load(path.string().c_str()); + LOGS_DEFAULT(VERBOSE) << "load model : Success"; + return true; } return false; +} catch (...) { + return false; } -void save_compiled_model(migraphx::program& prog, bool save_enable, std::string out_path) { - if (save_enable) { - LOGS_DEFAULT(WARNING) << "Model Save at " << out_path << ": Begin"; +void save_compiled_model(const migraphx::program& prog, const std::filesystem::path& path) { + if (!path.empty()) { + LOGS_DEFAULT(VERBOSE) << "Model Save at " << path.string() << ": Begin"; migraphx::file_options fo; fo.set_file_format("msgpack"); - migraphx::save(prog, out_path.c_str(), fo); - LOGS_DEFAULT(WARNING) << "Model Save: Complete"; + save(prog, path.string().c_str(), fo); + LOGS_DEFAULT(VERBOSE) << "Model Save: Complete"; } } @@ -1275,12 +1193,13 @@ void calibrate_and_quantize(migraphx::program& prog, const migraphx::target& t, const migraphx::program_parameters quant_params, bool fp16_enable, + bool bf16_enable, bool int8_enable, bool fp8_enable, bool int8_calibration_cache_available, std::unordered_map& dynamic_range_map) { // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if ((int8_enable xor fp8_enable) && int8_calibration_cache_available) { + if ((int8_enable ^ fp8_enable) && int8_calibration_cache_available) { LOGS_DEFAULT(WARNING) << "Quantizing input program"; auto param_shapes = prog.get_parameter_shapes(); @@ -1317,6 +1236,14 @@ void calibrate_and_quantize(migraphx::program& prog, migraphx::quantize_fp16(prog); LOGS_DEFAULT(WARNING) << "Quantizing fp16: Complete"; } + +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4 && HIP_VERSION_PATCH >= 2) + if (bf16_enable) { + LOGS_DEFAULT(WARNING) << "Quantizing input program to bf16"; + migraphx::quantize_bf16(prog); + LOGS_DEFAULT(WARNING) << "Quantizing bf16: Complete"; + } +#endif } void compile_program(migraphx::program& prog, @@ -1330,6 +1257,27 @@ void compile_program(migraphx::program& prog, LOGS_DEFAULT(WARNING) << "Model Compile: Complete"; } +std::string to_hex(const uint64_t v) { + std::array s{}; + auto [ptr, _] = std::to_chars(s.data(), s.data() + s.size(), v, 16); + return std::string{s.data(), ptr}; +} + +template +std::string make_hash(T v) { + std::array temp{}; + MurmurHash3::x86_128(v.data(), gsl::narrow_cast(v.size()), temp[0], temp.data()); + return to_hex(temp[0] | static_cast(temp[1]) << 32); +} + +template <> +std::string make_hash(const char* v) { + return make_hash(std::string_view{v}); +} + +constexpr std::uint64_t MIGraphX_Version = + ((MIGRAPHX_VERSION_MAJOR << 16) | (MIGRAPHX_VERSION_MINOR << 8) | MIGRAPHX_VERSION_PATCH); + Status MIGraphXExecutionProvider::Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) { migraphx::onnx_options options; @@ -1337,6 +1285,33 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& for (const auto& fused_node_graph : fused_nodes) { const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; const Node& fused_node = fused_node_graph.fused_node; + + std::filesystem::path model_cache_file; + auto mxr_filename_prefix = to_hex(MIGraphX_Version) + "-" + GenerateGraphId(graph_body_viewer) + "-" + make_hash(std::string_view(device_prop_.gcnArchName)) + "-"; + + // Get model input names (only first layer) + const Graph* cur_graph = &graph_body_viewer.GetGraph(); + while (cur_graph->IsSubgraph()) { + cur_graph = cur_graph->ParentGraph(); + } + const Graph& main_graph = *cur_graph; + const auto& input_tensor = main_graph.GetInputs(); + for (auto i : input_tensor) { + session_input_names.insert(i->Name()); + } + + // empty cache path means the MXR caching is disabled - always compile + if (!model_cache_path_.empty()) { + std::vector input_shapes; + for (std::size_t i = 0; i < session_input_names.size(); ++i) { + auto tensor_shape = input_tensor[i]->Shape(); + for (int j = 1; j < tensor_shape->dim_size(); ++j) { + input_shapes.push_back(tensor_shape->dim(j).dim_value()); + } + } + model_cache_file = model_cache_path_ / (mxr_filename_prefix + make_hash(input_shapes) + ".mxr"); + } + // map parameter input name to index std::unordered_map input_name_index; const auto& input_defs = fused_node.InputDefs(); @@ -1367,15 +1342,20 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::program prog; if (!no_input_shape) { - if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { - LOGS_DEFAULT(INFO) << "No input shapes detected quantizing model"; + if (!load_precompiled_model(prog, model_cache_file)) { + LOGS_DEFAULT(VERBOSE) << "No input shapes detected quantizing model"; +#ifndef ENABLE_TRAINING_CORE +#ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH + options.set_external_data_path(model_path_.parent_path().string()); +#endif +#endif prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); migraphx::program_parameters quant_params; - calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, int8_enable_, + calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, bf16_enable_, int8_enable_, fp8_enable_, int8_calibration_cache_available_, dynamic_range_map_); compile_program(prog, t_, exhaustive_tune_); - save_compiled_model(prog, save_compiled_model_, save_compiled_path_); + save_compiled_model(prog, model_cache_file); } auto prog_output_shapes = prog.get_output_shapes(); @@ -1396,10 +1376,9 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, - map_no_input_shape_[context->node_name], fp16_enable_, fp8_enable_, int8_enable_, + map_no_input_shape_[context->node_name], fp16_enable_, bf16_enable_, fp8_enable_, int8_enable_, int8_calibration_cache_available_, dynamic_range_map_, - save_compiled_model_, save_compiled_path_, - load_compiled_model_, load_compiled_path_, dump_model_ops_}; + model_cache_path_.string(), dump_model_ops_}; *state = p.release(); return 0; }; @@ -1409,7 +1388,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& delete static_cast(state); }; - compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) { + compute_info.compute_func = [this, mxr_filename_prefix](FunctionState state, const OrtApi* api, OrtKernelContext* context) { Ort::KernelContext ctx(context); MIGraphXFuncState* mgx_state = reinterpret_cast(state); @@ -1421,6 +1400,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::onnx_options& cmp_options = mgx_state->options; bool& no_input_shape = mgx_state->no_input_shape; bool fp16_enable = mgx_state->fp16_enable; + bool bf16_enable = mgx_state->bf16_enable; bool fp8_enable = mgx_state->fp8_enable; bool int8_enable = mgx_state->int8_enable; bool int8_calibration_cache_available = mgx_state->int8_calibration_cache_available; @@ -1429,8 +1409,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // from input data bool input_shape_match = true; migraphx::program_parameter_shapes param_shapes; + std::vector input_shapes; + if (no_input_shape) { - LOGS_DEFAULT(INFO) << "Missing input shape setting input parameters again"; + LOGS_DEFAULT(VERBOSE) << "Missing input shape setting input parameters again"; for (auto& it : map_input_name_index) { auto& name = it.first; auto& index = it.second; @@ -1442,7 +1424,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& input_shape_match = false; } } else { - LOGS_DEFAULT(INFO) << "Assigning inputs, and parameters from compiled model"; + LOGS_DEFAULT(VERBOSE) << "Assigning inputs, and parameters from compiled model"; param_shapes = prog.get_parameter_shapes(); auto prog_output_shapes = prog.get_output_shapes(); @@ -1459,8 +1441,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& auto mgx_s = param_shapes[name]; auto mgx_lens = mgx_s.lengths(); auto mgx_strides = mgx_s.strides(); - if (mgx_lens.size() == 1 and mgx_lens[0] == 1 and - mgx_strides.size() == 1 and mgx_strides[0] == 0) { + if (mgx_lens.size() == 1 && mgx_lens[0] == 1 && + mgx_strides.size() == 1 && mgx_strides[0] == 0) { mgx_lens.clear(); } @@ -1468,6 +1450,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& cmp_options.set_input_parameter_shape(name, ort_lens); input_shape_match = false; } + input_shapes.insert(input_shapes.end(), tensor_shape.begin(), tensor_shape.end()); } } } @@ -1476,20 +1459,25 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // input shapes are different, needs to re-parse onnx and // re-compile the program if (!input_shape_match) { - if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { - LOGS_DEFAULT(VERBOSE) << "Input shape mismatch detected. Recompiling" << std::endl; + std::filesystem::path model_cache_file; + // empty cache path means the MXR caching is disabled - always compile + if (!model_cache_path_.empty()) { + model_cache_file = mgx_state->model_cache_dir / (mxr_filename_prefix + make_hash(input_shapes) + ".mxr"); + } + if (!load_precompiled_model(prog, model_cache_file)) { + LOGS_DEFAULT(VERBOSE) << "Input shape mismatch detected. Recompiling"; #ifndef ENABLE_TRAINING_CORE -#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 2) - cmp_options.set_external_data_path(model_path_.has_parent_path() ? model_path_.parent_path().string() : std::filesystem::current_path().string()); +#ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH + cmp_options.set_external_data_path(model_path_.parent_path().string()); #endif #endif prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); migraphx::program_parameters quant_params; - if ((int8_enable xor fp8_enable) and int8_calibration_cache_available) { - auto param_shapes = prog.get_parameter_shapes(); + if ((int8_enable ^ fp8_enable) && int8_calibration_cache_available) { + auto local_param_shapes = prog.get_parameter_shapes(); // Add input parameter data and the values they're set to - for (auto&& name : param_shapes.names()) { + for (auto&& name : local_param_shapes.names()) { if (map_input_name_index.count(name) > 0) { auto input_tensor = ctx.GetInput(map_input_name_index[name]); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); @@ -1498,19 +1486,19 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx_shape_datatype_t mgx_type; getMIGraphXType(tensor_type, mgx_type); - auto mgx_s = param_shapes[name]; + auto mgx_s = local_param_shapes[name]; if (mgx_type != mgx_s.type()) { LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; } - quant_params.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); + quant_params.add(name, migraphx::argument(local_param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } } } - calibrate_and_quantize(prog, t, quant_params, fp16_enable, int8_enable, + calibrate_and_quantize(prog, t, quant_params, fp16_enable, bf16_enable, int8_enable, fp8_enable, int8_calibration_cache_available, map_dynamic_range); compile_program(prog, t, exhaustive_tune_); - save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path); + save_compiled_model(prog, model_cache_file); } mgx_state->prog = prog; @@ -1524,7 +1512,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (param_shapes.size() > 0) { for (auto&& name : param_shapes.names()) { if (map_input_name_index.count(name) > 0) { - LOGS_DEFAULT(INFO) << "Setting parameters for:" << name; + LOGS_DEFAULT(VERBOSE) << "Setting parameters for:" << name; auto input_tensor = ctx.GetInput(map_input_name_index[name]); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shape = tensor_info.GetShape(); @@ -1538,21 +1526,21 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; } - LOGS_DEFAULT(INFO) << "Writing Raw tensor data "; + LOGS_DEFAULT(VERBOSE) << "Writing Raw tensor data "; m.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } - // It is a output argument + // It is an output argument else { - auto compute_output_index = [](const std::string& name) -> int { - std::string out_name_prefix = "#output_"; - auto pos = name.find(out_name_prefix); - if (pos == std::string::npos) { + auto compute_output_index = [](const std::string_view sv) -> int { + constexpr std::string_view out_name_prefix = "#output_"; + const auto pos = sv.find(out_name_prefix); + if (pos == std::string_view::npos) { return -1; } - std::string index_str = name.substr(pos + out_name_prefix.length()); - return std::stoi(index_str); + const auto index_str = sv.substr(pos + out_name_prefix.length()); + return ToInteger(Trim(index_str, std::isdigit)); }; int output_index = compute_output_index(name); @@ -1599,7 +1587,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& static_cast(rocm_stream))); } } - }; + } return Status::OK(); }; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index aecccdd54d697..99f790b9f9f7a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -3,33 +3,37 @@ #pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include #include "core/framework/arena_extend_strategy.h" #include "core/framework/execution_provider.h" -#include +#include "core/framework/provider_options_utils.h" #include "core/providers/migraphx/migraphx_execution_provider_info.h" #include "core/providers/migraphx/migraphx_call.h" -#include -#include -#include +using namespace std::literals::string_view_literals; namespace onnxruntime { namespace migraphx_env_vars { -static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE"; -static const char kFP8Enable[] = "ORT_MIGRAPHX_FP8_ENABLE"; -static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE"; -static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; -static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"; -static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH"; -static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"; -static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL"; -static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILED_PATH"; -static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL"; -static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILED_PATH"; -static const char kExhaustiveTune[] = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"; - -}; // namespace migraphx_env_vars +constexpr auto kFP16Enable = "ORT_MIGRAPHX_FP16_ENABLE"sv; +constexpr auto kBF16Enable = "ORT_MIGRAPHX_BF16_ENABLE"sv; +constexpr auto kFP8Enable = "ORT_MIGRAPHX_FP8_ENABLE"sv; +constexpr auto kINT8Enable = "ORT_MIGRAPHX_INT8_ENABLE"sv; +constexpr auto kDumpModelOps = "ORT_MIGRAPHX_DUMP_MODEL_OPS"sv; +constexpr auto kINT8CalibrationTableName = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"sv; +constexpr auto kCachePath = "ORT_MIGRAPHX_CACHE_PATH"sv; +constexpr auto kINT8UseNativeMIGraphXCalibrationTable = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"sv; +constexpr auto kExhaustiveTune = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"sv; +constexpr auto kModelCachePath = "ORT_MIGRAPHX_MODEL_CACHE_PATH"sv; +} // namespace migraphx_env_vars // Information to construct kernel function state. struct MIGraphXFuncState { @@ -44,14 +48,12 @@ struct MIGraphXFuncState { std::mutex* mgx_mu_ptr = nullptr; bool no_input_shape = false; bool fp16_enable = false; + bool bf16_enable = false; bool fp8_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; std::unordered_map dynamic_range_map; - bool save_compiled_mode = false; - std::string save_compiled_path; - bool load_compiled_mode = false; - std::string load_compiled_path; + std::filesystem::path model_cache_dir; bool dump_model_ops = false; bool exhaustive_tune = false; }; @@ -60,11 +62,7 @@ struct MIGraphXFuncState { class MIGraphXExecutionProvider : public IExecutionProvider { public: explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info); - ~MIGraphXExecutionProvider(); - - void get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info); - void get_flags_from_env(); - void print_migraphx_ep_flags(); + ~MIGraphXExecutionProvider() override = default; Status Sync() const override; @@ -81,42 +79,55 @@ class MIGraphXExecutionProvider : public IExecutionProvider { common::Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; - virtual std::shared_ptr GetKernelRegistry() const override; + std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; - static AllocatorPtr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t migx_mem_limit, ArenaExtendStrategy arena_extend_strategy, - MIGraphXExecutionProviderExternalAllocatorInfo external_alloc_info, const OrtArenaCfg* arena_cfg); - std::unique_ptr GetSubGraph(const std::vector& graph_nodes_index, const GraphViewer& graph) const; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; std::vector CreatePreferredAllocators() override; - int GetDeviceId() const override { return info_.device_id; } + int GetDeviceId() const override { return device_id_; } ProviderOptions GetProviderOptions() const override { - return MIGraphXExecutionProviderInfo::ToProviderOptions(info_); + return { + {std::string{migraphx_provider_option::kDeviceId}, MakeStringWithClassicLocale(device_id_)}, + {std::string{migraphx_provider_option::kFp16Enable}, MakeStringWithClassicLocale(fp16_enable_)}, + {std::string{migraphx_provider_option::kBf16Enable}, MakeStringWithClassicLocale(bf16_enable_)}, + {std::string{migraphx_provider_option::kFp8Enable}, MakeStringWithClassicLocale(fp8_enable_)}, + {std::string{migraphx_provider_option::kInt8Enable}, MakeStringWithClassicLocale(int8_enable_)}, + {std::string{migraphx_provider_option::kInt8CalibTable}, MakeStringWithClassicLocale(int8_calibration_table_name_)}, + {std::string{migraphx_provider_option::kInt8UseNativeCalibTable}, MakeStringWithClassicLocale(int8_use_native_calibration_table_)}, + {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(exhaustive_tune_)}, + {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(mem_limit_)}, + {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, arena_extend_strategy_)}, + {std::string{migraphx_provider_option::kGpuExternalAlloc}, MakeStringWithClassicLocale(external_alloc_)}, + {std::string{migraphx_provider_option::kGpuExternalFree}, MakeStringWithClassicLocale(external_free_)}, + {std::string{migraphx_provider_option::kGpuExternalEmptyCache}, MakeStringWithClassicLocale(external_empty_cache_)}, + {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(model_cache_path_)}}; } private: - MIGraphXExecutionProviderInfo info_; + OrtDevice::DeviceId device_id_{0}; bool fp16_enable_ = false; + bool bf16_enable_ = false; bool fp8_enable_ = false; bool int8_enable_ = false; - std::string int8_calibration_cache_name_; + std::string int8_calibration_table_name_; bool int8_calibration_cache_available_ = false; - bool int8_use_native_migraphx_calibration_table_ = false; - std::string calibration_cache_path_; + bool int8_use_native_calibration_table_ = false; + std::filesystem::path calibration_cache_path_{}; std::unordered_map dynamic_range_map_; - bool save_compiled_model_ = false; - std::string save_compiled_path_; - bool load_compiled_model_ = false; - std::string load_compiled_path_; + std::filesystem::path model_cache_path_{}; + std::set session_input_names; bool dump_model_ops_ = false; migraphx::target t_; std::mutex mgx_mu_; hipStream_t stream_ = nullptr; + hipDeviceProp_t device_prop_{}; bool exhaustive_tune_ = false; - mutable std::filesystem::path model_path_; + mutable std::filesystem::path model_path_{}; + size_t mem_limit_{std::numeric_limits::max()}; + ArenaExtendStrategy arena_extend_strategy_{ArenaExtendStrategy::kNextPowerOfTwo}; std::unordered_map map_progs_; std::unordered_map map_onnx_string_; @@ -125,6 +136,9 @@ class MIGraphXExecutionProvider : public IExecutionProvider { AllocatorPtr allocator_; std::unique_ptr metadef_id_generator_; + void* external_alloc_{nullptr}; + void* external_free_{nullptr}; + void* external_empty_cache_{nullptr}; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index cf21d791cfe6b..33ef366eb18e5 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -1,14 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/providers/shared_library/provider_api.h" #include "core/providers/migraphx/migraphx_execution_provider_info.h" #include "core/common/make_string.h" #include "core/common/parse_string.h" -#include "core/framework/provider_options_utils.h" -#include "migraphx_inc.h" -#include "migraphx_call.h" +#include "core/providers/migraphx/migraphx_inc.h" +#include "core/providers/migraphx/migraphx_call.h" namespace onnxruntime { @@ -17,118 +18,90 @@ const EnumNameMapping arena_extend_strategy_mapping{ {ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"}, }; -namespace migraphx { -namespace provider_option_names { -constexpr const char* kDeviceId = "device_id"; -constexpr const char* kFp16Enable = "trt_fp16_enable"; -constexpr const char* kFp8Enable = "migx_fp8_enable"; -constexpr const char* kInt8Enable = "migx_int8_enable"; -constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name"; -constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table"; -constexpr const char* kSaveCompiledModel = "migx_save_compiled_model"; -constexpr const char* kSaveModelPath = "migx_save_model_name"; -constexpr const char* kLoadCompiledModel = "migx_load_compiled_model"; -constexpr const char* kLoadModelPath = "migx_load_model_name"; -constexpr const char* kExhaustiveTune = "migx_exhaustive_tune"; -constexpr const char* kMemLimit = "migx_mem_limit"; -constexpr const char* kArenaExtendStrategy = "migx_arena_extend_strategy"; -constexpr const char* kGpuExternalAlloc = "migx_external_alloc"; -constexpr const char* kGpuExternalFree = "migx_external_free"; -constexpr const char* kGpuExternalEmptyCache = "migx_external_empty_cache"; - -} // namespace provider_option_names -} // namespace migraphx - -MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { - MIGraphXExecutionProviderInfo info{}; - void* alloc = nullptr; - void* free = nullptr; - void* empty_cache = nullptr; +MIGraphXExecutionProviderInfo::MIGraphXExecutionProviderInfo(const ProviderOptions& options) { ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( - migraphx::provider_option_names::kDeviceId, - [&info](const std::string& value_str) -> Status { - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id)); + migraphx_provider_option::kDeviceId, + [this](const std::string& value_str) -> Status { + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, device_id)); int num_devices{}; ORT_RETURN_IF_ERROR(HIP_CALL(hipGetDeviceCount(&num_devices))); ORT_RETURN_IF_NOT( - 0 <= info.device_id && info.device_id < num_devices, - "Invalid device ID: ", info.device_id, + 0 <= device_id && device_id < num_devices, + "Invalid device ID: ", device_id, ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); return Status::OK(); }) .AddValueParser( - migraphx::provider_option_names::kGpuExternalAlloc, - [&alloc](const std::string& value_str) -> Status { - size_t address; + migraphx_provider_option::kGpuExternalAlloc, + [this](const std::string& value_str) -> Status { + std::uintptr_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - alloc = reinterpret_cast(address); + external_alloc = reinterpret_cast(address); return Status::OK(); }) .AddValueParser( - migraphx::provider_option_names::kGpuExternalFree, - [&free](const std::string& value_str) -> Status { - size_t address; + migraphx_provider_option::kGpuExternalFree, + [this](const std::string& value_str) -> Status { + std::uintptr_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - free = reinterpret_cast(address); + external_free = reinterpret_cast(address); return Status::OK(); }) .AddValueParser( - migraphx::provider_option_names::kGpuExternalEmptyCache, - [&empty_cache](const std::string& value_str) -> Status { - size_t address; + migraphx_provider_option::kGpuExternalEmptyCache, + [this](const std::string& value_str) -> Status { + std::uintptr_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - empty_cache = reinterpret_cast(address); + external_empty_cache = reinterpret_cast(address); + return Status::OK(); + }) + .AddValueParser( + migraphx_provider_option::kModelCacheDir, + [this](const std::string& value_str) -> Status { + model_cache_dir = ToPathString(value_str); return Status::OK(); }) - .AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable) - .AddAssignmentToReference(migraphx::provider_option_names::kFp8Enable, info.fp8_enable) - .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) - .AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model) - .AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model) - .AddAssignmentToReference(migraphx::provider_option_names::kExhaustiveTune, info.exhaustive_tune) - .AddAssignmentToReference(migraphx::provider_option_names::kMemLimit, info.mem_limit) - .AddAssignmentToEnumReference(migraphx::provider_option_names::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy) + .AddAssignmentToReference(migraphx_provider_option::kFp16Enable, fp16_enable) + .AddAssignmentToReference(migraphx_provider_option::kBf16Enable, bf16_enable) + .AddAssignmentToReference(migraphx_provider_option::kFp8Enable, fp8_enable) + .AddAssignmentToReference(migraphx_provider_option::kInt8Enable, int8_enable) + .AddAssignmentToReference(migraphx_provider_option::kInt8UseNativeCalibTable, int8_use_native_calibration_table) + .AddAssignmentToReference(migraphx_provider_option::kInt8CalibTable, int8_calibration_table_name) + .AddAssignmentToReference(migraphx_provider_option::kExhaustiveTune, exhaustive_tune) + .AddAssignmentToReference(migraphx_provider_option::kMemLimit, mem_limit) + .AddAssignmentToEnumReference(migraphx_provider_option::kArenaExtendStrategy, arena_extend_strategy_mapping, arena_extend_strategy) .Parse(options)); - - MIGraphXExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache}; - info.external_allocator_info = alloc_info; - - return info; } -ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXExecutionProviderInfo& info) { - const ProviderOptions options{ - {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, - {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, - {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)}, - {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, - {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)}, - {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)}, - {migraphx::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.mem_limit)}, - {migraphx::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, - {migraphx::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, - {migraphx::provider_option_names::kGpuExternalEmptyCache, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.empty_cache))}, - {migraphx::provider_option_names::kArenaExtendStrategy, - EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, - {migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.exhaustive_tune)}, - }; - return options; +MIGraphXExecutionProviderInfo::MIGraphXExecutionProviderInfo(const OrtMIGraphXProviderOptions& options) noexcept + : device_id{static_cast(options.device_id)}, + fp16_enable{options.migraphx_fp16_enable != 0}, + fp8_enable{options.migraphx_fp8_enable != 0}, + int8_enable{options.migraphx_int8_enable != 0}, + exhaustive_tune{options.migraphx_exhaustive_tune != 0}, + mem_limit{options.migraphx_mem_limit}, + arena_extend_strategy{options.migraphx_arena_extend_strategy} { } -ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGraphXProviderOptions& info) { - const ProviderOptions options{ - {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, - {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, - {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)}, - {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, - {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)}, - {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)}, - {migraphx::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.migraphx_mem_limit)}, - {migraphx::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast(info.migraphx_arena_extend_strategy))}, - {migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)}, +ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions() const { + return { + {std::string{migraphx_provider_option::kDeviceId}, MakeStringWithClassicLocale(device_id)}, + {std::string{migraphx_provider_option::kFp16Enable}, MakeStringWithClassicLocale(fp16_enable)}, + {std::string{migraphx_provider_option::kBf16Enable}, MakeStringWithClassicLocale(bf16_enable)}, + {std::string{migraphx_provider_option::kFp8Enable}, MakeStringWithClassicLocale(fp8_enable)}, + {std::string{migraphx_provider_option::kInt8Enable}, MakeStringWithClassicLocale(int8_enable)}, + {std::string{migraphx_provider_option::kInt8CalibTable}, MakeStringWithClassicLocale(int8_calibration_table_name)}, + {std::string{migraphx_provider_option::kInt8UseNativeCalibTable}, MakeStringWithClassicLocale(int8_use_native_calibration_table)}, + {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(mem_limit)}, + {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, arena_extend_strategy)}, + {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(exhaustive_tune)}, + {std::string{migraphx_provider_option::kGpuExternalAlloc}, MakeStringWithClassicLocale(external_alloc)}, + {std::string{migraphx_provider_option::kGpuExternalFree}, MakeStringWithClassicLocale(external_free)}, + {std::string{migraphx_provider_option::kGpuExternalEmptyCache}, MakeStringWithClassicLocale(external_empty_cache)}, + {std::string{migraphx_provider_option::kModelCacheDir}, MakeStringWithClassicLocale(model_cache_dir)}, }; - return options; } + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index a598052c5f025..414254aaa2629 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -3,70 +3,79 @@ #pragma once +#include #include #include +#include #include "core/framework/ortdevice.h" #include "core/common/hash_combine.h" #include "core/framework/arena_extend_strategy.h" #include "core/framework/provider_options.h" +#include "core/framework/provider_options_utils.h" #include "core/session/onnxruntime_c_api.h" -namespace onnxruntime { - -// Information needed to construct MIGraphX execution providers. -struct MIGraphXExecutionProviderExternalAllocatorInfo { - void* alloc{nullptr}; - void* free{nullptr}; - void* empty_cache{nullptr}; - - MIGraphXExecutionProviderExternalAllocatorInfo() { - alloc = nullptr; - free = nullptr; - empty_cache = nullptr; - } +using namespace std::literals::string_view_literals; - MIGraphXExecutionProviderExternalAllocatorInfo(void* a, void* f, void* e) { - alloc = a; - free = f; - empty_cache = e; - } +namespace onnxruntime { - bool UseExternalAllocator() const { - return (alloc != nullptr) && (free != nullptr); - } -}; +namespace migraphx_provider_option { +constexpr auto kDeviceId = "device_id"sv; +constexpr auto kFp16Enable = "migraphx_fp16_enable"sv; +constexpr auto kBf16Enable = "migraphx_bf16_enable"sv; +constexpr auto kFp8Enable = "migraphx_fp8_enable"sv; +constexpr auto kInt8Enable = "migraphx_int8_enable"sv; +constexpr auto kInt8CalibTable = "migraphx_int8_calibration_table_name"sv; +constexpr auto kInt8UseNativeCalibTable = "migraphx_int8_use_native_calibration_table"sv; +constexpr auto kExhaustiveTune = "migraphx_exhaustive_tune"sv; +constexpr auto kMemLimit = "migraphx_mem_limit"sv; +constexpr auto kArenaExtendStrategy = "migraphx_arena_extend_strategy"sv; +constexpr auto kGpuExternalAlloc = "migraphx_external_alloc"sv; +constexpr auto kGpuExternalFree = "migraphx_external_free"sv; +constexpr auto kGpuExternalEmptyCache = "migraphx_external_empty_cache"sv; +constexpr auto kModelCacheDir = "migraphx_model_cache_dir"sv; +} // namespace migraphx_provider_option + +extern const EnumNameMapping arena_extend_strategy_mapping; // Information needed to construct trt execution providers. struct MIGraphXExecutionProviderInfo { - std::string target_device; + std::string target_device{"gpu"}; OrtDevice::DeviceId device_id{0}; bool fp16_enable{false}; + bool bf16_enable{false}; bool fp8_enable{false}; bool int8_enable{false}; - std::string int8_calibration_table_name{""}; + std::string int8_calibration_table_name{}; bool int8_use_native_calibration_table{false}; - bool save_compiled_model{true}; - std::string save_model_file{"./compiled_model.mxr"}; - bool load_compiled_model{true}; - std::string load_model_file{"./compiled_model.mxr"}; + std::filesystem::path model_cache_dir{}; bool exhaustive_tune{false}; - size_t mem_limit{std::numeric_limits::max()}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified) - ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified) + size_t mem_limit{std::numeric_limits::max()}; + ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; OrtArenaCfg* default_memory_arena_cfg{nullptr}; - MIGraphXExecutionProviderExternalAllocatorInfo external_allocator_info{}; - static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); - static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info); - static ProviderOptions ToProviderOptions(const OrtMIGraphXProviderOptions& info); + void* external_alloc{nullptr}; + void* external_free{nullptr}; + void* external_empty_cache{nullptr}; + + bool UseExternalAlloc() const { + return external_alloc != nullptr && external_free != nullptr; + } + + MIGraphXExecutionProviderInfo() = default; + + explicit MIGraphXExecutionProviderInfo(const ProviderOptions& options); + explicit MIGraphXExecutionProviderInfo(const OrtMIGraphXProviderOptions& options) noexcept; + ProviderOptions ToProviderOptions() const; }; + } // namespace onnxruntime template <> struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> { - size_t operator()(const ::onnxruntime::MIGraphXExecutionProviderInfo& info) const { + size_t operator()(const ::onnxruntime::MIGraphXExecutionProviderInfo& info) const noexcept { size_t value{0xbc9f1d34}; // seed // Bits: device_id (16), arena_extend_strategy (reserved 2), boolean options (1 each) @@ -75,17 +84,21 @@ struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> { (static_cast(info.fp16_enable) << 18) ^ (static_cast(info.int8_enable) << 19) ^ (static_cast(info.int8_use_native_calibration_table) << 20) ^ - (static_cast(info.save_compiled_model) << 21) ^ - (static_cast(info.load_compiled_model) << 22) ^ - (static_cast(info.exhaustive_tune) << 23); + (static_cast(info.exhaustive_tune) << 21) ^ + (static_cast(info.bf16_enable) << 22); + onnxruntime::HashCombine(data, value); + onnxruntime::HashCombine(info.target_device, value); + onnxruntime::HashCombine(info.default_memory_arena_cfg, value); + onnxruntime::HashCombine(info.int8_calibration_table_name, value); + onnxruntime::HashCombine(info.model_cache_dir, value); onnxruntime::HashCombine(info.mem_limit, value); // Memory pointers - onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.alloc), value); - onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.free), value); - onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.empty_cache), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_alloc), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_free), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_empty_cache), value); // The default memory arena cfg is not used in hashing right now. return value; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h index 9274b5696185c..cce90f3ef82be 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -3,23 +3,27 @@ #pragma once +#include +#include +#include #include -#include -#include #include -#include #include +#include +#include +#include #include "flatbuffers/idl.h" #include "core/providers/migraphx/ort_trt_int8_cal_table.fbs.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/execution_provider.h" #include "core/common/path_string.h" +#include "core/framework/murmurhash3.h" namespace fs = std::filesystem; namespace onnxruntime { -bool IsGraphInput(const GraphViewer& graph, const std::string& name) { +inline bool IsGraphInput(const GraphViewer& graph, const std::string& name) { const auto& graph_inputs = graph.GetInputs(); std::vector input_names(graph_inputs.size()); std::transform(graph_inputs.begin(), graph_inputs.end(), input_names.begin(), [](auto in) { @@ -28,12 +32,12 @@ bool IsGraphInput(const GraphViewer& graph, const std::string& name) { return (std::find(input_names.begin(), input_names.end(), name) != input_names.end()); } -bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, [[maybe_unused]] bool check_outer_scope = true) { +inline bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, [[maybe_unused]] bool check_outer_scope = true) { const ONNX_NAMESPACE::TensorProto* initializer = nullptr; return graph.GetInitializedTensor(name, initializer); } -const Node* GetInputNode(const Node& node, int arg_index) { +inline const Node* GetInputNode(const Node& node, int arg_index) { int index = 0; for (auto nit = node.InputNodesBegin(); nit != node.InputNodesEnd(); ++nit, ++index) { if (index == arg_index) { @@ -44,7 +48,7 @@ const Node* GetInputNode(const Node& node, int arg_index) { return nullptr; } -std::size_t getNodeInputNum(const Node& node) { +inline std::size_t getNodeInputNum(const Node& node) { std::size_t node_num = 0; for (auto it = node.InputNodesBegin(); it != node.InputNodesEnd(); ++it) { node_num++; @@ -53,14 +57,14 @@ std::size_t getNodeInputNum(const Node& node) { return node_num; } -bool isInputNode(const Node* node, const std::string& name) { +inline bool isInputNode(const Node* node, const std::string& name) { auto outputs = node->OutputDefs(); return std::any_of(outputs.begin(), outputs.end(), [&](auto out) { return (out->Name() == name); }); } -bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector& input_nodes) { +inline bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector& input_nodes) { if (node == nullptr) { return false; } @@ -113,10 +117,10 @@ bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector return true; } -bool canEvalNodeArgument(const GraphViewer& graph, - const Node* node, - std::vector indices, - std::vector& input_nodes) { +inline bool canEvalNodeArgument(const GraphViewer& graph, + const Node* node, + std::vector indices, + std::vector& input_nodes) { input_nodes.clear(); std::vector in_nodes; for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit) { @@ -152,7 +156,7 @@ bool canEvalNodeArgument(const GraphViewer& graph, return true; } -float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) { +inline float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) { int s = (input >> 31) & 0x01; int e = ((input & 0x7f800000) >> 23) - 127; int p = -1; @@ -184,12 +188,12 @@ float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) { * Taken from the tensorRT EP to allow MIGraphX EP to reuse calibration tables for existing models * */ -bool ReadDynamicRange(const std::string file_name, - const bool is_calibration_table, - std::unordered_map& dynamic_range_map) { - std::ifstream infile(file_name, std::ios::binary | std::ios::in); - if (!infile) { +inline bool ReadDynamicRange(const std::filesystem::path& filename, + const bool is_calibration_table, + std::unordered_map& dynamic_range_map) { + std::ifstream infile{filename, std::ios::binary | std::ios::in}; + if (!infile.good()) { return false; } @@ -215,7 +219,7 @@ bool ReadDynamicRange(const std::string file_name, dynamic_range_map[tensor_name] = dynamic_range; } } else { - throw std::runtime_error("This is not a TensorRT generated calibration table " + file_name); + throw std::runtime_error("This is not a TensorRT generated calibration table " + filename.string()); } } } else { @@ -240,14 +244,111 @@ bool ReadDynamicRange(const std::string file_name, * Get cache by name * */ -std::string GetCachePath(const std::string& root, const std::string& name) { - if (root.empty()) { - return name; +inline std::filesystem::path GetCachePath(const std::filesystem::path& root, std::string_view name) { + return root.empty() ? std::filesystem::path{ToPathString(name)} : root / ToPathString(name); +} + +inline std::string GenerateGraphId(const GraphViewer& graph_viewer) { + HashValue model_hash; + + // find the top level graph + const Graph* cur_graph = &graph_viewer.GetGraph(); + while (cur_graph->IsSubgraph()) { + cur_graph = cur_graph->ParentGraph(); + } + + const Graph& main_graph = *cur_graph; + uint32_t hash[4] = {0, 0, 0, 0}; + + auto hash_str = [&hash](const std::string& str) { + MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); + }; + + // Use the model's file name instead of the entire path to avoid cache regeneration if a path changes + const fs::path path{main_graph.ModelPath()}; + + if (path.has_filename()) { + const auto model_name = path.filename().string(); + + LOGS_DEFAULT(INFO) << "Model name is '" << model_name << "'"; + // Ensure enough characters are hashed in case model names are too short + const size_t model_name_length = model_name.length(); + constexpr size_t hash_string_length = 500; + std::string repeat_model_name = model_name; + for (size_t i = model_name_length; i > 0 && i < hash_string_length; i += model_name_length) { + repeat_model_name += model_name; + } + hash_str(repeat_model_name); } else { - fs::path path = root; - path.append(name); - return path.string(); + LOGS_DEFAULT(INFO) << "Model path is empty"; + } + + // fingerprint current graph by hashing graph inputs + for (const auto* node_arg : graph_viewer.GetInputsIncludingInitializers()) { + hash_str(node_arg->Name()); + } + + // hashing outputs, inputs and inputs shapes of each node + const int number_of_ort_nodes = graph_viewer.NumberOfNodes(); + std::vector nodes_vector(number_of_ort_nodes); + std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); + const std::vector& node_index = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto& index : nodes_vector) { + const auto& node = graph_viewer.GetNode(node_index[index]); + for (const auto* node_arg : node->OutputDefs()) { + if (node_arg != nullptr && node_arg->Exists()) { + hash_str(node_arg->Name()); + } + } + for (const auto* node_arg : node->InputDefs()) { + if (node_arg != nullptr && node_arg->Exists()) { + hash_str(node_arg->Name()); + if (node_arg->Shape() == nullptr) { + continue; + } + int dim_size = node_arg->Shape()->dim_size(); + for (int i = 0; i < dim_size; i++) { + hash_str(std::to_string(node_arg->Shape()->dim(i).dim_value())); + } + } + } + } + +#ifdef __linux__ + hash_str("LINUX"); +#elif defined(_WIN32) + hash_str("WINDOWS"); +#endif + + model_hash = hash[0] | static_cast(hash[1]) << 32; + + std::array s{}; + auto [ptr, ec] = std::to_chars(s.data(), s.data() + s.size(), model_hash, 16); + return std::string{s.data(), ptr}; +} + +inline std::string_view TrimLeft(std::string_view sv, int (*fn)(int) = std::isspace) { + return sv.substr(0, sv.end() - std::find_if(sv.begin(), sv.end(), [fn](int ch) { + return fn(ch); + })); +} + +inline std::string_view TrimRight(std::string_view sv, int (*fn)(int) = std::isspace) { + return sv.substr(sv.end() - std::find_if(sv.rbegin(), sv.rend(), [fn](int ch) { + return fn(ch); + }).base()); +} + +inline std::string_view Trim(std::string_view sv, int (*fn)(int) = std::isspace) { + return TrimRight(TrimLeft(sv, fn), fn); +} + +inline int ToInteger(const std::string_view sv) { + int result = 0; + if (auto [_, ec] = std::from_chars(sv.data(), sv.data() + sv.length(), result); ec == std::errc()) { + return result; } + ORT_THROW("invalid input for conversion to integer"); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_inc.h b/onnxruntime/core/providers/migraphx/migraphx_inc.h index 2b035b20f619f..49e838747892f 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_inc.h +++ b/onnxruntime/core/providers/migraphx/migraphx_inc.h @@ -6,3 +6,4 @@ #include #include #include +#include diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 84c7fe3e4d4ab..626758bce36d7 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -1,28 +1,36 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License + #include +#include +#include +#include +#include + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#include +#include +#endif #include "core/providers/shared_library/provider_api.h" #include "core/providers/migraphx/migraphx_provider_factory.h" -#include "migraphx_execution_provider.h" -#include "migraphx_execution_provider_info.h" -#include "migraphx_provider_factory_creator.h" -#include "migraphx_allocator.h" -#include "gpu_data_transfer.h" +#include "core/providers/migraphx/migraphx_execution_provider.h" +#include "core/providers/migraphx/migraphx_execution_provider_info.h" +#include "core/providers/migraphx/migraphx_allocator.h" +#include "core/providers/migraphx/gpu_data_transfer.h" #include "core/framework/provider_options.h" #include "core/session/onnxruntime_c_api.h" -using namespace onnxruntime; - namespace onnxruntime { void InitializeRegistry(); void DeleteRegistry(); struct MIGraphXProviderFactory : IExecutionProviderFactory { - MIGraphXProviderFactory(const MIGraphXExecutionProviderInfo& info) : info_{info} {} - ~MIGraphXProviderFactory() override {} + explicit MIGraphXProviderFactory(MIGraphXExecutionProviderInfo info) : info_{std::move(info)} {} + ~MIGraphXProviderFactory() override = default; std::unique_ptr CreateProvider() override; @@ -35,11 +43,11 @@ std::unique_ptr MIGraphXProviderFactory::CreateProvider() { } struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX { - std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) override { + std::unique_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, const char* name) override { return std::make_unique(device_id, name); } - std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) override { + std::unique_ptr CreateMIGraphXPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) override { return std::make_unique(device_id, name); } @@ -61,14 +69,39 @@ struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX { HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost)); } - std::shared_ptr CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override { - return MIGraphXExecutionProvider::CreateMIGraphXAllocator(device_id, migx_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg); + std::shared_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, + void* alloc_fn, void* free_fn, void* empty_cache_fn, const OrtArenaCfg* default_memory_arena_cfg) override { + if (alloc_fn != nullptr && free_fn != nullptr) { + AllocatorCreationInfo default_memory_info{ + [alloc_fn, free_fn, empty_cache_fn](OrtDevice::DeviceId id) { + return std::make_unique(id, HIP, alloc_fn, free_fn, empty_cache_fn); + }, + device_id, false}; + + return CreateAllocator(default_memory_info); + } + AllocatorCreationInfo default_memory_info{ + [](OrtDevice::DeviceId id) { + return std::make_unique(id, HIP); + }, + device_id, + true, + {default_memory_arena_cfg ? *default_memory_arena_cfg + : OrtArenaCfg(mem_limit, static_cast(arena_extend_strategy), + -1, -1, -1, -1L)}, + // make it stream aware + true}; + + // ROCM malloc/free is expensive so always use an arena + return CreateAllocator(default_memory_info); } } g_info; -struct MIGraphX_Provider : Provider { +struct MIGraphX_Provider final : Provider { void* GetInfo() override { return &g_info; } + virtual ~MIGraphX_Provider() = default; + std::shared_ptr CreateExecutionProviderFactory(int device_id) override { MIGraphXExecutionProviderInfo info; info.device_id = static_cast(device_id); @@ -76,72 +109,49 @@ struct MIGraphX_Provider : Provider { return std::make_shared(info); } + // Method uses ProviderOptions, and not OrtMIGraphXProviderOptions (obsolete) std::shared_ptr CreateExecutionProviderFactory(const void* provider_options) override { - auto& options = *reinterpret_cast(provider_options); - MIGraphXExecutionProviderInfo info; - info.device_id = static_cast(options.device_id); - info.target_device = "gpu"; - info.fp16_enable = options.migraphx_fp16_enable; - info.fp8_enable = options.migraphx_fp8_enable; - info.exhaustive_tune = options.migraphx_exhaustive_tune; - info.int8_enable = options.migraphx_int8_enable; - info.int8_calibration_table_name = ""; - if (options.migraphx_int8_calibration_table_name != nullptr) { - info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name; - } - info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0; - info.save_compiled_model = options.migraphx_save_compiled_model; - info.save_model_file = ""; - if (options.migraphx_save_model_path != nullptr) { - info.save_model_file = options.migraphx_save_model_path; + if (provider_options != nullptr) { + return std::make_shared( + MIGraphXExecutionProviderInfo{*static_cast(provider_options)}); } - info.load_compiled_model = options.migraphx_load_compiled_model; - info.load_model_file = ""; - if (options.migraphx_load_model_path != nullptr) { - info.load_model_file = options.migraphx_load_model_path; - } - info.arena_extend_strategy = static_cast(options.migraphx_arena_extend_strategy); - info.mem_limit = options.migraphx_mem_limit; - return std::make_shared(info); + return nullptr; } void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override { - auto internal_options = onnxruntime::MIGraphXExecutionProviderInfo::FromProviderOptions(options); - auto& migx_options = *reinterpret_cast(provider_options); - migx_options.device_id = internal_options.device_id; - migx_options.migraphx_fp16_enable = internal_options.fp16_enable; - migx_options.migraphx_fp8_enable = internal_options.fp8_enable; - migx_options.migraphx_int8_enable = internal_options.int8_enable; - migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune; - - char* dest = nullptr; - auto str_size = internal_options.int8_calibration_table_name.size(); - if (str_size == 0) { - migx_options.migraphx_int8_calibration_table_name = nullptr; + MIGraphXExecutionProviderInfo internal_options{options}; + const auto migx_options = static_cast(provider_options); + migx_options->device_id = internal_options.device_id; + migx_options->migraphx_fp16_enable = internal_options.fp16_enable; + migx_options->migraphx_fp8_enable = internal_options.fp8_enable; + migx_options->migraphx_int8_enable = internal_options.int8_enable; + migx_options->migraphx_exhaustive_tune = internal_options.exhaustive_tune; + + if (internal_options.int8_calibration_table_name.empty()) { + migx_options->migraphx_int8_calibration_table_name = nullptr; } else { - dest = new char[str_size + 1]; + auto str_size = internal_options.int8_calibration_table_name.size(); + auto dest = new char[str_size + 1]; #ifdef _MSC_VER strncpy_s(dest, str_size + 1, internal_options.int8_calibration_table_name.c_str(), str_size); #else strncpy(dest, internal_options.int8_calibration_table_name.c_str(), str_size); #endif dest[str_size] = '\0'; - migx_options.migraphx_int8_calibration_table_name = (const char*)dest; + migx_options->migraphx_int8_calibration_table_name = static_cast(dest); } - migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table; + migx_options->migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table; - migx_options.migraphx_save_compiled_model = internal_options.save_compiled_model; - migx_options.migraphx_save_model_path = internal_options.save_model_file.c_str(); - migx_options.migraphx_load_compiled_model = internal_options.load_compiled_model; - migx_options.migraphx_load_model_path = internal_options.load_model_file.c_str(); - migx_options.migraphx_arena_extend_strategy = static_cast(internal_options.arena_extend_strategy); - migx_options.migraphx_mem_limit = internal_options.mem_limit; + migx_options->migraphx_arena_extend_strategy = static_cast(internal_options.arena_extend_strategy); + migx_options->migraphx_mem_limit = internal_options.mem_limit; } ProviderOptions GetProviderOptions(const void* provider_options) override { - auto& options = *reinterpret_cast(provider_options); - return onnxruntime::MIGraphXExecutionProviderInfo::ToProviderOptions(options); + return provider_options != nullptr ? MIGraphXExecutionProviderInfo{ + *static_cast(provider_options)} + .ToProviderOptions() + : ProviderOptions{}; } Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, @@ -151,17 +161,30 @@ struct MIGraphX_Provider : Provider { const OrtSessionOptions& session_options, const OrtLogger& logger, std::unique_ptr& ep) override { - const ConfigOptions* config_options = &session_options.GetConfigOptions(); - - std::array configs_array = {&provider_options, config_options}; - const void* arg = reinterpret_cast(&configs_array); - auto ep_factory = CreateExecutionProviderFactory(&provider_options); + ORT_UNUSED_PARAMETER(num_devices); + const auto ep_factory = CreateExecutionProviderFactory(&provider_options); ep = ep_factory->CreateProvider(session_options, logger); - return Status::OK(); } void Initialize() override { +#ifdef _WIN32 + HMODULE module = nullptr; + if (GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | + GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + static_cast(static_cast(InitializeRegistry)), + &module) != 0) { + std::vector pathBuf; + for (;;) { + pathBuf.resize(pathBuf.size() + MAX_PATH); + if (const auto writen = GetModuleFileNameW(module, pathBuf.data(), static_cast(pathBuf.size())); writen < pathBuf.size()) { + break; + } + } + std::filesystem::path path(pathBuf.begin(), pathBuf.end()); + SetDllDirectoryW(path.parent_path().native().c_str()); + } +#endif InitializeRegistry(); } @@ -181,26 +204,47 @@ struct MigraphXEpFactory : OrtEpFactory { const char* ep_name, OrtHardwareDeviceType hw_type, const OrtLogger& default_logger_in) - : ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, default_logger{default_logger_in} { + : ort_api{ort_api_in}, default_logger{default_logger_in}, ep_name{ep_name}, ort_hw_device_type{hw_type} { + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; + GetVersion = GetVersionImpl; + GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; + + CreateAllocator = CreateAllocatorImpl; + ReleaseAllocator = ReleaseAllocatorImpl; + CreateDataTransfer = CreateDataTransferImpl; + + IsStreamAware = IsStreamAwareImpl; + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; } // Returns the name for the EP. Each unique factory configuration must have a unique name. // Ex: a factory that supports NPU should have a different than a factory that supports GPU. - static const char* GetNameImpl(const OrtEpFactory* this_ptr) { + static const char* GetNameImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); return factory->ep_name.c_str(); } - static const char* GetVendorImpl(const OrtEpFactory* this_ptr) { + static const char* GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); return factory->vendor.c_str(); } + static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_id; + } + + static const char* GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->version.c_str(); + } + // Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports. // An EP created with this factory is expected to be able to execute a model with *all* supported // hardware devices at once. A single instance of MigraphX EP is not currently setup to partition a model among @@ -212,14 +256,14 @@ struct MigraphXEpFactory : OrtEpFactory { size_t num_devices, OrtEpDevice** ep_devices, size_t max_ep_devices, - size_t* p_num_ep_devices) { + size_t* p_num_ep_devices) noexcept { size_t& num_ep_devices = *p_num_ep_devices; auto* factory = static_cast(this_ptr); for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *devices[i]; - if (factory->ort_api.HardwareDevice_Type(&device) == factory->ort_hw_device_type) { - // factory->ort_api.HardwareDevice_VendorId(&device) == factory->vendor_id) { + if (factory->ort_api.HardwareDevice_Type(&device) == factory->ort_hw_device_type && + factory->ort_api.HardwareDevice_VendorId(&device) == 0x1002) { OrtKeyValuePairs* ep_options = nullptr; factory->ort_api.CreateKeyValuePairs(&ep_options); ORT_API_RETURN_IF_ERROR( @@ -237,20 +281,59 @@ struct MigraphXEpFactory : OrtEpFactory { _In_ size_t /*num_devices*/, _In_ const OrtSessionOptions* /*session_options*/, _In_ const OrtLogger* /*logger*/, - _Out_ OrtEp** /*ep*/) { + _Out_ OrtEp** /*ep*/) noexcept { return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, "[MigraphX/AMDGPU EP] EP factory does not support this method."); } - static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) { + static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) noexcept { // no-op as we never create an EP here. } + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(OrtEpFactory* this_ptr, + const OrtMemoryInfo* /*memory_info*/, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { + auto* factory = static_cast(this_ptr); + + *allocator = nullptr; + return factory->ort_api.CreateStatus( + ORT_INVALID_ARGUMENT, + "CreateAllocator should not be called as we did not add OrtMemoryInfo to our OrtEpDevice."); + } + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this_ptr*/, OrtAllocator* /*allocator*/) noexcept { + // should never be called as we don't implement CreateAllocator + } + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* /*this_ptr*/, + OrtDataTransferImpl** data_transfer) noexcept { + *data_transfer = nullptr; // not implemented + return nullptr; + } + + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return false; + } + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr, + const OrtMemoryDevice* /*memory_device*/, + const OrtKeyValuePairs* /*stream_options*/, + OrtSyncStreamImpl** stream) noexcept { + auto* factory = static_cast(this_ptr); + + *stream = nullptr; + return factory->ort_api.CreateStatus( + ORT_INVALID_ARGUMENT, "CreateSyncStreamForDevice should not be called as IsStreamAware returned false."); + } + const OrtApi& ort_api; const OrtLogger& default_logger; const std::string ep_name; const std::string vendor{"AMD"}; + const std::string version{"1.0.0"}; // MigraphX EP version - const uint32_t vendor_id{0x1002}; + // Not using AMD vendor id 0x1002 so that OrderDevices in provider_policy_context.cc will default dml ep + const uint32_t vendor_id{0x9999}; const OrtHardwareDeviceType ort_hw_device_type; // Supported OrtHardwareDevice }; diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h index d1c9457bafa0f..c23c9947c8d9b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h @@ -1,22 +1,23 @@ -// Copyright 2019 AMD AMDMIGraphX +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License -#include "core/framework/provider_options.h" -#include "onnxruntime_c_api.h" +#pragma once + +#include + +#include "core/framework/arena_extend_strategy.h" +#include "core/framework/ortdevice.h" namespace onnxruntime { class IAllocator; -class IDataTransfer; -struct IExecutionProviderFactory; -struct MIGraphXExecutionProviderInfo; -enum class ArenaExtendStrategy : int32_t; -struct MIGraphXExecutionProviderExternalAllocatorInfo; struct ProviderInfo_MIGraphX { - virtual std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) = 0; - virtual std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) = 0; + virtual std::unique_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, const char* name) = 0; + virtual std::unique_ptr CreateMIGraphXPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) = 0; virtual void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) = 0; virtual void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) = 0; - virtual std::shared_ptr CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0; + virtual std::shared_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t mem_limit, + ArenaExtendStrategy arena_extend_strategy, void* alloc_fn, void* free_fn, void* empty_cache_fn, const OrtArenaCfg* default_memory_arena_cfg) = 0; protected: ~ProviderInfo_MIGraphX() = default; // Can only be destroyed through a subclass instance diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h index 02d30ad0f6fbb..db169b9e2f5a9 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory_creator.h @@ -6,6 +6,7 @@ #include #include "core/providers/providers.h" +#include "core/framework/provider_options.h" struct OrtMIGraphXProviderOptions; @@ -14,5 +15,6 @@ namespace onnxruntime { struct MIGraphXProviderFactoryCreator { static std::shared_ptr Create(int device_id); static std::shared_ptr Create(const OrtMIGraphXProviderOptions* options); + static std::shared_ptr Create(const ProviderOptions&); }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc index 6e492327a73a3..0baa8a1c67c67 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc @@ -1,17 +1,26 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include "migraphx_stream_handle.h" +#include +#include + +#include "core/providers/resource.h" +#include "core/providers/migraphx/migraphx_stream_handle.h" + +#define MIGRAPHX_RESOURCE_VERSION 1 namespace onnxruntime { -struct MIGraphXNotification : public synchronize::Notification { - MIGraphXNotification(Stream& s) : Notification(s) { +enum MIGraphXResource { + hip_stream_t = rocm_resource_offset +}; + +struct MIGraphXNotification : synchronize::Notification { + explicit MIGraphXNotification(Stream& s) : Notification(s) { HIP_CALL_THROW(hipEventCreateWithFlags(&event_, hipEventDisableTiming)); } - ~MIGraphXNotification() { + ~MIGraphXNotification() override { if (event_) HIP_CALL_THROW(hipEventDestroy(event_)); } @@ -21,19 +30,19 @@ struct MIGraphXNotification : public synchronize::Notification { HIP_CALL_THROW(hipEventRecord(event_, static_cast(GetStream().GetHandle()))); } - void wait_on_device(Stream& device_stream) { - ORT_ENFORCE(device_stream.GetDevice().Type() == OrtDevice::GPU, "Unexpected device:", - device_stream.GetDevice().ToString()); - // launch a wait command to the migraphx stream - HIP_CALL_THROW(hipStreamWaitEvent(static_cast(device_stream.GetHandle()), event_, 0)); - }; + void wait_on_device(Stream* device_stream) const { + if (device_stream != nullptr) { + ORT_ENFORCE(device_stream->GetDevice().Type() == OrtDevice::GPU, "Unexpected device:", device_stream->GetDevice().ToString()); + // launch a wait command to the migraphx stream + HIP_CALL_THROW(hipStreamWaitEvent(static_cast(device_stream->GetHandle()), event_, 0)); + } + } - void wait_on_host() { - // CUDA_CALL_THROW(cudaStreamSynchronize(stream_)); + void wait_on_host() const { HIP_CALL_THROW(hipEventSynchronize(event_)); } - hipEvent_t event_; + hipEvent_t event_{}; }; MIGraphXStream::MIGraphXStream(hipStream_t stream, @@ -41,15 +50,14 @@ MIGraphXStream::MIGraphXStream(hipStream_t stream, AllocatorPtr cpu_allocator, bool release_cpu_buffer_on_migraphx_stream) : Stream(stream, device), - cpu_allocator_(cpu_allocator), + cpu_allocator_(std::move(cpu_allocator)), release_cpu_buffer_on_migraphx_stream_(release_cpu_buffer_on_migraphx_stream) { } MIGraphXStream::~MIGraphXStream() { - ORT_IGNORE_RETURN_VALUE(CleanUpOnRunEnd()); + ORT_IGNORE_RETURN_VALUE(MIGraphXStream::CleanUpOnRunEnd()); if (own_stream_) { - auto* handle = GetHandle(); - if (handle) + if (auto* handle = GetHandle()) HIP_CALL_THROW(hipStreamDestroy(static_cast(handle))); } } @@ -87,12 +95,12 @@ struct CpuBuffersInfo { std::unique_ptr buffers; // CPU buffer buffers[i]. // Number of buffer points in "buffers". - size_t n_buffers; + size_t n_buffers{}; }; static void ReleaseCpuBufferCallback(void* raw_info) { std::unique_ptr info = std::make_unique(); - info.reset(reinterpret_cast(raw_info)); + info.reset(static_cast(raw_info)); for (size_t i = 0; i < info->n_buffers; ++i) { info->allocator->Free(info->buffers[i]); } @@ -124,29 +132,25 @@ Status MIGraphXStream::CleanUpOnRunEnd() { } void* MIGraphXStream::GetResource(int version, int id) const { - ORT_ENFORCE(version <= ORT_ROCM_RESOURCE_VERSION, "resource version unsupported!"); - void* resource{}; - switch (id) { - case RocmResource::hip_stream_t: - return reinterpret_cast(GetHandle()); - default: - break; + ORT_ENFORCE(version <= MIGRAPHX_RESOURCE_VERSION, "resource version unsupported!"); + if (id == hip_stream_t) { + return GetHandle(); } - return resource; + return nullptr; } // CPU Stream command handles void WaitMIGraphXNotificationOnDevice(Stream* stream, synchronize::Notification& notification) { - static_cast(¬ification)->wait_on_device(*stream); + dynamic_cast(¬ification)->wait_on_device(stream); } void WaitMIGraphXNotificationOnHost(Stream* /*stream*/, synchronize::Notification& notification) { - static_cast(¬ification)->wait_on_host(); + dynamic_cast(¬ification)->wait_on_host(); } void RegisterMIGraphXStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, const OrtDevice::DeviceType device_type, - AllocatorPtr cpu_allocator, + const AllocatorPtr& cpu_allocator, bool release_cpu_buffer_on_migraphx_stream, hipStream_t external_stream, bool use_existing_stream) { @@ -154,19 +158,20 @@ void RegisterMIGraphXStreamHandles(IStreamCommandHandleRegistry& stream_handle_r stream_handle_registry.RegisterWaitFn(device_type, device_type, WaitMIGraphXNotificationOnDevice); // wait migraphx notification on cpu ep stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitMIGraphXNotificationOnHost); - if (!use_existing_stream) + if (!use_existing_stream) { stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_migraphx_stream](const OrtDevice& device) { HIP_CALL_THROW(hipSetDevice(device.Id())); hipStream_t stream = nullptr; HIP_CALL_THROW(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); return std::make_unique(stream, device, cpu_allocator, release_cpu_buffer_on_migraphx_stream); }); - else + } else { stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_migraphx_stream, external_stream](const OrtDevice& device) { return std::make_unique(external_stream, device, cpu_allocator, release_cpu_buffer_on_migraphx_stream); }); + } } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h index 886103690c661..132ae5fc09d13 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h @@ -2,12 +2,15 @@ // Licensed under the MIT License. #pragma once + +#include +#include + #include "core/framework/stream_handles.h" -#include "migraphx_inc.h" -#include "migraphx_call.h" +#include "core/providers/migraphx/migraphx_inc.h" +#include "core/providers/migraphx/migraphx_call.h" namespace onnxruntime { -void WaitMIGraphXNotificationOnDevice(Stream* stream, synchronize::Notification& notification); struct MIGraphXStream : Stream { MIGraphXStream(hipStream_t stream, @@ -15,7 +18,7 @@ struct MIGraphXStream : Stream { AllocatorPtr cpu_allocator, bool release_cpu_buffer_on_migraphx_stream); - ~MIGraphXStream(); + ~MIGraphXStream() override; std::unique_ptr CreateNotification(size_t /*num_consumers*/) override; @@ -27,7 +30,7 @@ struct MIGraphXStream : Stream { bool own_stream_{true}; - virtual void* GetResource(int version, int id) const; + void* GetResource(int version, int id) const override; private: std::vector deferred_cpu_buffers_; @@ -36,8 +39,8 @@ struct MIGraphXStream : Stream { }; void RegisterMIGraphXStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, - const OrtDevice::DeviceType device_type, - AllocatorPtr cpu_allocator, + OrtDevice::DeviceType device_type, + const AllocatorPtr& cpu_allocator, bool release_cpu_buffer_on_migraphx_stream, hipStream_t external_stream, bool use_existing_stream); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc index adc79576272ab..843f70dd27461 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc @@ -7,45 +7,6 @@ namespace onnxruntime { namespace nnapi { -namespace { -bool HasExternalInitializer(const InitializedTensorSet& initializers, const NodeUnit& node_unit) { - const auto is_ext_initializer = - [&](const NodeArg& node_arg) { - const auto& input_name(node_arg.Name()); - const auto initializer = initializers.find(input_name); - if (initializer == initializers.end()) - return false; - - const auto& tensor = *initializer->second; - if (tensor.has_data_location() && - tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { - LOGS_DEFAULT(VERBOSE) << "Initializer [" << input_name - << "] with external data location are not currently supported"; - return true; - } - - return false; - }; - - const auto& inputs = node_unit.Inputs(); - for (const auto& input : inputs) { - if (is_ext_initializer(input.node_arg)) - return true; - - if (!input.quant_param) - continue; - - if (is_ext_initializer(input.quant_param->scale)) - return true; - - if (input.quant_param->zero_point && is_ext_initializer(*input.quant_param->zero_point)) - return true; - } - - return false; -} -} // namespace - // Add operator related Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const NodeUnit& node_unit) const { @@ -86,10 +47,6 @@ bool BaseOpBuilder::IsOpSupported(const GraphViewer& graph_viewer, const NodeUni if (!HasSupportedInputOutputs(graph_viewer, node_unit, params)) return false; - // We do not support external initializers for now - if (HasExternalInitializer(graph_viewer.GetAllInitializedTensors(), node_unit)) - return false; - if (!HasSupportedOpSet(node_unit)) return false; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc index e4bee6f959a01..50734be17056e 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc @@ -283,7 +283,7 @@ Status ModelBuilder::RegisterInitializers() { auto [index, size, padded_size] = initializers[i++]; const uint8_t* src = nullptr; // TensorProto_DataType_UINT8 or TensorProto_DataType_FLOAT: - Initializer unpacked_tensor(tensor, graph_viewer_.ModelPath()); + Initializer unpacked_tensor(graph_viewer_.GetGraph(), tensor, graph_viewer_.ModelPath()); size_t size_in_bytes = unpacked_tensor.DataAsByteSpan().size(); ORT_RETURN_IF_NOT(size == size_in_bytes, "initializer tensor: ", tensor.name(), "'s size: ", diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 286db9070766d..0ee18cc6799fc 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -3,6 +3,7 @@ // Licensed under the MIT License. #include #include +#include #include #include "core/providers/shared_library/provider_api.h" #include "core/providers/nv_tensorrt_rtx/nv_provider_options.h" @@ -11,6 +12,7 @@ #include "core/common/common.h" #include "core/common/narrow.h" #include "core/common/safeint.h" +#include "core/framework/ort_value.h" #include "nv_execution_provider.h" #include "nv_execution_provider_utils.h" #include "nv_execution_provider_custom_ops.h" @@ -18,7 +20,7 @@ #include "nv_data_transfer.h" #include "onnx_ctx_model_helper.h" #include "core/providers/cuda/shared_inc/cuda_call.h" -#include "core/providers/cuda/math/unary_elementwise_ops_impl.h" +#include "core/providers/cuda/cuda_graph.h" #include "core/session/allocator_adapters.h" #include "cuda_runtime_api.h" #include "core/common/parse_string.h" @@ -83,39 +85,107 @@ struct ShutdownProtobuf { namespace onnxruntime { -namespace cuda { -template <> -void Impl_Cast( - cudaStream_t stream, - const int64_t* input_data, int32_t* output_data, - size_t count) { - return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); +// Helper function to check if a data type is supported by NvTensorRTRTX EP +static bool IsSupportedDataType(ONNXTensorElementDataType data_type) { + switch (data_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: // kFLOAT - 32-bit floating point + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: // kHALF - IEEE 16-bit floating-point + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // kBF16 - Brain float 16 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: // kBOOL - 8-bit boolean + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: // kINT4 - 4-bit signed integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: // kINT8 - 8-bit signed integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: // kUINT8 - 8-bit unsigned integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: // kINT32 - 32-bit signed integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: // kINT64 - 64-bit signed integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: // kFP8 - 8-bit floating point + return true; + default: + return false; + } } -template <> -void Impl_Cast( - cudaStream_t stream, - const int32_t* input_data, int64_t* output_data, - size_t count) { - return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); +// Helper function to get data type name as string +static std::string GetDataTypeName(ONNXTensorElementDataType data_type) { + switch (data_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return "FLOAT"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return "FLOAT16"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + return "BFLOAT16"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + return "BOOL"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: + return "INT4"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + return "INT8"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return "UINT8"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return "INT32"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return "INT64"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: + return "FLOAT8E4M3FN"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + return "DOUBLE"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + return "STRING"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + return "UINT16"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return "UINT32"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + return "UINT64"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + return "INT16"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: + return "COMPLEX64"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: + return "COMPLEX128"; + default: + return "UNKNOWN(" + std::to_string(static_cast(data_type)) + ")"; + } } -template <> -void Impl_Cast( - cudaStream_t stream, - const double* input_data, float* output_data, - size_t count) { - return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); -} +// Helper function to check if a node has supported data types +static bool CheckNodeDataTypes(const Node* node) { + // Check input data types + for (const auto* input_def : node->InputDefs()) { + if (input_def->Exists()) { + const auto* type_proto = input_def->TypeAsProto(); + if (type_proto && type_proto->has_tensor_type()) { + auto data_type = static_cast(type_proto->tensor_type().elem_type()); + if (!IsSupportedDataType(data_type)) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Node '" << node->Name() + << "' (OpType: " << node->OpType() + << ") has unsupported input data type: " << GetDataTypeName(data_type) + << " for input '" << input_def->Name() << "'"; + return false; + } + } + } + } + + // Check output data types + for (const auto* output_def : node->OutputDefs()) { + if (output_def->Exists()) { + const auto* type_proto = output_def->TypeAsProto(); + if (type_proto && type_proto->has_tensor_type()) { + auto data_type = static_cast(type_proto->tensor_type().elem_type()); + if (!IsSupportedDataType(data_type)) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Node '" << node->Name() + << "' (OpType: " << node->OpType() + << ") has unsupported output data type: " << GetDataTypeName(data_type) + << " for output '" << output_def->Name() << "'"; + return false; + } + } + } + } -template <> -void Impl_Cast( - cudaStream_t stream, - const float* input_data, double* output_data, - size_t count) { - return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); + return true; } -} // namespace cuda void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept { @@ -123,10 +193,11 @@ void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* / // even for empty tensors, so allocate a dummy byte. size = std::max(size, static_cast(1)); if (size > allocated_size) { - cudaFree(outputPtr); + alloc_->Free(alloc_, outputPtr); outputPtr = nullptr; allocated_size = 0; - if (cudaMalloc(&outputPtr, size) == cudaSuccess) { + outputPtr = alloc_->Alloc(alloc_, size); + if (outputPtr) { allocated_size = size; } } @@ -253,7 +324,8 @@ bool ApplyProfileShapesFromProviderOptions(std::vector>>& profile_min_shapes, std::unordered_map>>& profile_max_shapes, std::unordered_map>>& profile_opt_shapes, - ShapeRangesMap& input_explicit_shape_ranges) { + ShapeRangesMap& input_explicit_shape_ranges, + bool& cuda_graph_flag) { if (trt_profiles.size() == 0) { LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Number of optimization profiles should be greater than 0, but it's 0."; return false; @@ -280,6 +352,10 @@ bool ApplyProfileShapesFromProviderOptions(std::vectorisShapeTensor()) { + if (cuda_graph_flag) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Shape tensor detected on input '" << input->getName() << "'. Disabling CUDA Graph."; + cuda_graph_flag = false; + } int shape_size = nb_dims == 0 ? 1 : static_cast(profile_min_shapes[input_name][i].size()); std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); @@ -352,193 +428,6 @@ bool ApplyProfileShapesFromProviderOptions(std::vector shape values" for the INT32 shape tensor input across this inference run - * @param shape_tensor_values_int64 holds "shape tensor -> shape values" for the INT64 shape tensor input across this inference run - */ -Status ApplyProfileShapesFromInputTensorValue(std::vector& trt_profiles, - Ort::KernelContext ctx, - nvinfer1::ITensor* input, - ShapeRangesMap& shape_ranges, - const std::unordered_map& input_indexes, - std::unordered_map>& shape_tensor_values, - std::unordered_map>& shape_tensor_values_int64, - cudaStream_t stream, - bool* engine_update) { - for (size_t i = 0; i < trt_profiles.size(); i++) { - const std::string& input_name = input->getName(); - nvinfer1::Dims dims = input->getDimensions(); - int nb_dims = dims.nbDims; - - size_t input_index = 0; - const auto& iter = input_indexes.find(input_name); - if (iter != input_indexes.end()) { - input_index = iter->second; - } - - auto input_tensor = ctx.GetInput(input_index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shapes = tensor_info.GetShape(); - auto& shape_ranges_per_input = shape_ranges[input_name]; - - auto trt_profile = trt_profiles[i]; - - // If there are multiple profiles, for second and rest of profiles, simply copy the min/max/opt profile values from the first profile. - // Following "if statement" won't be executed since TRT EP currently only allows single profile for non-explicit profiles case. - if (i > 0) { - if (input->isShapeTensor()) { - // shape tensor - int shape_size = nb_dims == 0 ? 1 : static_cast(tensor_shapes[0]); - std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); - for (int j = 0; j < shape_size; j++) { - shapes_min[j] = *(trt_profiles[0]->getShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN)); - shapes_max[j] = *(trt_profiles[0]->getShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX)); - shapes_opt[j] = *(trt_profiles[0]->getShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT)); - } - trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); - trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); - trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); - } else { - // execution tensor - nvinfer1::Dims dims_min, dims_opt, dims_max; - dims_min = trt_profiles[0]->getDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN); - dims_max = trt_profiles[0]->getDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX); - dims_opt = trt_profiles[0]->getDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT); - trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min); - trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max); - trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt); - } - continue; - } - - // Create shape profile - if (input->isShapeTensor()) { - // Get shape values for shape tensor input - const auto tensor_type = tensor_info.GetElementType(); - // The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension - int shape_size = dims.nbDims == 0 ? 1 : static_cast(tensor_shapes[0]); - // For setting TRT optimization profile. (Note: the min/opt/max profile values are still int32 even though int64 is supported after TRT 10) - std::vector values(shape_size); - - switch (tensor_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - auto buffer = std::make_unique(shape_size); - auto status = GetShapeOfShapeTensor(input_tensor, buffer.get(), shape_size, stream); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); - } - shape_tensor_values[input_name].resize(shape_size); - for (int j = 0; j < shape_size; ++j) { - shape_tensor_values[input_name][j] = buffer[j]; - values[j] = buffer[j]; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - auto buffer = std::make_unique(shape_size); - auto status = GetShapeOfShapeTensor(input_tensor, buffer.get(), shape_size, stream); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); - } - shape_tensor_values_int64[input_name].resize(shape_size); - for (int j = 0; j < shape_size; ++j) { - shape_tensor_values_int64[input_name][j] = buffer[j]; - values[j] = static_cast(buffer[j]); - } - break; - } - default: { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT shape tensor data type: " + std::to_string(tensor_type) + " not supported."); - } - } - - // Update shape ranges - std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); - int shape_range_size = static_cast(shape_ranges_per_input.size()); - if (shape_size == shape_range_size) { - // If shape size matches, check/update shape range - for (int j = 0; j < shape_size; ++j) { - auto& shape_range = shape_ranges_per_input[j][0]; // only has one profile - shapes_min[j] = static_cast(shape_range[0]); - shapes_max[j] = static_cast(shape_range[1]); - shapes_opt[j] = static_cast(shape_range[2]); - - const auto& tensor_shape_value = values[j]; - // Update shape range lower bound - if (tensor_shape_value < shape_range[0]) { - shape_range[0] = tensor_shape_value; - shapes_min[j] = tensor_shape_value; - *engine_update = true; - } - // Update shape range upper bound - if (tensor_shape_value > shape_range[1]) { - shape_range[1] = tensor_shape_value; - shape_range[2] = tensor_shape_value; - shapes_max[j] = tensor_shape_value; - shapes_opt[j] = tensor_shape_value; - *engine_update = true; - } - } - } else { - // If shape size doesn't match, initialize shape_range with the new shape value - shape_ranges_per_input.clear(); - for (int j = 0; j < shape_size; ++j) { - const auto& tensor_shape_value = values[j]; - std::vector> profile_vector; - std::vector shape_vector{tensor_shape_value, tensor_shape_value, tensor_shape_value}; - profile_vector.push_back(shape_vector); // only one profile needed - shape_ranges_per_input[j] = profile_vector; - shapes_min[j] = tensor_shape_value; - shapes_opt[j] = tensor_shape_value; - shapes_max[j] = tensor_shape_value; - } - *engine_update = true; - } - - trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); - trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); - trt_profile->setShapeValuesV2(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); - } else { // Execution tensor - nvinfer1::Dims dims_min(dims), dims_opt(dims), dims_max(dims); - for (int j = 0, end = nb_dims; j < end; ++j) { - const auto& tensor_shape = tensor_shapes[j]; - if (shape_ranges_per_input.find(j) != shape_ranges_per_input.end()) { - auto& shape_range = shape_ranges_per_input[j][0]; // only has one profile - dims_min.d[j] = static_cast(shape_range[0]); - dims_max.d[j] = static_cast(shape_range[1]); - dims_opt.d[j] = static_cast(shape_range[2]); - - // Update minimum dimension - if (tensor_shape < shape_range[0]) { - shape_range[0] = tensor_shape; - dims_min.d[j] = static_cast(tensor_shape); - *engine_update = true; - } - // Update maximum dimension - if (tensor_shape > shape_range[1]) { - shape_range[1] = tensor_shape; - shape_range[2] = tensor_shape; - dims_max.d[j] = static_cast(tensor_shape); - dims_opt.d[j] = static_cast(tensor_shape); - *engine_update = true; - } - } - } - - trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min); - trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max); - trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt); - } - } - return Status::OK(); -} - #define CASE_GET_INPUT_TENSOR(DATA_TYPE, SrcT) \ case DATA_TYPE: { \ auto input_tensor_ptr = input_tensor.GetTensorData(); \ @@ -551,47 +440,19 @@ Status ApplyProfileShapesFromInputTensorValue(std::vector(); \ - if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ - data = scratch_buffers.back().get(); \ - cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), elem_cnt); \ - } else { \ - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ - data = scratch_buffers.back().get(); \ - } \ - break; \ - } - #define CASE_GET_OUTPUT_TENSOR(DATA_TYPE, SrcT) \ case DATA_TYPE: { \ auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + data_ptr = output_tensor_ptr; \ if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ - buffers[output_name] = output_tensor_ptr; \ + buffer = output_tensor_ptr; \ } else { \ scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ - buffers[output_name] = scratch_buffers.back().get(); \ + buffer = scratch_buffers.back().get(); \ } \ break; \ } -#define CASE_GET_CAST_OUTPUT_TENSOR(DATA_TYPE, SrcT, DstT) \ - case DATA_TYPE: { \ - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ - if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ - buffers[output_name] = scratch_buffers.back().get(); \ - output_dim_sizes[i] = static_cast(elem_cnt); \ - } else { \ - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ - buffers[output_name] = scratch_buffers.back().get(); \ - output_dim_sizes[i] = 1; \ - } \ - break; \ - } - #define CASE_COPY_TENSOR(DATA_TYPE, DstT) \ case DATA_TYPE: { \ auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ @@ -601,15 +462,6 @@ Status ApplyProfileShapesFromInputTensorValue(std::vector(); \ - if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ - cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), elem_cnt); \ - } \ - break; \ - } - /* * Set Nv executio context input. * @@ -628,7 +480,8 @@ Status BindContextInput(Ort::KernelContext& ctx, std::unordered_map>& shape_tensor_values_int64, std::vector>& scratch_buffers, OrtAllocator* alloc, - cudaStream_t stream) { + cudaStream_t stream, + bool& skip_input_binding_allowed) { auto input_tensor = ctx.GetInput(input_index); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shapes = tensor_info.GetShape(); @@ -647,7 +500,7 @@ Status BindContextInput(Ort::KernelContext& ctx, if (trt_engine->isShapeInferenceIO(input_name)) { // Bind "shape tensor" input buffer - + skip_input_binding_allowed = false; // Shape tensor input binding cannot be skipped // The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension int shape_size = trt_engine->getTensorShape(input_name).nbDims == 0 ? 1 : static_cast(tensor_shapes[0]); switch (tensor_type) { @@ -668,7 +521,7 @@ Status BindContextInput(Ort::KernelContext& ctx, if (!trt_context->setTensorAddress(input_name, &shape_tensor_values[input_name][0])) { std::string error_input_name = input_name; std::string error_msg = - "Nv EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + + "NvTensorRTRTX EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + error_input_name + "'"; ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, error_msg)); } @@ -691,7 +544,7 @@ Status BindContextInput(Ort::KernelContext& ctx, if (!trt_context->setTensorAddress(input_name, &shape_tensor_values_int64[input_name][0])) { std::string error_input_name = input_name; std::string error_msg = - "Nv EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + + "NvTensorRTRTX EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + error_input_name + "'"; ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, error_msg)); } @@ -713,7 +566,7 @@ Status BindContextInput(Ort::KernelContext& ctx, if (!trt_context->setInputShape(input_name, dims)) { std::string error_input_name = input_name; ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP failed to call nvinfer1::IExecutionContext::setInputShape() for input '" + error_input_name + "'")); + "NvTensorRTRTX EP failed to call nvinfer1::IExecutionContext::setInputShape() for input '" + error_input_name + "'")); } // Bind "execution tensor" input buffer @@ -727,14 +580,15 @@ Status BindContextInput(Ort::KernelContext& ctx, CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, uint8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) - CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported."); + "NvTensorRTRTX EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported."); } } trt_context->setTensorAddress(input_name, data); @@ -756,8 +610,6 @@ Status BindContextInput(Ort::KernelContext& ctx, * param output_type - Data type of the output * param i - Output iteration index * param output_tensors - Output iteration index to output's ORT value - * param output_dim_sizes - Output iteration index to the multiplocation of its shape's dimensions - * param dds_output_set - DDS output set * param dds_output_allocator_map - DDS output to its allocator * param scratch_buffer - The allocation buffer created by TRT EP * param allocator - ORT allocator @@ -769,25 +621,21 @@ Status BindContextOutput(Ort::KernelContext& ctx, const char* output_name, size_t output_index, size_t output_type, - size_t i, - std::unordered_map& output_tensors, - std::unordered_map& output_dim_sizes, DDSOutputAllocatorMap& dds_output_allocator_map, std::vector>& scratch_buffers, OrtAllocator* alloc, - std::unordered_map& buffers) { + nvinfer1::Dims& dims, + void*& data_ptr) { // Get output shape - nvinfer1::Dims dims = trt_context->getTensorShape(output_name); + dims = trt_context->getTensorShape(output_name); int nb_dims = dims.nbDims; bool is_DDS = false; - std::vector output_shapes(nb_dims); for (int j = 0, end = nb_dims; j < end; ++j) { // data-dependent shape if (dims.d[j] == -1) { is_DDS = true; break; } - output_shapes[j] = dims.d[j]; } auto known_DDS = dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end(); @@ -800,31 +648,36 @@ Status BindContextOutput(Ort::KernelContext& ctx, // Otherwise, if the shape of the output tensor is known prior to the runtime, ORT will pre-allocate memory buffer for the output tensor for enqueueV3. if (is_DDS || known_DDS) { if (!known_DDS) { - auto allocatorPtr = std::make_unique(); + auto allocatorPtr = std::make_unique(alloc); trt_context->setOutputAllocator(output_name, allocatorPtr.get()); dds_output_allocator_map[output_name] = std::move(allocatorPtr); + dims.nbDims = -1; // Set to -1 to indicate that the shape is not known at this point. + data_ptr = nullptr; // Set data_ptr to nullptr for DDS output binding. } } else { - output_tensors[i] = ctx.GetOutput(output_index, output_shapes); - auto& output_tensor = output_tensors[i]; + auto output_tensor = ctx.GetOutput(output_index, dims.d, nb_dims); const auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); + void* buffer = nullptr; + switch (output_type) { + // below macros set data_ptr and skip_output_binding_allowed variables CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, uint8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) - CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP output tensor data type: " + std::to_string(output_type) + " not supported."); + "NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported."); } } - trt_context->setTensorAddress(output_name, buffers[output_name]); + trt_context->setTensorAddress(output_name, buffer); } return Status::OK(); @@ -840,7 +693,6 @@ Status BindContextOutput(Ort::KernelContext& ctx, * we are waiting for ORT core to support "assign" memory address to ORT context output. Some works need to be done in ORT memory planner to be aware of this memory support. */ Status BindKernelOutput(Ort::KernelContext& ctx, - OrtMemoryInfo* /*mem_info*/, DDSOutputAllocatorMap& allocator_map, char const* output_name, size_t output_index, @@ -878,24 +730,25 @@ Status BindKernelOutput(Ort::KernelContext& ctx, CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, uint8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) - CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP output tensor data type: " + std::to_string(output_type) + " not supported."); + "NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported."); } } return Status::OK(); } NvExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, bool has_user_compute_stream, cudaStream_t stream) { - // TODO: figure out if PerThreadContext is used at all. If not, just clean it up. + // Only set device if user hasn't provided a compute stream if (has_user_compute_stream) { CUDA_CALL_THROW(cudaSetDevice(device_id)); - (void)(stream); + (void)stream; } } @@ -903,31 +756,6 @@ NvExecutionProvider::PerThreadContext::~PerThreadContext() { trt_context_map_.clear(); } -/* - * Returns true if the shape ranges maintained by the PerThreadContext is different from the shape ragnes maintained by TRT EP, meaning the - * engine is being updated and the execution context maintained by the PerThreadContext should be updated as well. Otherwise, returns false. - * - */ -bool NvExecutionProvider::PerThreadContext::CompareProfileShapes(std::string fused_node, ShapeRangesMap& shape_ranges) { - if (shape_ranges.size() > 0) { - if (input_shape_ranges_[fused_node] != shape_ranges) { - LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] The shape ranges maintained by the PerThreadContext is different from the shape ranges maintained by TRT EP. \ - This means the engine is updated and will need to update the execution context as well."; - return true; - } - } - return false; -} - -/* - * Updates the shape ranges maintained by the PerThreadContext. - * As long as the execution context maintained by the PerThreadContext is updated, the associated shape ranges should be updated as well. - * - */ -void NvExecutionProvider::PerThreadContext::UpdateProfileShapes(std::string fused_node, ShapeRangesMap& shape_ranges) { - input_shape_ranges_[fused_node] = shape_ranges; -} - void NvExecutionProvider::PerThreadContext::ResetTensorRTContext(std::string fused_node) { auto it = trt_context_map_.find(fused_node); if (it != trt_context_map_.end()) { @@ -935,9 +763,9 @@ void NvExecutionProvider::PerThreadContext::ResetTensorRTContext(std::string fus } } -bool NvExecutionProvider::PerThreadContext::UpdateTensorRTContext(std::string fused_node, std::unique_ptr context) { +bool NvExecutionProvider::PerThreadContext::UpdateTensorRTContext(std::string fused_node, tensorrt_ptr::unique_pointer_exec_ctx context) { if (!context) { - context = std::make_unique(); + context = tensorrt_ptr::unique_pointer_exec_ctx(); } trt_context_map_[fused_node] = std::move(context); @@ -947,6 +775,86 @@ bool NvExecutionProvider::PerThreadContext::UpdateTensorRTContext(std::string fu return false; } +void NvExecutionProvider::PerThreadContext::DeleteCapturedGraph(CudaGraphAnnotation_t cuda_graph_annotation_id) { + graph_id_to_run_count_.erase(cuda_graph_annotation_id); + cuda_graph_.Reset(); +} + +void NvExecutionProvider::PerThreadContext::ResetWarmupRuns(CudaGraphAnnotation_t cuda_graph_annotation_id) { + if (graph_id_to_run_count_.find(cuda_graph_annotation_id) == graph_id_to_run_count_.end()) { + return; + } + graph_id_to_run_count_[cuda_graph_annotation_id] = 0; +} + +bool NvExecutionProvider::PerThreadContext::IsGraphCaptureAllowed(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + if (!IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id)) { + return false; + } + + // Safe access to map - return false if key doesn't exist yet + auto it = graph_id_to_run_count_.find(cuda_graph_annotation_id); + if (it == graph_id_to_run_count_.end()) { + return false; // Entry doesn't exist yet, not ready for capture + } + + bool allowed = it->second >= min_num_runs_before_cuda_graph_capture_; + if (allowed) { + LOGS_DEFAULT(VERBOSE) << "NvTensorRTRTX EP Graph capture allowed for ID: " << cuda_graph_annotation_id + << ", run count: " << it->second; + } + return allowed; +} + +bool NvExecutionProvider::PerThreadContext::IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return cuda_graph_.IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id); +} + +CudaGraphAnnotation_t NvExecutionProvider::PerThreadContext::GetCudaGraphAnnotationId(const onnxruntime::RunOptions& run_options) const { + // Actual implementation + auto graph_annotation_str = run_options.GetConfigOptions().GetConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation); + CudaGraphAnnotation_t cuda_graph_annotation_id = kCudaGraphAnnotationDefault; + + // Kind of debugging head implementation, can be cleaned and made robust like CUDA EP + if (graph_annotation_str.has_value() && !graph_annotation_str->empty()) { + if (!TryParseStringWithClassicLocale(*graph_annotation_str, cuda_graph_annotation_id)) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Failed to parse cuda graph annotation id: " + << *graph_annotation_str << ", using default: " << kCudaGraphAnnotationDefault; + cuda_graph_annotation_id = kCudaGraphAnnotationDefault; + } + } + return cuda_graph_annotation_id; +} + +void NvExecutionProvider::PerThreadContext::SetCurrentGraphAnnotationId(CudaGraphAnnotation_t cuda_graph_annotation_id) { + current_graph_annotation_id_ = cuda_graph_annotation_id; +} + +CudaGraphAnnotation_t NvExecutionProvider::PerThreadContext::GetCurrentGraphAnnotationId() const { + return current_graph_annotation_id_; +} + +void NvExecutionProvider::PerThreadContext::CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id) { + cuda_graph_.Reset(); + cuda_graph_.CaptureBegin(cuda_graph_annotation_id); +} + +void NvExecutionProvider::PerThreadContext::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id) { + cuda_graph_.CaptureEnd(cuda_graph_annotation_id); +} + +bool NvExecutionProvider::PerThreadContext::IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return cuda_graph_.IsGraphCaptured(cuda_graph_annotation_id); +} + +Status NvExecutionProvider::PerThreadContext::ReplayGraph(CudaGraphAnnotation_t cuda_graph_annotation_id, bool sync_status_flag) { + return cuda_graph_.Replay(cuda_graph_annotation_id, sync_status_flag); +} + +void NvExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture(CudaGraphAnnotation_t cuda_graph_annotation_id) { + graph_id_to_run_count_[cuda_graph_annotation_id]++; +} + bool NvExecutionProvider::PerThreadContext::IsTensorRTContextInMap(std::string fused_node) { auto it = trt_context_map_.find(fused_node); if (it != trt_context_map_.end()) { @@ -958,11 +866,11 @@ bool NvExecutionProvider::PerThreadContext::IsTensorRTContextInMap(std::string f nvinfer1::IExecutionContext& NvExecutionProvider::PerThreadContext::GetTensorRTContext(std::string fused_node) { auto it = trt_context_map_.find(fused_node); if (it != trt_context_map_.end()) { - return *(it->second); // dereference shared pointer + return *(it->second.get()); // dereference shared pointer } - auto context = std::make_unique(); + auto context = tensorrt_ptr::unique_pointer_exec_ctx(); trt_context_map_[fused_node] = std::move(context); - return *(trt_context_map_[fused_node]); // dereference shared pointer + return *(trt_context_map_[fused_node].get()); // dereference shared pointer } void NvExecutionProvider::ReleasePerThreadContext() const { @@ -1039,10 +947,21 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) cudaDeviceProp prop; CUDA_CALL_THROW(cudaGetDeviceProperties(&prop, device_id_)); - compute_capability_ = GetComputeCapacity(prop); + auto cc = prop.major * 10 + prop.minor; + if (!(cc == 86 || cc == 89 || cc >= 120)) { + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[NvTensorRTRTX EP] The execution provider only supports RTX devices with compute capabilities 86, 89, 120 and above")); + } + compute_capability_ = GetComputeCapability(prop); if (info.has_user_compute_stream) { external_stream_ = true; stream_ = static_cast(info.user_compute_stream); + } else if (cuda_graph_enable_) { + external_stream_ = false; + CUDA_CALL_THROW(cudaStreamCreate(&stream_)); + } else { + external_stream_ = false; + stream_ = nullptr; // Will be created in compute function } std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes; @@ -1060,6 +979,20 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) max_shared_mem_size_ = info.max_shared_mem_size; dump_subgraphs_ = info.dump_subgraphs; weight_stripped_engine_enable_ = info.weight_stripped_engine_enable; + // make runtime cache path absolute and create directory if it doesn't exist + if (!info.runtime_cache_path.empty()) { + std::filesystem::path p(info.runtime_cache_path); + std::filesystem::path abs_path = std::filesystem::absolute(p); + const auto& env = GetDefaultEnv(); + auto status = env.CreateFolder(abs_path.string()); + if (!status.IsOK()) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] The runtime cache directory could not be created at: " << abs_path + << ". Runtime cache is disabled."; + } else { + runtime_cache_ = abs_path; + } + } + onnx_model_folder_path_ = info.onnx_model_folder_path; onnx_model_bytestream_ = info.onnx_bytestream; onnx_model_bytestream_size_ = info.onnx_bytestream_size; @@ -1069,6 +1002,15 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) "When providing either 'trt_onnx_bytestream_size' or " "'trt_onnx_bytestream' both have to be provided")); } + use_external_data_initializer_ = info.use_external_data_initializer; + onnx_external_data_bytestream_ = info.external_data_bytestream; + onnx_external_data_bytestream_size_ = info.external_data_bytestream_size; + if ((onnx_external_data_bytestream_ != nullptr && onnx_external_data_bytestream_size_ == 0) || + (onnx_external_data_bytestream_ == nullptr && onnx_external_data_bytestream_size_ != 0)) { + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "When providing either 'onnx_external_data_bytestream_size' or " + "'onnx_external_data_bytestream' both have to be provided")); + } detailed_build_log_ = info.detailed_build_log; dump_ep_context_model_ = info.dump_ep_context_model; ep_context_file_path_ = info.ep_context_file_path; @@ -1081,7 +1023,6 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) engine_decryption_lib_path_ = info.engine_decryption_lib_path; } force_sequential_engine_build_ = info.force_sequential_engine_build; - context_memory_sharing_enable_ = info.context_memory_sharing_enable; sparsity_enable_ = info.sparsity_enable; auxiliary_streams_ = info.auxiliary_streams; profile_min_shapes = info.profile_min_shapes; @@ -1183,13 +1124,13 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) LIBTYPE handle = OPENLIB(engine_decryption_lib_path_.c_str()); if (handle == nullptr) { ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP could not open shared library from " + engine_decryption_lib_path_)); + "NvTensorRTRTX EP could not open shared library from " + engine_decryption_lib_path_)); } engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt"); engine_encryption_ = (int (*)(const char*, char*, size_t))LIBFUNC(handle, "encrypt"); if (engine_decryption_ == nullptr) { ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP could not find decryption function in shared library from " + engine_decryption_lib_path_)); + "NvTensorRTRTX EP could not find decryption function in shared library from " + engine_decryption_lib_path_)); } } @@ -1199,7 +1140,7 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) // external stream: // If user provides "external" cuda stream, only this cuda stream will be used even if multiple threads are running InferenceSession.Run() concurrently. // So, no need to synchronize different streams after enqueueV3. - if (cuda_graph_enable_ || external_stream_) { + if (external_stream_) { sync_stream_after_enqueue_ = false; } @@ -1225,16 +1166,23 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) << ", nv_engine_decryption_enable: " << engine_decryption_enable_ << ", nv_engine_decryption_lib_path: " << engine_decryption_lib_path_ << ", nv_force_sequential_engine_build: " << force_sequential_engine_build_ - << ", nv_context_memory_sharing_enable: " << context_memory_sharing_enable_ << ", nv_sparsity_enable: " << sparsity_enable_ << ", nv_auxiliary_streams: " << auxiliary_streams_ - << ", nv_cuda_graph_enable: " << cuda_graph_enable_ + << ", enable_cuda_graph: " << cuda_graph_enable_ << ", nv_dump_ep_context_model: " << dump_ep_context_model_ << ", nv_ep_context_file_path: " << ep_context_file_path_ << ", nv_ep_context_embed_mode: " << ep_context_embed_mode_ << ", nv_cache_prefix: " << cache_prefix_ << ", nv_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_ - << ", nv_op_types_to_exclude: " << op_types_to_exclude_; + << ", nv_onnx_external_bytestream_size_: " << onnx_external_data_bytestream_size_ + << ", nv_use_external_data_initializer_: " << use_external_data_initializer_ + << ", nv_op_types_to_exclude: " << op_types_to_exclude_ + << ", nv_runtime_cache_path: " << runtime_cache_; +} + +Status NvExecutionProvider::Sync() const { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); + return Status::OK(); } NvExecutionProvider::~NvExecutionProvider() { @@ -1248,7 +1196,7 @@ NvExecutionProvider::~NvExecutionProvider() { } } - if (!external_stream_ && stream_) { + if (!external_stream_ && stream_ != nullptr) { ORT_IGNORE_RETURN_VALUE(CUDA_CALL(cudaStreamDestroy(stream_))); } ReleaseTensorRTCustomOpDomainList(info_.custom_op_domain_list); @@ -1260,47 +1208,94 @@ NvExecutionProvider::~NvExecutionProvider() { } } +void NvExecutionProvider::HandleCudaGraphStart(cudaStream_t stream, bool require_io_binding, + CudaGraphAnnotation_t cuda_graph_annotation_id, bool& graph_replay_on_this_run, bool& should_start_capture) { + graph_replay_on_this_run = false; + should_start_capture = false; + + // Case 1: CUDA Graph capture is enabled AND IO binding is required. + // In this case, we force graph re-capture by resetting warmup runs. + // If a graph for this annotation ID already exists, delete it before proceeding. + if (require_io_binding && cuda_graph_enable_) { + GetPerThreadContext().ResetWarmupRuns(cuda_graph_annotation_id); + + if (GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id)) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Graph already captured and required_io_binding is true, resetting warmup runs and deleting graph"; + GetPerThreadContext().DeleteCapturedGraph(cuda_graph_annotation_id); + } + // Case 2: CUDA Graph capture is enabled AND IO binding is NOT required + } else if (cuda_graph_enable_ && !require_io_binding) { + // If the graph is not yet captured, increment the regular run counter + if (cuda_graph_annotation_id != kCudaGraphAnnotationSkip && + !GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id)) { + GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(cuda_graph_annotation_id); + } + + // If capture is allowed and graph not already captured, + // set the stream and begin capture + if (!GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id) && + GetPerThreadContext().IsGraphCaptureAllowed(cuda_graph_annotation_id)) { + GetPerThreadContext().SetCudaGraphStream(stream); + GetPerThreadContext().CaptureBegin(cuda_graph_annotation_id); + should_start_capture = true; + } + + // If a graph is already captured for this ID, mark it for replay in this run. + if (GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id)) { + graph_replay_on_this_run = true; + } + } +} + bool NvExecutionProvider::IsGraphCaptureEnabled() const { return cuda_graph_enable_; } -bool NvExecutionProvider::IsGraphCaptureAllowed() const { - return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +bool NvExecutionProvider::IsGraphCaptured(int graph_annotation_id) const { + // This is hardcoded to always return false because we are not allowing the ORT framework to have the CUDA graph control. + (void)graph_annotation_id; + return false; } -void NvExecutionProvider::CaptureBegin(int) { - cuda_graph_.Reset(); - cuda_graph_.CaptureBegin(0); +Status NvExecutionProvider::ReplayGraph(int graph_annotation_id) { + // This is hardcoded to always return OK because we are not allowing the ORT framework to have the CUDA graph control. + (void)graph_annotation_id; + return Status::OK(); } -void NvExecutionProvider::CaptureEnd(int) { - cuda_graph_.CaptureEnd(0); - is_graph_captured_ = true; -} +Status NvExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { + if (cuda_graph_enable_) { + CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCudaGraphAnnotationId(run_options); + GetPerThreadContext().SetCurrentGraphAnnotationId(cuda_graph_annotation_id); + } -bool NvExecutionProvider::IsGraphCaptured(int) const { - return is_graph_captured_; + if (multi_profile_enable_ == true) { + auto graph_annotation_str = + run_options.GetConfigOptions().GetConfigEntry(nv::run_option_names::kProfileIndex); + TryParseStringWithClassicLocale(*graph_annotation_str, nv_profile_index_); + } + return Status::OK(); } -Status NvExecutionProvider::ReplayGraph(int) { - ORT_ENFORCE(IsGraphCaptured(0)); - // Please note that CUDAGraph::Replay() is not thread safe. - // ORT TRT calls ReplayGraph() in compute_func() where synchronization is enforced due to lock_guard(), - // therefore calling CUDAGraph::Replay() here is guaranteed to be thread safe. - return cuda_graph_.Replay(0); -} +Status NvExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) { + (void)run_options; -void NvExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { - // Please note that this function is not thread safe. - // ORT TRT calls this function in compute_func() where synchronization is enforced due to lock_guard(), - // therefore following increment is guaranteed to be thread safe. - ++regular_run_count_before_graph_capture_; + if (sync_stream && external_stream_) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); + } + return Status::OK(); } std::vector NvExecutionProvider::CreatePreferredAllocators() { + OrtArenaCfg arena_cfg(0, static_cast(ArenaExtendStrategy::kNextPowerOfTwo), + -1, -1, -1, -1); AllocatorCreationInfo default_memory_info( [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, CUDA); }, - narrow(device_id_)); + narrow(device_id_), + true, + arena_cfg, + // make it stream aware + true); AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { @@ -1315,22 +1310,6 @@ std::unique_ptr NvExecutionProvider::GetDataTransfer() const { return std::make_unique(); } -Status NvExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { - if (multi_profile_enable_ == true) { - auto graph_annotation_str = - run_options.GetConfigOptions().GetConfigEntry(nv::run_option_names::kProfileIndex); - TryParseStringWithClassicLocale(*graph_annotation_str, nv_profile_index_); - } - return Status::OK(); -} - -Status NvExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { - if (sync_stream && external_stream_) { - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); - } - return Status::OK(); -} - // Get the pointer to the IBuilder instance. // Note: This function is not thread safe. Calls to this function from different threads must be serialized // even though it doesn't make sense to have multiple threads initializing the same inference session. @@ -1339,6 +1318,9 @@ nvinfer1::IBuilder* NvExecutionProvider::GetBuilder(TensorrtLogger& trt_logger) { auto lock = GetApiLock(); builder_ = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + unsigned int num_threads = std::thread::hardware_concurrency(); + builder_->setMaxThreads(num_threads / 2); + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Set threads that the builder can use to:" << builder_->getMaxThreads(); } } return builder_.get(); @@ -1649,8 +1631,11 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t SetAllGraphInputs(graph_build); } - ORT_ENFORCE(graph_build.Resolve().IsOK()); - + auto status = graph_build.Resolve(); + if (!status.IsOK()) { + LOGS_DEFAULT(ERROR) << status.ErrorMessage(); + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX graph resolve failed: " + status.ErrorMessage())); + } // Add parent graph output to the subgraph int i = 0; std::vector subgraph_outputs; @@ -1701,7 +1686,37 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating // the model proto that has different node ordering compared to original onnx model. - graph_viewer->ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/); + + // save user provided external data in memory instead of writing to ModelProto + // needed for models > 2GB + std::vector userWeights; + if (use_external_data_initializer_) { + auto c_api = Ort::GetApi(); + const InitializedTensorSet& allInitializers = graph_viewer->GetAllInitializedTensors(); + userWeights.reserve(allInitializers.size()); + for (auto& entry : allInitializers) { + OrtValue initializer_value; + auto* tp = entry.second; + if (utils::HasRawData(*tp)) { + userWeights.emplace_back(TensorrtUserWeights(tp->name(), tp->raw_data().data(), tp->raw_data().size())); + } else if (graph_viewer->GetOrtValueInitializer(tp->name(), initializer_value)) { + // the initializer was marked as external data by the ORT graph at load time since it was provided in memory + size_t size = 0; + const void* ptr = nullptr; + Ort::ThrowOnError(c_api.GetTensorSizeInBytes(&initializer_value, &size)); + Ort::ThrowOnError(c_api.GetTensorData(&initializer_value, &ptr)); + userWeights.emplace_back(tp->name(), ptr, size); + } else if (utils::HasExternalDataInMemory(*tp)) { + // only copy and take ownership of the data if none of the above conditions are met + std::unique_ptr full_init; + ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init)); + userWeights.emplace_back(std::move(full_init->name()), std::move(full_init->raw_data())); + } + } + } + + graph_viewer->ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/, !use_external_data_initializer_ /*include raw initializers*/); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string string_buf; @@ -1720,11 +1735,25 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t auto network_flags = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); + bool is_model_supported = false; + // limit the scope of trt_parser so that model gets unloaded from memory asap { auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); - auto is_model_supported = trt_parser->supportsModelV2(string_buf.data(), string_buf.size(), model_path_); + if (use_external_data_initializer_) { +#if TRT_MAJOR_RTX > 1 || TRT_MINOR_RTX >= 1 + trt_parser->loadModelProto(string_buf.data(), string_buf.size(), model_path_); + for (auto const& userWeight : userWeights) { + trt_parser->loadInitializer(userWeight.Name(), userWeight.Data(), userWeight.Size()); + } + is_model_supported = trt_parser->parseModelProto(); +#else + ORT_THROW("'nv_use_external_data_initializer' is only supported on TensorRT RTX 1.1.x.x and above."); +#endif + } else { + is_model_supported = trt_parser->supportsModelV2(string_buf.data(), string_buf.size(), model_path_); + } // Note: Calling getNbSubgraphs or getSubgraphNodes before calling supportsModelV2 results in undefined behavior. auto num_subgraphs = trt_parser->getNbSubgraphs(); @@ -1907,21 +1936,33 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, #endif model_path_[sizeof(model_path_) - 1] = '\0'; - // If the model consists of only a single "EPContext" contrib op, it means TRT EP can fetch the precompiled engine info from the node and - // load the engine directly without having to go through the processes of graph proto reconstruction, calling TRT parser and engine compilation. - // So, simply return the ComputeCapability here. - if (graph.NumberOfNodes() == 1 && GraphHasCtxNode(graph)) { - SubGraph_t supported_node_vector = {{0}, true}; - std::unique_ptr sub_graph = GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph, std::to_string(trt_version_), std::to_string(cuda_version_)), 0); - result.push_back(ComputeCapability::Create(std::move(sub_graph))); - return result; - } + const int number_of_ort_nodes = graph.NumberOfNodes(); + const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); // Generate unique kernel name for TRT graph HashValue model_hash = TRTGenerateId(graph, std::to_string(trt_version_), std::to_string(cuda_version_)); - // Get supported node list from TensorRT parser - const int number_of_ort_nodes = graph.NumberOfNodes(); + // If there are "EPContext" contrib op nodes, it means TRT EP can fetch the precompiled engine info from the node and + // load the engine directly without having to go through the processes of graph proto reconstruction, calling TRT + // parser and engine compilation. So, simply return subgraphs consists of single ep context nodes here. + int subgraph_idx = 0; + for (size_t node_idx : node_index) { + const auto& node = graph.GetNode(node_idx); + const bool is_context_node = node && !node->OpType().empty() && node->OpType() == EPCONTEXT_OP; + if (is_context_node) { + SubGraph_t supported_node_vector(std::make_pair(std::vector{node_idx}, true)); + std::unique_ptr sub_graph = GetSubGraph(supported_node_vector, graph, model_hash, subgraph_idx++); + + result.push_back(ComputeCapability::Create(std::move(sub_graph))); + } + } + // return early if context nodes where found + if (!result.empty()) { + return result; + } + + // For regular ONNX nodes, get supported node list from TensorRT parser + std::vector nodes_vector(number_of_ort_nodes); std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); @@ -1940,12 +1981,12 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, auto exclude_ops_set = get_exclude_ops_set(op_types_to_exclude_); SubGraphCollection_t parser_nodes_vector, supported_nodes_vector; - const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); bool new_subgraph = true; /* Iterate all the nodes and exclude the node if: * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible. * 2. It's a DDS op. + * 3. It has unsupported data types. */ for (const auto& index : nodes_vector) { const auto& node = graph.GetNode(node_index[index]); @@ -1985,6 +2026,16 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, supported_node = false; } + // Check data types and print warnings for unsupported types + if (supported_node) { + if (!CheckNodeDataTypes(node)) { + supported_node = false; + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Node '" << node->Name() + << "' (OpType: " << node->OpType() + << ") excluded due to unsupported data types"; + } + } + if (supported_node) { if (new_subgraph) { parser_nodes_vector.emplace_back(); @@ -2131,14 +2182,16 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, */ common::Status NvExecutionProvider::RefitEngine(std::string onnx_model_filename, std::string& onnx_model_folder_path, - std::string& weight_stripped_engine_cath_path, bool path_check, const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, + const void* onnx_external_data_bytestream, + size_t onnx_external_data_bytestream_size, nvinfer1::ICudaEngine* trt_engine, - bool serialize_refitted_engine, bool detailed_build_log) { bool refit_from_file = onnx_model_bytestream == nullptr && onnx_model_bytestream_size == 0; + bool refit_with_external_data = onnx_external_data_bytestream != nullptr && onnx_external_data_bytestream_size != 0; + bool refit_complete = false; std::filesystem::path onnx_model_path{onnx_model_folder_path}; if (refit_from_file) { if (!onnx_model_filename.empty()) { @@ -2175,34 +2228,145 @@ common::Status NvExecutionProvider::RefitEngine(std::string onnx_model_filename, auto refitter = std::unique_ptr(nvinfer1::createInferRefitter(*trt_engine, trt_logger)); auto parser_refitter = std::unique_ptr( nvonnxparser::createParserRefitter(*refitter, trt_logger)); - if (refit_from_file) { - LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refitting from file on disk: " << onnx_model_path.string(); - if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { + + // New refit APIs + if (refit_with_external_data) { +#if TRT_MAJOR_RTX > 1 || TRT_MINOR_RTX >= 1 + // A valid model bytestream must be passed. + if (refit_from_file) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); + "NvTensorRTRTX EP's refit with external data must be called with a valid ONNX model bytestream"); } - } else { - LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refitting from byte array"; - if (!parser_refitter->refitFromBytes(onnx_model_bytestream, onnx_model_bytestream_size)) { + + if (!parser_refitter->loadModelProto(onnx_model_bytestream, onnx_model_bytestream_size, nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "NvTensorRTRTX EP's IParserRefitter could not load model from provided onnx_model_bytestream"); + } + + // Extract weight information from the Refitter. + int required_weights = refitter->getAllWeights(0, nullptr); + std::vector refit_names_prealocated(required_weights); + refitter->getAllWeights(required_weights, refit_names_prealocated.data()); + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refitter requires " << required_weights << " weights"; + std::unordered_set refit_names(std::make_move_iterator(refit_names_prealocated.begin()), + std::make_move_iterator(refit_names_prealocated.end())); + + // Vectors to keep track of data pointers. + std::vector names; + names.reserve(required_weights); + std::vector bytes; + bytes.reserve(required_weights); + std::vector sizes; + sizes.reserve(required_weights); + + auto onnx_model = ModelProto::Create(); + TensorProtos* allInitializers_byte_stream; + + // Reconstruct onnx model view. + const auto onnx_model_view = std::string((const char*)onnx_model_bytestream, + onnx_model_bytestream_size); + if (!onnx_model->ParseFromString(onnx_model_view)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "The provided ONNX bytestream to refit could not be parsed."); + } + + // Extract graph and initializer information. + auto const& graph = onnx_model->mutable_graph(); + allInitializers_byte_stream = graph->mutable_initializer(); + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Initializers that were found " << allInitializers_byte_stream->size(); + + // Loop through all initializers + int missing_initializer_data = 0; + for (int initializer_idx = 0; initializer_idx < allInitializers_byte_stream->size(); ++initializer_idx) { + auto& proto = allInitializers_byte_stream->at(initializer_idx); + auto& proto_name = proto.name(); + if (refit_names.find(proto_name) != refit_names.end()) { + if (proto.has_data_location()) { + if (proto.data_location() == TensorProto_DataLocation_EXTERNAL) { + // Default values for reading into external_data blob. + int64_t offset = 0; + size_t length = 0; + auto external_data = proto.mutable_external_data(); + const std::string kOffset = "offset", kLength = "length"; + for (int entry_idx = 0; entry_idx < external_data->size(); ++entry_idx) { + auto current_key = external_data->at(entry_idx).mutable_key(); + auto current_value = external_data->at(entry_idx).mutable_value(); + if (*current_key == kOffset && !current_value->empty()) { + offset = std::stoll(*current_value); + } else if (*current_key == kLength && !current_value->empty()) { + length = std::stoul(*current_value); + } + } + names.push_back(proto.name()); + bytes.push_back(static_cast(onnx_external_data_bytestream) + offset); + sizes.push_back(length); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[NvTensorRTRTX EP] Proto: " + proto_name + " expected to have external datalocation, but default datalocation was provided instead."); + } + } else if (proto.has_raw_data()) { + auto& raw_data = proto.raw_data(); + names.push_back(proto.name()); + bytes.push_back(raw_data.c_str()); + sizes.push_back(raw_data.size()); + } else { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Proto: " + proto_name + " has no raw nor external data."; + ++missing_initializer_data; + } + } else { + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Initializer with name: " << proto_name << " was not marked as refittable"; + } + } + if (missing_initializer_data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[NvTensorRTRTX EP] RefitEngine is missing " + std::to_string(missing_initializer_data) + " initializers."); + } + + // Load extracted initializers into the parser + if (!names.empty()) { + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Number of initializers submitted to refitter " << names.size(); + for (size_t i = 0; i < names.size(); i++) { + bool refloadInit = parser_refitter->loadInitializer(names[i].c_str(), bytes[i], sizes[i]); + if (!refloadInit) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "NvTensorRTRTX EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestream"); + } + } + } + // Perform refit. + if (!parser_refitter->refitModelProto()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestraem"); + "NvTensorRTRTX EP's IParserRefitter refitModelProto() failed with the provided external data bytestream."); + } + refit_complete = true; +#else + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Refit with external data is only supported on TensorRT RTX 1.1.x.x and above."); +#endif + } + + // If new refit flow was not completed, then fallback to refit_from_file. + if (!refit_complete) { + if (refit_from_file) { + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refitting from file on disk: " << onnx_model_path.string(); + if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "NvTensorRTRTX EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); + } + } else { + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refitting from byte array"; + if (!parser_refitter->refitFromBytes(onnx_model_bytestream, onnx_model_bytestream_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "NvTensorRTRTX EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestream"); + } } } if (refitter->refitCudaEngine()) { LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Successfully refitted the weight-stripped engine."; } else { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP's IRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); + "NvTensorRTRTX EP's IRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); } - // serialize the refitted engine to disk - if (serialize_refitted_engine) { - std::string refitted_engine_cache = GetWeightRefittedEnginePath(weight_stripped_engine_cath_path); - nvinfer1::IHostMemory* serialized_engine = trt_engine->serialize(); - std::ofstream engine_file(refitted_engine_cache, std::ios::binary | std::ios::out); - engine_file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); - LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Serialize the refitted engine to " << refitted_engine_cache; - } return Status::OK(); } @@ -2228,8 +2392,10 @@ common::Status NvExecutionProvider::Compile(const std::vector } Status status; - if (GraphHasCtxNode(graph_body_viewer)) { + size_t node_idx = 0; + if (GraphHasCtxNode(graph_body_viewer, node_idx)) { status = CreateNodeComputeInfoFromPrecompiledEngine(graph_body_viewer, + node_idx, fused_node, input_map, output_map, @@ -2244,6 +2410,106 @@ common::Status NvExecutionProvider::Compile(const std::vector return Status::OK(); } +/** + * @brief Determines whether I/O binding is required for TensorRT execution. + * + * This function optimizes TensorRT inference performance by determining when tensor + * input/output binding operations can be skipped. Binding is an expensive operation + * that involves setting up tensor pointers in the TensorRT execution context, so + * avoiding unnecessary rebinding can significantly improve inference throughput. + * + * The function implements a three-tier decision logic: + * 1. First run: Always requires binding to establish initial tensor mappings + * 2. Subsequent runs with optimization allowed: Only rebind if tensors have changed + * 3. Subsequent runs without optimization: Always rebind for safety + * + * @tparam TRTState The TensorRT state type (TensorrtFuncState or TensorrtShortFuncState) + * @param trt_state Pointer to the TensorRT execution state containing tensor cache + * and configuration flags + * @param ctx ONNX Runtime kernel context providing access to current input tensors + * + * @return true if I/O binding is required (tensors changed or safety conditions apply), + * false if binding can be safely skipped (optimization enabled and tensors unchanged) + * + * @note This function modifies trt_state by: + * - Setting is_first_run to false after first execution + * - Caching current tensor parameters in input_tensors vector + * - Updating cached tensors when changes are detected + * + * @warning The skip_io_binding_allowed flag must be carefully managed as incorrect + * usage can lead to inference with stale tensor bindings and incorrect results. + */ +template +static bool IsIOBindingRequired(TRTState* const trt_state, const Ort::KernelContext& ctx) { + // Check if input tensors have changed since the last run + // If so, we need to bind input tensors again + bool require_io_binding = false; + + if (trt_state->is_first_run) { + // If this is the first run, we always bind input tensors + require_io_binding = true; + auto input_tensor_count = ctx.GetInputCount(); + auto output_tensor_count = ctx.GetOutputCount(); + trt_state->input_tensors.resize(input_tensor_count); + trt_state->output_tensors.resize(output_tensor_count); + for (size_t input_index = 0; input_index < input_tensor_count; ++input_index) { + const auto& input_tensor = ctx.GetInput(input_index); + const auto& tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + + trt_state->input_tensors[input_index] = TensorParams{input_tensor.GetTensorRawData(), tensor_info.GetShape()}; + } + trt_state->is_first_run = false; + } else if (trt_state->skip_io_binding_allowed) { + // If skip_io_binding_allowed is true, we can skip binding if input tensors are the same as before + auto input_tensor_count = ctx.GetInputCount(); + for (size_t input_index = 0; input_index < input_tensor_count; ++input_index) { + const auto& input_tensor = ctx.GetInput(input_index); + const auto& tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + + TensorParams ip_tensor{input_tensor.GetTensorRawData(), tensor_info.GetShape()}; + + if (ip_tensor != trt_state->input_tensors[input_index]) { + require_io_binding = true; + trt_state->input_tensors[input_index] = ip_tensor; + } + } + } else { + // If this is not the first run and skip_io_binding_allowed is false, we need to bind input tensors + require_io_binding = true; + } + + if (!require_io_binding) { + // no need to bind inputs, but check outputs as well + auto output_tensor_count = ctx.GetOutputCount(); + + for (size_t output_index = 0; output_index < output_tensor_count; ++output_index) { + const auto& prev_output_tensor = trt_state->output_tensors[output_index]; + + if (prev_output_tensor.dims.nbDims != -1) { + const auto& new_output_tensor = ctx.GetOutput(output_index, prev_output_tensor.dims.d, prev_output_tensor.dims.nbDims); + + // different output tensor data means we need to bind outputs again + if (prev_output_tensor.data != new_output_tensor.GetTensorRawData()) { + require_io_binding = true; + break; + } + } + } + } + + return require_io_binding; +} + +const InlinedVector NvExecutionProvider::GetEpContextNodes() const { + InlinedVector ep_context_nodes; + if (ep_context_model_) { + for (auto* node : ep_context_model_->MainGraph().Nodes()) { + ep_context_nodes.push_back(node); + } + } + return ep_context_nodes; +} + Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& graph_body_viewer, const Node& fused_node, std::unordered_map& input_map, @@ -2253,11 +2519,38 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr auto model = graph_body_viewer.CreateModel(*GetLogger()); auto model_proto = model->ToProto(); + // exclude weights if external + std::vector userWeights; + if (use_external_data_initializer_) { + auto c_api = Ort::GetApi(); + const InitializedTensorSet& allInitializers = graph_body_viewer.GetAllInitializedTensors(); + userWeights.reserve(allInitializers.size()); + for (auto& entry : allInitializers) { + OrtValue initializer_value; + auto* tp = entry.second; + if (utils::HasRawData(*tp)) { + userWeights.emplace_back(TensorrtUserWeights(tp->name(), tp->raw_data().data(), tp->raw_data().size())); + } else if (graph_body_viewer.GetOrtValueInitializer(tp->name(), initializer_value)) { + // the initializer was marked as external data by the ORT graph at load time since it was provided in memory + size_t size = 0; + const void* ptr = nullptr; + Ort::ThrowOnError(c_api.GetTensorSizeInBytes(&initializer_value, &size)); + Ort::ThrowOnError(c_api.GetTensorData(&initializer_value, &ptr)); + userWeights.emplace_back(tp->name(), ptr, size); + } else if (utils::HasExternalDataInMemory(*tp)) { + // only copy and take ownership of the data if none of the above conditions are met + std::unique_ptr full_init; + ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init)); + userWeights.emplace_back(TensorrtUserWeights(std::move(full_init->name()), std::move(full_init->raw_data()))); + } + } + } + // ORT's default topological sort is using reversed DFS. // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating // the model proto that has different node ordering compared to original onnx model. - graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/); + graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/, !use_external_data_initializer_ /*include raw initializers*/); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string string_buf; model_proto->SerializeToString(string_buf); @@ -2274,7 +2567,21 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); - trt_parser->parse(string_buf.data(), string_buf.size(), model_path_); + + if (use_external_data_initializer_) { +#if TRT_MAJOR_RTX > 1 || TRT_MINOR_RTX >= 1 + trt_parser->loadModelProto(string_buf.data(), string_buf.size(), model_path_); + for (auto const& userWeight : userWeights) { + trt_parser->loadInitializer(userWeight.Name(), userWeight.Data(), userWeight.Size()); + } + trt_parser->parseModelProto(); +#else + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "'nv_use_external_data_initializer' is only supported on TensorRT RTX 1.1.x.x and above."); +#endif + } else { + trt_parser->parse(string_buf.data(), string_buf.size(), model_path_); + } + if (max_workspace_size_ > 0) { trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_); } @@ -2349,21 +2656,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr ShapeRangesMap input_explicit_shape_ranges; ShapeRangesMap input_implicit_shape_ranges; - auto tensor_is_dynamic = [&](nvinfer1::ITensor* tensor) -> bool { - if (tensor->isShapeTensor()) { - return true; - } else { - nvinfer1::Dims dims = tensor->getDimensions(); - // Execution tensor - for (int j = 0, end = dims.nbDims; j < end; ++j) { - if (dims.d[j] == -1) { - return true; - } - } - } - return false; - }; - bool has_dynamic_shape = false; // True if input tensor has dynamic shape and no explicit profile is specified, otherwise false if ((!profile_min_shapes_.empty()) && (!profile_max_shapes_.empty()) && (!profile_opt_shapes_.empty())) { has_explicit_profile = true; @@ -2375,7 +2667,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr } else { for (unsigned int i = 0, end = num_inputs; i < end; ++i) { auto input = trt_network->getInput(i); - has_dynamic_shape |= tensor_is_dynamic(input); + has_dynamic_shape |= checkTrtTensorIsDynamic(input); } if (has_dynamic_shape) { LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] No explicit optimization profile was specified. " @@ -2399,7 +2691,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr profile_opt_shapes_.find(input_name) != profile_opt_shapes_.end() && profile_max_shapes_.find(input_name) != profile_max_shapes_.end(); if (has_explicit_profile && tensor_has_profile) { - apply_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges); + apply_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges, cuda_graph_enable_); } else { LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Creating implicit profile for tensor " << input_name; profile_min_shapes_[input_name] = std::vector>{{}}; @@ -2426,7 +2718,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr profile_max_shapes_[input_name][0][idx_dim] = dim_value; } } - apply_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges); + apply_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges, cuda_graph_enable_); } if (!apply_profile) { std::ostringstream msg; @@ -2453,7 +2745,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr ; } } - std::string trt_node_name_with_precision = fused_node.Name() + "_strong_typed"; // enable sparse weights if (sparsity_enable_) { @@ -2480,33 +2771,10 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr // // Otherwise engine will be handled at inference time. std::unique_ptr trt_engine; - std::unique_ptr trt_context; - - std::string cache_path = ""; - std::string cache_suffix = ""; - // Customize cache prefix if assigned - if (!cache_prefix_.empty()) { - // Generate cache suffix in case user would like to customize cache prefix - cache_suffix = "_" + GetCacheSuffix(fused_node.Name(), trt_node_name_with_precision); - cache_path = GetCachePath(cache_path_, cache_prefix_) + cache_suffix; - } else { - cache_path = GetCachePath(cache_path_, trt_node_name_with_precision); - } - - // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache - // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity - const std::string cache_path_prefix = cache_path; - std::string engine_cache_path = cache_path_prefix + ".engine"; - const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; - const std::string profile_cache_path = cache_path_prefix + ".profile"; - - // If weight-stripped engine is enabled and refitted engine cache is not present, - // TRT EP will use the engine cache with ".stripped.engine" appended to the end. - const std::filesystem::path engine_cache_fs_path = engine_cache_path; - if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { - engine_cache_path = cache_path_prefix + ".stripped.engine"; - weight_stripped_engine_refit_ = true; - } + tensorrt_ptr::unique_pointer_exec_ctx trt_context; + std::unique_ptr trt_runtime_cache; + std::unique_ptr trt_runtime_config; + std::string runtime_cache_file = ""; // Generate file name for dumping ep context model if (dump_ep_context_model_ && ctx_model_path_.empty()) { @@ -2522,49 +2790,82 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr std::unique_ptr serialized_engine{trt_builder->buildSerializedNetwork(*trt_network, *trt_config)}; if (serialized_engine == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP failed to create engine from network for fused node: " + fused_node.Name()); + "NvTensorRTRTX EP failed to create engine from network for fused node: " + fused_node.Name()); } trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); if (trt_engine == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP failed to deserialize engine for fused node: " + fused_node.Name()); + "NvTensorRTRTX EP failed to deserialize engine for fused node: " + fused_node.Name()); + } + + trt_runtime_config = std::unique_ptr(trt_engine->createRuntimeConfig()); + if (trt_runtime_config && cuda_graph_enable_) { + trt_runtime_config->setDynamicShapesKernelSpecializationStrategy(nvinfer1::DynamicShapesKernelSpecializationStrategy::kEAGER); } + trt_runtime_config->setExecutionContextAllocationStrategy(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED); + if (!runtime_cache_.empty()) { + runtime_cache_file = (runtime_cache_ / fused_node.Name()).string(); + trt_runtime_cache = std::unique_ptr(trt_runtime_config->createRuntimeCache()); + auto cache_data = file_utils::ReadFile(runtime_cache_file); + if (!trt_runtime_cache->deserialize(cache_data.data(), cache_data.size())) { + trt_runtime_cache = std::unique_ptr(trt_runtime_config->createRuntimeCache()); + LOGS_DEFAULT(INFO) << "TensorRT RTX failed to deserialize the runtime cache, will overwrite with new one" << std::endl; + } + if (!trt_runtime_config->setRuntimeCache(*trt_runtime_cache)) { + LOGS_DEFAULT(INFO) << "TensorRT RTX failed to set the runtime cache" << std::endl; + } + } + if (detailed_build_log_) { auto engine_build_stop = std::chrono::steady_clock::now(); - LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; + LOGS_DEFAULT(INFO) << "TensorRT engine build for " << fused_node.Name() << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; } // dump EP context node model if (dump_ep_context_model_) { // "ep_cache_context" node attribute should be a relative path to context model directory - if (ep_cache_context_attr_.empty()) { - auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); - ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string(); + + std::string cache_path = ""; + // Customize cache prefix if assigned + if (!cache_prefix_.empty()) { + // Generate cache suffix in case user would like to customize cache prefix + cache_path = GetCachePath(cache_path_, cache_prefix_) + fused_node.Name() + ".engine"; + ; + } else { + cache_path = GetCachePath(cache_path_, fused_node.Name()) + ".engine"; + ; + } + // NV TRT EP per default generates hardware compatible engines for any RTX device with compute capability > 80 + std::string compute_capability_hw_compat = "80+"; + if (!ep_context_model_) { + ep_context_model_ = Model::Create("nv_trt_rtx_ep_context_model", false, *GetLogger()); + } + + auto status = CreateCtxNode(graph_body_viewer, + ep_context_model_->MainGraph(), + cache_path, + reinterpret_cast(serialized_engine->data()), + serialized_engine->size(), + ep_context_embed_mode_, + compute_capability_hw_compat, + model_path_, + fused_node.Name(), + trt_version_); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } - std::string compute_capability_hw_compat = compute_capability_ + "+"; - std::unique_ptr model_proto{CreateCtxModel(graph_body_viewer, - ep_cache_context_attr_, - reinterpret_cast(serialized_engine->data()), - serialized_engine->size(), - ep_context_embed_mode_, - compute_capability_hw_compat, - model_path_, - GetLogger())}; - DumpCtxModel(model_proto.get(), ctx_model_path_); } } if (weight_stripped_engine_refit_) { LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refit engine from main ONNX file after engine build"; - char* onnx = string_buf.data(); - size_t onnx_size = string_buf.size(); auto status = RefitEngine(model_path_, onnx_model_folder_path_, - engine_cache_path, false /* path check for security */, - onnx, - onnx_size, + onnx_model_bytestream_, + onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, trt_engine.get(), - false /* serialize refitted engine to disk */, detailed_build_log_); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); @@ -2574,31 +2875,20 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr // Build context // Note: Creating an execution context from an engine is thread safe per TRT doc // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - if (context_memory_sharing_enable_) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - size_t mem_size = trt_engine->getDeviceMemorySizeV2(); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (mem_size > max_ctx_mem_size_) { - max_ctx_mem_size_ = mem_size; - } - trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); - } else { - trt_context = std::unique_ptr(trt_engine->createExecutionContext()); - } + trt_context = tensorrt_ptr::unique_pointer_exec_ctx( + trt_engine->createExecutionContext(trt_runtime_config.get()), + tensorrt_ptr::IExecutionContextDeleter(runtime_cache_file, std::move(trt_runtime_cache))); if (!trt_context) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP could not build execution context for fused node: " + fused_node.Name()); + "NvTensorRTRTX EP could not build execution context for fused node: " + fused_node.Name()); } + bool is_dynamic_shape_context = false; // Create input to index map for (int i = 0; i < num_inputs; ++i) { auto input = trt_network->getInput(i); const std::string& input_name = input->getName(); + is_dynamic_shape_context |= checkTrtDimIsDynamic(trt_engine->getTensorShape(input_name.c_str())); const auto& iter = input_map.find(input_name); if (iter != input_map.end()) { input_indexes[input_name] = iter->second; @@ -2618,7 +2908,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr } // Save TRT engine, other TRT objects and input/output info to map - parsers_.emplace(fused_node.Name(), std::move(trt_parser)); engines_.emplace(fused_node.Name(), std::move(trt_engine)); contexts_.emplace(fused_node.Name(), std::move(trt_context)); networks_.emplace(fused_node.Name(), std::move(trt_network)); @@ -2634,15 +2923,14 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(), - &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], + &engines_[context->node_name], &contexts_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - input_shape_ranges_[context->node_name], &tensorrt_mu_, trt_node_name_with_precision, + input_shape_ranges_[context->node_name], &tensorrt_mu_, engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], - context_memory_sharing_enable_, &max_ctx_mem_size_, engine_decryption_enable_, engine_decryption_, engine_encryption_, detailed_build_log_, sparsity_enable_, - auxiliary_streams_, cuda_graph_enable_, cache_prefix_, cache_suffix}; + auxiliary_streams_, cuda_graph_enable_, is_dynamic_shape_context, cache_prefix_}; *state = p.release(); return 0; }; @@ -2666,93 +2954,38 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; auto fused_node_name = trt_state->fused_node_name; - // This map "shape_ranges" contains the shape range info for setting TRT optimization profiles. - // The info is used for both shape tensor and execution tensor: - // tensor name->(dimension->[min, max, opt]) - auto& shape_ranges = trt_state->input_shape_ranges; + std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name]; auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); auto trt_profiles = trt_state->profiles; - auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; - int num_inputs = static_cast(input_indexes.size()); - int num_outputs = static_cast(output_indexes.size()); std::unordered_set input_names; - OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, - narrow(device_id_)); - OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device); if (alloc_ == nullptr) { + OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, + narrow(device_id_)); + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device); Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); } OrtAllocator* alloc = alloc_; - void* cuda_stream; - Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); - cudaStream_t stream = static_cast(cuda_stream); - - if (multi_profile_enable_ == true) { - if (!trt_context->setOptimizationProfileAsync(nv_profile_index_, stream)) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Nv EP select an optimization profile for the current context failed"); - } - - // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache - // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity - // Prepare cache name - std::string cache_path = ""; - // Customize cache prefix if assigned - if (!cache_prefix_.empty()) { - cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->cache_prefix) + trt_state->cache_suffix; + cudaStream_t stream; + if (stream_ != nullptr) { + // Use our existing stream (either user's or our early-created) + stream = stream_; } else { - cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision); - } - - // Enable hardware compatility mode if assigned - std::string cache_hw_compat = "_sm" + compute_capability_; - - // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache - // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity - const std::string cache_path_prefix = cache_path + cache_hw_compat; - std::string engine_cache_path = cache_path_prefix + ".engine"; - const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; - const std::string profile_cache_path = cache_path_prefix + ".profile"; - - // If weight-stripped engine is enabled and refitted engine cache is not present, - // TRT EP will use the engine cache with ".stripped.engine" appended to the end. - const std::filesystem::path engine_cache_fs_path = engine_cache_path; - if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { - engine_cache_path = cache_path_prefix + ".stripped.engine"; - weight_stripped_engine_refit_ = true; - } - - // Check and update shape ranges for dynamic shape inputs. - for (int i = 0, end = num_inputs; i < end; ++i) { - auto input = trt_state->network->get()->getInput(i); - const std::string& input_name = input->getName(); - input_names.insert(input_name); - - // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved. - // TRT EP will help determine the min/max/opt profile values based on current input tensor value. - if (shape_ranges.find(input_name) != shape_ranges.end()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Nv EP failed to parse input tensor and generate optimization profiles."); - } + // Create stream now (lazy creation case) + void* cuda_stream; + Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); + stream = static_cast(cuda_stream); + stream_ = stream; } - if (weight_stripped_engine_refit_) { - auto status = RefitEngine(model_path_, - onnx_model_folder_path_, - engine_cache_path, - false /* path check for security */, - onnx_model_bytestream_, - onnx_model_bytestream_size_, - trt_engine, - false /* serialize refitted engine to disk */, - detailed_build_log_); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); - } + if (multi_profile_enable_ == true) { + if (!trt_context->setOptimizationProfileAsync(nv_profile_index_, stream)) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP select an optimization profile for the current context failed"); } // Check before using trt_engine @@ -2760,6 +2993,8 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "No engine is found."); } + bool require_io_binding = IsIOBindingRequired(trt_state, ctx); + // Get input and output binding names int total_bindings = trt_engine->getNbIOTensors(); std::vector input_binding_names, output_binding_names; @@ -2776,86 +3011,89 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr /* * Set input shapes and bind input buffers */ - std::vector> scratch_buffers; - for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { - char const* input_name = input_binding_names[i]; - - size_t input_index = 0; - const auto iter = input_indexes.find(input_name); - if (iter != input_indexes.end()) { - input_index = iter->second; - } - auto input_tensor = ctx.GetInput(input_index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shapes = tensor_info.GetShape(); + auto& scratch_buffers = trt_state->scratch_buffers; + if (require_io_binding) { + scratch_buffers.clear(); + bool skip_input_binding_allowed = true; + for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { + char const* input_name = input_binding_names[i]; + + size_t input_index = 0; + const auto iter = input_indexes.find(input_name); + if (iter != input_indexes.end()) { + input_index = iter->second; + } - auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream, skip_input_binding_allowed); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } } + trt_state->skip_io_binding_allowed = skip_input_binding_allowed; } /* * Set output shapes and bind output buffers */ - std::unordered_map buffers; - buffers.reserve(num_outputs); - using OutputOrtValue = Ort::UnownedValue; - std::unordered_map output_tensors; - output_tensors.reserve(num_outputs); - std::unordered_map output_dim_sizes; - output_dim_sizes.reserve(num_outputs); + if (require_io_binding) { + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; - for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { - char const* output_name = output_binding_names[i]; + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } - size_t output_index = 0; - const auto& index_iter = output_indexes.find(output_name); - if (index_iter != output_indexes.end()) { - output_index = index_iter->second; - } + size_t output_type = 0; + const auto type_iter = output_types.find(output_name); + if (type_iter != output_types.end()) { + output_type = type_iter->second; + } - size_t output_type = 0; - const auto type_iter = output_types.find(output_name); - if (type_iter != output_types.end()) { - output_type = type_iter->second; - } + nvinfer1::Dims dims; + void* data_ptr = nullptr; - Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, - dds_output_allocator_map, scratch_buffers, alloc, buffers); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, + dds_output_allocator_map, scratch_buffers, alloc, dims, data_ptr); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + + trt_state->output_tensors[output_index] = TensorParams{data_ptr, dims}; } } // Set execution context memory - if (trt_state->context_memory_sharing_enable) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif + if (require_io_binding) { size_t mem_size = trt_engine->getDeviceMemorySizeV2(); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (mem_size > *max_context_mem_size_ptr) { - *max_context_mem_size_ptr = mem_size; + if (trt_state->is_dynamic_shape) { + mem_size = trt_context->updateDeviceMemorySizeForShapes(); + } + if (trt_state->context_memory_size != mem_size) { + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] A new context memory was allocated with size " << mem_size; + trt_state->context_memory = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, mem_size, true /*use_reserve*/); + trt_state->context_memory_size = mem_size; + trt_context->setDeviceMemoryV2(trt_state->context_memory.get(), mem_size); } - trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); } - // Start CUDA graph capture. - // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because - // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { - LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; - cuda_graph_.SetStream(stream); - CaptureBegin(0); - } + // Start CUDA graph capture with the correct stream + // Note: We need to set the stream and start capture here because this is where we have access to the actual compute stream + // Get the graph annotation ID that was stored during OnRunStart + CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCurrentGraphAnnotationId(); + bool graph_replay_on_this_run = false; + bool should_start_capture = false; + + HandleCudaGraphStart(stream, require_io_binding, cuda_graph_annotation_id, + graph_replay_on_this_run, should_start_capture); - // Run TRT inference - if (!trt_context->enqueueV3(stream)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Nv EP execution context enqueue failed."); + if (!graph_replay_on_this_run) { + if (!trt_context->enqueueV3(stream)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP execution context enqueue failed."); + } + } else { + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id, sync_stream_after_enqueue_)); } /* @@ -2872,10 +3110,15 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above. * However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture. */ + + if (cuda_graph_enable_ && should_start_capture) { + GetPerThreadContext().CaptureEnd(cuda_graph_annotation_id); + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id, sync_stream_after_enqueue_)); + } + if (sync_stream_after_enqueue_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); } - // Assign TRT output back to ORT output // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output @@ -2894,33 +3137,10 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr if (index_iter != output_indexes.end()) { output_index = index_iter->second; } - auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); + auto status = BindKernelOutput(ctx, dds_output_allocator_map, output_name, output_index, output_type, stream); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); } - } else { - auto& output_tensor = output_tensors[i]; - if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); - } - } - } - } - - // End CUDA graph capture. - // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture - // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. - // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. - if (cuda_graph_enable_ && !IsGraphCaptured(0)) { - if (IsGraphCaptureAllowed()) { - CaptureEnd(0); - // CUDA work issued to a capturing stream doesn't actually run on the GPU, - // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(ReplayGraph(0)); - } else { - IncrementRegularRunCountBeforeGraphCapture(); } } @@ -2932,12 +3152,13 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr } Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const GraphViewer& graph_body_viewer, + size_t node_idx, const Node& fused_node, std::unordered_map& input_map, std::unordered_map& output_map, std::vector& node_compute_funcs) { std::unique_ptr trt_engine; - std::unique_ptr trt_context; + tensorrt_ptr::unique_pointer_exec_ctx trt_context; std::unordered_map input_indexes; // TRT engine input name -> ORT kernel context input index std::unordered_map output_indexes; // TRT engine output name -> ORT kernel context output index std::unordered_map output_types; // TRT engine output name -> ORT output tensor type @@ -2951,43 +3172,53 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra onnx_model_folder_path_, onnx_model_bytestream_, onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, detailed_build_log_); - auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer); + auto status = trt_cache_model_handler.GetEpContextFromGraph(*graph_body_viewer.GetNode(node_idx)); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } + std::unique_ptr trt_runtime_cache; + auto trt_runtime_config = std::unique_ptr(trt_engine->createRuntimeConfig()); + if (trt_runtime_config && cuda_graph_enable_) { + trt_runtime_config->setDynamicShapesKernelSpecializationStrategy(nvinfer1::DynamicShapesKernelSpecializationStrategy::kEAGER); + } + trt_runtime_config->setExecutionContextAllocationStrategy(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED); + std::string runtime_cache_file = ""; + if (!runtime_cache_.empty()) { + runtime_cache_file = (runtime_cache_ / graph_body_viewer.GetNode(node_idx)->Name()).string(); + trt_runtime_cache = std::unique_ptr(trt_runtime_config->createRuntimeCache()); + auto cache_data = file_utils::ReadFile(runtime_cache_file); + if (!trt_runtime_cache->deserialize(cache_data.data(), cache_data.size())) { + trt_runtime_cache = std::unique_ptr(trt_runtime_config->createRuntimeCache()); + LOGS_DEFAULT(INFO) << "TensorRT RTX failed to deserialize the runtime cache, will overwrite with new one" << std::endl; + } + if (!trt_runtime_config->setRuntimeCache(*trt_runtime_cache)) { + LOGS_DEFAULT(INFO) << "TensorRT RTX failed to set the runtime cache" << std::endl; + } + } + // Build context // // Note: Creating an execution context from an engine is thread safe per TRT doc // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - if (context_memory_sharing_enable_) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - size_t mem_size = trt_engine->getDeviceMemorySizeV2(); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (mem_size > max_ctx_mem_size_) { - max_ctx_mem_size_ = mem_size; - } - trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); - - } else { - trt_context = std::unique_ptr(trt_engine->createExecutionContext()); - } + trt_context = tensorrt_ptr::unique_pointer_exec_ctx( + trt_engine->createExecutionContext(trt_runtime_config.get()), + tensorrt_ptr::IExecutionContextDeleter(runtime_cache_file, std::move(trt_runtime_cache))); if (!trt_context) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP could not build execution context for fused node: " + fused_node.Name()); + "NvTensorRTRTX EP could not build execution context for fused node: " + fused_node.Name()); } + bool is_dynamic_shape_context = false; // Create input/output to index maps for (int32_t i = 0; i < trt_engine->getNbIOTensors(); ++i) { auto const& name = trt_engine->getIOTensorName(i); auto const& mode = trt_engine->getTensorIOMode(name); if (mode == nvinfer1::TensorIOMode::kINPUT) { + is_dynamic_shape_context |= checkTrtDimIsDynamic(trt_engine->getTensorShape(name)); const auto& iter = input_map.find(name); if (iter != input_map.end()) { input_indexes[name] = iter->second; @@ -3027,9 +3258,8 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra &contexts_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - context_memory_sharing_enable_, - &max_ctx_mem_size_, - &tensorrt_mu_}; + &tensorrt_mu_, + is_dynamic_shape_context}; *state = p.release(); return 0; }; @@ -3056,28 +3286,35 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name]; auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); - auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; - int num_outputs = static_cast(output_indexes.size()); std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input - OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, - narrow(device_id_)); - OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device); if (alloc_ == nullptr) { + OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, + narrow(device_id_)); + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device); Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); } OrtAllocator* alloc = alloc_; - void* cuda_stream; - Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); - cudaStream_t stream = static_cast(cuda_stream); + cudaStream_t stream; + if (stream_ != nullptr) { + // Use our existing stream (either user's or our early-created) + stream = stream_; + } else { + // Create stream now (lazy creation case) + void* cuda_stream; + Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream)); + stream = static_cast(cuda_stream); + } // Check before using trt_engine if (trt_engine == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "No engine is found."); } + bool require_io_binding = IsIOBindingRequired(trt_state, ctx); + // Get input and output binding names int total_bindings = trt_engine->getNbIOTensors(); std::vector input_binding_names, output_binding_names; @@ -3094,83 +3331,90 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra /* * Set input shapes and bind input buffers */ - std::vector> scratch_buffers; - for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { - char const* input_name = input_binding_names[i]; - - size_t input_index = 0; - const auto iter = input_indexes.find(input_name); - if (iter != input_indexes.end()) { - input_index = iter->second; - } + auto& scratch_buffers = trt_state->scratch_buffers; + if (require_io_binding) { + scratch_buffers.clear(); + bool skip_input_binding_allowed = true; + for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { + char const* input_name = input_binding_names[i]; + + size_t input_index = 0; + const auto iter = input_indexes.find(input_name); + if (iter != input_indexes.end()) { + input_index = iter->second; + } - Status status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + Status status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream, skip_input_binding_allowed); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } } + trt_state->skip_io_binding_allowed = skip_input_binding_allowed; } /* * Set output shapes and bind output buffers */ - std::unordered_map buffers; - buffers.reserve(num_outputs); - using OutputOrtValue = Ort::UnownedValue; - std::unordered_map output_tensors; - output_tensors.reserve(num_outputs); - std::unordered_map output_dim_sizes; - output_dim_sizes.reserve(num_outputs); + if (require_io_binding) { + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; - for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { - char const* output_name = output_binding_names[i]; + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } - size_t output_index = 0; - const auto& index_iter = output_indexes.find(output_name); - if (index_iter != output_indexes.end()) { - output_index = index_iter->second; - } + size_t output_type = 0; + const auto type_iter = output_types.find(output_name); + if (type_iter != output_types.end()) { + output_type = type_iter->second; + } - size_t output_type = 0; - const auto type_iter = output_types.find(output_name); - if (type_iter != output_types.end()) { - output_type = type_iter->second; - } + nvinfer1::Dims dims; + void* data_ptr = nullptr; - Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, - dds_output_allocator_map, scratch_buffers, alloc, buffers); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, + dds_output_allocator_map, scratch_buffers, alloc, dims, data_ptr); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); + } + + trt_state->output_tensors[output_index] = TensorParams{data_ptr, dims}; } } // Set execution context memory - if (trt_state->context_memory_sharing_enable) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif + if (require_io_binding) { size_t mem_size = trt_engine->getDeviceMemorySizeV2(); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (mem_size > *max_context_mem_size_ptr) { - *max_context_mem_size_ptr = mem_size; + if (trt_state->is_dynamic_shape) { + mem_size = trt_context->updateDeviceMemorySizeForShapes(); + } + if (trt_state->context_memory_size != mem_size) { + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] A new context memory was allocated with size " << mem_size; + trt_state->context_memory = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, mem_size, true /*use_reserve*/); + // trt_state->context_memory = IAllocator::MakeUniquePtr(alloc, mem_size, false /*use_reserve*/, stream); + trt_state->context_memory_size = mem_size; + trt_context->setDeviceMemoryV2(trt_state->context_memory.get(), mem_size); } - trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); } - // Start CUDA graph capture. - // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because - // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { - LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; - cuda_graph_.SetStream(stream); - CaptureBegin(0); - } + // Start CUDA graph capture with the correct stream + // Note: We need to set the stream and start capture here because this is where we have access to the actual compute stream + // Get the graph annotation ID that was stored during OnRunStart + CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCurrentGraphAnnotationId(); + bool graph_replay_on_this_run = false; + bool should_start_capture = false; - // Run TRT inference - if (!trt_context->enqueueV3(stream)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Nv EP execution context enqueue failed."); + HandleCudaGraphStart(stream, require_io_binding, cuda_graph_annotation_id, + graph_replay_on_this_run, should_start_capture); + + if (!graph_replay_on_this_run) { + if (!trt_context->enqueueV3(stream)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP execution context enqueue failed."); + } + } else { + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id, sync_stream_after_enqueue_)); } /* @@ -3187,10 +3431,15 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above. * However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture. */ + + if (cuda_graph_enable_ && should_start_capture) { + GetPerThreadContext().CaptureEnd(cuda_graph_annotation_id); + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id, sync_stream_after_enqueue_)); + } + if (sync_stream_after_enqueue_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); } - // Assign TRT output back to ORT output // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output @@ -3209,33 +3458,10 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra if (index_iter != output_indexes.end()) { output_index = index_iter->second; } - auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); + auto status = BindKernelOutput(ctx, dds_output_allocator_map, output_name, output_index, output_type, stream); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); } - } else { - auto& output_tensor = output_tensors[i]; - if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); - } - } - } - } - - // End CUDA graph capture. - // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture - // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. - // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. - if (cuda_graph_enable_ && !IsGraphCaptured(0)) { - if (IsGraphCaptureAllowed()) { - CaptureEnd(0); - // CUDA work issued to a capturing stream doesn't actually run on the GPU, - // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(ReplayGraph(0)); - } else { - IncrementRegularRunCountBeforeGraphCapture(); } } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index 7a0c47d28c81d..bb8f687db094f 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -12,10 +12,11 @@ typedef void* cublasHandle_t; typedef void* cudnnStatus_t; #endif #include "core/providers/nv_tensorrt_rtx/nv_includes.h" - +#include "core/session/onnxruntime_run_options_config_keys.h" #include #include "core/providers/cuda/cuda_graph.h" #include "nv_execution_provider_info.h" +#include "core/providers/nv_tensorrt_rtx/nv_file_utils.h" namespace onnxruntime { @@ -58,6 +59,26 @@ class TensorrtLogger : public nvinfer1::ILogger { }; namespace tensorrt_ptr { +/* + * custom deleter that will dump the optimized runtime cache when the execution context is destructed + */ +struct IExecutionContextDeleter { + IExecutionContextDeleter() = default; + IExecutionContextDeleter(const std::string& runtime_cache_path, std::unique_ptr&& runtime_cache) : runtime_cache_path_(runtime_cache_path), runtime_cache_(std::move(runtime_cache)) {}; + void operator()(nvinfer1::IExecutionContext* context) { + if (context != nullptr) { + if (!runtime_cache_path_.empty()) { + auto serialized_cache_data = std::unique_ptr(runtime_cache_->serialize()); + file_utils::WriteFile(runtime_cache_path_, serialized_cache_data->data(), serialized_cache_data->size()); + } + delete context; + } + } + + private: + std::string runtime_cache_path_; + std::unique_ptr runtime_cache_; +}; struct TensorrtInferDeleter { template @@ -70,6 +91,7 @@ struct TensorrtInferDeleter { template using unique_pointer = std::unique_ptr; +using unique_pointer_exec_ctx = std::unique_ptr; }; // namespace tensorrt_ptr // @@ -78,6 +100,9 @@ using unique_pointer = std::unique_ptr; // class OutputAllocator : public nvinfer1::IOutputAllocator { public: + OutputAllocator() = delete; + OutputAllocator(OrtAllocator* allocator) : alloc_(allocator) {}; + void* reallocateOutputAsync(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment, cudaStream_t stream) noexcept override; void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override; @@ -95,10 +120,11 @@ class OutputAllocator : public nvinfer1::IOutputAllocator { } ~OutputAllocator() override { - cudaFree(outputPtr); + alloc_->Free(alloc_, outputPtr); } private: + OrtAllocator* alloc_; void* outputPtr{nullptr}; uint64_t allocated_size = 0; std::vector output_shapes; @@ -110,6 +136,80 @@ class OutputAllocator : public nvinfer1::IOutputAllocator { */ using ShapeRangesMap = std::unordered_map>>>; +/** + * @brief Container for tensor data and their shape. + * + */ +struct TensorParams { + const void* data{nullptr}; + nvinfer1::Dims dims; + + TensorParams() = default; + + TensorParams(const void* data_ptr, const std::vector& shape) { + // Initialize data and dims from the Ort::ConstValue + data = data_ptr; + + dims.nbDims = static_cast(shape.size()); + for (int i = 0; i < dims.nbDims; ++i) { + dims.d[i] = static_cast(shape[i]); + } + } + + TensorParams(const void* data_ptr, nvinfer1::Dims& shape) { + // Initialize data and dims from the Ort::ConstValue + data = data_ptr; + + dims = shape; + } + + bool operator!=(const TensorParams& other) const { + if (data != other.data || dims.nbDims != other.dims.nbDims) + return true; + + for (int i = 0; i < dims.nbDims; ++i) { + if (dims.d[i] != other.dims.d[i]) + return true; + } + return false; + } +}; + +// Data structure to hold user weights when ModelProtos are serialized with external data +class TensorrtUserWeights { + public: + TensorrtUserWeights(const std::string& name, const std::string& data) : name_(name), + data_cpy_(data) { + }; + + TensorrtUserWeights(const std::string& name, const void* data, size_t size) : name_(name), data_(data), size_(size) { + }; + + const char* Name() const { + return name_.c_str(); + }; + + const void* Data() const { + if (!data_cpy_.empty()) { + return data_cpy_.data(); + } + return data_; + } + + int64_t Size() const { + if (!data_cpy_.empty()) { + return static_cast(data_cpy_.size()); + } + return static_cast(size_); + } + + private: + std::string name_{}; + std::string data_cpy_{}; + void const* data_; + size_t size_; +}; + // Information to construct kernel function state. struct TensorrtFuncState { AllocateFunc test_allocate_func = nullptr; @@ -117,21 +217,17 @@ struct TensorrtFuncState { AllocatorHandle allocator = nullptr; std::string fused_node_name; nvinfer1::IBuilder* builder; - tensorrt_ptr::unique_pointer* parser = nullptr; std::unique_ptr* engine = nullptr; - std::unique_ptr* context = nullptr; + tensorrt_ptr::unique_pointer_exec_ctx* context = nullptr; std::unique_ptr* network = nullptr; std::vector> input_info; std::vector> output_info; std::unordered_map>>> input_shape_ranges; std::mutex* tensorrt_mu_ptr = nullptr; - std::string trt_node_name_with_precision; bool engine_cache_enable = false; std::string engine_cache_path; nvinfer1::IRuntime* runtime = nullptr; std::vector profiles; - bool context_memory_sharing_enable = false; - size_t* max_context_mem_size_ptr = nullptr; bool engine_decryption_enable = false; int (*engine_decryption)(const char*, char*, size_t*) = nullptr; int (*engine_encryption)(const char*, char*, size_t) = nullptr; @@ -139,8 +235,17 @@ struct TensorrtFuncState { bool sparsity_enable = false; int auxiliary_streams = -1; bool cuda_graph_enable = 0; + bool is_dynamic_shape = false; std::string cache_prefix; std::string cache_suffix; + // runtime parameters + std::vector> scratch_buffers; + std::vector input_tensors; + std::vector output_tensors; + bool is_first_run = true; // Indicates if this is the first run of the engine + bool skip_io_binding_allowed = false; // Indicates if input/output binding can be skipped + IAllocatorUniquePtr context_memory = nullptr; + size_t context_memory_size = 0; }; // Minimum information to construct kernel function state for direct engine load code path @@ -150,12 +255,19 @@ struct TensorrtShortFuncState { AllocatorHandle allocator = nullptr; std::string fused_node_name; std::unique_ptr* engine = nullptr; - std::unique_ptr* context = nullptr; + tensorrt_ptr::unique_pointer_exec_ctx* context = nullptr; std::vector> input_info; std::vector> output_info; - bool context_memory_sharing_enable = false; - size_t* max_context_mem_size_ptr = nullptr; std::mutex* tensorrt_mu_ptr = nullptr; + bool is_dynamic_shape = false; + // runtime parameters + std::vector> scratch_buffers; + std::vector input_tensors; + std::vector output_tensors; + bool is_first_run = true; // Indicates if this is the first run of the engine + bool skip_io_binding_allowed = false; // Indicates if input/output binding can be skipped + IAllocatorUniquePtr context_memory = nullptr; + size_t context_memory_size = 0; }; // Holds important information for building valid ORT graph. @@ -195,6 +307,7 @@ class NvExecutionProvider : public IExecutionProvider { IResourceAccountant* /* resource_accountant */) const override; int GetDeviceId() const { return device_id_; } + Status Sync() const; common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; @@ -214,20 +327,24 @@ class NvExecutionProvider : public IExecutionProvider { std::vector CreatePreferredAllocators() override; + // CUDA Graph support bool IsGraphCaptureEnabled() const override; bool IsGraphCaptured(int graph_annotation_id) const override; Status ReplayGraph(int graph_annotation_id) override; + void HandleCudaGraphStart(cudaStream_t stream, bool require_io_binding, CudaGraphAnnotation_t cuda_graph_annotation_id, bool& graph_replay_on_this_run, bool& should_start_capture); static common::Status RefitEngine(std::string onnx_model_filename, std::string& onnx_model_folder_path, - std::string& weight_stripped_engine_cath_path, bool path_check, const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, + const void* onnx_external_data_bytestream, + size_t onnx_external_data_bytestream_size, nvinfer1::ICudaEngine* trt_engine, - bool serialize_refitted_engine, bool detailed_build_log); + const InlinedVector GetEpContextNodes() const override; + private: mutable NvExecutionProviderInfo info_; bool external_stream_ = false; @@ -244,6 +361,9 @@ class NvExecutionProvider : public IExecutionProvider { std::string onnx_model_folder_path_; const void* onnx_model_bytestream_; size_t onnx_model_bytestream_size_; + bool use_external_data_initializer_ = false; + const void* onnx_external_data_bytestream_ = nullptr; + size_t onnx_external_data_bytestream_size_ = 0; bool sparsity_enable_ = false; int auxiliary_streams_ = -1; std::string cache_path_, engine_decryption_lib_path_; @@ -251,9 +371,7 @@ class NvExecutionProvider : public IExecutionProvider { std::mutex tensorrt_mu_; int device_id_; std::string compute_capability_; - bool context_memory_sharing_enable_ = false; size_t max_ctx_mem_size_ = 0; - IAllocatorUniquePtr context_memory_ = nullptr; mutable char model_path_[4096] = {}; // Reserved for max path length bool engine_decryption_enable_ = false; int (*engine_decryption_)(const char*, char*, size_t*) = nullptr; @@ -261,9 +379,11 @@ class NvExecutionProvider : public IExecutionProvider { bool detailed_build_log_ = false; bool cuda_graph_enable_ = false; bool multi_profile_enable_ = false; + std::filesystem::path runtime_cache_; std::string cache_prefix_; std::string op_types_to_exclude_; int nv_profile_index_ = 0; + std::unique_ptr ep_context_model_; // The format is as for TENSORRT_VERSION: (MAJOR * 100 + MINOR) * 100 + PATCH int32_t trt_version_; @@ -278,7 +398,6 @@ class NvExecutionProvider : public IExecutionProvider { std::string ep_context_file_path_; int ep_context_embed_mode_ = 0; std::string ctx_model_path_; - std::string ep_cache_context_attr_; std::string engine_cache_relative_path_to_context_model_dir; std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; @@ -290,9 +409,8 @@ class NvExecutionProvider : public IExecutionProvider { // In general, TensorRT objects are not thread safe; accesses to an object from different threads must be serialized by the client. // But there are still some thread safe operations, please see here https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading // For those non thread safe operations, TRT EP uses (1) lock_guard or (2) PerThreadContext to make sure synchronization. - std::unordered_map> parsers_; std::unordered_map> engines_; - std::unordered_map> contexts_; + std::unordered_map contexts_; std::unordered_map> builders_; std::unordered_map> networks_; std::unordered_map>> input_info_; @@ -311,15 +429,6 @@ class NvExecutionProvider : public IExecutionProvider { // Call cudaStreamSynchronize() after TRT enqueueV3() mutable bool sync_stream_after_enqueue_ = true; - CUDAGraph cuda_graph_; - bool is_graph_captured_ = false; - int regular_run_count_before_graph_capture_ = 0; - // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: - // (1) memory pattern is enabled. (2) arena allocation for stream. - // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs - // to allocate enough memory in Arena before graph capturing. - const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. - // [Note] We don't use PerThreadContext for now since it has issue with multithreading // // TRT or CUDA objects that must be maintained on a per thread basis will be put under this PerThreadContext data structure. @@ -339,19 +448,23 @@ class NvExecutionProvider : public IExecutionProvider { bool IsTensorRTContextInMap(std::string fused_node); nvinfer1::IExecutionContext& GetTensorRTContext(std::string fused_node); - bool UpdateTensorRTContext(std::string fused_node, std::unique_ptr context); + bool UpdateTensorRTContext(std::string fused_node, tensorrt_ptr::unique_pointer_exec_ctx context); void ResetTensorRTContext(std::string fused_node); - bool CompareProfileShapes(std::string fused_node, ShapeRangesMap& shape_ranges); - void UpdateProfileShapes(std::string fused_node, ShapeRangesMap& shape_ranges); - - void InitCUDAGraph(); - void SetGraphStream(cudaStream_t stream); - bool IsGraphCaptureAllowed() const; - void CaptureBegin(int graph_annotation_id); - void CaptureEnd(int graph_annotation_id); - bool IsGraphCaptured(int graph_annotation_id) const; - Status ReplayGraph(int graph_annotation_id); - void IncrementRegularRunCountBeforeGraphCapture(); + + // CUDA Graph management + void SetCudaGraphStream(cudaStream_t stream) { cuda_graph_.SetStream(stream); } + bool IsGraphCaptureAllowed(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + bool IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + CudaGraphAnnotation_t GetCudaGraphAnnotationId(const onnxruntime::RunOptions& run_options) const; + void SetCurrentGraphAnnotationId(CudaGraphAnnotation_t cuda_graph_annotation_id); + CudaGraphAnnotation_t GetCurrentGraphAnnotationId() const; + void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id); + void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id); + bool IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + Status ReplayGraph(CudaGraphAnnotation_t cuda_graph_annotation_id, bool sync_status_flag); + void IncrementRegularRunCountBeforeGraphCapture(CudaGraphAnnotation_t cuda_graph_annotation_id); + void ResetWarmupRuns(CudaGraphAnnotation_t cuda_graph_annotation_id); + void DeleteCapturedGraph(CudaGraphAnnotation_t cuda_graph_annotation_id); private: cudnnHandle_t external_cudnn_handle_ = nullptr; @@ -365,7 +478,7 @@ class NvExecutionProvider : public IExecutionProvider { // See more details here: // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading // https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_execution_context.html#a63cd95430852038ce864e17c670e0b36 - std::unordered_map> trt_context_map_; + std::unordered_map trt_context_map_; // The profile shape ranges for the engine that the execution context maintained by the PerThreadContext is built with. // TRT EP needs this info to determine whether to rebuild the execution context. @@ -374,13 +487,18 @@ class NvExecutionProvider : public IExecutionProvider { // Cuda graph with multi threads will be supported in the future, so cuda_graph_ is put under PerThreadContext. // ORT TRT only supports CUDA graph when whole model is supported by TRT, so simply maintaining a CUDAGraph instance is enough (no need to maintain one CUDAGraph instance per TRT subgraph) CUDAGraph cuda_graph_; + // Map of graph id to regular_run_count_before_graph_capture + std::unordered_map graph_id_to_run_count_; bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; + // Current graph annotation ID for this run + CudaGraphAnnotation_t current_graph_annotation_id_ = kCudaGraphAnnotationDefault; // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: // (1) memory pattern is enabled. (2) arena allocation for stream. // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs // to allocate enough memory in Arena before graph capturing. - const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + const int min_num_runs_before_cuda_graph_capture_ = 2; // required min regular runs before graph capture for the necessary memory allocations. + // https://github.com/NVIDIA/TensorRT/blob/main/samples/common/sampleInference.cpp#L1258-L1291 Based on the trtexec code }; using PerThreadContextMap = std::unordered_map>; @@ -499,6 +617,7 @@ class NvExecutionProvider : public IExecutionProvider { * going through the time-consuming processes of model parsing and engine building. */ Status CreateNodeComputeInfoFromPrecompiledEngine(const GraphViewer& graph_body_viewer, + size_t node_idx, const Node& fused_node, std::unordered_map& input_map, std::unordered_map& output_map, @@ -513,11 +632,6 @@ class NvExecutionProvider : public IExecutionProvider { std::unordered_map& output_map, std::vector& node_compute_funcs); - bool IsGraphCaptureAllowed() const; - void CaptureBegin(int graph_annotation_id); - void CaptureEnd(int graph_annotation_id); - void IncrementRegularRunCountBeforeGraphCapture(); - /** * Get the pointer to the IBuilder instance. * This function only creates the instance at the first time it's being called." diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc index f90bf24ef4975..f25718114891b 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc @@ -17,6 +17,7 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi NvExecutionProviderInfo info{}; void* user_compute_stream = nullptr; void* onnx_bytestream = nullptr; + void* external_data_bytestream = nullptr; ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( @@ -48,21 +49,15 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi .AddAssignmentToReference(nv::provider_option_names::kProfilesMaxShapes, info.profile_max_shapes) .AddAssignmentToReference(nv::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes) .AddAssignmentToReference(nv::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable) + .AddAssignmentToReference(nv::provider_option_names::kUseExternalDataInitializer, info.use_external_data_initializer) .AddAssignmentToReference(nv::provider_option_names::kMultiProfileEnable, info.multi_profile_enable) - .AddValueParser( - nv::provider_option_names::kONNXBytestream, - [&onnx_bytestream](const std::string& value_str) -> Status { - size_t address; - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - onnx_bytestream = reinterpret_cast(address); - return Status::OK(); - }) - .AddAssignmentToReference(nv::provider_option_names::kONNXBytestreamSize, info.onnx_bytestream_size) + .AddAssignmentToReference(nv::provider_option_names::kRuntimeCacheFile, info.runtime_cache_path) .Parse(options)); // add new provider option here. info.user_compute_stream = user_compute_stream; info.has_user_compute_stream = (user_compute_stream != nullptr); info.onnx_bytestream = onnx_bytestream; + info.external_data_bytestream = external_data_bytestream; // EP context settings // when EP context is enabled, default is to embed the engine in the context model @@ -73,7 +68,8 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi info.dump_ep_context_model = false; } else if (ep_context_enable == "1") { info.dump_ep_context_model = true; - info.weight_stripped_engine_enable = true; + // We want to reenable weightless engines as soon constant initializers are supported as inputs + info.weight_stripped_engine_enable = false; } else { ORT_THROW("Invalid ", kOrtSessionOptionEpContextEnable, " must 0 or 1"); } @@ -110,9 +106,8 @@ ProviderOptions NvExecutionProviderInfo::ToProviderOptions(const NvExecutionProv {nv::provider_option_names::kProfilesMaxShapes, MakeStringWithClassicLocale(info.profile_max_shapes)}, {nv::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)}, {nv::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)}, - {nv::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(info.onnx_bytestream)}, - {nv::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.onnx_bytestream_size)}, - }; + {nv::provider_option_names::kUseExternalDataInitializer, MakeStringWithClassicLocale(info.use_external_data_initializer)}, + {nv::provider_option_names::kRuntimeCacheFile, MakeStringWithClassicLocale(info.runtime_cache_path)}}; return options; } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h index 2a67f3c3bec4d..372e8196f38c2 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h @@ -31,11 +31,13 @@ struct NvExecutionProviderInfo { std::string onnx_model_folder_path{""}; const void* onnx_bytestream{nullptr}; size_t onnx_bytestream_size{0}; + bool use_external_data_initializer{false}; + const void* external_data_bytestream{nullptr}; + size_t external_data_bytestream_size{0}; bool engine_decryption_enable{false}; std::string engine_decryption_lib_path{""}; bool force_sequential_engine_build{false}; - bool context_memory_sharing_enable{false}; - std::string timing_cache_path{""}; + std::string runtime_cache_path{""}; bool detailed_build_log{false}; bool sparsity_enable{false}; int auxiliary_streams{-1}; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h index 22e5eea6924de..c564fe65c3d5c 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h @@ -386,22 +386,11 @@ std::string GetCachePath(const std::string& root, const std::string& name) { * Get compute capability * */ -std::string GetComputeCapacity(const cudaDeviceProp& prop) { +std::string GetComputeCapability(const cudaDeviceProp& prop) { const std::string compute_capability = std::to_string(prop.major * 10 + prop.minor); return compute_capability; } -/* - * Get Timing by compute capability - * - */ -std::string GetTimingCachePath(const std::string& root, std::string& compute_cap) { - // append compute capability of the GPU as this invalidates the cache and TRT will throw when loading the cache - const std::string timing_cache_name = "NvExecutionProvider_cache_sm" + - compute_cap + ".timing"; - return GetCachePath(root, timing_cache_name); -} - /* * Get cache by type * @@ -683,4 +672,29 @@ std::string GetCacheSuffix(const std::string& fused_node_name, const std::string } return ""; } + +/* + * Checks if there is a an element with value `-1` in nvinfer1::Dims + */ +static bool checkTrtDimIsDynamic(nvinfer1::Dims dims) { + for (int j = 0, end = dims.nbDims; j < end; ++j) { + if (dims.d[j] == -1) { + return true; + } + } + return false; +} + +/* + * Checks if an nvinfer1::ITensor signales a dynamic shape, + * either due to dynamic shapes or due to it being a shape tensor + */ +static bool checkTrtTensorIsDynamic(nvinfer1::ITensor* tensor) { + if (tensor->isShapeTensor()) { + return true; + } else { + // Execution tensor + return checkTrtDimIsDynamic(tensor->getDimensions()); + } +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_file_utils.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_file_utils.h new file mode 100644 index 0000000000000..159aba0507ffb --- /dev/null +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_file_utils.h @@ -0,0 +1,52 @@ +#pragma once +#include +#include +#include +#include +#include +#include "core/providers/shared_library/provider_api.h" + +namespace onnxruntime { +namespace file_utils { + +inline std::vector ReadFile(const std::string& path) { + if (!std::filesystem::exists(path)) { + LOGS_DEFAULT(INFO) << "TensorRT RTX could not find the file and will create a new one " << path << std::endl; + return {}; + } + std::ifstream file(path, std::ios::in | std::ios::binary); + if (!file) { + ORT_THROW("Failed to open file: " + path); + } + file.seekg(0, std::ios::end); + std::streamsize size = file.tellg(); + file.seekg(0, std::ios::beg); + std::vector buffer(size); + if (size > 0 && !file.read(buffer.data(), size)) { + ORT_THROW("Failed to read file: " + path); + } + return buffer; +} + +inline void WriteFile(const std::string& path, const void* data, size_t size) { + if (std::filesystem::exists(path)) { + std::ofstream file(path, std::ios::out | std::ios::binary | std::ios::trunc); + if (!file) { + ORT_THROW("Failed to open file for writing: " + path); + } + file.write(static_cast(data), size); + } else { + LOGS_DEFAULT(INFO) << "TensorRT RTX a new file cache was written to " << path << std::endl; + // Create new file + std::ofstream file(path, std::ios::out | std::ios::binary); + if (!file) { + ORT_THROW("Failed to create file: " + path); + } + file.write(static_cast(data), size); + } +} + +inline void WriteFile(const std::string& path, const std::vector& data) { WriteFile(path, data.data(), data.size()); } + +} // namespace file_utils +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index e236cccaaaa77..c3fbccef84883 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -431,7 +431,7 @@ struct NvTrtRtxSyncNotificationImpl : OrtSyncNotificationImpl { Release = ReleaseImpl; } - cudaStream_t& stream_; + cudaStream_t stream_; cudaEvent_t event_; const OrtApi& ort_api; @@ -477,9 +477,9 @@ struct NvTrtRtxSyncStreamImpl : OrtSyncStreamImpl { *notification_impl = nullptr; std::unique_ptr notification; - cudaStream_t* cuda_stream = static_cast(impl.stream_.GetHandle()); + cudaStream_t cuda_stream = static_cast(impl.stream_.GetHandle()); - RETURN_IF_ERROR(NvTrtRtxSyncNotificationImpl::Create(*cuda_stream, impl.ort_api, notification)); + RETURN_IF_ERROR(NvTrtRtxSyncNotificationImpl::Create(cuda_stream, impl.ort_api, notification)); *notification_impl = notification.release(); return nullptr; @@ -557,6 +557,67 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { return ORT_VERSION; } + /** + * @brief Checks if a given OrtHardwareDevice is a supported NVIDIA GPU. + * + * This function verifies if the provided hardware device corresponds to a physical + * NVIDIA GPU that meets the minimum compute capability requirements for this execution provider. + * + * The check is performed by: + * 1. Extracting the LUID (Locally Unique Identifier) from the device's metadata. + * 2. Converting the string LUID to a 64-bit integer. + * 3. Iterating through all available CUDA devices on the system. + * 4. For each CUDA device, constructing its 64-bit LUID from its properties. + * 5. Comparing the LUIDs. If a match is found, it checks if the device's + * compute capability is at least 8.0 (Ampere) or newer. + * + * @param device The OrtHardwareDevice to check. + * @return True if the device is a supported NVIDIA GPU, false otherwise. + */ + bool IsOrtHardwareDeviceSupported(const OrtHardwareDevice& device) { + const auto& metadata_entries = device.metadata.Entries(); + const auto it = metadata_entries.find("LUID"); + if (it == metadata_entries.end()) { + return false; + } + + uint64_t target_luid; + try { + target_luid = std::stoull(it->second); + } catch (const std::exception&) { + return false; + } + + int device_count = 0; + if (cudaGetDeviceCount(&device_count) != cudaSuccess) { + return false; + } + + for (int i = 0; i < device_count; ++i) { + cudaDeviceProp prop; + if (cudaGetDeviceProperties(&prop, i) != cudaSuccess) { + continue; + } + + // The LUID is an 8-byte value, valid on Windows when luidDeviceNodeMask is non-zero. + // We reconstruct the 64-bit integer representation from the raw bytes. + if (prop.luidDeviceNodeMask == 0) { + continue; + } + + // Ensure the LUID is 8 bytes and reinterpret it directly as a uint64_t for comparison. + static_assert(sizeof(prop.luid) == sizeof(uint64_t), "cudaDeviceProp::luid should be 8 bytes"); + uint64_t current_luid = *reinterpret_cast(prop.luid); + + if (current_luid == target_luid) { + // Ampere architecture or newer is required. + return prop.major >= 8; + } + } + + return false; + } + // Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports. // An EP created with this factory is expected to be able to execute a model with *all* supported // hardware devices at once. A single instance of NvTensorRtRtx EP is not currently setup to partition a model among @@ -579,11 +640,12 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { int16_t device_id = 0; for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *devices[i]; + if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU && - factory->ort_api.HardwareDevice_VendorId(&device) == factory->vendor_id) { + factory->ort_api.HardwareDevice_VendorId(&device) == factory->vendor_id && + factory->IsOrtHardwareDeviceSupported(device)) { OrtKeyValuePairs* ep_options = nullptr; OrtKeyValuePairs* ep_metadata = nullptr; - factory->ort_api.CreateKeyValuePairs(&ep_options); factory->ort_api.CreateKeyValuePairs(&ep_metadata); factory->ort_api.AddKeyValuePair(ep_options, "device_id", std::to_string(device_id).c_str()); diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc index 21d964b0c341f..c1626fa4f36ad 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc @@ -20,10 +20,11 @@ extern TensorrtLogger& GetTensorrtLogger(bool verbose_log); * * Note: Please see more details about "EPContext" contrib op in contrib_defs.cc */ -bool GraphHasCtxNode(const GraphViewer& graph_viewer) { +bool GraphHasCtxNode(const GraphViewer& graph_viewer, size_t& node_idx) { for (int i = 0; i < graph_viewer.MaxNodeIndex(); ++i) { auto node = graph_viewer.GetNode(i); if (node != nullptr && node->OpType() == EPCONTEXT_OP) { + node_idx = i; return true; } } @@ -63,19 +64,18 @@ void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto, } /* - * Create "EP context node" model where engine information is embedded + * Create EP context node where engine information is embedded */ -ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, - const std::string engine_cache_path, - char* engine_data, - size_t size, - const int64_t embed_mode, - const std::string compute_capability, - const std::string onnx_model_path, - const logging::Logger* logger) { - auto model_build = graph_viewer.CreateModel(*logger); - auto& graph_build = model_build->MainGraph(); - +Status CreateCtxNode(const GraphViewer& graph_viewer, + Graph& graph_build, + const std::string engine_cache_path, + char* engine_data, + size_t size, + const int64_t embed_mode, + const std::string compute_capability, + const std::string onnx_model_path, + const std::string& ep_context_node_name, + int32_t trt_version) { // Get graph inputs and outputs std::vector inputs, outputs; for (auto input : graph_viewer.GetInputs()) { @@ -89,55 +89,71 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, } // Create EP context node attributes - auto attr_0 = ONNX_NAMESPACE::AttributeProto::Create(); // embed_mode - auto attr_1 = ONNX_NAMESPACE::AttributeProto::Create(); // ep_cache_context - auto attr_2 = ONNX_NAMESPACE::AttributeProto::Create(); // hardware_architecture - auto attr_3 = ONNX_NAMESPACE::AttributeProto::Create(); // onnx_model_filename + auto attr_embed_mode = ONNX_NAMESPACE::AttributeProto::Create(); + auto attr_main_context = ONNX_NAMESPACE::AttributeProto::Create(); + auto attr_ep_cache_context = ONNX_NAMESPACE::AttributeProto::Create(); + auto attr_sdk_version = ONNX_NAMESPACE::AttributeProto::Create(); + auto attr_hw_architecture = ONNX_NAMESPACE::AttributeProto::Create(); + auto attr_onnx_filename = ONNX_NAMESPACE::AttributeProto::Create(); + auto attr_partition_name = ONNX_NAMESPACE::AttributeProto::Create(); std::string engine_data_str = ""; - attr_0->set_name(EMBED_MODE); - attr_0->set_type(onnx::AttributeProto_AttributeType_INT); - attr_0->set_i(embed_mode); - attr_1->set_name(EP_CACHE_CONTEXT); - attr_1->set_type(onnx::AttributeProto_AttributeType_STRING); + attr_main_context->set_name(MAIN_CONTEXT); + attr_main_context->set_type(onnx::AttributeProto_AttributeType_INT); + attr_main_context->set_i(0); // we do not support a main context node but each has it's own engine payload + attr_embed_mode->set_name(EMBED_MODE); + attr_embed_mode->set_type(onnx::AttributeProto_AttributeType_INT); + attr_embed_mode->set_i(embed_mode); + attr_ep_cache_context->set_name(EP_CACHE_CONTEXT); + attr_ep_cache_context->set_type(onnx::AttributeProto_AttributeType_STRING); if (embed_mode) { if (size > 0) { engine_data_str.assign(engine_data, size); } - attr_1->set_s(engine_data_str); - // TODO(maximilianm) we might want to disable this warning as we only support weightless engines that are really small - // the reason we had this was that the field will be hashed and storing a large bytestream has significant overhead - LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; + attr_ep_cache_context->set_s(engine_data_str); } else { - attr_1->set_s(engine_cache_path); + std::string engine_cache_filename = std::filesystem::path(engine_cache_path).filename().string(); + attr_ep_cache_context->set_s(engine_cache_filename); + std::fstream engine_cache_file(engine_cache_path, std::ios::binary | std::ios::out); + if (engine_cache_file.is_open()) { + engine_cache_file.write(engine_data, size); + engine_cache_file.close(); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "NvTensorRTRTX EP could not write cache to ", engine_cache_path); + } } - attr_2->set_name(COMPUTE_CAPABILITY); - attr_2->set_type(onnx::AttributeProto_AttributeType_STRING); - attr_2->set_s(compute_capability); - attr_3->set_name(ONNX_MODEL_FILENAME); - attr_3->set_type(onnx::AttributeProto_AttributeType_STRING); - attr_3->set_s(std::filesystem::path(onnx_model_path).filename().string()); + + attr_hw_architecture->set_name(COMPUTE_CAPABILITY); + attr_hw_architecture->set_type(onnx::AttributeProto_AttributeType_STRING); + attr_hw_architecture->set_s(compute_capability); + + attr_partition_name->set_name(PARTITION_NAME); + attr_partition_name->set_type(onnx::AttributeProto_AttributeType_STRING); + attr_partition_name->set_s(ep_context_node_name); // includes hash of the subgraph that was built + + attr_onnx_filename->set_name(ONNX_MODEL_FILENAME); + attr_onnx_filename->set_type(onnx::AttributeProto_AttributeType_STRING); + attr_onnx_filename->set_s(std::filesystem::path(onnx_model_path).filename().string()); + + attr_sdk_version->set_name(SDK_VERSION); + attr_sdk_version->set_type(onnx::AttributeProto_AttributeType_STRING); + attr_sdk_version->set_s(std::to_string(trt_version)); auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create(); constexpr int num_attributes = 4; node_attributes->reserve(num_attributes); - node_attributes->emplace(EMBED_MODE, *attr_0); - node_attributes->emplace(EP_CACHE_CONTEXT, *attr_1); - node_attributes->emplace(COMPUTE_CAPABILITY, *attr_2); - node_attributes->emplace(ONNX_MODEL_FILENAME, *attr_3); + node_attributes->emplace(MAIN_CONTEXT, *attr_main_context); + node_attributes->emplace(EMBED_MODE, *attr_embed_mode); + node_attributes->emplace(EP_CACHE_CONTEXT, *attr_ep_cache_context); + node_attributes->emplace(COMPUTE_CAPABILITY, *attr_hw_architecture); + node_attributes->emplace(PARTITION_NAME, *attr_partition_name); + node_attributes->emplace(ONNX_MODEL_FILENAME, *attr_onnx_filename); + node_attributes->emplace(SDK_VERSION, *attr_sdk_version); // Create EP context node - graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN); + graph_build.AddNode(ep_context_node_name, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN); ORT_ENFORCE(graph_build.Resolve().IsOK()); - - // Serialize modelproto to string - auto new_graph_viewer = graph_build.CreateGraphViewer(); - auto& metadata = graph_viewer.GetGraph().GetModel().MetaData(); - auto model = new_graph_viewer->CreateModel(*logger, metadata); - auto model_proto = model->ToProto(); - new_graph_viewer->ToProto(*model_proto->mutable_graph(), true, true); - model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - - return model_proto.release(); + return Status::OK(); } /* @@ -206,17 +222,6 @@ std::string GetCtxModelPath(const std::string& ep_context_file_path, return ctx_model_path; } -/* - * Dump "EP context" model - * - */ -void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto, - const std::string& ctx_model_path) { - std::fstream dump(ctx_model_path, std::ios::out | std::ios::trunc | std::ios::binary); - model_proto->SerializeToOstream(dump); - LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Dumped " + ctx_model_path; -} - bool IsAbsolutePath(const std::string& path_string) { #ifdef _WIN32 onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string); @@ -248,38 +253,12 @@ bool IsRelativePathToParentPath(const std::string& path_string) { #endif } -/* - * Get the weight-refitted engine cache path from a weight-stripped engine cache path - * - * Weight-stipped engine: - * An engine with weights stripped and its size is smaller than a regualr engine. - * The cache name of weight-stripped engine is NvExecutionProvider_TRTKernel_XXXXX.stripped.engine - * - * Weight-refitted engine: - * An engine that its weights have been refitted and it's simply a regular engine. - * The cache name of weight-refitted engine is NvExecutionProvider_TRTKernel_XXXXX.engine - */ -std::string GetWeightRefittedEnginePath(std::string stripped_engine_cache) { - std::filesystem::path stripped_engine_cache_path(stripped_engine_cache); - std::string refitted_engine_cache_path = stripped_engine_cache_path.stem().stem().string() + ".engine"; - return refitted_engine_cache_path; -} - -bool IsWeightStrippedEngineCache(std::filesystem::path& engine_cache_path) { - // The weight-stripped engine cache has the naming of xxx.stripped.engine - return engine_cache_path.stem().extension().string() == ".stripped"; -} - -Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) { - if (!ValidateEPCtxNode(graph_viewer)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node"); - } - auto node = graph_viewer.GetNode(0); - auto& attrs = node->GetAttributes(); +Status TensorRTCacheModelHandler::GetEpContextFromGraph(const Node& node) { + auto& attrs = node.GetAttributes(); const int64_t embed_mode = attrs.at(EMBED_MODE).i(); // Only make path checks if model not provided as byte buffer - bool make_secure_path_checks = !GetModelPath(graph_viewer).empty(); + bool make_secure_path_checks = ep_context_model_path_.empty(); if (embed_mode) { // Get engine from byte stream. @@ -294,15 +273,14 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph if (weight_stripped_engine_refit_) { const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s(); - std::string placeholder; auto status = NvExecutionProvider::RefitEngine(onnx_model_filename, onnx_model_folder_path_, - placeholder, make_secure_path_checks, onnx_model_bytestream_, onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, (*trt_engine_).get(), - false /* serialize refitted engine to disk */, detailed_build_log_); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); @@ -327,34 +305,25 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph auto engine_cache_path = ctx_model_dir.append(cache_path); LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string(); - // If it's a weight-stripped engine cache, it needs to be refitted even though the refit flag is not enabled - if (!weight_stripped_engine_refit_) { - weight_stripped_engine_refit_ = IsWeightStrippedEngineCache(engine_cache_path); - } - - // If the serialized refitted engine is present, use it directly without refitting the engine again - if (weight_stripped_engine_refit_) { - const std::filesystem::path refitted_engine_cache_path = GetWeightRefittedEnginePath(engine_cache_path.string()); - if (std::filesystem::exists(refitted_engine_cache_path)) { - LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] " + refitted_engine_cache_path.string() + " exists."; - engine_cache_path = refitted_engine_cache_path.string(); - weight_stripped_engine_refit_ = false; - } - } - if (!std::filesystem::exists(engine_cache_path)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Nv EP can't find engine cache: " + engine_cache_path.string() + ". Please make sure engine cache is in the same directory or sub-directory of context model."); } - std::ifstream engine_file(engine_cache_path.string(), std::ios::binary | std::ios::in); - engine_file.seekg(0, std::ios::end); - size_t engine_size = engine_file.tellg(); - engine_file.seekg(0, std::ios::beg); - std::unique_ptr engine_buf{new char[engine_size]}; - engine_file.read((char*)engine_buf.get(), engine_size); - *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); + size_t file_length = 0; + auto path_str = ToPathString(engine_cache_path.string()); + + Env::MappedMemoryPtr engine_buf; + const auto& env = GetDefaultEnv(); + ORT_RETURN_IF_ERROR(env.GetFileLength(path_str.c_str(), file_length)); + if (!file_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "Nv EP could not read engine from cache: " + engine_cache_path.string()); + } + ORT_RETURN_IF_ERROR(env.MapFileIntoMemory(path_str.c_str(), 0, file_length, engine_buf)); + + *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(engine_buf.get(), file_length)); if (!(*trt_engine_)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Nv EP could not deserialize engine from cache: " + engine_cache_path.string()); @@ -366,12 +335,12 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph std::string weight_stripped_engine_cache = engine_cache_path.string(); auto status = NvExecutionProvider::RefitEngine(onnx_model_filename, onnx_model_folder_path_, - weight_stripped_engine_cache, make_secure_path_checks, onnx_model_bytestream_, onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, (*trt_engine_).get(), - true /* serialize refitted engine to disk */, detailed_build_log_); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); @@ -384,11 +353,8 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph /* * The sanity check for EP context contrib op. */ -bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewer) { - assert(graph_viewer.NumberOfNodes() == 1); - assert(graph_viewer.GetNode(0)->OpType() == EPCONTEXT_OP); - auto node = graph_viewer.GetNode(0); - auto& attrs = node->GetAttributes(); +bool TensorRTCacheModelHandler::ValidateEPCtxNode(const Node& node) { + auto& attrs = node.GetAttributes(); // Show the warning if compute capability is not matched if (attrs.count(COMPUTE_CAPABILITY) > 0) { @@ -413,7 +379,7 @@ bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewe const int64_t embed_mode = attrs.at(EMBED_MODE).i(); if (embed_mode == 1) { // engine binary data - LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; + // LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; } return true; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h index f0a05c42414e5..7c52f26cc9177 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h @@ -7,6 +7,7 @@ #include #include #include +#include #include "core/providers/nv_tensorrt_rtx/nv_includes.h" #include "core/providers/shared_library/provider_api.h" @@ -14,33 +15,32 @@ namespace onnxruntime { static const std::string EPCONTEXT_OP = "EPContext"; +static const std::string MAIN_CONTEXT = "main_context"; static const std::string EMBED_MODE = "embed_mode"; static const std::string EP_CACHE_CONTEXT = "ep_cache_context"; static const std::string COMPUTE_CAPABILITY = "hardware_architecture"; static const std::string ONNX_MODEL_FILENAME = "onnx_model_filename"; +static const std::string PARTITION_NAME = "partition_name"; +static const std::string SDK_VERSION = "ep_sdk_version"; static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft"; -static const std::string EPCONTEXT_WARNING = - "It's suggested to set the ORT graph optimization level to 0 and \ - make \"embed_mode\" to 0 (\"ep_cache_context\" is the cache path)\ - for the best model loading time"; -bool GraphHasCtxNode(const GraphViewer& graph_viewer); +bool GraphHasCtxNode(const GraphViewer& graph_viewer, size_t& node_idx); const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer); std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path); -ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, - const std::string engine_cache_path, - char* engine_data, - size_t size, - const int64_t embed_mode, - const std::string compute_capability, - const std::string onnx_model_path, - const logging::Logger* logger); +Status CreateCtxNode(const GraphViewer& graph_viewer, + Graph& graph_build, + const std::string engine_cache_path, + char* engine_data, + size_t size, + const int64_t embed_mode, + const std::string compute_capability, + const std::string onnx_model_path, + const std::string& ep_context_node_name, + int trt_version); std::string GetCtxModelPath(const std::string& ep_context_file_path, const std::string& original_model_path); bool IsAbsolutePath(const std::string& path_string); bool IsRelativePathToParentPath(const std::string& path_string); -void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto, - const std::string& ctx_model_path); void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto, char* engine_data, size_t size); @@ -55,6 +55,8 @@ class TensorRTCacheModelHandler { std::string onnx_model_folder_path, const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, + const void* onnx_external_data_bytestream, + size_t onnx_external_data_bytestream_size, bool detailed_build_log) : trt_engine_(trt_engine), trt_runtime_(trt_runtime), @@ -64,13 +66,15 @@ class TensorRTCacheModelHandler { onnx_model_folder_path_(onnx_model_folder_path), onnx_model_bytestream_(onnx_model_bytestream), onnx_model_bytestream_size_(onnx_model_bytestream_size), + onnx_external_data_bytestream_(onnx_external_data_bytestream), + onnx_external_data_bytestream_size_(onnx_external_data_bytestream_size), detailed_build_log_(detailed_build_log) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler); - bool ValidateEPCtxNode(const GraphViewer& graph_viewer); + bool ValidateEPCtxNode(const Node& node); - Status GetEpContextFromGraph(const GraphViewer& graph_viewer); + Status GetEpContextFromGraph(const Node& node); private: std::unique_ptr* trt_engine_; @@ -81,6 +85,8 @@ class TensorRTCacheModelHandler { std::string onnx_model_folder_path_; const void* onnx_model_bytestream_; size_t onnx_model_bytestream_size_; + const void* onnx_external_data_bytestream_; + size_t onnx_external_data_bytestream_size_; bool detailed_build_log_; }; // TRTCacheModelHandler } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index be59b1ae07020..68d15bdfdcee0 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -90,7 +90,12 @@ BackendManager::BackendManager(SessionContext& session_context, "[OpenVINO-EP] Bounded dynamic model execution using provider option reshape_input is not supported for OVEP EPContext model"; ORT_THROW(exception_str); } - model_stream = ep_ctx_handle_.GetModelBlobStream(session_context_.so_context_file_path, subgraph); + if (subgraph_context_.is_ep_ctx_ovir_encapsulated) { + model_stream = ep_ctx_handle_.GetModelBlobStream(session_context_.onnx_model_path_name.replace_extension("xml").string(), subgraph); + } else { + model_stream = ep_ctx_handle_.GetModelBlobStream(session_context_.so_context_file_path, subgraph); + } + } else { model_proto = GetModelProtoFromFusedNode(fused_node, subgraph, logger); } @@ -236,7 +241,9 @@ Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVie std::ofstream blob_file(blob_filename, std::ios::out | std::ios::trunc | std::ios::binary); if (!blob_file) { - ORT_THROW("Unable to open file for epctx model dump."); + std::ostringstream err_msg; + err_msg << "Unable to open file for epctx model dump: " << blob_filename; + ORT_THROW(err_msg.str()); } compiled_model.export_model(blob_file); model_blob_str = blob_filename.filename().string(); @@ -375,6 +382,56 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) { return false; } +static bool IsModelBF16(const onnxruntime::GraphViewer& graph_viewer) { + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (std::size_t i = 0; i < node_indices.size(); i++) { + gsl::not_null node(graph_viewer.GetNode(node_indices[i])); + for (auto& output : node->OutputDefs()) { + if (output->ToProto().type().tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) + return true; + } + } + return false; +} + +static bool Is16BitTensor(const onnxruntime::NodeArg* node_arg) { + const auto* type_proto = node_arg ? node_arg->TypeAsProto() : nullptr; + return type_proto && type_proto->has_tensor_type() && + (type_proto->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT16 || + type_proto->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_INT16); +} + +// Check to see if the graph has Q/DQ nodes with int16 or uint16 quantization +static bool IsQDQGraphWithUint16OrInt16(const onnxruntime::GraphViewer& graph_viewer) { + std::unordered_set qdq_ops = {"QuantizeLinear", "DequantizeLinear"}; + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + + for (size_t i = 0; i < node_indices.size(); i++) { + gsl::not_null node(graph_viewer.GetNode(node_indices[i])); + + if (qdq_ops.find(node->OpType()) != qdq_ops.end()) { + const auto& input_defs = node->InputDefs(); + + if (node->OpType() == "DequantizeLinear") { + // DequantizeLinear: [quantized_input, scale, zero_point] -> [float_output] + // Check quantized input tensor and optional zero point + if (Is16BitTensor(input_defs.empty() ? nullptr : input_defs[0]) || + (input_defs.size() >= 3 && Is16BitTensor(input_defs[2]))) { + return true; + } + } else if (node->OpType() == "QuantizeLinear") { + // QuantizeLinear: [float_input, scale, zero_point] -> [quantized_output] + const auto& output_defs = node->OutputDefs(); + if (Is16BitTensor(output_defs.empty() ? nullptr : output_defs[0]) || + (input_defs.size() >= 3 && Is16BitTensor(input_defs[2]))) { + return true; + } + } + } + } + return false; +} + static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& onnx_model_path_name, [[maybe_unused]] ONNX_NAMESPACE::ModelProto* model_proto, [[maybe_unused]] const onnxruntime::Node& fused_node) { @@ -433,6 +490,10 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, } #endif + // Check if the graph is QDQ and has int16 or uint16 quantization + // If so, we will apply the QDQ scales fix transformation (for GPU device only) + bool is_qdq_graph_uint16_or_int16 = IsQDQGraphWithUint16OrInt16(subgraph); + const auto& onnx_model_path_name = subgraph.ModelPath(); // QDQ stripping enabled only for the NPU and experimentally on the GPU if ((session_context_.device_type.find("NPU") != std::string::npos) && @@ -446,7 +507,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); return model_proto; } else if ((session_context_.device_type.find("GPU") != std::string::npos) && - enable_ovep_qdq_optimizer) { + is_qdq_graph_uint16_or_int16) { // Create a copy of the model std::unique_ptr model; Status status = qdq_scales_fix::Transform(subgraph, logger, model); @@ -456,6 +517,16 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); return model_proto; + } else if (IsModelBF16(subgraph)) { + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP bfloat16->float16 optimization pass is enabled"; + std::unique_ptr model; + Status status = bfloat16_fix::Transform(subgraph, logger, model); + auto model_proto = model->ToProto(); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + print_model_proto_duration(); + DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node); + ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); + return model_proto; } else { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP QDQ optimization pass is disabled"; auto model = subgraph.CreateModel(logger); diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index 73fbe9a0fa76f..7027861f0c4dc 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -150,6 +150,11 @@ CreateOVModel(std::string&& model, LOGS_DEFAULT(INFO) << log_tag << "Reshaping the ov tensor to specified shape"; ov_model->reshape(session_context.reshape); } + + if (!session_context.layout.empty()) { + LOGS_DEFAULT(INFO) << log_tag << "Setting the ov tensor layout to specified layout"; + ov_model = Set_Layout(ov_model, session_context.layout); + } // Check for Constant Folding if ((session_context.device_type != "NPU") && !session_context.is_wholly_supported_graph) { ov::pass::ConstantFolding pass_const_obj; @@ -199,6 +204,41 @@ GetOutputTensor(Ort::KernelContext& context, return context.GetOutput(index, output_shape); } +std::shared_ptr Set_Layout(std::shared_ptr ov_model, const layout_t& layout) { + ov::preprocess::PrePostProcessor preproc(ov_model); + + const auto& inputs = ov_model->inputs(); + const auto& outputs = ov_model->outputs(); + + auto find_tensor_index = [](const std::vector>& tensors, const std::string& name) -> std::optional { + for (size_t i = 0; i < tensors.size(); ++i) { + const auto& tensor = tensors[i]; + if (tensor.get_any_name() == name || tensor.get_tensor().get_names().count(name) > 0) { + return i; + } + } + return std::nullopt; + }; + + for (const auto& [tensor_name, layout_value] : layout) { + bool tensor_found = false; + + if (auto input_idx = find_tensor_index(inputs, tensor_name)) { + preproc.input(*input_idx).tensor().set_layout(layout_value); + tensor_found = true; + } else if (auto output_idx = find_tensor_index(outputs, tensor_name)) { + preproc.output(*output_idx).tensor().set_layout(layout_value); + tensor_found = true; + } + + if (!tensor_found) { + LOGS_DEFAULT(WARNING) << "Tensor '" << tensor_name << "' not found in model inputs or outputs"; + } + } + + return preproc.build(); +} + int GetFirstAvailableDevice(SessionContext& session_context) { int i = 0; // Get the first available VAD-M device and set the device to busy diff --git a/onnxruntime/core/providers/openvino/backend_utils.h b/onnxruntime/core/providers/openvino/backend_utils.h index 15145df651fa2..27f791c7a5bd1 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.h +++ b/onnxruntime/core/providers/openvino/backend_utils.h @@ -79,6 +79,8 @@ int GetFirstAvailableDevice(SessionContext& session_context); void FillOutputsWithConstantData(std::shared_ptr node, Ort::UnownedValue& out_tensor); +std::shared_ptr Set_Layout(std::shared_ptr ov_model, const layout_t& layout); + template void FillOutputHelper(Ort::UnownedValue& out_tensor, std::shared_ptr node); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 8b7309e6a5a98..2f174110dd31b 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -59,7 +59,7 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr }; // If the EPContext node with OVIR Encapsulation, then create // an executable network from EP_CACHE_CONTEXT using read_model() & compile_model() - exe_network_ = OVCore::Get()->ImportEPCtxOVIREncapsulation(*model_stream, + exe_network_ = OVCore::Get()->ImportEPCtxOVIREncapsulation(*model_stream->stream_, hw_target, device_config, enable_causallm, @@ -98,6 +98,7 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr !subgraph_context_.has_dynamic_input_shape && !session_context_.so_context_enable && session_context_.reshape.empty() && + session_context_.layout.empty() && !enable_causallm && !eligible_for_cpu_fallback && auto_unified_compile); @@ -213,122 +214,29 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { if (!session_context_.load_config.empty()) { const std::map& target_config = session_context_.load_config; - if ((session_context_.device_type.find("NPU") != std::string::npos) && session_context_.enable_causallm) { - if (target_config.find("NPU") != target_config.end()) { - auto npu_genai_config = target_config.at("NPU"); - CausalLMConfig().ApplyConfig(npu_genai_config, device_config); - } else { - LOGS_DEFAULT(WARNING) << "ORT GenAI CausalLMConfig Configuration not found."; - } - } + // Extract device names from device string and apply their configs + // Examples: "GPU" -> ["GPU"], "AUTO:GPU.0,CPU" -> ["AUTO", "GPU", "CPU"] + auto apply_device_config = [&](std::string_view device) { + if (device.empty()) return; - if (session_context_.device_type.find("NPU") != std::string::npos) { - auto npuw_config = target_config.at("NPU"); - - // Check if "NPU_USE_NPUW" exists and is set to "YES" - auto npu_use_npuw_it = npuw_config.find("NPU_USE_NPUW"); - if (npu_use_npuw_it != npuw_config.end() && - npu_use_npuw_it->second.is() && - npu_use_npuw_it->second.as() == "YES") { - // Only add NPUW-related keys if NPU_USE_NPUW is "YES" - for (const auto& [key, value] : npuw_config) { - if (key.find("NPUW") != std::string::npos) { - if (!value.is()) { - LOGS_DEFAULT(ERROR) << "Invalid value type for key: " << key; - continue; - } - device_config[key] = value; - } - } - } else { - // Check if there are any "NPUW" keys and log a warning - if (std::any_of(npuw_config.begin(), npuw_config.end(), - [&](const auto& pair) { return pair.first.find("NPUW") != std::string::npos; })) { - LOGS_DEFAULT(WARNING) << "Skipping NPUW-related configurations as NPU_USE_NPUW is not set to 'YES'."; - } - } - } - auto find_device_type_mode = [&](const std::string& device_type) -> std::string { - std::string device_mode = ""; - auto delimiter_pos = device_type.find(':'); - if (delimiter_pos != std::string::npos) { - std::stringstream str_stream(device_type.substr(0, delimiter_pos)); - std::getline(str_stream, device_mode, ','); - } - return device_mode; - }; + // Remove device index: "GPU.0" -> "GPU" + auto base_device = device.substr(0, device.find('.')); - // Parse device types like "AUTO:CPU,GPU" and extract individual devices - auto parse_individual_devices = [&](const std::string& device_type) -> std::vector { - std::vector devices; - auto delimiter_pos = device_type.find(':'); - if (delimiter_pos != std::string::npos) { - std::stringstream str_stream(device_type.substr(delimiter_pos + 1)); - std::string device; - while (std::getline(str_stream, device, ',')) { - devices.emplace_back(device); + if (auto config_it = target_config.find(std::string(base_device)); config_it != target_config.end()) { + for (const auto& [key, value] : config_it->second) { + device_config[key] = value; } - } else { - devices.emplace_back(device_type); } - return devices; }; - // Check if a property is supported and mutable - auto is_supported_and_mutable = [&](const std::string& key, - const std::vector& supported_config) -> bool { - auto it = std::find_if(supported_config.begin(), supported_config.end(), [&](const ov::PropertyName& property) { - return property == key && property.is_mutable(); - }); - return it != supported_config.end(); - }; - - // Set properties if they are valid, else log a warning if the property is missing or immutable by skipping the same - auto set_target_properties = [&](const std::string& device, const ov::AnyMap& config_options, - const std::vector& supported_properties) { - for (const auto& [key, value] : config_options) { - if ((key.find("NPUW") != std::string::npos) || - ((device_config.find(key) != device_config.end()) && session_context_.enable_causallm)) { - continue; - } - if (is_supported_and_mutable(key, supported_properties)) { - OVCore::Get()->core.set_property(device, ov::AnyMap{{key, value}}); - } else { - LOGS_DEFAULT(WARNING) << "WARNING: Property \"" << key - << "\" is either unsupported in current OpenVINO version" - << " or property is immutable for target device \"" - << device << "\". Skipping setting this property."; - } - } - }; - - // Check if the device type is AUTO, HETERO, or MULTI - if (session_context_.device_type.find("AUTO") == 0 || - session_context_.device_type.find("HETERO") == 0 || - session_context_.device_type.find("MULTI") == 0) { - //// Parse to get the device mode (e.g., "AUTO:CPU,GPU" -> "AUTO") - std::unordered_set supported_mode = {"AUTO", "HETERO", "MULTI"}; - auto device_mode = find_device_type_mode(session_context_.device_type); - ORT_ENFORCE(supported_mode.find(device_mode) != supported_mode.end(), " Invalid device mode is passed : ", session_context_.device_type); - // Parse individual devices (e.g., "AUTO:CPU,GPU" -> ["CPU", "GPU"]) - auto individual_devices = parse_individual_devices(session_context_.device_type); - if (!device_mode.empty()) individual_devices.emplace_back(device_mode); - - // Set properties only for individual devices (e.g., "CPU", "GPU") - for (const std::string& device : individual_devices) { - if (target_config.count(device)) { - // Get supported properties for each individual device - auto device_properties = OVCore::Get()->core.get_property(device, ov::supported_properties); - // Set properties for the device - set_target_properties(device, target_config.at(device), device_properties); + // Parse device string by splitting on ':' and ',' delimiters + const auto& device_str = session_context_.device_type; + for (size_t start = 0, pos = 0; pos <= device_str.size(); ++pos) { + if (pos == device_str.size() || device_str[pos] == ':' || device_str[pos] == ',') { + if (pos > start) { + apply_device_config(std::string_view(device_str).substr(start, pos - start)); } - } - } else { - if (target_config.count(session_context_.device_type)) { - auto supported_properties = OVCore::Get()->core.get_property(session_context_.device_type, - ov::supported_properties); - set_target_properties(session_context_.device_type, - target_config.at(session_context_.device_type), supported_properties); + start = pos + 1; } } } diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index 6a2b375d733f9..07b09899ac214 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -70,6 +70,7 @@ class SharedContext : public WeakSingleton { using config_t = std::map; using reshape_t = std::map; +using layout_t = std::map; struct ProviderInfo { std::string device_type{""}; // [device_type]: Overrides the accelerator hardware type and @@ -88,6 +89,7 @@ struct ProviderInfo { // (GPU) feature. If blob files are already present, // it will be directly loaded. reshape_t reshape{}; // Used for reshaping the ov input tensor shape at runtime. + layout_t layout{}; // Used for specifying the ov input/output tensor layout at runtime. std::string model_priority{"DEFAULT"}; // High-level OpenVINO model priority hint // Defines what model should be provided with more performant // bounded resource first @@ -110,7 +112,7 @@ struct ProviderInfo { const ConfigOptions* config_options{NULL}; const std::unordered_set valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", "precision", "load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer", - "enable_causallm", "disable_dynamic_shapes", "reshape_input"}; + "enable_causallm", "disable_dynamic_shapes", "reshape_input", "layout"}; }; // Holds context applicable to the entire EP instance. diff --git a/onnxruntime/core/providers/openvino/ibackend.h b/onnxruntime/core/providers/openvino/ibackend.h index ec38425f602eb..365a4625815d6 100644 --- a/onnxruntime/core/providers/openvino/ibackend.h +++ b/onnxruntime/core/providers/openvino/ibackend.h @@ -19,7 +19,7 @@ class IBackend { virtual ~IBackend() = default; virtual void RewindKVCache(size_t index) {} }; -using ptr_stream_t = std::unique_ptr; +using ptr_stream_t = std::unique_ptr; class BackendFactory { public: static std::shared_ptr diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc index 9e70756a254aa..051a39bd4f205 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc @@ -100,7 +100,8 @@ Status EPCtxHandler::AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer, return Status::OK(); } -std::unique_ptr EPCtxHandler::GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const { +std::unique_ptr +EPCtxHandler::GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const { auto first_index = *graph_viewer.GetNodesInTopologicalOrder().begin(); auto node = graph_viewer.GetNode(first_index); ORT_ENFORCE(node != nullptr); @@ -113,10 +114,11 @@ std::unique_ptr EPCtxHandler::GetModelBlobStream(const std::filesy bool embed_mode = static_cast(attrs.at(EMBED_MODE).i()); std::unique_ptr result; + std::filesystem::path blob_filepath{}; if (embed_mode) { result.reset((std::istream*)new std::istringstream(ep_cache_context)); } else { - auto blob_filepath = so_context_file_path; + blob_filepath = so_context_file_path; if (blob_filepath.empty() && !graph_viewer.ModelPath().empty()) { blob_filepath = graph_viewer.ModelPath(); } @@ -126,16 +128,18 @@ std::unique_ptr EPCtxHandler::GetModelBlobStream(const std::filesy } bool isXML = backend_utils::IsModelStreamXML(*result); + std::filesystem::path native_blob_path{}; if (!isXML) { // If the model stream is not an XML (i.e. precompiled blob), the OpenVINO SDK version that it was // exported with must match the version that is currently running. + native_blob_path = std::move(blob_filepath); ORT_ENFORCE((attrs.count(EP_SDK_VER) == 1) && (attrs.at(EP_SDK_VER).s() == openvino_sdk_version_), "EPCtx blob was exported / is compatible with OpenVINO SDK version " + attrs.at(EP_SDK_VER).s() + ", but OpenVINO SDK version currently in use is " + openvino_sdk_version_); } LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Read blob from EPContext Node"; - return result; + return std::make_unique(std::move(result), native_blob_path); } bool EPCtxHandler::CheckForOVEPCtxNodeInGraph(const GraphViewer& graph_viewer) const { diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h index b9ddb40a7a233..f207f5014ca1f 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h @@ -12,6 +12,12 @@ namespace onnxruntime { namespace openvino_ep { +struct ModelBlobWrapper { + ModelBlobWrapper(std::unique_ptr stream, const std::filesystem::path& native_blob_path) : stream_(std::move(stream)), maybe_native_blob_path_(native_blob_path) {} + std::unique_ptr stream_; + std::filesystem::path maybe_native_blob_path_; +}; + // Utilities to handle EPContext node export and parsing of an EPContext node // to create the compiled_model object to infer on static const char EPCONTEXT_OP[] = "EPContext"; @@ -31,7 +37,7 @@ class EPCtxHandler { const std::string& graph_name, const bool embed_mode, std::string&& model_blob_str) const; - std::unique_ptr GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const; + std::unique_ptr GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const; InlinedVector GetEPCtxNodes() const; bool CheckEPCacheContextAttribute(const GraphViewer& graph_viewer, const std::string& target_attr_extn) const; diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 1b19517b07363..a0fa885cbfc38 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -94,18 +94,23 @@ common::Status OpenVINOExecutionProvider::Compile( auto& logger = *GetLogger(); Status status = Status::OK(); + bool is_epctx_model = false; if (!fused_nodes.empty()) { // Assume these properties are constant for all the model subgraphs, otherwise move to SubGraphContext const auto& graph_body_viewer_0 = fused_nodes[0].filtered_graph.get(); session_context_.onnx_model_path_name = graph_body_viewer_0.ModelPath().string(); session_context_.onnx_opset_version = graph_body_viewer_0.DomainToVersionMap().at(kOnnxDomain); + + // OVIR wrapped in epctx should be treated as source but this code does not + // This corner case is not in use and will be addressed in a future commit + is_epctx_model = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(graph_body_viewer_0); } // The block below is executed during EP context model inference auto& metadata = shared_context_->shared_weights.metadata; // Metadata object in memory if (session_context_.so_share_ep_contexts && - !session_context_.so_context_enable && + is_epctx_model && metadata.empty()) { fs::path context_model_file_path = session_context_.so_context_file_path; if (context_model_file_path.empty()) { diff --git a/onnxruntime/core/providers/openvino/openvino_parser_utils.cc b/onnxruntime/core/providers/openvino/openvino_parser_utils.cc index 21fc7f935da23..a290fea73e0e8 100644 --- a/onnxruntime/core/providers/openvino/openvino_parser_utils.cc +++ b/onnxruntime/core/providers/openvino/openvino_parser_utils.cc @@ -236,5 +236,79 @@ ov::Dimension OpenVINOParserUtils::ParseDimensionRange(const std::string& range_ return ov::Dimension(range_start, range_end); } +layout_t OpenVINOParserUtils::ParseLayout(const std::string& layout_definition) { + layout_t parsed_layout_map; + + // Return empty map for empty input + if (layout_definition.empty()) { + ORT_THROW("Empty layout definition provided in layout parameter"); + } + + // Regular expression for parsing layout definitions + const std::regex layout_pattern(R"(([^\[\],]+)\s*\[(.*?)\])"); // e.g. "input_1[NC],data[CHW]" + + // Find all tensor layout definitions using regex + auto layout_begin = std::sregex_iterator( + layout_definition.begin(), + layout_definition.end(), + layout_pattern); + auto layout_end = std::sregex_iterator(); + + // If no matches found, throw error + if (layout_begin == layout_end) { + ORT_THROW("Invalid layout definition format: " + layout_definition); + } + + // Process each tensor definition + for (std::sregex_iterator i = std::move(layout_begin); i != layout_end; ++i) { + std::smatch layout_match = *i; + + // Extract tensor name and trim whitespace + std::string tensor_name = layout_match[1].str(); // Group 1: tensor name e.g. "input_1" + tensor_name = TrimWhitespace(tensor_name); + + if (tensor_name.empty()) { + ORT_THROW("Empty tensor name provided in layout parameter"); + } + + // Extract dimensions string + std::string dimensions_str = layout_match[2].str(); // Group 2: dimensions string [e.g. "NC", "CHW"] + + if (!Check_Valid_Layout(dimensions_str, tensor_name)) { + ORT_THROW("Invalid dimensions string provided in layout parameter"); + } + + // Store parsed shape in result map + parsed_layout_map[tensor_name] = ov::Layout(dimensions_str); + } + + return parsed_layout_map; +} + +bool OpenVINOParserUtils::Check_Valid_Layout(const std::string& layout_str, const std::string& tensor_name) { + // Check if the layout string is empty + if (layout_str.empty()) { + return false; + } + + std::unordered_set seen_alphabets; + for (char c : layout_str) { + if (std::isalpha(c)) { + char upper_c = static_cast(std::toupper(c)); // Convert to uppercase for case-insensitive comparison + if (seen_alphabets.find(upper_c) != seen_alphabets.end()) { + ORT_THROW("Repeated Dim '" + std::string(1, c) + + "' found in layout dimensions for tensor '" + tensor_name + "'"); + } + seen_alphabets.insert(upper_c); + } else if (c != '?') { + // Only '?' is allowed as non-alphabetic character + ORT_THROW("Invalid character '" + std::string(1, c) + + "' found in layout dimensions for tensor '" + tensor_name + "'"); + } + } + + return true; +} + } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/openvino_parser_utils.h b/onnxruntime/core/providers/openvino/openvino_parser_utils.h index e6aa0e0a46a3b..a0936d627df40 100644 --- a/onnxruntime/core/providers/openvino/openvino_parser_utils.h +++ b/onnxruntime/core/providers/openvino/openvino_parser_utils.h @@ -18,8 +18,10 @@ class OpenVINOParserUtils { std::string& device_type, const std::string& option_name); static reshape_t ParseInputShape(const std::string& reshape_input_definition); + static layout_t ParseLayout(const std::string& layout_definition); static std::string TrimWhitespace(const std::string& str); static ov::Dimension ParseDimensionRange(const std::string& range_str, const std::string& tensor_name); + static bool Check_Valid_Layout(const std::string& layout_str, const std::string& tensor_name); }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 9dba8623031d0..1a10d9849d5cc 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -171,7 +171,7 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio if (!device_mode.empty()) { selected_device = device_mode + ":" + ov_luid_devices; for (const auto& dev_str : devices_to_check) { - const auto default_dev = split(dev_str, '.')[0]; + const std::string default_dev = split(dev_str, '.')[0]; if (ov_luid_devices.find(default_dev) == std::string::npos) selected_device = selected_device + "," + dev_str; @@ -230,6 +230,10 @@ static void ParseProviderInfo(const ProviderOptions& provider_options, pi.reshape = OpenVINOParserUtils::ParseInputShape(provider_options.at("reshape_input")); } + if (provider_options.contains("layout")) { + pi.layout = OpenVINOParserUtils::ParseLayout(provider_options.at("layout")); + } + if (provider_options.contains("load_config")) { auto parse_config = [&](const std::string& config_str) -> std::map { // If the config string is empty, return an empty map and skip processing @@ -526,7 +530,7 @@ struct OpenVINO_Provider : Provider { std::string ov_device_string; if (is_meta_device_factory) { // Build up a meta device string based on the devices that are passed in. E.g. AUTO:NPU,GPU.0,CPU - ov_device_string = ov_meta_device_type; + ov_device_string = std::move(ov_meta_device_type); ov_device_string += ":"; } @@ -539,7 +543,7 @@ struct OpenVINO_Provider : Provider { prepend_comma = true; } - provider_options["device_type"] = ov_device_string; + provider_options["device_type"] = std::move(ov_device_string); // Parse provider info with the device type ProviderInfo pi; diff --git a/onnxruntime/core/providers/openvino/ov_factory.cc b/onnxruntime/core/providers/openvino/ov_factory.cc index 8860405338409..2853cc17726ab 100644 --- a/onnxruntime/core/providers/openvino/ov_factory.cc +++ b/onnxruntime/core/providers/openvino/ov_factory.cc @@ -105,7 +105,7 @@ OrtStatus* OpenVINOEpPluginFactory::GetSupportedDevices(const OrtHardwareDevice* std::string ov_device_name; auto get_gpu_device_id = [&](const std::string& ov_device) { try { - auto device_id_str = ov_core_->get_property(ov_device, "GPU_DEVICE_ID").as(); + const std::string device_id_str = ov_core_->get_property(ov_device, "GPU_DEVICE_ID").as(); return static_cast(std::stoul(device_id_str, nullptr, 0)); } catch (ov::Exception&) { return 0u; // If we can't get the GPU_DEVICE_ID info, we won't have a device ID. diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 2d29df8eb4197..899845d4890cf 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -11,6 +11,7 @@ #include "core/providers/openvino/backend_utils.h" #include "core/providers/openvino/backends/basic_backend.h" #include "core/providers/openvino/ov_stateful_patch_utils.h" +#include "core/providers/openvino/onnx_ctx_model_helper.h" namespace onnxruntime { namespace openvino_ep { @@ -191,14 +192,23 @@ OVExeNetwork OVCore::CompileModel(const std::string& onnx_model, "Exception while Loading Network for graph {}", name); } -OVExeNetwork OVCore::ImportModel(std::istream& model_stream, +OVExeNetwork OVCore::ImportModel(ModelBlobWrapper& model_blob, std::string hw_target, const ov::AnyMap& device_config, std::string name) { return OvExceptionBoundary([&]() { ov::CompiledModel obj; - obj = core.import_model(model_stream, hw_target, device_config); +#if (OPENVINO_VERSION_MAJOR > 2025 || (OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR >= 3)) + if (!model_blob.maybe_native_blob_path_.empty()) { + obj = core.import_model(ov::read_tensor_data(model_blob.maybe_native_blob_path_), hw_target, device_config); + } else { + obj = core.import_model(*model_blob.stream_, hw_target, device_config); + } +#else + obj = core.import_model(*model_blob.stream_, hw_target, device_config); +#endif OVExeNetwork exe(obj, hw_target); + #ifndef NDEBUG printDebugInfo(exe.Get()); #endif diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 6d1db4366410b..38ea883078e85 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -26,6 +26,7 @@ namespace openvino_ep { class OVCore; class OVInferRequest; class OVExeNetwork; +struct ModelBlobWrapper; typedef ov::Tensor OVTensor; typedef ov::ProfilingInfo OVProfilingInfo; @@ -82,7 +83,7 @@ struct OVCore : WeakSingleton { ov::AnyMap& device_config, const std::string& name); // OV Interface for Import model Stream - OVExeNetwork ImportModel(std::istream& model_stream, + OVExeNetwork ImportModel(ModelBlobWrapper& model_blob, std::string hw_target, const ov::AnyMap& device_config, std::string name); @@ -126,29 +127,16 @@ class OVInferRequest { OVTensorPtr GetTensor(const std::string& name); std::string GetInputTensorName(uint32_t index); - // Set tensor described param_info and ort_ptr. Overrides shape in param_info with shape_override. Call infer req tensor if ort_ptr is last set. + // Set tensor call infer req tensor if ort_ptr differs from last set ptr. void SetTensor(const std::string& name, const ov::element::Type& type, const ov::Shape& shape, void* ort_ptr) { auto& cached_binding = bindings_cache_[name]; - if (cached_binding.ort_ptr != ort_ptr) { - auto tensor_ptr = std::make_shared(type, shape, const_cast(ort_ptr)); - SetTensor(name, tensor_ptr); - cached_binding = {tensor_ptr, ort_ptr}; - } else if (ort_ptr == nullptr) { - // a null ort_ptr is expected for a tensor that has 0 elements. - // for example, a tensor of shape=[1, 8, 0, 64], which is valid. - // So, we check to see if at least one shape entry is 0. - auto contains_zero = [](const ov::Shape& shape) { - for (auto& s : shape) - if (s == 0) return true; - return false; - }; - if (contains_zero(shape)) { - // if there are zero elements (i.e. at least one shape entry is 0), - // then create and set the tensor anyway. - auto tensor_ptr = std::make_shared(type, shape); - SetTensor(name, tensor_ptr); - cached_binding = {tensor_ptr, ort_ptr}; - } + if (cached_binding.ort_ptr != ort_ptr || + !cached_binding.tensor_ptr || + cached_binding.tensor_ptr->get_shape() != shape) { + cached_binding.tensor_ptr.reset(); + auto ov_tensor = std::make_shared(type, shape, const_cast(ort_ptr)); + ovInfReq.set_tensor(name, *ov_tensor); + cached_binding = {std::move(ov_tensor), ort_ptr}; } } diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 2309ff3de751b..1893700cab09c 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -166,17 +166,28 @@ std::vector> GetCapability::Execute() { auto connected_clusters = GetConnectedClusters(graph_viewer_, ng_clusters); int no_of_clusters = 0; - + size_t cluster_index = 0; + size_t total_clusters = connected_clusters.size(); for (auto this_cluster : connected_clusters) { - // If subgraph has less then three, graph is considered trivial unless its an epctx cluster - if (this_cluster.size() < 3) { - bool is_epctx_node = false; - for (auto node_idx : this_cluster) { - if (graph_viewer_.GetNode(node_idx)->OpType() == "EPContext") - is_epctx_node = true; + bool omit_subgraph = false; + + if (this_cluster.size() == 1) { + // check next cluster + auto index = this_cluster.at(0); + size_t j = cluster_index; + if (graph_viewer_.GetNode(index)->OpType() == "EPContext") { + omit_subgraph = false; + } else if (j < total_clusters - 1) { + bool append_node = false; + while (j < total_clusters && !append_node) { + j = j + 1; + append_node = AddTrivialClusterToNextClusterIfConnected(graph_viewer_, index, connected_clusters[j]); + } + if (append_node) { + connected_clusters[j].emplace_back(index); + } + omit_subgraph = true; } - if (!is_epctx_node) - continue; } std::vector cluster_graph_inputs, cluster_inputs, cluster_outputs; @@ -188,7 +199,6 @@ std::vector> GetCapability::Execute() { cluster_inputs, cluster_outputs); - bool omit_subgraph = false; // Omitting zero dim subgraphs for (auto index : this_cluster) { const Node* node = graph_viewer_.GetNode(index); @@ -217,15 +227,17 @@ std::vector> GetCapability::Execute() { } } } - if (omit_subgraph) - continue; /* In scenarios, when there are no inputs or all inputs being initializers, ConstantFolding optimization in onnxruntime pre-computes the value.*/ - if (!cluster_inputs.empty()) { - AppendClusterToSubGraph(this_cluster, cluster_inputs, cluster_outputs, result); - no_of_clusters++; + if (!omit_subgraph) { + if (!cluster_inputs.empty()) { + AppendClusterToSubGraph(this_cluster, cluster_inputs, cluster_outputs, result); + no_of_clusters++; + } } + + cluster_index = cluster_index + 1; } LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Supported subgraphs on OpenVINO: " << no_of_clusters; } diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 336b294117cba..f848b89ed10c8 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -121,6 +121,7 @@ std::vector supported_op_mode = { {"DepthToSpace", V_2020_4, {"CPU", "GPU"}}, {"DequantizeLinear", V_2021_4, {"CPU", "GPU"}}, {"DequantizeLinear", V_2024_4, {"NPU"}}, + {"DynamicQuantizeLinear", V_2025_2, {"CPU", "GPU"}}, {"DynamicQuantizeMatMul", V_2025_0, {"CPU", "GPU"}}, {"Div", V_2020_4, {"CPU", "GPU"}}, {"Dropout", V_2020_4, {"CPU", "GPU"}}, @@ -172,6 +173,7 @@ std::vector supported_op_mode = { {"LSTM", V_2020_4, {"CPU", "GPU"}}, {"MatMul", V_2020_4, {"CPU", "GPU"}}, {"MatMulInteger", V_2022_1, {"CPU"}}, + {"MatMulInteger", V_2025_2, {"GPU"}}, {"MatMulNBits", V_2024_5, {"CPU", "GPU"}}, {"Max", V_2020_4, {"CPU", "GPU"}}, {"MaxPool", V_2020_4, {"CPU", "GPU"}}, @@ -191,7 +193,7 @@ std::vector supported_op_mode = { {"Pad", V_2020_4, {"CPU", "GPU"}}, {"Pow", V_2020_4, {"CPU", "GPU"}}, {"PRelu", V_2020_4, {"CPU", "GPU"}}, - {"QLinearMatMul", V_2022_3, {"CPU"}}, + // {"QLinearMatMul", V_2022_3, {"CPU"}}, {"QuantizeLinear", V_2021_4, {"CPU", "GPU"}}, {"QuickGelu", V_2025_0, {"CPU", "GPU"}}, {"RNN", V_2023_1, {"CPU", "GPU"}}, @@ -361,6 +363,7 @@ void DataOps::populate_op_mode_supported() { no_dimension_supported_.push_back({"Clip", V_2022_1, {"All"}}); no_dimension_supported_.push_back({"Div", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"DequantizeLinear", V_2021_4, {"All"}}); + no_dimension_supported_.push_back({"DynamicQuantizeLinear", V_2025_2, {"All"}}); no_dimension_supported_.push_back({"Equal", V_2022_1, {"CPU"}}); no_dimension_supported_.push_back({"Equal", V_2023_0, {"GPU"}}); no_dimension_supported_.push_back({"Expand", V_2023_3, {"CPU"}}); @@ -374,6 +377,7 @@ void DataOps::populate_op_mode_supported() { no_dimension_supported_.push_back({"Loop", V_2021_4, {"All"}}); no_dimension_supported_.push_back({"Max", V_2024_4, {"All"}}); no_dimension_supported_.push_back({"Min", V_2020_4, {"All"}}); + no_dimension_supported_.push_back({"MatMulInteger", V_2025_2, {"All"}}); no_dimension_supported_.push_back({"Mul", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Neg", V_2023_0, {"CPU", "GPU"}}); no_dimension_supported_.push_back({"Pow", V_2023_0, {"CPU", "GPU"}}); @@ -469,15 +473,7 @@ void DataOps::populate_op_mode_supported() { } } - // check for input dimensions const auto& x_arg = node->InputDefs()[0]; - auto shape = x_arg->Shape(); - if (shape != nullptr) { - // input tensor rank cannot be of one dimension - if (shape->dim_size() == 1 || shape->dim_size() == 4) { - return true; - } - } // x_arg supports only float, int8 and float16 type if ((x_arg->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) || @@ -563,8 +559,13 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { return false; } + auto dtype = type_proto->tensor_type().elem_type(); + // Enable bfloat16 -> float16 on-the-fly conversion + if (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16 || + dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16 || + dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16) + return true; if (is_initializer) { - auto dtype = type_proto->tensor_type().elem_type(); for (auto const& var : supported_types_initializer_) { if ((var.first <= version_id_) && (var.second == dtype)) { @@ -579,8 +580,6 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { #endif return false; } else { - auto dtype = type_proto->tensor_type().elem_type(); - if (device_id_.find("HETERO") != std::string::npos || device_id_.find("MULTI") != std::string::npos || device_id_.find("AUTO") != std::string::npos) { for (auto const& var : supported_types_npu_) { @@ -617,9 +616,6 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { (var.second == dtype)) { return true; } - // experimentally for GPU and qdq stripping mode allow int16 types - if (npu_qdq_optimizer_enabled_ && (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16 || dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16)) - return true; } #ifndef NDEBUG if (openvino_ep::backend_utils::IsDebugEnabled()) { diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.cc b/onnxruntime/core/providers/openvino/ov_versions/utils.cc index f924fa0c8205c..791341218913f 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.cc @@ -153,6 +153,24 @@ GetConnectedClusters(const GraphViewer& graph_viewer, const std::vector& search_cluster) { + for (auto index : search_cluster) { + auto curr_node = graph_viewer.GetNode(index); + for (auto node = curr_node->InputNodesBegin(); node != curr_node->InputNodesEnd(); ++node) { + if ((*node).Index() == curr_node_index) + return true; + } + + for (auto node = curr_node->OutputNodesBegin(); node != curr_node->OutputNodesEnd(); ++node) { + if ((*node).Index() == curr_node_index) + return true; + } + } + return false; +} + void GetInputsOutputsOfCluster(const GraphViewer& graph_viewer, const std::vector& cluster, const std::unordered_set& ng_required_initializers, diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.h b/onnxruntime/core/providers/openvino/ov_versions/utils.h index 34aa762ba9b67..bdad047a422c1 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.h +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.h @@ -40,6 +40,10 @@ void IdentifyConnectedNodes( std::vector> GetConnectedClusters(const GraphViewer& graph_viewer, const std::vector>& clusters); +bool AddTrivialClusterToNextClusterIfConnected(const GraphViewer& graph_viewer, + const NodeIndex index, + const std::vector& search_cluster); + void GetInputsOutputsOfCluster(const GraphViewer& graph_viewer, const std::vector& cluster, const std::unordered_set& ng_required_initializers, diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp index d159930d52845..3a39152b5d17d 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp @@ -3,6 +3,8 @@ #include "qdq_scales_fix.h" #include "core/providers/openvino/ov_protobuf_utils.h" +#include "core/framework/ort_value.h" +#include "core/framework/float16.h" #include #include @@ -903,22 +905,11 @@ Status copy_model(const GraphViewer& src_graph_viewer, } for (auto& [name, tensor_proto] : src_graph.GetAllInitializedTensors()) { - dst_graph.AddInitializedTensor(*tensor_proto); - } - - for (auto node_arg : src_graph.GetInputsIncludingInitializers()) { - auto check_inputs = [node_arg](auto input_node_arg) { - return input_node_arg->Name() == node_arg->Name(); - }; - if (std::find_if(dst_graph_inputs.begin(), dst_graph_inputs.end(), check_inputs) != dst_graph_inputs.end()) - continue; - - auto src_tensor_proto = src_graph.GetConstantInitializer(node_arg->Name(), true); - if (src_tensor_proto) { - auto dst_tensor_proto = onnx::TensorProto::Create(); - dst_tensor_proto->copy_from(src_tensor_proto); - dst_graph.AddInitializedTensor(*dst_tensor_proto); - } + auto ort_value = OrtValue(); + if (src_graph.GetOrtValueInitializer(name, ort_value)) + ORT_RETURN_IF_ERROR(dst_graph.AddInitializedOrtValue(*tensor_proto, ort_value)); + else + dst_graph.AddInitializedTensor(*tensor_proto); } ORT_RETURN_IF_ERROR(dst_graph.Resolve()); @@ -940,5 +931,54 @@ Status Transform(const GraphViewer& src_graph_viewer, return status; } } // namespace qdq_scales_fix + +namespace bfloat16_fix { +void replace_bf16_with_fp16(qdq_scales_fix::CustomGraph& gen_graph) { + for (auto& const_node : gen_graph.original_graph.Nodes()) { + auto node = const_cast(const_node); + if (node->OpType() == "Cast") { + for (auto& [name, const_attribute] : node->GetAttributes()) { + auto& attribute = const_cast(const_attribute); + if (name == "to" && attribute.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INT) + if (attribute.i() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) + attribute.set_i(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + } + } + for (auto& output : node->OutputDefs()) { + auto& output_proto = const_cast(output->ToProto().type()); + if (output_proto.mutable_tensor_type()->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) + output_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + } + } + + const auto& init_set = gen_graph.original_graph.GetAllInitializedTensors(); + for (auto& [key, const_tensor_proto] : init_set) { + auto tensor_proto = const_cast(const_tensor_proto); + auto dt = tensor_proto->data_type(); + if (dt == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) { + auto raw_data = tensor_proto->has_raw_data() ? reinterpret_cast(tensor_proto->mutable_raw_data()->data()) : nullptr; + if (raw_data) { + tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + std::int64_t size = 1; + for (int i = 0; i < tensor_proto->dims_size(); ++i) + size *= tensor_proto->dims()[i]; + for (std::int64_t i = 0; i < size; ++i) { + raw_data[i] = onnxruntime::MLFloat16(onnxruntime::BFloat16::FromBits(raw_data[i])).val; + } + } + } + } +} + +Status Transform(const GraphViewer& src_graph_viewer, + const logging::Logger& logger, + /*out*/ std::unique_ptr& model) { + auto status = qdq_scales_fix::copy_model(src_graph_viewer, logger, model); + auto g = qdq_scales_fix::generate_graph_from_onnx(model->MainGraph()); + + replace_bf16_with_fp16(g); + return status; +} +} // namespace bfloat16_fix } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h index c54c531e1bd40..2182850d96c43 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h @@ -15,5 +15,10 @@ Status Transform(const GraphViewer& src_graph, const logging::Logger& logger, /*out*/ std::unique_ptr& model); } +namespace bfloat16_fix { +Status Transform(const GraphViewer& src_graph, + const logging::Logger& logger, + /*out*/ std::unique_ptr& model); +} } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index 24e8892622175..e010851f22e50 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -677,6 +677,27 @@ static void AddInitializerAsInput(onnxruntime::Graph& dst_graph, } } +// To check if the input parameters of a DQ or Q node are quantization parameters +// Scale and Zero point parameters are quantization parameters +static bool IsQuantizationParameter(const std::string& initializer_name, + const onnxruntime::GraphViewer& src_graph) { + // Check if this initializer is used as scale or zero_point in any DQ/Q node + for (auto& node_idx : src_graph.GetNodesInTopologicalOrder()) { + const auto* node = src_graph.GetNode(node_idx); + if (node->OpType() == "DequantizeLinear" || node->OpType() == "QuantizeLinear") { + const auto& input_defs = node->InputDefs(); + // Check if this initializer is used as scale (input 1) or zero_point (input 2) + if (input_defs.size() >= 2 && input_defs[1]->Name() == initializer_name) { + return true; // This is a scale parameter + } + if (input_defs.size() >= 3 && input_defs[2]->Name() == initializer_name) { + return true; // This is a zero_point parameter + } + } + } + return false; +} + // Creates a new model without the DQ/Q operators in the src graph. Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, const logging::Logger& logger, @@ -845,10 +866,20 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, if (!init_with_data && utils::HasExternalData(initializer_tensor) && enable_ovep_weight_sharing) { - insert_metadata(initializer_tensor); + // Only convert to input if it's not a quantization parameter + bool is_quant_param = IsQuantizationParameter(name, src_graph); + + if (!is_quant_param) { + // This is actual weight data - so to convert to input for weight sharing + insert_metadata(initializer_tensor); + AddInitializerAsInput(dst_graph, accumulated_inputs, src_graph, name); + } else { + // This is a quantization parameter - keep as initializer even if external - // Add initializer with external data as input - AddInitializerAsInput(dst_graph, accumulated_inputs, src_graph, name); + if (initializers_to_keep.count(name) > 0) { + dst_graph.AddInitializedTensor(initializer_tensor); + } + } } else { // Add as an initialized tensor if it does not have external data if (initializers_to_keep.count(name) > 0) { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index 0152ad27c0ba2..d99e322641199 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -273,11 +273,10 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper, // Check if we need to add a cast node for int64 bool needs_int64_cast = false; if (is_graph_output) { - for (const auto& input_name : input_names) { - if (input_name.find("_cast_int32") != std::string::npos) { - needs_int64_cast = true; - break; - } + if (supported_qnn_data_type == output_info.qnn_data_type && + (output_info.qnn_data_type == QNN_DATATYPE_INT_64 || output_info.qnn_data_type == QNN_DATATYPE_UINT_64)) { + supported_qnn_data_type = supported_qnn_data_type == QNN_DATATYPE_INT_64 ? QNN_DATATYPE_INT_32 : QNN_DATATYPE_UINT_32; + needs_int64_cast = true; } } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index e910afcbcf6c6..dbdb2d828f039 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -236,7 +236,7 @@ class BaseOpBuilder : public IOpBuilder { } // Onnx Pads is [x1_begin, x2_begin, x1_end, x2_end], QNN requires [x1_begin, x1_end, x2_begin, x2_end] - void ReArranagePads(std::vector& pads) const { + void ReArrangePads(std::vector& pads) const { auto pads_size = pads.size(); auto middle_pos = pads_size / 2; std::vector first_half(pads.begin(), pads.begin() + middle_pos); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc index b80d9db5d3560..dba4fbdbe0872 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc @@ -24,7 +24,6 @@ static Status GetOnnxConvType(const std::string& onnx_op_type, OnnxConvType& con } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unsupported ONNX convolution op type: ", onnx_op_type.c_str()); } - return Status::OK(); } @@ -171,7 +170,7 @@ Status ConvOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, return ProcessConv2D3DInputs(qnn_model_wrapper, node_unit, logger, input_names, do_op_validation); } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN Conv only supports 3D(rank 5), 2D (rank 4) or 1D (rank 3) inputs."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN Conv only supports 3D (rank 5), 2D (rank 4) or 1D (rank 3) inputs."); } Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper, @@ -713,7 +712,7 @@ Status ConvOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra } } - ReArranagePads(pads); + ReArrangePads(pads); uint32_t pad_size = narrow(pads.size() / 2); QnnParamWrapper pad_amount_paramwrapper(node_unit.Index(), node_unit.Name(), QNN_OP_CONV_2D_PARAM_PAD_AMOUNT, {pad_size, 2}, std::move(pads)); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc index 404d3c402c21e..d2b1434c1c896 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc @@ -193,7 +193,7 @@ Status PadOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrap [](int64_t item) { return SafeInt(item); }); // Onnx format is begin_0, begin_1, ..., end_0, end_1, ... // Qnn format is begin_0, end_0, begin_1, end_1, ... - ReArranagePads(pad_amount); + ReArrangePads(pad_amount); std::vector input_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape of input 0."); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index 21947a22e2b92..78ab047a560a7 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -93,15 +93,16 @@ Status PoolOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } -static std::vector AmendOutputShapeForRank3Pool( +Status AmendOutputShapeForRank3Pool( gsl::span input_shape, // {N, H, W, C} gsl::span kernel_shape, // {k_h, k_w} gsl::span strides, // {s_h, s_w} - gsl::span pads) { - assert(input_shape.size() == 4 && - kernel_shape.size() == 2 && - strides.size() == 2 && - pads.size() == 4); + gsl::span pads, + std::vector& output_shape) { + ORT_RETURN_IF_NOT(input_shape.size() == 4, "Expecting input rank 4 for amending 1D Pool output shape."); + ORT_RETURN_IF_NOT(kernel_shape.size() == 2, "Expecting kernel size 2 for amending 1D Pool output shape."); + ORT_RETURN_IF_NOT(strides.size() == 2, "Expecting strides size 2 for amending 1D Pool output shape."); + ORT_RETURN_IF_NOT(pads.size() == 4, "Expecting pad size 4 for amending 1D Pool output shape."); const uint32_t N = input_shape[0]; const uint32_t H = input_shape[1]; @@ -120,7 +121,13 @@ static std::vector AmendOutputShapeForRank3Pool( ? 0 : (padded_W - kernel_shape[1]) / strides[1] + 1; - return {N, out_H, out_W, C}; + output_shape.resize(4); + output_shape[0] = N; + output_shape[1] = out_H; + output_shape[2] = out_W; + output_shape[3] = C; + + return Status::OK(); } Status PoolOpBuilder::SetCommonPoolParams(const NodeAttrHelper& node_helper, @@ -177,10 +184,7 @@ Status PoolOpBuilder::SetCommonPoolParams(const NodeAttrHelper& node_helper, if (auto_pad.compare("NOTSET") != 0) { if (output_shape.size() == 3) { // Calculate rank-4 output shape for rank-3 input. - output_shape = AmendOutputShapeForRank3Pool(input_shape, - filter_size, - stride, - pad_amount); + ORT_RETURN_IF_ERROR(AmendOutputShapeForRank3Pool(input_shape, filter_size, stride, pad_amount, output_shape)); } for (size_t axis = 0; axis < rank - 2; ++axis) { @@ -195,7 +199,7 @@ Status PoolOpBuilder::SetCommonPoolParams(const NodeAttrHelper& node_helper, } } } - ReArranagePads(pad_amount); + ReArrangePads(pad_amount); // Param: rounding_mode. rounding_mode = node_helper.Get("ceil_mode", rounding_mode); @@ -365,14 +369,6 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra std::move(output_shape))); } - // Calculate rank-4 output shape for rank-3 input. - std::vector onnx_in_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, onnx_in_shape), "Cannot get shape"); - if (onnx_in_shape.size() == 3) { - onnx_in_shape = {onnx_in_shape[0], 1, onnx_in_shape[1], onnx_in_shape[2]}; - } - auto pooled_shape = AmendOutputShapeForRank3Pool(onnx_in_shape, filter_size, stride, pad_amount); - // Construct param wrappers. ORT_RETURN_IF_NOT(SetPoolParam(node_unit, param_filter_size, @@ -443,6 +439,16 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra return Status::OK(); } + + // Calculate rank-4 output shape for rank-3 input. + std::vector onnx_in_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, onnx_in_shape), "Cannot get shape"); + if (onnx_in_shape.size() == 3) { + onnx_in_shape = {onnx_in_shape[0], 1, onnx_in_shape[1], onnx_in_shape[2]}; + } + std::vector pooled_shape; + ORT_RETURN_IF_ERROR(AmendOutputShapeForRank3Pool(onnx_in_shape, filter_size, stride, pad_amount, pooled_shape)); + const auto& outputs = node_unit.Outputs(); const std::string real_out = outputs[0].node_arg.Name(); const std::string pool_out = real_out + "_reshape_after"; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 3dc103046424e..5bcb8ca394346 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -787,10 +787,12 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord std::vector context_params_list; std::vector context_paramsv1_list; - std::vector context_params_ptr_list(context_bin_map.size() + 1); + std::vector context_params_ptr_list; std::vector> buffer_list; - size_t idx = 0; + context_params_list.reserve(context_bin_map.size()); + context_params_ptr_list.reserve(context_bin_map.size() + 1); + for (auto& it : context_bin_map) { auto context_bin_filepath = it.first; @@ -821,9 +823,9 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord buffer_list.push_back(std::move(buffer)); context_params_list.push_back(std::move(context_params)); context_paramsv1_list.push_back(std::move(context_params_v1)); - context_params_ptr_list[idx++] = &context_params_list.back(); + context_params_ptr_list.push_back(&context_params_list.back()); } - context_params_ptr_list[idx] = nullptr; + context_params_ptr_list.push_back(nullptr); auto result = qnn_interface_.contextCreateFromBinaryListAsync(backend_handle_, device_handle_, context_params_ptr_list.data(), @@ -1178,6 +1180,14 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, #if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 26) if (vtcm_backup_buffer_sharing_enabled_) { + // If a context bin filepath has not been processed yet, + // then a new context must be created for the set of context bins + auto first_mapping_it = ep_context_handle_map_.find(context_bin_map.begin()->first); + if (first_mapping_it == ep_context_handle_map_.end()) { + LOGS(logger, VERBOSE) << "Creating context for new set of context binaries"; + return CreateContextVtcmBackupBufferSharingEnabled(context_bin_map); + } + LOGS(logger, VERBOSE) << "Mapping contexts to new EP main context nodes"; for (auto& it : context_bin_map) { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index e1a74b9e35370..ee5f52289d779 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -158,7 +158,7 @@ bool QnnModelWrapper::CreateQnnInputOutputTensors(const std::string& qnn_node_na return false; } - // During graph patitioning, we only need to do op validation, it's not required to create Qnn graph tensor + // During graph partitioning, we only need to do op validation, it's not required to create Qnn graph tensor // We only need to create the Qnn graph tensor during Compile to create Qnn graph if (!do_op_validation) { std::string error_string; diff --git a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc index 8641952f27ee5..a7e553848fb4d 100644 --- a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc +++ b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc @@ -219,9 +219,11 @@ struct QnnEpFactory : OrtEpFactory { OrtKeyValuePairs* ep_options = nullptr; factory->ort_api.CreateKeyValuePairs(&ep_options); factory->ort_api.AddKeyValuePair(ep_options, "backend_path", factory->qnn_backend_path.c_str()); - ORT_API_RETURN_IF_ERROR( - factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, nullptr, ep_options, - &ep_devices[num_ep_devices++])); + OrtStatus* status = factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, nullptr, ep_options, + &ep_devices[num_ep_devices++]); + + factory->ort_api.ReleaseKeyValuePairs(ep_options); + ORT_API_RETURN_IF_ERROR(status); } } diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 71d51c4c2992d..a7fd83f10fe18 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -326,6 +326,12 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe const logging::Logger& logger); std::string GetEnvironmentVar(const std::string& var_name); +inline std::string GetEnvironmentVar(std::string_view var_name) { + return GetEnvironmentVar(std::string{var_name}); +} +inline std::string GetEnvironmentVar(const char* var_name) { + return GetEnvironmentVar(std::string{var_name}); +} namespace profiling { diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 031a4df59d83f..d690cf31072d2 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -790,12 +790,12 @@ Status LoadDynamicLibrary(onnxruntime::PathString library_name) { #endif #ifdef _WIN32 -std::string ToUTF8String(const std::wstring& s) { - return g_host->ToUTF8String(s); +std::string ToUTF8String(std::wstring_view s) { + return g_host->ToUTF8String(std::wstring{s}); } -std::wstring ToWideString(const std::string& s) { - return g_host->ToWideString(s); +std::wstring ToWideString(std::string_view s) { + return g_host->ToWideString(std::string{s}); } #endif // _WIN32 } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 5c9c1a0ae163f..9a0bcb53c9ad7 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -1011,6 +1011,8 @@ struct ProviderHost { virtual void Graph__AddInitializedTensor(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor) = 0; // We pass OrtValue by reference here (as opposed to the original Graph function) to avoid header inclusion virtual Status Graph__AddInitializedOrtValue(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor, const OrtValue& value) = 0; + virtual bool Graph__GetOrtValueInitializer(const Graph* p, const std::string& tensor_name, OrtValue& value, + bool check_outer_scope) = 0; virtual Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span& input_args, const gsl::span& output_args, const NodeAttributes* attributes, const std::string& domain) = 0; virtual Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span& input_args, const gsl::span& output_args, NodeAttributes&& attributes, const std::string& domain) = 0; virtual Node& Graph__AddNode(Graph* p, const Node& other) = 0; @@ -1074,6 +1076,8 @@ struct ProviderHost { virtual const ONNX_NAMESPACE::TensorProto* GraphViewer__GetConstantInitializer(const GraphViewer* p, const std::string& name, bool check_outer_scope) const = 0; + virtual bool GraphViewer__GetOrtValueInitializer(const GraphViewer* p, const std::string& tensor_name, + OrtValue& value) = 0; virtual const Node* GraphViewer__ParentNode(const GraphViewer* p) = 0; virtual int GraphViewer__NumberOfNodes(const GraphViewer* p) noexcept = 0; virtual int GraphViewer__MaxNodeIndex(const GraphViewer* p) noexcept = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 23fbead1e9707..19b4636c3766d 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1041,6 +1041,10 @@ struct Graph final { Status AddInitializedOrtValue(const ONNX_NAMESPACE::TensorProto& tensor, const OrtValue& ort_value) { return g_host->Graph__AddInitializedOrtValue(this, tensor, ort_value); } + bool GetOrtValueInitializer(const std::string& tensor_name, OrtValue& ort_value, + bool check_outer_scope = false) const { + return g_host->Graph__GetOrtValueInitializer(this, tensor_name, ort_value, check_outer_scope); + } Node& AddNode(const std::string& name, const std::string& op_type, const std::string& description, gsl::span input_args, gsl::span output_args, const NodeAttributes* attributes, const std::string& domain) { return g_host->Graph__AddNode(this, name, op_type, description, input_args, output_args, attributes, domain); } Node& AddNode(const std::string& name, const std::string& op_type, const std::string& description, gsl::span input_args, gsl::span output_args, NodeAttributes&& attributes, const std::string& domain) { return g_host->Graph__AddNode(this, name, op_type, description, input_args, output_args, std::move(attributes), domain); } Node& AddNode(const Node& other) { return g_host->Graph__AddNode(this, other); } @@ -1124,6 +1128,9 @@ class GraphViewer final { bool check_outer_scope = true) const { return g_host->GraphViewer__GetConstantInitializer(this, name, check_outer_scope); } + bool GetOrtValueInitializer(const std::string& tensor_name, OrtValue& ort_value) const { + return g_host->GraphViewer__GetOrtValueInitializer(this, tensor_name, ort_value); + } const Node* ParentNode() const { return g_host->GraphViewer__ParentNode(this); } int NumberOfNodes() const noexcept { return g_host->GraphViewer__NumberOfNodes(this); } diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc index 20ae1cfbfa2c1..c6bf29dafa184 100644 --- a/onnxruntime/core/providers/vitisai/imp/graph.cc +++ b/onnxruntime/core/providers/vitisai/imp/graph.cc @@ -15,7 +15,7 @@ #include "vaip/node.h" #include "vaip/node_arg.h" - +#include "./tensor_proto.h" namespace vaip { struct NodeEdgeT { @@ -286,7 +286,14 @@ Model* model_clone(const Model& original_model, int64_t external_data_threshold) cloned_tensor->add_dims(dim); size = size * dim; } - if (size >= external_data_threshold) { + auto ORT_MEM_ADDR_tag = process_ext_address(*original_tensor); + if (!ORT_MEM_ADDR_tag.empty()) { + cloned_tensor->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + auto external_data = cloned_tensor->mutable_external_data(); + auto p = external_data->Add(); + *p->mutable_key() = "location"; + *p->mutable_value() = std::string("<") + graph_ptr; + } else if (size >= external_data_threshold) { cloned_tensor->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); auto external_data = cloned_tensor->mutable_external_data(); auto p = external_data->Add(); diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc index bb942c69003a1..2f1478bf1326b 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc @@ -10,7 +10,7 @@ namespace vaip { using namespace onnxruntime; -static gsl::span process_ext_address(const ONNX_NAMESPACE::TensorProto& tensor) { +gsl::span process_ext_address(const ONNX_NAMESPACE::TensorProto& tensor) { auto tensor_proto = const_cast(&tensor); auto file = std::string(); uintptr_t offset = 0; diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h index 73015d3411a54..a7c90ac18b44e 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h @@ -37,4 +37,5 @@ ONNX_NAMESPACE::TensorProto* tensor_proto_new_fp16(const std::string& name, cons const std::vector& data); ONNX_NAMESPACE::TensorProto* tensor_proto_new_doubles(const std::string& name, const std::vector& shape, const std::vector& data); +gsl::span process_ext_address(const ONNX_NAMESPACE::TensorProto& tensor); } // namespace vaip diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index a22d21d8d798b..bdeea726a2cf5 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -491,16 +491,29 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha ss << ","; } - auto alignment = (data_type == ProgramUniformVariableDataType::Float16 && length > 4) ? "@align(16) " : ""; - ss << "\n " << alignment << name << ": "; + // The actual variable type for the uniform variable depends on the data type (T) and length (N). + // + // For T in [i32, u32, f32]: + // - If N == 1, the type is simply i32, u32, or f32. + // - If 2 < N <= 4, the type is vecN, vecN, or vecN where N is the length. + // - If N > 4, the type is array, ceil(N / 4)>. + // + // For T is f16: + // - If N == 1 or N == 2, the type is u32. + // - If 2 < N <= 8, the type is vecX where X is ceil(N / 2). + // - If N > 8, the type is array, X> where X is ceil(N / 8). + // + // Note: Using f16 type in uniforms is not generally supported on all devices. We use a u32 variable to represent + // 2 f16 values. + + if (data_type == ProgramUniformVariableDataType::Float16) { + data_type = ProgramUniformVariableDataType::Uint32; // f16 is represented as u32 + length = (length + 1) / 2; // each u32 can hold 2 f16 values + } + ss << "\n " << name << ": "; if (length > 4) { - if (data_type == ProgramUniformVariableDataType::Float16) { - size_t array_size = (length + 7) / 8; - ss << "array, " << array_size << ">"; - } else { - size_t array_size = (length + 3) / 4; - ss << "array, " << array_size << ">"; - } + size_t array_size = (length + 3) / 4; + ss << "array, " << array_size << ">"; } else if (length > 1) { ss << "vec" << length << "<" << data_type << ">"; } else { diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 2aba2a59d157f..78c98ab26f5b8 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -17,18 +17,34 @@ template || std::is_same_v>> std::string GetElementAt(std::string_view var, const TIdx& idx, TRank rank, bool is_f16 = false) { - // "std::string::rfind(str, 0) == 0" is equivalent to "std::string::starts_with(str)" before C++20. - if (var.rfind("uniforms.", 0) == 0) { - if (rank > 4) { - if constexpr (std::is_integral_v) { - if (is_f16) { - return MakeStringWithClassicLocale(var, "[", idx / 8, "][", (idx % 8) / 4, "][", (idx % 8) % 4, "]"); + if (var.starts_with("uniforms.")) { + if (is_f16) { + if (rank > 8) { + // array, N> + if constexpr (std::is_integral_v) { + return MakeStringWithClassicLocale("bitcast>(", var, "[", idx / 8, "][", (idx % 8) / 2, "])[", (idx % 8) % 2, "]"); } else { - return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]"); + return MakeStringWithClassicLocale("bitcast>(", var, "[(", idx, ") / 8][((", idx, ") % 8) / 2])[((", idx, ") % 8) % 2]"); + } + } else if (rank > 2) { + // vecN + if constexpr (std::is_integral_v) { + return MakeStringWithClassicLocale("bitcast>(", var, "[", idx / 2, "])[", idx % 2, "]"); + } else { + return MakeStringWithClassicLocale("bitcast>(", var, "[(", idx, ") / 2])[(", idx, ") % 2]"); } } else { - if (is_f16) { - return MakeStringWithClassicLocale(var, "[(", idx, ") / 8][(", idx, ") % 8 / 4][(", idx, ") % 8 % 4]"); + // u32 + if constexpr (std::is_integral_v) { + return MakeStringWithClassicLocale("bitcast>(", var, ")[", idx % 2, "]"); + } else { + return MakeStringWithClassicLocale("bitcast>(", var, ")[(", idx, ") % 2]"); + } + } + } else { + if (rank > 4) { + if constexpr (std::is_integral_v) { + return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]"); } else { return MakeStringWithClassicLocale(var, "[(", idx, ") / 4][(", idx, ") % 4]"); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 4bd79a627df22..562d54d1bf977 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -210,7 +210,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { return tensor != nullptr && tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && tensor->Location().device.Type() == OrtDevice::GPU && - !strcmp(tensor->Location().name, WEBGPU_BUFFER); + !strcmp(tensor->Location().name.c_str(), WEBGPU_BUFFER); }), "All inputs must be tensors on WebGPU buffers."); @@ -219,7 +219,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { return tensor != nullptr && tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && tensor->Location().device.Type() == OrtDevice::GPU && - !strcmp(tensor->Location().name, WEBGPU_BUFFER); + !strcmp(tensor->Location().name.c_str(), WEBGPU_BUFFER); }), "All outputs must be tensors on WebGPU buffers."); } @@ -373,26 +373,57 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { continue; } - bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; - - size_t element_size = ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)]; + // Calculate the size and alignment of the uniform variable. + // // https://www.w3.org/TR/WGSL/#alignof - size_t base_alignment = is_f16 - ? (length > 4 ? 16 : length > 2 ? 8 - : length * element_size) - : (length > 2 ? 16 : length * element_size); - size_t struct_size = is_f16 && length <= 4 ? length * element_size : 16; - - current_offset = (current_offset + base_alignment - 1) / base_alignment * base_alignment; + // + // For f16: + // - length > 8 : array, N> (align 16) (size 16 * N, N = ceil(length / 8)) + // - length == 7 or 8: vec4 (align 16) (size 16) + // - length == 5 or 6: vec3 (align 16) (size 12) + // - length == 3 or 4: vec2 (align 8) (size 8) + // - length == 1 or 2: u32 (align 4) (size 4) + // + // For other types (i32, u32, f32): + // - length > 4 : array, N> (align 16) (size 16 * N, N = ceil(length / 4)) + // - length == 4 : vec4 (align 16) (size 16) + // - length == 3 : vec3 (align 16) (size 12) + // - length == 2 : vec2 (align 8) (size 8) + // - length == 1 : T (align 4) (size 4) + // + + const bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; + + size_t variable_alignment = 4; // default alignment for scalar types + size_t variable_size = 4; // default size for scalar types + + if (is_f16) { + if (length > 6) { + variable_alignment = 16; + variable_size = 16 * ((length + 7) / 8); + } else if (length > 4) { + variable_alignment = 16; + variable_size = 12; + } else if (length > 2) { + variable_alignment = 8; + variable_size = 8; + } + } else { + if (length > 3) { + variable_alignment = 16; + variable_size = 16 * ((length + 3) / 4); + } else if (length > 2) { + variable_alignment = 16; + variable_size = 12; + } else if (length > 1) { + variable_alignment = 8; + variable_size = 8; + } + } + current_offset = (current_offset + variable_alignment - 1) / variable_alignment * variable_alignment; uniform_and_offsets.emplace_back(uniform, current_offset); - // For non-float16 type, when length > 4, the uniform variable is of type array,N>, where - // N = ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * SizeOf(vec4). - // For float16 type, when length > 4, the uniform variable is of type array,N>, where - // N = ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte length is N * SizeOf(mat2x4). - size_t element_per_struct = is_f16 ? 8 : 4; - current_offset += - length > 4 ? (length + element_per_struct - 1) / element_per_struct * struct_size : length * element_size; + current_offset += variable_size; } // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set diff --git a/onnxruntime/core/session/abi_key_value_pairs.h b/onnxruntime/core/session/abi_key_value_pairs.h index 7d739439b7a27..72d5007b84e6f 100644 --- a/onnxruntime/core/session/abi_key_value_pairs.h +++ b/onnxruntime/core/session/abi_key_value_pairs.h @@ -18,11 +18,11 @@ struct OrtKeyValuePairs { CopyFromMap(other.entries_); } - OrtKeyValuePairs(OrtKeyValuePairs&& other) : OrtKeyValuePairs{} { + OrtKeyValuePairs(OrtKeyValuePairs&& other) noexcept : OrtKeyValuePairs{} { swap(*this, other); } - OrtKeyValuePairs& operator=(OrtKeyValuePairs other) { // handles copy and move assignment + OrtKeyValuePairs& operator=(OrtKeyValuePairs other) noexcept { // handles copy and move assignment swap(*this, other); return *this; } diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index 59b0992d827e1..b9a54ea7104e1 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -64,7 +64,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetInputModelPath, API_IMPL_BEGIN #if !defined(ORT_MINIMAL_BUILD) auto model_compile_options = reinterpret_cast(ort_model_compile_options); - std::string model_path = PathToUTF8String(input_model_path); + std::filesystem::path model_path = input_model_path; if (model_path.empty()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid input model: path string is empty"); @@ -113,7 +113,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelPath, #if !defined(ORT_MINIMAL_BUILD) auto model_compile_options = reinterpret_cast(ort_model_compile_options); - std::string model_path = PathToUTF8String(output_model_path); + std::filesystem::path model_path = output_model_path; if (model_path.empty()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid output model path: path is empty"); } @@ -136,17 +136,18 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInf #if !defined(ORT_MINIMAL_BUILD) auto model_compile_options = reinterpret_cast(ort_model_compile_options); - std::string output_dir = PathToUTF8String(output_directory); - if (output_dir.empty()) { + std::filesystem::path output_directory_path = output_directory; + if (output_directory_path.empty()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid output directory: path is empty"); } - std::string model_name_str = ToUTF8String(model_name); - if (model_name_str.empty()) { + std::filesystem::path model_name_path = model_name; + if (model_name_path.empty()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid model name: string is empty"); } - ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetEpContextBinaryInformation(output_dir, model_name_str)); + ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetEpContextBinaryInformation(output_directory_path, + model_name_path)); return nullptr; #else ORT_UNUSED_PARAMETER(ort_model_compile_options); @@ -163,7 +164,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelExterna size_t external_initializer_size_threshold) { API_IMPL_BEGIN #if !defined(ORT_MINIMAL_BUILD) - std::string initializers_file_path = PathToUTF8String(external_initializers_file_path); + std::filesystem::path initializers_file_path = external_initializers_file_path; if (initializers_file_path.empty()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid external initializer file: path is empty"); } @@ -214,6 +215,50 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelBuffer, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelWriteFunc, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + _In_ OrtWriteBufferFunc write_func, _In_ void* state) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + + if (write_func == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtWriteBufferFunc function for output model is null"); + } + + model_compile_options->SetOutputModelWriteFunc(write_func, state); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(write_func); + ORT_UNUSED_PARAMETER(state); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + + if (get_initializer_location_func == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "OrtGetInitializerLocationFunc function for output model is null"); + } + + model_compile_options->SetOutputModelGetInitializerLocationFunc(get_initializer_location_func, state); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(get_initializer_location_func); + ORT_UNUSED_PARAMETER(state); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModelCompilationOptions* ort_model_compile_options, bool embed_ep_context_in_model) { @@ -231,7 +276,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode } ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetFlags, - _In_ OrtModelCompilationOptions* ort_model_compile_options, size_t flags) { + _In_ OrtModelCompilationOptions* ort_model_compile_options, uint32_t flags) { API_IMPL_BEGIN #if !defined(ORT_MINIMAL_BUILD) auto model_compile_options = reinterpret_cast(ort_model_compile_options); @@ -245,6 +290,22 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetFlags, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationLevel, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + _In_ GraphOptimizationLevel graph_optimization_level) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetGraphOptimizationLevel(graph_optimization_level)); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(graph_optimization_level); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtCompileAPI::CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* ort_model_compile_options) { API_IMPL_BEGIN @@ -278,6 +339,9 @@ static constexpr OrtCompileApi ort_compile_api = { &OrtCompileAPI::ModelCompilationOptions_SetFlags, &OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation, + &OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationLevel, + &OrtCompileAPI::ModelCompilationOptions_SetOutputModelWriteFunc, + &OrtCompileAPI::ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index 93cc5dbf20fce..34fa06340a7f9 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -29,8 +29,17 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModel bool embed_ep_context_in_model); ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); ORT_API_STATUS_IMPL(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_options, - size_t flags); + uint32_t flags); ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextBinaryInformation, _In_ OrtModelCompilationOptions* model_compile_options, _In_ const ORTCHAR_T* output_dir, _In_ const ORTCHAR_T* model_name); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetGraphOptimizationLevel, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ GraphOptimizationLevel graph_optimization_level); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelWriteFunc, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ OrtWriteBufferFunc write_func, _In_ void* state); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state); } // namespace OrtCompileAPI diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 450a8bad09392..9c40eb75780ee 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -16,10 +16,10 @@ #include "core/session/abi_session_options_impl.h" #include "core/session/allocator_adapters.h" #include "core/session/inference_session.h" -#include "core/session/ep_factory_internal.h" -#include "core/session/ep_library_internal.h" -#include "core/session/ep_library_plugin.h" -#include "core/session/ep_library_provider_bridge.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_library_internal.h" +#include "core/session/plugin_ep/ep_library_plugin.h" +#include "core/session/plugin_ep/ep_library_provider_bridge.h" #include "core/session/ort_apis.h" #include "core/session/utils.h" @@ -72,21 +72,23 @@ ProviderInfo_CUDA& GetProviderInfo_CUDA(); #endif // defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) namespace { -// Ignore whether there is an arena wrapping the allocator by excluding OrtMemoryInfo.alloc_type from the comparison +// Ignore whether there is an arena wrapping the allocator by excluding OrtMemoryInfo.alloc_type from the comparison. static bool AreOrtMemoryInfosEquivalent( const OrtMemoryInfo& left, const OrtMemoryInfo& right, - bool match_name = true) { + bool match_name = true, + bool ignore_alignment = false) { return left.mem_type == right.mem_type && - left.device == right.device && - (!match_name || strcmp(left.name, right.name) == 0); + (ignore_alignment ? left.device.EqualIgnoringAlignment(right.device) : left.device == right.device) && + (!match_name || left.name == right.name); } std::vector::const_iterator FindExistingAllocator(const std::vector& allocators, const OrtMemoryInfo& mem_info, - bool match_name = true) { + bool match_name = true, + bool ignore_alignment = false) { auto ite = std::find_if(std::begin(allocators), std::end(allocators), - [&mem_info, match_name](const AllocatorPtr& alloc_ptr) { + [&mem_info, match_name, ignore_alignment](const AllocatorPtr& alloc_ptr) { // We want to do the equality checking of 2 OrtMemoryInfos sans the OrtAllocatorType field. // This is because we want to avoid registering two allocators for the same device that just // differ on OrtAllocatorType. @@ -96,7 +98,8 @@ std::vector::const_iterator FindExistingAllocator(const std::vecto // OrtDeviceAllocator (which is the only accepted value while registering a custom allocator). // If we allowed this, it could potentially cause a lot of confusion as to which shared allocator // to use for that device and we want to avoid having any ugly logic around this. - return AreOrtMemoryInfosEquivalent(alloc_ptr->Info(), mem_info, match_name); + return AreOrtMemoryInfosEquivalent(alloc_ptr->Info(), mem_info, + match_name, ignore_alignment); }); return ite; @@ -179,11 +182,6 @@ Status Environment::UnregisterAllocatorImpl(const OrtMemoryInfo& mem_info, bool shared_ort_allocators_.erase(it2); } - // also remove an arena wrapped allocator from an EP if the user called CreateSharedAllocator to create one - if (auto it3 = arena_ort_allocators_.find(&mem_info); it3 != arena_ort_allocators_.end()) { - arena_ort_allocators_.erase(it3); - } - if (found_shared_allocator) { shared_allocators_.erase(it); } @@ -428,8 +426,29 @@ Status Environment::CreateAndRegisterAllocatorV2(const std::string& provider_typ } Environment::~Environment() { - // need to make sure all the OrtAllocator instances are released prior to any plugin EPs being freed + // need to make sure all the OrtAllocator instances are released prior to any plugin EPs being freed. + // this is because any entry in shared_allocators_ wrapping an OrtAllocator from a plugin EP owns the OrtAllocator + // instance and will call Release on it. If the plugin EP has been freed the Release will fail. shared_allocators_.clear(); + + // and as any OrtAllocator instances in shared_ort_allocators_ were owned by values in shared_allocators_ and have + // now been released we need to clear that too before calling UnregisterExecutionProviderLibrary(). + shared_ort_allocators_.clear(); + +#if !defined(ORT_MINIMAL_BUILD) + // unregister any remaining EP libraries so they're cleaned up in a determistic way. + while (!ep_libraries_.empty()) { + auto it = ep_libraries_.begin(); + ORT_IGNORE_RETURN_VALUE(UnregisterExecutionProviderLibrary(it->first)); + } +#endif +} + +AllocatorPtr Environment::GetRegisteredSharedAllocator(const OrtMemoryInfo& mem_info) const { + std::lock_guard lock{mutex_}; + + auto it = FindExistingAllocator(shared_allocators_, mem_info, /*match_name*/ false, /*ignore_alignment*/ true); + return it != shared_allocators_.end() ? *it : nullptr; } Status Environment::GetSharedAllocator(const OrtMemoryInfo& mem_info, OrtAllocator*& allocator) { @@ -653,11 +672,6 @@ Status Environment::CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, shared_ort_allocators_.erase(it); } - // if a previous call created an arena wrapped allocator for the EP's memory_info we also need to remove that - if (auto it = arena_ort_allocators_.find(&memory_info); it != arena_ort_allocators_.end()) { - arena_ort_allocators_.erase(it); - } - // we only want one shared allocator for an OrtDevice in the shared_allocators_ so that it's deterministic which // one will be used for an inference session. ignore the name so that is the case. if (auto it = FindExistingAllocator(shared_allocators_, memory_info, /*match_name*/ false); diff --git a/onnxruntime/core/session/ep_library_internal.cc b/onnxruntime/core/session/ep_library_internal.cc deleted file mode 100644 index 986ccb1fa17fc..0000000000000 --- a/onnxruntime/core/session/ep_library_internal.cc +++ /dev/null @@ -1,281 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/session/ep_library_internal.h" - -#include "core/framework/error_code_helper.h" -#include "core/framework/ortmemoryinfo.h" -#include "core/framework/session_options.h" -#include "core/providers/cpu/cpu_execution_provider.h" -#include "core/session/abi_devices.h" -#include "core/session/abi_logger.h" -#include "core/session/abi_session_options_impl.h" -#include "core/session/ep_api.h" -#include "core/session/ort_apis.h" - -#if defined(USE_DML) -#include "core/providers/dml/dml_provider_factory_creator.h" -#endif - -#if defined(USE_WEBGPU) -#include "core/providers/webgpu/webgpu_provider_factory_creator.h" -#endif - -namespace onnxruntime { - -class CpuEpFactory : public EpFactoryInternalImpl { - public: - CpuEpFactory() : EpFactoryInternalImpl(kCpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) noexcept override { - size_t& num_ep_devices = *p_num_ep_devices; - for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { - const OrtHardwareDevice& device = *devices[i]; - if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { - ORT_API_RETURN_IF_ERROR( - OrtExecutionProviderApi::CreateEpDevice(&ep_factory, &device, nullptr, nullptr, - &ep_devices[num_ep_devices++])); - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - if (num_devices != 1) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "CPU EP factory currently only supports one device at a time."); - } - - CPUExecutionProviderInfo epi{session_options->value.enable_cpu_mem_arena}; - *ep = std::make_unique(epi); - (*ep)->SetLogger(session_logger->ToInternal()); - - return nullptr; - } -}; - -std::unique_ptr EpLibraryInternal::CreateCpuEp() { - auto cpu_factory_impl = std::make_unique(); - auto internal_factory = std::make_unique(std::move(cpu_factory_impl)); - return std::make_unique(std::move(internal_factory)); -} - -#if defined(USE_DML) -class DmlEpFactory : public EpFactoryInternalImpl { - public: - DmlEpFactory() : EpFactoryInternalImpl(kDmlExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) noexcept override { - size_t& num_ep_devices = *p_num_ep_devices; - num_ep_devices = 0; - - for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { - const OrtHardwareDevice& device = *devices[i]; - if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - std::unique_ptr ep_options; - - // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is - // associated with a specific device. - // How would we know what options should not allow user overrides if set in OrtEpDevice? - int32_t device_id = 0; // If no device_id was found default to 0 - if (auto it = device.metadata.Entries().find("DxgiAdapterNumber"); it != device.metadata.Entries().end()) { - ep_options = std::make_unique(); - device_id = std::stoi(it->second); - } - - ep_options->Add("device_id", std::to_string(device_id)); - - auto* api_status = OrtExecutionProviderApi::CreateEpDevice(&ep_factory, - &device, nullptr, ep_options.get(), - &ep_devices[num_ep_devices]); - - if (device_memory_infos.size() < device_id + 1) { - device_memory_infos.resize(device_id + 1); - device_allocators.resize(device_id + 1); - } - - if (device_memory_infos[device_id] == nullptr) { - // Create memory info for the device if it doesn't already exist - device_memory_infos[device_id] = std::make_unique( - "DML", OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, - narrow(device_id))); - } - - // This is what we need to add once CreateAllocator is implemented to create a shared allocator for the device. - // OrtExecutionProviderApi::EpDevice_AddAllocatorInfo(ep_devices[num_ep_devices], - // device_memory_infos[device_id].get()); - - if (api_status != nullptr) { - return api_status; - } - - ++num_ep_devices; - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - *ep = nullptr; - - if (num_devices != 1) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "DML EP factory currently only supports one device at a time."); - } - - auto ep_options = GetOptionsFromSessionOptions(session_options->value); - auto dml_ep_factory = DMLProviderFactoryCreator::CreateFromProviderOptions(session_options->value.config_options, - ep_options); - - *ep = dml_ep_factory->CreateProvider(); - (*ep)->SetLogger(session_logger->ToInternal()); - - return nullptr; - } - - OrtStatus* CreateAllocator(const OrtMemoryInfo* /*memory_info*/, - const OrtKeyValuePairs* /*allocator_options*/, - OrtAllocator** allocator) noexcept override { - // TODO: This needs to create an allocator for the specific device so it's available as a shared allocator. That - // requires pulling lots of things out of the DML EP to get the D3D12 device and create a - // BucketizedBufferAllocator. See providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp - //*allocator = device_allocators[memory_info->device.Id()].get(); - *allocator = nullptr; - return nullptr; - } - - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { - // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. - *data_transfer = nullptr; - return nullptr; - } - - std::vector> device_memory_infos; // memory info for each device - std::vector> device_allocators; // allocators for each device -}; - -std::unique_ptr EpLibraryInternal::CreateDmlEp() { - auto dml_factory_impl = std::make_unique(); - auto internal_factory = std::make_unique(std::move(dml_factory_impl)); - return std::make_unique(std::move(internal_factory)); -} -#endif - -#if defined(USE_WEBGPU) -class WebGpuEpFactory : public EpFactoryInternalImpl { - public: - WebGpuEpFactory() : EpFactoryInternalImpl(kWebGpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) noexcept override { - size_t& num_ep_devices = *p_num_ep_devices; - num_ep_devices = 0; - - for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { - const OrtHardwareDevice& device = *devices[i]; - if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - // TODO: any metadata or options to add? - ORT_API_RETURN_IF_ERROR(OrtExecutionProviderApi::CreateEpDevice(&ep_factory, - &device, nullptr, nullptr, - &ep_devices[num_ep_devices++])); - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - *ep = nullptr; - - if (num_devices != 1) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "WebGPU EP factory currently only supports one device at a time."); - } - - auto webgpu_ep_factory = WebGpuProviderFactoryCreator::Create(session_options->value.config_options); - *ep = webgpu_ep_factory->CreateProvider(); - (*ep)->SetLogger(session_logger->ToInternal()); - - return nullptr; - } - - /* TODO: Implement CreateAllocator and CreateDataTransfer to support shared allocators and data transfer outside of - an InferenceSession. - OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, - const OrtKeyValuePairs* allocator_options, - OrtAllocator** allocator) noexcept override { - *allocator = device_allocators[memory_info->device.Id()].get(); - } - - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { - // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. - *data_transfer = nullptr; - return nullptr; - } - */ -}; - -std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { - auto webgpu_factory_impl = std::make_unique(); - auto internal_factory = std::make_unique(std::move(webgpu_factory_impl)); - return std::make_unique(std::move(internal_factory)); -} -#endif - -std::vector> EpLibraryInternal::CreateInternalEps() { - std::vector> internal_eps; - internal_eps.reserve(4); - - // CPU EP - internal_eps.push_back(CreateCpuEp()); - -#if defined(USE_WEBGPU) - internal_eps.push_back(CreateWebGpuEp()); -#endif - -#if defined(USE_DML) - internal_eps.push_back(CreateDmlEp()); -#endif - - return internal_eps; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library_provider_bridge.cc b/onnxruntime/core/session/ep_library_provider_bridge.cc deleted file mode 100644 index ae553891beaa7..0000000000000 --- a/onnxruntime/core/session/ep_library_provider_bridge.cc +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/session/ep_library_provider_bridge.h" - -#include "core/common/status.h" -#include "core/framework/error_code_helper.h" -#include "core/framework/session_options.h" -#include "core/providers/cuda/cuda_provider_options.h" -#include "core/providers/shared_library/provider_host_api.h" -#include "core/session/abi_devices.h" -#include "core/session/abi_session_options_impl.h" -#include "core/session/ep_factory_internal.h" - -namespace onnxruntime { -class ProviderBridgeEpFactory : public EpFactoryInternalImpl { - public: - ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library) - : EpFactoryInternalImpl(ep_factory.GetName(&ep_factory), - ep_factory.GetVendor(&ep_factory), - ep_factory.GetVendorId(&ep_factory)), - ep_factory_{ep_factory}, - provider_library_{provider_library} { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* num_ep_devices) noexcept override { - ORT_API_RETURN_IF_ERROR(ep_factory_.GetSupportedDevices(&ep_factory_, devices, num_devices, ep_devices, - max_ep_devices, num_ep_devices)); - - // add the EpFactoryInternal layer back in so that we can redirect to CreateIExecutionProvider. - for (size_t i = 0; i < *num_ep_devices; ++i) { - auto* ep_device = ep_devices[i]; - if (ep_device) { - ep_device->ep_factory = &ep_factory; - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, - const OrtKeyValuePairs* const* ep_metadata_pairs, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - // get the provider specific options - auto ep_options = GetOptionsFromSessionOptions(session_options->value); - auto& provider = provider_library_.Get(); - - auto status = provider.CreateIExecutionProvider(devices, ep_metadata_pairs, num_devices, - ep_options, *session_options, *session_logger, *ep); - - return ToOrtStatus(status); - } - - OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, - const OrtKeyValuePairs* allocator_options, - OrtAllocator** allocator) noexcept override { - return ep_factory_.CreateAllocator(&ep_factory_, memory_info, allocator_options, allocator); - } - - void ReleaseAllocator(OrtAllocator* allocator) noexcept override { - ep_factory_.ReleaseAllocator(&ep_factory_, allocator); - } - - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { - return ep_factory_.CreateDataTransfer(&ep_factory_, data_transfer); - } - - bool IsStreamAware() const noexcept override { - return ep_factory_.IsStreamAware(&ep_factory_); - } - - OrtStatus* CreateSyncStreamForDevice(const OrtMemoryDevice* device, - const OrtKeyValuePairs* stream_options, - OrtSyncStreamImpl** stream) noexcept override { - return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); - } - - OrtEpFactory& ep_factory_; // OrtEpFactory from the provider bridge EP - ProviderLibrary& provider_library_; // ProviderLibrary from the provider bridge EP -}; - -Status EpLibraryProviderBridge::Load() { - std::lock_guard lock{mutex_}; - - if (!factories_.empty()) { - // already loaded - return Status::OK(); - } - - // if we have been unloaded we can't just be reloaded. - if (!ep_library_plugin_ || !provider_library_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "EpLibraryProviderBridge has been unloaded. " - "Please create a new instance using LoadPluginOrProviderBridge."); - } - - // wrap the EpLibraryPlugin factories that were created via calling CreateEpFactories in the library. - // use GetSupportedDevices from the library's factory. - // to do this we need to capture `factory` and plug it in to is_supported_fn and create_fn. - // we also need to update any returned OrtEpDevice instances to swap the wrapper EpFactoryInternal in so that we can - // call Provider::CreateIExecutionProvider in EpFactoryInternal::CreateIExecutionProvider. - for (const auto& factory : ep_library_plugin_->GetFactories()) { - auto factory_impl = std::make_unique(*factory, *provider_library_); - auto internal_factory = std::make_unique(std::move(factory_impl)); - - factory_ptrs_.push_back(internal_factory.get()); - internal_factory_ptrs_.push_back(internal_factory.get()); - factories_.push_back(std::move(internal_factory)); - } - - return Status::OK(); -} - -Status EpLibraryProviderBridge::Unload() { - std::lock_guard lock{mutex_}; - - internal_factory_ptrs_.clear(); - factory_ptrs_.clear(); - factories_.clear(); - - // we loaded ep_library_plugin_ after provider_library_ in LoadPluginOrProviderBridge so do the reverse order here. - ORT_RETURN_IF_ERROR(ep_library_plugin_->Unload()); - ep_library_plugin_ = nullptr; - - provider_library_->Unload(); - provider_library_ = nullptr; - - return Status::OK(); -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index f4f76a389030e..c0900c5ad28a0 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1421,6 +1421,29 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool } } + // We choose to convert initializers into OrtValues before partitioning here so plug-in EPs could + // take advantage of the initializers being in OrtValue format and not to deal with protobuf. + // + // The initializers data is transferred to an OrtValue. The original TensorProto is replaced + // with a TensorProto that has the same data type, shape and name. However, its external data + // is used in a non-standard way. The location is set to a string constant utils::kTensorProtoMemoryAddressTag, + // The file offset is set to the address of the OrtValue's data buffer, and the length is set to the size of the + // OrtValue's data buffer. Because this external location is non-standard, onnx code can not handle it, so we choose + // to do it as late as possible but before the partitioning so type and shape inference accesses the initializers + // before they are converted to OrtValues. + // + // If any transformations are applied later, they would not introduce any in-memory initializers, + // type and shape inference would run only on any newly added nodes and any new initializers + // will be converted at session finalization time. + // + // The conversion is performed using the following steps (within ConvertInitializersIntoOrtValues()) + // constexpr const bool use_tensor_buffer_true = true; + // auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get(), tensor_proto.name(), + // use_tensor_buffer_true); + // ORT_RETURN_IF_ERROR(graph.ReplaceInitializedTensor(tensor_proto_to_add, ort_value)); + + ORT_RETURN_IF_ERROR_SESSIONID_(graph.ConvertInitializersIntoOrtValues()); + // Do partitioning based on execution providers' capabilities. ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state_->GetMutableFuncMgr(), transform_layout_fn, session_options_.config_options, *session_logger_, @@ -1984,13 +2007,15 @@ static void ResolveMemoryPatternFlags(SessionState& session_state) { // For now, this function only checks for invalid combination of DML EP with other EPs. // TODO: extend this function to check for other invalid combinations of EPs. common::Status InferenceSession::HasInvalidCombinationOfExecutionProviders() const { - // DML EP is only allowed with CPU EP + // DML EP is not allowed with other GPU or NPU EPs. + // historical reason for this is unknown. relaxing the limit that it must only be used with the CPU EP to support + // scenarios where alternative EPs are CPU based (e.g. openvino). bool has_dml_ep = execution_providers_.Get(kDmlExecutionProvider) != nullptr; if (has_dml_ep) { - const auto& ep_list = execution_providers_.GetIds(); - for (const auto& ep : ep_list) { - if (ep == kDmlExecutionProvider || ep == kCpuExecutionProvider) continue; - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DML EP can be used with only CPU EP."); + for (const auto& ep : execution_providers_) { + if (ep->Type() != kDmlExecutionProvider && ep->GetDevice().Type() != OrtDevice::CPU) { + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DML EP can only be used with CPU EPs."); + } } } return Status::OK(); diff --git a/onnxruntime/core/session/lora_adapters.cc b/onnxruntime/core/session/lora_adapters.cc index 85ea958981e2c..124d748029fd4 100644 --- a/onnxruntime/core/session/lora_adapters.cc +++ b/onnxruntime/core/session/lora_adapters.cc @@ -53,11 +53,11 @@ void LoraAdapter::MemoryMap(const std::filesystem::path& file_path) { static std::unique_ptr GetDataTransfer(const OrtMemoryInfo& mem_info) { std::unique_ptr data_transfer; - if (strcmp(mem_info.name, onnxruntime::CPU) == 0) { + if (mem_info.name == onnxruntime::CPU) { return data_transfer; } - if (strcmp(mem_info.name, onnxruntime::CUDA) == 0) { + if (mem_info.name == onnxruntime::CUDA) { #if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) auto* cuda_provider_info = TryGetProviderInfo_CUDA(); if (cuda_provider_info != nullptr) { diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index bbb110033f54c..84f41771cb62b 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -7,8 +7,11 @@ #include #include #include +#include +#include "core/common/path_string.h" #include "core/framework/allocator.h" +#include "core/framework/ep_context_options.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/environment.h" @@ -22,14 +25,16 @@ ModelCompilationOptions::ModelCompilationOptions(const onnxruntime::Environment& // defaulting to kGenerateModel to support wider usage. session_options_.value.ep_context_gen_options.action_if_no_compiled_nodes = - EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kGenerateModel; + epctx::ModelGenOptions::ActionIfNoCompiledNodes::kGenerateModel; // Shouldn't fail because the key/value strings are below the maximum string length limits in ConfigOptions. ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1").IsOK()); ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionsDisableModelCompile, "0").IsOK()); + + session_options_.value.graph_optimization_level = TransformerLevel::Default; // L0: required transformers only } -void ModelCompilationOptions::SetInputModelPath(const std::string& input_model_path) { +void ModelCompilationOptions::SetInputModelPath(const std::filesystem::path& input_model_path) { ResetInputModelSettings(); input_model_path_ = input_model_path; } @@ -40,17 +45,16 @@ void ModelCompilationOptions::SetInputModelFromBuffer(const void* input_model_da input_model_data_size_ = input_model_data_size; } -Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_model_path) { - ORT_RETURN_IF_ERROR(ResetOutputModelSettings()); - +Status ModelCompilationOptions::SetOutputModelPath(const std::filesystem::path& output_model_path) { ConfigOptions& config_options = session_options_.value.config_options; - EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; + epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; - ep_context_gen_options.output_model_file_path = output_model_path; + ep_context_gen_options.output_model_location = output_model_path; - if (ep_context_gen_options.output_model_file_path.size() <= ConfigOptions::kMaxValueLength) { - Status status = config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, - ep_context_gen_options.output_model_file_path.c_str()); + std::string output_model_path_str = PathToUTF8String(output_model_path); + + if (output_model_path_str.size() <= ConfigOptions::kMaxValueLength) { + Status status = config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, output_model_path_str.c_str()); ORT_ENFORCE(status.IsOK()); // Should not fail because both key/value strings are below the min string lengths // required by ConfigOptions::AddConfigEntry(). } else { @@ -71,7 +75,7 @@ Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_mod logging::LoggingManager* log_manager = env_.GetLoggingManager(); if (log_manager != nullptr && log_manager->HasDefaultLogger()) { const logging::Logger& logger = log_manager->DefaultLogger(); - LOGS(logger, WARNING) << "Output model path length (" << ep_context_gen_options.output_model_file_path.size() + LOGS(logger, WARNING) << "Output model path length (" << output_model_path_str.size() << ") exceeds limit of " << ConfigOptions::kMaxValueLength << " characters." << "ORT will still generate the expected output file, but EPs will see an empty " << "output model path in SessionOption's ConfigOptions."; @@ -80,40 +84,58 @@ Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_mod return Status::OK(); } -void ModelCompilationOptions::SetOutputModelExternalInitializersFile(const std::string& external_initializers_path, - size_t external_initializer_size_threshold) { - session_options_.value.ep_context_gen_options.output_external_initializers_file_path = external_initializers_path; - session_options_.value.ep_context_gen_options.output_external_initializer_size_threshold = - external_initializer_size_threshold; +void ModelCompilationOptions::SetOutputModelExternalInitializersFile( + const std::filesystem::path& external_initializers_path, + size_t external_initializer_size_threshold) { + session_options_.value.ep_context_gen_options.initializers_location = epctx::ExternalInitializerFileInfo{ + external_initializers_path, + external_initializer_size_threshold, + }; } Status ModelCompilationOptions::SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr) { - ORT_RETURN_IF_ERROR(ResetOutputModelSettings()); + session_options_.value.ep_context_gen_options.output_model_location = epctx::BufferHolder{ + output_model_buffer_ptr, + output_model_buffer_size_ptr, + std::move(allocator), + }; - session_options_.value.ep_context_gen_options.output_model_buffer_ptr = output_model_buffer_ptr; - session_options_.value.ep_context_gen_options.output_model_buffer_size_ptr = output_model_buffer_size_ptr; - session_options_.value.ep_context_gen_options.output_model_buffer_allocator = std::move(allocator); return Status::OK(); } -Status ModelCompilationOptions::SetEpContextBinaryInformation(const std::string& output_directory, - const std::string& model_name) { +void ModelCompilationOptions::SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void* state) { + session_options_.value.ep_context_gen_options.output_model_location = epctx::BufferWriteFuncHolder{ + write_func, + state, + }; +} + +void ModelCompilationOptions::SetOutputModelGetInitializerLocationFunc( + OrtGetInitializerLocationFunc get_initializer_location_func, void* state) { + session_options_.value.ep_context_gen_options.initializers_location = epctx::InitializerHandler{ + get_initializer_location_func, + state, + }; +} + +Status ModelCompilationOptions::SetEpContextBinaryInformation(const std::filesystem::path& output_directory, + const std::filesystem::path& model_name) { if (output_directory.empty() || model_name.empty()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir or model_name is empty."); } - std::filesystem::path output_dir_path(output_directory); - if (output_dir_path.has_filename() && output_dir_path.extension() == "") { + if (output_directory.has_filename() && output_directory.extension() == "") { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir is not a valid directory."); } - std::filesystem::path ctx_model_path = output_directory / std::filesystem::path(model_name); + std::filesystem::path ctx_model_path = output_directory / model_name; + std::string ctx_model_path_str = PathToUTF8String(ctx_model_path); - if (ctx_model_path.string().size() <= ConfigOptions::kMaxValueLength) { + if (ctx_model_path_str.size() <= ConfigOptions::kMaxValueLength) { ORT_RETURN_IF_ERROR(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, - ctx_model_path.string().c_str())); + ctx_model_path_str.c_str())); } else { logging::LoggingManager* log_manager = env_.GetLoggingManager(); if (log_manager != nullptr && log_manager->HasDefaultLogger()) { @@ -135,12 +157,12 @@ Status ModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_m return Status::OK(); } -Status ModelCompilationOptions::SetFlags(size_t flags) { - EpContextModelGenerationOptions& options = session_options_.value.ep_context_gen_options; +Status ModelCompilationOptions::SetFlags(uint32_t flags) { + epctx::ModelGenOptions& options = session_options_.value.ep_context_gen_options; options.error_if_output_file_exists = flags & OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS; options.action_if_no_compiled_nodes = - (flags & OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED) ? EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kReturnError - : EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kGenerateModel; + (flags & OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED) ? epctx::ModelGenOptions::ActionIfNoCompiledNodes::kReturnError + : epctx::ModelGenOptions::ActionIfNoCompiledNodes::kGenerateModel; return Status::OK(); } @@ -152,7 +174,7 @@ bool ModelCompilationOptions::InputModelComesFromFile() const { return !input_model_path_.empty(); } -const std::string& ModelCompilationOptions::GetInputModelPath() const { +const std::filesystem::path& ModelCompilationOptions::GetInputModelPath() const { return input_model_path_; } @@ -170,77 +192,106 @@ void ModelCompilationOptions::ResetInputModelSettings() { input_model_data_size_ = 0; } -Status ModelCompilationOptions::ResetOutputModelSettings() { - EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; - ep_context_gen_options.output_model_file_path.clear(); - ep_context_gen_options.output_model_buffer_ptr = nullptr; - ep_context_gen_options.output_model_buffer_size_ptr = nullptr; - ep_context_gen_options.output_model_buffer_allocator = nullptr; +Status ModelCompilationOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) { + switch (graph_optimization_level) { + case ORT_DISABLE_ALL: + // TransformerLevel::Default means that we only run required transformers. + session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Default; + break; + case ORT_ENABLE_BASIC: + session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level1; + break; + case ORT_ENABLE_EXTENDED: + session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level2; + break; + case ORT_ENABLE_LAYOUT: + session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level3; + break; + case ORT_ENABLE_ALL: + session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::MaxLevel; + break; + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "graph_optimization_level with value ", + static_cast(graph_optimization_level), " is invalid. Valid values are: ", + "ORT_DISABLE_ALL (0), ORT_ENABLE_BASIC (1), ORT_ENABLE_EXTENDED (2), ", + "ORT_ENABLE_LAYOUT (3), and ORT_ENABLE_ALL (99)."); + } + return Status::OK(); } -Status ModelCompilationOptions::CheckInputModelSettings() const { - const bool comes_from_file = !input_model_path_.empty(); - const bool comes_from_memory = input_model_data_ != nullptr; +Status ModelCompilationOptions::Check() const { + const ConfigOptions& config_options = session_options_.value.config_options; + + ORT_ENFORCE(session_options_.value.ep_context_gen_options.enable); + ORT_ENFORCE(config_options.GetConfigOrDefault(kOrtSessionOptionsDisableModelCompile, "0") == "0"); - if (!comes_from_file && !comes_from_memory) { + // Check input model settings. + const bool input_from_file = !input_model_path_.empty(); + const bool input_from_memory = input_model_data_ != nullptr; + + if (!input_from_file && !input_from_memory) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input model to compile must be loaded from either a file or a memory buffer"); } - if (comes_from_file && comes_from_memory) { + if (input_from_file && input_from_memory) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input model to compile must be loaded from either a file or a memory buffer, ", "but not both."); } - if (comes_from_file && !std::filesystem::exists(input_model_path_)) { + if (input_from_file && !std::filesystem::exists(input_model_path_)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input model path does not exist: ", input_model_path_); } - if (comes_from_memory && input_model_data_size_ == 0) { + if (input_from_memory && input_model_data_size_ == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Buffer for input model data has size 0"); } - return Status::OK(); -} + // Check output model settings. + const epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; + bool has_no_output_model_location = std::holds_alternative( + ep_context_gen_options.output_model_location); -Status ModelCompilationOptions::CheckOutputModelSettings() const { - const EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; - - const bool explicit_writes_to_file = !ep_context_gen_options.output_model_file_path.empty(); - const bool writes_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr; - - if (!explicit_writes_to_file && !writes_to_buffer) { - // User did not specify an output file or an output buffer. We default to generating an output file - // with a name based on the input file name, so do not return an error. + if (has_no_output_model_location && input_from_file) { + // User did not specify an output file, an output buffer, or an output write function. We default to generating an + // output file with a name based on the input file name, so do not return an error. return Status::OK(); } - if (explicit_writes_to_file && writes_to_buffer) { + if (has_no_output_model_location) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Output model to compile must be saved either to a file or to a buffer, but not both."); + "Unable to generate an output model path: require an input model path if the location " + "of the output model (e.g., file, buffer, or stream) is not specified."); } - if (writes_to_buffer && ep_context_gen_options.output_model_buffer_size_ptr == nullptr) { + const epctx::BufferHolder* output_buffer_ptr = ep_context_gen_options.TryGetOutputModelBuffer(); + + if (output_buffer_ptr != nullptr && output_buffer_ptr->buffer_ptr == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid buffer configuration for output model: buffer pointer is null"); + } + + if (output_buffer_ptr != nullptr && output_buffer_ptr->buffer_size_ptr == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid buffer configuration for output model: size pointer is null"); } - if (writes_to_buffer && ep_context_gen_options.output_model_buffer_allocator == nullptr) { + if (output_buffer_ptr != nullptr && output_buffer_ptr->buffer_allocator == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid buffer configuration for output model: allocator is null"); } - return Status::OK(); -} + const epctx::BufferWriteFuncHolder* output_write_func_holder = ep_context_gen_options.TryGetOutputModelWriteFunc(); + + if (output_write_func_holder != nullptr && output_write_func_holder->write_func == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid buffer writing function for output model: function pointer is null"); + } -Status ModelCompilationOptions::Check() const { - ORT_ENFORCE(session_options_.value.ep_context_gen_options.enable); - ORT_ENFORCE(session_options_.value.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableModelCompile, "0") == "0"); - ORT_RETURN_IF_ERROR(CheckInputModelSettings()); - ORT_RETURN_IF_ERROR(CheckOutputModelSettings()); return Status::OK(); } + } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index 2824df863013d..45323e6cb13c5 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -4,6 +4,7 @@ #if !defined(ORT_MINIMAL_BUILD) #pragma once +#include #include #include #include "core/common/status.h" @@ -34,7 +35,7 @@ class ModelCompilationOptions { /// Overrides any previous call to SetInputModelPath() or SetInputModelFromBuffer(). /// /// The input model's path - void SetInputModelPath(const std::string& input_model_path); + void SetInputModelPath(const std::filesystem::path& input_model_path); /// /// Sets the buffer that stores the input ONNX model to compile. @@ -50,7 +51,7 @@ class ModelCompilationOptions { /// /// /// Status indicating potential error - Status SetOutputModelPath(const std::string& output_model_path); + Status SetOutputModelPath(const std::filesystem::path& output_model_path); /// /// Sets the file path to the file that will store external ONNX initializers for the compiled model. @@ -58,7 +59,7 @@ class ModelCompilationOptions { /// /// Path to the external initializers file to generate /// Initializers that exceed this threshold are external - void SetOutputModelExternalInitializersFile(const std::string& external_initializers_path, + void SetOutputModelExternalInitializersFile(const std::filesystem::path& external_initializers_path, size_t external_initializer_size_threshold); /// @@ -72,6 +73,21 @@ class ModelCompilationOptions { Status SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); + /// + /// Sets an output stream (write function + state) used to write out the compiled model bytes. + /// + /// Write function + /// The user's state + void SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void* state); + + /// + /// Sets a user-provided function to handle serialization of ONNX initializers. + /// + /// The user-provided function called for every initializer + /// The user's state. + void SetOutputModelGetInitializerLocationFunc(OrtGetInitializerLocationFunc get_initializer_location_func, + void* state); + /// /// Sets information relate to EP context binary file. /// EP use this information to decide the location and context binary file name. @@ -80,7 +96,8 @@ class ModelCompilationOptions { /// The folder path to the generated context binary file /// Model name used to decide the context binary file name: [model_name]_[ep].bin /// Status indicating potential error - Status SetEpContextBinaryInformation(const std::string& output_directory, const std::string& model_name); + Status SetEpContextBinaryInformation(const std::filesystem::path& output_directory, + const std::filesystem::path& model_name); /// /// Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute of EPContext @@ -95,7 +112,7 @@ class ModelCompilationOptions { /// /// unsigned integer set to the bitwise OR of enabled flags. /// Status indicating success or an error - Status SetFlags(size_t flags); + Status SetFlags(uint32_t flags); /// /// Returns a reference to the session options object. @@ -107,7 +124,7 @@ class ModelCompilationOptions { /// Returns the file path to the input ONNX model. /// /// input model's path - const std::string& GetInputModelPath() const; + const std::filesystem::path& GetInputModelPath() const; /// /// Returns true if the input model is read from a file. @@ -129,6 +146,13 @@ class ModelCompilationOptions { /// input model buffer's size in bytes size_t GetInputModelDataSize() const; + /// + /// Sets the graph optimization level for the underlying session that compiles the model. + /// + /// The optimization level + /// + Status SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); + /// /// Checks if the compilation options described by this object are valid. /// @@ -137,13 +161,10 @@ class ModelCompilationOptions { private: void ResetInputModelSettings(); - Status ResetOutputModelSettings(); - Status CheckInputModelSettings() const; - Status CheckOutputModelSettings() const; const onnxruntime::Environment& env_; OrtSessionOptions session_options_; - std::string input_model_path_; + std::filesystem::path input_model_path_; const void* input_model_data_ = nullptr; size_t input_model_data_size_ = 0; }; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 27f81b18be0c9..21d09df5cc4db 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -38,8 +38,8 @@ #include "core/session/allocator_adapters.h" #include "core/session/compile_api.h" #include "core/session/environment.h" -#include "core/session/ep_api.h" -#include "core/session/ep_library_internal.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_library_internal.h" #include "core/session/inference_session.h" #include "core/session/inference_session_utils.h" #include "core/session/IOBinding.h" @@ -1094,7 +1094,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorMutableData, _Inout_ OrtValue* value, _Out API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::GetTensorData, _Inout_ const OrtValue* value, _Outptr_ const void** output) { +ORT_API_STATUS_IMPL(OrtApis::GetTensorData, _In_ const OrtValue* value, _Outptr_ const void** output) { TENSOR_READ_API_BEGIN *output = tensor.DataRaw(); return nullptr; @@ -1378,9 +1378,16 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetOverridableInitializerTypeInfo, _In_ cons return GetNodeDefTypeInfoHelper(sess, get_overridable_initializers_fn, index, out); } -char* onnxruntime::StrDup(const std::string& str, OrtAllocator* allocator) { - char* output_string = reinterpret_cast(allocator->Alloc(allocator, str.size() + 1)); - memcpy(output_string, str.c_str(), str.size()); +char* onnxruntime::StrDup(std::string_view str, OrtAllocator* allocator) { + char* output_string = static_cast(allocator->Alloc(allocator, str.size() + 1)); + memcpy(output_string, str.data(), str.size()); + output_string[str.size()] = '\0'; + return output_string; +} + +wchar_t* onnxruntime::StrDup(std::wstring_view str, OrtAllocator* allocator) { + auto* output_string = static_cast(allocator->Alloc(allocator, (str.size() + 1) * sizeof(wchar_t))); + memcpy(output_string, str.data(), str.size() * sizeof(wchar_t)); output_string[str.size()] = '\0'; return output_string; } @@ -2531,6 +2538,23 @@ ORT_API(void, OrtApis::ReleaseExternalInitializerInfo, _Frees_ptr_opt_ OrtExtern delete static_cast(info); } +ORT_API_STATUS_IMPL(OrtApis::CreateExternalInitializerInfo, _In_ const ORTCHAR_T* filepath, + _In_ int64_t file_offset, _In_ size_t byte_size, _Outptr_ OrtExternalInitializerInfo** out) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto ext_data_info = std::make_unique(filepath, file_offset, byte_size); + *out = ext_data_info.release(); + return nullptr; +#else + ORT_UNUSED_PARAMETER(filepath); + ORT_UNUSED_PARAMETER(file_offset); + ORT_UNUSED_PARAMETER(byte_size); + ORT_UNUSED_PARAMETER(out); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateExternalInitializerInfo() is not supported in this build."); +#endif + API_IMPL_END +} + ORT_API(const ORTCHAR_T*, OrtApis::ExternalInitializerInfo_GetFilePath, _In_ const OrtExternalInitializerInfo* info) { return info->GetRelPath().c_str(); } @@ -2619,6 +2643,16 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out) { + API_IMPL_BEGIN + if (out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'out' argument is NULL"); + } + *out = reinterpret_cast(graph->GetModelMetadata().release()); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::Graph_GetModelPath, _In_ const OrtGraph* graph, _Outptr_ const ORTCHAR_T** model_path) { API_IMPL_BEGIN if (model_path == nullptr) { @@ -2754,7 +2788,8 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetNodes, _In_ const OrtGraph* graph, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_ const OrtNode** node) { +ORT_API_STATUS_IMPL(OrtApis::Graph_GetParentNode, _In_ const OrtGraph* graph, + _Outptr_result_maybenull_ const OrtNode** node) { API_IMPL_BEGIN if (node == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'node' argument is NULL"); @@ -2993,7 +3028,8 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributes, _In_ const OrtNode* node, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_ const OrtOpAttr** attribute) { +ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, + _Outptr_result_maybenull_ const OrtOpAttr** attribute) { API_IMPL_BEGIN if (attribute == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'attribute' argument is NULL"); @@ -3004,14 +3040,62 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetAttributeByName."); } - *attribute = ep_node->GetAttribute(attribute_name); + bool is_unset_optional_attr = false; + *attribute = ep_node->GetAttribute(attribute_name, is_unset_optional_attr); - if (*attribute) { + if (*attribute || is_unset_optional_attr) { return nullptr; } else { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute does not exist."); + std::ostringstream oss; + oss << "Node attribute does not exist: " << attribute_name; + return OrtApis::CreateStatus(OrtErrorCode::ORT_NOT_FOUND, oss.str().c_str()); } + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) { + API_IMPL_BEGIN + if (attr_tensor == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attr_tensor argument is null"); + } + if (attribute == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null"); + } + + const auto* attr_proto = reinterpret_cast(attribute); + + if (attr_proto->type() != onnx::AttributeProto::TENSOR) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "This OrtOpAttr instance is not a 'TENSOR' attribute"); + } + + const auto& tensor_proto = attr_proto->t(); + // Check that TensorProto is valid. + if (!utils::HasDataType(tensor_proto)) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Tensor proto doesn't have data type."); + } + + if (!ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type())) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Tensor proto has invalid data type."); + } + + if (utils::HasExternalData(tensor_proto)) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, + "Tensor proto with external data for value attribute is not supported."); + } + + // Initialize OrtValue for tensor attribute. + auto tensor_attribute_value = std::make_unique(); + AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); + // The tensor in the 'Tensor' attribute's TensorProto is stored inline, not in an external file. + // Therefore, the 'model_path' passed to TensorProtoToOrtValue() may be an empty path. + std::filesystem::path model_path; + ORT_API_RETURN_IF_STATUS_NOT_OK(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto, + tensor_attribute_allocator, *tensor_attribute_value)); + + *attr_tensor = tensor_attribute_value.release(); + + return nullptr; API_IMPL_END } @@ -3052,6 +3136,10 @@ ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetType, _In_ const OrtOpAttr* attribute, _O *type = OrtOpAttrType::ORT_OP_ATTR_GRAPH; break; } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR: { + *type = OrtOpAttrType::ORT_OP_ATTR_TENSOR; + break; + } default: return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Unexpected attribute type."); } @@ -3181,7 +3269,7 @@ ORT_API(void, OrtApis::GetKeyValuePairs, _In_ const OrtKeyValuePairs* kvps, *num_entries = kvps->Entries().size(); } -ORT_API(void, OrtApis::RemoveKeyValuePair, _Frees_ptr_opt_ OrtKeyValuePairs* kvps, _In_ const char* key) { +ORT_API(void, OrtApis::RemoveKeyValuePair, _In_ OrtKeyValuePairs* kvps, _In_ const char* key) { kvps->Remove(key); } @@ -3190,7 +3278,7 @@ ORT_API(void, OrtApis::ReleaseKeyValuePairs, _Frees_ptr_opt_ OrtKeyValuePairs* k } #if !defined(ORT_MINIMAL_BUILD) -ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* env, const char* registration_name, +ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name, const ORTCHAR_T* path) { API_IMPL_BEGIN ORT_API_RETURN_IF_STATUS_NOT_OK(env->GetEnvironment().RegisterExecutionProviderLibrary(registration_name, path)); @@ -3198,7 +3286,7 @@ ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* env, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::UnregisterExecutionProviderLibrary, _In_ OrtEnv* env, const char* registration_name) { +ORT_API_STATUS_IMPL(OrtApis::UnregisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name) { API_IMPL_BEGIN ORT_API_RETURN_IF_STATUS_NOT_OK(env->GetEnvironment().UnregisterExecutionProviderLibrary(registration_name)); return nullptr; @@ -3384,25 +3472,86 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* env, API_IMPL_END } +// Validate compiled model compatibility info for specific EP device(s) +ORT_API_STATUS_IMPL(OrtApis::GetModelCompatibilityForEpDevices, + _In_reads_(num_ep_devices) const OrtEpDevice* const* ep_devices, + _In_ size_t num_ep_devices, + _In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* out_status) { + API_IMPL_BEGIN + if (ep_devices == nullptr || num_ep_devices == 0 || compatibility_info == nullptr || out_status == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid argument provided to GetModelCompatibilityForEpDevices."); + } + + // Validate inputs and ensure all devices belong to the same EP/factory + const OrtEpFactory* first_factory = nullptr; + for (size_t i = 0; i < num_ep_devices; ++i) { + if (ep_devices[i] == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_devices contains a null entry."); + } + const OrtEpFactory* f = ep_devices[i]->GetMutableFactory(); + if (i == 0) { + first_factory = f; + } else if (f != first_factory) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "All ep_devices must be from the same execution provider."); + } + } + + OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + OrtStatus* ort_status = nullptr; + OrtEpFactory* factory = ep_devices[0]->GetMutableFactory(); + if (factory && factory->ValidateCompiledModelCompatibilityInfo) { + // collect hardware devices corresponding to the ep_devices + InlinedVector hardware_devices; + hardware_devices.reserve(num_ep_devices); + for (size_t i = 0; i < num_ep_devices; ++i) { + hardware_devices.push_back(ep_devices[i]->device); + } + ort_status = factory->ValidateCompiledModelCompatibilityInfo(factory, + hardware_devices.data(), + hardware_devices.size(), + compatibility_info, + &status); + } + if (ort_status != nullptr) { + return ToOrtStatus(ToStatusAndRelease(ort_status)); + } + + *out_status = status; + return nullptr; + API_IMPL_END +} + #else // defined(ORT_MINIMAL_BUILD) -ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, const char* /*registration_name*/, +ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, _In_ const char* /*registration_name*/, const ORTCHAR_T* /*path*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "RegisterExecutionProviderLibrary is not supported in a minimal build."); API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::UnregisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, - const char* /*registration_name*/) { + _In_ const char* /*registration_name*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "UnregisterExecutionProviderLibrary is not supported in a minimal build."); API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::GetEpDevices, _In_ const OrtEnv* /*env*/, _Outptr_ const OrtEpDevice* const** /*ep_devices*/, _Out_ size_t* /*num_ep_devices*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetEpDevices is not supported in a minimal build."); + API_IMPL_END +} + +// Minimal build stub for GetModelCompatibilityForEpDevices to satisfy symbol references from the API table +ORT_API_STATUS_IMPL(OrtApis::GetModelCompatibilityForEpDevices, + _In_reads_(num_ep_devices) const OrtEpDevice* const* /*ep_devices*/, + _In_ size_t /*num_ep_devices*/, + _In_ const char* /*compatibility_info*/, + _Out_ OrtCompiledModelCompatibility* /*out_status*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetModelCompatibilityForEpDevices is not supported in a minimal build."); API_IMPL_END } @@ -3414,7 +3563,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS _In_reads_(num_op_options) const char* const* /*ep_option_vals*/, size_t /*num_ep_options*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SessionOptionsAppendExecutionProvider_V2 is not supported in a minimal build."); API_IMPL_END } @@ -3427,7 +3576,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForInputs, _In_ const OrtSession* _Out_writes_(num_values) const OrtEpDevice** /*inputs_ep_devices*/, _In_ size_t /*num_values*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SessionGetEpDeviceForInputs is not supported in a minimal build."); API_IMPL_END } @@ -3435,7 +3584,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice _In_opt_ const OrtKeyValuePairs* /*stream_options*/, _Outptr_ OrtSyncStream** /*ort_stream*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateSyncStreamForEpDevice is not supported in a minimal build."); API_IMPL_END } @@ -3454,7 +3603,7 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* /*env*/, _In_opt_ OrtSyncStream* /*stream*/, _In_ size_t /*num_tensors*/) { API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CopyTensors is not supported in a minimal build."); API_IMPL_END } @@ -3522,7 +3671,7 @@ OrtStatus* GetInputOutputMemoryInfo(const OrtSession* ort_session, InlinedVector mem_info; ORT_API_RETURN_IF_STATUS_NOT_OK( - session->GetInputOutputMemoryInfo(InferenceSession::SessionInputOutputType::kInput, mem_info)); + session->GetInputOutputMemoryInfo(type, mem_info)); auto num_found = mem_info.size(); if (num_found > num_values) { @@ -4034,6 +4183,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetNumAttributes, &OrtApis::Node_GetAttributes, &OrtApis::Node_GetAttributeByName, + &OrtApis::OpAttr_GetTensorAttributeAsOrtValue, &OrtApis::OpAttr_GetType, &OrtApis::OpAttr_GetName, &OrtApis::Node_GetNumSubgraphs, @@ -4066,6 +4216,10 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::ReleaseSyncStream, &OrtApis::CopyTensors, + + &OrtApis::Graph_GetModelMetadata, + &OrtApis::GetModelCompatibilityForEpDevices, + &OrtApis::CreateExternalInitializerInfo, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index d2f22397bf82c..78616c7b3973e 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -562,7 +562,7 @@ ORT_API(void, GetKeyValuePairs, _In_ const OrtKeyValuePairs* kvps, ORT_API(void, RemoveKeyValuePair, _In_ OrtKeyValuePairs* kvps, _In_ const char* key); ORT_API(void, ReleaseKeyValuePairs, _Frees_ptr_opt_ OrtKeyValuePairs*); -ORT_API_STATUS_IMPL(RegisterExecutionProviderLibrary, _In_ OrtEnv* env, const char* registration_name, +ORT_API_STATUS_IMPL(RegisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name, const ORTCHAR_T* path); ORT_API_STATUS_IMPL(UnregisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name); @@ -635,6 +635,14 @@ ORT_API_STATUS_IMPL(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_i // OrtGraph ORT_API_STATUS_IMPL(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name); +ORT_API_STATUS_IMPL(Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out); + +// EP Compatibility Info APIs +ORT_API_STATUS_IMPL(GetModelCompatibilityForEpDevices, + _In_reads_(num_ep_devices) const OrtEpDevice* const* ep_devices, + _In_ size_t num_ep_devices, + _In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* out_status); ORT_API_STATUS_IMPL(Graph_GetModelPath, _In_ const OrtGraph* graph, _Outptr_ const ORTCHAR_T** model_path); ORT_API_STATUS_IMPL(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); ORT_API_STATUS_IMPL(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets); @@ -652,7 +660,7 @@ ORT_API_STATUS_IMPL(Graph_GetInitializers, _In_ const OrtGraph* graph, _Out_writes_(num_initializers) const OrtValueInfo** initializers, _In_ size_t num_initializers); ORT_API_STATUS_IMPL(Graph_GetNumNodes, _In_ const OrtGraph* graph, _Out_ size_t* num_nodes); -ORT_API_STATUS_IMPL(Graph_GetNodes, const OrtGraph* graph, +ORT_API_STATUS_IMPL(Graph_GetNodes, _In_ const OrtGraph* graph, _Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes); ORT_API_STATUS_IMPL(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); ORT_API_STATUS_IMPL(Graph_GetGraphView, _In_ const OrtGraph* graph, _In_ const OrtNode** nodes, _In_ size_t num_nodes, @@ -678,7 +686,9 @@ ORT_API_STATUS_IMPL(Node_GetNumAttributes, _In_ const OrtNode* node, _Out_ size_ ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node, _Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes); ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, - _Outptr_ const OrtOpAttr** attribute); + _Outptr_result_maybenull_ const OrtOpAttr** attribute); +ORT_API_STATUS_IMPL(OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute, + _Outptr_result_maybenull_ OrtValue** attr_tensor); ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type); ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); ORT_API_STATUS_IMPL(Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs); @@ -690,6 +700,8 @@ ORT_API_STATUS_IMPL(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_may // OrtExternalInitializerInfo ORT_API(void, ReleaseExternalInitializerInfo, _Frees_ptr_opt_ OrtExternalInitializerInfo* info); +ORT_API_STATUS_IMPL(CreateExternalInitializerInfo, _In_ const ORTCHAR_T* filepath, _In_ int64_t file_offset, + _In_ size_t byte_size, _Outptr_ OrtExternalInitializerInfo** out); ORT_API(const ORTCHAR_T*, ExternalInitializerInfo_GetFilePath, _In_ const OrtExternalInitializerInfo* info); ORT_API(int64_t, ExternalInitializerInfo_GetFileOffset, _In_ const OrtExternalInitializerInfo* info); ORT_API(size_t, ExternalInitializerInfo_GetByteSize, _In_ const OrtExternalInitializerInfo* info); diff --git a/onnxruntime/core/session/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc similarity index 99% rename from onnxruntime/core/session/ep_api.cc rename to onnxruntime/core/session/plugin_ep/ep_api.cc index 8fd1fc198374f..cae0b086af66c 100644 --- a/onnxruntime/core/session/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_api.h" +#include "core/session/plugin_ep/ep_api.h" #include #include diff --git a/onnxruntime/core/session/ep_api.h b/onnxruntime/core/session/plugin_ep/ep_api.h similarity index 100% rename from onnxruntime/core/session/ep_api.h rename to onnxruntime/core/session/plugin_ep/ep_api.h diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_cpu.cc b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.cc new file mode 100644 index 0000000000000..7e6d0dd2ae5df --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.cc @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_factory_cpu.h" + +#include "core/framework/error_code_helper.h" +#include "core/graph/constants.h" +#include "core/providers/cpu/cpu_execution_provider.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_logger.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { + +OrtStatus* CpuEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { + ORT_API_RETURN_IF_ERROR( + OrtExecutionProviderApi::CreateEpDevice(&ep_factory, &device, nullptr, nullptr, + &ep_devices[num_ep_devices++])); + } + } + + return nullptr; +} + +OrtStatus* CpuEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + if (num_devices != 1) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "CPU EP factory currently only supports one device at a time."); + } + + CPUExecutionProviderInfo epi{session_options->value.enable_cpu_mem_arena}; + *ep = std::make_unique(epi); + (*ep)->SetLogger(session_logger->ToInternal()); + + return nullptr; +} +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_cpu.h b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.h new file mode 100644 index 0000000000000..fba9bac976bb2 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/graph/constants.h" + +namespace onnxruntime { + +class CpuEpFactory : public EpFactoryInternalImpl { + public: + CpuEpFactory() : EpFactoryInternalImpl(kCpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_dml.cc b/onnxruntime/core/session/plugin_ep/ep_factory_dml.cc new file mode 100644 index 0000000000000..2f12ffa394537 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_dml.cc @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(USE_DML) + +#include "core/session/plugin_ep/ep_factory_dml.h" + +#include "core/framework/error_code_helper.h" +#include "core/providers/dml/dml_provider_factory_creator.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_logger.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { + +OrtStatus* DmlEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + auto ep_options = std::make_unique(); + + // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is + // associated with a specific device. + // How would we know what options should not allow user overrides if set in OrtEpDevice? + int32_t device_id = 0; // If no device_id was found default to 0 + if (auto it = device.metadata.Entries().find("DxgiAdapterNumber"); it != device.metadata.Entries().end()) { + device_id = std::stoi(it->second); + } + + ep_options->Add("device_id", std::to_string(device_id)); + + auto* api_status = OrtExecutionProviderApi::CreateEpDevice(&ep_factory, + &device, nullptr, ep_options.get(), + &ep_devices[num_ep_devices]); + + if (device_memory_infos.size() < device_id + 1) { + device_memory_infos.resize(device_id + 1); + device_allocators.resize(device_id + 1); + } + + if (device_memory_infos[device_id] == nullptr) { + // Create memory info for the device if it doesn't already exist + device_memory_infos[device_id] = std::make_unique( + "DML", OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, + narrow(device_id))); + } + + // This is what we need to add once CreateAllocator is implemented to create a shared allocator for the device. + // OrtExecutionProviderApi::EpDevice_AddAllocatorInfo(ep_devices[num_ep_devices], + // device_memory_infos[device_id].get()); + + if (api_status != nullptr) { + return api_status; + } + + ++num_ep_devices; + } + } + + return nullptr; +} + +OrtStatus* DmlEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + *ep = nullptr; + + if (num_devices != 1) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "DML EP factory currently only supports one device at a time."); + } + + auto ep_options = GetOptionsFromSessionOptions(session_options->value); + auto dml_ep_factory = DMLProviderFactoryCreator::CreateFromProviderOptions(session_options->value.config_options, + ep_options); + + *ep = dml_ep_factory->CreateProvider(); + (*ep)->SetLogger(session_logger->ToInternal()); + + return nullptr; +} + +/* +// TODO: This needs to create an allocator for the specific device so it's available as a shared allocator. That +// requires pulling lots of things out of the DML EP to get the D3D12 device and create a +// BucketizedBufferAllocator. See providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp +OrtStatus* DmlEpFactory::CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept { +} + +// TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. +OrtStatus* DmlEpFactory::CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept { +} +*/ +} // namespace onnxruntime + +#endif // USE_DML diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_dml.h b/onnxruntime/core/session/plugin_ep/ep_factory_dml.h new file mode 100644 index 0000000000000..1cdd172901942 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_dml.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if defined(USE_DML) + +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/graph/constants.h" + +namespace onnxruntime { + +class DmlEpFactory : public EpFactoryInternalImpl { + public: + DmlEpFactory() : EpFactoryInternalImpl(kDmlExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; + + std::vector> device_memory_infos; // memory info for each device + std::vector> device_allocators; // allocators for each device +}; + +} // namespace onnxruntime + +#endif // USE_DML diff --git a/onnxruntime/core/session/ep_factory_internal.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc similarity index 58% rename from onnxruntime/core/session/ep_factory_internal.cc rename to onnxruntime/core/session/plugin_ep/ep_factory_internal.cc index 9804aa6a5c42d..f3e30caf07e81 100644 --- a/onnxruntime/core/session/ep_factory_internal.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc @@ -1,18 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" #include "core/session/abi_session_options_impl.h" -#include "core/session/ep_api_utils.h" +#include "core/session/plugin_ep/forward_to_factory_impl.h" #include "core/session/ort_apis.h" -#include "onnxruntime_config.h" // for ORT_VERSION namespace onnxruntime { - -using Forward = ForwardToFactory; +using Forward = ForwardToFactoryImpl; EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl) : impl_{std::move(impl)} { @@ -25,6 +23,7 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl OrtEpFactory::GetSupportedDevices = Forward::GetSupportedDevices; OrtEpFactory::CreateEp = Forward::CreateEp; OrtEpFactory::ReleaseEp = Forward::ReleaseEp; + OrtEpFactory::ValidateCompiledModelCompatibilityInfo = Forward::ValidateCompiledModelCompatibilityInfo; OrtEpFactory::CreateAllocator = Forward::CreateAllocator; OrtEpFactory::ReleaseAllocator = Forward::ReleaseAllocator; OrtEpFactory::CreateDataTransfer = Forward::CreateDataTransfer; @@ -32,38 +31,6 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl OrtEpFactory::CreateSyncStreamForDevice = Forward::CreateSyncStreamForDevice; } -const char* EpFactoryInternal::GetVersion() const noexcept { - return ORT_VERSION; -} - -OrtStatus* EpFactoryInternal::CreateEp(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t /*num_devices*/, - const OrtSessionOptions* /*api_session_options*/, - const OrtLogger* /*api_logger*/, - OrtEp** /*ep*/) { - ORT_THROW("Internal error. CreateIExecutionProvider should be used for EpFactoryInternal."); -} - -// Prior to addition to SessionOptions the EP options do not have a prefix. -// They are prefixed with 'ep..' when added to SessionOptions. -// -// Use this function to get the options without the prefix from SessionOptions. -// Required by the option parsing for multiple existing EPs. -ProviderOptions EpFactoryInternalImpl::GetOptionsFromSessionOptions(const SessionOptions& session_options) const { - const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(GetName()); - ProviderOptions ep_options; - - for (const auto& [key, value] : session_options.config_options.configurations) { - if (key.find(option_prefix) == 0) { - // remove the prefix and add - ep_options[key.substr(option_prefix.length())] = value; - } - } - - return ep_options; -} - InternalExecutionProviderFactory::InternalExecutionProviderFactory(EpFactoryInternal& ep_factory, gsl::span ep_devices) : ep_factory_{ep_factory} { diff --git a/onnxruntime/core/session/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h similarity index 50% rename from onnxruntime/core/session/ep_factory_internal.h rename to onnxruntime/core/session/plugin_ep/ep_factory_internal.h index ae450efa394e8..093bfce462d32 100644 --- a/onnxruntime/core/session/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -7,85 +7,16 @@ #include #include "core/common/common.h" -#include "core/framework/execution_provider.h" #include "core/providers/providers.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/ort_apis.h" +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "onnxruntime_config.h" // for ORT_VERSION namespace onnxruntime { -class EpFactoryInternal; -class EpLibraryInternal; struct SessionOptions; - -// class with virtual methods that are implemented for each internal EP -class EpFactoryInternalImpl { - public: - EpFactoryInternalImpl(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id) - : ep_name_(ep_name), vendor_(vendor), vendor_id_(vendor_id) { - } - - const char* GetName() const noexcept { return ep_name_.c_str(); } - const char* GetVendor() const noexcept { return vendor_.c_str(); } - uint32_t GetVendorId() const noexcept { return vendor_id_; } - const char* GetVersion() const noexcept; - - virtual OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_ size_t num_devices, - _Inout_ OrtEpDevice** ep_devices, - _In_ size_t max_ep_devices, - _Out_ size_t* num_ep_devices) noexcept = 0; - - virtual OrtStatus* CreateIExecutionProvider(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, - _Out_ std::unique_ptr* ep) = 0; - - virtual OrtStatus* CreateAllocator(_In_ const OrtMemoryInfo* /*memory_info*/, - _In_opt_ const OrtKeyValuePairs* /*allocator_options*/, - _Outptr_ OrtAllocator** allocator) noexcept { - // default implementation does not add OrtMemoryInfo to OrtEpDevice instances returned - // so this should never be called - *allocator = nullptr; - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateAllocator is not implemented for this EP factory."); - } - - virtual void ReleaseAllocator(_In_ OrtAllocator* /*allocator*/) noexcept { - // we don't create any allocators so we don't need to release any - } - - virtual OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) { - *data_transfer = nullptr; - return nullptr; // Default implementation does nothing - } - - virtual bool IsStreamAware() const { - return false; - } - - virtual OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* /*memory_device*/, - _In_opt_ const OrtKeyValuePairs* /*stream_options*/, - _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) { - *stream = nullptr; - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, - "CreateSyncStreamForDevice is not implemented for this EP factory."); - } - - // Function ORT calls to release an EP instance. - void ReleaseEp(OrtEp* ep); - - virtual ~EpFactoryInternalImpl() = default; - - protected: - ProviderOptions GetOptionsFromSessionOptions(const SessionOptions& session_options) const; - - private: - const std::string ep_name_; // EP name library was registered with - const std::string vendor_; // EP vendor name - const uint32_t vendor_id_; // EP vendor ID -}; +class EpFactoryInternalImpl; // this class can't have any virtual methods as they break using it as an OrtEpFactory* in OrtEpDevice. class EpFactoryInternal : public OrtEpFactory { @@ -95,7 +26,7 @@ class EpFactoryInternal : public OrtEpFactory { const char* GetName() const noexcept { return impl_->GetName(); } const char* GetVendor() const noexcept { return impl_->GetVendor(); } uint32_t GetVendorId() const noexcept { return impl_->GetVendorId(); } - const char* GetVersion() const noexcept; + const char* GetVersion() const noexcept { return ORT_VERSION; } OrtStatus* GetSupportedDevices(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, _In_ size_t num_devices, @@ -106,11 +37,14 @@ class EpFactoryInternal : public OrtEpFactory { } // we don't implement this. CreateIExecutionProvider should be used. - OrtStatus* CreateEp(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, _Out_ OrtEp** ep); + OrtStatus* CreateEp(_In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, + _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + _In_ size_t /*num_devices*/, + _In_ const OrtSessionOptions* /*session_options*/, + _In_ const OrtLogger* /*logger*/, + _Out_ OrtEp** /*ep*/) { + ORT_THROW("Internal error. CreateIExecutionProvider should be used for EpFactoryInternal."); + } // same input args as CreateEp in case we need something from device or ep_metadata_pairs in the future. OrtStatus* CreateIExecutionProvider(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, @@ -132,24 +66,30 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->ReleaseAllocator(allocator); } - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) { + OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept { return impl_->CreateDataTransfer(data_transfer); } - bool IsStreamAware() const { + bool IsStreamAware() const noexcept { return impl_->IsStreamAware(); } OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* memory_device, _In_opt_ const OrtKeyValuePairs* stream_options, - _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) { + _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) noexcept { return impl_->CreateSyncStreamForDevice(memory_device, stream_options, stream); } + OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept { + return impl_->ValidateCompiledModelCompatibilityInfo(devices, num_devices, compatibility_info, model_compatibility); + } + // Function ORT calls to release an EP instance. - void ReleaseEp(OrtEp* /*ep*/) { + void ReleaseEp(OrtEp* /*ep*/) noexcept { // we never create an OrtEp so we should never be trying to release one - ORT_THROW("Internal error. No ReleaseEp call is required for EpFactoryInternal."); } private: diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.cc new file mode 100644 index 0000000000000..e61804d842859 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.cc @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/framework/error_code_helper.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_factory_internal.h" + +namespace onnxruntime { + +// Prior to addition to SessionOptions the EP options do not have a prefix. +// They are prefixed with 'ep..' when added to SessionOptions. +// +// Use this function to get the options without the prefix from SessionOptions. +// Required by the option parsing for multiple existing EPs. +ProviderOptions EpFactoryInternalImpl::GetOptionsFromSessionOptions(const SessionOptions& session_options) const { + const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(GetName()); + ProviderOptions ep_options; + + for (const auto& [key, value] : session_options.config_options.configurations) { + if (key.find(option_prefix) == 0) { + // remove the prefix and add + ep_options[key.substr(option_prefix.length())] = value; + } + } + + return ep_options; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h new file mode 100644 index 0000000000000..f29154d19c53c --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/framework/execution_provider.h" +#include "core/framework/provider_options.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { +class EpFactoryInternal; +struct SessionOptions; + +// class with virtual methods that are implemented for each internal EP +class EpFactoryInternalImpl { + public: + EpFactoryInternalImpl(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id) + : ep_name_(ep_name), vendor_(vendor), vendor_id_(vendor_id) { + } + + const char* GetName() const noexcept { return ep_name_.c_str(); } + const char* GetVendor() const noexcept { return vendor_.c_str(); } + uint32_t GetVendorId() const noexcept { return vendor_id_; } + const char* GetVersion() const noexcept; + + virtual OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _Inout_ OrtEpDevice** ep_devices, + _In_ size_t max_ep_devices, + _Out_ size_t* num_ep_devices) noexcept = 0; + + virtual OrtStatus* CreateIExecutionProvider(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, + _In_ size_t num_devices, + _In_ const OrtSessionOptions* session_options, + _In_ const OrtLogger* logger, + _Out_ std::unique_ptr* ep) = 0; + + virtual OrtStatus* CreateAllocator(_In_ const OrtMemoryInfo* /*memory_info*/, + _In_opt_ const OrtKeyValuePairs* /*allocator_options*/, + _Outptr_ OrtAllocator** allocator) noexcept { + // default implementation does not add OrtMemoryInfo to OrtEpDevice instances returned + // so this should never be called + *allocator = nullptr; + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateAllocator is not implemented for this EP factory."); + } + + virtual void ReleaseAllocator(_In_ OrtAllocator* /*allocator*/) noexcept { + // we don't create any allocators so we don't need to release any + } + + virtual OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept { + *data_transfer = nullptr; + return nullptr; // Default implementation does nothing + } + + virtual bool IsStreamAware() const noexcept { + return false; + } + + virtual OrtStatus* ValidateCompiledModelCompatibilityInfo( + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept { + ORT_UNUSED_PARAMETER(devices); + ORT_UNUSED_PARAMETER(num_devices); + ORT_UNUSED_PARAMETER(compatibility_info); + // Default implementation: mark as not applicable + *model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + return nullptr; + } + + virtual OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* /*memory_device*/, + _In_opt_ const OrtKeyValuePairs* /*stream_options*/, + _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) noexcept { + *stream = nullptr; + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, + "CreateSyncStreamForDevice is not implemented for this EP factory."); + } + + // Function ORT calls to release an EP instance. + void ReleaseEp(OrtEp* ep); + + virtual ~EpFactoryInternalImpl() = default; + + protected: + ProviderOptions GetOptionsFromSessionOptions(const SessionOptions& session_options) const; + + private: + const std::string ep_name_; // EP name library was registered with + const std::string vendor_; // EP vendor name + const uint32_t vendor_id_; // EP vendor ID +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc new file mode 100644 index 0000000000000..42b65239de92c --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_factory_provider_bridge.h" + +#include "core/providers/shared_library/provider_host_api.h" +#include "core/session/plugin_ep/ep_library_plugin.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" + +namespace onnxruntime { +OrtStatus* ProviderBridgeEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* num_ep_devices) noexcept { + ORT_API_RETURN_IF_ERROR(ep_factory_.GetSupportedDevices(&ep_factory_, devices, num_devices, ep_devices, + max_ep_devices, num_ep_devices)); + + // add the EpFactoryInternal layer back in so that we can redirect to CreateIExecutionProvider. + for (size_t i = 0; i < *num_ep_devices; ++i) { + auto* ep_device = ep_devices[i]; + if (ep_device) { + ep_device->ep_factory = &ep_factory; + + // Add library path to EP metadata if available + if (library_path_.has_value()) { + ep_device->ep_metadata.Add(kOrtEpDevice_EpMetadataKey_LibraryPath, library_path_->string()); + } + } + } + + return nullptr; +} + +OrtStatus* ProviderBridgeEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + // get the provider specific options + auto ep_options = GetOptionsFromSessionOptions(session_options->value); + auto& provider = provider_library_.Get(); + + auto status = provider.CreateIExecutionProvider(devices, ep_metadata_pairs, num_devices, + ep_options, *session_options, *session_logger, *ep); + + return ToOrtStatus(status); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h new file mode 100644 index 0000000000000..8c5ef526baba1 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/framework/error_code_helper.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/provider_bridge_library.h" + +namespace onnxruntime { +class ProviderBridgeEpFactory : public EpFactoryInternalImpl { + public: + ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library, + std::optional library_path = std::nullopt) + : EpFactoryInternalImpl(ep_factory.GetName(&ep_factory), + ep_factory.GetVendor(&ep_factory), + ep_factory.GetVendorId(&ep_factory)), + ep_factory_{ep_factory}, + provider_library_{provider_library}, + library_path_{std::move(library_path)} { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; + + OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept override { + return ep_factory_.CreateAllocator(&ep_factory_, memory_info, allocator_options, allocator); + } + + void ReleaseAllocator(OrtAllocator* allocator) noexcept override { + ep_factory_.ReleaseAllocator(&ep_factory_, allocator); + } + + OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept override { + return ep_factory_.CreateDataTransfer(&ep_factory_, data_transfer); + } + + bool IsStreamAware() const noexcept override { + return ep_factory_.IsStreamAware(&ep_factory_); + } + + OrtStatus* CreateSyncStreamForDevice(const OrtMemoryDevice* device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept override { + return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); + } + + OrtEpFactory& ep_factory_; + ProviderLibrary& provider_library_; + std::optional library_path_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc new file mode 100644 index 0000000000000..0f955e0bab248 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(USE_WEBGPU) +#include "core/session/plugin_ep/ep_factory_webgpu.h" + +#include "core/framework/error_code_helper.h" +#include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_logger.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { + +OrtStatus* WebGpuEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // TODO: any metadata or options to add? + ORT_API_RETURN_IF_ERROR(OrtExecutionProviderApi::CreateEpDevice(&ep_factory, + &device, nullptr, nullptr, + &ep_devices[num_ep_devices++])); + } + } + + return nullptr; +} + +OrtStatus* WebGpuEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + *ep = nullptr; + + if (num_devices != 1) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "WebGPU EP factory currently only supports one device at a time."); + } + + auto webgpu_ep_factory = WebGpuProviderFactoryCreator::Create(session_options->value.config_options); + *ep = webgpu_ep_factory->CreateProvider(); + (*ep)->SetLogger(session_logger->ToInternal()); + + return nullptr; +} + +/* TODO: Implement CreateAllocator and CreateDataTransfer to support shared allocators and data transfer outside of + an InferenceSession. +OrtStatus* WebGpuEpFactory::CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept override { + *allocator = device_allocators[memory_info->device.Id()].get(); +} + +OrtStatus* WebGpuEpFactory::CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { + // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. + *data_transfer = nullptr; + return nullptr; +} +*/ +} // namespace onnxruntime + +#endif // USE_WEBGPU diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h new file mode 100644 index 0000000000000..06ecfa744bbda --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if defined(USE_WEBGPU) +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/graph/constants.h" + +namespace onnxruntime { + +class WebGpuEpFactory : public EpFactoryInternalImpl { + public: + WebGpuEpFactory() : EpFactoryInternalImpl(kWebGpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; +}; +} // namespace onnxruntime + +#endif // USE_WEBGPU diff --git a/onnxruntime/core/session/ep_library.h b/onnxruntime/core/session/plugin_ep/ep_library.h similarity index 99% rename from onnxruntime/core/session/ep_library.h rename to onnxruntime/core/session/plugin_ep/ep_library.h index 24ab74e1c77fc..af5bc23143e33 100644 --- a/onnxruntime/core/session/ep_library.h +++ b/onnxruntime/core/session/plugin_ep/ep_library.h @@ -23,6 +23,7 @@ class EpLibrary { virtual Status Load() { return Status::OK(); } virtual const std::vector& GetFactories() = 0; // valid after Load() virtual Status Unload() { return Status::OK(); } + virtual ~EpLibrary() = default; ORT_DISALLOW_COPY_AND_ASSIGNMENT(EpLibrary); diff --git a/onnxruntime/core/session/plugin_ep/ep_library_internal.cc b/onnxruntime/core/session/plugin_ep/ep_library_internal.cc new file mode 100644 index 0000000000000..d4015e0bbd366 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_library_internal.cc @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_library_internal.h" +#include "core/session/plugin_ep/ep_factory_cpu.h" +#include "core/session/plugin_ep/ep_factory_dml.h" +#include "core/session/plugin_ep/ep_factory_webgpu.h" + +namespace onnxruntime { + +std::unique_ptr EpLibraryInternal::CreateCpuEp() { + auto cpu_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(cpu_factory_impl)); + return std::make_unique(std::move(internal_factory)); +} + +#if defined(USE_DML) + +std::unique_ptr EpLibraryInternal::CreateDmlEp() { + auto dml_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(dml_factory_impl)); + return std::make_unique(std::move(internal_factory)); +} +#endif + +#if defined(USE_WEBGPU) +std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { + auto webgpu_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(webgpu_factory_impl)); + return std::make_unique(std::move(internal_factory)); +} +#endif + +std::vector> EpLibraryInternal::CreateInternalEps() { + std::vector> internal_eps; + internal_eps.reserve(4); + + // CPU EP + internal_eps.push_back(CreateCpuEp()); + +#if defined(USE_WEBGPU) + internal_eps.push_back(CreateWebGpuEp()); +#endif + +#if defined(USE_DML) + internal_eps.push_back(CreateDmlEp()); +#endif + + return internal_eps; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library_internal.h b/onnxruntime/core/session/plugin_ep/ep_library_internal.h similarity index 94% rename from onnxruntime/core/session/ep_library_internal.h rename to onnxruntime/core/session/plugin_ep/ep_library_internal.h index ab529edc2507f..1587f01360e26 100644 --- a/onnxruntime/core/session/ep_library_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_internal.h @@ -4,8 +4,8 @@ #pragma once #include "core/common/common.h" -#include "core/session/ep_library.h" -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_library.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/provider_bridge_library.h" diff --git a/onnxruntime/core/session/ep_library_plugin.cc b/onnxruntime/core/session/plugin_ep/ep_library_plugin.cc similarity index 98% rename from onnxruntime/core/session/ep_library_plugin.cc rename to onnxruntime/core/session/plugin_ep/ep_library_plugin.cc index 32ddd8a765b4c..ebfa364f4f1df 100644 --- a/onnxruntime/core/session/ep_library_plugin.cc +++ b/onnxruntime/core/session/plugin_ep/ep_library_plugin.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_library_plugin.h" +#include "core/session/plugin_ep/ep_library_plugin.h" #include "core/common/logging/logging.h" #include "core/framework/error_code_helper.h" diff --git a/onnxruntime/core/session/ep_library_plugin.h b/onnxruntime/core/session/plugin_ep/ep_library_plugin.h similarity index 96% rename from onnxruntime/core/session/ep_library_plugin.h rename to onnxruntime/core/session/plugin_ep/ep_library_plugin.h index e2b02ccc654da..e044e91b61e37 100644 --- a/onnxruntime/core/session/ep_library_plugin.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_plugin.h @@ -6,7 +6,7 @@ #include #include -#include "core/session/ep_library.h" +#include "core/session/plugin_ep/ep_library.h" namespace onnxruntime { /// diff --git a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc new file mode 100644 index 0000000000000..da94a9f12ba9d --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_library_provider_bridge.h" + +#include "core/session/plugin_ep/ep_factory_provider_bridge.h" +#include "core/session/plugin_ep/ep_library_plugin.h" + +namespace onnxruntime { +Status EpLibraryProviderBridge::Load() { + std::lock_guard lock{mutex_}; + + if (!factories_.empty()) { + // already loaded + return Status::OK(); + } + + // if we have been unloaded we can't just be reloaded. + if (!ep_library_plugin_ || !provider_library_) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "EpLibraryProviderBridge has been unloaded. " + "Please create a new instance using LoadPluginOrProviderBridge."); + } + + // wrap the EpLibraryPlugin factories that were created via calling CreateEpFactories in the library. + // use GetSupportedDevices from the library's factory. + // to do this we need to capture `factory` and plug it in to is_supported_fn and create_fn. + // we also need to update any returned OrtEpDevice instances to swap the wrapper EpFactoryInternal in so that we can + // call Provider::CreateIExecutionProvider in EpFactoryInternal::CreateIExecutionProvider. + + for (const auto& factory : ep_library_plugin_->GetFactories()) { + auto factory_impl = std::make_unique(*factory, *provider_library_, library_path_); + auto internal_factory = std::make_unique(std::move(factory_impl)); + + factory_ptrs_.push_back(internal_factory.get()); + internal_factory_ptrs_.push_back(internal_factory.get()); + factories_.push_back(std::move(internal_factory)); + } + + return Status::OK(); +} + +Status EpLibraryProviderBridge::Unload() { + std::lock_guard lock{mutex_}; + + internal_factory_ptrs_.clear(); + factory_ptrs_.clear(); + factories_.clear(); + + // we loaded ep_library_plugin_ after provider_library_ in LoadPluginOrProviderBridge so do the reverse order here. + ORT_RETURN_IF_ERROR(ep_library_plugin_->Unload()); + ep_library_plugin_ = nullptr; + + provider_library_->Unload(); + provider_library_ = nullptr; + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h similarity index 83% rename from onnxruntime/core/session/ep_library_provider_bridge.h rename to onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h index 0717ccd957de7..45277b2828f56 100644 --- a/onnxruntime/core/session/ep_library_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h @@ -5,8 +5,8 @@ #include #include -#include "core/session/ep_library.h" -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_library.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/provider_bridge_library.h" namespace onnxruntime { @@ -21,9 +21,11 @@ namespace onnxruntime { class EpLibraryProviderBridge : public EpLibrary { public: EpLibraryProviderBridge(std::unique_ptr provider_library, - std::unique_ptr ep_library_plugin) + std::unique_ptr ep_library_plugin, + std::optional library_path = std::nullopt) : provider_library_{std::move(provider_library)}, - ep_library_plugin_{std::move(ep_library_plugin)} { + ep_library_plugin_{std::move(ep_library_plugin)}, + library_path_{std::move(library_path)} { } const char* RegistrationName() const override { @@ -53,6 +55,9 @@ class EpLibraryProviderBridge : public EpLibrary { // implement EpFactoryInternal::CreateIExecutionProvider by calling Provider::CreateIExecutionProvider. std::unique_ptr ep_library_plugin_; + // Library path for EP metadata + std::optional library_path_; + std::vector> factories_; std::vector factory_ptrs_; // for convenience std::vector internal_factory_ptrs_; // for convenience diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc similarity index 92% rename from onnxruntime/core/session/ep_plugin_provider_interfaces.cc rename to onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index c7d7ea2e8a4ec..c8829423fbe26 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include #include @@ -644,4 +644,42 @@ void PluginExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistr registry.RegisterWaitFn(device_type, OrtDevice::CPU, plugin_ep::Notification::WaitNotificationOnHost); } } + +std::string PluginExecutionProvider::GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const { + if (ort_ep_->GetCompiledModelCompatibilityInfo == nullptr) { + // Plugin EP did not provide an implementation of this function, so we call a default implementation. + return Base::GetCompiledModelCompatibilityInfo(graph_viewer); + } + std::unique_ptr ep_graph = nullptr; + auto ort_status = EpGraph::Create(graph_viewer, ep_graph); + if (!ort_status.IsOK()) { + LOGS(*GetLogger(), ERROR) << "Failed to create EpGraph: " << ort_status.ToString(); + return {}; + } + // Call EP plugin's OrtEp::GenerateCompiledModelCompatibilityInfo() function. + std::string compatibility_info_string; + compatibility_info_string = ort_ep_->GetCompiledModelCompatibilityInfo(ort_ep_.get(), ep_graph.get()); + return compatibility_info_string; +} + +Status PluginExecutionProvider::ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info, + OrtCompiledModelCompatibility& model_compatibility) const { + if (ep_factory_.ValidateCompiledModelCompatibilityInfo == nullptr) { + // Plugin EP did not provide an implementation of this function, so we call a default implementation. + return Base::ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility); + } + // Delegate to the EP factory's validation method, passing hardware devices derived from our ep_devices_ + std::vector hardware_devices; + hardware_devices.reserve(ep_devices_.size()); + for (const auto* ep_device : ep_devices_) { + hardware_devices.push_back(ep_device->device); + } + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory_.ValidateCompiledModelCompatibilityInfo(&ep_factory_, + hardware_devices.data(), + hardware_devices.size(), + compatibility_info.c_str(), + &model_compatibility))); + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h similarity index 94% rename from onnxruntime/core/session/ep_plugin_provider_interfaces.h rename to onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h index 728f959ad67cb..622bbb3f97b24 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.h +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h @@ -101,6 +101,11 @@ class PluginExecutionProvider : public IExecutionProvider { // needed based on matching against allocator_mem_infos_. std::vector CreatePreferredAllocators() override; + std::string GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const override; + + Status ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info, + OrtCompiledModelCompatibility& model_compatibility) const override; + private: struct FusedNodeState { FusedNodeState() = default; diff --git a/onnxruntime/core/session/ep_api_utils.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h similarity index 83% rename from onnxruntime/core/session/ep_api_utils.h rename to onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 77528565eced7..2cceb1d08d536 100644 --- a/onnxruntime/core/session/ep_api_utils.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -7,7 +7,7 @@ namespace onnxruntime { // helper to forward a call from the C API to an instance of the factory implementation. // used by EpFactoryInternal and EpFactoryProviderBridge. template -struct ForwardToFactory { +struct ForwardToFactoryImpl { static const char* ORT_API_CALL GetFactoryName(const OrtEpFactory* this_ptr) noexcept { return static_cast(this_ptr)->GetName(); } @@ -45,6 +45,15 @@ struct ForwardToFactory { session_options, logger, ep); } + static OrtStatus* ORT_API_CALL ValidateCompiledModelCompatibilityInfo(OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + size_t num_devices, + const char* compatibility_info, + OrtCompiledModelCompatibility* model_compatibility) noexcept { + return static_cast(this_ptr)->ValidateCompiledModelCompatibilityInfo(devices, num_devices, + compatibility_info, model_compatibility); + } + static OrtStatus* ORT_API_CALL CreateAllocator(_In_ OrtEpFactory* this_ptr, _In_ const OrtMemoryInfo* memory_info, _In_opt_ const OrtKeyValuePairs* allocator_options, diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 01b70db6d940e..41cf8be1d1412 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1258,6 +1258,10 @@ struct ProviderHostImpl : ProviderHost { void Graph__AddInitializedTensor(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor) override { p->AddInitializedTensor(tensor); } Status Graph__AddInitializedOrtValue(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor, const OrtValue& value) override { return p->AddInitializedOrtValue(tensor, value); } + bool Graph__GetOrtValueInitializer(const Graph* p, const std::string& tensor_name, OrtValue& value, + bool check_outer_scope) override { + return p->GetOrtValueInitializer(tensor_name, value, check_outer_scope); + } Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span& input_args, const gsl::span& output_args, const NodeAttributes* attributes, const std::string& domain) override { return p->AddNode(name, op_type, description, input_args, output_args, attributes, domain); } @@ -1356,6 +1360,10 @@ struct ProviderHostImpl : ProviderHost { bool check_outer_scope) const override { return p->GetConstantInitializer(name, check_outer_scope); } + bool GraphViewer__GetOrtValueInitializer(const GraphViewer* p, const std::string& tensor_name, + OrtValue& value) override { + return p->GetOrtValueInitializer(tensor_name, value); + } const Node* GraphViewer__ParentNode(const GraphViewer* p) override { return p->ParentNode(); } int GraphViewer__NumberOfNodes(const GraphViewer* p) noexcept override { return p->NumberOfNodes(); } int GraphViewer__MaxNodeIndex(const GraphViewer* p) noexcept override { return p->MaxNodeIndex(); } @@ -2108,8 +2116,13 @@ std::shared_ptr NvProviderFactoryCreator::Create( return nullptr; } -std::shared_ptr MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* provider_options) { - return s_library_migraphx.Get().CreateExecutionProviderFactory(provider_options); +std::shared_ptr MIGraphXProviderFactoryCreator::Create(const ProviderOptions& provider_options) { + return s_library_migraphx.Get().CreateExecutionProviderFactory(&provider_options); +} + +std::shared_ptr MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* options) { + const auto provider_options{s_library_migraphx.Get().GetProviderOptions(options)}; + return s_library_migraphx.Get().CreateExecutionProviderFactory(&provider_options); } // Adapter to convert the legacy OrtOpenVINOProviderOptions to ProviderOptions diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index 211bf8b2d15a4..6bcbda0f13b92 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -11,8 +11,8 @@ #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" #include "core/session/abi_logger.h" -#include "core/session/ep_factory_internal.h" -#include "core/session/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include "core/session/inference_session.h" #include "core/session/inference_session_utils.h" #include "core/session/onnxruntime_c_api.h" diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 18a463ef69943..48d52ae3cf428 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -101,6 +101,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, VitisAI, CoreML, NvTensorRtRtx, // TensorRt EP for RTX GPUs. + MIGraphX }; struct EpToAppend { @@ -109,7 +110,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, const char* canonical_name = nullptr; }; - static std::array supported_eps = { + static std::array supported_eps = { EpToAppend{EpID::DML, "DML", kDmlExecutionProvider}, EpToAppend{EpID::QNN, "QNN", kQnnExecutionProvider}, EpToAppend{EpID::OpenVINO, "OpenVINO", kOpenVINOExecutionProvider}, @@ -121,7 +122,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, EpToAppend{EpID::JS, "JS", kJsExecutionProvider}, EpToAppend{EpID::VitisAI, "VitisAI", kVitisAIExecutionProvider}, EpToAppend{EpID::CoreML, "CoreML", kCoreMLExecutionProvider}, - EpToAppend{EpID::NvTensorRtRtx, "NvTensorRtRtx", kNvTensorRTRTXExecutionProvider}}; + EpToAppend{EpID::NvTensorRtRtx, "NvTensorRtRtx", kNvTensorRTRTXExecutionProvider}, + EpToAppend{EpID::MIGraphX, "MIGraphX", kMIGraphXExecutionProvider}}; ProviderOptions provider_options; OrtStatus* status = ParseProviderOptions(provider_options_keys, @@ -279,6 +281,14 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options, &(options->value))); #else status = create_not_supported_status(); +#endif + break; + } + case EpID::MIGraphX: { +#if defined(USE_MIGRAPHX) || defined(USE_MIGRAPHX_PROVIDER_INTERFACE) + options->provider_factories.push_back(MIGraphXProviderFactoryCreator::Create(provider_options)); +#else + status = create_not_supported_status(); #endif break; } diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 69039beb49363..444027692903c 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -17,12 +17,13 @@ #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/ort_apis.h" #include "core/session/ort_env.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #if !defined(ORT_MINIMAL_BUILD) -#include "core/session/ep_factory_internal.h" -#include "core/session/ep_plugin_provider_interfaces.h" -#include "core/session/ep_library_plugin.h" -#include "core/session/ep_library_provider_bridge.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_library_plugin.h" +#include "core/session/plugin_ep/ep_library_provider_bridge.h" #include "core/session/model_compilation_options.h" #include "core/session/provider_policy_context.h" #endif // !defined(ORT_MINIMAL_BUILD) @@ -135,13 +136,11 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op // If ep.context_enable is set, then ep.context_file_path is expected, otherwise ORT don't know where to generate the _ctx.onnx file if (options && model_path == nullptr) { - EpContextModelGenerationOptions ep_ctx_gen_options = options->value.GetEpContextGenerationOptions(); + epctx::ModelGenOptions ep_ctx_gen_options = options->value.GetEpContextGenerationOptions(); // This is checked by the OrtCompileApi's CompileModel() function, but we check again here in case // the user used the older SessionOptions' configuration entries to generate a compiled model. - if (ep_ctx_gen_options.enable && - ep_ctx_gen_options.output_model_file_path.empty() && - ep_ctx_gen_options.output_model_buffer_ptr == nullptr) { + if (ep_ctx_gen_options.enable && !ep_ctx_gen_options.HasOutputModelLocation()) { return OrtApis::CreateStatus(ORT_FAIL, "Inference session was configured with EPContext model generation enabled but " "without a valid location (e.g., file or buffer) for the output model. " @@ -206,6 +205,117 @@ OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, return CreateSessionAndLoadModelImpl(options, env->GetEnvironment(), model_path, model_data, model_data_length, sess); } +#if !defined(ORT_MINIMAL_BUILD) +static const char* GetCompatibilityStatusString(OrtCompiledModelCompatibility status) { + switch (status) { + case OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL: + return "SUPPORTED_OPTIMAL"; + case OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION: + return "SUPPORTED_PREFER_RECOMPILATION"; + case OrtCompiledModelCompatibility_EP_UNSUPPORTED: + return "UNSUPPORTED"; + case OrtCompiledModelCompatibility_EP_NOT_APPLICABLE: + return "NOT_APPLICABLE"; + default: + return "UNKNOWN"; + } +} + +static Status ValidateCompiledModelCompatibility(InferenceSession& sess) { + // Get model metadata + auto [status, model_metadata] = sess.GetModelMetadata(); + if (!status.IsOK() || !model_metadata) { + // No metadata available, skip validation + return Status::OK(); + } + + const auto& custom_metadata = model_metadata->custom_metadata_map; + if (custom_metadata.empty()) { + // No custom metadata available, skip validation + return Status::OK(); + } + + // Check if user wants to fail on suboptimal models + bool fail_on_suboptimal = sess.GetSessionOptions().config_options.GetConfigEntry( + kOrtSessionOptionsFailOnSuboptimalCompiledModel) == "1"; + + const auto& registered_provider_types = sess.GetRegisteredProviderTypes(); + + // Access the execution providers through the session state (available after Initialize) + const auto& execution_providers = sess.GetSessionState().GetExecutionProviders(); + + for (const auto& ep_type : registered_provider_types) { + // Construct the full metadata key using the prefix + EP type + const std::string metadata_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep_type; + + auto metadata_it = custom_metadata.find(metadata_key); + if (metadata_it != custom_metadata.end()) { + const std::string& compatibility_info = metadata_it->second; + + // Get the actual EP instance to call validation + const IExecutionProvider* ep = execution_providers.Get(ep_type); + + if (ep != nullptr) { + // Call the EP's validation method (virtual method with default implementation) + OrtCompiledModelCompatibility compatibility_status; + Status validation_result = ep->ValidateCompiledModelCompatibilityInfo( + compatibility_info, compatibility_status); + + if (validation_result.IsOK()) { + // Log the compatibility status + const char* status_str = GetCompatibilityStatusString(compatibility_status); + LOGS(*sess.GetLogger(), INFO) + << "EP " << ep_type << " compiled model compatibility: " << status_str; + + // Enforce compatibility based on status + switch (compatibility_status) { + case OrtCompiledModelCompatibility_EP_NOT_APPLICABLE: + case OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL: + // Continue execution + break; + + case OrtCompiledModelCompatibility_EP_UNSUPPORTED: + // Always fail for unsupported models + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Compiled model is not supported by execution provider: " + ep_type); + + case OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION: + // Behavior depends on user setting + if (fail_on_suboptimal) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Compiled model is suboptimal for execution provider: " + ep_type + + ". Recompilation recommended for better performance."); + } + // Otherwise continue with warning + LOGS(*sess.GetLogger(), WARNING) + << "EP " << ep_type << " reports compiled model is supported but suboptimal. " + << "Consider recompiling for better performance."; + break; + + default: + // Handle any unknown status values + LOGS(*sess.GetLogger(), WARNING) + << "EP " << ep_type << " returned unknown compatibility status: " << compatibility_status; + break; + } + } else { + // Validation failed - this should cause session initialization to fail + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to validate compiled model compatibility for EP " + ep_type + + ": " + validation_result.ErrorMessage()); + } + } + } else { + // No compatibility info found for this EP - normal for non-compiled models + LOGS(*sess.GetLogger(), VERBOSE) + << "No compiled model compatibility info found for EP " << ep_type; + } + } + + return Status::OK(); +} +#endif // !defined(ORT_MINIMAL_BUILD) + OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, _In_ onnxruntime::InferenceSession& sess, _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container) { @@ -253,6 +363,12 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, ORT_API_RETURN_IF_STATUS_NOT_OK(sess.Initialize()); +#if !defined(ORT_MINIMAL_BUILD) + // Validate compiled model compatibility for all registered execution providers + // This must be done after Initialize() so the session state is available + ORT_API_RETURN_IF_STATUS_NOT_OK(ValidateCompiledModelCompatibility(sess)); +#endif // !defined(ORT_MINIMAL_BUILD) + return nullptr; } @@ -265,7 +381,7 @@ Status CompileModel(const Environment& env, const ModelCompilationOptions& model const OrtSessionOptions* session_options = &model_compile_options.GetSessionOptions(); if (model_compile_options.InputModelComesFromFile()) { - PathString input_model_path = ToPathString(model_compile_options.GetInputModelPath()); + const std::filesystem::path& input_model_path = model_compile_options.GetInputModelPath(); ORT_RETURN_IF_ERROR(ToStatusAndRelease(CreateSessionAndLoadModelImpl(session_options, env, input_model_path.c_str(), nullptr, 0, session))); @@ -303,13 +419,14 @@ Status LoadPluginOrProviderBridge(const std::string& registration_name, << (is_provider_bridge ? " as a provider bridge" : " as a plugin"); // create EpLibraryPlugin to ensure CreateEpFactories and ReleaseEpFactory are available - auto ep_library_plugin = std::make_unique(registration_name, std::move(resolved_library_path)); + auto ep_library_plugin = std::make_unique(registration_name, resolved_library_path); ORT_RETURN_IF_ERROR(ep_library_plugin->Load()); if (is_provider_bridge) { // wrap the EpLibraryPlugin with EpLibraryProviderBridge to add to directly create an IExecutionProvider auto ep_library_provider_bridge = std::make_unique(std::move(provider_library), - std::move(ep_library_plugin)); + std::move(ep_library_plugin), + resolved_library_path); ORT_RETURN_IF_ERROR(ep_library_provider_bridge->Load()); internal_factories = ep_library_provider_bridge->GetInternalFactories(); ep_library = std::move(ep_library_provider_bridge); diff --git a/onnxruntime/core/util/shape_checker.h b/onnxruntime/core/util/shape_checker.h index 9c975275c45b9..89c20deb8f649 100644 --- a/onnxruntime/core/util/shape_checker.h +++ b/onnxruntime/core/util/shape_checker.h @@ -27,6 +27,8 @@ TensorShape make_shape(Args... args) { } \ } +#define CHECK_TENSOR_SHAPE ASSERT_TENSOR_DIMS + // This assumes the tensor is optional, and check wether its shape is expected. #define ASSERT_TENSOR_SHAPE(tensor, shape) \ if (tensor != nullptr) { \ @@ -60,4 +62,31 @@ TensorShape make_shape(Args... args) { } \ } +#define ASSERT_TENSOR_DIMENSION(tensor, dim) \ + if (tensor != nullptr) { \ + static_assert(std::is_same::value, "tensor must be a pointer to a Tensor"); \ + const auto tensor_dimensions = tensor->Shape().NumDimensions(); \ + if (tensor_dimensions != dim) { \ + return ORT_MAKE_STATUS( \ + ONNXRUNTIME, INVALID_ARGUMENT, "Input '" #tensor "' is expected to have " #dim " dimensions, got ", \ + tensor_dimensions); \ + } \ + } + +#define ASSERT_TENSOR_DIMENSION_2_CHOICES(tensor, choice1, choice2) \ + if ((tensor) != nullptr) { \ + static_assert(std::is_same::value, "tensor must be a pointer to a Tensor"); \ + const auto tensor_dimensions = tensor->Shape().NumDimensions(); \ + if (tensor_dimensions != choice1 && tensor_dimensions != choice2) { \ + return ORT_MAKE_STATUS( \ + ONNXRUNTIME, INVALID_ARGUMENT, \ + "Input '" #tensor "' is expected to have " #choice1 " or ", #choice2, " dimensions, got ", \ + tensor_dimensions); \ + } \ + } + +#define ASSERT_TENSOR_2D(tensor) ASSERT_TENSOR_DIMENSION(tensor, 2) +#define ASSERT_TENSOR_3D(tensor) ASSERT_TENSOR_DIMENSION(tensor, 3) +#define ASSERT_TENSOR_2D_OR_3D(tensor) ASSERT_TENSOR_DIMENSION_2_CHOICES(tensor, 2, 3) + } // namespace onnxruntime diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index e8e51db13bcd3..35abad5760c32 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -9,7 +9,7 @@ import os import typing import warnings -from collections.abc import Sequence +from collections.abc import Callable, Sequence from typing import Any from onnxruntime.capi import _pybind_state as C @@ -21,7 +21,7 @@ import onnxruntime -def get_ort_device_type(device_type: str, device_index) -> C.OrtDevice: +def get_ort_device_type(device_type: str) -> int: if device_type == "cuda": return C.OrtDevice.cuda() elif device_type == "cann": @@ -32,8 +32,10 @@ def get_ort_device_type(device_type: str, device_index) -> C.OrtDevice: return C.OrtDevice.dml() elif device_type == "webgpu": return C.OrtDevice.webgpu() - elif device_type == "ort": - return C.get_ort_device(device_index).device_type() + elif device_type == "gpu": + return C.OrtDevice.gpu() + elif device_type == "npu": + return C.OrtDevice.npu() else: raise Exception("Unsupported device type: " + device_type) @@ -618,6 +620,36 @@ def _register_ep_custom_ops(self, session_options, providers, provider_options, C.register_nv_tensorrt_rtx_plugins_as_custom_ops(session_options, providers[i][1]) +def make_get_initializer_location_func_wrapper( + get_initializer_location_func: GetInitializerLocationFunc, +) -> GetInitializerLocationWrapperFunc: + """ + Wraps a user's "get initializer location" function. The returned wrapper function adheres to the + signature expected by ORT. + + Need this wrapper to: + - Convert the `initializer_value` parameter from `C.OrtValue` to `onnxruntime.OrtValue`, which is more + convenient for the user's function to use. + - Allow the user's function to return the original `external_info` parameter (this wrapper makes a copy) + """ + + def get_initializer_location_func_wrapper( + initializer_name: str, + initializer_value: C.OrtValue, + external_info: C.OrtExternalInitializerInfo | None, + ) -> C.OrtExternalInitializerInfo | None: + ret_val: C.OrtExternalInitializerInfo | None = get_initializer_location_func( + initializer_name, OrtValue(initializer_value), external_info + ) + if ret_val is not None and ret_val == external_info: + # User returned `external_info` (const and owned by ORT). ORT expects the returned value to be + # a new instance (that it deletes), so make a copy. + ret_val = C.OrtExternalInitializerInfo(ret_val.filepath, ret_val.file_offset, ret_val.byte_size) + return ret_val + + return get_initializer_location_func_wrapper + + class ModelCompiler: """ This class is used to compile an ONNX model. A compiled ONNX model has EPContext nodes that each @@ -645,6 +677,8 @@ def __init__( external_initializers_file_path: str | os.PathLike | None = None, external_initializers_size_threshold: int = 1024, flags: int = C.OrtCompileApiFlags.NONE, + graph_optimization_level: C.GraphOptimizationLevel = C.GraphOptimizationLevel.ORT_DISABLE_ALL, + get_initializer_location_func: GetInitializerLocationFunc | None = None, ): """ Creates a ModelCompiler instance. @@ -661,6 +695,27 @@ def __init__( is None or empty. Initializers larger than this threshold are stored in the external initializers file. :param flags: Additional boolean options to enable. Set this parameter to a bitwise OR of flags in onnxruntime.OrtCompileApiFlags. + :param graph_optimization_level: The graph optimization level. + Defaults to onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL. + :param get_initializer_location_func: Optional function called for every initializer to allow user to specify + whether an initializer should be stored within the model or externally. Example: + ``` + def get_initializer_location( + initializer_name: str, + initializer_value: onnxrt.OrtValue, + external_info: onnxrt.OrtExternalInitializerInfo | None, + ) -> onnxrt.OrtExternalInitializerInfo | None: + byte_size = initializer_value.tensor_size_in_bytes() + + if byte_size < 64: + return None # Store small initializer within compiled model. + + # Else, write initializer to new external file. + value_np = initializer_value.numpy() + file_offset = ext_init_file.tell() + ext_init_file.write(value_np.tobytes()) + return onnxrt.OrtExternalInitializerInfo(initializer_file_path, file_offset, byte_size) + ``` """ input_model_path: str | os.PathLike | None = None input_model_bytes: bytes | None = None @@ -683,6 +738,18 @@ def __init__( else: external_initializers_file_path = "" + if get_initializer_location_func is not None: + if external_initializers_file_path: + raise ValueError( + "Cannot initialize ModelCompiler with both `external_initializers_file_path` " + "and `get_initializer_location_func`" + ) + self.get_initializer_location_func_wrapper = make_get_initializer_location_func_wrapper( + get_initializer_location_func + ) + else: + self.get_initializer_location_func_wrapper = None + if input_model_path: self._model_compiler = C.ModelCompiler( sess_options, @@ -692,6 +759,8 @@ def __init__( external_initializers_file_path, external_initializers_size_threshold, flags, + graph_optimization_level, + self.get_initializer_location_func_wrapper, ) else: self._model_compiler = C.ModelCompiler( @@ -702,6 +771,8 @@ def __init__( external_initializers_file_path, external_initializers_size_threshold, flags, + graph_optimization_level, + self.get_initializer_location_func_wrapper, ) def compile_to_file(self, output_model_path: str | None = None): @@ -731,6 +802,14 @@ def compile_to_bytes(self) -> bytes: """ return self._model_compiler.compile_to_bytes() + def compile_to_stream(self, write_function: Callable[[bytes], None]): + """ + Compiles the input model and writes the serialized ONNX bytes to a stream using the provided write function. + Raises an 'InvalidArgument' exception if the compilation options are invalid. + :param write_function: A callable that accepts a bytes buffer to write. + """ + self._model_compiler.compile_to_stream(write_function) + class IOBinding: """ @@ -765,7 +844,7 @@ def bind_input(self, name, device_type, device_id, element_type, shape, buffer_p self._iobinding.bind_input( name, C.OrtDevice( - get_ort_device_type(device_type, device_id), + get_ort_device_type(device_type), C.OrtDevice.default_memory(), device_id, ), @@ -812,7 +891,7 @@ def bind_output( self._iobinding.bind_output( name, C.OrtDevice( - get_ort_device_type(device_type, device_id), + get_ort_device_type(device_type), C.OrtDevice.default_memory(), device_id, ), @@ -823,7 +902,7 @@ def bind_output( self._iobinding.bind_output( name, C.OrtDevice( - get_ort_device_type(device_type, device_id), + get_ort_device_type(device_type), C.OrtDevice.default_memory(), device_id, ), @@ -889,7 +968,7 @@ def _get_c_value(self) -> C.OrtValue: return self._ortvalue @classmethod - def ortvalue_from_numpy(cls, numpy_obj: np.ndarray, /, device_type="cpu", device_id=0) -> OrtValue: + def ortvalue_from_numpy(cls, numpy_obj: np.ndarray, /, device_type="cpu", device_id=0, vendor_id=-1) -> OrtValue: """ Factory method to construct an OrtValue (which holds a Tensor) from a given Numpy object A copy of the data in the Numpy object is held by the OrtValue only if the device is NOT cpu @@ -897,6 +976,7 @@ def ortvalue_from_numpy(cls, numpy_obj: np.ndarray, /, device_type="cpu", device :param numpy_obj: The Numpy object to construct the OrtValue from :param device_type: e.g. cpu, cuda, cann, cpu by default :param device_id: device id, e.g. 0 + :param vendor_id: The device's PCI vendor id. If provided, the device_type should be "gpu" or "npu". """ # Hold a reference to the numpy object (if device_type is 'cpu') as the OrtValue # is backed directly by the data buffer of the numpy object and so the numpy object @@ -904,11 +984,7 @@ def ortvalue_from_numpy(cls, numpy_obj: np.ndarray, /, device_type="cpu", device return cls( C.OrtValue.ortvalue_from_numpy( numpy_obj, - C.OrtDevice( - get_ort_device_type(device_type, device_id), - C.OrtDevice.default_memory(), - device_id, - ), + OrtDevice.make(device_type, device_id, vendor_id)._get_c_device(), ), numpy_obj if device_type.lower() == "cpu" else None, ) @@ -929,7 +1005,7 @@ def ortvalue_from_numpy_with_onnx_type(cls, data: np.ndarray, /, onnx_element_ty @classmethod def ortvalue_from_shape_and_type( - cls, shape: Sequence[int], element_type, device_type: str = "cpu", device_id: int = 0 + cls, shape: Sequence[int], element_type, device_type: str = "cpu", device_id: int = 0, vendor_id: int = -1 ) -> OrtValue: """ Factory method to construct an OrtValue (which holds a Tensor) from given shape and element_type @@ -938,7 +1014,11 @@ def ortvalue_from_shape_and_type( :param element_type: The data type of the elements. It can be either numpy type (like numpy.float32) or an integer for onnx type (like onnx.TensorProto.BFLOAT16). :param device_type: e.g. cpu, cuda, cann, cpu by default :param device_id: device id, e.g. 0 + :param vendor_id: If provided the device type should be "gpu" or "npu". """ + + device = OrtDevice.make(device_type, device_id, vendor_id)._get_c_device() + # Integer for onnx element type (see https://onnx.ai/onnx/api/mapping.html). # This is helpful for some data type (like TensorProto.BFLOAT16) that is not available in numpy. if isinstance(element_type, int): @@ -946,11 +1026,7 @@ def ortvalue_from_shape_and_type( C.OrtValue.ortvalue_from_shape_and_onnx_type( shape, element_type, - C.OrtDevice( - get_ort_device_type(device_type, device_id), - C.OrtDevice.default_memory(), - device_id, - ), + device, ) ) @@ -958,11 +1034,7 @@ def ortvalue_from_shape_and_type( C.OrtValue.ortvalue_from_shape_and_type( shape, element_type, - C.OrtDevice( - get_ort_device_type(device_type, device_id), - C.OrtDevice.default_memory(), - device_id, - ), + device, ) ) @@ -1085,14 +1157,27 @@ def _get_c_device(self): return self._ort_device @staticmethod - def make(ort_device_name, device_id): - return OrtDevice( - C.OrtDevice( - get_ort_device_type(ort_device_name, device_id), - C.OrtDevice.default_memory(), - device_id, + def make(ort_device_name, device_id, vendor_id=-1): + if vendor_id < 0: + # backwards compatibility with predefined OrtDevice names + return OrtDevice( + C.OrtDevice( + get_ort_device_type(ort_device_name), + C.OrtDevice.default_memory(), + device_id, + ) + ) + else: + # generic. use GPU or NPU for ort_device_name and provide a vendor id. + # vendor id of 0 is valid in some cases (e.g. webgpu is generic and does not have a vendor id) + return OrtDevice( + C.OrtDevice( + get_ort_device_type(ort_device_name), + C.OrtDevice.default_memory(), + vendor_id, + device_id, + ) ) - ) def device_id(self): return self._ort_device.device_id() @@ -1100,6 +1185,9 @@ def device_id(self): def device_type(self): return self._ort_device.device_type() + def device_vendor_id(self): + return self._ort_device.vendor_id() + class SparseTensor: """ @@ -1282,3 +1370,14 @@ def device_name(self) -> str: Returns the name of the device where the SparseTensor data buffers reside e.g. cpu, cuda """ return self._tensor.device_name().lower() + + +# Type hint for user-specified function that allows the user to specify initializer locations when compiling a model. +GetInitializerLocationFunc = Callable[ + [str, OrtValue, C.OrtExternalInitializerInfo | None], C.OrtExternalInitializerInfo | None +] + +# Type hint that adheres to the signature expected by ORT. +GetInitializerLocationWrapperFunc = Callable[ + [str, C.OrtValue, C.OrtExternalInitializerInfo | None], C.OrtExternalInitializerInfo | None +] diff --git a/onnxruntime/python/onnxruntime_pybind_exceptions.cc b/onnxruntime/python/onnxruntime_pybind_exceptions.cc index 8f3b97c8c7786..6b3062205b52e 100644 --- a/onnxruntime/python/onnxruntime_pybind_exceptions.cc +++ b/onnxruntime/python/onnxruntime_pybind_exceptions.cc @@ -37,6 +37,7 @@ void RegisterExceptions(pybind11::module& m) { pybind11::register_exception(m, "EPFail"); pybind11::register_exception(m, "ModelLoadCanceled"); pybind11::register_exception(m, "ModelRequiresCompilation"); + pybind11::register_exception(m, "NotFound"); } void OrtPybindThrowIfError(onnxruntime::common::Status status) { @@ -67,6 +68,8 @@ void OrtPybindThrowIfError(onnxruntime::common::Status status) { throw ModelLoadCanceled(std::move(msg)); case onnxruntime::common::StatusCode::MODEL_REQUIRES_COMPILATION: throw ModelRequiresCompilation(std::move(msg)); + case onnxruntime::common::StatusCode::NOT_FOUND: + throw NotFound(std::move(msg)); default: throw std::runtime_error(std::move(msg)); } diff --git a/onnxruntime/python/onnxruntime_pybind_exceptions.h b/onnxruntime/python/onnxruntime_pybind_exceptions.h index 86bc4a5da8d46..7680c06c59d79 100644 --- a/onnxruntime/python/onnxruntime_pybind_exceptions.h +++ b/onnxruntime/python/onnxruntime_pybind_exceptions.h @@ -50,6 +50,9 @@ struct ModelLoadCanceled : std::runtime_error { struct ModelRequiresCompilation : std::runtime_error { explicit ModelRequiresCompilation(const std::string& what) : std::runtime_error(what) {} }; +struct NotFound : std::runtime_error { + explicit NotFound(const std::string& what) : std::runtime_error(what) {} +}; void RegisterExceptions(pybind11::module& m); diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 958c9fc46bcd8..1934e0eda7956 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -99,6 +99,44 @@ TensorShape GetShape(const py::array& arr) { return shape; } +AllocatorPtr GetSharedAllocator(const OrtDevice& device) { + auto& env = GetOrtEnv()->GetEnvironment(); + + OrtMemoryInfo mem_info("ignored", OrtDeviceAllocator, device); + return env.GetRegisteredSharedAllocator(mem_info); +} + +MemCpyFunc CreateDataTransferMemCpy([[maybe_unused]] const OrtDevice& src_device, + [[maybe_unused]] const OrtDevice& dst_device) { +#if defined(ORT_MINIMAL_BUILD) + // plugin EPs are not supported in a minimal build so there won't be any data transfers registered + return nullptr; +#else + + auto& env = GetOrtEnv()->GetEnvironment(); + const DataTransferManager& data_transfer_manager = env.GetDataTransferManager(); + const IDataTransfer* data_transfer = data_transfer_manager.GetDataTransfer(src_device, dst_device); + if (!data_transfer) { + return nullptr; + } + + const auto copy_func = [src_device, dst_device, data_transfer](void* dst, const void* src, size_t bytes) { + OrtMemoryInfo src_memory_info("ignored", OrtDeviceAllocator, src_device); + OrtMemoryInfo dst_memory_info("ignored", OrtDeviceAllocator, dst_device); + + // real shape doesn't matter as the Tensor instances here are temporary in order to be able to call CopyTensor. + // we set the shape to `bytes` and the data type to uint8_t to copy the correct number of bytes. + TensorShape shape = {narrow(bytes)}; + Tensor src_tensor{DataTypeImpl::GetType(), shape, const_cast(src), src_memory_info}; + Tensor dst_tensor{DataTypeImpl::GetType(), shape, dst, dst_memory_info}; + + ORT_THROW_IF_ERROR(data_transfer->CopyTensor(src_tensor, dst_tensor)); + }; + + return copy_func; +#endif +} + void CpuToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { memcpy(dst, src, num_bytes); } @@ -158,9 +196,10 @@ void CudaToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { GetProviderInfo_CUDA().cudaMemcpy_DeviceToHost(dst, src, num_bytes); } -const std::unordered_map* GetCudaToHostMemCpyFunction() { - static std::unordered_map map{ - {OrtDevice::GPU, CudaToCpuMemCpy}}; +const std::unordered_map* GetCudaToHostMemCpyFunction(const OrtDevice& device) { + static std::unordered_map map{ + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, device.Id()}, CudaToCpuMemCpy}, + }; return ↦ } @@ -207,6 +246,7 @@ std::unique_ptr GetGPUDataTransfer() { #endif #ifdef USE_MIGRAPHX + void CpuToMIGraphXMemCpy(void* dst, const void* src, size_t num_bytes) { GetProviderInfo_MIGraphX().MIGraphXMemcpy_HostToDevice(dst, src, num_bytes); } @@ -215,9 +255,10 @@ void MIGraphXToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { GetProviderInfo_MIGraphX().MIGraphXMemcpy_DeviceToHost(dst, src, num_bytes); } -const std::unordered_map* GetMIGraphXToHostMemCpyFunction() { - static std::unordered_map map{ - {OrtDevice::GPU, MIGraphXToCpuMemCpy}}; +const std::unordered_map* GetMIGraphXToHostMemCpyFunction(const OrtDevice& device) { + static std::unordered_map map{ + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device.Id()}, MIGraphXToCpuMemCpy}, + }; return ↦ } @@ -230,7 +271,11 @@ AllocatorPtr GetMIGraphXAllocator(OrtDevice::DeviceId id) { if (id_to_allocator_map->find(id) == id_to_allocator_map->end()) { // TODO: Expose knobs so that users can set fields associated with OrtArenaCfg so that we can pass it to the following method - id_to_allocator_map->insert({id, GetProviderInfo_MIGraphX().CreateMIGraphXAllocator(id, gpu_mem_limit, arena_extend_strategy, migx_external_allocator_info, nullptr)}); + id_to_allocator_map->insert( + {id, GetProviderInfo_MIGraphX().CreateMIGraphXAllocator( + id, gpu_mem_limit, arena_extend_strategy, + migraphx::external::alloc_fn, migraphx::external::free_fn, migraphx::external::empty_cache_fn, + nullptr)}); } return (*id_to_allocator_map)[id]; @@ -334,9 +379,10 @@ void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { D3D12_RESOURCE_STATE_UNORDERED_ACCESS); } -const std::unordered_map* GetDmlToHostMemCpyFunction() { - static std::unordered_map map{ - {OrtDevice::GPU, DmlToCpuMemCpy}}; +const std::unordered_map* GetDmlToHostMemCpyFunction(const OrtDevice& device) { + static std::unordered_map map{ + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, device.Id()}, DmlToCpuMemCpy}, + }; return ↦ } @@ -352,9 +398,10 @@ void CannToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { GetProviderInfo_CANN().cannMemcpy_DeviceToHost(dst, src, num_bytes); } -const std::unordered_map* GetCannToHostMemCpyFunction() { - static std::unordered_map map{ - {OrtDevice::NPU, CannToCpuMemCpy}}; +const std::unordered_map* GetCannToHostMemCpyFunction() { + static std::unordered_map map{ + {OrtDevice{OrtDevice::NPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::HUAWEI, 0}, CannToCpuMemCpy}, + }; return ↦ } @@ -402,9 +449,10 @@ void RocmToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { GetProviderInfo_ROCM().rocmMemcpy_DeviceToHost(dst, src, num_bytes); } -const std::unordered_map* GetRocmToHostMemCpyFunction() { - static std::unordered_map map{ - {OrtDevice::GPU, RocmToCpuMemCpy}}; +const std::unordered_map* GetRocmToHostMemCpyFunction(const OrtDevice& device) { + static std::unordered_map map{ + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device.Id()}, RocmToCpuMemCpy}, + }; return ↦ } @@ -581,7 +629,7 @@ using OrtPybindSingleUseAllocatorPtr = std::shared_ptr& p_tensor, - MemCpyFunc mem_cpy_to_device = CpuToCpuMemCpy) { + const MemCpyFunc& mem_cpy_to_device = CpuToCpuMemCpy) { CopyDataToTensor(darray, npy_type, *p_tensor, mem_cpy_to_device); } -void CopyDataToTensor(const py::array& py_array, int npy_type, Tensor& tensor, MemCpyFunc mem_cpy_to_device) { +void CopyDataToTensor(const py::array& py_array, int npy_type, Tensor& tensor, const MemCpyFunc& mem_cpy_to_device) { CopyDataToTensor(reinterpret_cast(py_array.ptr()), npy_type, tensor, mem_cpy_to_device); } @@ -656,7 +704,7 @@ void CopyDataToTensor(const py::array& py_array, int npy_type, Tensor& tensor, M // The numpy object owns the memory and needs to be alive until the corresponding OrtValue is in scope static std::unique_ptr CreateTensor(const AllocatorPtr& alloc, const std::string& name_input, PyArrayObject* pyObject, bool use_numpy_data_memory = true, - MemCpyFunc mem_cpy_to_device = CpuToCpuMemCpy) { + const MemCpyFunc& mem_cpy_to_device = CpuToCpuMemCpy) { PyArrayObject* darray = PyArray_GETCONTIGUOUS(pyObject); ORT_ENFORCE(darray != nullptr, "The object must be a contiguous array for input '", name_input, "'."); @@ -746,7 +794,8 @@ static void CreateSequenceOfTensors(AllocatorPtr alloc, const std::string& name_ // as the backing data buffer for the ORT Tensor where applicable (for numeric tensors) // The numpy object owns the memory and needs to be alive until the corresponding OrtValue is in scope static void CreateTensorMLValue(const AllocatorPtr& alloc, const std::string& name_input, PyArrayObject* pyObject, - OrtValue* p_mlvalue, bool use_numpy_data_memory = true, MemCpyFunc mem_cpy_to_device = CpuToCpuMemCpy) { + OrtValue* p_mlvalue, bool use_numpy_data_memory = true, + const MemCpyFunc& mem_cpy_to_device = CpuToCpuMemCpy) { auto p_tensor = CreateTensor(alloc, name_input, pyObject, use_numpy_data_memory, mem_cpy_to_device); auto ml_tensor = DataTypeImpl::GetType(); @@ -994,9 +1043,10 @@ static void CreateGenericIterableMLValue(PyObject* iterator, AllocatorPtr alloc, // Setting `use_numpy_data_memory` to `true` will ensure that the underlying numpy array buffer is directly used // as the backing data buffer for the ORT Tensor where applicable (for numeric tensors) // The numpy object owns the memory and needs to be alive until the corresponding OrtValue is in scope -void CreateGenericMLValue(const onnxruntime::InputDefList* input_def_list, const AllocatorPtr& alloc, const std::string& name_input, - const py::object& value, OrtValue* p_mlvalue, bool accept_only_numpy_array, - bool use_numpy_data_memory, MemCpyFunc mem_cpy_to_device) { +void CreateGenericMLValue(const onnxruntime::InputDefList* input_def_list, const AllocatorPtr& alloc, + const std::string& name_input, const py::object& value, OrtValue* p_mlvalue, + bool accept_only_numpy_array, bool use_numpy_data_memory, + const MemCpyFunc& mem_cpy_to_device) { onnx::TypeProto type_proto; if (PyObjectCheck_NumpyArray(value.ptr())) { // The most frequent case: input comes as an array. diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.h b/onnxruntime/python/onnxruntime_pybind_mlvalue.h index e9bafea2ed1b5..eba783d826212 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.h +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.h @@ -42,22 +42,27 @@ MLDataType NumpyTypeToOnnxRuntimeTensorType(int numpy_type); MLDataType OnnxTypeToOnnxRuntimeTensorType(int onnx_element_type); -using MemCpyFunc = void (*)(void*, const void*, size_t); - +using MemCpyFunc = std::function; using DataTransferAlternative = std::variant; +// helpers to get allocator and IDataTransfer from Environment for plugin EP +AllocatorPtr GetSharedAllocator(const OrtDevice& device); +MemCpyFunc CreateDataTransferMemCpy(const OrtDevice& src_device, const OrtDevice& dst_device); + void CpuToCpuMemCpy(void*, const void*, size_t); -void CopyDataToTensor(const pybind11::array& py_array, int npy_type, Tensor& tensor, MemCpyFunc mem_cpy_to_device = CpuToCpuMemCpy); +void CopyDataToTensor(const pybind11::array& py_array, int npy_type, Tensor& tensor, + const MemCpyFunc& mem_cpy_to_device = CpuToCpuMemCpy); pybind11::object AddTensorAsPyObj(const OrtValue& val, const DataTransferManager* data_transfer_manager, - const std::unordered_map* mem_cpy_to_host_functions); + const std::unordered_map* mem_cpy_to_host_functions); -pybind11::object GetPyObjectFromSparseTensor(size_t pos, const OrtValue& ort_value, const DataTransferManager* data_transfer_manager); +pybind11::object GetPyObjectFromSparseTensor(size_t pos, const OrtValue& ort_value, + const DataTransferManager* data_transfer_manager); pybind11::object AddNonTensorAsPyObj(const OrtValue& val, const DataTransferManager* data_transfer_manager, - const std::unordered_map* mem_cpy_to_host_functions); + const std::unordered_map* mem_cpy_to_host_functions); OrtMemoryInfo GetMemoryInfoPerDeviceType(const OrtDevice& ort_device); @@ -69,7 +74,7 @@ void CpuToCudaMemCpy(void* dst, const void* src, size_t num_bytes); void CudaToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetCudaToHostMemCpyFunction(); +const std::unordered_map* GetCudaToHostMemCpyFunction(const OrtDevice&); bool IsCudaDeviceIdValid(const onnxruntime::logging::Logger& logger, int id); @@ -87,7 +92,7 @@ void CpuToDmlMemCpy(void* dst, const void* src, size_t num_bytes); void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetDmlToHostMemCpyFunction(); +const std::unordered_map* GetDmlToHostMemCpyFunction(const OrtDevice&); #endif @@ -97,7 +102,7 @@ void CpuToMIGraphXMemCpy(void* dst, const void* src, size_t num_bytes); void MIGraphXToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetMIGraphXToHostMemCpyFunction(); +const std::unordered_map* GetMIGraphXToHostMemCpyFunction(const OrtDevice&); AllocatorPtr GetMIGraphXAllocator(OrtDevice::DeviceId id); @@ -109,7 +114,7 @@ void CpuToCannMemCpy(void* dst, const void* src, size_t num_bytes); void CannToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetCannToHostMemCpyFunction(); +const std::unordered_map* GetCannToHostMemCpyFunction(); bool IsCannDeviceIdValid(const onnxruntime::logging::Logger& logger, int id); @@ -127,17 +132,18 @@ void CpuToRocmMemCpy(void* dst, const void* src, size_t num_bytes); void RocmToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetRocmToHostMemCpyFunction(); +const std::unordered_map* GetRocmToHostMemCpyFunction(const OrtDevice&); #endif void CreateGenericMLValue(const onnxruntime::InputDefList* input_def_list, const AllocatorPtr& alloc, const std::string& name_input, const pybind11::object& value, OrtValue* p_mlvalue, - bool accept_only_numpy_array = false, bool use_numpy_data_memory = true, MemCpyFunc mem_cpy_to_device = CpuToCpuMemCpy); + bool accept_only_numpy_array = false, bool use_numpy_data_memory = true, + const MemCpyFunc& mem_cpy_to_device = CpuToCpuMemCpy); pybind11::object GetPyObjFromTensor(const OrtValue& rtensor, const DataTransferManager* data_transfer_manager = nullptr, - const std::unordered_map* mem_cpy_to_host_functions = nullptr); + const std::unordered_map* mem_cpy_to_host_functions = nullptr); // The below two functions are used to convert OrtValue to numpy arrays diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc index e2b069b01f95b..6ff252b5d1353 100644 --- a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. +#if !defined(ORT_MINIMAL_BUILD) #include "python/onnxruntime_pybind_model_compiler.h" #include @@ -8,11 +9,56 @@ #include #include "core/common/common.h" #include "core/framework/error_code_helper.h" +#include "core/graph/abi_graph_types.h" #include "core/session/utils.h" namespace onnxruntime { namespace python { +/// +/// This function is called by ORT to allow the user to handle where every initializer is stored +/// (i.e., externally or internally). This function wraps (and calls) the actual Python function +/// provided by the user. +/// +/// Opaque state that holds a pointer to the user's Python function. +/// The name of the initializer to handle. +/// The OrtValue with the initializer's data, type, and shape. +/// The original external location of the initializer, if any. May be null. +/// Output parameter set to the initializer's new external location. Function may +/// return NULL if the initializer should be stored within the compiled ONNX model. +/// A status indicating success or an error. +static OrtStatus* ORT_API_CALL PyGetInitializerLocationFuncWrapper( + void* state, + const char* initializer_name, + const OrtValue* initializer_value, + const OrtExternalInitializerInfo* external_info, + /*out*/ OrtExternalInitializerInfo** new_external_info) { + PyGetInitializerLocationFunc* py_func = reinterpret_cast(state); + OrtStatus* status = nullptr; + std::shared_ptr py_new_external_info = nullptr; + + // Call the Python function and convert any exceptions to a status. + ORT_TRY { + py_new_external_info = (*py_func)(initializer_name, *initializer_value, external_info); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what())); + }); + } + + if (py_new_external_info) { + // ORT expects to take ownership of the new external info, so make a copy because other Python code + // may be holding a reference to the `py_new_external_info`. + auto py_result_copy = std::make_unique(*py_new_external_info.get()); + *new_external_info = py_result_copy.release(); + } else { + *new_external_info = nullptr; + } + + return status; +} + onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptr& out, onnxruntime::Environment& env, const PySessionOptions& sess_options, @@ -20,8 +66,11 @@ onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptr(env, sess_options, PrivateConstructorTag{}); + uint32_t flags, + GraphOptimizationLevel graph_optimization_level, + const PyGetInitializerLocationFunc& py_get_initializer_location_func) { + auto model_compiler = std::make_unique(env, sess_options, py_get_initializer_location_func, + PrivateConstructorTag{}); ModelCompilationOptions& compile_options = model_compiler->model_compile_options_; if (input_model_is_path) { @@ -43,6 +92,14 @@ onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptrpy_get_initializer_location_func_) { + compile_options.SetOutputModelGetInitializerLocationFunc( + PyGetInitializerLocationFuncWrapper, + reinterpret_cast(&model_compiler->py_get_initializer_location_func_)); + } + out = std::move(model_compiler); return Status::OK(); } @@ -77,9 +134,47 @@ onnxruntime::Status PyModelCompiler::CompileToBytes(std::string& output_buffer) return Status::OK(); } +/// +/// Function called by ORT to allow the user to write out the compiled ONNX model bytes to a custom output stream. +/// This function wraps (and calls) the actual Python function provided by the user. +/// +/// Opaque state that holds a pointer to the user's Python function. +/// The buffer to write out. Contains a portion of the compiled ONNX model's bytes. +/// The number of bytes in the buffer. +/// A status indicating success or an error. +static OrtStatus* ORT_API_CALL PyOutStreamWriteFuncWrapper(void* stream_state, const void* buffer, + size_t buffer_num_bytes) { + PyOutStreamWriteFunc* py_write_func = reinterpret_cast(stream_state); + OrtStatus* status = nullptr; + + // Call the Python write function and convert any exceptions to a status. + ORT_TRY { + pybind11::bytes py_bytes(reinterpret_cast(buffer), buffer_num_bytes); + (*py_write_func)(py_bytes); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what())); + }); + } + + return status; +} + +onnxruntime::Status PyModelCompiler::CompileToOutStream(PyOutStreamWriteFunc& write_func) { + model_compile_options_.SetOutputModelWriteFunc(PyOutStreamWriteFuncWrapper, + reinterpret_cast(&write_func)); + ORT_RETURN_IF_ERROR(onnxruntime::CompileModel(env_, model_compile_options_)); + return Status::OK(); +} + PyModelCompiler::PyModelCompiler(onnxruntime::Environment& env, const PySessionOptions& sess_options, + const PyGetInitializerLocationFunc& py_get_initializer_location_func, PrivateConstructorTag) - : env_(env), model_compile_options_(env, sess_options) { + : env_(env), + model_compile_options_(env, sess_options), + py_get_initializer_location_func_(py_get_initializer_location_func) { } } // namespace python } // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.h b/onnxruntime/python/onnxruntime_pybind_model_compiler.h index e61ae4674210b..957350accdba2 100644 --- a/onnxruntime/python/onnxruntime_pybind_model_compiler.h +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.h @@ -3,7 +3,6 @@ // Licensed under the MIT License. #pragma once -#if !defined(ORT_MINIMAL_BUILD) #include #include #include "core/common/status.h" @@ -14,11 +13,24 @@ namespace onnxruntime { class Environment; namespace python { +// Type of the function provided by Python code that is called by ORT to write out the compiled model. +using PyOutStreamWriteFunc = std::function; + +// Type of the function provided by Python code that is called by ORT to handle every initializer. +using PyGetInitializerLocationFunc = std::function( + const std::string& initializer_name, + const OrtValue& initializer_value, + const OrtExternalInitializerInfo* external_info)>; + /// /// Class exposed to Python that enables compiling ONNX models. /// Internally wraps a onnxruntime::ModelCompilationOptions that stores and validates settings. /// class PyModelCompiler { +#if defined(ORT_MINIMAL_BUILD) + public: + bool not_defined_in_this_build{}; // Prevent empty class warning. +#else private: // private tag to pass to constructor to ensure that constructor cannot be directly called externally struct PrivateConstructorTag {}; @@ -35,9 +47,12 @@ class PyModelCompiler { /// True to embed compiled binary data into EPContext nodes. /// The file into which to store initializers for non-compiled /// nodes. - /// Flags from OrtCompileApiFlags /// Ignored if 'external_initializers_file_path' is empty. /// Initializers with a size greater than this threshold are dumped into the external file. + /// Flags from OrtCompileApiFlags + /// Optimization level for graph transformations on the model. + /// Defaults to ORT_DISABLE_ALL to allow EP to get the original loaded model. + /// User's function to handle saving of initializers. /// A Status indicating error or success. static onnxruntime::Status Create(/*out*/ std::unique_ptr& out, onnxruntime::Environment& env, @@ -46,11 +61,14 @@ class PyModelCompiler { bool embed_compiled_data_into_model = false, const std::string& external_initializers_file_path = {}, size_t external_initializers_size_threshold = 1024, - size_t flags = 0); + uint32_t flags = 0, + GraphOptimizationLevel graph_opt_level = GraphOptimizationLevel::ORT_DISABLE_ALL, + const PyGetInitializerLocationFunc& py_get_initializer_location_func = nullptr); // Note: Creation should be done via Create(). This constructor is public so that it can be called from // std::make_shared(). PyModelCompiler(onnxruntime::Environment& env, const PySessionOptions& sess_options, + const PyGetInitializerLocationFunc& py_get_initializer_location_func, PrivateConstructorTag); /// @@ -70,11 +88,19 @@ class PyModelCompiler { /// A Status indicating error or success. onnxruntime::Status CompileToBytes(std::string& output_buffer); + /// + /// Compiles the input model and writes the result into the provided output stream (write functor). + /// + /// Write functor that encapsulates the stream's state. + /// A Status indicating error or success. + onnxruntime::Status CompileToOutStream(PyOutStreamWriteFunc& write_func); + private: onnxruntime::Environment& env_; onnxruntime::ModelCompilationOptions model_compile_options_; std::string input_model_bytes_; + PyGetInitializerLocationFunc py_get_initializer_location_func_; +#endif // defined(ORT_MINIMAL_BUILD) }; } // namespace python } // namespace onnxruntime -#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index d1d4d6f3cdad5..1fe7ab0884f9c 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -23,42 +23,57 @@ std::unique_ptr OrtValueFromShapeAndType(const std::vector& s MLDataType element_type, const OrtDevice& device) { AllocatorPtr allocator; + if (strcmp(GetDeviceName(device), CPU) == 0) { allocator = GetAllocator(); - } else if (strcmp(GetDeviceName(device), CUDA) == 0) { + } else { +#if !defined(ORT_MINIMAL_BUILD) + // prefer a shared allocator from the environment. + // these are provided by plugin EPs or custom allocators explicitly registered by the user. + allocator = GetSharedAllocator(device); +#endif + + if (!allocator) { + if (strcmp(GetDeviceName(device), CUDA) == 0) { #ifdef USE_CUDA - if (!IsCudaDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { - throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); - } - allocator = GetCudaAllocator(device.Id()); + if (!IsCudaDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { + throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); + } + + allocator = GetCudaAllocator(device.Id()); #else - throw std::runtime_error( - "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " - "Please use the CUDA package of OnnxRuntime to use this feature."); + throw std::runtime_error( + "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " + "Please use the CUDA package of OnnxRuntime to use this feature."); #endif - } else if (strcmp(GetDeviceName(device), HIP) == 0) { + } else if (strcmp(GetDeviceName(device), HIP) == 0) { #if USE_ROCM - if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { - throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); - } - allocator = GetRocmAllocator(device.Id()); + if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { + throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); + } + + allocator = GetRocmAllocator(device.Id()); #elif USE_MIGRAPHX - allocator = GetMIGraphXAllocator(device.Id()); + allocator = GetMIGraphXAllocator(device.Id()); #else - throw std::runtime_error( - "Can't allocate memory on the AMD device using this package of OnnxRuntime. " - "Please use the ROCm package of OnnxRuntime to use this feature."); + throw std::runtime_error( + "Can't allocate memory on the AMD device using this package of OnnxRuntime. " + "Please use the ROCm package of OnnxRuntime to use this feature."); #endif - } else if (strcmp(GetDeviceName(device), DML) == 0) { + } else if (strcmp(GetDeviceName(device), DML) == 0) { #if USE_DML - allocator = GetDmlAllocator(device.Id()); + allocator = GetDmlAllocator(device.Id()); #else - throw std::runtime_error( - "Can't allocate memory on the DirectML device using this package of OnnxRuntime. " - "Please use the DirectML package of OnnxRuntime to use this feature."); + throw std::runtime_error( + "Can't allocate memory on the DirectML device using this package of OnnxRuntime. " + "Please use the DirectML package of OnnxRuntime to use this feature."); #endif - } else { - throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device"); + } + } + + if (!allocator) { + throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device"); + } } auto ml_value = std::make_unique(); @@ -90,7 +105,8 @@ void addOrtValueMethods(pybind11::module& m) { if (device.Vendor() == OrtDevice::VendorIds::MICROSOFT) { // InputDeflist is null because OrtValue creation is not tied to a specific model // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) - // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML + // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors + // in DML CreateGenericMLValue( nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy); } else @@ -103,8 +119,10 @@ void addOrtValueMethods(pybind11::module& m) { // InputDeflist is null because OrtValue creation is not tied to a specific model // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) - // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA - CreateGenericMLValue(nullptr, GetCudaAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToCudaMemCpy); + // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors + // in CUDA + CreateGenericMLValue(nullptr, GetCudaAllocator(device.Id()), "", array_on_cpu, ml_value.get(), + true, false, CpuToCudaMemCpy); } else #endif #ifdef USE_ROCM @@ -115,22 +133,34 @@ void addOrtValueMethods(pybind11::module& m) { // InputDeflist is null because OrtValue creation is not tied to a specific model // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) - // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA - CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToRocmMemCpy); + // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors + // in ROCM + CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), + true, false, CpuToRocmMemCpy); } else #endif #if USE_MIGRAPHX if (device.Vendor() == OrtDevice::VendorIds::AMD) { // InputDeflist is null because OrtValue creation is not tied to a specific model // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) - // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in MIGraphX - CreateGenericMLValue(nullptr, GetMIGraphXAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToMIGraphXMemCpy); + // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors + // in MIGraphX + CreateGenericMLValue(nullptr, GetMIGraphXAllocator(device.Id()), "", array_on_cpu, ml_value.get(), + true, false, CpuToMIGraphXMemCpy); } else #endif { - throw std::runtime_error( - "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " - "Please use the CUDA package of OnnxRuntime to use this feature."); + // see if we can do the copy with an allocator and IDataTransfer registered by a plugin EP + auto allocator = GetSharedAllocator(device); + auto cpu_to_device_copy_fn = allocator ? CreateDataTransferMemCpy(OrtDevice{}, device) : nullptr; + if (cpu_to_device_copy_fn) { + CreateGenericMLValue(nullptr, allocator, "", array_on_cpu, ml_value.get(), true, false, + cpu_to_device_copy_fn); + } else { + throw std::runtime_error( + "Can't allocate memory on the device using this package of OnnxRuntime. " + "Please use the appropriate package of OnnxRuntime for your hardware to use this feature."); + } } } else if (device.Type() == OrtDevice::NPU && device.Vendor() == OrtDevice::VendorIds::HUAWEI) { #ifdef USE_CANN @@ -214,8 +244,16 @@ void addOrtValueMethods(pybind11::module& m) { } else #endif { - throw std::runtime_error( - "Unsupported GPU device: Cannot find the supported GPU device."); + // see if we can do the copy with an allocator and IDataTransfer registered by a plugin EP + auto allocator = GetSharedAllocator(device); + auto cpu_to_device_copy_fn = allocator ? CreateDataTransferMemCpy(OrtDevice{}, device) : nullptr; + if (cpu_to_device_copy_fn) { + onnxruntime::python::CopyDataToTensor(py_values, values_type, *(ml_value->GetMutable()), + cpu_to_device_copy_fn); + } else { + throw std::runtime_error( + "Unsupported GPU device: Cannot find the supported GPU device."); + } } } else if (device.Type() == OrtDevice::DML) { #if USE_DML @@ -383,21 +421,39 @@ void addOrtValueMethods(pybind11::module& m) { // Converts Tensor into a numpy array .def("numpy", [](const OrtValue* ml_value) -> py::object { ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are convertible to Numpy objects"); - + [[maybe_unused]] const auto& device = ml_value->Get().Location().device; +#ifdef _MSC_VER +// The switch statement may only contain the 'default' label. In such a case, the MSVC compiler +// will warn about it, and since the warnings are treated as errors, the compilation will break. +// Below pragmas turn off warning generation for this switch only. +#pragma warning(push) +#pragma warning(disable : 4065) +#endif + switch (device.Vendor()) { #ifdef USE_CUDA - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCudaToHostMemCpyFunction()); -#elif USE_ROCM - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetRocmToHostMemCpyFunction()); -#elif USE_CANN - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCannToHostMemCpyFunction()); -#elif USE_DML - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction()); -#elif USE_MIGRAPHX - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetMIGraphXToHostMemCpyFunction()); -#else - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, nullptr); + case OrtDevice::VendorIds::NVIDIA: + return GetPyObjFromTensor(*ml_value, nullptr, GetCudaToHostMemCpyFunction(device)); #endif - return obj; }) +#ifdef USE_CANN + case OrtDevice::VendorIds::HUAWEI: + return GetPyObjFromTensor(*ml_value, nullptr, GetCannToHostMemCpyFunction()); +#endif + +#ifdef USE_DML + case OrtDevice::VendorIds::MICROSOFT: + return GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction(device)); +#endif +#ifdef USE_MIGRAPHX + case OrtDevice::VendorIds::AMD: + return GetPyObjFromTensor(*ml_value, nullptr, GetMIGraphXToHostMemCpyFunction(device)); +#endif + default: + return GetPyObjFromTensor(*ml_value, nullptr, nullptr); + } +#ifdef _MSC_VER +#pragma warning(pop) +#endif + }) #if defined(ENABLE_DLPACK) .def("to_dlpack", [](OrtValue* ort_value) -> py::object { return py::reinterpret_steal(ToDlpack(*ort_value)); }, "Returns a DLPack representing the tensor. This method does not copy the pointer shape, " diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index ec4d8c6330c8d..e370518b1fffb 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -6,11 +6,8 @@ #include #include "python/onnxruntime_pybind_exceptions.h" #include "python/onnxruntime_pybind_mlvalue.h" -#include "python/onnxruntime_pybind_state_common.h" - -#if !defined(ORT_MINIMAL_BUILD) #include "python/onnxruntime_pybind_model_compiler.h" -#endif // !defined(ORT_MINIMAL_BUILD) +#include "python/onnxruntime_pybind_state_common.h" #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #define PY_ARRAY_UNIQUE_SYMBOL onnxruntime_python_ARRAY_API @@ -45,8 +42,9 @@ #include "core/session/lora_adapters.h" #if !defined(ORT_MINIMAL_BUILD) +#include "core/graph/abi_graph_types.h" #include "core/session/abi_devices.h" -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/provider_policy_context.h" #include "core/session/utils.h" #endif @@ -205,7 +203,7 @@ void AppendLoraParametersAsInputs(const RunOptions& run_options, template static py::object AddNonTensor(const OrtValue& val, const DataTransferManager* /*data_transfer_manager*/, - const std::unordered_map* /*mem_cpy_to_host_functions*/) { + const std::unordered_map* /*mem_cpy_to_host_functions*/) { return py::cast(val.Get()); } @@ -265,39 +263,65 @@ pybind11::array PrimitiveTensorToNumpyFromDevice(const OrtValue& ort_value, cons // pretty much does what a DataTransferManager does - copy data from device(s) to the host py::object GetPyObjFromTensor(const OrtValue& ort_value, const DataTransferManager* data_transfer_manager, - const std::unordered_map* mem_cpy_to_host_functions) { + const std::unordered_map* mem_cpy_to_host_functions) { ORT_ENFORCE(ort_value.IsTensor(), "This function only supports tensors"); const auto& tensor = ort_value.Get(); + const auto& device = tensor.Location().device; + if (tensor.IsDataTypeString()) { - ORT_ENFORCE(tensor.Location().device.Type() == OrtDevice::CPU, "Strings can only be on CPU"); + ORT_ENFORCE(device.Type() == OrtDevice::CPU, "Strings can only be on CPU"); // Create a numpy array of strings (python objects) by copy/converting them py::array result = StringTensorToNumpyArray(tensor); return py::cast(result); } - const auto device_type = tensor.Location().device.Type(); + const auto device_type = device.Type(); // Create an numpy array on top of the OrtValue memory, no copy if (device_type == OrtDevice::CPU) { py::array result = PrimitiveTensorToNumpyOverOrtValue(ort_value); return py::cast(result); } - if (!data_transfer_manager && !mem_cpy_to_host_functions) { - throw std::runtime_error( - "GetPyObjFromTensor: Either data transfer manager or a " - "function to copy data to the host is needed to convert non-CPU tensor to numpy array"); - } - py::array result; if (data_transfer_manager != nullptr) { result = PrimitiveTensorToNumpyFromDevice(ort_value, data_transfer_manager); } else { - auto mem_cpy_to_host = mem_cpy_to_host_functions->find(device_type); - ORT_ENFORCE(mem_cpy_to_host != mem_cpy_to_host_functions->end(), - "Unable to locate a function that can copy data to the host from the device"); - result = PrimitiveTensorToNumpyFromDevice(ort_value, mem_cpy_to_host->second); + bool copied = false; + if (mem_cpy_to_host_functions) { + auto it = std::find_if(mem_cpy_to_host_functions->begin(), mem_cpy_to_host_functions->end(), + [&device](const auto& entry) { + const auto& copy_device = entry.first; + // We're ignoring OrtDevice.Id() currently for historical reasons. + // The key to mem_cpy_to_host_functions was previously the device type (CPU/GPU/NPU). + // This changed to be OrtDevice to get the vendor id. + // Assumably it would be better to also match on device id, but that was not possible + // previously and to preserve existing behavior we keep the old logic and expect the + // copy function to handle the device id correctly. + return device.Type() == copy_device.Type() && + device.MemType() == copy_device.MemType() && + device.Vendor() == copy_device.Vendor(); + }); + + if (it != mem_cpy_to_host_functions->end()) { + result = PrimitiveTensorToNumpyFromDevice(ort_value, it->second); + copied = true; + } + } + + if (!copied) { + // see if we have a shared data transfer function from a plugin EP + auto device_to_cpu_copy_func = CreateDataTransferMemCpy(device, OrtDevice{}); + if (device_to_cpu_copy_func) { + result = PrimitiveTensorToNumpyFromDevice(ort_value, device_to_cpu_copy_func); + } else { + throw std::runtime_error( + "GetPyObjFromTensor: Either data transfer manager or a " + "function to copy data to the host is needed to convert non-CPU tensor to numpy array"); + } + } } + return py::cast(result); } @@ -373,7 +397,7 @@ py::object GetPyObjectFromSparseTensor(size_t pos, const OrtValue& ort_value, co template <> py::object AddNonTensor(const OrtValue& val, const DataTransferManager* data_transfer_manager, - const std::unordered_map* mem_cpy_to_host_functions) { + const std::unordered_map* mem_cpy_to_host_functions) { const auto& seq_tensors = val.Get(); py::list py_list; for (const auto& ort_value : seq_tensors) { @@ -389,7 +413,7 @@ py::object AddNonTensor(const OrtValue& val, py::object AddNonTensorAsPyObj(const OrtValue& val, const DataTransferManager* data_transfer_manager, - const std::unordered_map* mem_cpy_to_host_functions) { + const std::unordered_map* mem_cpy_to_host_functions) { // Should be in sync with core/framework/datatypes.h auto val_type = val.Type(); if (val_type->IsTensorSequenceType()) { @@ -429,7 +453,7 @@ py::object AddNonTensorAsPyObj(const OrtValue& val, } py::object AddTensorAsPyObj(const OrtValue& val, const DataTransferManager* data_transfer_manager, - const std::unordered_map* mem_cpy_to_host_functions) { + const std::unordered_map* mem_cpy_to_host_functions) { return GetPyObjFromTensor(val, data_transfer_manager, mem_cpy_to_host_functions); } @@ -953,135 +977,10 @@ static std::shared_ptr CreateExecutionProviderFactory #endif } else if (type == kMIGraphXExecutionProvider) { #if defined(USE_MIGRAPHX) || defined(USE_MIGRAPHX_PROVIDER_INTERFACE) - std::string calibration_table; - std::string save_model_path; - std::string load_model_path; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { - OrtMIGraphXProviderOptions params{ - 0, - 0, - 0, - 0, - 0, - nullptr, - 1, - "./compiled_model.mxr", - 1, - "./compiled_model.mxr", - 1, - SIZE_MAX, - 0}; - for (auto option : it->second) { - if (option.first == "device_id") { - if (!option.second.empty()) { - params.device_id = std::stoi(option.second); - } else { - ORT_THROW("[ERROR] [MIGraphX] The value for the key 'device_id' should be a number i.e. '0'.\n"); - } - } else if (option.first == "migraphx_fp16_enable") { - if (option.second == "True" || option.second == "true") { - params.migraphx_fp16_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_fp16_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_fp16_enable' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_fp8_enable") { - if (option.second == "True" || option.second == "true") { - params.migraphx_fp8_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_fp8_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_fp8_enable' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_int8_enable") { - if (option.second == "True" || option.second == "true") { - params.migraphx_int8_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_int8_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_int8_enable' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_int8_calibration_table_name") { - if (!option.second.empty()) { - calibration_table = option.second; - params.migraphx_int8_calibration_table_name = calibration_table.c_str(); - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_int8_calibration_table_name' should be a " - "file name i.e. 'cal_table'.\n"); - } - } else if (option.first == "migraphx_use_native_calibration_table") { - if (option.second == "True" || option.second == "true") { - params.migraphx_use_native_calibration_table = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_use_native_calibration_table = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_use_native_calibration_table' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_save_compiled_model") { - if (option.second == "True" || option.second == "true") { - params.migraphx_fp16_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_fp16_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_save_compiled_model' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_save_model_path") { - if (!option.second.empty()) { - save_model_path = option.second; - params.migraphx_save_model_path = save_model_path.c_str(); - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_save_model_name' should be a " - "file name i.e. 'compiled_model.mxr'.\n"); - } - } else if (option.first == "migraphx_load_compiled_model") { - if (option.second == "True" || option.second == "true") { - params.migraphx_fp16_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_fp16_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_load_compiled_model' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_load_model_path") { - if (!option.second.empty()) { - load_model_path = option.second; - params.migraphx_load_model_path = load_model_path.c_str(); - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_load_model_name' should be a " - "file name i.e. 'compiled_model.mxr'.\n"); - } - } else if (option.first == "migraphx_exhaustive_tune") { - if (option.second == "True" || option.second == "true") { - params.migraphx_exhaustive_tune = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_exhaustive_tune = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_exhaustive_tune' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else { - ORT_THROW("Invalid MIGraphX EP option: ", option.first); - } - } if (std::shared_ptr migraphx_provider_factory = - onnxruntime::MIGraphXProviderFactoryCreator::Create(¶ms)) { + onnxruntime::MIGraphXProviderFactoryCreator::Create(it->second)) { return migraphx_provider_factory; } } else { @@ -1674,6 +1573,17 @@ void addGlobalMethods(py::module& m) { R"pbdoc(Get the list of available OrtEpDevice instances.)pbdoc", py::return_value_policy::reference); + m.def( + "get_model_compatibility_for_ep_devices", + [](const std::vector& ep_devices, + const std::string& compatibility_info) -> OrtCompiledModelCompatibility { + OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + Ort::ThrowOnError(Ort::GetApi().GetModelCompatibilityForEpDevices( + ep_devices.data(), ep_devices.size(), compatibility_info.c_str(), &status)); + return status; + }, + R"pbdoc("Validate a compiled model's compatibility information for one or more EP devices.)pbdoc"); + #if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) m.def( "get_available_openvino_device_ids", []() -> std::vector { @@ -1858,6 +1768,12 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .value("PRIORITY_BASED", ExecutionOrder::PRIORITY_BASED) .value("MEMORY_EFFICIENT", ExecutionOrder::MEMORY_EFFICIENT); + py::enum_(m, "OrtCompiledModelCompatibility") + .value("EP_NOT_APPLICABLE", OrtCompiledModelCompatibility_EP_NOT_APPLICABLE) + .value("EP_SUPPORTED_OPTIMAL", OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL) + .value("EP_SUPPORTED_PREFER_RECOMPILATION", OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION) + .value("EP_UNSUPPORTED", OrtCompiledModelCompatibility_EP_UNSUPPORTED); + py::enum_(m, "OrtAllocatorType") .value("INVALID", OrtInvalidAllocator) .value("ORT_DEVICE_ALLOCATOR", OrtDeviceAllocator) @@ -1881,25 +1797,31 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra type = OrtDevice::GPU; vendor = OrtDevice::VendorIds::MICROSOFT; } else if (type == OrtDevice::GPU) { -#if USE_CUDA +#if USE_CUDA || USE_NV || USE_NV_PROVIDER_INTERFACE || USE_CUDA_PROVIDER_INTERFACE vendor = OrtDevice::VendorIds::NVIDIA; #elif USE_ROCM || USE_MIGRAPHX vendor = OrtDevice::VendorIds::AMD; +#endif + } else if (type == OrtDevice::NPU) { +#if USE_CANN + vendor = OrtDevice::VendorIds::HUAWEI; #endif } - return OrtDevice(type, mem_type, vendor, device_id); }), R"pbdoc(Constructor with vendor_id defaulted to 0 for backward compatibility.)pbdoc") .def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc") .def("device_type", &OrtDevice::Type, R"pbdoc(Device Type.)pbdoc") .def("vendor_id", &OrtDevice::Vendor, R"pbdoc(Vendor Id.)pbdoc") + // generic device types that are typically used with a vendor id. .def_static("cpu", []() { return OrtDevice::CPU; }) + .def_static("gpu", []() { return OrtDevice::GPU; }) + .def_static("npu", []() { return OrtDevice::NPU; }) + // EP specific device types for backward compatibility. .def_static("cuda", []() { return OrtDevice::GPU; }) .def_static("cann", []() { return OrtDevice::NPU; }) - .def_static("fpga", []() { return OrtDevice::FPGA; }) - .def_static("npu", []() { return OrtDevice::NPU; }) .def_static("dml", []() { return OrtDevice::DML; }) + .def_static("fpga", []() { return OrtDevice::FPGA; }) .def_static("webgpu", []() { return OrtDevice::GPU; }) .def_static("default_memory", []() { return OrtDevice::MemType::DEFAULT; }); @@ -2806,6 +2728,35 @@ including arg name, arg type (contains both type and shape).)pbdoc") .value("kSameAsRequested", onnxruntime::ArenaExtendStrategy::kSameAsRequested) .export_values(); + // Must use a std::shared_ptr to hold OrtExternalInitializerInfo because the same instances is passed + // between C++ and Python (and Python cannot transfer ownership to C++). + py::class_> ort_external_initializer_info_binding( + m, "OrtExternalInitializerInfo", + R"pbdoc(Location information for initializer data stored in an external file)pbdoc"); + ort_external_initializer_info_binding + .def(py::init([](const std::basic_string& filepath, int64_t file_offset, size_t byte_size) { +#if !defined(ORT_MINIMAL_BUILD) + return std::make_shared(filepath, file_offset, byte_size); +#else + ORT_UNUSED_PARAMETER(filepath); + ORT_UNUSED_PARAMETER(file_offset); + ORT_UNUSED_PARAMETER(byte_size); + ORT_THROW("OrtExternalInitializerInfo creation is not supported in this build"); +#endif + })) + .def_property_readonly( + "filepath", + [](OrtExternalInitializerInfo* info) -> std::basic_string { return info->GetRelPath(); }, + R"pbdoc(The relative path to the file in which initializer data is stored.)pbdoc") + .def_property_readonly( + "file_offset", + [](OrtExternalInitializerInfo* info) -> int64_t { return info->GetOffset(); }, + R"pbdoc(The file byte offset where the initializer data is stored.)pbdoc") + .def_property_readonly( + "byte_size", + [](OrtExternalInitializerInfo* info) -> size_t { return info->GetLength(); }, + R"pbdoc(The byte size of the initializer data in the file.)pbdoc"); + py::enum_(m, "OrtCompileApiFlags", py::arithmetic()) .value("NONE", OrtCompileApiFlags_NONE) .value("ERROR_IF_NO_NODES_COMPILED", OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED) @@ -2819,7 +2770,9 @@ including arg name, arg type (contains both type and shape).)pbdoc") bool embed_compiled_data_into_model = false, std::string external_initializers_file_path = {}, size_t external_initializers_size_threshold = 1024, - size_t flags = OrtCompileApiFlags_NONE) { + uint32_t flags = OrtCompileApiFlags_NONE, + GraphOptimizationLevel graph_optimization_level = GraphOptimizationLevel::ORT_DISABLE_ALL, + const PyGetInitializerLocationFunc& py_get_initializer_location_func = nullptr) { #if !defined(ORT_MINIMAL_BUILD) std::unique_ptr result; OrtPybindThrowIfError(PyModelCompiler::Create(result, GetEnv(), sess_options, @@ -2827,7 +2780,8 @@ including arg name, arg type (contains both type and shape).)pbdoc") embed_compiled_data_into_model, external_initializers_file_path, external_initializers_size_threshold, - flags)); + flags, graph_optimization_level, + py_get_initializer_location_func)); return result; #else ORT_UNUSED_PARAMETER(sess_options); @@ -2837,6 +2791,8 @@ including arg name, arg type (contains both type and shape).)pbdoc") ORT_UNUSED_PARAMETER(external_initializers_file_path); ORT_UNUSED_PARAMETER(external_initializers_size_threshold); ORT_UNUSED_PARAMETER(flags); + ORT_UNUSED_PARAMETER(graph_optimization_level); + ORT_UNUSED_PARAMETER(py_get_initializer_location_func); ORT_THROW("Compile API is not supported in this build."); #endif })) @@ -2864,7 +2820,19 @@ including arg name, arg type (contains both type and shape).)pbdoc") ORT_THROW("Compile API is not supported in this build."); #endif }, - R"pbdoc(Compile an ONNX model into a buffer.)pbdoc"); + R"pbdoc(Compile an ONNX model into a buffer.)pbdoc") + .def( + "compile_to_stream", + [](PyModelCompiler* model_compiler, PyOutStreamWriteFunc& py_stream_write_func) { +#if !defined(ORT_MINIMAL_BUILD) + OrtPybindThrowIfError(model_compiler->CompileToOutStream(py_stream_write_func)); +#else + ORT_UNUSED_PARAMETER(model_compiler); + ORT_UNUSED_PARAMETER(py_stream_write_func); + ORT_THROW("Compile API is not supported in this build."); +#endif + }, + R"pbdoc(Compile an ONNX model into an output stream using the provided write functor.)pbdoc"); } bool InitArray() { diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.cc b/onnxruntime/python/onnxruntime_pybind_state_common.cc index 4b9e012764885..cccdb9d23900a 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.cc +++ b/onnxruntime/python/onnxruntime_pybind_state_common.cc @@ -47,7 +47,11 @@ onnxruntime::ArenaExtendStrategy arena_extend_strategy = onnxruntime::ArenaExten #endif #ifdef USE_MIGRAPHX -onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo migx_external_allocator_info{}; +namespace migraphx::external { +void* alloc_fn{nullptr}; +void* free_fn{nullptr}; +void* empty_cache_fn{nullptr}; +} // namespace migraphx::external #endif #if defined(ENABLE_DLPACK) diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 706c151936192..b4a33e798f942 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -40,7 +40,7 @@ struct OrtStatus { #define BACKEND_PROC "CPU" #endif -#if USE_DNNL +#ifdef USE_DNNL #define BACKEND_DNNL "-DNNL" #else #define BACKEND_DNNL "" @@ -226,9 +226,14 @@ extern onnxruntime::ArenaExtendStrategy arena_extend_strategy; namespace onnxruntime { ProviderInfo_MIGraphX* TryGetProviderInfo_MIGraphX(); ProviderInfo_MIGraphX& GetProviderInfo_MIGraphX(); -namespace python { -extern onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo migx_external_allocator_info; -} // namespace python +namespace python::migraphx::external { +extern void* alloc_fn; +extern void* free_fn; +extern void* empty_cache_fn; +inline bool UseExternalAllocator() { + return alloc_fn != nullptr && free_fn != nullptr; +} +} // namespace python::migraphx::external } // namespace onnxruntime #endif diff --git a/onnxruntime/python/tools/tensorrt/perf/mem_test/main.cpp b/onnxruntime/python/tools/tensorrt/perf/mem_test/main.cpp index ec30b8ba0985d..2550fde338cd5 100644 --- a/onnxruntime/python/tools/tensorrt/perf/mem_test/main.cpp +++ b/onnxruntime/python/tools/tensorrt/perf/mem_test/main.cpp @@ -10,21 +10,15 @@ void run_ort_trt2() { Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test"); - const auto& api = Ort::GetApi(); - OrtTensorRTProviderOptionsV2* tensorrt_options; Ort::SessionOptions session_options; session_options.SetIntraOpNumThreads(1); - session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); const char* model_path = "squeezenet.onnx"; - Ort::ThrowOnError(api.CreateTensorRTProviderOptions(&tensorrt_options)); - std::unique_ptr rel_trt_options( - tensorrt_options, api.ReleaseTensorRTProviderOptions); - Ort::ThrowOnError(api.SessionOptionsAppendExecutionProvider_TensorRT_V2(static_cast(session_options), - rel_trt_options.get())); + Ort::TensorRTProviderOptions tensorrt_options; + session_options.AppendExecutionProvider_TensorRT_V2(*tensorrt_options); std::cout << "Running ORT TRT EP with default provider options" << std::endl; @@ -127,7 +121,7 @@ void run_ort_trt2() { void run_ort_trt() { Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test"); const auto& api = Ort::GetApi(); - OrtTensorRTProviderOptionsV2* tensorrt_options; + Ort::TensorRTProviderOptions tensorrt_options; Ort::SessionOptions session_options; session_options.SetIntraOpNumThreads(1); @@ -136,11 +130,7 @@ void run_ort_trt() { const char* model_path = "/data/ep-perf-models/onnx-zoo-models/squeezenet1.0-7/squeezenet/model.onnx"; - Ort::ThrowOnError(api.CreateTensorRTProviderOptions(&tensorrt_options)); - std::unique_ptr rel_trt_options( - tensorrt_options, api.ReleaseTensorRTProviderOptions); - Ort::ThrowOnError(api.SessionOptionsAppendExecutionProvider_TensorRT_V2(static_cast(session_options), - rel_trt_options.get())); + session_options.AppendExecutionProvider_TensorRT_V2(*tensorrt_options); std::cout << "Running ORT TRT EP with default provider options" << std::endl; diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index a236c4da1738e..85b3632c516ca 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -6,7 +6,6 @@ import json import logging import os -import textwrap from pathlib import Path import numpy as np @@ -93,16 +92,10 @@ def save_processing( if separate_encoder_and_decoder_init: return - audio_processor_json = textwrap.dedent("""\ - { + audio_processor_cfg = { "feature_extraction": { "sequence": [ - { - "operation": { - "name": "audio_decoder", - "type": "AudioDecoder" - } - }, + {"operation": {"name": "audio_decoder", "type": "AudioDecoder"}}, { "operation": { "name": "STFT", @@ -511,27 +504,23 @@ def save_processing( 0.000986635684967041, 0.0005550682544708252, 0.0002467334270477295, - 0.0000616908073425293 - ] - } + 0.0000616908073425293, + ], + }, } }, { "operation": { "name": "log_mel_spectrogram", "type": "LogMelSpectrum", - "attrs": { - "chunk_size": 30, - "hop_length": 160, - "n_fft": 400, - "n_mel": 80 - } + "attrs": {"chunk_size": 30, "hop_length": 160, "n_fft": 400, "n_mel": config.num_mel_bins}, } - } + }, ] } } - """) + audio_processor_json = json.dumps(audio_processor_cfg, indent=4) + with open(os.path.join(output_dir, "audio_processor_config.json"), "w") as f: f.write(audio_processor_json) diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index 287eba05a0595..e4265713d2d0a 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -33,95 +33,88 @@ struct MulKernel { return iter != float_initializers.end() ? &iter->second : nullptr; } - OrtStatus* GetInputDataAndShape(OrtKernelContext* kernel_context, size_t index, - /*out*/ gsl::span& data, - /*out*/ std::vector& shape) const { - const OrtValue* input = nullptr; - RETURN_IF_ERROR(ort_api.KernelContext_GetInput(kernel_context, index, &input)); - - OrtTensorTypeAndShapeInfo* type_shape = nullptr; - DeferOrtRelease release_type(&type_shape, ort_api.ReleaseTensorTypeAndShapeInfo); - - RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(input, &type_shape)); - - ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape, &elem_type)); - RETURN_IF(elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ort_api, "Expected float32 inputs"); - - size_t num_elems = 0; - RETURN_IF_ERROR(ort_api.GetTensorShapeElementCount(type_shape, &num_elems)); - - size_t num_dims = 0; - RETURN_IF_ERROR(ort_api.GetDimensionsCount(type_shape, &num_dims)); - - shape.resize(num_dims, 0); - RETURN_IF_ERROR(ort_api.GetDimensions(type_shape, shape.data(), shape.size())); - - const void* raw_data = nullptr; - RETURN_IF_ERROR(ort_api.GetTensorData(input, &raw_data)); - - const float* float_data = static_cast(raw_data); + void GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index, + /*out*/ gsl::span& data, + /*out*/ std::vector& shape) const { + Ort::ConstValue input = kernel_context.GetInput(index); + auto type_shape = input.GetTensorTypeAndShapeInfo(); + + ONNXTensorElementDataType elem_type = type_shape.GetElementType(); + if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) + throw Ort::Exception("EP Expected float32 inputs", ORT_EP_FAIL); + + const float* float_data = input.GetTensorData(); + size_t num_elems = type_shape.GetElementCount(); data = gsl::span(float_data, num_elems); - return nullptr; + shape = type_shape.GetShape(); } - OrtStatus* Compute(OrtKernelContext* kernel_context) { + OrtStatus* Compute(OrtKernelContext* kernel_ctx) { RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); - gsl::span input0; - gsl::span input1; - std::vector shape0; - std::vector shape1; - - size_t num_inputs = 0; - RETURN_IF_ERROR(ort_api.KernelContext_GetInputCount(kernel_context, &num_inputs)); - - if (num_inputs == 2) { - // Both inputs are non-constant. Get them from ORT's KernelContext. - RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 0, input0, shape0)); - RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 1, input1, shape1)); - } else if (num_inputs == 1) { - // ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs. - // Get the constant input from the initializers saved by the EP. - // Refer to "NodeFusionOptions_DropConstantInitializers()". - - if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) { - RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 0, input1, shape1)); + Ort::KernelContext kernel_context(kernel_ctx); + try { + gsl::span input0; + gsl::span input1; + std::vector shape0; + std::vector shape1; + + size_t num_inputs = kernel_context.GetInputCount(); + + if (num_inputs == 2) { + // Both inputs are non-constant. Get them from ORT's KernelContext. + GetInputDataAndShape(kernel_context, 0, input0, shape0); + GetInputDataAndShape(kernel_context, 1, input1, shape1); + } else if (num_inputs == 1) { + // ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs. + // Get the constant input from the initializers saved by the EP. + // Refer to "NodeFusionOptions_DropConstantInitializers()". + + if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) { + GetInputDataAndShape(kernel_context, 0, input1, shape1); + input0 = gsl::span(const_input0->data); + shape0 = const_input0->shape; + } else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) { + GetInputDataAndShape(kernel_context, 0, input0, shape0); + input1 = gsl::span(const_input1->data); + shape1 = const_input1->shape; + } + } else { + // Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding) + // are disabled. + const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); + const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); + RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api, + "Expected 2 initializer inputs to be saved by EP"); + input0 = gsl::span(const_input0->data); - shape0 = const_input0->shape; - } else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) { - RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 0, input0, shape0)); input1 = gsl::span(const_input1->data); + shape0 = const_input0->shape; shape1 = const_input1->shape; } - } else { - // Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding) - // are disabled. - const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); - const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); - RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api, - "Expected 2 initializer inputs to be saved by EP"); - - input0 = gsl::span(const_input0->data); - input1 = gsl::span(const_input1->data); - shape0 = const_input0->shape; - shape1 = const_input1->shape; - } - RETURN_IF(shape0 != shape1, ort_api, "Expected same dimensions for both inputs"); // No broadcasting. + if (shape0 != shape1) { + throw Ort::Exception("Expected same dimensions for both inputs", ORT_INVALID_ARGUMENT); + } - size_t num_outputs = 0; - RETURN_IF_ERROR(ort_api.KernelContext_GetOutputCount(kernel_context, &num_outputs)); - RETURN_IF(num_outputs != 1, ort_api, "Expected 1 output for MulKernel"); + size_t num_outputs = kernel_context.GetOutputCount(); + if (num_outputs != 1) { + throw Ort::Exception("Expected 1 output for MulKernel", ORT_INVALID_ARGUMENT); + } - OrtValue* output = nullptr; - float* output_data = nullptr; - RETURN_IF_ERROR(ort_api.KernelContext_GetOutput(kernel_context, 0, shape0.data(), shape0.size(), &output)); - RETURN_IF_ERROR(ort_api.GetTensorMutableData(output, reinterpret_cast(&output_data))); + auto output = kernel_context.GetOutput(0, shape0); + float* output_data = output.GetTensorMutableData(); - for (size_t i = 0; i < input0.size(); ++i) { - output_data[i] = input0[i] * input1[i]; + for (size_t i = 0; i < input0.size(); ++i) { + output_data[i] = input0[i] * input1[i]; + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); } return nullptr; @@ -183,178 +176,175 @@ const char* ORT_API_CALL ExampleEp ::GetNameImpl(const OrtEp* this_ptr) noexcept return ep->name_.c_str(); } -OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* graph) { - size_t num_initializers = 0; - RETURN_IF_ERROR(ort_api.Graph_GetNumInitializers(graph, &num_initializers)); - - std::vector initializers(num_initializers); - RETURN_IF_ERROR(ort_api.Graph_GetInitializers(graph, initializers.data(), initializers.size())); - - for (const OrtValueInfo* initializer : initializers) { - bool is_constant = false; - RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(initializer, &is_constant)); - - if (is_constant) { - const char* name = nullptr; - const OrtValue* value = nullptr; - OrtTensorTypeAndShapeInfo* type_shape = nullptr; - DeferOrtRelease release_type(&type_shape, ort_api.ReleaseTensorTypeAndShapeInfo); - size_t num_elems = 0; +OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* ort_graph) { + Ort::ConstGraph graph{ort_graph}; - RETURN_IF_ERROR(ort_api.GetValueInfoName(initializer, &name)); - RETURN_IF_ERROR(ort_api.ValueInfo_GetInitializerValue(initializer, &value)); - RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(value, &type_shape)); - RETURN_IF_ERROR(ort_api.GetTensorShapeElementCount(type_shape, &num_elems)); + try { + std::vector initializers = graph.GetInitializers(); - ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape, &elem_type)); - RETURN_IF(elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ort_api, "Expected float32 initializers"); + for (const auto& initializer : initializers) { + const bool is_constant = initializer.IsConstantInitializer(); - size_t num_dims = 0; - RETURN_IF_ERROR(ort_api.GetDimensionsCount(type_shape, &num_dims)); + if (is_constant) { + auto name = initializer.GetName(); + Ort::ConstValue value; + auto status = initializer.GetInitializer(value); + if (!status.IsOK()) + return status.release(); - std::vector dims(num_dims, 0); - RETURN_IF_ERROR(ort_api.GetDimensions(type_shape, dims.data(), dims.size())); + auto type_shape = value.GetTensorTypeAndShapeInfo(); + const size_t num_elems = type_shape.GetElementCount(); + const ONNXTensorElementDataType elem_type = type_shape.GetElementType(); + if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) + return Ort::Status("Expected float32 initializers", ORT_INVALID_ARGUMENT).release(); - const float* data = nullptr; - RETURN_IF_ERROR(ort_api.GetTensorMutableData(const_cast(value), (void**)&data)); + std::vector dims = type_shape.GetShape(); + const float* data = value.GetTensorData(); - FloatInitializer ep_initializer = {std::move(dims), std::vector(data, data + num_elems)}; - float_initializers_.emplace(name, std::move(ep_initializer)); + FloatInitializer ep_initializer = {std::move(dims), std::vector(data, data + num_elems)}; + float_initializers_.emplace(std::move(name), std::move(ep_initializer)); + } } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); } return nullptr; } /*static*/ -OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, +OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, OrtEpGraphSupportInfo* graph_support_info) noexcept { - ExampleEp* ep = static_cast(this_ptr); + try { + ExampleEp* ep = static_cast(this_ptr); - size_t num_nodes = 0; - RETURN_IF_ERROR(ep->ort_api.Graph_GetNumNodes(graph, &num_nodes)); + Ort::ConstGraph graph{ort_graph}; + std::vector nodes = graph.GetNodes(); + if (nodes.empty()) { + return nullptr; // No nodes to process + } - if (num_nodes == 0) { - return nullptr; // No nodes to process - } + std::vector supported_nodes; - std::vector nodes(num_nodes); - RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); - - std::vector supported_nodes; - - for (const OrtNode* node : nodes) { - const char* op_type = nullptr; - RETURN_IF_ERROR(ep->ort_api.Node_GetOperatorType(node, &op_type)); - - if (std::strncmp(op_type, "Mul", 4) == 0) { - // Check that Mul has inputs/output of type float - size_t num_inputs = 0; - size_t num_outputs = 0; - RETURN_IF_ERROR(ep->ort_api.Node_GetNumInputs(node, &num_inputs)); - RETURN_IF_ERROR(ep->ort_api.Node_GetNumOutputs(node, &num_outputs)); - RETURN_IF(num_inputs != 2 || num_outputs != 1, ep->ort_api, "Mul should have 2 inputs and 1 output"); - - std::vector inputs(num_inputs); - std::vector outputs(num_outputs); - RETURN_IF_ERROR(ep->ort_api.Node_GetInputs(node, inputs.data(), inputs.size())); - RETURN_IF_ERROR(ep->ort_api.Node_GetOutputs(node, outputs.data(), outputs.size())); - - std::array is_float = {false, false, false}; - RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, inputs[0], is_float[0])); - RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, inputs[1], is_float[1])); - RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, outputs[0], is_float[2])); - if (!is_float[0] || !is_float[1] || !is_float[2]) { - continue; // Input or output is not of type float - } + for (const auto& node : nodes) { + auto op_type = node.GetOperatorType(); - supported_nodes.push_back(node); // Only support a single Mul for now. - break; - } - } + if (op_type != "Mul") { + // Check that Mul has inputs/output of type float + std::vector inputs = node.GetInputs(); + std::vector outputs = node.GetOutputs(); + + RETURN_IF(inputs.size() != 2 || outputs.size() != 1, ep->ort_api, "Mul should have 2 inputs and 1 output"); - // Create (optional) fusion options for the supported nodes to fuse. - OrtNodeFusionOptions node_fusion_options = {}; - node_fusion_options.ort_version_supported = ORT_API_VERSION; + std::array is_float = {false, false, false}; + IsFloatTensor(inputs[0], is_float[0]); + IsFloatTensor(inputs[1], is_float[1]); + IsFloatTensor(outputs[0], is_float[2]); + if (!is_float[0] || !is_float[1] || !is_float[2]) { + continue; // Input or output is not of type float + } - // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers - // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. - // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use - // during inference. - node_fusion_options.drop_constant_initializers = true; - RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, supported_nodes.data(), - supported_nodes.size(), &node_fusion_options)); + supported_nodes.push_back(node); // Only support a single Mul for now. + break; + } + } + + // Create (optional) fusion options for the supported nodes to fuse. + OrtNodeFusionOptions node_fusion_options = {}; + node_fusion_options.ort_version_supported = ORT_API_VERSION; + + // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers + // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. + // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use + // during inference. + node_fusion_options.drop_constant_initializers = true; + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + reinterpret_cast(supported_nodes.data()), + supported_nodes.size(), + &node_fusion_options)); + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } return nullptr; } /*static*/ -OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, +OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** ort_graphs, _In_ const OrtNode** fused_nodes, _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, _Out_writes_(count) OrtNode** ep_context_nodes) noexcept { - ExampleEp* ep = static_cast(this_ptr); - const OrtApi& ort_api = ep->ort_api; - - if (count != 1) { - return ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single graph"); - } - - // In GetCapability(), this EP specified that it doesn't need ORT to provide constant initializers during inference. - // So, this EP saves constant initializers so that they're available during inference, but an actual EP - // implementation could transfer the weights to device memory. - ep->SaveConstantInitializers(graphs[0]); - - size_t num_nodes = 0; - RETURN_IF_ERROR(ep->ort_api.Graph_GetNumNodes(graphs[0], &num_nodes)); - - std::vector nodes(num_nodes); - RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graphs[0], nodes.data(), nodes.size())); - - if (num_nodes != 1) { - return ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single Mul node"); - } + try { + if (count != 1) { + Ort::Status status("Expected to compile a single graph", ORT_EP_FAIL); + return status.release(); + } - const char* node_op_type = nullptr; - RETURN_IF_ERROR(ort_api.Node_GetOperatorType(nodes[0], &node_op_type)); + ExampleEp* ep = static_cast(this_ptr); - if (std::strncmp(node_op_type, "Mul", 4) != 0) { - return ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single Mul node"); - } + Ort::ConstGraph graph{ort_graphs[0]}; - // Now we know we're compiling a single Mul node. Create a computation kernel. - std::array node_inputs = {}; - std::array node_input_names = {}; + // In GetCapability(), this EP specified that it doesn't need ORT to provide constant initializers during inference. + // So, this EP saves constant initializers so that they're available during inference, but an actual EP + // implementation could transfer the weights to device memory. + ep->SaveConstantInitializers(graph); - RETURN_IF_ERROR(ort_api.Node_GetInputs(nodes[0], node_inputs.data(), node_inputs.size())); - RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[0], &node_input_names[0])); - RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[1], &node_input_names[1])); + std::vector nodes = graph.GetNodes(); + if (nodes.size() != 1) { + Ort::Status status("Expected to compile a single Mul node", ORT_EP_FAIL); + return status.release(); + } - const char* ep_name = nullptr; - RETURN_IF_ERROR(ort_api.Node_GetEpName(fused_nodes[0], &ep_name)); - if (std::strncmp(ep_name, "example_ep", 11) != 0) { - return ort_api.CreateStatus(ORT_EP_FAIL, "The fused node is expected to assigned to this EP to run on"); - } + auto node_op_type = nodes[0].GetOperatorType(); + if (node_op_type != "Mul") { + Ort::Status status("Expected to compile a single Mul node", ORT_EP_FAIL); + return status.release(); + } - // Associate the name of the fused node with our MulKernel. - const char* fused_node_name = nullptr; - RETURN_IF_ERROR(ort_api.Node_GetName(fused_nodes[0], &fused_node_name)); + // Now we know we're compiling a single Mul node. Create a computation kernel. + std::vector node_inputs = nodes[0].GetInputs(); + std::array node_input_names; + node_input_names[0] = node_inputs[0].GetName(); + node_input_names[1] = node_inputs[1].GetName(); + + Ort::ConstNode fused_node{fused_nodes[0]}; + auto ep_name = fused_node.GetEpName(); + if (ep_name != "example_ep") { + Ort::Status status("The fused node is expected to assigned to this EP to run on", ORT_EP_FAIL); + return status.release(); + } - ep->kernels_.emplace(std::string(fused_node_name), std::make_unique(ep->ort_api, ep->logger_, + // Associate the name of the fused node with our MulKernel. + auto fused_node_name = fused_node.GetName(); + ep->kernels_.emplace(std::move(fused_node_name), std::make_unique(ep->ort_api, ep->logger_, ep->float_initializers_, node_input_names[0], node_input_names[1])); - // Update the OrtNodeComputeInfo associated with the graph. - auto node_compute_info = std::make_unique(*ep); - node_compute_infos[0] = node_compute_info.release(); + // Update the OrtNodeComputeInfo associated with the graph. + auto node_compute_info = std::make_unique(*ep); + node_compute_infos[0] = node_compute_info.release(); - // Create EpContext nodes for the fused nodes we compiled. - if (ep->config_.enable_ep_context) { - assert(ep_context_nodes != nullptr); - RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span(fused_nodes, count), - gsl::span(ep_context_nodes, count))); + // Create EpContext nodes for the fused nodes we compiled. + if (ep->config_.enable_ep_context) { + assert(ep_context_nodes != nullptr); + RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span(fused_nodes, count), + gsl::span(ep_context_nodes, count))); + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); } return nullptr; @@ -375,69 +365,74 @@ void ORT_API_CALL ExampleEp::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, // cannot currently run the EPContext model. OrtStatus* ExampleEp::CreateEpContextNodes(gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes) { - assert(fused_nodes.size() == ep_context_nodes.size()); + try { + assert(fused_nodes.size() == ep_context_nodes.size()); - // Helper to collect input or output names from an array of OrtValueInfo instances. - auto collect_input_output_names = [&](gsl::span value_infos, - std::vector& result) -> OrtStatus* { - size_t num_values = value_infos.size(); - std::vector value_names(num_values); + // Helper to collect input or output names from an array of OrtValueInfo instances. + auto collect_input_output_names = [&](gsl::span value_infos, + std::vector& result) { + std::vector value_names; + value_names.reserve(value_infos.size()); - for (size_t i = 0; i < num_values; ++i) { - const OrtValueInfo* value_info = value_infos[i]; - RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &value_names[i])); - } + for (const auto vi : value_infos) { + value_names.push_back(vi.GetName()); + } - result = std::move(value_names); - return nullptr; - }; - - // Create an "EPContext" node for every fused node. - for (size_t i = 0; i < fused_nodes.size(); ++i) { - const OrtNode* fused_node = fused_nodes[i]; - const char* fused_node_name = nullptr; - - RETURN_IF_ERROR(ort_api.Node_GetName(fused_node, &fused_node_name)); - - size_t num_fused_node_inputs = 0; - size_t num_fused_node_outputs = 0; - RETURN_IF_ERROR(ort_api.Node_GetNumInputs(fused_node, &num_fused_node_inputs)); - RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(fused_node, &num_fused_node_outputs)); - - std::vector fused_node_inputs(num_fused_node_inputs); - std::vector fused_node_outputs(num_fused_node_outputs); - RETURN_IF_ERROR(ort_api.Node_GetInputs(fused_node, fused_node_inputs.data(), fused_node_inputs.size())); - RETURN_IF_ERROR(ort_api.Node_GetOutputs(fused_node, fused_node_outputs.data(), fused_node_outputs.size())); - - std::vector input_names; - std::vector output_names; - - RETURN_IF_ERROR(collect_input_output_names(fused_node_inputs, /*out*/ input_names)); - RETURN_IF_ERROR(collect_input_output_names(fused_node_outputs, /*out*/ output_names)); - - int64_t is_main_context = (i == 0); - int64_t embed_mode = 1; - - // Create node attributes. The CreateNode() function copies the attributes, so we have to release them. - std::array attributes = {}; - DeferOrtRelease defer_release_attrs(attributes.data(), attributes.size(), ort_api.ReleaseOpAttr); - - std::string ep_ctx = "binary_data"; - RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_cache_context", ep_ctx.c_str(), static_cast(ep_ctx.length()), - ORT_OP_ATTR_STRING, &attributes[0])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("main_context", &is_main_context, 1, ORT_OP_ATTR_INT, &attributes[1])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("embed_mode", &embed_mode, 1, ORT_OP_ATTR_INT, &attributes[2])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_sdk_version", "1", 1, ORT_OP_ATTR_STRING, &attributes[3])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("partition_name", fused_node_name, static_cast(strlen(fused_node_name)), - ORT_OP_ATTR_STRING, &attributes[4])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("source", this->name_.c_str(), static_cast(this->name_.length()), - ORT_OP_ATTR_STRING, &attributes[5])); - - RETURN_IF_ERROR(model_editor_api.CreateNode("EPContext", "com.microsoft", fused_node_name, - input_names.data(), input_names.size(), - output_names.data(), output_names.size(), - attributes.data(), attributes.size(), - &ep_context_nodes[i])); + result = std::move(value_names); + }; + + // Create an "EPContext" node for every fused node. + for (size_t i = 0; i < fused_nodes.size(); ++i) { + Ort::ConstNode fused_node{fused_nodes[i]}; + auto fused_node_name = fused_node.GetName(); + + std::vector fused_node_inputs = fused_node.GetInputs(); + std::vector fused_node_outputs = fused_node.GetOutputs(); + + std::vector input_names; + std::vector output_names; + + collect_input_output_names(fused_node_inputs, /*out*/ input_names); + collect_input_output_names(fused_node_outputs, /*out*/ output_names); + + int64_t is_main_context = (i == 0); + int64_t embed_mode = 1; + + // Create node attributes. The CreateNode() function copies the attributes. + std::array attributes = {}; + std::string ep_ctx = "binary_data"; + attributes[0] = Ort::OpAttr("ep_cache_context", ep_ctx.data(), static_cast(ep_ctx.size()), + ORT_OP_ATTR_STRING); + + attributes[1] = Ort::OpAttr("main_context", &is_main_context, 1, ORT_OP_ATTR_INT); + attributes[2] = Ort::OpAttr("embed_mode", &embed_mode, 1, ORT_OP_ATTR_INT); + attributes[3] = Ort::OpAttr("ep_sdk_version", "1", 1, ORT_OP_ATTR_STRING); + attributes[4] = Ort::OpAttr("partition_name", fused_node_name.data(), static_cast(fused_node_name.size()), + ORT_OP_ATTR_STRING); + + attributes[5] = Ort::OpAttr("source", this->name_.data(), static_cast(this->name_.size()), + ORT_OP_ATTR_STRING); + + std::vector c_input_names; + std::transform(input_names.begin(), input_names.end(), std::back_inserter(c_input_names), + [](const std::string& s) { return s.c_str(); }); + std::vector c_output_names; + std::transform(output_names.begin(), output_names.end(), std::back_inserter(c_output_names), + [](const std::string& s) { return s.c_str(); }); + + OrtOpAttr** op_attrs = reinterpret_cast(attributes.data()); + RETURN_IF_ERROR(model_editor_api.CreateNode("EPContext", "com.microsoft", fused_node_name.c_str(), + c_input_names.data(), c_input_names.size(), + c_output_names.data(), c_output_names.size(), + op_attrs, attributes.size(), + &ep_context_nodes[i])); + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); } return nullptr; diff --git a/onnxruntime/test/autoep/library/ep.h b/onnxruntime/test/autoep/library/ep.h index fa6eb24c5cc04..279925a7ec3e1 100644 --- a/onnxruntime/test/autoep/library/ep.h +++ b/onnxruntime/test/autoep/library/ep.h @@ -54,7 +54,7 @@ class ExampleEp : public OrtEp, public ApiPtrs { OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes); - OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* graph); + OrtStatus* SaveConstantInitializers(const OrtGraph* graph); ExampleEpFactory& factory_; std::string name_; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep.cc index b6f982a422b6a..c14bdc1b52093 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep.cc @@ -1,6 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + #include "ep_factory.h" // To make symbols visible on macOS/iOS @@ -21,6 +25,9 @@ EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const const OrtEpApi* ep_api = ort_api->GetEpApi(); const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi(); + // Manual init for the C++ API + Ort::InitApi(ort_api); + // Factory could use registration_name or define its own EP name. std::unique_ptr factory = std::make_unique(registration_name, ApiPtrs{*ort_api, *ep_api, diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc b/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc index 549551931c647..263b4d208bd91 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc @@ -5,48 +5,33 @@ #include -OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& ort_api, const OrtSessionOptions& session_options, +OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& /* ort_api */, const OrtSessionOptions& session_options, const char* config_key, const std::string& default_val, /*out*/ std::string& config_val) { - int has_config = 0; - RETURN_IF_ERROR(ort_api.HasSessionConfigEntry(&session_options, config_key, &has_config)); - - if (has_config != 1) { - config_val = default_val; - return nullptr; + try { + Ort::ConstSessionOptions sess_opt{&session_options}; + config_val = sess_opt.GetConfigEntryOrDefault(config_key, default_val); + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); } - size_t size = 0; - RETURN_IF_ERROR(ort_api.GetSessionConfigEntry(&session_options, config_key, nullptr, &size)); - - config_val.resize(size); - RETURN_IF_ERROR(ort_api.GetSessionConfigEntry(&session_options, config_key, config_val.data(), &size)); - config_val.resize(size - 1); // remove the terminating '\0' - return nullptr; } -OrtStatus* IsFloatTensor(const OrtApi& ort_api, const OrtValueInfo* value_info, bool& result) { +void IsFloatTensor(Ort::ConstValueInfo value_info, bool& result) { result = false; - const OrtTypeInfo* type_info = nullptr; - RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(value_info, &type_info)); - - ONNXType onnx_type = ONNX_TYPE_UNKNOWN; - RETURN_IF_ERROR(ort_api.GetOnnxTypeFromTypeInfo(type_info, &onnx_type)); + auto type_info = value_info.TypeInfo(); + ONNXType onnx_type = type_info.GetONNXType(); if (onnx_type != ONNX_TYPE_TENSOR) { - return nullptr; + return; } - const OrtTensorTypeAndShapeInfo* type_shape = nullptr; - RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(type_info, &type_shape)); - - ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape, &elem_type)); + auto type_shape = type_info.GetTensorTypeAndShapeInfo(); + ONNXTensorElementDataType elem_type = type_shape.GetElementType(); if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { - return nullptr; + return; } - result = true; - return nullptr; } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h index 99ebee9ff64de..e8c086d38a7cb 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h @@ -107,4 +107,4 @@ OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& ort_api, const OrtSessio /*out*/ std::string& config_val); // Returns true (via output parameter) if the given OrtValueInfo represents a float tensor. -OrtStatus* IsFloatTensor(const OrtApi& ort_api, const OrtValueInfo* value_info, bool& result); +void IsFloatTensor(Ort::ConstValueInfo value_info, bool& result); diff --git a/onnxruntime/test/autoep/test_allocators.cc b/onnxruntime/test/autoep/test_allocators.cc index 77d2bb24b7d35..88b522eb10dca 100644 --- a/onnxruntime/test/autoep/test_allocators.cc +++ b/onnxruntime/test/autoep/test_allocators.cc @@ -60,66 +60,58 @@ struct DummyAllocator : OrtAllocator { // validate CreateSharedAllocator allows adding an arena to the shared allocator TEST(SharedAllocators, AddArenaToSharedAllocator) { - const OrtApi& c_api = Ort::GetApi(); RegisteredEpDeviceUniquePtr example_ep; Utils::RegisterAndGetExampleEp(*ort_env, example_ep); - const auto* ep_memory_info = c_api.EpDevice_MemoryInfo(example_ep.get(), OrtDeviceMemoryType_DEFAULT); + Ort::ConstEpDevice example_ep_device{example_ep.get()}; + + auto ep_memory_info = example_ep_device.GetMemoryInfo(OrtDeviceMemoryType_DEFAULT); // validate there is a shared allocator - OrtAllocator* allocator = nullptr; - ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, ep_memory_info, &allocator)); + auto allocator = ort_env->GetSharedAllocator(ep_memory_info); ASSERT_NE(allocator, nullptr); // call CreateSharedAllocator to replace with arena based allocator. arena is configured with kvps - OrtKeyValuePairs allocator_options; + Ort::KeyValuePairs allocator_options; auto initial_chunk_size = "25600"; // arena allocates in 256 byte amounts allocator_options.Add(OrtArenaCfg::ConfigKeyNames::InitialChunkSizeBytes, initial_chunk_size); - ASSERT_ORTSTATUS_OK(c_api.CreateSharedAllocator(*ort_env, example_ep.get(), OrtDeviceMemoryType_DEFAULT, - // allocator is internally added by EP. - // OrtArenaAllocator can only be used for the internal BFCArena - OrtDeviceAllocator, - &allocator_options, &allocator)); + allocator = ort_env->CreateSharedAllocator(example_ep.get(), OrtDeviceMemoryType_DEFAULT, + // allocator is internally added by EP. + // OrtArenaAllocator can only be used for the internal BFCArena + OrtDeviceAllocator, + allocator_options); // first allocation should init the arena to the initial chunk size - void* mem = allocator->Alloc(allocator, 16); - allocator->Free(allocator, mem); + void* mem = allocator.Alloc(16); + allocator.Free(mem); // stats should prove the arena was used - OrtKeyValuePairs* allocator_stats = nullptr; - ASSERT_ORTSTATUS_OK(allocator->GetStats(allocator, &allocator_stats)); + auto allocator_stats = allocator.GetStats(); using ::testing::Contains; using ::testing::Pair; - const auto& stats = allocator_stats->Entries(); + const auto& stats = static_cast(allocator_stats)->Entries(); EXPECT_THAT(stats, Contains(Pair("NumAllocs", "1"))); EXPECT_THAT(stats, Contains(Pair("NumArenaExtensions", "1"))); EXPECT_THAT(stats, Contains(Pair("TotalAllocated", initial_chunk_size))); // optional. ORT owns the allocator but we want to test the release implementation - ASSERT_ORTSTATUS_OK(c_api.ReleaseSharedAllocator(*ort_env, example_ep.get(), OrtDeviceMemoryType_DEFAULT)); + ort_env->ReleaseSharedAllocator(example_ep.get(), OrtDeviceMemoryType_DEFAULT); } TEST(SharedAllocators, GetSharedAllocator) { - const OrtApi& c_api = Ort::GetApi(); - // default CPU allocator should be available. // create a memory info with a different name to validate the shared allocator lookup ignores the name - OrtMemoryInfo* test_cpu_memory_info = nullptr; - ASSERT_ORTSTATUS_OK(c_api.CreateMemoryInfo_V2("dummy", OrtMemoryInfoDeviceType_CPU, 0, 0, - OrtDeviceMemoryType_DEFAULT, 0, OrtDeviceAllocator, - &test_cpu_memory_info)); + auto test_cpu_memory_info = Ort::MemoryInfo("dummy", OrtMemoryInfoDeviceType_CPU, 0, 0, + OrtDeviceMemoryType_DEFAULT, 0, OrtDeviceAllocator); const auto get_allocator_and_check_name = [&](const std::string& expected_name) { - OrtAllocator* allocator = nullptr; - ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, test_cpu_memory_info, &allocator)); + auto allocator = ort_env->GetSharedAllocator(test_cpu_memory_info); ASSERT_NE(allocator, nullptr); - const OrtMemoryInfo* ort_cpu_memory_info = nullptr; - ASSERT_ORTSTATUS_OK(c_api.AllocatorGetInfo(allocator, &ort_cpu_memory_info)); - const char* allocator_name; - ASSERT_ORTSTATUS_OK(c_api.MemoryInfoGetName(ort_cpu_memory_info, &allocator_name)); + auto ort_cpu_memory_info = allocator.GetInfo(); + auto allocator_name = ort_cpu_memory_info.GetAllocatorName(); ASSERT_EQ(expected_name, allocator_name); // Default ORT CPU allocator }; @@ -128,18 +120,16 @@ TEST(SharedAllocators, GetSharedAllocator) { // register custom allocator and make sure that is accessible by exact match DummyAllocator dummy_alloc{test_cpu_memory_info}; - c_api.RegisterAllocator(*ort_env, &dummy_alloc); + ort_env->RegisterAllocator(&dummy_alloc); // GetSharedAllocator should now match the custom allocator get_allocator_and_check_name("dummy"); // unregister custom allocator - ASSERT_ORTSTATUS_OK(c_api.UnregisterAllocator(*ort_env, test_cpu_memory_info)); + ort_env->UnregisterAllocator(test_cpu_memory_info); // there should always be a CPU allocator available get_allocator_and_check_name(onnxruntime::CPU); - - c_api.ReleaseMemoryInfo(test_cpu_memory_info); } } // namespace test diff --git a/onnxruntime/test/autoep/test_data_transfer.cc b/onnxruntime/test/autoep/test_data_transfer.cc index cc09699b754b6..71c69698ed386 100644 --- a/onnxruntime/test/autoep/test_data_transfer.cc +++ b/onnxruntime/test/autoep/test_data_transfer.cc @@ -23,16 +23,15 @@ namespace onnxruntime { namespace test { TEST(OrtEpLibrary, DataTransfer) { - const OrtApi& c_api = Ort::GetApi(); RegisteredEpDeviceUniquePtr example_ep; Utils::RegisterAndGetExampleEp(*ort_env, example_ep); - const OrtEpDevice* ep_device = example_ep.get(); + Ort::ConstEpDevice ep_device(example_ep.get()); - const OrtMemoryInfo* device_memory_info = c_api.EpDevice_MemoryInfo(ep_device, OrtDeviceMemoryType_DEFAULT); + auto device_memory_info = ep_device.GetMemoryInfo(OrtDeviceMemoryType_DEFAULT); // create a tensor using the default CPU allocator Ort::AllocatorWithDefaultOptions cpu_allocator; - std::vector shape{2, 3, 4}; // shape doesn't matter + constexpr const std::array shape{2, 3, 4}; // shape doesn't matter const size_t num_elements = 2 * 3 * 4; RandomValueGenerator random{}; @@ -44,24 +43,21 @@ TEST(OrtEpLibrary, DataTransfer) { // create an on-device Tensor using the example EPs alternative CPU allocator. // it has a different vendor to the default ORT CPU allocator so we can copy between them even though both are // really CPU based. - OrtAllocator* allocator = nullptr; - ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, device_memory_info, &allocator)); + auto allocator = ort_env->GetSharedAllocator(device_memory_info); ASSERT_NE(allocator, nullptr); Ort::Value device_tensor = Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); - std::vector src_tensor_ptrs{cpu_tensor}; - std::vector dst_tensor_ptrs{device_tensor}; + std::vector src_tensor; + src_tensor.push_back(std::move(cpu_tensor)); + std::vector dst_tensor; + dst_tensor.push_back(std::move(device_tensor)); - ASSERT_ORTSTATUS_OK(c_api.CopyTensors(*ort_env, src_tensor_ptrs.data(), dst_tensor_ptrs.data(), nullptr, - src_tensor_ptrs.size())); + ASSERT_CXX_ORTSTATUS_OK(ort_env->CopyTensors(src_tensor, dst_tensor, nullptr)); - const float* src_data = nullptr; - const float* dst_data = nullptr; - ASSERT_ORTSTATUS_OK(c_api.GetTensorData(cpu_tensor, reinterpret_cast(&src_data))); - ASSERT_ORTSTATUS_OK(c_api.GetTensorData(device_tensor, reinterpret_cast(&dst_data))); + const float* src_data = src_tensor[0].GetTensorData(); + const float* dst_data = dst_tensor[0].GetTensorData(); - size_t bytes; - ASSERT_ORTSTATUS_OK(c_api.GetTensorSizeInBytes(cpu_tensor, &bytes)); + size_t bytes = src_tensor[0].GetTensorSizeInBytes(); ASSERT_EQ(bytes, num_elements * sizeof(float)); ASSERT_NE(src_data, dst_data) << "Should have copied between two different memory locations"; diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index f1ef67e1f6ba4..0f4a654f116c4 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -54,12 +54,12 @@ void RunModelWithPluginEp(Ort::SessionOptions& session_options) { TEST(OrtEpLibrary, PluginEp_AppendV2_MulInference) { RegisteredEpDeviceUniquePtr example_ep; Utils::RegisterAndGetExampleEp(*ort_env, example_ep); - const OrtEpDevice* plugin_ep_device = example_ep.get(); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); // Create session with example plugin EP Ort::SessionOptions session_options; std::unordered_map ep_options; - session_options.AppendExecutionProvider_V2(*ort_env, {Ort::ConstEpDevice(plugin_ep_device)}, ep_options); + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); RunModelWithPluginEp(session_options); } @@ -83,7 +83,7 @@ TEST(OrtEpLibrary, PluginEp_PreferCpu_MulInference) { TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { RegisteredEpDeviceUniquePtr example_ep; Utils::RegisterAndGetExampleEp(*ort_env, example_ep); - const OrtEpDevice* plugin_ep_device = example_ep.get(); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); { const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); @@ -94,7 +94,7 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { Ort::SessionOptions session_options; std::unordered_map ep_options; - session_options.AppendExecutionProvider_V2(*ort_env, {Ort::ConstEpDevice(plugin_ep_device)}, ep_options); + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); // Create model compilation options from the session options. Ort::ModelCompilationOptions compile_options(*ort_env, session_options); @@ -102,9 +102,7 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { compile_options.SetOutputModelPath(output_model_file); // Compile the model. - Ort::Status status = Ort::CompileModel(*ort_env, compile_options); - ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); - + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); // Make sure the compiled model was generated. ASSERT_TRUE(std::filesystem::exists(output_model_file)); } diff --git a/onnxruntime/test/common/string_utils_test.cc b/onnxruntime/test/common/string_utils_test.cc index 79f8ddff7b52a..983f7fa7a87f9 100644 --- a/onnxruntime/test/common/string_utils_test.cc +++ b/onnxruntime/test/common/string_utils_test.cc @@ -15,6 +15,8 @@ namespace test { namespace { template void TestSuccessfulParse(const std::string& input, const T& expected_value) { + SCOPED_TRACE(MakeString("Input: \"", input, "\", expected_value: ", expected_value)); + T value; ASSERT_TRUE(TryParseStringWithClassicLocale(input, value)); EXPECT_EQ(value, expected_value); @@ -22,6 +24,8 @@ void TestSuccessfulParse(const std::string& input, const T& expected_value) { template void TestFailedParse(const std::string& input) { + SCOPED_TRACE(MakeString("Input: \"", input, "\"")); + T value; EXPECT_FALSE(TryParseStringWithClassicLocale(input, value)); } @@ -31,6 +35,7 @@ TEST(StringUtilsTest, TryParseStringWithClassicLocale) { TestSuccessfulParse("-1", -1); TestSuccessfulParse("42", 42u); TestSuccessfulParse("2.5", 2.5f); + TestSuccessfulParse("0x100", uint32_t{0x100}); // out of range TestFailedParse("32768"); diff --git a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc index c9a7116bf8052..2918e4baf86a4 100644 --- a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc +++ b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc @@ -82,21 +82,48 @@ static void CalculateDynamicQuantizeMatMul(const int64_t M, const int64_t N, con } } +struct TestDynamicQuantizeMatMulOptions { + bool is_matrix_b_constant = true; + + bool per_column = false; + + bool is_scale_constant = false; + + bool has_zp = true; + bool is_zp_constant = false; + bool is_zp_zero = false; + + bool has_bias = false; + bool is_bias_constant = false; + + bool empty_input = false; +}; + template -void TestDynamicQuantizeMatMul(bool is_matrix_b_constant, - bool per_column = false, - bool has_zp = true, - bool has_bias = false, - bool empty_input = false) { +void TestDynamicQuantizeMatMul(const TestDynamicQuantizeMatMulOptions& opts) { + static_assert(std::is_same_v || std::is_same_v); + + SCOPED_TRACE(MakeString( + "b data type:", (std::is_same_v ? "uint8" : "int8"), + ", is_matrix_b_constant:", opts.is_matrix_b_constant, + ", per_column:", opts.per_column, + ", is_scale_constant:", opts.is_scale_constant, + ", has_zp:", opts.has_zp, + ", is_zp_constant:", opts.is_zp_constant, + ", is_zp_zero:", opts.is_zp_zero, + ", has_bias:", opts.has_bias, + ", is_bias_constant:", opts.is_bias_constant, + ", empty_input:", opts.empty_input)); + // create rand inputs RandomValueGenerator random{1668426375}; - int64_t M = empty_input ? 1 : 4; + int64_t M = opts.empty_input ? 1 : 4; int64_t N = 128; int64_t K = 128; - std::vector A_dims{empty_input ? 0 : M, K}; + std::vector A_dims{opts.empty_input ? 0 : M, K}; std::vector B_dims{K, N}; - std::vector Y_dims{empty_input ? 0 : M, K}; + std::vector Y_dims{opts.empty_input ? 0 : M, K}; std::vector A_data = random.Uniform(A_dims, -1.0f, 1.0f); std::vector B_data; std::vector tmp_B_data = random.Uniform(B_dims, @@ -106,101 +133,120 @@ void TestDynamicQuantizeMatMul(bool is_matrix_b_constant, return static_cast(v); }); - int64_t b_scale_zp_size = per_column ? B_dims.back() : 1; + int64_t b_scale_zp_size = opts.per_column ? B_dims.back() : 1; std::vector B_scale = random.Uniform(AsSpan({b_scale_zp_size}), -0.1f, 0.1f); std::vector B_zero_point(b_scale_zp_size); - std::for_each(B_zero_point.begin(), - B_zero_point.end(), - [&random](T& zp) { - zp = static_cast(random.Uniform(std::array{1}, - std::numeric_limits::min(), - std::numeric_limits::max())[0]); - }); + if (!opts.is_zp_zero) { + std::for_each(B_zero_point.begin(), + B_zero_point.end(), + [&random](T& zp) { + zp = static_cast(random.Uniform(std::array{1}, + std::numeric_limits::min(), + std::numeric_limits::max())[0]); + }); + } std::vector Bias = random.Uniform(AsSpan({B_dims.back()}), -0.1f, 0.1f); OpTester test("DynamicQuantizeMatMul", 1, onnxruntime::kMSDomain); test.AddInput("A", A_dims, A_data); - test.AddInput("B", B_dims, B_data, is_matrix_b_constant); - test.AddInput("b_scale", {b_scale_zp_size}, B_scale); + test.AddInput("B", B_dims, B_data, opts.is_matrix_b_constant); + test.AddInput("b_scale", {b_scale_zp_size}, B_scale, opts.is_scale_constant); - if (has_zp) { - test.AddInput("b_zero_point", {b_scale_zp_size}, B_zero_point); + if (opts.has_zp) { + test.AddInput("b_zero_point", {b_scale_zp_size}, B_zero_point, opts.is_zp_constant); } else { test.AddOptionalInputEdge(); } - if (has_bias) { - test.AddInput("bias", {B_dims.back()}, Bias); + if (opts.has_bias) { + test.AddInput("bias", {B_dims.back()}, Bias, opts.is_bias_constant); } else { test.AddOptionalInputEdge(); } std::vector Y_data(M * N); CalculateDynamicQuantizeMatMul(M, N, K, A_data, B_data, B_scale, B_zero_point, Bias, Y_data, - per_column, has_zp, has_bias); + opts.per_column, opts.has_zp, opts.has_bias); test.AddOutput("Y", Y_dims, Y_data); test.SetOutputRelErr("Y", 0.02f); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } -template -void RunDynamicQuantizeMatMulTest() { - TestDynamicQuantizeMatMul(false, /*is_matrix_b_constant*/ - false, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); - - TestDynamicQuantizeMatMul(true, /*is_matrix_b_constant*/ - false, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); - - TestDynamicQuantizeMatMul(false, /*is_matrix_b_constant*/ - true, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); - - TestDynamicQuantizeMatMul(true, /*is_matrix_b_constant*/ - true, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); +template +void TestDynamicQuantizeMatMul(bool is_matrix_b_constant, + bool per_column = false, + bool has_zp = true, + bool has_bias = false, + bool empty_input = false) { + TestDynamicQuantizeMatMulOptions opts{}; + opts.is_matrix_b_constant = is_matrix_b_constant; + opts.per_column = per_column; + opts.has_zp = has_zp; + opts.has_bias = has_bias; + opts.empty_input = empty_input; + + TestDynamicQuantizeMatMul(opts); } -TEST(DynamicQuantizeMatMul, HasZeroPoint_NoBias_test_S8) { - RunDynamicQuantizeMatMulTest(); +template +void RunDynamicQuantizeMatMulTest() { + for (bool is_matrix_b_constant : {false, true}) { + for (bool per_column : {false, true}) { + for (bool has_zp : {false, true}) { + for (bool has_bias : {false, true}) { + TestDynamicQuantizeMatMul(is_matrix_b_constant, + per_column, + has_zp, + has_bias); + } + } + } + } } -TEST(DynamicQuantizeMatMul, HasZeroPoint_NoBias_test_U8) { - RunDynamicQuantizeMatMulTest(); +TEST(DynamicQuantizeMatMul, Int8) { + RunDynamicQuantizeMatMulTest(); } -TEST(DynamicQuantizeMatMul, NoZeroPoint_HasBias_test_S8) { - RunDynamicQuantizeMatMulTest(); +TEST(DynamicQuantizeMatMul, UInt8) { + RunDynamicQuantizeMatMulTest(); } -TEST(DynamicQuantizeMatMul, NoZeroPoint_HasBias_test_U8) { - RunDynamicQuantizeMatMulTest(); -} +TEST(DynamicQuantizeMatMul, WithConstantBInputs) { + TestDynamicQuantizeMatMulOptions base_opts{}; + base_opts.is_matrix_b_constant = true; + base_opts.is_scale_constant = true; + base_opts.is_zp_constant = true; -TEST(DynamicQuantizeMatMul, NoZeroPoint_NoBias_test_S8) { - RunDynamicQuantizeMatMulTest(); -} + { + // no zp + auto opts = base_opts; + opts.has_zp = false; -TEST(DynamicQuantizeMatMul, NoZeroPoint_NoBias_test_U8) { - RunDynamicQuantizeMatMulTest(); -} + TestDynamicQuantizeMatMul(opts); + TestDynamicQuantizeMatMul(opts); + } -TEST(DynamicQuantizeMatMul, HasZeroPoint_HasBias_test_S8) { - RunDynamicQuantizeMatMulTest(); -} + { + // zp that is zero (symmetric quantization) + auto opts = base_opts; + opts.has_zp = true; + opts.is_zp_zero = true; -TEST(DynamicQuantizeMatMul, HasZeroPoint_HasBias_test_U8) { - RunDynamicQuantizeMatMulTest(); + TestDynamicQuantizeMatMul(opts); + TestDynamicQuantizeMatMul(opts); + } + + { + // zp that is non-zero + auto opts = base_opts; + opts.has_zp = true; + opts.is_zp_zero = false; + + TestDynamicQuantizeMatMul(opts); + TestDynamicQuantizeMatMul(opts); + } } TEST(DynamicQuantizeMatMul, UInt8_test_with_empty_input) { diff --git a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc old mode 100755 new mode 100644 index 334be3e03b483..4b586e24c9bd3 --- a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc +++ b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc @@ -10,6 +10,7 @@ #include "core/common/common.h" #include "core/framework/execution_provider.h" +#include "test/common/cuda_op_test_utils.h" #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" @@ -102,6 +103,7 @@ void RunGatherBlockQuantized(const std::vector& data, const std::vector& output_shape, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, bool touch_on_device_data = false) { + (void)touch_on_device_data; CheckDataAndShape(data, data_shape, "data in RunGatherBlockQuantized"); CheckDataAndShape(indices, indices_shape, "indices in RunGatherBlockQuantized"); CheckDataAndShape(scales, scales_shape, "scales in RunGatherBlockQuantized"); @@ -127,12 +129,15 @@ void RunGatherBlockQuantized(const std::vector& data, test.AddOutput("output", output_shape, output); - if (touch_on_device_data) { - // test would need to see data on device - test.Run(expect_result, "", {kWebGpuExecutionProvider}, nullptr); + bool enable_cuda = HasCudaEnvironment(0); + std::vector> eps; + if (enable_cuda) { + eps.push_back(DefaultCudaExecutionProvider()); } else { - test.Run(expect_result, ""); + eps.push_back(DefaultCpuExecutionProvider()); } + + test.Run(expect_result, "", {}, nullptr, &eps); }; run_test(false); @@ -275,6 +280,7 @@ void Test_Fail_WithZeroPoints(int64_t gather_axis, gather_axis, quantize_axis, block_size, bits, output, output_shape, false); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, UnsupportedTypes) { Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); @@ -289,6 +295,7 @@ TEST(GatherBlockQuantizedOpTest, UnsupportedTypes) { Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); } +#endif template void Test_Fail_WithoutZeroPoints(int64_t gather_axis, @@ -317,6 +324,7 @@ void Test_Fail_WithoutZeroPoints(int64_t gather_axis, gather_axis, quantize_axis, block_size, bits, output, output_shape, false); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, UnsupportedUInt8DataType) { // Gather on axis other than 0 is not supported with uint8_t Test_Fail_WithoutZeroPoints(1, 2, 16); @@ -349,6 +357,7 @@ TEST(GatherBlockQuantizedOpTest, NotSupportedBits) { Test_Fail_WithZeroPoints(0, 2, 16, 6); Test_Fail_WithZeroPoints(0, 2, 16, 7); } +#endif template void Test_ShapeMismatch_WithZeroPoints() { @@ -377,11 +386,13 @@ void Test_ShapeMismatch_WithZeroPoints() { gather_axis, quantize_axis, block_size, bits, output, output_shape, false); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, ShapeMismatch) { Test_ShapeMismatch_WithZeroPoints(); Test_ShapeMismatch_WithZeroPoints(); Test_ShapeMismatch_WithZeroPoints(); } +#endif template void Test_InvalidIndices_WithZeroPoints() { @@ -410,11 +421,13 @@ void Test_InvalidIndices_WithZeroPoints() { gather_axis, quantize_axis, block_size, bits, output, output_shape, false, true); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, InvalidIndices) { Test_InvalidIndices_WithZeroPoints(); Test_InvalidIndices_WithZeroPoints(); Test_InvalidIndices_WithZeroPoints(); } +#endif template void Test_GatherAxis0_WithZeroPoints(int bits = 4) { @@ -447,6 +460,7 @@ void Test_GatherAxis0_WithZeroPoints(int bits = 4) { -3, -1, block_size, bits, output, output_shape, true); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, GatherAxis0WithZeroPoints) { Test_GatherAxis0_WithZeroPoints(); Test_GatherAxis0_WithZeroPoints(); @@ -457,6 +471,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0WithZeroPoints) { Test_GatherAxis0_WithZeroPoints(); Test_GatherAxis0_WithZeroPoints(); } +#endif template void Test_GatherAxis0_WithZeroPoints_Uint8(int bits = 4) { @@ -490,6 +505,7 @@ void Test_GatherAxis0_WithZeroPoints_Uint8(int bits = 4) { -3, -1, block_size, bits, output, output_shape, true); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, GatherAxis0WithZeroPoints_4Bits) { Test_GatherAxis0_WithZeroPoints_Uint8(); Test_GatherAxis0_WithZeroPoints_Uint8(); @@ -499,6 +515,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0WithZeroPoints_8Bits) { Test_GatherAxis0_WithZeroPoints_Uint8(8); Test_GatherAxis0_WithZeroPoints_Uint8(8); } +#endif template void Test_GatherAxis0_NoZeroPoints(int bits = 4) { @@ -533,6 +550,7 @@ void Test_GatherAxis0_NoZeroPoints(int bits = 4) { -3, -1, block_size, bits, output, output_shape, true); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints) { Test_GatherAxis0_NoZeroPoints(); Test_GatherAxis0_NoZeroPoints(); @@ -551,6 +569,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints_8Bits) { Test_GatherAxis0_NoZeroPoints(8); Test_GatherAxis0_NoZeroPoints(8); } +#endif template void Test_GatherAxis1_WithZeroPoints() { @@ -585,6 +604,7 @@ void Test_GatherAxis1_WithZeroPoints() { -2, -2, block_size, bits, output, output_shape, true); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, GatherAxis1) { Test_GatherAxis1_WithZeroPoints(); Test_GatherAxis1_WithZeroPoints(); @@ -595,6 +615,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis1) { Test_GatherAxis1_WithZeroPoints(); Test_GatherAxis1_WithZeroPoints(); } +#endif template void Test_GatherAxis2_WithZeroPoints() { @@ -629,6 +650,7 @@ void Test_GatherAxis2_WithZeroPoints() { -1, -3, block_size, bits, output, output_shape, true); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, GatherAxis2) { Test_GatherAxis2_WithZeroPoints(); Test_GatherAxis2_WithZeroPoints(); @@ -639,6 +661,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis2) { Test_GatherAxis2_WithZeroPoints(); Test_GatherAxis2_WithZeroPoints(); } +#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 4c3f9e8dd4dbd..7213937d0ef11 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -345,7 +345,7 @@ void TestMatMulNBitsTyped(std::optional abs_error = std::nullopt, #if !defined(USE_OPENVINO) -TEST(MatMulNBits, Float32_Accuracy0) { +TEST(MatMulNBits, Float32_4b_Accuracy0) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); @@ -372,7 +372,7 @@ TEST(MatMulNBits, Float32_Accuracy0) { TestMatMulNBitsTyped(); } -TEST(MatMulNBits, Float32_Accuracy1) { +TEST(MatMulNBits, Float32_4b_Accuracy1) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); @@ -383,7 +383,7 @@ TEST(MatMulNBits, Float32_Accuracy1) { TestMatMulNBitsTyped(); } -TEST(MatMulNBits, Float32_Accuracy4) { +TEST(MatMulNBits, Float32_4b_Accuracy4) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); @@ -415,7 +415,7 @@ TEST(MatMulNBits, Float32_Accuracy4) { #if !defined(USE_DML) // Actual and expected difference is over 0.01 with DmlExecutionProvider. // Skip the tests instead of raising the tolerance to make is pass. -TEST(MatMulNBits, Float16_Accuracy2) { +TEST(MatMulNBits, Float16_4b_Accuracy2) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); @@ -442,7 +442,7 @@ TEST(MatMulNBits, Float16_Accuracy2) { TestMatMulNBitsTyped(); } -TEST(MatMulNBits, Float16_Accuracy0) { +TEST(MatMulNBits, Float16_4b_Accuracy0) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); @@ -453,7 +453,7 @@ TEST(MatMulNBits, Float16_Accuracy0) { TestMatMulNBitsTyped(); } -TEST(MatMulNBits, Float16_Accuracy4) { +TEST(MatMulNBits, Float16_4b_Accuracy4) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); @@ -483,7 +483,7 @@ TEST(MatMulNBits, Float16_Accuracy4) { TestMatMulNBitsTyped(); } -TEST(MatMulNBits, LegacyShape) { +TEST(MatMulNBits, LegacyShape_4b) { constexpr bool legacy_shape = true; TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); diff --git a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc index a7df3b7bbec54..c60abbc278962 100644 --- a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc @@ -25,7 +25,9 @@ #include "core/session/ort_env.h" #include "core/util/qmath.h" -#if (defined(MLAS_TARGET_AMD64_IX86) && !defined(USE_DML) && !defined(USE_WEBGPU) && !defined(USE_COREML)) || defined(USE_CUDA) || defined(USE_WEBGPU) +#if ((defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_ARM64)) && \ + !defined(USE_DML) && !defined(USE_WEBGPU) && !defined(USE_COREML)) || \ + defined(USE_CUDA) || defined(USE_WEBGPU) extern std::unique_ptr ort_env; @@ -275,6 +277,7 @@ TEST(MatMulNBits, Float32_8b_AccuracyLevel4) { GTEST_SKIP() << "Skipping test on Android x86_64 (emulator)."; #endif TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index 42f62981cb52b..0690b8894eb7a 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -9,17 +9,19 @@ namespace onnxruntime { namespace test { +// Note: QMoE CPU implementation now always applies softmax normalization to top-k selected experts +// regardless of the normalize_routing_weights parameter value for mathematical correctness. + #ifndef ENABLE_TRAINING static void RunMoETest(const std::vector& input, const std::vector& router_probs, const std::vector& fc1_experts_weights, const std::vector& fc2_experts_weights, const std::vector& fc3_experts_weights, const std::vector& fc1_experts_bias, const std::vector& fc2_experts_bias, const std::vector& output_data, int num_rows, int num_experts, int hidden_size, int inter_size, std::string activation_type, - int normalize_routing_weights = 0, int top_k = 1, bool use_float16 = false) { + int normalize_routing_weights = 1, int top_k = 1, bool use_float16 = false) { constexpr int min_cuda_arch = 700; - constexpr int max_cuda_arch = 900; - bool enable_cuda = HasCudaEnvironment(min_cuda_arch) && !NeedSkipIfCudaArchGreaterEqualThan(max_cuda_arch); + bool enable_cuda = HasCudaEnvironment(min_cuda_arch); if (enable_cuda) { OpTester tester("MoE", 1, onnxruntime::kMSDomain); tester.AddAttribute("k", static_cast(top_k)); @@ -28,8 +30,8 @@ static void RunMoETest(const std::vector& input, const std::vector std::vector input_dims = {num_rows, hidden_size}; std::vector router_probs_dims = {num_rows, num_experts}; - std::vector fc1_experts_weights_dims = {num_experts, hidden_size, inter_size}; - std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size}; + std::vector fc1_experts_weights_dims = {num_experts, inter_size, hidden_size}; + std::vector fc2_experts_weights_dims = {num_experts, hidden_size, inter_size}; std::vector fc3_experts_weights_dims = fc1_experts_weights_dims; std::vector fc1_experts_bias_dims = {num_experts, inter_size}; std::vector fc2_experts_bias_dims = {num_experts, hidden_size}; @@ -91,44 +93,97 @@ static void RunQMoETest(const std::vector& input, const std::vector& fc3_experts_weights, const std::vector& fc1_scales, const std::vector& fc2_scales, const std::vector& fc3_scales, const std::vector& output_data, int num_rows, int num_experts, int hidden_size, - int inter_size, std::string activation_type, int normalize_routing_weights = 0, int top_k = 1) { + int inter_size, std::string activation_type, int normalize_routing_weights = 1, int top_k = 1, int expert_weight_bits = 4) { constexpr int min_cuda_arch = 700; - constexpr int max_cuda_arch = 900; - bool enable_cuda = HasCudaEnvironment(min_cuda_arch) && !NeedSkipIfCudaArchGreaterEqualThan(max_cuda_arch); + // Test CUDA execution provider + bool enable_cuda = HasCudaEnvironment(min_cuda_arch); if (enable_cuda) { - OpTester tester("QMoE", 1, onnxruntime::kMSDomain); - tester.AddAttribute("k", static_cast(top_k)); - tester.AddAttribute("activation_type", activation_type); - tester.AddAttribute("normalize_routing_weights", static_cast(normalize_routing_weights)); + OpTester cuda_tester("QMoE", 1, onnxruntime::kMSDomain); + cuda_tester.AddAttribute("k", static_cast(top_k)); + cuda_tester.AddAttribute("activation_type", activation_type); + cuda_tester.AddAttribute("normalize_routing_weights", static_cast(normalize_routing_weights)); std::vector input_dims = {num_rows, hidden_size}; std::vector router_probs_dims = {num_rows, num_experts}; - std::vector fc1_experts_weights_dims = {num_experts, hidden_size, inter_size / 2}; - std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size / 2}; + std::vector fc1_experts_weights_dims = {num_experts, hidden_size, expert_weight_bits == 4 ? inter_size / 2 : inter_size}; + std::vector fc2_experts_weights_dims = {num_experts, inter_size, expert_weight_bits == 4 ? hidden_size / 2 : hidden_size}; std::vector fc3_experts_weights_dims = fc1_experts_weights_dims; std::vector fc1_scales_dims = {num_experts, inter_size}; std::vector fc2_scales_dims = {num_experts, hidden_size}; std::vector fc3_scales_dims = fc1_scales_dims; std::vector output_dims = {num_rows, hidden_size}; - tester.AddInput("input", input_dims, ToFloat16(input)); - tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + cuda_tester.AddInput("input", input_dims, ToFloat16(input)); + cuda_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); - tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); - tester.AddInput("fc1_scales", fc1_scales_dims, ToFloat16(fc1_scales)); - tester.AddOptionalInputEdge(); // fc1_experts_bias - tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); - tester.AddInput("fc2_scales", fc2_scales_dims, ToFloat16(fc2_scales)); - tester.AddOptionalInputEdge(); // fc2_experts_bias - tester.AddInput("fc3_experts_weights", fc3_experts_weights_dims, fc3_experts_weights); - tester.AddInput("fc3_scales", fc3_scales_dims, ToFloat16(fc3_scales)); - tester.AddOutput("output", output_dims, ToFloat16(output_data)); - tester.SetOutputTolerance(0.005f); + cuda_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + cuda_tester.AddInput("fc1_scales", fc1_scales_dims, ToFloat16(fc1_scales)); + cuda_tester.AddOptionalInputEdge(); // fc1_experts_bias + cuda_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + cuda_tester.AddInput("fc2_scales", fc2_scales_dims, ToFloat16(fc2_scales)); + cuda_tester.AddOptionalInputEdge(); // fc2_experts_bias - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + // Only add FC3 inputs if fc3_experts_weights is not empty + if (!fc3_experts_weights.empty()) { + cuda_tester.AddInput("fc3_experts_weights", fc3_experts_weights_dims, fc3_experts_weights); + cuda_tester.AddInput("fc3_scales", fc3_scales_dims, ToFloat16(fc3_scales)); + } else { + cuda_tester.AddOptionalInputEdge(); // fc3_experts_weights + cuda_tester.AddOptionalInputEdge(); // fc3_scales + } + cuda_tester.AddOptionalInputEdge(); // fc3_experts_bias + cuda_tester.AddOutput("output", output_dims, ToFloat16(output_data)); + cuda_tester.SetOutputTolerance(0.005f); + + std::vector> cuda_execution_providers; + cuda_execution_providers.push_back(DefaultCudaExecutionProvider()); + cuda_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cuda_execution_providers); + } + + // Test CPU execution provider (always available) + // Skip CPU test if FC3 weights are provided since CPU doesn't support FC3 + if (fc3_experts_weights.empty()) { + // Ensure CPU EP is available before running CPU tests + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + return; // Skip CPU test if CPU EP is not available + } + + OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); + cpu_tester.AddAttribute("k", static_cast(top_k)); + cpu_tester.AddAttribute("activation_type", activation_type); + cpu_tester.AddAttribute("normalize_routing_weights", static_cast(normalize_routing_weights)); + cpu_tester.AddAttribute("expert_weight_bits", static_cast(expert_weight_bits)); + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + std::vector fc1_experts_weights_dims = {num_experts, hidden_size, expert_weight_bits == 4 ? inter_size / 2 : inter_size}; + std::vector fc2_experts_weights_dims = {num_experts, inter_size, expert_weight_bits == 4 ? hidden_size / 2 : hidden_size}; + std::vector fc1_scales_dims = {num_experts, inter_size}; + std::vector fc2_scales_dims = {num_experts, hidden_size}; + std::vector output_dims = {num_rows, hidden_size}; + + cpu_tester.AddInput("input", input_dims, ToFloat16(input)); + cpu_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + + cpu_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + cpu_tester.AddInput("fc1_scales", fc1_scales_dims, fc1_scales); // Use float, not MLFloat16 + cpu_tester.AddOptionalInputEdge(); // fc1_experts_bias + cpu_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + cpu_tester.AddInput("fc2_scales", fc2_scales_dims, fc2_scales); // Use float, not MLFloat16 + cpu_tester.AddOptionalInputEdge(); // fc2_experts_bias + + // CPU doesn't support FC3, so always skip it + cpu_tester.AddOptionalInputEdge(); // fc3_experts_weights (skip FC3 for CPU - not implemented) + cpu_tester.AddOptionalInputEdge(); // fc3_scales (use float, not MLFloat16) + cpu_tester.AddOptionalInputEdge(); // fc3_experts_bias + cpu_tester.AddOutput("output", output_dims, ToFloat16(output_data)); + cpu_tester.SetOutputTolerance(0.01f); // Slightly higher tolerance for CPU vs CUDA differences + + std::vector> cpu_execution_providers; + cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); + cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); } } @@ -307,7 +362,7 @@ TEST(MoETest, MoETest_Gelu) { 1.3354061f, 0.5049282f, 0.72775036f, 0.90331376f, 1.2945517f, 0.9123066f, 1.1995136f, 0.7708638f}; RunMoETest(input, router_probs, fc1_experts_weights, fc2_experts_weights, {}, fc1_experts_bias, fc2_experts_bias, - output, num_rows, num_experts, hidden_size, inter_size, "gelu"); + output, num_rows, num_experts, hidden_size, inter_size, "gelu", 0); } TEST(MoETest, MoETest_Relu) { @@ -485,7 +540,7 @@ TEST(MoETest, MoETest_Relu) { 4.8571277f, 5.649453f, 5.485141f, 5.306299f, 4.767025f, 6.9010167f, 5.3520975f, 6.711155f}; RunMoETest(input, router_probs, fc1_experts_weights, fc2_experts_weights, {}, fc1_experts_bias, fc2_experts_bias, - output, num_rows, num_experts, hidden_size, inter_size, "relu"); + output, num_rows, num_experts, hidden_size, inter_size, "relu", 0); } TEST(MoETest, MoETest_Mixtral) { @@ -1268,8 +1323,373 @@ TEST(MoETest, QMoETest_Mixtral_Int4) { RunQMoETest(input, router_probs, fc1_experts_weights, fc2_experts_weights, fc3_experts_weights, fc1_scales, fc2_scales, fc3_scales, output, num_rows, num_experts, hidden_size, inter_size, "silu", 1, /*normalize_routing_weights*/ - 2 /*top_k*/); + 2, /*top_k*/ + 4 /*expert_weight_bits*/); } + +// CPU-specific QMoE tests +TEST(MoETest, QMoETest_CPU_Int4_MLAS) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + + int num_rows = 2; + int num_experts = 2; + int hidden_size = 32; + int inter_size = 32; + + const std::vector input = { + -0.5f, 0.2f, 1.1f, -0.3f, 0.8f, -0.1f, 0.4f, -0.7f, 0.9f, -0.2f, 0.6f, 0.1f, -0.4f, 0.3f, -0.8f, 0.7f, + 0.2f, -0.5f, 0.1f, 0.9f, -0.3f, 0.6f, -0.1f, 0.4f, -0.7f, 0.8f, 0.3f, -0.2f, 0.5f, 0.1f, -0.6f, 0.9f, + 0.1f, 0.7f, -0.4f, 0.2f, 0.8f, -0.3f, 0.5f, -0.1f, 0.6f, 0.4f, -0.7f, 0.3f, 0.9f, -0.2f, 0.1f, 0.8f, + -0.5f, 0.6f, 0.3f, -0.1f, 0.4f, 0.7f, -0.8f, 0.2f, 0.9f, 0.1f, -0.3f, 0.5f, 0.6f, -0.4f, 0.8f, 0.2f}; + + const std::vector router_probs = {0.3f, 0.7f, 0.6f, 0.4f}; + + // Generate simple test weights for 4-bit symmetric quantization with SwiGLU + // Use 0x88 which unpacks to 8,8 -> 0,0 in signed form (8-8=0) for zero weights + // For SwiGLU: FC1 outputs 2*inter_size (gate + linear), FC2 takes inter_size input + std::vector fc1_experts_weights(num_experts * hidden_size * inter_size, 0x88); // 2*inter_size for SwiGLU, packed into /2 + std::vector fc2_experts_weights(num_experts * inter_size * hidden_size / 2, 0x88); // 8,8 values to produce zero output + std::vector fc3_experts_weights; // Empty for CPU (FC3 not supported) + + std::vector fc1_scales(num_experts * inter_size * 2, 0.01f); // 2x for SwiGLU (gate + linear) + std::vector fc2_scales(num_experts * hidden_size, 0.01f); // Smaller scale factor + std::vector fc3_scales; + + // With zero weights (0x88 -> 8,8 -> 0,0 signed), the implementation will produce all zero outputs + std::vector output(num_rows * hidden_size, 0.0f); + + // Test CPU execution provider ONLY (don't use RunQMoETest which tests both CUDA and CPU) + OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); + cpu_tester.AddAttribute("k", 2); + cpu_tester.AddAttribute("activation_type", "swiglu"); // CPU only supports swiglu + cpu_tester.AddAttribute("normalize_routing_weights", 1); // Always use 1 - softmax normalization always applied + cpu_tester.AddAttribute("expert_weight_bits", 4); // Test 4-bit quantization + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + std::vector fc1_experts_weights_dims = {num_experts, 2 * inter_size, hidden_size / 2}; // SwiGLU: 2*inter_size output, 4-bit packed + std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size / 2}; + std::vector fc1_scales_dims = {num_experts, inter_size * 2}; // 2x for SwiGLU + std::vector fc2_scales_dims = {num_experts, hidden_size}; + std::vector output_dims = {num_rows, hidden_size}; + + cpu_tester.AddInput("input", input_dims, ToFloat16(input)); + cpu_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + cpu_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + cpu_tester.AddInput("fc1_scales", fc1_scales_dims, fc1_scales); // Use float for CPU + cpu_tester.AddOptionalInputEdge(); // fc1_experts_bias + cpu_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + cpu_tester.AddInput("fc2_scales", fc2_scales_dims, fc2_scales); // Use float for CPU + cpu_tester.AddOptionalInputEdge(); // fc2_experts_bias + cpu_tester.AddOptionalInputEdge(); // fc3_experts_weights (skip FC3 for CPU) + cpu_tester.AddOptionalInputEdge(); // fc3_scales (use float for CPU) + cpu_tester.AddOptionalInputEdge(); // fc3_experts_bias + + // When using 0x88 for 4-bit quantized weights with the current implementation, + // all dequantized values should be 0.0f (8-8=0), and thus output should be all zeros + std::vector expected_output(num_rows * hidden_size, 0.0f); + + cpu_tester.AddOutput("output", output_dims, ToFloat16(expected_output)); + cpu_tester.SetOutputTolerance(0.05f); // Small tolerance for numerical differences + + std::vector> cpu_execution_providers; + cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); + cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif +} + +TEST(MoETest, QMoETest_CPU_Int8_MLAS) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + + // Test CPU implementation with 8-bit quantization - CPU ONLY + int num_rows = 1; + int num_experts = 2; + int hidden_size = 16; + int inter_size = 16; + + const std::vector input = { + 0.1f, -0.2f, 0.3f, -0.4f, 0.5f, -0.6f, 0.7f, -0.8f, 0.9f, -1.0f, 1.1f, -1.2f, 1.3f, -1.4f, 1.5f, -1.6f}; + + const std::vector router_probs = {0.4f, 0.6f}; + + // For 8-bit symmetric quantization with SwiGLU + // Use quantized weights at zero point for zero outputs (128 = 0 in signed) + std::vector fc1_experts_weights(num_experts * 2 * inter_size * hidden_size, 128); // 2*inter_size for SwiGLU, no packing for 8-bit + std::vector fc2_experts_weights(num_experts * inter_size * hidden_size, 128); // 128 = 0 in signed + std::vector fc3_experts_weights; // Empty for CPU + + std::vector fc1_scales(num_experts * inter_size * 2, 0.1f); // 2x for SwiGLU + std::vector fc2_scales(num_experts * hidden_size, 0.1f); + std::vector fc3_scales; + + // Expected output should be zero since we're using zero weights (128-128=0) + std::vector output(num_rows * hidden_size, 0.0f); + + // Test with different attributes for 8-bit + OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); + cpu_tester.AddAttribute("k", 1); + cpu_tester.AddAttribute("activation_type", "swiglu"); // CPU only supports swiglu + cpu_tester.AddAttribute("normalize_routing_weights", 1); // Always use 1 - softmax normalization always applied + cpu_tester.AddAttribute("expert_weight_bits", 8); // Test 8-bit quantization + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + std::vector fc1_experts_weights_dims = {num_experts, inter_size * 2, hidden_size}; // SwiGLU: 2*inter_size output, 8-bit no packing + std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size}; + std::vector fc1_scales_dims = {num_experts, inter_size * 2}; // 2x for SwiGLU + std::vector fc2_scales_dims = {num_experts, hidden_size}; + std::vector output_dims = {num_rows, hidden_size}; + + cpu_tester.AddInput("input", input_dims, ToFloat16(input)); + cpu_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + cpu_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + cpu_tester.AddInput("fc1_scales", fc1_scales_dims, fc1_scales); // Use float, not MLFloat16 + cpu_tester.AddOptionalInputEdge(); // fc1_experts_bias + cpu_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + cpu_tester.AddInput("fc2_scales", fc2_scales_dims, fc2_scales); // Use float, not MLFloat16 + cpu_tester.AddOptionalInputEdge(); // fc2_experts_bias + cpu_tester.AddOptionalInputEdge(); // fc3_experts_weights (skip FC3 for CPU) + cpu_tester.AddOptionalInputEdge(); // fc3_scales (use float, not MLFloat16) + cpu_tester.AddOptionalInputEdge(); // fc3_experts_bias + cpu_tester.AddOutput("output", output_dims, ToFloat16(output)); + cpu_tester.SetOutputTolerance(0.05f); // Small tolerance since we expect near-zero output + + std::vector> cpu_execution_providers; + cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); + cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif +} + +TEST(MoETest, QMoETest_CPU_FC3_Error) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + + // Test that CPU throws error when FC3 gating is provided - CPU ONLY + int num_rows = 1; + int num_experts = 2; + int hidden_size = 8; + int inter_size = 8; + + const std::vector input = {0.1f, -0.2f, 0.3f, -0.4f, 0.5f, -0.6f, 0.7f, -0.8f}; + const std::vector router_probs = {0.5f, 0.5f}; + + // Using new layout: fc1 has fused swiglu doubling (2*inter_size) and 4-bit pack_size=2 so hidden_size packed dimension is hidden_size/2 + const int pack_size = 2; // for 4-bit + const int fc1_inter_size = 2 * inter_size; // swiglu fused + std::vector fc1_experts_weights(num_experts * fc1_inter_size * (hidden_size / pack_size), 0x01); + std::vector fc2_experts_weights(num_experts * hidden_size * (inter_size / pack_size), 0x10); + std::vector fc3_experts_weights(num_experts * inter_size * (hidden_size / pack_size), 0x21); // FC3 provided + + std::vector fc1_scales(num_experts * fc1_inter_size, 0.1f); + std::vector fc2_scales(num_experts * hidden_size, 0.05f); + std::vector fc3_scales(num_experts * inter_size, 0.08f); // FC3 scales provided + + // Test CPU execution provider ONLY (designed to test CPU-specific error handling) + OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); + cpu_tester.AddAttribute("k", 1); + cpu_tester.AddAttribute("activation_type", "swiglu"); // CPU only supports swiglu + cpu_tester.AddAttribute("normalize_routing_weights", 1); // Use 1 for consistency, though this test focuses on FC3 error + cpu_tester.AddAttribute("expert_weight_bits", 4); + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + std::vector fc1_experts_weights_dims = {num_experts, fc1_inter_size, hidden_size / pack_size}; + std::vector fc2_experts_weights_dims = {num_experts, hidden_size, inter_size / pack_size}; + std::vector fc3_experts_weights_dims = {num_experts, inter_size, hidden_size / pack_size}; + std::vector fc1_scales_dims = {num_experts, fc1_inter_size}; + std::vector fc2_scales_dims = {num_experts, hidden_size}; + std::vector fc3_scales_dims = {num_experts, inter_size}; + std::vector output_dims = {num_rows, hidden_size}; + + cpu_tester.AddInput("input", input_dims, ToFloat16(input)); + cpu_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + cpu_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + cpu_tester.AddInput("fc1_scales", fc1_scales_dims, fc1_scales); // Use float for CPU + cpu_tester.AddOptionalInputEdge(); // fc1_experts_bias + cpu_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + cpu_tester.AddInput("fc2_scales", fc2_scales_dims, fc2_scales); // Use float for CPU + cpu_tester.AddOptionalInputEdge(); // fc2_experts_bias + cpu_tester.AddInput("fc3_experts_weights", fc3_experts_weights_dims, fc3_experts_weights); // FC3 provided! + cpu_tester.AddInput("fc3_scales", fc3_scales_dims, fc3_scales); // Use float for CPU + cpu_tester.AddOptionalInputEdge(); // fc3_experts_bias + + std::vector dummy_output(num_rows * hidden_size, 0.0f); + cpu_tester.AddOutput("output", output_dims, ToFloat16(dummy_output)); + + std::vector> cpu_execution_providers; + cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); + + // Expect this to fail with FC3 not implemented error + cpu_tester.Run(OpTester::ExpectResult::kExpectFailure, "FC3 gating is not yet implemented", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif +} + +TEST(MoETest, QMoETest_CPU_SwiGLU_Int4) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + + // Test CPU implementation with 4-bit quantization and SwiGLU activation + int num_rows = 2; + int num_experts = 2; + int hidden_size = 16; + int inter_size = 16; + + const std::vector input = { + 0.1f, -0.2f, 0.3f, -0.4f, 0.5f, -0.6f, 0.7f, -0.8f, 0.9f, -1.0f, 1.1f, -1.2f, 1.3f, -1.4f, 1.5f, -1.6f, + 0.2f, -0.3f, 0.4f, -0.5f, 0.6f, -0.7f, 0.8f, -0.9f, 1.0f, -1.1f, 1.2f, -1.3f, 1.4f, -1.5f, 1.6f, -1.7f}; + + const std::vector router_probs = {0.6f, 0.4f, 0.3f, 0.7f}; + + // For SwiGLU, FC1 weights need to be 2x inter_size (concatenated linear + gate weights) + // 4-bit: each uint8 stores 2 weights, so we need (hidden_size * inter_size * 2) / 2 uint8s per expert + const int fc1_weight_size_per_expert = hidden_size * inter_size * 2 / 2; // For 4-bit SwiGLU + const int fc2_weight_size_per_expert = inter_size * hidden_size / 2; // For 4-bit FC2 + + // Generate test weights for symmetric quantization (zero point is 8 for 4-bit) + std::vector fc1_experts_weights(num_experts * fc1_weight_size_per_expert, 0x88); // 8,8 -> 0,0 signed weights + std::vector fc2_experts_weights(num_experts * fc2_weight_size_per_expert, 0x88); // 8,8 -> 0,0 signed weights + std::vector fc3_experts_weights; // Empty for SwiGLU (gate weights concatenated with FC1) + + // Scales: for SwiGLU, FC1 has 2*inter_size outputs (linear + gate) + std::vector fc1_scales(num_experts * inter_size * 2, 0.05f); // Small scale for reasonable outputs + std::vector fc2_scales(num_experts * hidden_size, 0.05f); + std::vector fc3_scales; + + // For SwiGLU with zero weights (0x88 -> 8,8 -> 0,0 signed): + // Gate output = 0, Linear output = 0 + // SwiGLU = gate * sigmoid(gate) * (linear + 1) = 0 * sigmoid(0) * (0 + 1) = 0 * 0.5 * 1 = 0 + // So output should be zero + std::vector output(num_rows * hidden_size, 0.0f); + + OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); + cpu_tester.AddAttribute("k", 2); + cpu_tester.AddAttribute("activation_type", "swiglu"); // Test SwiGLU activation + cpu_tester.AddAttribute("normalize_routing_weights", 1); + cpu_tester.AddAttribute("expert_weight_bits", 4); + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + std::vector fc1_experts_weights_dims = {num_experts, 2 * inter_size, hidden_size / 2}; + std::vector fc2_experts_weights_dims = {num_experts, hidden_size, inter_size / 2}; + std::vector fc1_scales_dims = {num_experts, inter_size * 2}; // 2x for SwiGLU (linear + gate) + std::vector fc2_scales_dims = {num_experts, hidden_size}; + std::vector output_dims = {num_rows, hidden_size}; + + cpu_tester.AddInput("input", input_dims, ToFloat16(input)); + cpu_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + cpu_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + cpu_tester.AddInput("fc1_scales", fc1_scales_dims, fc1_scales); + cpu_tester.AddOptionalInputEdge(); // fc1_experts_bias + cpu_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + cpu_tester.AddInput("fc2_scales", fc2_scales_dims, fc2_scales); + cpu_tester.AddOptionalInputEdge(); // fc2_experts_bias + cpu_tester.AddOptionalInputEdge(); // fc3_experts_weights (empty for SwiGLU) + cpu_tester.AddOptionalInputEdge(); // fc3_scales + cpu_tester.AddOptionalInputEdge(); // fc3_experts_bias + cpu_tester.AddOutput("output", output_dims, ToFloat16(output)); + cpu_tester.SetOutputTolerance(0.02f); // Higher tolerance for SwiGLU nonlinearity + + std::vector> cpu_execution_providers; + cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); + cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif +} + +TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + // Test CPU implementation with 8-bit quantization and SwiGLU activation + int num_rows = 1; + int num_experts = 2; + int hidden_size = 8; + int inter_size = 8; + + const std::vector input = {0.2f, -0.3f, 0.4f, -0.5f, 0.6f, -0.7f, 0.8f, -0.9f}; + const std::vector router_probs = {0.0f, 0.0f}; + + // For SwiGLU with 8-bit symmetric quantization: FC1 weights are 2x inter_size (concatenated linear + gate weights) + const int fc1_weight_size_per_expert = hidden_size * inter_size * 2; // For 8-bit SwiGLU + const int fc2_weight_size_per_expert = inter_size * hidden_size; // For 8-bit FC2 + + // Generate test weights at zero (for symmetric quantization storage format: uint8 with zero point 128) + // Fill with 128 so dequantized value (val - 128) == 0 => zero output + std::vector fc1_experts_weights(num_experts * fc1_weight_size_per_expert, 128); + std::vector fc2_experts_weights(num_experts * fc2_weight_size_per_expert, 128); + std::vector fc3_experts_weights; // Empty for SwiGLU + + // Scales: for SwiGLU, FC1 has 2*inter_size outputs + std::vector fc1_scales(num_experts * inter_size * 2, 0.1f); + std::vector fc2_scales(num_experts * hidden_size, 0.1f); + std::vector fc3_scales; + + std::vector output(num_rows * hidden_size, 0.0f); + + OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); + cpu_tester.AddAttribute("k", 2); + cpu_tester.AddAttribute("activation_type", "swiglu"); // Test SwiGLU activation + cpu_tester.AddAttribute("normalize_routing_weights", 1); + cpu_tester.AddAttribute("expert_weight_bits", 8); + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + std::vector fc1_experts_weights_dims = {num_experts, inter_size * 2, hidden_size}; // 8-bit SwiGLU: explicit 2x + std::vector fc2_experts_weights_dims = {num_experts, hidden_size, inter_size}; + std::vector fc1_scales_dims = {num_experts, inter_size * 2}; // 2x for SwiGLU + std::vector fc2_scales_dims = {num_experts, hidden_size}; + std::vector output_dims = {num_rows, hidden_size}; + + cpu_tester.AddInput("input", input_dims, ToFloat16(input)); + cpu_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + cpu_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + cpu_tester.AddInput("fc1_scales", fc1_scales_dims, fc1_scales); + cpu_tester.AddOptionalInputEdge(); // fc1_experts_bias + cpu_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + cpu_tester.AddInput("fc2_scales", fc2_scales_dims, fc2_scales); + cpu_tester.AddOptionalInputEdge(); // fc2_experts_bias + cpu_tester.AddOptionalInputEdge(); // fc3_experts_weights + cpu_tester.AddOptionalInputEdge(); // fc3_scales + cpu_tester.AddOptionalInputEdge(); // fc3_experts_bias + cpu_tester.AddOutput("output", output_dims, ToFloat16(output)); + cpu_tester.SetOutputTolerance(0.02f); + + std::vector> cpu_execution_providers; + cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); + cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif +} + #endif } // namespace test diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index 4e8d1b9f016f0..df83815cc29ea 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -146,15 +146,10 @@ static void RunOneTest( execution_providers.push_back(DefaultRocmExecutionProvider()); } else { if (strict) { - const auto& api = Ort::GetApi(); - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); - std::unique_ptr - rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions); - std::vector keys{"enable_skip_layer_norm_strict_mode"}; - std::vector values{"1"}; - ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 1) == nullptr); - execution_providers.push_back(CudaExecutionProviderWithOptions(std::move(rel_cuda_options.get()))); + Ort::CUDAProviderOptions cuda_options; + std::unordered_map options = {{"enable_skip_layer_norm_strict_mode", "1"}}; + cuda_options.Update(options); + execution_providers.push_back(CudaExecutionProviderWithOptions(std::move(cuda_options))); } else { execution_providers.push_back(DefaultCudaExecutionProvider()); } diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 45314f8f39eea..7e6d157799d86 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -87,6 +87,72 @@ TEST(EpGraphTest, Check3LayerNestedSubgraphV2) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } +TEST(EpGraphTest, GetAttributeByName) { + // Load model with a single Conv that has no explicit attributes set. + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/conv_default_attrs.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // + // Pre-check + // + + // Original Conv has no explicit attributes but Graph::Resolve() fills in default values for + // 'auto_pad' and 'group'. The other optional attributes (i.e. dilations, kernel_shape, pads, strides) do not + // have statically computable default values, so will not be filled in by Graph::Resolve(). + const OrtGraph& ort_graph = test_graph->GetOrtGraph(); + Ort::ConstGraph graph{&ort_graph}; + + auto nodes = graph.GetNodes(); + ASSERT_EQ(nodes.size(), 1); + + auto conv_node = nodes[0]; + auto op_type = conv_node.GetOperatorType(); + ASSERT_EQ(op_type, "Conv"); + + auto attrs = conv_node.GetAttributes(); + ASSERT_EQ(attrs.size(), 2); + + for (const auto& attr : attrs) { + auto attr_name = attr.GetName(); + ASSERT_TRUE(attr_name == "auto_pad" || attr_name == "group"); // Only 'auto_pad' and 'group' have been set + } + + // + // Test 1: Get optional attribute that is not set (e.g., dilations). Should not get an error. + // + { + Ort::ConstOpAttr attr; + auto status = conv_node.GetAttributeByName("dilations", attr); + ASSERT_EQ(attr, nullptr); + } + + // + // Test 2: Get attribute that does not exist in operator schema. Should get a ORT_NOT_FOUND error. + // + { + Ort::ConstOpAttr attr; + Ort::Status status = conv_node.GetAttributeByName("_does_not_exist_", attr); + ASSERT_FALSE(status.IsOK()); + ASSERT_EQ(status.GetErrorCode(), ORT_NOT_FOUND); + ASSERT_EQ(attr, nullptr); + } + + // + // Test 3: Get attribute that is known to be set. + // + { + Ort::ConstOpAttr attr; + ASSERT_ORTSTATUS_OK(conv_node.GetAttributeByName("auto_pad", attr)); + ASSERT_NE(attr, nullptr); + + OrtOpAttrType type = attr.GetType(); + ASSERT_EQ(ORT_OP_ATTR_STRING, type); + std::string auto_pad_val; + ASSERT_ORTSTATUS_OK(attr.GetValue(auto_pad_val)); + ASSERT_EQ(auto_pad_val, "NOTSET"); + } +} + // Check correctness of an OrtGraph that has external initializers. TEST(EpGraphTest, CheckModelExternalInitializers) { auto test_graph = TestGraph::Load(ORT_TSTR("testdata/conv_qdq_external_ini.onnx")); @@ -143,14 +209,10 @@ TEST(EpGraphTest, SerializeToProto_InputModelHasExternalIni) { std::string ext_ini_file_path = "conv_qdq_ext_ini_serialized.bin"; std::filesystem::remove(ext_ini_file_path); std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); - auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* value_info, + auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* /* value_info */, const void* data, size_t bytes, bool& is_external, std::string& location, int64_t& offset) -> Ort::Status { - // OrtValueInfo* could be used to query initializer's name, type, shape, - // node consumers, etc. - (void)value_info; - if (bytes <= 127) { is_external = false; // Keep small initializers stored inside the TensorProto. return Ort::Status{nullptr}; @@ -220,6 +282,39 @@ static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector& outpu output_data.assign(output_values, output_values + num_output_elems); } +static void RunConstantOfShapeModel(const ORTCHAR_T* model_path, std::vector& output_data) { + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + Ort::SessionOptions sess_options; + Ort::Session session(*ort_env, model_path, sess_options); + + std::vector input_shape = {3}; + std::vector input_data = {2, 3, 4}; + std::vector ort_inputs; + std::vector ort_input_names; + + // Add 'x' + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size())); + ort_input_names.push_back("x"); + + // Run session and get outputs + std::array output_names{"y"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check output type and number of elements. + Ort::Value& ort_output = ort_outputs[0]; + auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo(); + size_t num_output_elems = output_type_shape.GetElementCount(); + + ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + ASSERT_EQ(num_output_elems, 24); + + // Return output data. + const float* output_values = ort_output.GetTensorData(); + output_data.assign(output_values, output_values + num_output_elems); +} + // Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file. // Checks that the outputs of the serialized and original models are identical. TEST(EpGraphTest, SerializeToProto_Mnist) { @@ -323,13 +418,13 @@ TEST(EpGraphTest, SerializeToProto_ExternalInitializersInMemory) { } for (size_t i = 0; i < api_num_initializers; ++i) { - const OrtValue* ort_value = nullptr; - const void* ort_value_data = nullptr; - const char* value_name = nullptr; + std::string value_name; + Ort::ConstValue ort_value; - ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(api_initializers[i], &value_name)); - ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetInitializerValue(api_initializers[i], &ort_value)); - ASSERT_ORTSTATUS_OK(ort_api.GetTensorData(ort_value, &ort_value_data)); + Ort::ConstValueInfo vi(api_initializers[i]); + value_name = vi.GetName(); + ASSERT_ORTSTATUS_OK(vi.GetInitializer(ort_value)); + const void* ort_value_data = ort_value.GetTensorRawData(); auto iter = tensor_proto_map.find(value_name); ASSERT_NE(iter, tensor_proto_map.end()); @@ -350,6 +445,65 @@ TEST(EpGraphTest, SerializeToProto_ExternalInitializersInMemory) { } } +// Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file. +// Checks that the outputs of the serialized and original models are identical. +TEST(EpGraphTest, SerializeToProto_ConstantOfShape) { + const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/ort_minimal_test_models/tensor_attribute.onnx"); + const ORTCHAR_T* serialized_model_path = ORT_TSTR("constant_of_shape.onnx"); + std::filesystem::remove(serialized_model_path); + + { + auto test_graph = TestGraph::Load(original_model_path); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // Serialize OrtGraph to GraphProto. Save initializers to external file. + std::string ext_ini_file_path = "constant_of_shape_serialized.bin"; + std::filesystem::remove(ext_ini_file_path); + std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); + auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, + // node consumers, etc. + static_cast(value_info); + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = ext_ini_ofs.tellp(); + location = ext_ini_file_path; + ext_ini_ofs.write(static_cast(data), bytes); + ext_ini_ofs.flush(); + is_external = true; // True if is external initializer. + + return Ort::Status{nullptr}; + }; + + ONNX_NAMESPACE::ModelProto model_proto; + ASSERT_CXX_ORTSTATUS_OK(OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto, + handle_initializer_data)); + + std::ofstream ofs(serialized_model_path, std::ios::binary); + model_proto.SerializeToOstream(&ofs); + ofs.flush(); + + ASSERT_TRUE(std::filesystem::exists(serialized_model_path)); + ASSERT_TRUE(std::filesystem::exists(ext_ini_file_path)); + } + + // Compare output of the original and serialized models. Should be identical. + std::vector output_original; + std::vector output_serialized; + + RunConstantOfShapeModel(original_model_path, output_original); + RunConstantOfShapeModel(serialized_model_path, output_serialized); + + EXPECT_EQ(output_serialized, output_original); +} + static void Run3LayerModel(const ORTCHAR_T* model_path, bool input_cond, std::vector& output_data) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); Ort::SessionOptions sess_options; @@ -545,25 +699,21 @@ static void CheckValueInfoConsumers(const GraphViewer& graph_viewer, const OrtVa static void CheckInitializerValueInfo(const OrtValueInfo* api_value_info, const ONNX_NAMESPACE::TensorProto* tensor_proto, const GraphViewer& graph_viewer) { - const OrtApi& ort_api = Ort::GetApi(); - - const char* api_initializer_name = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(api_value_info, &api_initializer_name)); - ASSERT_NE(api_initializer_name, nullptr); + Ort::ConstValueInfo vi(api_value_info); + std::string api_initializer_name = vi.GetName(); // Check external initializer info (if any). - OrtExternalInitializerInfo* api_ext_info = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetExternalInitializerInfo(api_value_info, &api_ext_info)); - DeferOrtRelease defer_release_info(&api_ext_info, ort_api.ReleaseExternalInitializerInfo); + Ort::ExternalInitializerInfo api_ext_info{nullptr}; + auto external_status = vi.GetExternalInitializerInfo(api_ext_info); std::unique_ptr ext_info = nullptr; bool has_ext_info = graph_viewer.GetGraph().GetExternalInitializerInfo(api_initializer_name, ext_info, true); if (has_ext_info) { ASSERT_NE(api_ext_info, nullptr); - const ORTCHAR_T* api_ext_file_path = ort_api.ExternalInitializerInfo_GetFilePath(api_ext_info); - int64_t api_ext_file_offset = ort_api.ExternalInitializerInfo_GetFileOffset(api_ext_info); - size_t api_ext_byte_size = ort_api.ExternalInitializerInfo_GetByteSize(api_ext_info); + const std::basic_string api_ext_file_path = api_ext_info.GetFilePath(); + int64_t api_ext_file_offset = api_ext_info.GetFileOffset(); + size_t api_ext_byte_size = api_ext_info.GetByteSize(); ASSERT_EQ(PathString(api_ext_file_path), ext_info->GetRelPath()); ASSERT_EQ(api_ext_file_offset, static_cast(ext_info->GetOffset())); @@ -573,61 +723,49 @@ static void CheckInitializerValueInfo(const OrtValueInfo* api_value_info, ASSERT_FALSE(utils::HasExternalDataInFile(*tensor_proto)); } - const OrtValue* api_initializer_value = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetInitializerValue(api_value_info, &api_initializer_value)); + Ort::ConstValue api_initializer_value; + ASSERT_ORTSTATUS_OK(vi.GetInitializer(api_initializer_value)); ASSERT_NE(api_initializer_value, nullptr); // Check initializer type. const ONNX_NAMESPACE::TypeProto type_proto = utils::TypeProtoFromTensorProto(*tensor_proto); auto type_info = OrtTypeInfo::FromTypeProto(type_proto); - const OrtTypeInfo* api_type_info = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoTypeInfo(api_value_info, &api_type_info)); + Ort::ConstTypeInfo api_type_info = vi.TypeInfo(); CheckTypeInfo(api_type_info, type_info.get()); } -static void CheckInitializerValueInfosCApi(gsl::span initializer_value_infos, +static void CheckInitializerValueInfosCApi(gsl::span initializer_value_infos, const InitializedTensorSet& initializer_tensor_protos, const GraphViewer& graph_viewer) { - const OrtApi& ort_api = Ort::GetApi(); - for (size_t i = 0; i < initializer_value_infos.size(); i++) { - const OrtValueInfo* api_value_info = initializer_value_infos[i]; - - const char* api_initializer_name = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(api_value_info, &api_initializer_name)); - ASSERT_NE(api_initializer_name, nullptr); + Ort::ConstValueInfo vi(initializer_value_infos[i]); + std::string api_initializer_name = vi.GetName(); auto tensor_proto_iter = initializer_tensor_protos.find(api_initializer_name); ASSERT_NE(tensor_proto_iter, initializer_tensor_protos.end()); const ONNX_NAMESPACE::TensorProto* tensor_proto = tensor_proto_iter->second; ASSERT_NE(tensor_proto, nullptr); - - CheckInitializerValueInfo(api_value_info, tensor_proto, graph_viewer); + CheckInitializerValueInfo(vi, tensor_proto, graph_viewer); } } // Checks that the OrtValueInfos obtained from the public C API are "equivalent" to the NodeArgs // in the original graph. -static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span value_infos, +static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span value_infos, gsl::span node_args) { ASSERT_EQ(value_infos.size(), node_args.size()); - const OrtApi& ort_api = Ort::GetApi(); const auto& graph_viewer_inputs = graph_viewer.GetInputsIncludingInitializers(); const auto& graph_viewer_outputs = graph_viewer.GetOutputs(); for (size_t i = 0; i < value_infos.size(); i++) { const NodeArg* node_arg = node_args[i]; - const OrtValueInfo* value_info = value_infos[i]; + Ort::ConstValueInfo vi(value_infos[i]); if (node_arg->Exists()) { const auto& value_name = node_arg->Name(); - - ASSERT_NE(value_info, nullptr); - - const char* api_name = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(value_info, &api_name)); + std::string api_name = vi.GetName(); ASSERT_EQ(std::string(api_name), value_name); bool is_graph_input = std::any_of(graph_viewer_inputs.begin(), graph_viewer_inputs.end(), @@ -647,64 +785,52 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::spanName()); - bool api_is_outer_scope = false; - ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_IsFromOuterScope(value_info, &api_is_outer_scope)); + bool api_is_outer_scope = vi.IsFromOuterScope(); ASSERT_EQ(api_is_outer_scope, is_outer_scope); - bool api_is_const_initializer = false; - ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_IsConstantInitializer(value_info, &api_is_const_initializer)); + bool api_is_const_initializer = vi.IsConstantInitializer(); ASSERT_EQ(api_is_const_initializer, is_const_initializer); if (is_const_initializer || api_is_opt_graph_input) { - CheckInitializerValueInfo(value_info, initializer, graph_viewer); + CheckInitializerValueInfo(vi, initializer, graph_viewer); } else { auto node_arg_type_info = OrtTypeInfo::FromTypeProto(*node_arg->TypeAsProto()); - const OrtTypeInfo* api_type_info = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoTypeInfo(value_info, &api_type_info)); + Ort::ConstTypeInfo api_type_info = vi.TypeInfo(); CheckTypeInfo(api_type_info, node_arg_type_info.get()); } - CheckValueInfoProducer(graph_viewer, value_info, node_arg); - CheckValueInfoConsumers(graph_viewer, value_info, node_arg); + CheckValueInfoProducer(graph_viewer, vi, node_arg); + CheckValueInfoConsumers(graph_viewer, vi, node_arg); } else { - ASSERT_EQ(value_info, nullptr); // A missing optional input has a null OrtValueInfo. + ASSERT_EQ(vi, nullptr); // A missing optional input has a null OrtValueInfo. } } } // Checks the Graph_GetSubgraph C API static void Check_Graph_GetSubgraph(const OrtGraph& api_graph) { - const OrtApi& ort_api = Ort::GetApi(); - + Ort::ConstGraph ort_graph{&api_graph}; // Get all the nodes - size_t num_nodes = 0; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumNodes(&api_graph, &num_nodes)); - - std::vector nodes(num_nodes); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, nodes.data(), nodes.size())); + std::vector nodes = ort_graph.GetNodes(); // Select a half of nodes to create a OrtGraph size_t num_selected_nodes = std::max((nodes.size() >> 1), (size_t)1); - std::vector selected_nodes(num_selected_nodes); + std::vector selected_nodes(num_selected_nodes); for (size_t i = 0; i < num_selected_nodes; i++) { selected_nodes[i] = nodes[i]; } - OrtGraph* sub_graph; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetGraphView(&api_graph, selected_nodes.data(), selected_nodes.size(), &sub_graph)); + Ort::Graph sub_graph = ort_graph.GetGraphView(selected_nodes); // Convert OrtGraph/GraphViewer to ModelProto and dump it to disk. // If the GraphViewer associated with the OrtGraph somehow is incorrect, GraphViewerToProto() will throw. @@ -714,68 +840,65 @@ static void Check_Graph_GetSubgraph(const OrtGraph& api_graph) { GraphViewerToProto(sub_graph_viewer, *model_proto->mutable_graph(), true, true, static_cast(1)); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - const char* graph_name = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetName(&api_graph, &graph_name)); + auto graph_name = ort_graph.GetName(); std::string name = graph_name; name += "_half.onnx"; // Dump the graph for debugging // std::fstream dump(name, std::ios::out | std::ios::trunc | std::ios::binary); // model_proto->SerializeToOstream(&dump); - - ort_api.ReleaseGraph(sub_graph); } // Checks that the contents of the original GraphViewer matches the contents of the OrtGraph. // Uses the public C APIs to traverse the OrtGraph. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph) { - const OrtApi& ort_api = Ort::GetApi(); - + auto ort_cxx_graph = Ort::ConstGraph(&api_graph); // Check the path to model. const std::filesystem::path& model_path = graph_viewer.ModelPath(); - const ORTCHAR_T* api_model_path = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetModelPath(&api_graph, &api_model_path)); + const auto api_model_path = ort_cxx_graph.GetModelPath(); ASSERT_EQ(PathString(api_model_path), PathString(model_path.c_str())); - + // Check the model metadata + Ort::AllocatorWithDefaultOptions default_allocator; + auto ort_cxx_model_metadat = ort_cxx_graph.GetModelMetadata(); + auto& model = graph_viewer.GetGraph().GetModel(); + ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetProducerNameAllocated(default_allocator).get(), model.ProducerName().c_str()), 0); + ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetGraphNameAllocated(default_allocator).get(), model.MainGraph().Name().c_str()), 0); + ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetDomainAllocated(default_allocator).get(), model.Domain().c_str()), 0); + ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetDescriptionAllocated(default_allocator).get(), model.DocString().c_str()), 0); + ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetGraphDescriptionAllocated(default_allocator).get(), model.GraphDocString().c_str()), 0); + ASSERT_EQ(ort_cxx_model_metadat.GetVersion(), model.ModelVersion()); + auto model_meta_data = model.MetaData(); + for (auto& [k, v] : model_meta_data) { + ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.LookupCustomMetadataMapAllocated(k.c_str(), default_allocator).get(), v.c_str()), 0) + << " key=" << k << "; value=" << v; + } // Check graph inputs. const auto& graph_input_node_args = graph_viewer.GetInputsIncludingInitializers(); - size_t api_num_graph_inputs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumInputs(&api_graph, &api_num_graph_inputs)); - ASSERT_EQ(api_num_graph_inputs, graph_input_node_args.size()); + std::vector api_graph_inputs = ort_cxx_graph.GetInputs(); + ASSERT_EQ(api_graph_inputs.size(), graph_input_node_args.size()); - std::vector api_graph_inputs(api_num_graph_inputs); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetInputs(&api_graph, api_graph_inputs.data(), api_graph_inputs.size())); CheckValueInfosCApi(graph_viewer, api_graph_inputs, graph_input_node_args); // Check graph outputs. const auto& graph_output_node_args = graph_viewer.GetOutputs(); - size_t api_num_graph_outputs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumOutputs(&api_graph, &api_num_graph_outputs)); - ASSERT_EQ(api_num_graph_outputs, graph_output_node_args.size()); + std::vector api_graph_outputs = ort_cxx_graph.GetOutputs(); + ASSERT_EQ(api_graph_outputs.size(), graph_output_node_args.size()); - std::vector api_graph_outputs(api_num_graph_outputs); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetOutputs(&api_graph, api_graph_outputs.data(), api_graph_outputs.size())); CheckValueInfosCApi(graph_viewer, api_graph_outputs, graph_output_node_args); // Check graph initializers const auto& graph_initializers = graph_viewer.GetAllInitializedTensors(); - size_t api_num_initializers = 0; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumInitializers(&api_graph, &api_num_initializers)); - ASSERT_EQ(api_num_initializers, graph_initializers.size()); - - std::vector api_initializers(api_num_initializers); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetInitializers(&api_graph, api_initializers.data(), api_initializers.size())); + std::vector api_initializers = ort_cxx_graph.GetInitializers(); + ASSERT_EQ(api_initializers.size(), graph_initializers.size()); CheckInitializerValueInfosCApi(api_initializers, graph_initializers, graph_viewer); // Check if it has a parent node. const Node* parent_node = graph_viewer.ParentNode(); const bool has_parent_node = parent_node != nullptr; - const OrtNode* api_parent_node = nullptr; - - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetParentNode(&api_graph, &api_parent_node)); + Ort::ConstNode api_parent_node = ort_cxx_graph.GetParentNode(); const bool api_has_parent_node = api_parent_node != nullptr; ASSERT_EQ(api_has_parent_node, has_parent_node); @@ -784,79 +907,56 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ } // Check all nodes. - size_t api_num_nodes = 0; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumNodes(&api_graph, &api_num_nodes)); - ASSERT_EQ(api_num_nodes, graph_viewer.NumberOfNodes()); - - std::vector api_nodes(api_num_nodes); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, api_nodes.data(), api_nodes.size())); + std::vector api_nodes = ort_cxx_graph.GetNodes(); + ASSERT_EQ(api_nodes.size(), graph_viewer.NumberOfNodes()); std::vector node_indices = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT); - for (size_t node_idx = 0; node_idx < api_num_nodes; node_idx++) { + for (size_t node_idx = 0; node_idx < api_nodes.size(); node_idx++) { // Check basic node properties. const Node* node = graph_viewer.GetNode(node_indices[node_idx]); - const OrtNode* api_node = api_nodes[node_idx]; + Ort::ConstNode api_node = api_nodes[node_idx]; CheckNode(node, api_node); - int api_since_version = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetSinceVersion(api_node, &api_since_version)); + const int api_since_version = api_node.GetSinceVersion(); ASSERT_EQ(api_since_version, node->SinceVersion()); // Check node inputs const auto input_node_args = node->InputDefs(); - size_t api_node_num_inputs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumInputs(api_node, &api_node_num_inputs)); - ASSERT_EQ(api_node_num_inputs, input_node_args.size()); - - std::vector api_node_inputs(api_node_num_inputs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetInputs(api_node, api_node_inputs.data(), api_node_inputs.size())); + std::vector api_node_inputs = api_node.GetInputs(); + ASSERT_EQ(api_node_inputs.size(), input_node_args.size()); CheckValueInfosCApi(graph_viewer, api_node_inputs, input_node_args); // Check node outputs const auto output_node_args = node->OutputDefs(); - size_t api_node_num_outputs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumOutputs(api_node, &api_node_num_outputs)); - ASSERT_EQ(api_node_num_outputs, output_node_args.size()); - - std::vector api_node_outputs(api_node_num_outputs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetOutputs(api_node, api_node_outputs.data(), api_node_outputs.size())); + std::vector api_node_outputs = api_node.GetOutputs(); + ASSERT_EQ(api_node_outputs.size(), output_node_args.size()); CheckValueInfosCApi(graph_viewer, api_node_outputs, output_node_args); // Check node attributes const auto& node_attrs = node->GetAttributes(); if (!node_attrs.empty()) { - size_t api_num_node_attributes = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumAttributes(api_node, &api_num_node_attributes)); - - std::vector api_node_attributes(api_num_node_attributes); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributes(api_node, api_node_attributes.data(), api_node_attributes.size())); + std::vector api_node_attributes = api_node.GetAttributes(); size_t attr_idx = 0; for (const auto& node_attr : node_attrs) { - const OrtOpAttr* api_node_attr = api_node_attributes[attr_idx]; + auto api_node_attr = api_node_attributes[attr_idx]; ASSERT_NE(api_node_attr, nullptr); - api_node_attr = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributeByName(api_node, node_attr.first.c_str(), &api_node_attr)); + auto status = api_node.GetAttributeByName(node_attr.first, api_node_attr); + ASSERT_TRUE(status.IsOK()); ASSERT_NE(api_node_attr, nullptr); - const char* api_node_attr_name = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.OpAttr_GetName(api_node_attr, &api_node_attr_name)); - ASSERT_STREQ(api_node_attr_name, node_attr.first.c_str()); - - OrtOpAttrType api_node_attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; + auto api_node_attr_name = api_node_attr.GetName(); + ASSERT_EQ(api_node_attr_name, node_attr.first); + // XXX: Investigate why not // It's possible that the type is defined in ONNX::AttributeProto_AttributeType but not in OrtOpAttrType, since the two are not in a 1:1 mapping. // In such cases, OpAttr_GetType will return a non-null status, and we simply skip the check here. // TODO: Once we add support for ORT_OP_ATTR_TENSOR, we should be able to just fail if OpAttr_GetType // returns an error. - OrtStatusPtr status = ort_api.OpAttr_GetType(api_node_attr, &api_node_attr_type); - if (status != nullptr) { - Ort::GetApi().ReleaseStatus(status); - continue; - } + OrtOpAttrType api_node_attr_type = api_node_attr.GetType(); ONNX_NAMESPACE::AttributeProto_AttributeType node_attr_type = node_attr.second.type(); switch (node_attr_type) { @@ -892,9 +992,13 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_GRAPH); break; } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR: { + ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_TENSOR); + break; + } default: // The unsupported type should be skipped by 'continue' above. It's unexpected so we force test to fail. - ASSERT_ORTSTATUS_OK(ort_api.CreateStatus(ORT_FAIL, "The attribute type is not in AttributeProto_AttributeType and this case shouldn't be hit.")); + FAIL() << "The attribute type is not in AttributeProto_AttributeType and this case shouldn't be hit."; } attr_idx++; } @@ -908,41 +1012,19 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ // Check node's implicit inputs to its subgraph nodes. const auto implicit_input_node_args = node->ImplicitInputDefs(); - size_t api_num_node_implicit_inputs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumImplicitInputs(api_node, &api_num_node_implicit_inputs)); - ASSERT_EQ(api_num_node_implicit_inputs, implicit_input_node_args.size()); - - std::vector api_node_implicit_inputs(api_num_node_implicit_inputs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetImplicitInputs(api_node, api_node_implicit_inputs.data(), - api_node_implicit_inputs.size())); - + std::vector api_node_implicit_inputs = api_node.GetImplicitInputs(); + ASSERT_EQ(api_node_implicit_inputs.size(), implicit_input_node_args.size()); CheckValueInfosCApi(graph_viewer, api_node_implicit_inputs, implicit_input_node_args); // Recursively check subgraphs. - size_t api_num_node_subgraphs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumSubgraphs(api_node, &api_num_node_subgraphs)); - ASSERT_EQ(api_num_node_subgraphs, node_subgraphs_map.size()); - - std::vector api_node_subgraphs(api_num_node_subgraphs); - std::vector api_subgraph_attr_names(api_num_node_subgraphs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetSubgraphs(api_node, api_node_subgraphs.data(), api_node_subgraphs.size(), - api_subgraph_attr_names.data())); - - for (const auto& [attr_name, subgraph] : node_subgraphs_map) { - // find index of this subgraph. - size_t api_subgraph_idx = api_num_node_subgraphs; - for (size_t subgraph_idx = 0; subgraph_idx < api_num_node_subgraphs; subgraph_idx++) { - if (api_subgraph_attr_names[subgraph_idx] == attr_name) { - api_subgraph_idx = subgraph_idx; - break; - } - } - ASSERT_NE(api_subgraph_idx, api_num_node_subgraphs); - - // Recursively check the subgraph - auto subgraph_viewer = std::make_unique(*subgraph); - const OrtGraph* api_subgraph = api_node_subgraphs[api_subgraph_idx]; - CheckGraphCApi(*subgraph_viewer, *api_subgraph); + std::vector api_node_subgraphs = api_node.GetSubgraphs(); + ASSERT_EQ(api_node_subgraphs.size(), node_subgraphs_map.size()); + + for (const auto& name_subgraph : api_node_subgraphs) { + auto hit = node_subgraphs_map.find(name_subgraph.attr_name); + ASSERT_NE(node_subgraphs_map.end(), hit); + auto subgraph_viewer = std::make_unique(*hit->second); + CheckGraphCApi(*subgraph_viewer, *name_subgraph.sub_graph); } } } diff --git a/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc b/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc index 63652d8835e77..2e2bce97f0cb9 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc @@ -56,19 +56,19 @@ static Ort::Status GetNodeInputEdgeCount(const OrtNode* node, size_t& num_input_ // Sum the number of inputs with a producer node. num_input_edges = 0; - for (const OrtValueInfo* input : inputs) { + for (const OrtValueInfo* ort_input : inputs) { + Ort::ConstValueInfo input{ort_input}; if (input == nullptr) continue; // Skip missing optional input - const OrtNode* producer_node = nullptr; - RETURN_IF_API_ERROR(ort_api.ValueInfo_GetValueProducer(input, &producer_node, /*output_index*/ nullptr)); - num_input_edges += static_cast(producer_node != nullptr); + auto producer_info = input.GetProducerNode(); + num_input_edges += static_cast(producer_info.node != nullptr); } return Ort::Status{nullptr}; } // Get all output nodes that consume an output from the given node. -static Ort::Status GetOutputNodes(const OrtNode* node, std::vector& result) { +static Ort::Status GetOutputNodes(const OrtNode* node, std::vector& result) { const OrtApi& ort_api = Ort::GetApi(); size_t num_outputs = 0; @@ -77,23 +77,17 @@ static Ort::Status GetOutputNodes(const OrtNode* node, std::vector outputs(num_outputs); RETURN_IF_API_ERROR(ort_api.Node_GetOutputs(node, outputs.data(), outputs.size())); - std::vector output_nodes; + std::vector output_nodes; output_nodes.reserve(num_outputs); // May have more than `num_outputs` // Gather the OrtNode consumers of every output. - for (const OrtValueInfo* output : outputs) { + for (const OrtValueInfo* ort_output : outputs) { + Ort::ConstValueInfo output{ort_output}; if (output == nullptr) continue; // Skip missing optional output - size_t num_consumers = 0; - RETURN_IF_API_ERROR(ort_api.ValueInfo_GetValueNumConsumers(output, &num_consumers)); - - std::vector node_consumers(num_consumers, nullptr); - std::vector input_indices(num_consumers, 0); - RETURN_IF_API_ERROR(ort_api.ValueInfo_GetValueConsumers(output, node_consumers.data(), - input_indices.data(), num_consumers)); - - for (const OrtNode* consumer : node_consumers) { - output_nodes.push_back(consumer); + auto consumers_info = output.GetConsumers(); + for (const auto& consumer : consumers_info) { + output_nodes.push_back(consumer.node); } } @@ -108,77 +102,85 @@ static Ort::Status KahnsTopologicalSort(const OrtGraph& graph, const std::function& comp) { const OrtApi& ort_api = Ort::GetApi(); - // Get all nodes - size_t num_nodes = 0; - RETURN_IF_API_ERROR(ort_api.Graph_GetNumNodes(&graph, &num_nodes)); + try { + // Get all nodes + size_t num_nodes = 0; + RETURN_IF_API_ERROR(ort_api.Graph_GetNumNodes(&graph, &num_nodes)); - if (num_nodes == 0) { - return Ort::Status{nullptr}; // Nothing to sort. - } + if (num_nodes == 0) { + return Ort::Status{nullptr}; // Nothing to sort. + } - std::vector nodes(num_nodes); - RETURN_IF_API_ERROR(ort_api.Graph_GetNodes(&graph, nodes.data(), nodes.size())); + std::vector nodes(num_nodes); + RETURN_IF_API_ERROR(ort_api.Graph_GetNodes(&graph, nodes.data(), nodes.size())); - // Get the maximum node ID. Not really required if we chose to represent the `in_degree` as a map instead of vector. - size_t max_node_id = 0; - for (const OrtNode* node : nodes) { - size_t node_id = 0; - RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); - max_node_id = std::max(max_node_id, node_id); - } + // Get the maximum node ID. Not really required if we chose to represent the `in_degree` as a map instead of vector. + size_t max_node_id = 0; + for (const OrtNode* node : nodes) { + size_t node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); + max_node_id = std::max(max_node_id, node_id); + } - std::vector in_degree(max_node_id + 1, 0); - std::vector topo_order; - VisitorPriorityQueue to_visit(comp); + std::vector in_degree(max_node_id + 1, 0); + std::vector topo_order; + VisitorPriorityQueue to_visit(comp); - topo_order.reserve(num_nodes); + topo_order.reserve(num_nodes); - // Initialize in_degree and initial nodes to visit first. - for (const OrtNode* node : nodes) { - size_t input_edge_count = 0; - RETURN_IF_API_ERROR(GetNodeInputEdgeCount(node, input_edge_count)); + // Initialize in_degree and initial nodes to visit first. + for (const OrtNode* node : nodes) { + size_t input_edge_count = 0; + RETURN_IF_API_ERROR(GetNodeInputEdgeCount(node, input_edge_count)); - size_t node_id = 0; - RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); + size_t node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); - in_degree[node_id] = input_edge_count; - if (input_edge_count == 0) { - to_visit.push(node); + in_degree[node_id] = input_edge_count; + if (input_edge_count == 0) { + to_visit.push(node); + } } - } - while (!to_visit.empty()) { - const OrtNode* current_node = to_visit.top(); - to_visit.pop(); + while (!to_visit.empty()) { + const OrtNode* current_node = to_visit.top(); + to_visit.pop(); - if (!current_node) continue; + if (!current_node) continue; - if (enter) { - enter(current_node); - } + if (enter) { + enter(current_node); + } - std::vector output_nodes; - RETURN_IF_API_ERROR(GetOutputNodes(current_node, output_nodes)); + std::vector output_nodes; + RETURN_IF_API_ERROR(GetOutputNodes(current_node, output_nodes)); - for (const OrtNode* output_node : output_nodes) { - size_t output_node_id = 0; - RETURN_IF_API_ERROR(ort_api.Node_GetId(output_node, &output_node_id)); + for (const auto& output_node : output_nodes) { + size_t output_node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(output_node, &output_node_id)); - auto& node_in_degree = in_degree[output_node_id]; - node_in_degree--; + auto& node_in_degree = in_degree[output_node_id]; + node_in_degree--; - if (node_in_degree == 0) { - to_visit.push(output_node); + if (node_in_degree == 0) { + to_visit.push(output_node); + } } - } - size_t current_node_id = 0; - RETURN_IF_API_ERROR(ort_api.Node_GetId(current_node, ¤t_node_id)); - topo_order.push_back(current_node_id); - } + size_t current_node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(current_node, ¤t_node_id)); + topo_order.push_back(current_node_id); + } - if (num_nodes != topo_order.size()) { - return Ort::Status("Some nodes are not included in the topological sort: graph has a cycle", ORT_FAIL); + if (num_nodes != topo_order.size()) { + return Ort::Status("Some nodes are not included in the topological sort: graph has a cycle", ORT_FAIL); + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status; + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status; } return Ort::Status{nullptr}; diff --git a/onnxruntime/test/framework/TestAllocatorManager.cc b/onnxruntime/test/framework/TestAllocatorManager.cc index 30f2686cd62f5..6440a805cdc59 100644 --- a/onnxruntime/test/framework/TestAllocatorManager.cc +++ b/onnxruntime/test/framework/TestAllocatorManager.cc @@ -10,7 +10,7 @@ namespace test { class DummyArena : public IAllocator { public: explicit DummyArena(std::unique_ptr resource_allocator) - : IAllocator(OrtMemoryInfo(resource_allocator->Info().name, + : IAllocator(OrtMemoryInfo(resource_allocator->Info().name.c_str(), OrtAllocatorType::OrtDeviceAllocator, resource_allocator->Info().device, resource_allocator->Info().mem_type)), diff --git a/onnxruntime/test/framework/allocator_test.cc b/onnxruntime/test/framework/allocator_test.cc index 3efba6f1b6e52..445e023746aaa 100644 --- a/onnxruntime/test/framework/allocator_test.cc +++ b/onnxruntime/test/framework/allocator_test.cc @@ -13,7 +13,7 @@ namespace test { TEST(AllocatorTest, CPUAllocatorTest) { auto cpu_arena = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; - ASSERT_STREQ(cpu_arena->Info().name, CPU); + ASSERT_STREQ(cpu_arena->Info().name.c_str(), CPU); EXPECT_EQ(cpu_arena->Info().device.Id(), 0); const auto expected_allocator_type = DoesCpuAllocatorSupportArenaUsage() diff --git a/onnxruntime/test/framework/cuda/fence_cuda_test.cc b/onnxruntime/test/framework/cuda/fence_cuda_test.cc index b86f3efeefafd..fced72ce3246d 100644 --- a/onnxruntime/test/framework/cuda/fence_cuda_test.cc +++ b/onnxruntime/test/framework/cuda/fence_cuda_test.cc @@ -67,7 +67,7 @@ static common::Status LoadInferenceSessionFromModel(FenceCudaTestInferenceSessio tensor_proto.set_data_type(PROTO_DATATYPE); \ for (auto v : value) tensor_proto.PROTO_ADD_DATA(v); \ tensor_proto.set_name(name); \ - return graph_utils::AddInitializerWithExternalData(graph, tensor_proto); \ + return graph_utils::AddInitializer(graph, tensor_proto); \ } CREATE_INITIALIZER_FUNC(float, TensorProto_DataType_FLOAT, add_float_data) diff --git a/onnxruntime/test/framework/ep_compatibility_test.cc b/onnxruntime/test/framework/ep_compatibility_test.cc new file mode 100644 index 0000000000000..a8a83fbe5ceb6 --- /dev/null +++ b/onnxruntime/test/framework/ep_compatibility_test.cc @@ -0,0 +1,530 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" +#include "core/framework/execution_provider.h" +#include "core/framework/compute_capability.h" +#include "core/framework/kernel_registry.h" +#include "core/graph/graph_viewer.h" +#include "core/graph/model.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" +#include "core/session/utils.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/abi_session_options_impl.h" +#include "core/framework/error_code_helper.h" +#include "dummy_provider.h" +#include "test_utils.h" +#include "test/test_environment.h" +#include "test/providers/provider_test_utils.h" + +using namespace onnxruntime; +using namespace onnxruntime::test; + +namespace { + +// Test execution provider that extends IExecutionProvider with compatibility string functionality +class TestCompatibilityExecutionProvider : public IExecutionProvider { + public: + static constexpr const char* kTestCompatibilityExecutionProviderType = "TestCompatibilityExecutionProvider"; + + TestCompatibilityExecutionProvider() : IExecutionProvider(kTestCompatibilityExecutionProviderType) { + } + + std::shared_ptr GetKernelRegistry() const override { + return std::make_shared(); + } + + std::vector CreatePreferredAllocators() override { + return {}; + } + + // Configurable mock behavior + void SetMockCompatibilityString(const std::string& str) { + mock_compatibility_string_ = str; + } + + void SetMockCompatibilityStatus(OrtCompiledModelCompatibility status) { + mock_compatibility_status_ = status; + } + + void SetShouldFailValidation(bool should_fail) { + should_fail_validation_ = should_fail; + } + + // Override compatibility methods + std::string GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const override { + ORT_UNUSED_PARAMETER(graph_viewer); + return mock_compatibility_string_; + } + + common::Status ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info, + OrtCompiledModelCompatibility& model_compatibility) const override { + if (should_fail_validation_) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Mock validation failure"); + } + + // Simple validation logic for testing + // If the mock status is explicitly set to NOT_APPLICABLE, always return that + if (mock_compatibility_status_ == OrtCompiledModelCompatibility_EP_NOT_APPLICABLE) { + model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + } else if (compatibility_info.empty()) { + model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + } else if (compatibility_info == mock_compatibility_string_) { + model_compatibility = mock_compatibility_status_; + } else { + model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + } + + return Status::OK(); + } + + private: + std::string mock_compatibility_string_ = "default_test_compatibility_v1.0"; + OrtCompiledModelCompatibility mock_compatibility_status_ = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL; + bool should_fail_validation_ = false; +}; + +// Helper class to create test models +class ModelBuilderWithCompatibility { + public: + static std::unique_ptr CreateSimpleTestModel() { + // Create a simple model with a single Add operation + std::unordered_map domain_to_version; + domain_to_version[onnxruntime::kOnnxDomain] = 7; + + auto p_model = std::make_unique("test_model", true, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, + std::vector(), + DefaultLoggingManager().DefaultLogger()); + + onnxruntime::Graph& graph = p_model->MainGraph(); + + // Define tensor type + ONNX_NAMESPACE::TypeProto tensor_float; + tensor_float.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + tensor_float.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2); + tensor_float.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2); + + // Create input and output node args + auto& input_arg_a = graph.GetOrCreateNodeArg("A", &tensor_float); + auto& input_arg_b = graph.GetOrCreateNodeArg("B", &tensor_float); + auto& output_arg = graph.GetOrCreateNodeArg("C", &tensor_float); + + // Create Add node + std::vector input_defs = {&input_arg_a, &input_arg_b}; + std::vector output_defs = {&output_arg}; + graph.AddNode("add_node", "Add", "Add two tensors", input_defs, output_defs, nullptr, onnxruntime::kOnnxDomain); + + auto status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + + return p_model; + } + + static std::unique_ptr CreateModelWithCompatibilityMetadata( + const std::map& ep_compatibility_info) { + auto model = CreateSimpleTestModel(); + + // Add compatibility metadata + auto& metadata = model->MetaData(); + for (const auto& [ep_type, compatibility_string] : ep_compatibility_info) { + std::string metadata_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep_type; + metadata[metadata_key] = compatibility_string; + } + + return model; + } +}; + +// Helper class to create test sessions +class SessionBuilderWithCompatibility { + public: + static std::unique_ptr CreateTestSession(std::unique_ptr model, bool fail_on_suboptimal = false) { + SessionOptions so; + so.session_logid = "EpCompatibilityTest"; + so.session_log_verbosity_level = 1; + + if (fail_on_suboptimal) { + EXPECT_TRUE(so.config_options.AddConfigEntry(kOrtSessionOptionsFailOnSuboptimalCompiledModel, "1").IsOK()); + } + + // Convert Model to ModelProto and serialize + auto model_proto = model->ToProto(); + std::string model_data; + EXPECT_TRUE(model_proto.SerializeToString(&model_data)); + std::stringstream model_stream(model_data); + + // Create session with basic constructor + auto session = std::make_unique(so, GetEnvironment()); + + // Load the model from the stream and validate the status + auto load_status = session->Load(model_stream); + EXPECT_TRUE(load_status.IsOK()) << "Failed to load model: " << load_status.ErrorMessage(); + + return session; + } +}; + +// Helper function to initialize session using the proper validation pathway +Status InitializeSessionWithValidation(InferenceSession& session) { + // Create OrtSessionOptions from the session's SessionOptions to use the proper initialization path + OrtSessionOptions ort_session_options; + ort_session_options.value = session.GetSessionOptions(); + + // Call the InitializeSession function from utils.cc which includes validation + OrtStatus* ort_status = InitializeSession(&ort_session_options, session, nullptr); + + // Convert OrtStatus to Status using the proper helper function + return ToStatusAndRelease(ort_status); +} + +} // anonymous namespace + +class EpCompatibilityTest : public ::testing::Test { + protected: + void SetUp() override { + test_model_ = ModelBuilderWithCompatibility::CreateSimpleTestModel(); + } + + protected: + std::unique_ptr test_model_; +}; + +// Test basic compatibility string generation during compilation +TEST_F(EpCompatibilityTest, TestCompatibilityStringGeneration) { + const std::string expected_compatibility_string = "test_ep_v1.0_compatibility_data"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(expected_compatibility_string); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(test_model_)); + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); + + // Note: In the actual implementation, we would need to trigger EP context model creation + // to see the compatibility strings stored. For now, this tests that the methods are called + // without error during session initialization. +} + +// Test compatibility string storage in model metadata +TEST_F(EpCompatibilityTest, TestCompatibilityStringStorage) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + const std::string expected_compatibility_string = "stored_compatibility_v2.0"; + + // Create model with pre-populated compatibility metadata + std::map compatibility_info = { + {ep_type, expected_compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + // Verify metadata was stored correctly + const auto& metadata = model_with_metadata->MetaData(); + std::string expected_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep_type; + + auto it = metadata.find(expected_key); + ASSERT_NE(it, metadata.end()) << "Expected compatibility metadata key not found: " << expected_key; + EXPECT_EQ(it->second, expected_compatibility_string); +} + +// Test multiple EPs generating different compatibility strings +TEST_F(EpCompatibilityTest, TestMultipleEpCompatibilityStrings) { + std::map compatibility_info = { + {"EP_A", "ep_a_compatibility_v1.0"}, + {"EP_B", "ep_b_compatibility_v2.1"}, + {"EP_C", "ep_c_compatibility_v1.5"}}; + + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + // Verify all compatibility strings are stored + const auto& metadata = model_with_metadata->MetaData(); + for (const auto& [ep_type, expected_string] : compatibility_info) { + std::string expected_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep_type; + auto it = metadata.find(expected_key); + ASSERT_NE(it, metadata.end()) << "Expected compatibility metadata key not found: " << expected_key; + EXPECT_EQ(it->second, expected_string); + } +} + +// Test empty compatibility string handling +TEST_F(EpCompatibilityTest, TestEmptyCompatibilityString) { + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(""); // Empty string + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(test_model_)); + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); // Should succeed even with empty compatibility string +} + +// Test compatibility validation with optimal status +TEST_F(EpCompatibilityTest, TestCompatibilityValidation_Optimal) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + const std::string compatibility_string = "optimal_compatibility_v1.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(compatibility_string); + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL); + + // Create model with matching compatibility metadata + std::map compatibility_info = {{ep_type, compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata)); + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); // Should succeed with optimal compatibility +} + +// Test compatibility validation with suboptimal status (default session settings) +TEST_F(EpCompatibilityTest, TestCompatibilityValidation_Suboptimal_DefaultSettings) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + const std::string compatibility_string = "suboptimal_compatibility_v1.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(compatibility_string); + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION); + + std::map compatibility_info = {{ep_type, compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata), false); // Don't fail on suboptimal + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); // Should succeed by default with suboptimal compatibility +} + +// Test compatibility validation with suboptimal status (fail on suboptimal enabled) +TEST_F(EpCompatibilityTest, TestCompatibilityValidation_Suboptimal_FailEnabled) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + const std::string compatibility_string = "suboptimal_compatibility_v1.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(compatibility_string); + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION); + + std::map compatibility_info = {{ep_type, compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata), true); // Fail on suboptimal + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + + // Should fail during initialization due to suboptimal compatibility + auto status = InitializeSessionWithValidation(*session); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("suboptimal")); +} + +// Test compatibility validation with unsupported status +TEST_F(EpCompatibilityTest, TestCompatibilityValidation_Unsupported) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + const std::string stored_compatibility_string = "old_compatibility_v1.0"; + const std::string current_compatibility_string = "new_compatibility_v2.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(current_compatibility_string); + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_UNSUPPORTED); + + // Model has old compatibility string, EP has new one -> unsupported + std::map compatibility_info = {{ep_type, stored_compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata), false); // Even with fail_on_suboptimal=false + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + + // Should fail during initialization due to unsupported compatibility + auto status = InitializeSessionWithValidation(*session); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("not supported")); +} + +// Test compatibility validation with not applicable status +TEST_F(EpCompatibilityTest, TestCompatibilityValidation_NotApplicable) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(""); // Empty compatibility string + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_NOT_APPLICABLE); + + // Model has some compatibility string, but EP returns not applicable + std::map compatibility_info = {{ep_type, "some_compatibility_string"}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata)); + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); // Should succeed with not applicable status +} + +// Test missing compatibility info in model metadata +TEST_F(EpCompatibilityTest, TestMissingCompatibilityInfo) { + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString("some_compatibility_string"); + + // Use model without any compatibility metadata + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(test_model_)); + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); // Should succeed when no compatibility info is present +} + +// Test EP validation failure +TEST_F(EpCompatibilityTest, TestEpValidationFailure) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + const std::string compatibility_string = "test_compatibility_v1.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(compatibility_string); + test_ep->SetShouldFailValidation(true); // Force validation failure + + std::map compatibility_info = {{ep_type, compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata)); + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + + // Should handle EP validation failure gracefully + auto status = InitializeSessionWithValidation(*session); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Mock validation failure")); +} + +// Test session option configuration for fail on suboptimal +TEST_F(EpCompatibilityTest, TestSessionOptionConfiguration) { + SessionOptions so; + + // Test default value + std::string config_value; + bool has_config = so.config_options.TryGetConfigEntry(kOrtSessionOptionsFailOnSuboptimalCompiledModel, config_value); + EXPECT_FALSE(has_config); // Should not be set by default + + // Test setting the option + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsFailOnSuboptimalCompiledModel, "1")); + has_config = so.config_options.TryGetConfigEntry(kOrtSessionOptionsFailOnSuboptimalCompiledModel, config_value); + EXPECT_TRUE(has_config); + EXPECT_EQ(config_value, "1"); + + // Test setting to disabled + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsFailOnSuboptimalCompiledModel, "0")); + has_config = so.config_options.TryGetConfigEntry(kOrtSessionOptionsFailOnSuboptimalCompiledModel, config_value); + EXPECT_TRUE(has_config); + EXPECT_EQ(config_value, "0"); +} + +// ----------------------------- +// C API unit tests +// ----------------------------- + +namespace { + +// Helper to create an OrtEnv and fetch a CPU EP device pointer via the C API. +// Returns a pair of (env, cpu_device). Caller releases env via api->ReleaseEnv. +static std::pair CreateEnvAndGetCpuEpDevice(const OrtApi* api) { + OrtEnv* env = nullptr; + EXPECT_EQ(nullptr, api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "EpCompatCapiTest", &env)); + EXPECT_NE(env, nullptr); + + const OrtEpDevice* const* devices = nullptr; + size_t num_devices = 0; + EXPECT_EQ(nullptr, api->GetEpDevices(env, &devices, &num_devices)); + EXPECT_GT(num_devices, 0u); + + const OrtEpDevice* cpu_device = nullptr; + for (size_t i = 0; i < num_devices; ++i) { + const char* name = api->EpDevice_EpName(devices[i]); + if (name && std::string(name) == "CPUExecutionProvider") { + cpu_device = devices[i]; + break; + } + } + + // Fallback: just pick the first device if CPU wasn't found (environment-dependent builds). + if (!cpu_device && num_devices > 0) { + cpu_device = devices[0]; + } + + EXPECT_NE(cpu_device, nullptr); + return {env, cpu_device}; +} + +} // namespace + +TEST(EpCompatibilityCapiTest, InvalidArguments) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtCompiledModelCompatibility out_status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + + // ep_devices == nullptr + OrtStatus* st = api->GetModelCompatibilityForEpDevices(nullptr, 0, "info", &out_status); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + // Prepare a valid device + auto [env, device] = CreateEnvAndGetCpuEpDevice(api); + ASSERT_NE(env, nullptr); + ASSERT_NE(device, nullptr); + + // compatibility_info == nullptr + const OrtEpDevice* devices1[] = {device}; + st = api->GetModelCompatibilityForEpDevices(devices1, 1, nullptr, &out_status); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + // out_status == nullptr + st = api->GetModelCompatibilityForEpDevices(devices1, 1, "some-info", nullptr); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + api->ReleaseEnv(env); +} + +TEST(EpCompatibilityCapiTest, CpuEpReturnsNotApplicableIfNoValidation) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + auto [env, device] = CreateEnvAndGetCpuEpDevice(api); + ASSERT_NE(env, nullptr); + ASSERT_NE(device, nullptr); + + OrtCompiledModelCompatibility out_status = static_cast(-1); + const OrtEpDevice* devices2[] = {device}; + OrtStatus* st = api->GetModelCompatibilityForEpDevices(devices2, 1, "arbitrary-compat-string", &out_status); + ASSERT_EQ(st, nullptr) << (st ? api->GetErrorMessage(st) : ""); + + // For providers that don't implement validation, API should return EP_NOT_APPLICABLE. + EXPECT_EQ(out_status, OrtCompiledModelCompatibility_EP_NOT_APPLICABLE); + api->ReleaseStatus(st); + + api->ReleaseEnv(env); +} + +// ----------------------------- +// C++ API unit tests +// ----------------------------- + +TEST(EpCompatibilityCxxApiTest, SingleDeviceCpuProvider) { + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "EpCompatCxx"}; + auto devices = env.GetEpDevices(); + ASSERT_FALSE(devices.empty()); + + std::vector selected; + for (const auto& d : devices) { + if (std::string{d.EpName()} == "CPUExecutionProvider") { + selected.push_back(d); + break; + } + } + + ASSERT_FALSE(selected.empty()); + + // Pick a status that the CPU EP would never return to ensure the value is set correctly. + OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION; + ASSERT_NO_FATAL_FAILURE({ + status = Ort::GetModelCompatibilityForEpDevices(selected, "arbitrary-compat-string"); + }); + + ASSERT_TRUE(status == OrtCompiledModelCompatibility_EP_NOT_APPLICABLE); +} \ No newline at end of file diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 4c5dcd2bd7580..35f7d06fb0912 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include "gsl/gsl" #include "gtest/gtest.h" diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index bc01135fbbf1e..6131eff92ac78 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -588,93 +588,6 @@ TEST(InferenceSessionTests, RequestLoadCancellation) { } } -#ifdef ORT_RUN_EXTERNAL_ONNX_TESTS -static bool Compare(const InputDefList& f_arg, const InputDefList& s_arg) { - if (f_arg.size() != s_arg.size()) { - std::cout << "Sizes differ: f_arg size: " << f_arg.size() << " s_arg size: " << s_arg.size() << std::endl; - return false; - } - - for (size_t i = 0; i < f_arg.size(); ++i) { - const onnxruntime::NodeArg* x = f_arg[i]; - const onnxruntime::NodeArg* y = s_arg[i]; - if ((x->Shape() == nullptr) ^ (y->Shape() == nullptr)) { - return false; - } - if (!x->Shape()) { - continue; - } - auto x_shape = utils::GetTensorShapeFromTensorShapeProto(*x->Shape()); - auto y_shape = utils::GetTensorShapeFromTensorShapeProto(*y->Shape()); - if (x->Name() == y->Name() && x_shape == y_shape && *x->Type() == *y->Type()) { - continue; - } - return false; - } - - return true; -} - -TEST(InferenceSessionTests, ModelMetadata) { - SessionOptions so; - - so.session_logid = "InferenceSessionTests.ModelMetadata"; - InferenceSession session_object{so, GetEnvironment()}; - auto model_uri = ORT_TSTR("../models/opset8/test_squeezenet/model.onnx"); - ASSERT_STATUS_OK(session_object.Load(model_uri)); - - std::shared_ptr p_model; - ASSERT_STATUS_OK(onnxruntime::Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger())); - const onnxruntime::Graph& graph = p_model->MainGraph(); - - // 1. first test the model meta - { - auto retval = session_object.GetModelMetadata(); - ASSERT_TRUE(retval.first.IsOK()); - const ModelMetadata* m = retval.second; - ASSERT_TRUE(m->custom_metadata_map == p_model->MetaData() && - m->description == p_model->DocString() && - m->domain == p_model->Domain() && - m->graph_name == graph.Name() && - m->producer_name == p_model->ProducerName() && - m->version == p_model->ModelVersion()); - } - - { - // 2. test inputs - auto& inputs = graph.GetInputs(); - auto weights = graph.GetAllInitializedTensors(); - - // skip the weights - InputDefList inputs_no_weights; - for (auto& elem : inputs) { - if (weights.find(elem->Name()) != weights.end()) { - continue; - } else { - inputs_no_weights.push_back(elem); - } - } - - auto retval = session_object.GetModelInputs(); - std::cout << "weights size: " << weights.size() - << " inputs.size(): " << inputs.size() - << " from session: " << retval.second->size() << std::endl; - ASSERT_TRUE(retval.first.IsOK()); - ASSERT_TRUE(Compare(inputs_no_weights, *retval.second)); - } - - // 3. test outputs - { - auto retval = session_object.GetModelOutputs(); - ASSERT_TRUE(retval.first.IsOK()); - - auto& outputs = graph.GetOutputs(); - retval = session_object.GetModelOutputs(); - ASSERT_TRUE(retval.first.IsOK()); - ASSERT_TRUE(Compare(outputs, *retval.second)); - } -} -#endif TEST(InferenceSessionTests, CheckRunLogger) { if constexpr (!SessionOptions::DEFAULT_USE_PER_SESSION_THREADS) { GTEST_SKIP() << "Skipping the test"; diff --git a/onnxruntime/test/framework/save_model_with_external_initializers.cc b/onnxruntime/test/framework/save_model_with_external_initializers.cc index 98874874d50e9..e70d870ef6988 100644 --- a/onnxruntime/test/framework/save_model_with_external_initializers.cc +++ b/onnxruntime/test/framework/save_model_with_external_initializers.cc @@ -84,7 +84,7 @@ Status LoadSaveAndCompareModel(const std::filesystem::path& input_onnx, size_t tensor_offset; std::stringstream stream(entry.value()); stream >> tensor_offset; - ORT_RETURN_IF_NOT(tensor_offset % model_saving_options.allocation_granularity == 0, + ORT_RETURN_IF_NOT(tensor_offset % model_saving_options.on_disk_alignment == 0, "tensor offset not align"); } } diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 6ad21fa9f5cf5..a9d6273ae2f20 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -11,6 +11,7 @@ #include "core/framework/kernel_registry.h" #include "core/framework/op_kernel.h" #include "core/framework/bfc_arena.h" +#include "core/framework/ep_context_options.h" #include "core/framework/session_state.h" #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" @@ -504,7 +505,7 @@ void LoadWithResourceAwarePartitioning(const ORTCHAR_T* model_path, ASSERT_STATUS_OK( partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, sess_options.config_options, default_logger, GraphPartitioner::Mode::kNormal, - EpContextModelGenerationOptions{}, + epctx::ModelGenOptions{}, debug_graph_fn)); verifier_fn(graph); diff --git a/onnxruntime/test/framework/tensor_test.cc b/onnxruntime/test/framework/tensor_test.cc index 2ac1a93013932..f08675271de21 100644 --- a/onnxruntime/test/framework/tensor_test.cc +++ b/onnxruntime/test/framework/tensor_test.cc @@ -29,7 +29,7 @@ void CPUTensorTest(std::vector dims, const int offset_elements = 0) { EXPECT_EQ(shape.GetDims(), tensor_shape.GetDims()); EXPECT_EQ(t.DataType(), DataTypeImpl::GetType()); auto& location = t.Location(); - EXPECT_STREQ(location.name, CPU); + EXPECT_STREQ(location.name.c_str(), CPU); EXPECT_EQ(location.device.Id(), 0); const T* t_data = t.Data(); @@ -47,7 +47,7 @@ void CPUTensorTest(std::vector dims, const int offset_elements = 0) { EXPECT_EQ(shape.GetDims(), tensor_shape.GetDims()); EXPECT_EQ(new_t.DataType(), DataTypeImpl::GetType()); auto& new_location = new_t.Location(); - ASSERT_STREQ(new_location.name, CPU); + ASSERT_STREQ(new_location.name.c_str(), CPU); EXPECT_EQ(new_location.device.Id(), 0); } } @@ -135,7 +135,7 @@ TEST(TensorTest, EmptyTensorTest) { EXPECT_TRUE(!data); auto& location = t.Location(); - ASSERT_STREQ(location.name, CPU); + ASSERT_STREQ(location.name.c_str(), CPU); EXPECT_EQ(location.device.Id(), 0); const auto expected_allocator_type = DoesCpuAllocatorSupportArenaUsage() @@ -160,7 +160,7 @@ TEST(TensorTest, StringTensorTest) { EXPECT_EQ(shape, tensor_shape); EXPECT_EQ(t.DataType(), DataTypeImpl::GetType()); auto& location = t.Location(); - ASSERT_STREQ(location.name, CPU); + ASSERT_EQ(location.name, CPU); EXPECT_EQ(location.device.Id(), 0); std::string* new_data = t.MutableData(); diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc index c6d958536f488..324394798863c 100644 --- a/onnxruntime/test/global_thread_pools/test_inference.cc +++ b/onnxruntime/test/global_thread_pools/test_inference.cc @@ -74,8 +74,7 @@ static Ort::Session GetSessionObj(Ort::Env& env, T model_uri, int provider_type) if (provider_type == 1) { #ifdef USE_CUDA - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + Ort::CUDAProviderOptions options; session_options.AppendExecutionProvider_CUDA_V2(*options); std::cout << "Running simple inference with cuda provider" << std::endl; #else diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index e2b54950e7b24..ca1166e19037c 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -1894,14 +1894,21 @@ TEST_F(GraphTest, AddRemoveInitializerHandling) { ASSERT_EQ(graph_proto_from_graph.initializer_size(), 2); auto validate_proto = [&](const GraphProto& proto) { + // Due to changes in a way we generate ToGraphProto() const, we can not guarantee the order of initializers + // in the generated GraphProto. auto initializers = proto.initializer(); - // we expect '2' to be before '1' due to the remove moving the last initializer into the slot of the one being - // removed in order to free memory and only move one entry - EXPECT_EQ(initializers[0].name(), init2.name()); - EXPECT_EQ(initializers[0].int32_data()[0], 2); - - EXPECT_EQ(initializers[1].name(), init.name()); - EXPECT_EQ(initializers[1].int32_data()[0], 1); + auto hit = std::find_if(initializers.begin(), initializers.end(), + [&init](const ONNX_NAMESPACE::TensorProto& t) { return t.name() == init.name(); }); + EXPECT_NE(hit, initializers.end()) + << "Initializer with name '" << init.name() << "' not found in the proto."; + EXPECT_EQ(hit->int32_data()[0], 1); + + hit = std::find_if(initializers.begin(), initializers.end(), + [&init2](const ONNX_NAMESPACE::TensorProto& t) { return t.name() == init2.name(); }); + EXPECT_NE(hit, initializers.end()) + << "Initializer with name '" << init2.name() << "' not found in the proto."; + + EXPECT_EQ(hit->int32_data()[0], 2); }; validate_proto(graph_proto_from_const_graph); diff --git a/onnxruntime/test/ir/onnx_model_test.cc b/onnxruntime/test/ir/onnx_model_test.cc index 9327d86966981..55fc4f42bec64 100644 --- a/onnxruntime/test/ir/onnx_model_test.cc +++ b/onnxruntime/test/ir/onnx_model_test.cc @@ -26,44 +26,6 @@ class ONNXModelsTest : public ::testing::Test { std::unique_ptr logger_; }; -#ifdef ORT_RUN_EXTERNAL_ONNX_TESTS -// Tests that Resolve() properly clears the state of topological sorted nodes, -// inputs, outputs and valueInfo. -// Assumes the graph passed in has been previously resolved. -static void TestResolve(onnxruntime::Graph& graph) { - GraphViewer graph_viewer(graph); - auto& nodes_before = graph_viewer.GetNodesInTopologicalOrder(); - auto& inputs_before = graph.GetInputs(); - auto& outputs_before = graph.GetOutputs(); - auto& value_info_before = graph.GetValueInfo(); - - // Touch the graph to force Resolve() to recompute. - graph.SetGraphResolveNeeded(); - graph.SetGraphProtoSyncNeeded(); - ASSERT_STATUS_OK(graph.Resolve()); - - GraphViewer graph_viewer_2(graph); - auto& nodes_after = graph_viewer_2.GetNodesInTopologicalOrder(); - auto& inputs_after = graph.GetInputs(); - auto& outputs_after = graph.GetOutputs(); - auto& value_info_after = graph.GetValueInfo(); - - // Multiple calls to Resolve() should not alter the sorted nodes, - // inputs, outputs and valueInfo. The internal state should be - // cleared. - EXPECT_EQ(nodes_before, nodes_after); - EXPECT_EQ(inputs_before, inputs_after); - EXPECT_EQ(outputs_before, outputs_after); - EXPECT_EQ(value_info_before, value_info_after); -} - -TEST_F(ONNXModelsTest, squeeze_net) { - // NOTE: this requires the current directory to be where onnxruntime_ir_UT.exe is located - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ORT_TSTR("../models/opset8/test_squeezenet/model.onnx"), model, nullptr, *logger_)); - TestResolve(model->MainGraph()); -} -#endif TEST_F(ONNXModelsTest, non_existing_model) { // NOTE: this requires the current directory to be where onnxruntime_ir_UT.exe is located @@ -96,76 +58,6 @@ class ONNXModelsTest1 : public ::testing::TestWithParam { return oss.str(); } }; -#ifdef ORT_RUN_EXTERNAL_ONNX_TESTS -TEST_F(ONNXModelsTest, bvlc_alexnet_1) { - using ::google::protobuf::io::CodedInputStream; - using ::google::protobuf::io::FileInputStream; - using ::google::protobuf::io::ZeroCopyInputStream; - int fd; - ASSERT_STATUS_OK(Env::Default().FileOpenRd(ORT_TSTR("../models/opset8/test_bvlc_alexnet/model.onnx"), fd)); - ASSERT_TRUE(fd > 0); - std::unique_ptr raw_input(new FileInputStream(fd)); - std::unique_ptr coded_input(new CodedInputStream(raw_input.get())); - // Allows protobuf library versions < 3.2.0 to parse messages greater than 64MB. - coded_input->SetTotalBytesLimit(INT_MAX); - ModelProto model_proto; - bool result = model_proto.ParseFromCodedStream(coded_input.get()); - coded_input.reset(); - raw_input.reset(); - EXPECT_TRUE(result); - ASSERT_STATUS_OK(Env::Default().FileClose(fd)); - - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ORT_TSTR("../models/opset8/test_bvlc_alexnet/model.onnx"), model, nullptr, - *logger_)); - - // Check the graph input/output/value_info should have the same size as specified in the model file. - EXPECT_EQ(static_cast(model_proto.graph().value_info_size()), model->MainGraph().GetValueInfo().size()); - EXPECT_EQ(static_cast(model_proto.graph().input_size()), model->MainGraph().GetInputs().size() + model->MainGraph().GetAllInitializedTensors().size()); - EXPECT_EQ(static_cast(model_proto.graph().output_size()), model->MainGraph().GetOutputs().size()); - TestResolve(model->MainGraph()); -} - -TEST_P(ONNXModelsTest1, LoadFromFile) { - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(GetModelFileName(), model, nullptr, - *logger_)); - TestResolve(model->MainGraph()); -} - -TEST_P(ONNXModelsTest1, LoadFromProtobuf) { - using ::google::protobuf::io::CodedInputStream; - using ::google::protobuf::io::FileInputStream; - using ::google::protobuf::io::ZeroCopyInputStream; - int fd; - ASSERT_STATUS_OK(Env::Default().FileOpenRd(GetModelFileName(), fd)); - ASSERT_TRUE(fd > 0); - std::unique_ptr raw_input(new FileInputStream(fd)); - std::unique_ptr coded_input(new CodedInputStream(raw_input.get())); - coded_input->SetTotalBytesLimit(INT_MAX); - ModelProto model_proto; - bool result = model_proto.ParseFromCodedStream(coded_input.get()); - coded_input.reset(); - raw_input.reset(); - ASSERT_TRUE(result); - ASSERT_STATUS_OK(Env::Default().FileClose(fd)); - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(std::move(model_proto), model, nullptr, - *logger_)); - TestResolve(model->MainGraph()); -} - -#ifndef DISABLE_CONTRIB_OPS -INSTANTIATE_TEST_SUITE_P(ONNXModelsTests, - ONNXModelsTest1, - ::testing::Values(ORT_TSTR("bvlc_alexnet"), ORT_TSTR("bvlc_googlenet"), ORT_TSTR("bvlc_reference_caffenet"), ORT_TSTR("bvlc_reference_rcnn_ilsvrc13"), ORT_TSTR("densenet121"), ORT_TSTR("emotion_ferplus"), ORT_TSTR("inception_v1"), ORT_TSTR("inception_v2"), ORT_TSTR("mnist"), ORT_TSTR("resnet50"), ORT_TSTR("shufflenet"), ORT_TSTR("squeezenet"), ORT_TSTR("tiny_yolov2"), ORT_TSTR("vgg19"), ORT_TSTR("zfnet512"))); -#else -INSTANTIATE_TEST_SUITE_P(ONNXModelsTests, - ONNXModelsTest1, - ::testing::Values(ORT_TSTR("bvlc_alexnet"), ORT_TSTR("bvlc_googlenet"), ORT_TSTR("bvlc_reference_caffenet"), ORT_TSTR("bvlc_reference_rcnn_ilsvrc13"), ORT_TSTR("densenet121"), ORT_TSTR("emotion_ferplus"), ORT_TSTR("inception_v1"), ORT_TSTR("inception_v2"), ORT_TSTR("mnist"), ORT_TSTR("resnet50"), ORT_TSTR("shufflenet"), ORT_TSTR("squeezenet"), ORT_TSTR("vgg19"), ORT_TSTR("zfnet512"))); -#endif - -#endif // test a model that conforms to ONNX IR v4 where there are initializers that are not graph inputs. // a NodeArg should be created for all initializers in this case. diff --git a/onnxruntime/test/lora/lora_test.cc b/onnxruntime/test/lora/lora_test.cc index e8291a36447ca..ecfaf34c8a076 100644 --- a/onnxruntime/test/lora/lora_test.cc +++ b/onnxruntime/test/lora/lora_test.cc @@ -216,7 +216,7 @@ TEST(LoraAdapterTest, VerifyDeviceCopy) { for (; begin != end; ++begin) { const auto& [_, param] = *begin; const auto& tensor_device = param.GetDeviceOrMapped().Get(); - ASSERT_EQ(0, strcmp(tensor_device.Location().name, onnxruntime::CUDA)); + ASSERT_EQ(0, strcmp(tensor_device.Location().name.c_str(), onnxruntime::CUDA)); const auto& tensor_cpu = param.GetMapped().Get(); ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size()); diff --git a/onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp index fad804f3ce305..1be05d88849cd 100644 --- a/onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp @@ -31,10 +31,156 @@ class MlasSQ8BitPrepackTest : public MlasTestBase { std::uniform_real_distribution distrib_f32_; MatrixGuardBuffer inputB_, inputZp_, refB_, packedBuffer_; MatrixGuardBuffer inputScale_, refScale_; - MatrixGuardBuffer inputBlkSum_, refBlkSum_; + MatrixGuardBuffer inputBlkSum_, refBlkSum_, refBlkUnsignedQuantAZeroPointCorrection_; +#ifdef MLAS_TARGET_ARM64 template - void PrepackB(const uint8_t* src, uint8_t* dst) { + void PrepackB(const uint8_t* src, uint8_t* dst, float* refBlkUnsignedQuantAZeroPointCorrection) { + constexpr size_t ldb = (K + BlkLen - 1) & (~(BlkLen - 1)); + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + size_t n = 0; + for (; n - n % 8 + 8 <= N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t src_idx = n * ldb + k; + size_t dst_idx = n / 8 * 8 * ldb + k / 4 * 4 * 8 + (n % 8) * 4 + k % 4; + size_t blkSum_idx = n / 16 * 16 * BlkCount + k / BlkLen * 16 + n % 16; + dst[dst_idx] = src[src_idx]; + if (refBlkUnsignedQuantAZeroPointCorrection) { + refBlkUnsignedQuantAZeroPointCorrection[blkSum_idx] += src[src_idx]; + } + } + } + for (; n - n % 4 + 4 <= N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t src_idx = n * ldb + k; + size_t dst_idx = n / 4 * 4 * ldb + k / 4 * 4 * 4 + (n % 4) * 4 + k % 4; + size_t blkSum_idx = n / 16 * 16 * BlkCount + k / BlkLen * 16 + n % 16; + dst[dst_idx] = src[src_idx]; + if (refBlkUnsignedQuantAZeroPointCorrection) { + refBlkUnsignedQuantAZeroPointCorrection[blkSum_idx] += src[src_idx]; + } + } + } + for (; n < N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t src_idx = n * ldb + k; + size_t dst_idx = n * ldb + k; + size_t blkSum_idx = n / 16 * 16 * BlkCount + k / BlkLen * 16 + n % 16; + dst[dst_idx] = src[src_idx]; + if (refBlkUnsignedQuantAZeroPointCorrection) { + refBlkUnsignedQuantAZeroPointCorrection[blkSum_idx] += src[src_idx]; + } + } + } + } + + template + void PrepackBlkSumAndScale(const float* scale, const uint8_t* zp, float* packedScale, float* blkSum, float* refBlkUnsignedQuantAZeroPointCorrection) { + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + size_t n = 0; + for (; n - n % 8 + 8 <= N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t src_idx = n * BlkCount + k; + size_t scale_dst_idx = n / 8 * 8 * BlkCount + k * 8 + n % 8; + size_t sum_dst_idx = n / 16 * 16 * BlkCount + k * 16 + n % 16; + float zp_val = (zp ? static_cast(zp[src_idx]) : 128.f); + float vSum = -scale[src_idx] * zp_val; + packedScale[scale_dst_idx] = scale[src_idx]; + blkSum[sum_dst_idx] = vSum; + if (refBlkUnsignedQuantAZeroPointCorrection) { + float vSum2 = -refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] + zp_val * std::min(BlkLen, K - k * BlkLen); + refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] = vSum2 * scale[src_idx]; + } + } + } + for (; n - n % 4 + 4 <= N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t src_idx = n * BlkCount + k; + size_t scale_dst_idx = n / 4 * 4 * BlkCount + k * 4 + n % 4; + size_t sum_dst_idx = n / 16 * 16 * BlkCount + k * 16 + n % 16; + float zp_val = (zp ? static_cast(zp[src_idx]) : 128.f); + float vSum = -scale[src_idx] * zp_val; + packedScale[scale_dst_idx] = scale[src_idx]; + blkSum[sum_dst_idx] = vSum; + if (refBlkUnsignedQuantAZeroPointCorrection) { + float vSum2 = -refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] + zp_val * std::min(BlkLen, K - k * BlkLen); + refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] = vSum2 * scale[src_idx]; + } + } + } + for (; n < N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t src_idx = n * BlkCount + k; + size_t scale_dst_idx = n * BlkCount + k; + size_t sum_dst_idx = n / 16 * 16 * BlkCount + k * 16 + n % 16; + float zp_val = (zp ? static_cast(zp[src_idx]) : 128.f); + float vSum = -scale[src_idx] * zp_val; + packedScale[scale_dst_idx] = scale[src_idx]; + blkSum[sum_dst_idx] = vSum; + if (refBlkUnsignedQuantAZeroPointCorrection) { + float vSum2 = -refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] + zp_val * std::min(BlkLen, K - k * BlkLen); + refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] = vSum2 * scale[src_idx]; + } + } + } + } + + template + void CheckB(const uint8_t* packedB, const uint8_t* refB) { + constexpr size_t ldb = (K + BlkLen - 1) & (~(BlkLen - 1)); + size_t n = 0; + for (; n - n % 8 + 8 <= N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t idx = n / 8 * 8 * ldb + k / 4 * 4 * 8 + (n % 8) * 4 + k % 4; + ASSERT_EQ(packedB[idx], refB[idx]) << " at n=" << n << " k=" << k; + } + } + + for (; n - n % 4 + 4 <= N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t idx = n / 4 * 4 * ldb + k / 4 * 4 * 4 + (n % 4) * 4 + k % 4; + ASSERT_EQ(packedB[idx], refB[idx]) << " at n=" << n << " k=" << k; + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t idx = n * ldb + k; + ASSERT_EQ(packedB[idx], refB[idx]) << " at n=" << n << " k=" << k; + } + } + } + + template + void CheckScale(const float* packedScale, const float* refScale) { + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + size_t n = 0; + for (; n - n % 8 + 8 <= N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t idx = n / 8 * 8 * BlkCount + k * 8 + n % 8; + ASSERT_EQ(packedScale[idx], refScale[idx]) << " at n=" << n << " k=" << k; + } + } + + for (; n - n % 4 + 4 <= N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t idx = n / 4 * 4 * BlkCount + k * 4 + n % 4; + ASSERT_EQ(packedScale[idx], refScale[idx]) << " at n=" << n << " k=" << k; + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t idx = n * BlkCount + k; + ASSERT_EQ(packedScale[idx], refScale[idx]) << " at n=" << n << " k=" << k; + } + } + } +#else // not MLAS_TARGET_ARM64 + template + void PrepackB(const uint8_t* src, uint8_t* dst, float* blkUnsignedQuantAZeroPointCorrection) { + MLAS_UNREFERENCED_PARAMETER(blkUnsignedQuantAZeroPointCorrection); + constexpr size_t ldb = (K + BlkLen - 1) & (~(BlkLen - 1)); size_t n = 0; for (; n + 4 <= N; n += 4) { @@ -65,7 +211,9 @@ class MlasSQ8BitPrepackTest : public MlasTestBase { } template - void PrepackBlkSumAndScale(const float* scale, const uint8_t* zp, float* packedScale, float* blkSum) { + void PrepackBlkSumAndScale(const float* scale, const uint8_t* zp, float* packedScale, float* blkSum, float* blkUnsignedQuantAZeroPointCorrection) { + MLAS_UNREFERENCED_PARAMETER(blkUnsignedQuantAZeroPointCorrection); + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; constexpr size_t BlkPerSubBlk = SubBlkLen > BlkLen ? SubBlkLen / BlkLen : 1; @@ -174,10 +322,15 @@ class MlasSQ8BitPrepackTest : public MlasTestBase { } } } +#endif // MLAS_TARGET_ARM64 template void CheckBlkSum(const float* packedBlkSum, const float* refBlkSum) { - size_t BlkCount = (K + BlkLen - 1) / BlkLen; + if (refBlkSum == nullptr) { + return; + } + + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; for (size_t n = 0; n < N; ++n) { for (size_t k = 0; k < BlkCount; ++k) { @@ -198,6 +351,7 @@ class MlasSQ8BitPrepackTest : public MlasTestBase { constexpr size_t PackBCount = N * Ldb; constexpr size_t ScaleCount = BlkCount * N; const size_t BufferSize = MlasQNBitGemmPackQuantBDataSize(N, K, Bits, BlkLen, hasZp, SQNBIT_CompInt8); + const bool isQuantAUnsigned = GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned; const auto* inputB = inputB_.GetFilledBuffer(PackBCount, [this](uint8_t* p, size_t t) { for (size_t i = 0; i < t; i++) { @@ -222,25 +376,36 @@ class MlasSQ8BitPrepackTest : public MlasTestBase { auto* refB = refB_.GetBuffer(PackBCount, true); auto* refScale = refScale_.GetBuffer(ScaleCount, true); auto* refBlkSum = refBlkSum_.GetBuffer(((N + 15) & (~15)) * BlkCount, true); + auto* refBlkUnsignedQuantAZeroPointCorrection = isQuantAUnsigned ? refBlkUnsignedQuantAZeroPointCorrection_.GetBuffer(((N + 15) & (~15)) * BlkCount, true) : nullptr; + + PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen, isQuantAUnsigned); + + // Models the packing calls from MatmulNBits operator - we will have 3 separate calls + // for 3 different inputs in the Prepack() function + // The first call prepacks the quantized weights (and accumulates necessary metadata for BlkUnsignedQuantAZeroPointCorrection). + // The second call prepacks the scales. + // The third call prepacks the zero points. + // The inputScale and zero points will be ignored while prepacking the weights (if they are provided). MlasQNBitGemmPackQuantBData( N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, inputB, packedBuffer, - inputScale, hasZp, nullptr, nullptr); + inputScale, hasZp, inputZp, nullptr); + MlasQNBitGemmPackQuantBData( N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, inputScale, hasZp, nullptr, nullptr); + MlasQNBitGemmPackQuantBData( N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, nullptr, hasZp, inputZp, nullptr); - PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen); + PrepackB(inputB, refB, refBlkUnsignedQuantAZeroPointCorrection); + PrepackBlkSumAndScale(inputScale, inputZp, refScale, refBlkSum, refBlkUnsignedQuantAZeroPointCorrection); - PrepackB(inputB, refB); - PrepackBlkSumAndScale(inputScale, inputZp, refScale, refBlkSum); - - CheckB(refB, reinterpret_cast(packedQuantB.PackedQuantBData)); - CheckScale(refScale, packedQuantB.PackedQuantBScale); - CheckBlkSum(refBlkSum, packedQuantB.QuantBBlkSum); + CheckB(reinterpret_cast(packedQuantB.PackedQuantBData), refB); + CheckScale(packedQuantB.PackedQuantBScale, refScale); + CheckBlkSum(packedQuantB.QuantBBlkSum, refBlkSum); + CheckBlkSum(packedQuantB.BlkUnsignedQuantAZeroPointCorrection, refBlkUnsignedQuantAZeroPointCorrection); } public: @@ -298,31 +463,203 @@ class MlasSQ8BitPrepackTest : public MlasTestBase { Execute<1, 1, 256, 64>(); Execute<16, 4, 16, 64>(); - Execute<32, 4, 16, 64>(); - Execute<64, 4, 16, 64>(); - Execute<128, 4, 16, 64>(); + Execute<32, 8, 16, 64>(); + Execute<64, 12, 32, 64>(); + Execute<128, 16, 64, 64>(); - Execute<15, 5, 16, 64>(); - Execute<15, 5, 32, 64>(); + Execute<15, 3, 16, 64>(); + Execute<15, 4, 32, 64>(); Execute<15, 5, 64, 64>(); - Execute<15, 5, 128, 64>(); - Execute<15, 5, 256, 64>(); - + Execute<15, 6, 128, 64>(); + Execute<15, 7, 256, 64>(); + Execute<15, 8, 16, 64>(); + Execute<15, 9, 16, 64>(); + + Execute<17, 3, 16, 64>(); + Execute<17, 4, 32, 64>(); + Execute<17, 5, 64, 64>(); + Execute<17, 6, 128, 64>(); + Execute<17, 7, 256, 64>(); Execute<17, 8, 16, 64>(); - Execute<17, 8, 32, 64>(); - Execute<17, 8, 64, 64>(); - Execute<17, 8, 128, 64>(); - Execute<17, 8, 256, 64>(); + Execute<17, 9, 16, 64>(); Execute<159, 16, 16, 64>(); Execute<160, 17, 32, 64>(); Execute<161, 15, 64, 64>(); Execute<160, 17, 128, 64>(); Execute<159, 16, 256, 64>(); + Execute<3072, 128, 16, 64>(); } } }; +class MlasSQ8BitQuantAKernelTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_int_distribution distrib_u8_; + std::uniform_real_distribution distrib_f32_; + MatrixGuardBuffer workspace_, refQuantA_; + MatrixGuardBuffer inputA_, refScale_, refBlkSum_; + + template + void QuantA(const float* inputA, uint8_t* quantA, float* scalePtr, float* blkSum, bool quantAUnsigned) { + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + constexpr size_t input_lda = K; + + constexpr size_t Bits = 8; + constexpr size_t output_lda = (((K + BlkLen - 1) & (~(BlkLen - 1))) * Bits + 7) / 8; + + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < BlkCount; ++j) { + float vAbsMax = 0.f; + for (size_t k = 0; k < std::min(BlkLen, K - j * BlkLen); ++k) { + size_t input_idx = i * input_lda + j * BlkLen + k; + vAbsMax = std::max(vAbsMax, fabsf(inputA[input_idx])); + } + + float scale = vAbsMax / 127.f; + float invScale = vAbsMax == 0.f ? 0.f : 127.f / vAbsMax; + scalePtr[i * BlkCount + j] = scale; + + float vSum = 0.f; + for (size_t k = 0; k < BlkLen; ++k) { + size_t input_idx = i * input_lda + j * BlkLen + k; + size_t output_idx = i * output_lda + j * BlkLen + k; + if (k < std::min(BlkLen, K - j * BlkLen)) { + const auto input_val = inputA[input_idx]; + // Round to nearest, ties away from zero + // float v = std::clamp(std::roundf(input_val * invScale), -128.f, 127.f); + + // Round to nearest, ties to even + float v = std::clamp(std::nearbyint(input_val * invScale), -128.f, 127.f); + + if (quantAUnsigned) { + quantA[output_idx] = static_cast(v + 128.f); + vSum += v + 128.f; + } else { + reinterpret_cast(quantA)[output_idx] = static_cast(v); + vSum += v; + } + } else { + quantA[output_idx] = 0; + } + } + blkSum[i * BlkCount + j] = vSum * scale; + } + } + } + + template + void CheckQuantA(const uint8_t* quantA, const uint8_t* refQuantA) { + constexpr size_t lda = (K + BlkLen - 1) & (~(BlkLen - 1)); + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < lda; ++j) { + size_t idx = i * lda + j; + ASSERT_EQ(quantA[idx], refQuantA[idx]) << " at i=" << i << " j=" << j; + } + } + } + + template + void CheckScale(const float* scale, const float* refScale) { + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < BlkCount; ++j) { + size_t idx = i * BlkCount + j; + ASSERT_EQ(scale[idx], refScale[idx]) << " at i=" << i << " j=" << j; + } + } + } + + template + void TestQuantA() { + if (!MlasIsQNBitGemmAvailable(8, BlkLen, SQNBIT_CompInt8)) return; + + const auto* dispatch = GetMlasPlatform().QNBitGemmDispatch; + constexpr size_t Bits = 8; + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + constexpr size_t Lda = (((K + BlkLen - 1) & (~(BlkLen - 1))) * Bits + 7) / 8; + constexpr size_t PackACount = M * Lda; + constexpr size_t ScaleCount = M * BlkCount; + const size_t BufferSize = MlasQNBitGemmBatchWorkspaceSize(M, 1, K, 1, Bits, BlkLen, true, SQNBIT_CompInt8); + const bool isQuantAUnsigned = GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned; + + const auto* inputA = inputA_.GetFilledBuffer(M * K, [this](float* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = this->distrib_f32_(this->gen_); + } + }); + + auto* workspace = workspace_.GetBuffer(BufferSize, true); + auto* refQuantA = refQuantA_.GetBuffer(PackACount, true); + auto* refScale = refScale_.GetBuffer(ScaleCount, true); + auto* refBlkSum = refBlkSum_.GetBuffer(ScaleCount, true); + + const size_t Alignment = dispatch->QNBitGemmPerGemmWorkspaceAlignment(BlkLen, SQNBIT_CompInt8); + const uintptr_t WorkspaceAddress = reinterpret_cast(workspace); + auto* quantAPtr = reinterpret_cast((WorkspaceAddress + Alignment - 1) & (~(Alignment - 1))); + auto* scaleAPtr = reinterpret_cast(quantAPtr + PackACount); + auto* blkSumAPtr = scaleAPtr + ScaleCount; + + for (size_t i = 0; i < M; ++i) { + dispatch->QuantizeARowComputeBlkSum_CompInt8(BlkLen, inputA + i * K, K, quantAPtr + i * Lda, scaleAPtr + i * BlkCount, blkSumAPtr + i * BlkCount); + } + + QuantA(inputA, refQuantA, refScale, refBlkSum, isQuantAUnsigned); + CheckQuantA(reinterpret_cast(quantAPtr), refQuantA); + CheckScale(scaleAPtr, refScale); + CheckScale(blkSumAPtr, refBlkSum); + } + + public: + MlasSQ8BitQuantAKernelTest() + : seed_(19287), gen_(seed_), distrib_u8_(0, 255), distrib_f32_(-10.f, 10.f) { + } + + static const char* GetTestSuiteName() { + return "SQ8BitQuantA"; + } + + void ExecuteShort(void) override { + TestQuantA<1, 16, 16>(); + TestQuantA<1, 1, 32>(); + TestQuantA<1, 1, 64>(); + TestQuantA<1, 1, 128>(); + TestQuantA<1, 1, 256>(); + + TestQuantA<4, 16, 16>(); + TestQuantA<8, 32, 16>(); + TestQuantA<12, 64, 32>(); + TestQuantA<16, 128, 64>(); + + TestQuantA<3, 15, 16>(); + TestQuantA<4, 15, 32>(); + TestQuantA<5, 15, 64>(); + TestQuantA<6, 15, 128>(); + TestQuantA<7, 15, 256>(); + TestQuantA<8, 15, 16>(); + TestQuantA<9, 15, 16>(); + + TestQuantA<3, 17, 16>(); + TestQuantA<4, 17, 32>(); + TestQuantA<5, 17, 64>(); + TestQuantA<6, 17, 128>(); + TestQuantA<7, 17, 256>(); + TestQuantA<8, 17, 16>(); + TestQuantA<9, 17, 16>(); + + TestQuantA<2, 159, 16>(); + TestQuantA<3, 159, 16>(); + TestQuantA<17, 160, 32>(); + TestQuantA<15, 161, 64>(); + TestQuantA<17, 160, 128>(); + TestQuantA<16, 159, 256>(); + + TestQuantA<1, 3072, 16>(); + } +}; + class MlasSQ8BitGemmKernelTest : public MlasTestBase { private: unsigned int seed_; @@ -383,9 +720,6 @@ class MlasSQ8BitGemmKernelTest : public MlasTestBase { } }); - int q_rows, q_cols; - MlasBlockwiseQuantizedShape((int)BlkLen, true, (int)K, (int)N, q_rows, q_cols); - size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; MlasBlockwiseQuantizedBufferSizes<8>((int)(BlkLen), true, (int)K, (int)N, q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); @@ -420,24 +754,34 @@ class MlasSQ8BitGemmKernelTest : public MlasTestBase { size_t bufferSize = MlasQNBitGemmPackQuantBDataSize(N, K, 8, BlkLen, HasZp, SQNBIT_CompInt8); auto* packedBuffer = packedBuffer_.GetBuffer(bufferSize, true); + // Models the packing calls from MatmulNBits operator - we will have 3 separate calls + // for 3 different inputs in the Prepack() function + // The first call prepacks the quantized weights (and accumulates necessary metadata for BlkUnsignedQuantAZeroPointCorrection). + // The second call prepacks the scales. + // The third call prepacks the zero points. + + // The inputScale and zero points will be ignored while prepacking the weights (if they are provided). MlasQNBitGemmPackQuantBData( N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, inputB, packedBuffer, - inputScale, HasZp, nullptr, nullptr); + inputScale, HasZp, inputZp, nullptr); + MlasQNBitGemmPackQuantBData( N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, inputScale, HasZp, nullptr, nullptr); + MlasQNBitGemmPackQuantBData( N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, nullptr, HasZp, inputZp, nullptr); - PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen); + const bool isQuantAUnsigned = GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned; + PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen, isQuantAUnsigned); auto* C = C_.GetBuffer(M * ldc, true); auto* ref = ref_.GetBuffer(M * ldc, true); - auto* bias = HasBias ? bias_.GetFilledBuffer(N, [this](float* p, size_t t) { + auto* bias = HasBias ? bias_.GetFilledBuffer(N, [](float* p, size_t t) { for (size_t i = 0; i < t; i++) { - p[i] = this->distrib_f32_(this->gen_); + p[i] = (float)(5 + i); } }) : nullptr; @@ -473,14 +817,17 @@ class MlasSQ8BitGemmKernelTest : public MlasTestBase { template void Execute(void) { - TestSQ8BitGemmKernel(); TestSQ8BitGemmKernel(); - TestSQ8BitGemmKernel(); TestSQ8BitGemmKernel(); + + TestSQ8BitGemmKernel(); + TestSQ8BitGemmKernel(); } void ExecuteShort(void) override { + Execute<1, 16, 1, 16>(); Execute<1, 1, 1, 16>(); + Execute<7, 2, 4, 16>(); Execute<7, 128, 4, 16>(); Execute<8, 497, 5, 16>(); Execute<1, 3072, 128, 16>(); @@ -515,6 +862,7 @@ static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_exe size_t count = 0; if (is_short_execute) { count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); count += MlasDirectShortExecuteTests::RegisterShortExecute(); } return count; diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 98640bb2f6b4c..f626a1704f7a1 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -5370,8 +5370,59 @@ TEST(QDQTransformerTests, WeightBiasQuantization_Conv_Weight_Bias) { #endif } +// Tests that the WeightBiasQuantization optimizer still processes nodes that contain a type-preserving no +// branch ReLU op to QuantizeLinear e.g., Q -> DQ -> Conv (w/ float weight initializer) -> ReLU -> Q -> DQ +TEST(QDQTransformerTests, WeightBiasQuantization_ConvWithReLU) { + auto test_case = [](bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + NodeArg* input_fp32 = builder.MakeInput({1, 1, 4, 4}, -1.0f, 1.0f); + NodeArg* weight_fp32 = builder.MakeInitializer({2, 1, 3, 3}, -1.0f, 1.0f); + NodeArg* input_q = builder.MakeIntermediate(); + NodeArg* input_dq = builder.MakeIntermediate(); + NodeArg* conv_fp32 = builder.MakeIntermediate(); + NodeArg* relu_fp32 = builder.MakeIntermediate(); + NodeArg* relu_q = builder.MakeIntermediate(); + NodeArg* relu_dq = builder.MakeOutput(); + builder.AddQuantizeLinearNode(input_fp32, 0.18f, static_cast(127), input_q, use_contrib_qdq); + builder.AddDequantizeLinearNode(input_q, 0.18f, static_cast(127), input_dq, use_contrib_qdq); + auto& conv_node = builder.AddNode("Conv", {input_dq, weight_fp32}, {conv_fp32}); + conv_node.AddAttribute("dilations", std::vector{1, 1}); + conv_node.AddAttribute("kernel_shape", std::vector{3, 3}); + conv_node.AddAttribute("strides", std::vector{1, 1}); + conv_node.AddAttribute("group", static_cast(1)); + conv_node.AddAttribute("pads", std::vector{0, 0, 0, 0}); + builder.AddNode("Relu", {conv_fp32}, {relu_fp32}); + builder.AddQuantizeLinearNode(relu_fp32, 0.69f, static_cast(127), relu_q, use_contrib_qdq); + builder.AddDequantizeLinearNode(relu_q, 0.69f, static_cast(127), relu_dq, use_contrib_qdq); + }; + + // Conv's weights should be quantized and folded, one additional Q/DQ pair inserted for weight + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["QuantizeLinear"] + op_to_count["com.microsoft.QuantizeLinear"], 2 + 1); + EXPECT_EQ(op_to_count["DequantizeLinear"] + op_to_count["com.microsoft.DequantizeLinear"], 2 + 1); + EXPECT_EQ(op_to_count["Conv"], 1); + EXPECT_EQ(op_to_count["Relu"], 1); + }; + + TransformerTester(build_test_case, + check_transformed_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + /*opset_version=*/20, + /*per_sample_tolerance=*/0.01, + /*relative_per_sample_tolerance=*/0.01, + /*transformer=*/std::make_unique()); + }; + + test_case(false); +#if !defined(DISABLE_CONTRIB_OPS) + test_case(true); +#endif +} + // Tests that the WeightBiasQuantization optimizer does not process nodes that do not -// already have an output that is consumed by a single QuantizeLinear node. +// already have an output that is consumed by a valid path to QuantizeLinear node. TEST(QDQTransformerTests, WeightBiasQuantization_SkipIfOutputNotQuantized) { auto test_case = [](bool add_final_reshape) { auto build_test_case = [&](ModelTestBuilder& builder) { diff --git a/onnxruntime/test/optimizer/resnet50_fusion_test.cc b/onnxruntime/test/optimizer/resnet50_fusion_test.cc index 5cb0206156a84..7e6677c8e1ddf 100644 --- a/onnxruntime/test/optimizer/resnet50_fusion_test.cc +++ b/onnxruntime/test/optimizer/resnet50_fusion_test.cc @@ -16,7 +16,6 @@ namespace onnxruntime { namespace test { -// #define ORT_RUN_EXTERNAL_ONNX_TESTS // #define MLAS_F16VEC_INTRINSICS_SUPPORTED #define MODEL_FOLDER ORT_TSTR("testdata/transform/") @@ -28,54 +27,7 @@ class ResNet50FusionTests : public ::testing::Test { } std::unique_ptr logger; }; -#if defined(ORT_RUN_EXTERNAL_ONNX_TESTS) -TEST_F(ResNet50FusionTests, FuseConvIntegrationTest) { - std::basic_string fp32_model_path = ORT_TSTR("../models/opset10/Resnet50_Fusion_Testing/resnet50.onnx"); - std::shared_ptr fp32_model; - std::basic_string fp16_model_path = ORT_TSTR("../models/opset10/Resnet50_Fusion_Testing_fp16/resnet50.fp16.onnx"); - std::shared_ptr fp16_model; - if (Model::Load(fp32_model_path, fp32_model, nullptr, *logger) != Status::OK()) { - GTEST_SKIP() << "Failed to load model: " << fp32_model_path; - } - if (Model::Load(fp16_model_path, fp16_model, nullptr, *logger) != Status::OK()) { - GTEST_SKIP() << "Failed to load model: " << fp16_model_path; - } - // ASSERT_STATUS_OK(Model::Load(fp32_model_path, fp32_model, nullptr, *logger)); - Graph& fp32_graph = fp32_model->MainGraph(); - for (auto& node : fp32_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCpuExecutionProvider); - } - Graph& fp16_graph = fp16_model->MainGraph(); - for (auto& node : fp16_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCpuExecutionProvider); - } - // std::cout << "-------Op Counts Before Fusion---------" << std::endl; - std::map fp32_op_count = CountOpsInGraph(fp32_graph); - std::map fp16_op_count = CountOpsInGraph(fp16_graph); - for (auto& op : fp32_op_count) { - // std::cout << op.first << " " << op.second << std::endl; - ASSERT_EQ(op.second, fp16_op_count[op.first]); - } - onnxruntime::GraphTransformerManager graph_transformation_mgr_32{5}; - ASSERT_STATUS_OK(graph_transformation_mgr_32.Register(std::make_unique(), TransformerLevel::Level3)); - ASSERT_STATUS_OK(graph_transformation_mgr_32.Register(std::make_unique(), TransformerLevel::Level3)); - ASSERT_STATUS_OK(graph_transformation_mgr_32.ApplyTransformers(fp32_graph, TransformerLevel::Level3, *logger)); - ASSERT_STATUS_OK(Model::Save(*fp32_model, ORT_TSTR("resnet50_fused.onnx"))); - - onnxruntime::GraphTransformerManager graph_transformation_mgr_16{5}; - ASSERT_STATUS_OK(graph_transformation_mgr_16.Register(std::make_unique(), TransformerLevel::Level3)); - ASSERT_STATUS_OK(graph_transformation_mgr_16.Register(std::make_unique(), TransformerLevel::Level3)); - ASSERT_STATUS_OK(graph_transformation_mgr_16.ApplyTransformers(fp16_graph, TransformerLevel::Level3, *logger)); - ASSERT_STATUS_OK(Model::Save(*fp16_model, ORT_TSTR("resnet50_fp16_fused.onnx"))); - // std::cout << "-------Op Counts After Fusion---------" << std::endl; - fp32_op_count = CountOpsInGraph(fp32_graph); - fp16_op_count = CountOpsInGraph(fp16_graph); - // for (auto& op : fp32_op_count) { - // ASSERT_EQ(op.second, fp16_op_count[op.first]); - // } -} -#endif // defined(ORT_RUN_EXTERNAL_ONNX_TESTS) TEST_F(ResNet50FusionTests, FuseConvAddReluUnitTest) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/conv_add_relu_fp16.onnx"; std::shared_ptr p_model; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 843875a881f0a..46958843872d7 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -4,6 +4,7 @@ // Licensed under the MIT License. #include "command_args_parser.h" +#include "utils.h" #include #include @@ -11,14 +12,6 @@ #include #include -// Windows Specific -#ifdef _WIN32 -#include "getopt.h" -#include "windows.h" -#else -#include -#endif - #include #include #include @@ -26,407 +19,483 @@ #include "test_configuration.h" #include "strings_helper.h" +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/flags/usage.h" +#include "absl/flags/usage_config.h" +#include "absl/flags/reflection.h" + +static const onnxruntime::perftest::PerformanceTestConfig& DefaultPerformanceTestConfig() { + static onnxruntime::perftest::PerformanceTestConfig default_config{}; + return default_config; +} + +ABSL_FLAG(std::string, f, "", + "Specifies a free dimension by name to override to a specific value for performance optimization.\n" + "[Usage]: -f \"dimension_name1:override_value1\" -f \"dimension_name2:override_value2\" ... or" + " -f \"dimension_name1:override_value1 dimension_name2:override_value2 ... \". Override value must > 0."); +ABSL_FLAG(std::string, F, "", + "Specifies a free dimension by denotation to override to a specific value for performance optimization.\n" + "[Usage]: -f \"dimension_denotation1:override_value1\" -f \"dimension_denotation2:override_value2\" ... or" + " -f \"dimension_denotation1:override_value1 dimension_denotation2 : override_value2... \". Override value must > 0."); +ABSL_FLAG(std::string, m, "duration", "Specifies the test mode. Value could be 'duration' or 'times'."); +ABSL_FLAG(std::string, e, "cpu", "Specifies the provider 'cpu','cuda','dnnl','tensorrt', 'nvtensorrtrtx', 'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'qnn', 'snpe', 'rocm', 'migraphx', 'xnnpack', 'vitisai' or 'webgpu'."); +ABSL_FLAG(size_t, r, DefaultPerformanceTestConfig().run_config.repeated_times, "Specifies the repeated times if running in 'times' test mode."); +ABSL_FLAG(size_t, t, DefaultPerformanceTestConfig().run_config.duration_in_seconds, "Specifies the seconds to run for 'duration' mode."); +ABSL_FLAG(std::string, p, "", "Specifies the profile name to enable profiling and dump the profile data to the file."); +ABSL_FLAG(int, x, DefaultPerformanceTestConfig().run_config.intra_op_num_threads, "Sets the number of threads used to parallelize the execution within nodes, A value of 0 means ORT will pick a default. Must >=0."); +ABSL_FLAG(int, y, DefaultPerformanceTestConfig().run_config.inter_op_num_threads, "Sets the number of threads used to parallelize the execution of the graph (across nodes), A value of 0 means ORT will pick a default. Must >=0."); +ABSL_FLAG(size_t, c, DefaultPerformanceTestConfig().run_config.concurrent_session_runs, "Specifies the (max) number of runs to invoke simultaneously."); +ABSL_FLAG(int, d, DefaultPerformanceTestConfig().run_config.cudnn_conv_algo, "Specifies CUDNN convolution algorithms: 0(benchmark), 1(heuristic), 2(default)."); +ABSL_FLAG(int, o, DefaultPerformanceTestConfig().run_config.optimization_level, "Specifies graph optimization level. Default is 99 (all). Valid values are 0 (disable), 1 (basic), 2 (extended), 3 (layout), 99 (all)."); +ABSL_FLAG(std::string, u, "", "Specifies the optimized model path for saving."); +ABSL_FLAG(std::string, i, "", + "Specifies EP specific runtime options as key-value pairs.\n Different runtime options available are: \n" + " [Usage]: -e -i '| |'\n" + "\n" + " [ACL only] [enable_fast_math]: Options: 'true', 'false', default: 'false', \n" + "\n" + " [DML only] [performance_preference]: DML device performance preference, options: 'default', 'minimum_power', 'high_performance', \n" + " [DML only] [device_filter]: DML device filter, options: 'any', 'gpu', 'npu', \n" + " [DML only] [disable_metacommands]: Options: 'true', 'false', \n" + " [DML only] [enable_graph_capture]: Options: 'true', 'false', \n" + " [DML only] [enable_graph_serialization]: Options: 'true', 'false', \n" + "\n" + " [OpenVINO only] [device_type]: Overrides the accelerator hardware type and precision with these values at runtime.\n" + " [OpenVINO only] [device_id]: Selects a particular hardware device for inference.\n" + " [OpenVINO only] [num_of_threads]: Overrides the accelerator hardware type and precision with these values at runtime.\n" + " [OpenVINO only] [cache_dir]: Explicitly specify the path to dump and load the blobs(Model caching) or cl_cache (Kernel Caching) files feature. If blob files are already present, it will be directly loaded.\n" + " [OpenVINO only] [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU device(Reduces the CPU Utilization while using GPU) \n" + " [OpenVINO only] [reshape_input]: Sets model input shapes with support for bounded dynamic dimensions using 'min..max' syntax (e.g., [1..10,3,224,224]) \n" + " [OpenVINO only] [layout]: Specifies the layout for inputs/outputs to interpret tensor dimensions correctly. \n" + " [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU num_of_threads|5 enable_opencl_throttling|true reshape_input|[1,3,60,60..100] layout|[NCHW] cache_dir|\"\"\"\n" + "\n" + " [QNN only] [backend_type]: QNN backend type. E.g., 'cpu', 'htp'. Mutually exclusive with 'backend_path'.\n" + " [QNN only] [backend_path]: QNN backend path. E.g., '/folderpath/libQnnHtp.so', '/winfolderpath/QnnHtp.dll'. Mutually exclusive with 'backend_type'.\n" + " [QNN only] [profiling_level]: QNN profiling level, options: 'basic', 'detailed', default 'off'.\n" + " [QNN only] [profiling_file_path] : QNN profiling file path if ETW not enabled.\n" + " [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n" + " [QNN only] [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n" + " [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n" + " 'high_power_saver', 'low_balanced', 'extreme_power_saver', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n" + " [QNN only] [op_packages]: QNN UDO package, allowed format: \n" + " op_packages|::[:],::[:]. \n" + " [QNN only] [qnn_context_priority]: QNN context priority, options: 'low', 'normal', 'normal_high', 'high'. Default to 'normal'. \n" + " [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n" + " [QNN only] [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: \n" + " '0', '1', '2', '3', default is '0'.\n" + " [QNN only] [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" + " [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. \n" + " Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" + " [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" + " [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" + " Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" + " [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" + " Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" + " [QNN only] [enable_htp_spill_fill_buffer]: Enable HTP spill fill buffer, used while generating QNN context binary.\n" + " [QNN only] [enable_htp_shared_memory_allocator]: Enable the QNN HTP shared memory allocator and use it for inputs and outputs. Requires libcdsprpc.so/dll to be available.\n" + " Defaults to '0' (disabled).\n" + " [Example] [For QNN EP] -e qnn -i \"backend_type|cpu\" \n" + "\n" + " [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" + " [TensorRT only] [trt_min_subgraph_size]: Minimum size of TensorRT subgraphs.\n" + " [TensorRT only] [trt_max_workspace_size]: Set TensorRT maximum workspace size in byte.\n" + " [TensorRT only] [trt_fp16_enable]: Enable TensorRT FP16 precision.\n" + " [TensorRT only] [trt_int8_enable]: Enable TensorRT INT8 precision.\n" + " [TensorRT only] [trt_int8_calibration_table_name]: Specify INT8 calibration table name.\n" + " [TensorRT only] [trt_int8_use_native_calibration_table]: Use Native TensorRT calibration table.\n" + " [TensorRT only] [trt_dla_enable]: Enable DLA in Jetson device.\n" + " [TensorRT only] [trt_dla_core]: DLA core number.\n" + " [TensorRT only] [trt_dump_subgraphs]: Dump TRT subgraph to onnx model.\n" + " [TensorRT only] [trt_engine_cache_enable]: Enable engine caching.\n" + " [TensorRT only] [trt_engine_cache_path]: Specify engine cache path.\n" + " [TensorRT only] [trt_engine_cache_prefix]: Customize engine cache prefix when trt_engine_cache_enable is true.\n" + " [TensorRT only] [trt_engine_hw_compatible]: Enable hardware compatibility. Engines ending with '_sm80+' can be re-used across all Ampere+ GPU (a hardware-compatible engine may have lower throughput and/or higher latency than its non-hardware-compatible counterpart).\n" + " [TensorRT only] [trt_weight_stripped_engine_enable]: Enable weight-stripped engine build.\n" + " [TensorRT only] [trt_onnx_model_folder_path]: Folder path for the ONNX model with weights.\n" + " [TensorRT only] [trt_force_sequential_engine_build]: Force TensorRT engines to be built sequentially.\n" + " [TensorRT only] [trt_context_memory_sharing_enable]: Enable TensorRT context memory sharing between subgraphs.\n" + " [TensorRT only] [trt_layer_norm_fp32_fallback]: Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow.\n" + " [Example] [For TensorRT EP] -e tensorrt -i 'trt_fp16_enable|true trt_int8_enable|true trt_int8_calibration_table_name|calibration.flatbuffers trt_int8_use_native_calibration_table|false trt_force_sequential_engine_build|false'\n" + "\n" + " [NNAPI only] [NNAPI_FLAG_USE_FP16]: Use fp16 relaxation in NNAPI EP..\n" + " [NNAPI only] [NNAPI_FLAG_USE_NCHW]: Use the NCHW layout in NNAPI EP.\n" + " [NNAPI only] [NNAPI_FLAG_CPU_DISABLED]: Prevent NNAPI from using CPU devices.\n" + " [NNAPI only] [NNAPI_FLAG_CPU_ONLY]: Using CPU only in NNAPI EP.\n" + " [Example] [For NNAPI EP] -e nnapi -i \"NNAPI_FLAG_USE_FP16 NNAPI_FLAG_USE_NCHW NNAPI_FLAG_CPU_DISABLED\"\n" + "\n" + " [CoreML only] [ModelFormat]:[MLProgram, NeuralNetwork] Create an ML Program model or Neural Network. Default is NeuralNetwork.\n" + " [CoreML only] [MLComputeUnits]:[CPUAndNeuralEngine CPUAndGPU ALL CPUOnly] Specify to limit the backend device used to run the model.\n" + " [CoreML only] [AllowStaticInputShapes]:[0 1].\n" + " [CoreML only] [EnableOnSubgraphs]:[0 1].\n" + " [CoreML only] [SpecializationStrategy]:[Default FastPrediction].\n" + " [CoreML only] [ProfileComputePlan]:[0 1].\n" + " [CoreML only] [AllowLowPrecisionAccumulationOnGPU]:[0 1].\n" + " [CoreML only] [ModelCacheDirectory]:[path../a/b/c].\n" + " [Example] [For CoreML EP] -e coreml -i \"ModelFormat|MLProgram MLComputeUnits|CPUAndGPU\"\n" + "\n" + " [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" + " [SNPE only] [priority]: execution priority, options: 'low', 'normal'. \n" + " [SNPE only] [buffer_type]: options: 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. default: ITENSOR'. \n" + " [SNPE only] [enable_init_cache]: enable SNPE init caching feature, set to 1 to enabled it. Disabled by default. \n" + " [Example] [For SNPE EP] -e snpe -i \"runtime|CPU priority|low\" \n"); +ABSL_FLAG(int, S, DefaultPerformanceTestConfig().run_config.random_seed_for_input_data, "Given random seed, to produce the same input data. This defaults to -1(no initialize)."); +ABSL_FLAG(std::string, T, "", "Specifies intra op thread affinity string."); +ABSL_FLAG(std::string, C, "", + "Specifies session configuration entries as key-value pairs:\n -C \"| |\" \n" + "Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n" + "[Example] -C \"session.disable_cpu_ep_fallback|1 ep.context_enable|1\" \n"); +ABSL_FLAG(std::string, R, "", "Allows user to register custom op by .so or .dll file."); +ABSL_FLAG(bool, A, DefaultPerformanceTestConfig().run_config.enable_cpu_mem_arena, "Disables memory arena."); +ABSL_FLAG(bool, M, DefaultPerformanceTestConfig().run_config.enable_memory_pattern, "Disables memory pattern."); +ABSL_FLAG(bool, s, DefaultPerformanceTestConfig().run_config.f_dump_statistics, "Shows statistics result, like P75, P90. If no result_file provided this defaults to on."); +ABSL_FLAG(bool, v, DefaultPerformanceTestConfig().run_config.f_verbose, "Shows verbose information."); +ABSL_FLAG(bool, I, DefaultPerformanceTestConfig().run_config.generate_model_input_binding, "Generates tensor input binding. Free dimensions are treated as 1 unless overridden using -f."); +ABSL_FLAG(bool, P, false, "Uses parallel executor instead of sequential executor."); +ABSL_FLAG(bool, q, DefaultPerformanceTestConfig().run_config.do_cuda_copy_in_separate_stream, "[CUDA only] Uses separate stream for copy."); +ABSL_FLAG(bool, z, DefaultPerformanceTestConfig().run_config.set_denormal_as_zero, "Sets denormal as zero. When turning on this option reduces latency dramatically, a model may have denormals."); +ABSL_FLAG(bool, D, DefaultPerformanceTestConfig().run_config.disable_spinning, "Disables spinning entirely for thread owned by onnxruntime intra-op thread pool."); +ABSL_FLAG(bool, Z, DefaultPerformanceTestConfig().run_config.disable_spinning_between_run, "Disallows thread from spinning during runs to reduce cpu usage."); +ABSL_FLAG(bool, n, DefaultPerformanceTestConfig().run_config.exit_after_session_creation, "Allows user to measure session creation time to measure impact of enabling any initialization optimizations."); +ABSL_FLAG(bool, l, DefaultPerformanceTestConfig().model_info.load_via_path, "Provides file as binary in memory by using fopen before session creation."); +ABSL_FLAG(bool, g, DefaultPerformanceTestConfig().run_config.enable_cuda_io_binding, "[TensorRT RTX | TensorRT | CUDA] Enables tensor input and output bindings on CUDA before session run."); +ABSL_FLAG(bool, X, DefaultPerformanceTestConfig().run_config.use_extensions, "Registers custom ops from onnxruntime-extensions."); +ABSL_FLAG(std::string, plugin_ep_libs, "", + "Specifies a list of plugin execution provider (EP) registration names and their corresponding shared libraries to register.\n" + "[Usage]: --plugin_ep_libs \"plugin_ep_name_1|plugin_ep_1.dll plugin_ep_name_2|plugin_ep_2.dll ... \""); +ABSL_FLAG(std::string, plugin_eps, "", "Specifies a semicolon-separated list of plugin execution providers (EPs) to use."); +ABSL_FLAG(std::string, plugin_ep_options, "", + "Specifies provider options for each EP listed in --plugin_eps. Options (key-value pairs) for each EP are separated by space and EPs are separated by semicolons.\n" + "[Usage]: --plugin_ep_options \"ep_1_option_1_key|ep_1_option_1_value ...;ep_2_option_1_key|ep_2_option_1_value ...;... \" or \n" + "--plugin_ep_options \";ep_2_option_1_key|ep_2_option_1_value ...;... \" or \n" + "--plugin_ep_options \"ep_1_option_1_key|ep_1_option_1_value ...;;ep_3_option_1_key|ep_3_option_1_value ...;... \""); +ABSL_FLAG(bool, list_ep_devices, false, "Prints all available device indices and their properties (including metadata). This option makes the program exit early without performing inference.\n"); +ABSL_FLAG(std::string, select_ep_devices, "", "Specifies a semicolon-separated list of device indices to add to the session and run with."); +ABSL_FLAG(bool, h, false, "Print program usage."); + namespace onnxruntime { namespace perftest { -/*static*/ void CommandLineParser::ShowUsage() { - printf( - "perf_test [options...] model_path [result_file]\n" - "Options:\n" - "\t-m [test_mode]: Specifies the test mode. Value could be 'duration' or 'times'.\n" - "\t\tProvide 'duration' to run the test for a fix duration, and 'times' to repeated for a certain times. \n" - "\t-M: Disable memory pattern.\n" - "\t-A: Disable memory arena\n" - "\t-I: Generate tensor input binding. Free dimensions are treated as 1 unless overridden using -f.\n" - "\t-c [parallel runs]: Specifies the (max) number of runs to invoke simultaneously. Default:1.\n" - "\t-e [cpu|cuda|dnnl|tensorrt|openvino|dml|acl|nnapi|coreml|qnn|snpe|rocm|migraphx|xnnpack|vitisai|webgpu]: Specifies the provider 'cpu','cuda','dnnl','tensorrt', " - "'nvtensorrtrtx', 'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'qnn', 'snpe', 'rocm', 'migraphx', 'xnnpack', 'vitisai' or 'webgpu'. " - "Default:'cpu'.\n" - "\t-b [tf|ort]: backend to use. Default:ort\n" - "\t-r [repeated_times]: Specifies the repeated times if running in 'times' test mode.Default:1000.\n" - "\t-t [seconds_to_run]: Specifies the seconds to run for 'duration' mode. Default:600.\n" - "\t-p [profile_file]: Specifies the profile name to enable profiling and dump the profile data to the file.\n" - "\t-s: Show statistics result, like P75, P90. If no result_file provided this defaults to on.\n" - "\t-S: Given random seed, to produce the same input data. This defaults to -1(no initialize).\n" - "\t-v: Show verbose information.\n" - "\t-x [intra_op_num_threads]: Sets the number of threads used to parallelize the execution within nodes, A value of 0 means ORT will pick a default. Must >=0.\n" - "\t-y [inter_op_num_threads]: Sets the number of threads used to parallelize the execution of the graph (across nodes), A value of 0 means ORT will pick a default. Must >=0.\n" - "\t-f [free_dimension_override]: Specifies a free dimension by name to override to a specific value for performance optimization. " - "Syntax is [dimension_name:override_value]. override_value must > 0\n" - "\t-F [free_dimension_override]: Specifies a free dimension by denotation to override to a specific value for performance optimization. " - "Syntax is [dimension_denotation:override_value]. override_value must > 0\n" - "\t-P: Use parallel executor instead of sequential executor.\n" - "\t-o [optimization level]: Default is 99 (all). Valid values are 0 (disable), 1 (basic), 2 (extended), 3 (layout), 99 (all).\n" - "\t\tPlease see onnxruntime_c_api.h (enum GraphOptimizationLevel) for the full list of all optimization levels.\n" - "\t-u [optimized_model_path]: Specify the optimized model path for saving.\n" - "\t-d [CUDA only][cudnn_conv_algorithm]: Specify CUDNN convolution algorithms: 0(benchmark), 1(heuristic), 2(default). \n" - "\t-q [CUDA only] use separate stream for copy. \n" - "\t-g [TensorRT RTX | TensorRT | CUDA] Enable tensor input and output bindings on CUDA before session run \n" - "\t-z: Set denormal as zero. When turning on this option reduces latency dramatically, a model may have denormals.\n" - "\t-C: Specify session configuration entries as key-value pairs: -C \"| |\" \n" - "\t Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n" - "\t [Example] -C \"session.disable_cpu_ep_fallback|1 ep.context_enable|1\" \n" - "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" - "\t [Usage]: -e -i '| |'\n" - "\n" - "\t [ACL only] [enable_fast_math]: Options: 'true', 'false', default: 'false', \n" - "\t [DML only] [performance_preference]: DML device performance preference, options: 'default', 'minimum_power', 'high_performance', \n" - "\t [DML only] [device_filter]: DML device filter, options: 'any', 'gpu', 'npu', \n" - "\t [DML only] [disable_metacommands]: Options: 'true', 'false', \n" - "\t [DML only] [enable_graph_capture]: Options: 'true', 'false', \n" - "\t [DML only] [enable_graph_serialization]: Options: 'true', 'false', \n" - "\n" - "\t [OpenVINO only] [device_type]: Overrides the accelerator hardware type and precision with these values at runtime.\n" - "\t [OpenVINO only] [device_id]: Selects a particular hardware device for inference.\n" - "\t [OpenVINO only] [num_of_threads]: Overrides the accelerator hardware type and precision with these values at runtime.\n" - "\t [OpenVINO only] [cache_dir]: Explicitly specify the path to dump and load the blobs(Model caching) or cl_cache (Kernel Caching) files feature. If blob files are already present, it will be directly loaded.\n" - "\t [OpenVINO only] [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU device(Reduces the CPU Utilization while using GPU) \n" - "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" - "\n" - "\t [QNN only] [backend_type]: QNN backend type. E.g., 'cpu', 'htp'. Mutually exclusive with 'backend_path'.\n" - "\t [QNN only] [backend_path]: QNN backend path. E.g., '/folderpath/libQnnHtp.so', '/winfolderpath/QnnHtp.dll'. Mutually exclusive with 'backend_type'.\n" - "\t [QNN only] [profiling_level]: QNN profiling level, options: 'basic', 'detailed', default 'off'.\n" - "\t [QNN only] [profiling_file_path] : QNN profiling file path if ETW not enabled.\n" - "\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n" - "\t [QNN only] [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n" - "\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n" - "\t 'high_power_saver', 'low_balanced', 'extreme_power_saver', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n" - "\t [QNN only] [op_packages]: QNN UDO package, allowed format: \n" - "\t op_packages|::[:],::[:]. \n" - "\t [QNN only] [qnn_context_priority]: QNN context priority, options: 'low', 'normal', 'normal_high', 'high'. Default to 'normal'. \n" - "\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n" - "\t [QNN only] [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: \n" - "\t '0', '1', '2', '3', default is '0'.\n" - "\t [QNN only] [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" - "\t [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. \n" - "\t Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" - "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" - "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" - "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" - "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" - "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" - "\t [QNN only] [enable_htp_spill_fill_buffer]: Enable HTP spill fill buffer, used while generating QNN context binary.\n" - "\t [QNN only] [enable_htp_shared_memory_allocator]: Enable the QNN HTP shared memory allocator and use it for inputs and outputs. Requires libcdsprpc.so/dll to be available.\n" - "\t Defaults to '0' (disabled).\n" - "\t [Example] [For QNN EP] -e qnn -i \"backend_type|cpu\" \n" - "\n" - "\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" - "\t [TensorRT only] [trt_min_subgraph_size]: Minimum size of TensorRT subgraphs.\n" - "\t [TensorRT only] [trt_max_workspace_size]: Set TensorRT maximum workspace size in byte.\n" - "\t [TensorRT only] [trt_fp16_enable]: Enable TensorRT FP16 precision.\n" - "\t [TensorRT only] [trt_int8_enable]: Enable TensorRT INT8 precision.\n" - "\t [TensorRT only] [trt_int8_calibration_table_name]: Specify INT8 calibration table name.\n" - "\t [TensorRT only] [trt_int8_use_native_calibration_table]: Use Native TensorRT calibration table.\n" - "\t [TensorRT only] [trt_dla_enable]: Enable DLA in Jetson device.\n" - "\t [TensorRT only] [trt_dla_core]: DLA core number.\n" - "\t [TensorRT only] [trt_dump_subgraphs]: Dump TRT subgraph to onnx model.\n" - "\t [TensorRT only] [trt_engine_cache_enable]: Enable engine caching.\n" - "\t [TensorRT only] [trt_engine_cache_path]: Specify engine cache path.\n" - "\t [TensorRT only] [trt_engine_cache_prefix]: Customize engine cache prefix when trt_engine_cache_enable is true.\n" - "\t [TensorRT only] [trt_engine_hw_compatible]: Enable hardware compatibility. Engines ending with '_sm80+' can be re-used across all Ampere+ GPU (a hardware-compatible engine may have lower throughput and/or higher latency than its non-hardware-compatible counterpart).\n" - "\t [TensorRT only] [trt_weight_stripped_engine_enable]: Enable weight-stripped engine build.\n" - "\t [TensorRT only] [trt_onnx_model_folder_path]: Folder path for the ONNX model with weights.\n" - "\t [TensorRT only] [trt_force_sequential_engine_build]: Force TensorRT engines to be built sequentially.\n" - "\t [TensorRT only] [trt_context_memory_sharing_enable]: Enable TensorRT context memory sharing between subgraphs.\n" - "\t [TensorRT only] [trt_layer_norm_fp32_fallback]: Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow.\n" - "\t [Example] [For TensorRT EP] -e tensorrt -i 'trt_fp16_enable|true trt_int8_enable|true trt_int8_calibration_table_name|calibration.flatbuffers trt_int8_use_native_calibration_table|false trt_force_sequential_engine_build|false'\n" - "\n" - "\t [NNAPI only] [NNAPI_FLAG_USE_FP16]: Use fp16 relaxation in NNAPI EP..\n" - "\t [NNAPI only] [NNAPI_FLAG_USE_NCHW]: Use the NCHW layout in NNAPI EP.\n" - "\t [NNAPI only] [NNAPI_FLAG_CPU_DISABLED]: Prevent NNAPI from using CPU devices.\n" - "\t [NNAPI only] [NNAPI_FLAG_CPU_ONLY]: Using CPU only in NNAPI EP.\n" - "\t [Example] [For NNAPI EP] -e nnapi -i \"NNAPI_FLAG_USE_FP16 NNAPI_FLAG_USE_NCHW NNAPI_FLAG_CPU_DISABLED\"\n" - "\n" - "\t [CoreML only] [ModelFormat]:[MLProgram, NeuralNetwork] Create an ML Program model or Neural Network. Default is NeuralNetwork.\n" - "\t [CoreML only] [MLComputeUnits]:[CPUAndNeuralEngine CPUAndGPU ALL CPUOnly] Specify to limit the backend device used to run the model.\n" - "\t [CoreML only] [AllowStaticInputShapes]:[0 1].\n" - "\t [CoreML only] [EnableOnSubgraphs]:[0 1].\n" - "\t [CoreML only] [SpecializationStrategy]:[Default FastPrediction].\n" - "\t [CoreML only] [ProfileComputePlan]:[0 1].\n" - "\t [CoreML only] [AllowLowPrecisionAccumulationOnGPU]:[0 1].\n" - "\t [CoreML only] [ModelCacheDirectory]:[path../a/b/c].\n" - "\t [Example] [For CoreML EP] -e coreml -i \"ModelFormat|MLProgram MLComputeUnits|CPUAndGPU\"\n" - "\n" - "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" - "\t [SNPE only] [priority]: execution priority, options: 'low', 'normal'. \n" - "\t [SNPE only] [buffer_type]: options: 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. default: ITENSOR'. \n" - "\t [SNPE only] [enable_init_cache]: enable SNPE init caching feature, set to 1 to enabled it. Disabled by default. \n" - "\t [Example] [For SNPE EP] -e snpe -i \"runtime|CPU priority|low\" \n\n" - "\n" - "\t-T [Set intra op thread affinities]: Specify intra op thread affinity string\n" - "\t [Example]: -T 1,2;3,4;5,6 or -T 1-2;3-4;5-6 \n" - "\t\t Use semicolon to separate configuration between threads.\n" - "\t\t E.g. 1,2;3,4;5,6 specifies affinities for three threads, the first thread will be attached to the first and second logical processor.\n" - "\t\t The number of affinities must be equal to intra_op_num_threads - 1\n\n" - "\t-D [Disable thread spinning]: disable spinning entirely for thread owned by onnxruntime intra-op thread pool.\n" - "\t-Z [Force thread to stop spinning between runs]: disallow thread from spinning during runs to reduce cpu usage.\n" - "\t-n [Exit after session creation]: allow user to measure session creation time to measure impact of enabling any initialization optimizations.\n" - "\t-l Provide file as binary in memory by using fopen before session creation.\n" - "\t-R [Register custom op]: allow user to register custom op by .so or .dll file.\n" - "\t-X [Enable onnxruntime-extensions custom ops]: Registers custom ops from onnxruntime-extensions. " - "onnxruntime-extensions must have been built in to onnxruntime. This can be done with the build.py " - "'--use_extensions' option.\n" - "\t-h: help\n"); +std::string CustomUsageMessage() { + std::ostringstream oss; + oss << "onnxruntime_perf_test [options...] model_path [result_file]\n\n"; + oss << "Note: Options may be specified with either a single dash(-option) or a double dash(--option). Both forms are accepted and treated identically.\n\n"; + oss << "Options:"; + + return oss.str(); } -#ifdef _WIN32 -static const ORTCHAR_T* overrideDelimiter = L":"; -#else -static const ORTCHAR_T* overrideDelimiter = ":"; -#endif -static bool ParseDimensionOverride(std::basic_string& dim_identifier, int64_t& override_val) { - std::basic_string free_dim_str(optarg); - size_t delimiter_location = free_dim_str.find(overrideDelimiter); - if (delimiter_location >= free_dim_str.size() - 1) { - return false; + +bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { + // Following callback is to make sure all the ABSL flags defined above will be showed up when running with "--help". + // Note: By default abseil only wants flags in binary's main. It expects the main routine to reside in .cc or -main.cc or + // _main.cc, where the is the name of the binary (without .exe on Windows). See usage_config.cc in abseil for more details. + absl::FlagsUsageConfig config; + config.contains_help_flags = [](absl::string_view filename) { + return std::filesystem::path(filename).filename() == std::filesystem::path(__FILE__).filename(); + }; + + config.normalize_filename = [](absl::string_view f) { + return std::string(f); + }; + absl::SetFlagsUsageConfig(config); + absl::SetProgramUsageMessage(CustomUsageMessage()); + + auto utf8_argv_strings = utils::ConvertArgvToUtf8Strings(argc, argv); + auto utf8_argv = utils::CStringsFromStrings(utf8_argv_strings); + auto positional = absl::ParseCommandLine(static_cast(utf8_argv.size()), utf8_argv.data()); + + // -f + { + const auto& dim_override_str = absl::GetFlag(FLAGS_f); + if (!dim_override_str.empty()) { + // Abseil doesn't support the same option being provided multiple times - only the last occurrence is applied. + // To preserve the previous usage of '-f', where users may specify it multiple times to override different dimension names, + // we need to manually parse argv. + std::string option = "f"; + if (!ParseDimensionOverrideFromArgv(argc, utf8_argv_strings, option, test_config.run_config.free_dim_name_overrides)) { + return false; + } + } } - dim_identifier = free_dim_str.substr(0, delimiter_location); - std::basic_string override_val_str = free_dim_str.substr(delimiter_location + 1, std::wstring::npos); - ORT_TRY { - override_val = std::stoll(override_val_str.c_str()); - if (override_val <= 0) { - return false; + + // -F + { + const auto& dim_override_str = absl::GetFlag(FLAGS_F); + if (!dim_override_str.empty()) { + // Same reason as '-f' above to manully parse argv. + std::string option = "F"; + if (!ParseDimensionOverrideFromArgv(argc, utf8_argv_strings, option, test_config.run_config.free_dim_denotation_overrides)) { + return false; + } } } - ORT_CATCH(...) { - return false; + + // -m + { + const auto& test_mode_str = absl::GetFlag(FLAGS_m); + if (!test_mode_str.empty()) { + if (test_mode_str == "duration") { + test_config.run_config.test_mode = TestMode::kFixDurationMode; + } else if (test_mode_str == "times") { + test_config.run_config.test_mode = TestMode::KFixRepeatedTimesMode; + } else { + return false; + } + } } - return true; -} -/*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { - int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqznlgR:X"))) != -1) { - switch (ch) { - case 'f': { - std::basic_string dim_name; - int64_t override_val; - if (!ParseDimensionOverride(dim_name, override_val)) { - return false; - } - test_config.run_config.free_dim_name_overrides[dim_name] = override_val; - break; + // -p + { + const auto& profile_file = absl::GetFlag(FLAGS_p); + if (!profile_file.empty()) test_config.run_config.profile_file = ToPathString(profile_file); + } + + // -M + test_config.run_config.enable_memory_pattern = absl::GetFlag(FLAGS_M); + + // -A + test_config.run_config.enable_cpu_mem_arena = absl::GetFlag(FLAGS_A); + + // -e + { + auto const& ep = absl::GetFlag(FLAGS_e); + if (!ep.empty()) { + if (ep == "cpu") { + test_config.machine_config.provider_type_name = onnxruntime::kCpuExecutionProvider; + } else if (ep == "cuda") { + test_config.machine_config.provider_type_name = onnxruntime::kCudaExecutionProvider; + } else if (ep == "dnnl") { + test_config.machine_config.provider_type_name = onnxruntime::kDnnlExecutionProvider; + } else if (ep == "openvino") { + test_config.machine_config.provider_type_name = onnxruntime::kOpenVINOExecutionProvider; + } else if (ep == "tensorrt") { + test_config.machine_config.provider_type_name = onnxruntime::kTensorrtExecutionProvider; + } else if (ep == "qnn") { + test_config.machine_config.provider_type_name = onnxruntime::kQnnExecutionProvider; + } else if (ep == "snpe") { + test_config.machine_config.provider_type_name = onnxruntime::kSnpeExecutionProvider; + } else if (ep == "nnapi") { + test_config.machine_config.provider_type_name = onnxruntime::kNnapiExecutionProvider; + } else if (ep == "vsinpu") { + test_config.machine_config.provider_type_name = onnxruntime::kVSINPUExecutionProvider; + } else if (ep == "coreml") { + test_config.machine_config.provider_type_name = onnxruntime::kCoreMLExecutionProvider; + } else if (ep == "dml") { + test_config.machine_config.provider_type_name = onnxruntime::kDmlExecutionProvider; + } else if (ep == "acl") { + test_config.machine_config.provider_type_name = onnxruntime::kAclExecutionProvider; + } else if (ep == "armnn") { + test_config.machine_config.provider_type_name = onnxruntime::kArmNNExecutionProvider; + } else if (ep == "rocm") { + test_config.machine_config.provider_type_name = onnxruntime::kRocmExecutionProvider; + } else if (ep == "migraphx") { + test_config.machine_config.provider_type_name = onnxruntime::kMIGraphXExecutionProvider; + } else if (ep == "xnnpack") { + test_config.machine_config.provider_type_name = onnxruntime::kXnnpackExecutionProvider; + } else if (ep == "vitisai") { + test_config.machine_config.provider_type_name = onnxruntime::kVitisAIExecutionProvider; + } else if (ep == "webgpu") { + test_config.machine_config.provider_type_name = onnxruntime::kWebGpuExecutionProvider; + } else if (ep == "nvtensorrtrtx") { + test_config.machine_config.provider_type_name = onnxruntime::kNvTensorRTRTXExecutionProvider; + } else { + return false; } - case 'F': { - std::basic_string dim_denotation; - int64_t override_val; - if (!ParseDimensionOverride(dim_denotation, override_val)) { - return false; - } - test_config.run_config.free_dim_denotation_overrides[dim_denotation] = override_val; - break; + } + } + + // Helper function to check if the option is explicitly specified. + // Abseil Flags does not provide this capability by default. + // It cannot distinguish between cases where: + // - The user typed `-r 1000` (explicitly passing the default value), and + // - The user omitted `-r` entirely. + // To determine this accurately, we must inspect argv directly. + auto is_option_specified = [&](std::string option) { + for (int i = 1; i < argc; ++i) { + auto utf8_arg = ToUTF8String(argv[i]); + if (utf8_arg == ("-" + option) || utf8_arg == ("--" + option)) { + return true; } - case 'm': - if (!CompareCString(optarg, ORT_TSTR("duration"))) { - test_config.run_config.test_mode = TestMode::kFixDurationMode; - } else if (!CompareCString(optarg, ORT_TSTR("times"))) { - test_config.run_config.test_mode = TestMode::KFixRepeatedTimesMode; - } else { - return false; - } - break; - case 'p': - test_config.run_config.profile_file = optarg; - break; - case 'M': - test_config.run_config.enable_memory_pattern = false; - break; - case 'A': - test_config.run_config.enable_cpu_mem_arena = false; - break; - case 'e': - if (!CompareCString(optarg, ORT_TSTR("cpu"))) { - test_config.machine_config.provider_type_name = onnxruntime::kCpuExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("cuda"))) { - test_config.machine_config.provider_type_name = onnxruntime::kCudaExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("dnnl"))) { - test_config.machine_config.provider_type_name = onnxruntime::kDnnlExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("openvino"))) { - test_config.machine_config.provider_type_name = onnxruntime::kOpenVINOExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("tensorrt"))) { - test_config.machine_config.provider_type_name = onnxruntime::kTensorrtExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("qnn"))) { - test_config.machine_config.provider_type_name = onnxruntime::kQnnExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("snpe"))) { - test_config.machine_config.provider_type_name = onnxruntime::kSnpeExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("nnapi"))) { - test_config.machine_config.provider_type_name = onnxruntime::kNnapiExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("vsinpu"))) { - test_config.machine_config.provider_type_name = onnxruntime::kVSINPUExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("coreml"))) { - test_config.machine_config.provider_type_name = onnxruntime::kCoreMLExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("dml"))) { - test_config.machine_config.provider_type_name = onnxruntime::kDmlExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("acl"))) { - test_config.machine_config.provider_type_name = onnxruntime::kAclExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("armnn"))) { - test_config.machine_config.provider_type_name = onnxruntime::kArmNNExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("rocm"))) { - test_config.machine_config.provider_type_name = onnxruntime::kRocmExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("migraphx"))) { - test_config.machine_config.provider_type_name = onnxruntime::kMIGraphXExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("xnnpack"))) { - test_config.machine_config.provider_type_name = onnxruntime::kXnnpackExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("vitisai"))) { - test_config.machine_config.provider_type_name = onnxruntime::kVitisAIExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("webgpu"))) { - test_config.machine_config.provider_type_name = onnxruntime::kWebGpuExecutionProvider; - } else if (!CompareCString(optarg, ORT_TSTR("nvtensorrtrtx"))) { - test_config.machine_config.provider_type_name = onnxruntime::kNvTensorRTRTXExecutionProvider; - } else { - return false; - } - break; - case 'r': - test_config.run_config.repeated_times = static_cast(OrtStrtol(optarg, nullptr)); - if (test_config.run_config.repeated_times <= 0) { - return false; - } - test_config.run_config.test_mode = TestMode::KFixRepeatedTimesMode; - break; - case 't': - test_config.run_config.duration_in_seconds = static_cast(OrtStrtol(optarg, nullptr)); - if (test_config.run_config.repeated_times <= 0) { - return false; - } - test_config.run_config.test_mode = TestMode::kFixDurationMode; - break; - case 's': - test_config.run_config.f_dump_statistics = true; - break; - case 'S': - test_config.run_config.random_seed_for_input_data = static_cast( - OrtStrtol(optarg, nullptr)); - break; - case 'v': - test_config.run_config.f_verbose = true; - break; - case 'x': - test_config.run_config.intra_op_num_threads = static_cast(OrtStrtol(optarg, nullptr)); - if (test_config.run_config.intra_op_num_threads < 0) { - return false; - } - break; - case 'y': - test_config.run_config.inter_op_num_threads = static_cast(OrtStrtol(optarg, nullptr)); - if (test_config.run_config.inter_op_num_threads < 0) { - return false; - } - break; - case 'P': - test_config.run_config.execution_mode = ExecutionMode::ORT_PARALLEL; - break; - case 'c': - test_config.run_config.concurrent_session_runs = - static_cast(OrtStrtol(optarg, nullptr)); - if (test_config.run_config.concurrent_session_runs <= 0) { - return false; - } - break; - case 'o': { - int tmp = static_cast(OrtStrtol(optarg, nullptr)); - switch (tmp) { - case ORT_DISABLE_ALL: - test_config.run_config.optimization_level = ORT_DISABLE_ALL; - break; - case ORT_ENABLE_BASIC: - test_config.run_config.optimization_level = ORT_ENABLE_BASIC; - break; - case ORT_ENABLE_EXTENDED: - test_config.run_config.optimization_level = ORT_ENABLE_EXTENDED; - break; - case ORT_ENABLE_LAYOUT: - test_config.run_config.optimization_level = ORT_ENABLE_LAYOUT; - break; - case ORT_ENABLE_ALL: + } + return false; + }; + + // -r + if (is_option_specified("r")) { + if (absl::GetFlag(FLAGS_r) == static_cast(0)) return false; + test_config.run_config.repeated_times = absl::GetFlag(FLAGS_r); + test_config.run_config.test_mode = TestMode::KFixRepeatedTimesMode; + } + + // -t + if (is_option_specified("t")) { + if (absl::GetFlag(FLAGS_t) <= static_cast(0)) return false; + test_config.run_config.duration_in_seconds = absl::GetFlag(FLAGS_t); + test_config.run_config.test_mode = TestMode::kFixDurationMode; + } + + // -s + test_config.run_config.f_dump_statistics = absl::GetFlag(FLAGS_s); + + // -S + test_config.run_config.random_seed_for_input_data = absl::GetFlag(FLAGS_S); + + // -v + test_config.run_config.f_verbose = absl::GetFlag(FLAGS_v); + + // -x + if (absl::GetFlag(FLAGS_x) < 0) return false; + test_config.run_config.intra_op_num_threads = absl::GetFlag(FLAGS_x); + + // -y + if (absl::GetFlag(FLAGS_y) < 0) return false; + test_config.run_config.inter_op_num_threads = absl::GetFlag(FLAGS_y); + + // -P + if (absl::GetFlag(FLAGS_P)) test_config.run_config.execution_mode = ExecutionMode::ORT_PARALLEL; + + // -c + if (absl::GetFlag(FLAGS_c) <= static_cast(0)) return false; + test_config.run_config.concurrent_session_runs = absl::GetFlag(FLAGS_c); + + // -o + { + const auto optimization_level = absl::GetFlag(FLAGS_o); + if (optimization_level != test_config.run_config.optimization_level) { + switch (optimization_level) { + case ORT_DISABLE_ALL: + test_config.run_config.optimization_level = ORT_DISABLE_ALL; + break; + case ORT_ENABLE_BASIC: + test_config.run_config.optimization_level = ORT_ENABLE_BASIC; + break; + case ORT_ENABLE_EXTENDED: + test_config.run_config.optimization_level = ORT_ENABLE_EXTENDED; + break; + case ORT_ENABLE_LAYOUT: + test_config.run_config.optimization_level = ORT_ENABLE_LAYOUT; + break; + case ORT_ENABLE_ALL: + test_config.run_config.optimization_level = ORT_ENABLE_ALL; + break; + default: { + if (optimization_level > ORT_ENABLE_ALL) { // relax constraint test_config.run_config.optimization_level = ORT_ENABLE_ALL; - break; - default: { - if (tmp > ORT_ENABLE_ALL) { // relax constraint - test_config.run_config.optimization_level = ORT_ENABLE_ALL; - } else { - return false; - } + } else { + return false; } } - break; } - case 'u': - test_config.run_config.optimized_model_path = optarg; - break; - case 'I': - test_config.run_config.generate_model_input_binding = true; - break; - case 'd': - test_config.run_config.cudnn_conv_algo = static_cast(OrtStrtol(optarg, nullptr)); - break; - case 'q': - test_config.run_config.do_cuda_copy_in_separate_stream = true; - break; - case 'z': - test_config.run_config.set_denormal_as_zero = true; - break; - case 'i': - test_config.run_config.ep_runtime_config_string = optarg; - break; - case 'T': - test_config.run_config.intra_op_thread_affinities = ToUTF8String(optarg); - break; - case 'C': { - ORT_TRY { - ParseSessionConfigs(ToUTF8String(optarg), test_config.run_config.session_config_entries); - } - ORT_CATCH(const std::exception& ex) { - ORT_HANDLE_EXCEPTION([&]() { - fprintf(stderr, "Error parsing session configuration entries: %s\n", ex.what()); - }); - return false; - } - break; + } + } + + // -u + { + const auto& optimized_model_path = absl::GetFlag(FLAGS_u); + if (!optimized_model_path.empty()) test_config.run_config.optimized_model_path = ToPathString(optimized_model_path); + } + + // -I + test_config.run_config.generate_model_input_binding = absl::GetFlag(FLAGS_I); + + // -d + if (absl::GetFlag(FLAGS_d) < 0) return false; + test_config.run_config.cudnn_conv_algo = absl::GetFlag(FLAGS_d); + + // -q + test_config.run_config.do_cuda_copy_in_separate_stream = absl::GetFlag(FLAGS_q); + + // -z + test_config.run_config.set_denormal_as_zero = absl::GetFlag(FLAGS_z); + + // -i + { + const auto& ep_options = absl::GetFlag(FLAGS_i); + if (!ep_options.empty()) test_config.run_config.ep_runtime_config_string = ToPathString(ep_options); + } + + // -T + if (!absl::GetFlag(FLAGS_T).empty()) test_config.run_config.intra_op_thread_affinities = absl::GetFlag(FLAGS_T); + + // -C + { + const auto& session_configs = absl::GetFlag(FLAGS_C); + if (!session_configs.empty()) { + ORT_TRY { + ParseSessionConfigs(session_configs, test_config.run_config.session_config_entries); } - case 'D': - test_config.run_config.disable_spinning = true; - break; - case 'Z': - test_config.run_config.disable_spinning_between_run = true; - break; - case 'n': - test_config.run_config.exit_after_session_creation = true; - break; - case 'l': - test_config.model_info.load_via_path = true; - break; - case 'R': - test_config.run_config.register_custom_op_path = optarg; - break; - case 'g': - test_config.run_config.enable_cuda_io_binding = true; - break; - case 'X': - test_config.run_config.use_extensions = true; - break; - case '?': - case 'h': - default: + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + fprintf(stderr, "Error parsing session configuration entries: %s\n", ex.what()); + }); return false; + } } } - // parse model_path and result_file_path - argc -= optind; - argv += optind; - - switch (argc) { - case 2: - test_config.model_info.result_file_path = argv[1]; - break; - case 1: - test_config.run_config.f_dump_statistics = true; - break; - default: - return false; + // -D + test_config.run_config.disable_spinning = absl::GetFlag(FLAGS_D); + + // -Z + test_config.run_config.disable_spinning_between_run = absl::GetFlag(FLAGS_Z); + + // -n + test_config.run_config.exit_after_session_creation = absl::GetFlag(FLAGS_n); + + // -l + test_config.model_info.load_via_path = absl::GetFlag(FLAGS_l); + + // -R + { + const auto& register_custom_op_path = absl::GetFlag(FLAGS_R); + if (!register_custom_op_path.empty()) test_config.run_config.register_custom_op_path = ToPathString(register_custom_op_path); } - test_config.model_info.model_file_path = argv[0]; + // -g + test_config.run_config.enable_cuda_io_binding = absl::GetFlag(FLAGS_g); + + // -X + test_config.run_config.use_extensions = absl::GetFlag(FLAGS_X); + + // --plugin_ep_libs + { + const auto& plugin_ep_names_and_libs = absl::GetFlag(FLAGS_plugin_ep_libs); + if (!plugin_ep_names_and_libs.empty()) test_config.plugin_ep_names_and_libs = ToPathString(plugin_ep_names_and_libs); + } + + // --plugin_eps + { + const auto& plugin_eps = absl::GetFlag(FLAGS_plugin_eps); + if (!plugin_eps.empty()) ParseEpList(plugin_eps, test_config.machine_config.plugin_provider_type_list); + } + + // --plugin_ep_options + { + const auto& plugin_ep_options = absl::GetFlag(FLAGS_plugin_ep_options); + if (!plugin_ep_options.empty()) test_config.run_config.ep_runtime_config_string = ToPathString(plugin_ep_options); + } + + // --list_ep_devices + if (absl::GetFlag(FLAGS_list_ep_devices)) { + test_config.list_available_ep_devices = true; + return true; + } + + // --select_ep_devices + { + const auto& select_ep_devices = absl::GetFlag(FLAGS_select_ep_devices); + if (!select_ep_devices.empty()) test_config.selected_ep_device_indices = select_ep_devices; + } + + if (positional.size() == 2) { + test_config.model_info.model_file_path = ToPathString(positional[1]); + test_config.run_config.f_dump_statistics = true; + } else if (positional.size() == 3) { + test_config.model_info.model_file_path = ToPathString(positional[1]); + test_config.model_info.result_file_path = ToPathString(positional[2]); + } else { + return false; + } return true; } diff --git a/onnxruntime/test/perftest/command_args_parser.h b/onnxruntime/test/perftest/command_args_parser.h index 86c81072233c0..5a94f99874797 100644 --- a/onnxruntime/test/perftest/command_args_parser.h +++ b/onnxruntime/test/perftest/command_args_parser.h @@ -11,7 +11,6 @@ struct PerformanceTestConfig; class CommandLineParser { public: - static void ShowUsage(); static bool ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]); }; diff --git a/onnxruntime/test/perftest/common_utils.cc b/onnxruntime/test/perftest/common_utils.cc new file mode 100644 index 0000000000000..5cc6c240e25f0 --- /dev/null +++ b/onnxruntime/test/perftest/common_utils.cc @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test/perftest/utils.h" +#include "test/perftest/strings_helper.h" +#include + +#include + +#include + +namespace onnxruntime { +namespace perftest { +namespace utils { + +void ListEpDevices(const Ort::Env& env) { + std::vector ep_devices = env.GetEpDevices(); + + for (size_t i = 0; i < ep_devices.size(); ++i) { + auto device = ep_devices[i]; + std::string device_info_msg = "===== device id " + std::to_string(i) + " ======\n"; + device_info_msg += "name: " + std::string(device.EpName()) + "\n"; + device_info_msg += "vendor: " + std::string(device.EpVendor()) + "\n"; + + auto metadata = device.EpMetadata(); + std::unordered_map metadata_entries = metadata.GetKeyValuePairs(); + if (!metadata_entries.empty()) { + device_info_msg += "metadata:\n"; + } + + for (auto& entry : metadata_entries) { + device_info_msg += " " + entry.first + ": " + entry.second + "\n"; + } + device_info_msg += "\n"; + fprintf(stdout, "%s", device_info_msg.c_str()); + } +} + +void RegisterExecutionProviderLibrary(Ort::Env& env, PerformanceTestConfig& test_config) { + if (!test_config.plugin_ep_names_and_libs.empty()) { + std::unordered_map ep_names_to_libs; + ParseSessionConfigs(ToUTF8String(test_config.plugin_ep_names_and_libs), ep_names_to_libs); + if (ep_names_to_libs.size() > 0) { + for (auto& pair : ep_names_to_libs) { + const std::filesystem::path library_path = pair.second; + const std::string registration_name = pair.first; + Ort::Status status(Ort::GetApi().RegisterExecutionProviderLibrary(env, registration_name.c_str(), ToPathString(library_path.string()).c_str())); + if (status.IsOK()) { + test_config.registered_plugin_eps.push_back(registration_name); + } else { + fprintf(stderr, "Can't register %s plugin library: %s\n", registration_name.c_str(), status.GetErrorMessage().c_str()); + } + } + } + } +} + +void UnregisterExecutionProviderLibrary(Ort::Env& env, PerformanceTestConfig& test_config) { + for (auto& registration_name : test_config.registered_plugin_eps) { + Ort::Status status(Ort::GetApi().UnregisterExecutionProviderLibrary(env, registration_name.c_str())); + if (!status.IsOK()) { + fprintf(stderr, "%s", status.GetErrorMessage().c_str()); + } + } +} + +std::vector ConvertArgvToUtf8Strings(int argc, ORTCHAR_T* argv[]) { + std::vector utf8_args; + utf8_args.reserve(argc); + for (int i = 0; i < argc; ++i) { + std::string utf8_string = ToUTF8String(argv[i]); + + // Abseil flags doens't natively alias "-h" to "--help". + // We make "-h" alias to "--help" here. + if (utf8_string == "-h" || utf8_string == "--h") { + utf8_args.push_back("--help"); + } else { + utf8_args.push_back(utf8_string); + } + } + return utf8_args; +} + +std::vector CStringsFromStrings(std::vector& utf8_args) { + std::vector utf8_argv; + utf8_argv.reserve(utf8_args.size()); + for (auto& str : utf8_args) { + utf8_argv.push_back(&str[0]); + } + return utf8_argv; +} + +} // namespace utils +} // namespace perftest +} // namespace onnxruntime diff --git a/onnxruntime/test/perftest/main.cc b/onnxruntime/test/perftest/main.cc index 43bf54963cabb..513122609bb01 100644 --- a/onnxruntime/test/perftest/main.cc +++ b/onnxruntime/test/perftest/main.cc @@ -6,6 +6,8 @@ #include #include "command_args_parser.h" #include "performance_runner.h" +#include "utils.h" +#include "strings_helper.h" #include using namespace onnxruntime; @@ -19,7 +21,7 @@ int real_main(int argc, char* argv[]) { g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); perftest::PerformanceTestConfig test_config; if (!perftest::CommandLineParser::ParseArguments(test_config, argc, argv)) { - perftest::CommandLineParser::ShowUsage(); + fprintf(stderr, "%s", "See 'onnxruntime_perf_test --help'."); return -1; } Ort::Env env{nullptr}; @@ -33,7 +35,7 @@ int real_main(int argc, char* argv[]) { } ORT_CATCH(const Ort::Exception& e) { ORT_HANDLE_EXCEPTION([&]() { - fprintf(stderr, "Error creating environment: %s \n", e.what()); + std::cerr << "Error creating environment: " << e.what() << std::endl; failed = true; }); } @@ -41,6 +43,30 @@ int real_main(int argc, char* argv[]) { if (failed) return -1; } + + if (!test_config.plugin_ep_names_and_libs.empty()) { + perftest::utils::RegisterExecutionProviderLibrary(env, test_config); + } + + // Unregister all registered plugin EP libraries before program exits. + // This is necessary because unregistering the plugin EP also unregisters any associated shared allocators. + // If we don't do this and program returns, the factories stored inside the environment will be destroyed when the environment goes out of scope. + // Later, when the shared allocator's deleter runs, it may cause a segmentation fault because it attempts to use the already-destroyed factory to call ReleaseAllocator. + // See "ep_device.ep_factory->ReleaseAllocator" in Environment::CreateSharedAllocatorImpl. + auto unregister_plugin_eps_at_scope_exit = gsl::finally([&]() { + if (!test_config.registered_plugin_eps.empty()) { + perftest::utils::UnregisterExecutionProviderLibrary(env, test_config); // this won't throw + } + }); + + if (test_config.list_available_ep_devices) { + perftest::utils::ListEpDevices(env); + if (test_config.registered_plugin_eps.empty()) { + fprintf(stdout, "No plugin execution provider libraries are registered. Please specify them using \"--plugin_ep_libs\"; otherwise, only CPU may be available.\n"); + } + return 0; + } + std::random_device rd; perftest::PerformanceRunner perf_runner(env, test_config, rd); @@ -72,7 +98,7 @@ int main(int argc, char* argv[]) { } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { - fprintf(stderr, "%s\n", ex.what()); + std::cerr << ex.what() << std::endl; retval = -1; }); } diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 7a210ca8482a4..1ba3078efdb1a 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -62,6 +62,84 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device : rand_engine_(rd()), input_names_(m.GetInputCount()), input_names_str_(m.GetInputCount()), input_length_(m.GetInputCount()) { Ort::SessionOptions session_options; + // Add EP devices if any (created by plugin EP) + if (!performance_test_config.registered_plugin_eps.empty()) { + std::vector ep_devices = env.GetEpDevices(); + // EP -> associated EP devices (All OrtEpDevice instances must be from the same execution provider) + std::unordered_map> added_ep_devices; + std::unordered_set added_ep_device_index_set; + + auto& ep_list = performance_test_config.machine_config.plugin_provider_type_list; + std::unordered_set ep_set(ep_list.begin(), ep_list.end()); + + // Select EP devices by provided device index + if (!performance_test_config.selected_ep_device_indices.empty()) { + std::vector device_list; + device_list.reserve(performance_test_config.selected_ep_device_indices.size()); + ParseEpDeviceIndexList(performance_test_config.selected_ep_device_indices, device_list); + for (auto index : device_list) { + if (static_cast(index) > (ep_devices.size() - 1)) { + fprintf(stderr, "%s", "The device index provided is not correct. Will skip this device id."); + continue; + } + + Ort::ConstEpDevice& device = ep_devices[index]; + if (ep_set.find(std::string(device.EpName())) != ep_set.end()) { + if (added_ep_device_index_set.find(index) == added_ep_device_index_set.end()) { + added_ep_devices[device.EpName()].push_back(device); + added_ep_device_index_set.insert(index); + fprintf(stdout, "[Plugin EP] EP Device [Index: %d, Name: %s] has been added to session.\n", index, device.EpName()); + } + } else { + std::string err_msg = "[Plugin EP] [WARNING] : The EP device index and its corresponding OrtEpDevice is not created from " + + performance_test_config.machine_config.provider_type_name + ". Will skip adding this device.\n"; + fprintf(stderr, "%s", err_msg.c_str()); + } + } + } else { + // Find and select the OrtEpDevice associated with the EP in "--plugin_eps". + for (size_t index = 0; index < ep_devices.size(); ++index) { + Ort::ConstEpDevice& device = ep_devices[index]; + if (ep_set.find(std::string(device.EpName())) != ep_set.end()) { + added_ep_devices[device.EpName()].push_back(device); + fprintf(stdout, "EP Device [Index: %d, Name: %s] has been added to session.\n", static_cast(index), device.EpName()); + } + } + } + + if (added_ep_devices.empty()) { + ORT_THROW("[ERROR] [Plugin EP]: No matching EP devices found."); + } + + std::string ep_option_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); + + // EP's associated provider option lists + std::vector> ep_options_list; + ParseEpOptions(ep_option_string, ep_options_list); + + // If user only provide the EPs' provider option lists for the first several EPs, + // add empty provider option lists for the rest EPs. + if (ep_options_list.size() < ep_list.size()) { + for (size_t i = ep_options_list.size(); i < ep_list.size(); ++i) { + ep_options_list.emplace_back(); // Adds a new empty map + } + } else if (ep_options_list.size() > ep_list.size()) { + ORT_THROW("[ERROR] [Plugin EP]: Too many EP provider option lists provided."); + } + + // EP -> associated provider options + std::unordered_map> ep_options_map; + for (size_t i = 0; i < ep_list.size(); ++i) { + ep_options_map.emplace(ep_list[i], ep_options_list[i]); + } + + for (auto& ep_and_devices : added_ep_devices) { + auto& ep = ep_and_devices.first; + auto& devices = ep_and_devices.second; + session_options.AppendExecutionProvider_V2(env, devices, ep_options_map[ep]); + } + } + provider_name_ = performance_test_config.machine_config.provider_type_name; std::unordered_map provider_options; if (provider_name_ == onnxruntime::kDnnlExecutionProvider) { @@ -101,53 +179,29 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #endif } else if (provider_name_ == onnxruntime::kCudaExecutionProvider) { #ifdef USE_CUDA - const auto& api = Ort::GetApi(); - OrtCUDAProviderOptionsV2* cuda_options; - Ort::ThrowOnError(api.CreateCUDAProviderOptions(&cuda_options)); - std::vector option_keys, option_values; - // used to keep all option keys and value strings alive - std::list buffer; - buffer.emplace_back("cudnn_conv_algo_search"); - option_keys.push_back(buffer.back().c_str()); + Ort::CUDAProviderOptions cuda_options; + + const char* config_val = nullptr; switch (performance_test_config.run_config.cudnn_conv_algo) { case 0: - buffer.emplace_back("EXHAUSTIVE"); + config_val = "EXHAUSTIVE"; break; case 1: - buffer.emplace_back("HEURISTIC"); + config_val = "HEURISTIC"; break; default: - buffer.emplace_back("DEFAULT"); + config_val = "DEFAULT"; break; } - option_values.push_back(buffer.back().c_str()); - - buffer.emplace_back("do_copy_in_default_stream"); - option_keys.push_back(buffer.back().c_str()); - buffer.emplace_back(!performance_test_config.run_config.do_cuda_copy_in_separate_stream ? "1" : "0"); - option_values.push_back(buffer.back().c_str()); + provider_options.emplace("cudnn_conv_algo_search", config_val); + provider_options.emplace("do_copy_in_default_stream", + (!performance_test_config.run_config.do_cuda_copy_in_separate_stream ? "1" : "0")); -#ifdef _MSC_VER std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); -#else - std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; -#endif + ParseSessionConfigs(ov_string, provider_options); - for (const auto& provider_option : provider_options) { - option_keys.push_back(provider_option.first.c_str()); - option_values.push_back(provider_option.second.c_str()); - } + cuda_options.Update(provider_options); - Ort::Status status(api.UpdateCUDAProviderOptions(cuda_options, - option_keys.data(), option_values.data(), option_keys.size())); - if (!status.IsOK()) { - OrtAllocator* allocator; - char* options; - Ort::ThrowOnError(api.GetAllocatorWithDefaultOptions(&allocator)); - Ort::ThrowOnError(api.GetCUDAProviderOptionsAsString(cuda_options, allocator, &options)); - ORT_THROW("[ERROR] [CUDA] Configuring the CUDA options failed with message: ", status.GetErrorMessage(), - "\nSupported options are:\n", options); - } session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); if (performance_test_config.run_config.enable_cuda_io_binding) { device_memory_name_ = CUDA; @@ -157,12 +211,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #endif } else if (provider_name_ == onnxruntime::kTensorrtExecutionProvider) { #ifdef USE_TENSORRT - const auto& api = Ort::GetApi(); - OrtTensorRTProviderOptionsV2* tensorrt_options; - Ort::ThrowOnError(api.CreateTensorRTProviderOptions(&tensorrt_options)); - std::unique_ptr rel_trt_options( - tensorrt_options, api.ReleaseTensorRTProviderOptions); - std::vector option_keys, option_values; + Ort::TensorRTProviderOptions tensorrt_options; // used to keep all option keys and value strings alive std::list buffer; @@ -172,25 +221,11 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; #endif ParseSessionConfigs(ov_string, provider_options); - for (const auto& provider_option : provider_options) { - option_keys.push_back(provider_option.first.c_str()); - option_values.push_back(provider_option.second.c_str()); - } - Ort::Status status(api.UpdateTensorRTProviderOptions(tensorrt_options, - option_keys.data(), option_values.data(), option_keys.size())); - if (!status.IsOK()) { - OrtAllocator* allocator; - char* options; - Ort::ThrowOnError(api.GetAllocatorWithDefaultOptions(&allocator)); - Ort::ThrowOnError(api.GetTensorRTProviderOptionsAsString(tensorrt_options, allocator, &options)); - ORT_THROW("[ERROR] [TensorRT] Configuring the CUDA options failed with message: ", status.GetErrorMessage(), - "\nSupported options are:\n", options); - } - + tensorrt_options.Update(provider_options); session_options.AppendExecutionProvider_TensorRT_V2(*tensorrt_options); OrtCUDAProviderOptions cuda_options; - cuda_options.device_id = tensorrt_options->device_id; + cuda_options.device_id = static_cast(tensorrt_options)->device_id; cuda_options.cudnn_conv_algo_search = static_cast(performance_test_config.run_config.cudnn_conv_algo); cuda_options.do_copy_in_default_stream = !performance_test_config.run_config.do_cuda_copy_in_separate_stream; // TODO: Support arena configuration for users of perf test @@ -828,12 +863,14 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); ov_options[key] = value; } else if (key == "reshape_input") { ov_options[key] = value; + } else if (key == "layout") { + ov_options[key] = value; } else { ORT_THROW( "[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO." " ['device_type', 'device_id', 'num_of_threads', 'load_config', 'cache_dir', 'num_streams', " "'enable_opencl_throttling', 'disable_dynamic_shapes', 'enable_qdq_optimizer'," - " 'enable_causallm', 'model_priority'] \n"); + " 'enable_causallm', 'reshape_input', 'layout', 'model_priority'] \n"); } } session_options.AppendExecutionProvider_OpenVINO_V2(ov_options); diff --git a/onnxruntime/test/perftest/strings_helper.cc b/onnxruntime/test/perftest/strings_helper.cc index 9fd49da1d0486..5743346f8edf1 100644 --- a/onnxruntime/test/perftest/strings_helper.cc +++ b/onnxruntime/test/perftest/strings_helper.cc @@ -8,6 +8,8 @@ #include "strings_helper.h" #include "core/common/common.h" +#include "core/common/parse_string.h" +#include "core/common/string_utils.h" namespace onnxruntime { namespace perftest { @@ -53,5 +55,87 @@ void ParseSessionConfigs(const std::string& configs_string, session_configs.insert(std::make_pair(std::move(key), std::move(value))); } } + +bool ParseDimensionOverride(const std::string& input, std::map& free_dim_override_map) { + std::stringstream ss(input); + std::string free_dim_str; + + while (std::getline(ss, free_dim_str, ' ')) { + if (!free_dim_str.empty()) { + size_t delimiter_location = free_dim_str.find(":"); + if (delimiter_location >= free_dim_str.size() - 1) { + return false; + } + std::string dim_identifier = free_dim_str.substr(0, delimiter_location); + std::string override_val_str = free_dim_str.substr(delimiter_location + 1, std::string::npos); + ORT_TRY { + int64_t override_val = std::stoll(override_val_str.c_str()); + if (override_val <= 0) { + return false; + } + free_dim_override_map[dim_identifier] = override_val; + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + std::cerr << "Error parsing free dimension override value: " << override_val_str.c_str() << ", " << ex.what() << std::endl; + }); + return false; + } + } + } + + return true; +} + +bool ParseDimensionOverrideFromArgv(int argc, std::vector& argv, std::string& option, std::map& free_dim_override_map) { + for (int i = 1; i < argc; ++i) { + auto utf8_arg = argv[i]; + if (utf8_arg == ("-" + option) || utf8_arg == ("--" + option)) { + auto value_idx = i + 1; + if (value_idx >= argc || argv[value_idx][0] == '-') { + std::cerr << utf8_arg << " should be followed by a key-value pair." << std::endl; + return false; + } + + if (!ParseDimensionOverride(argv[value_idx], free_dim_override_map)) return false; + } + } + return true; +} + +void ParseEpOptions(const std::string& input, std::vector>& result) { + auto tokens = utils::SplitString(input, ";", true); + + for (const auto& token : tokens) { + result.emplace_back(); // Adds a new empty map + if (!token.empty()) { + ParseSessionConfigs(std::string(token), result.back()); // only parse non-empty + } + // if token is empty, we still get an empty map in `result` + } +} + +void ParseEpList(const std::string& input, std::vector& result) { + std::stringstream ss(input); + std::string token; + + while (std::getline(ss, token, ';')) { + if (!token.empty()) { + result.push_back(token); + } + } +} + +void ParseEpDeviceIndexList(const std::string& input, std::vector& result) { + std::stringstream ss(input); + std::string item; + + while (std::getline(ss, item, ';')) { + if (!item.empty()) { + int value = ParseStringWithClassicLocale(item); + result.push_back(value); + } + } +} } // namespace perftest } // namespace onnxruntime diff --git a/onnxruntime/test/perftest/strings_helper.h b/onnxruntime/test/perftest/strings_helper.h index 0d6c56709fde6..a33b3d5089c9b 100644 --- a/onnxruntime/test/perftest/strings_helper.h +++ b/onnxruntime/test/perftest/strings_helper.h @@ -3,8 +3,10 @@ // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include +#include #include #include +#include namespace onnxruntime { namespace perftest { @@ -12,5 +14,15 @@ namespace perftest { void ParseSessionConfigs(const std::string& configs_string, std::unordered_map& session_configs, const std::unordered_set& available_keys = {}); + +bool ParseDimensionOverride(const std::string& input, std::map& free_dim_override_map); + +bool ParseDimensionOverrideFromArgv(int argc, std::vector& argv, std::string& option, std::map& free_dim_override_map); + +void ParseEpList(const std::string& input, std::vector& result); + +void ParseEpOptions(const std::string& input, std::vector>& result); + +void ParseEpDeviceIndexList(const std::string& input, std::vector& result); } // namespace perftest } // namespace onnxruntime diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index 8145f5f35c3b3..29ee84dd40dac 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -35,6 +35,7 @@ struct ModelInfo { struct MachineConfig { Platform platform{Platform::kWindows}; std::string provider_type_name{onnxruntime::kCpuExecutionProvider}; + std::vector plugin_provider_type_list; }; struct RunConfig { @@ -59,8 +60,8 @@ struct RunConfig { bool set_denormal_as_zero{false}; std::basic_string ep_runtime_config_string; std::unordered_map session_config_entries; - std::map, int64_t> free_dim_name_overrides; - std::map, int64_t> free_dim_denotation_overrides; + std::map free_dim_name_overrides; + std::map free_dim_denotation_overrides; std::string intra_op_thread_affinities; bool disable_spinning = false; bool disable_spinning_between_run = false; @@ -74,6 +75,10 @@ struct PerformanceTestConfig { ModelInfo model_info; MachineConfig machine_config; RunConfig run_config; + std::basic_string plugin_ep_names_and_libs; + std::vector registered_plugin_eps; + std::string selected_ep_device_indices; + bool list_available_ep_devices = false; }; } // namespace perftest diff --git a/onnxruntime/test/perftest/utils.h b/onnxruntime/test/perftest/utils.h index f22abc04fa99e..9f180e2c8d942 100644 --- a/onnxruntime/test/perftest/utils.h +++ b/onnxruntime/test/perftest/utils.h @@ -2,7 +2,8 @@ // Licensed under the MIT License. #pragma once - +#include "test/perftest/test_configuration.h" +#include #include namespace onnxruntime { @@ -22,6 +23,16 @@ class ICPUUsage { std::unique_ptr CreateICPUUsage(); +std::vector ConvertArgvToUtf8Strings(int argc, ORTCHAR_T* argv[]); + +std::vector CStringsFromStrings(std::vector& utf8_args); + +void RegisterExecutionProviderLibrary(Ort::Env& env, PerformanceTestConfig& test_config); + +void UnregisterExecutionProviderLibrary(Ort::Env& env, PerformanceTestConfig& test_config); + +void ListEpDevices(const Ort::Env& env); + } // namespace utils } // namespace perftest } // namespace onnxruntime diff --git a/onnxruntime/test/platform/apple/apple_package_test/Podfile.template b/onnxruntime/test/platform/apple/apple_package_test/Podfile.template index 9abec2242502f..b6b8b8aa02a51 100644 --- a/onnxruntime/test/platform/apple/apple_package_test/Podfile.template +++ b/onnxruntime/test/platform/apple/apple_package_test/Podfile.template @@ -15,7 +15,7 @@ if ENV['SKIP_MACOS_TEST'] != 'true' # Comment the next line if you don't want to use dynamic frameworks use_frameworks! - platform :osx, '13.3' + platform :osx, '13.4' target 'macos_package_testUITests' do inherit! :search_paths diff --git a/onnxruntime/test/platform/device_discovery_test.cc b/onnxruntime/test/platform/device_discovery_test.cc new file mode 100644 index 0000000000000..21ddf9a5b1cd7 --- /dev/null +++ b/onnxruntime/test/platform/device_discovery_test.cc @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/platform/device_discovery.h" + +#include "gtest/gtest.h" + +namespace onnxruntime::test { + +namespace { + +std::vector GetDevicesByType(OrtHardwareDeviceType device_type) { + std::vector result{}; + const auto& devices = DeviceDiscovery::GetDevices(); + std::copy_if(devices.begin(), devices.end(), std::back_inserter(result), + [device_type](const OrtHardwareDevice& device) { + return device.type == device_type; + }); + return result; +} + +} // namespace + +TEST(DeviceDiscoveryTest, HasCpuDevice) { + const auto cpu_devices = GetDevicesByType(OrtHardwareDeviceType_CPU); + ASSERT_GT(cpu_devices.size(), 0); + +#if !defined(__wasm__) + ASSERT_NE(cpu_devices[0].vendor_id, 0); +#endif // !defined(__WASM__) +} + +} // namespace onnxruntime::test diff --git a/onnxruntime/test/platform/file_io_test.cc b/onnxruntime/test/platform/file_io_test.cc index ccc703716844f..e6f3c4dd8b89e 100644 --- a/onnxruntime/test/platform/file_io_test.cc +++ b/onnxruntime/test/platform/file_io_test.cc @@ -17,7 +17,7 @@ #include #include "gtest/gtest.h" - +#include "asserts.h" #include "core/common/span_utils.h" #include "test/util/include/file_util.h" @@ -157,7 +157,6 @@ TEST(FileIoTest, MapFileIntoMemory) { SYSTEM_INFO sysinfo; GetSystemInfo(&sysinfo); static const auto page_size = sysinfo.dwPageSize; - static const auto allocation_granularity = sysinfo.dwAllocationGranularity; ASSERT_GT(page_size, static_cast(0)); TempFilePath tmp(ORT_TSTR("map_file_test_")); @@ -167,21 +166,10 @@ TEST(FileIoTest, MapFileIntoMemory) { const auto offsets_and_lengths = GenerateValidOffsetLengthPairs( 0, expected_data.size(), page_size / 10); - for (const auto& offset_and_length : offsets_and_lengths) { - const auto offset = offset_and_length.first; - const auto length = offset_and_length.second; - - // The offset must be a multiple of the allocation granularity - if (offset % allocation_granularity != 0) { - continue; - } - + for (const auto& [offset, length] : offsets_and_lengths) { Env::MappedMemoryPtr mapped_memory{}; - auto status = Env::Default().MapFileIntoMemory( - tmp.path.c_str(), offset, length, mapped_memory); - ASSERT_TRUE(status.IsOK()) - << "MapFileIntoMemory failed for offset " << offset << " and length " << length - << " with error: " << status.ErrorMessage(); + ASSERT_STATUS_OK(Env::Default().MapFileIntoMemory( + tmp.path.c_str(), offset, length, mapped_memory)); auto mapped_span = gsl::make_span(mapped_memory.get(), length); @@ -190,20 +178,11 @@ TEST(FileIoTest, MapFileIntoMemory) { ASSERT_TRUE(SpanEq(mapped_span, expected_data_span)); } - { - Env::MappedMemoryPtr mapped_memory{}; - - // invalid - offset is not a multiple of the allocation granularity - ASSERT_FALSE(Env::Default().MapFileIntoMemory( - tmp.path.c_str(), allocation_granularity * 3 / 2, page_size / 10, mapped_memory) - .IsOK()); - } - { Env::MappedMemoryPtr mapped_memory{}; // invalid - negative offset - ASSERT_FALSE(Env::Default().MapFileIntoMemory(tmp.path.c_str(), -1, 0, mapped_memory).IsOK()); + ASSERT_STATUS_NOT_OK(Env::Default().MapFileIntoMemory(tmp.path.c_str(), -1, 0, mapped_memory)); } } #endif diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index a5fd37361a255..dc50a75873034 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -688,7 +688,7 @@ TEST(Loop, SubgraphTypeOverride) { Graph::ResolveOptions options; options.override_types = true; test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider}, &session_run_options, nullptr, + {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}, &session_run_options, nullptr, ExecutionMode::ORT_SEQUENTIAL, options); } @@ -1162,7 +1162,7 @@ TEST(Loop, SequenceAsLoopCarriedDependency) { test.AddSeqOutput("loop_var_0_final", seq_output); // Disable TensorRT on unsupported data type BOOL - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } #if !defined(DISABLE_OPTIONAL_TYPE) diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index d5f6f1ddf700e..cf49601e6c671 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -10,6 +10,7 @@ #include #include #include +#include #include #include "core/session/onnxruntime_c_api.h" @@ -22,6 +23,7 @@ #include #include "default_providers.h" #include "test/onnx/TestCase.h" +#include "test/util/include/api_asserts.h" #ifdef USE_DNNL #include "core/providers/dnnl/dnnl_provider_factory.h" @@ -59,21 +61,161 @@ extern std::unique_ptr ort_env; -// asserts that the OrtStatus* result of `status_expr` does not indicate an error -// note: this takes ownership of the OrtStatus* result -#define ASSERT_ORT_STATUS_OK(status_expr) \ - do { \ - if (OrtStatus* _status = (status_expr); _status != nullptr) { \ - std::unique_ptr _rel_status{ \ - _status, &OrtApis::ReleaseStatus}; \ - FAIL() << "OrtStatus error: " << OrtApis::GetErrorMessage(_rel_status.get()); \ - } \ - } while (false) - using namespace onnxruntime::common; namespace onnxruntime { namespace test { + +// Models verified to exist in both VM and Zoo with identical checksums +// These 20 unique models have been confirmed as public (33 instances across opsets) +static const std::unordered_set VERIFIED_PUBLIC_MODELS = { + "AlexNet", + "BERT-Squad", + "CaffeNet", + "DenseNet-121", + "Emotion FERPlus", + "Faster R-CNN R-50-FPN", + "GoogleNet", + "Inception-1", + "Inception-2", + "Mask R-CNN R-50-FPN", + "MNIST", + "MobileNet v2-7", + "R-CNN ILSVRC13", + "ShuffleNet-v1", + "SqueezeNet 1.0", + "SqueezeNet 1.1", + "SSD", + "VGG 19-caffe2", + "YOLOv3", + "ZFNet-512"}; + +// All ONNX Model Zoo models (always safe as they're public) +// Total: 158 models from https://github.com/onnx/models +static const std::unordered_set ONNX_ZOO_MODELS = { + // Verified models (20 unique) + "AlexNet", + "BERT-Squad", + "CaffeNet", + "DenseNet-121", + "Emotion FERPlus", + "Faster R-CNN R-50-FPN", + "GoogleNet", + "Inception-1", + "Inception-2", + "Mask R-CNN R-50-FPN", + "MNIST", + "MobileNet v2-7", + "R-CNN ILSVRC13", + "ShuffleNet-v1", + "SqueezeNet 1.0", + "SqueezeNet 1.1", + "SSD", + "VGG 19-caffe2", + "YOLOv3", + "ZFNet-512", + // Additional Zoo-only models (138) + "AlexNet-int8", + "BERT-Squad-int8", + "BiDAF", + "BiDAF-int8", + "CaffeNet-int8", + "CaffeNet-qdq", + "Candy", + "DenseNet-121-12", + "DenseNet-121-12-int8", + "EfficientNet-Lite4", + "EfficientNet-Lite4-int8", + "EfficientNet-Lite4-qdq", + "Emotion FERPlus int8", + "FCN ResNet-50", + "FCN ResNet-50-int8", + "FCN ResNet-50-qdq", + "FCN ResNet-101", + "Faster R-CNN R-50-FPN-fp32", + "Faster R-CNN R-50-FPN-int8", + "Faster R-CNN R-50-FPN-qdq", + "GoogleNet-int8", + "GoogleNet-qdq", + "GPT-2", + "GPT-2-LM-HEAD", + "Inception-1-int8", + "Inception-1-qdq", + "LResNet100E-IR", + "LResNet100E-IR-int8", + "Mask R-CNN R-50-FPN-fp32", + "Mask R-CNN R-50-FPN-int8", + "Mask R-CNN R-50-FPN-qdq", + "MNIST-12", + "MNIST-12-int8", + "MobileNet v2-1.0", + "MobileNet v2-1.0-fp32", + "MobileNet v2-1.0-int8", + "MobileNet v2-1.0-qdq", + "Mosaic", + "Pointilism", + "Rain Princess", + "ResNet18", + "ResNet18-v2", + "ResNet34", + "ResNet34-v2", + "ResNet50", + "ResNet50-caffe2", + "ResNet50-fp32", + "ResNet50-int8", + "ResNet50-qdq", + "ResNet50-v2", + "ResNet101", + "ResNet101-v2", + "ResNet101_DUC_HDC", + "ResNet101_DUC_HDC-12", + "ResNet101_DUC_HDC-12-int8", + "ResNet152", + "ResNet152-v2", + "ResNet-preproc", + "RetinaNet (ResNet101 backbone)", + "RoBERTa-BASE", + "RoBERTa-SequenceClassification", + "ShuffleNet-v2", + "ShuffleNet-v2-fp32", + "ShuffleNet-v2-int8", + "ShuffleNet-v2-qdq", + "SqueezeNet 1.0-int8", + "SqueezeNet 1.0-qdq", + "SSD-int8", + "SSD-qdq", + "SSD-MobilenetV1", + "SSD-MobilenetV1-12", + "SSD-MobilenetV1-12-int8", + "SSD-MobilenetV1-12-qdq", + "Super_Resolution", + "T5-decoder-with-lm-head", + "T5-encoder", + "Tiny YOLOv2", + "Tiny YOLOv3", + "Udnie", + "VGG 16", + "VGG 16-bn", + "VGG 16-fp32", + "VGG 16-int8", + "VGG 16-qdq", + "VGG 19", + "VGG 19-bn", + "version-RFB-320", + "version-RFB-320-int8", + "version-RFB-640", + "YOLOv2", + "YOLOv3-12", + "YOLOv3-12-int8", + "YOLOv4", + "ZFNet-512-int8", + "ZFNet-512-qdq"}; + +// Helper function to check if a model is allowed +inline bool IsModelAllowed(const std::string& model_name) { + return ONNX_ZOO_MODELS.count(model_name) > 0; +} + // parameter is provider_name + "_" + model_path class ModelTest : public testing::TestWithParam> {}; @@ -179,17 +321,14 @@ TEST_P(ModelTest, Run) { ortso.SetLogId(ToUTF8String(test_case_name).c_str()); ortso.SetLogSeverityLevel(ORT_LOGGING_LEVEL_ERROR); if (provider_name == "cuda") { - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_ORT_STATUS_OK(OrtApis::CreateCUDAProviderOptions(&cuda_options)); - std::unique_ptr rel_cuda_options( - cuda_options, &OrtApis::ReleaseCUDAProviderOptions); + Ort::CUDAProviderOptions cuda_options; - std::vector keys{"device_id", "use_tf32"}; - std::vector values; std::string device_id = Env::Default().GetEnvironmentVar("ONNXRUNTIME_TEST_GPU_DEVICE_ID"); - values.push_back(device_id.empty() ? "0" : device_id.c_str()); - values.push_back("0"); - ASSERT_ORT_STATUS_OK(OrtApis::UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), 2)); + + std::unordered_map options; + options["device_id"] = (device_id.empty() ? "0" : device_id.c_str()); + options["use_tf32"] = "0"; // Disable TF32 for CUDA provider + cuda_options.Update(options); ortso.AppendExecutionProvider_CUDA_V2(*cuda_options); } else if (provider_name == "rocm") { @@ -199,11 +338,11 @@ TEST_P(ModelTest, Run) { #ifdef USE_DNNL else if (provider_name == "dnnl") { OrtDnnlProviderOptions* ep_option; - ASSERT_ORT_STATUS_OK(OrtApis::CreateDnnlProviderOptions(&ep_option)); + ASSERT_CXX_ORTSTATUS_OK(OrtApis::CreateDnnlProviderOptions(&ep_option)); std::unique_ptr rel_dnnl_options(ep_option, &OrtApis::ReleaseDnnlProviderOptions); ep_option->use_arena = 0; - ASSERT_ORT_STATUS_OK(OrtApis::SessionOptionsAppendExecutionProvider_Dnnl(ortso, ep_option)); + ASSERT_CXX_ORTSTATUS_OK(OrtApis::SessionOptionsAppendExecutionProvider_Dnnl(ortso, ep_option)); } #endif else if (provider_name == "tensorrt") { @@ -211,24 +350,17 @@ TEST_P(ModelTest, Run) { OrtTensorRTProviderOptionsV2 params; ortso.AppendExecutionProvider_TensorRT_V2(params); } else { - OrtTensorRTProviderOptionsV2* ep_option = nullptr; - ASSERT_ORT_STATUS_OK(OrtApis::CreateTensorRTProviderOptions(&ep_option)); - std::unique_ptr - rel_cuda_options(ep_option, &OrtApis::ReleaseTensorRTProviderOptions); + Ort::TensorRTProviderOptions ep_option; ortso.AppendExecutionProvider_TensorRT_V2(*ep_option); } // Enable CUDA fallback - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_ORT_STATUS_OK(OrtApis::CreateCUDAProviderOptions(&cuda_options)); - std::unique_ptr rel_cuda_options( - cuda_options, &OrtApis::ReleaseCUDAProviderOptions); + Ort::CUDAProviderOptions cuda_options; - std::vector keys{"device_id", "use_tf32"}; - std::vector values; std::string device_id = Env::Default().GetEnvironmentVar("ONNXRUNTIME_TEST_GPU_DEVICE_ID"); - values.push_back(device_id.empty() ? "0" : device_id.c_str()); - values.push_back("0"); - ASSERT_ORT_STATUS_OK(OrtApis::UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), 2)); + std::unordered_map options; + options["device_id"] = (device_id.empty() ? "0" : device_id.c_str()); + options["use_tf32"] = "0"; // Disable TF32 for CUDA provider + cuda_options.Update(options); ortso.AppendExecutionProvider_CUDA_V2(*cuda_options); } else if (provider_name == "migraphx") { @@ -240,27 +372,27 @@ TEST_P(ModelTest, Run) { } #ifdef USE_NNAPI else if (provider_name == "nnapi") { - ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_Nnapi(ortso, 0)); + ASSERT_CXX_ORTSTATUS_OK(OrtSessionOptionsAppendExecutionProvider_Nnapi(ortso, 0)); } #endif #ifdef USE_VSINPU else if (provider_name == "vsinpu") { - ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_VSINPU(ortso)); + ASSERT_CXX_ORTSTATUS_OK(OrtSessionOptionsAppendExecutionProvider_VSINPU(ortso)); } #endif #ifdef USE_RKNPU else if (provider_name == "rknpu") { - ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_Rknpu(ortso)); + ASSERT_CXX_ORTSTATUS_OK(OrtSessionOptionsAppendExecutionProvider_Rknpu(ortso)); } #endif #ifdef USE_ACL else if (provider_name == "acl") { - ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_ACL(ortso, false)); + ASSERT_CXX_ORTSTATUS_OK(OrtSessionOptionsAppendExecutionProvider_ACL(ortso, false)); } #endif #ifdef USE_ARMNN else if (provider_name == "armnn") { - ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_ArmNN(ortso)); + ASSERT_CXX_ORTSTATUS_OK(OrtSessionOptionsAppendExecutionProvider_ArmNN(ortso)); } #endif #ifdef USE_XNNPACK @@ -300,11 +432,11 @@ TEST_P(ModelTest, Run) { std::unordered_map feeds; l->LoadTestData(task_id, holder, feeds, true); size_t output_count; - ASSERT_ORT_STATUS_OK(OrtApis::SessionGetOutputCount(ort_session, &output_count)); + ASSERT_ORTSTATUS_OK(OrtApis::SessionGetOutputCount(ort_session, &output_count)); // Create output feed std::vector output_names(output_count); for (size_t i = 0; i != output_count; ++i) { - ASSERT_ORT_STATUS_OK( + ASSERT_ORTSTATUS_OK( OrtApis::SessionGetOutputName(ort_session, i, default_allocator.get(), &output_names[i])); } @@ -676,15 +808,12 @@ ::std::vector<::std::basic_string> GetParameterStrings() { // Same as the above, except this one is for large models #if defined(NDEBUG) || defined(RUN_MODELTEST_IN_DEBUG_MODE) #ifdef _WIN32 - ORT_STRING_VIEW model_test_root_path = ORT_TSTR("..\\models"); - // thus, only the root path should be mounted. ORT_STRING_VIEW model_zoo_path = ORT_TSTR("..\\models\\zoo"); #else - ORT_STRING_VIEW model_test_root_path = ORT_TSTR("../models"); ORT_STRING_VIEW model_zoo_path = ORT_TSTR("../models/zoo"); #endif for (auto p : kvp.second) { - paths.push_back(ConcatPathComponent(model_test_root_path, p)); + // ONLY use Model Zoo path - guaranteed public models with public test data paths.push_back(ConcatPathComponent(model_zoo_path, p)); } #endif @@ -770,6 +899,13 @@ ::std::vector<::std::basic_string> GetParameterStrings() { std::basic_string test_case_name = path.parent_path().filename().native(); if (test_case_name.compare(0, 5, ORT_TSTR("test_")) == 0) test_case_name = test_case_name.substr(5); + + // Check if model is in the public whitelist + std::string model_name_str = ToUTF8String(test_case_name); + if (!IsModelAllowed(model_name_str)) { + continue; // Skip models not in whitelist + } + if (all_disabled_tests.find(test_case_name) != all_disabled_tests.end()) continue; diff --git a/onnxruntime/test/providers/cpu/tensor/dynamic_quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/dynamic_quantize_linear_test.cc index f4d8cad90a714..1a71da6d95135 100644 --- a/onnxruntime/test/providers/cpu/tensor/dynamic_quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/dynamic_quantize_linear_test.cc @@ -11,7 +11,8 @@ namespace test { // range = [-ve, +ve] TEST(QuantizeLinearOpTest, DynamicQuantizeLinear) { // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { + if (DefaultDmlExecutionProvider().get() != nullptr || + DefaultOpenVINOExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: Expected equality of these values: 26 and 25"; } diff --git a/onnxruntime/test/providers/cuda/test_cases/allocator_cuda_test.cc b/onnxruntime/test/providers/cuda/test_cases/allocator_cuda_test.cc index 91a4fe9a54251..af49bd0e3d58d 100644 --- a/onnxruntime/test/providers/cuda/test_cases/allocator_cuda_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/allocator_cuda_test.cc @@ -24,7 +24,7 @@ TEST(AllocatorTest, CUDAAllocatorTest) { size_t size = 1024; - EXPECT_STREQ(cuda_arena->Info().name, CUDA); + EXPECT_STREQ(cuda_arena->Info().name.c_str(), CUDA); EXPECT_EQ(cuda_arena->Info().device.Id(), cuda_device_id); EXPECT_EQ(cuda_arena->Info().mem_type, OrtMemTypeDefault); EXPECT_EQ(cuda_arena->Info().alloc_type, OrtArenaAllocator); @@ -38,7 +38,7 @@ TEST(AllocatorTest, CUDAAllocatorTest) { auto pinned_allocator = CreateAllocator(pinned_memory_info); - EXPECT_STREQ(pinned_allocator->Info().name, CUDA_PINNED); + EXPECT_STREQ(pinned_allocator->Info().name.c_str(), CUDA_PINNED); EXPECT_EQ(pinned_allocator->Info().device.Id(), 0); EXPECT_EQ(pinned_allocator->Info().mem_type, OrtMemTypeCPUOutput); EXPECT_EQ(pinned_allocator->Info().alloc_type, OrtArenaAllocator); @@ -50,7 +50,7 @@ TEST(AllocatorTest, CUDAAllocatorTest) { AllocatorCreationInfo cpu_memory_info( [](int) { return std::make_unique(); }, true); const auto& cpu_arena = CreateAllocator(cpu_memory_info); - EXPECT_STREQ(cpu_arena->Info().name, CPU); + EXPECT_STREQ(cpu_arena->Info().name.c_str(), CPU); EXPECT_EQ(cpu_arena->Info().device.Id(), 0); EXPECT_EQ(cpu_arena->Info().mem_type, OrtMemTypeDefault); EXPECT_EQ(cpu_arena->Info().alloc_type, OrtArenaAllocator); diff --git a/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc b/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc index 5de7885a9452a..761ddf1975d15 100644 --- a/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc +++ b/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc @@ -188,6 +188,24 @@ TEST(MIGraphXExecutionProviderTest, canEvalArgument) { ASSERT_EQ(canEvalNodeArgument(gv, node2, {1}, input_nodes), true); } +static bool SessionHasEp(Ort::Session& session, const char* ep_name) { + // Access the underlying InferenceSession. + const OrtSession* ort_session = session; + const InferenceSession* s = reinterpret_cast(ort_session); + bool has_ep = false; + + for (const auto& provider : s->GetRegisteredProviderTypes()) { + if (provider == ep_name) { + has_ep = true; + break; + } + } + return has_ep; +} + +#if defined(WIN32) +// Tests autoEP feature to automatically select an EP that supports the GPU. +// Currently only works on Windows. TEST(MIGraphXExecutionProviderTest, AutoEp_PreferGpu) { PathString model_name = ORT_TSTR("migraphx_basic_test.onnx"); @@ -212,6 +230,7 @@ TEST(MIGraphXExecutionProviderTest, AutoEp_PreferGpu) { env.UnregisterExecutionProviderLibrary(kMIGraphXExecutionProvider); } +#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index 0559699670c4a..2327bc2094d1a 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -1,25 +1,16 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" #include "test/providers/provider_test_utils.h" #include "test/framework/test_utils.h" -#include "gtest/gtest.h" + #include "test/util/include/scoped_env_vars.h" #include "test/common/trt_op_test_utils.h" #include "test/common/random_generator.h" #include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" -#include "test/util/include/api_asserts.h" -#include "test/util/include/asserts.h" -#include -#include -#include -#include -#include #include -#include #include using namespace std; @@ -30,200 +21,6 @@ namespace onnxruntime { namespace test { -template -class NvExecutionProviderTest : public ::testing::Test { - protected: - std::string getTypeAsName() { - std::string dtype_name = ""; - if constexpr (std::is_same::value) { - dtype_name = "fp64"; - } else if constexpr (std::is_same::value) { - dtype_name = "fp32"; - } else if constexpr (std::is_same::value) { - dtype_name = "bf16"; - } else if constexpr (std::is_same::value) { - dtype_name = "fp16"; - } else if constexpr (std::is_same::value) { - dtype_name = "int8"; - } else if constexpr (std::is_same::value) { - dtype_name = "uint8"; - } else if constexpr (std::is_same::value) { - dtype_name = "int32"; - } else if constexpr (std::is_same::value) { - dtype_name = "int64"; - } - return dtype_name; - } -}; - -using NvExecutionProviderTestTypes = ::testing::Types; // double, -TYPED_TEST_SUITE(NvExecutionProviderTest, NvExecutionProviderTestTypes); - -std::string PathToUTF8(const PathString& path) { -#ifdef WIN32 - std::wstring_convert> converter; - return converter.to_bytes(path); -#else - return path.c_str(); -#endif -} - -void clearFileIfExists(PathString path) { - if (std::filesystem::exists(path)) { - std::filesystem::remove(path); - } -} - -template -void VerifyOutputs(const std::vector& fetches, const std::vector& expected_dims, - const std::vector& expected_values) { - ASSERT_EQ(1, fetches.size()); - auto& rtensor = fetches.front().Get(); - TensorShape expected_shape(expected_dims); - ASSERT_EQ(expected_shape, rtensor.Shape()); - const std::vector found(rtensor.Data(), rtensor.Data() + expected_values.size()); - ASSERT_EQ(expected_values, found); -} - -/** - * Create a simple model with dynamic or non-dynamic input shape. - * \param model_name - model name - * \param graph_name - graph name - * \param dims - input dimensions - * \param add_fast_gelu - add FastGelu node which makes the whole model partition into TRT EP and CUDA EP subgraphs. - * - * input: "X", "Y" and "Z" - * you can specify input dimensions, for example (1, 3, 2), (1, 2) or (1, -1, -1)). Note: -1 means the dimension is dynamic. - * All three inputs have the same dimensions. - * output: "M" - * - * "X" "Y" - * \ / - * "Z" Add - * \ / - * Add - * / - * Add (+ float scalar "S") - * / - * "O" - * - * or - * - * "X" "Y" - * \ / - * "Z" Add - * \ / - * Add - * / - * FastGelu (This node will be placed on CUDA EP) - * / - * * Add (+ float scalar "S") - * / - * "O" - */ -static void CreateBaseModel(const PathString& model_name, - std::string graph_name, - std::vector dims, - bool add_fast_gelu = false, - ONNX_NAMESPACE::TensorProto_DataType dtype = ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger()); - auto& graph = model.MainGraph(); - std::vector inputs; - std::vector outputs; - - // FLOAT tensor - ONNX_NAMESPACE::TypeProto float_tensor; - float_tensor.mutable_tensor_type()->set_elem_type(dtype); - - for (auto dim : dims) { - float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); - } - ONNX_NAMESPACE::TypeProto dyn_float_tensor; - dyn_float_tensor.mutable_tensor_type()->set_elem_type(dtype); - - auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor); - auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor); - inputs.push_back(&input_arg_1); - inputs.push_back(&input_arg_2); - auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor); - outputs.push_back(&output_arg); - graph.AddNode("node_1", "Add", "node 1.", inputs, outputs); - - auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor); - inputs.clear(); - inputs.push_back(&output_arg); - inputs.push_back(&input_arg_3); - - auto& output_arg_2 = graph.GetOrCreateNodeArg("node_2_out_1", &float_tensor); - outputs.clear(); - outputs.push_back(&output_arg_2); - graph.AddNode("node_2", "Add", "node 2.", inputs, outputs); - - inputs.clear(); - inputs.push_back(&output_arg_2); - - if (add_fast_gelu) { - auto& output_arg_3 = graph.GetOrCreateNodeArg("node_3_out_1", &dyn_float_tensor); - outputs.clear(); - outputs.push_back(&output_arg_3); - - graph.AddNode("node_3", "FastGelu", "node 3.", inputs, outputs, - /* attributes */ nullptr, kMSDomain); - - inputs.clear(); - inputs.push_back(&output_arg_3); - } - - ONNX_NAMESPACE::TypeProto float_scalar; - float_scalar.mutable_tensor_type()->set_elem_type(dtype); - float_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); - auto& input_scalar = graph.GetOrCreateNodeArg("S", &float_scalar); - inputs.push_back(&input_scalar); - - auto& output_arg_4 = graph.GetOrCreateNodeArg("O", &dyn_float_tensor); - - outputs.clear(); - outputs.push_back(&output_arg_4); - graph.AddNode("node_5", "Add", "node 5.", inputs, outputs); - - auto status = graph.Resolve(); - ASSERT_TRUE(status.IsOK()); - status = onnxruntime::Model::Save(model, model_name); - ASSERT_TRUE(status.IsOK()); -} - -static Ort::IoBinding generate_io_binding(Ort::Session& session, std::map> shape_overwrites = {}) { - Ort::IoBinding binding(session); - auto allocator = Ort::AllocatorWithDefaultOptions(); - for (int input_idx = 0; input_idx < int(session.GetInputCount()); ++input_idx) { - auto input_name = session.GetInputNameAllocated(input_idx, Ort::AllocatorWithDefaultOptions()); - auto full_tensor_info = session.GetInputTypeInfo(input_idx); - auto tensor_info = full_tensor_info.GetTensorTypeAndShapeInfo(); - auto shape = tensor_info.GetShape(); - auto type = tensor_info.GetElementType(); - if (shape_overwrites.find(input_name.get()) == shape_overwrites.end()) { - for (auto& v : shape) { - if (v == -1) { - v = 1; - } - } - } else { - shape = shape_overwrites[input_name.get()]; - } - auto input_value = Ort::Value::CreateTensor(allocator, - shape.data(), - shape.size(), - type); - binding.BindInput(input_name.get(), input_value); - } - - for (int output_idx = 0; output_idx < int(session.GetOutputCount()); ++output_idx) { - auto output_name = session.GetOutputNameAllocated(output_idx, Ort::AllocatorWithDefaultOptions()); - binding.BindOutput(output_name.get(), allocator.GetInfo()); - } - return binding; -} - TEST(NvExecutionProviderTest, ContextEmbedAndReload) { PathString model_name = ORT_TSTR("nv_execution_provider_test.onnx"); PathString model_name_ctx = ORT_TSTR("nv_execution_provider_test_ctx.onnx"); @@ -233,11 +30,6 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReload) { std::vector dims = {1, 3, 2}; CreateBaseModel(model_name, graph_name, dims); - - auto env = Ort::Env(); - auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; - env.UpdateEnvWithCustomLogLevel(logging_level); - // AOT time { auto start = std::chrono::high_resolution_clock::now(); @@ -246,7 +38,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReload) { so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, model_name_ctx_str.c_str()); so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name.c_str(), so); + Ort::Session session_object(*ort_env, model_name.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation AOT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -261,7 +53,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReload) { Ort::RunOptions run_options; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name_ctx.c_str(), so); + Ort::Session session_object(*ort_env, model_name_ctx.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation JIT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -280,10 +72,6 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDynamic) { CreateBaseModel(model_name, graph_name, dims); - auto env = Ort::Env(); - auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; - env.UpdateEnvWithCustomLogLevel(logging_level); - // AOT time { auto start = std::chrono::high_resolution_clock::now(); @@ -292,7 +80,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDynamic) { so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, model_name_ctx_str.c_str()); so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name.c_str(), so); + Ort::Session session_object(*ort_env, model_name.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation AOT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -307,7 +95,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDynamic) { Ort::RunOptions run_options; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name_ctx.c_str(), so); + Ort::Session session_object(*ort_env, model_name_ctx.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation JIT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -329,10 +117,6 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) { CreateBaseModel(model_name, graph_name, dims); - auto env = Ort::Env(); - auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; - env.UpdateEnvWithCustomLogLevel(logging_level); - // AOT time { auto start = std::chrono::high_resolution_clock::now(); @@ -341,7 +125,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) { so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, model_name_ctx_str.c_str()); so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name.c_str(), so); + Ort::Session session_object(*ort_env, model_name.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation AOT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -356,7 +140,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) { Ort::RunOptions run_options; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name_ctx.c_str(), so); + Ort::Session session_object(*ort_env, model_name_ctx.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation JIT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -368,32 +152,71 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) { } } -TYPED_TEST(NvExecutionProviderTest, IOTypeTests) { - std::string dtype_name = this->getTypeAsName(); +std::string getTypeAsName(ONNX_NAMESPACE::TensorProto_DataType dtype) { + switch (dtype) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + return "fp64"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return "fp32"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return "fp16"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + return "bf16"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return "int64"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return "int32"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + return "int8"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return "uint8"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: + return "int4"; + default: + return "Unkwon type"; + } +} + +class TypeTests : public ::testing::TestWithParam { + public: +}; + +TEST_P(TypeTests, IOTypes) { + const std::string dtype_name = getTypeAsName(GetParam()); ASSERT_FALSE(dtype_name.empty()); const std::string model_name_str = "nv_execution_provider_" + dtype_name + ".onnx"; const PathString model_name = ToPathString(model_name_str); - std::string graph_name = "test" + dtype_name; - std::vector dims = {1, -1, -1}; + const std::string graph_name = "test" + dtype_name; + const std::vector dims = {1, 5, 10}; - CreateBaseModel(model_name, graph_name, dims); - - auto env = Ort::Env(); - auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; - env.UpdateEnvWithCustomLogLevel(logging_level); + CreateBaseModel(model_name, graph_name, dims, false, GetParam()); // AOT time { Ort::SessionOptions so; Ort::RunOptions run_options; so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name.c_str(), so); + Ort::Session session_object(*ort_env, model_name.c_str(), so); auto io_binding = generate_io_binding(session_object); session_object.Run(run_options, io_binding); } } +INSTANTIATE_TEST_SUITE_P(NvExecutionProviderTest, TypeTests, + ::testing::Values(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, + ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 + // disabled low precision integer types since a specific quantize/dequantize model is required + // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, + // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, + // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 + ), + [](const testing::TestParamInfo& info) { return getTypeAsName(info.param); }); + +#ifdef _WIN32 static bool SessionHasEp(Ort::Session& session, const char* ep_name) { // Access the underlying InferenceSession. const OrtSession* ort_session = session; @@ -409,31 +232,26 @@ static bool SessionHasEp(Ort::Session& session, const char* ep_name) { return has_ep; } -#if defined(WIN32) // Tests autoEP feature to automatically select an EP that supports the GPU. // Currently only works on Windows. TEST(NvExecutionProviderTest, AutoEp_PreferGpu) { - PathString model_name = ORT_TSTR("nv_execution_provider_data_dyn_test.onnx"); + PathString model_name = ORT_TSTR("nv_execution_provider_auto_ep.onnx"); std::string graph_name = "test"; std::vector dims = {1, 3, 2}; CreateBaseModel(model_name, graph_name, dims); - auto env = Ort::Env(); - auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; - env.UpdateEnvWithCustomLogLevel(logging_level); - { - env.RegisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider, ORT_TSTR("onnxruntime_providers_nv_tensorrt_rtx.dll")); + ort_env->RegisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider, ORT_TSTR("onnxruntime_providers_nv_tensorrt_rtx.dll")); Ort::SessionOptions so; so.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_GPU); - Ort::Session session_object(env, model_name.c_str(), so); + Ort::Session session_object(*ort_env, model_name.c_str(), so); EXPECT_TRUE(SessionHasEp(session_object, kNvTensorRTRTXExecutionProvider)); } - env.UnregisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider); + ort_env->UnregisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider); } TEST(NvExecutionProviderTest, GetSharedAllocator) { @@ -580,7 +398,7 @@ TEST(NvExecutionProviderTest, DataTransfer) { device_tensor = Ort::Value(); } -#endif // defined(WIN32) +#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc new file mode 100644 index 0000000000000..ce49ae81c81c0 --- /dev/null +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc @@ -0,0 +1,206 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. +#include "core/common/path_utils.h" +#include "test/framework/test_utils.h" +#include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" + +#include + +extern std::unique_ptr ort_env; + +namespace onnxruntime { + +namespace test { + +RegisteredEpDeviceUniquePtr AppendTrtEtxEP(Ort::SessionOptions& session_options, std::unordered_map& option_map) { + RegisteredEpDeviceUniquePtr nv_tensorrt_rtx_ep; +#ifdef _WIN32 + /// Since this test runs after other tests that use registration interface this test has to use it as well + /// windows as otherwise the kernel registry inside the EP will not be populated. The legacy APis ony call the initialize once. + Utils::RegisterAndGetNvTensorRtRtxEp(*ort_env, nv_tensorrt_rtx_ep); + auto ep_devices = ort_env->GetEpDevices(); + Ort::ConstEpDevice selected_device; + for (auto& device : ep_devices) { + if (!std::strcmp(device.EpName(), kNvTensorRTRTXExecutionProvider)) { + selected_device = device; + } + } + session_options.AppendExecutionProvider_V2(*ort_env, {selected_device}, option_map); +#else + session_options.AppendExecutionProvider(onnxruntime::kNvTensorRTRTXExecutionProvider, option_map); +#endif + return nv_tensorrt_rtx_ep; +} + +std::vector readBinaryFile(const PathString& filename) { + std::ifstream file(filename, std::ios::binary); + if (!file.is_open()) { + throw std::runtime_error("Could not open file: " + PathToUTF8String(filename)); + } + + file.seekg(0, std::ios::end); + std::streamsize filesize = file.tellg(); + file.seekg(0, std::ios::beg); + + std::vector buffer(filesize); + if (!file.read(reinterpret_cast(buffer.data()), filesize)) { + throw std::runtime_error("Could not read file: " + PathToUTF8String(filename)); + } + + return buffer; +} + +struct CompileParam { + bool embed_mode; + bool bytestream_io; + bool external_initialzier_for_parser = false; + const std::string to_string() const { + return "embed_mode_" + std::to_string(embed_mode) + "_bytestream_io_" + std::to_string(bytestream_io) + "_ext_init_" + std::to_string(external_initialzier_for_parser); + ; + } +}; +class CompileApiTest + : public testing::TestWithParam { + public: + const CompileParam& GetCompileParam() const { + return GetParam(); + } +}; + +void SmallModelTest(CompileParam test_param, bool fully_supported_model) { + std::string test_name = test_param.to_string(); + if (!fully_supported_model) + test_name += "_fast_gelu"; + PathString model_name = path_utils::MakePathString("nv_execution_provider_compile_" + test_name + ".onnx"); + PathString model_name_ctx = path_utils::MakePathString("nv_execution_provider_compile_" + test_name + "_ctx.onnx"); + clearFileIfExists(model_name_ctx); + std::string graph_name = "test"; + std::vector dims = {1, 3, 2}; + + CreateBaseModel(model_name, graph_name, dims, !fully_supported_model); + + Ort::SessionOptions session_options; + std::unordered_map option_map{ + {onnxruntime::nv::provider_option_names::kUseExternalDataInitializer, std::to_string(test_param.external_initialzier_for_parser)}}; + auto ep = AppendTrtEtxEP(session_options, option_map); + + Ort::ModelCompilationOptions model_compile_options(*ort_env, session_options); + model_compile_options.SetEpContextEmbedMode(test_param.embed_mode); + + void* output_context = nullptr; + size_t output_context_size = 0; + std::vector input_onnx; + if (test_param.bytestream_io) { + input_onnx = readBinaryFile(model_name); + model_compile_options.SetInputModelFromBuffer(input_onnx.data(), input_onnx.size()); + model_compile_options.SetOutputModelBuffer(Ort::AllocatorWithDefaultOptions(), &output_context, &output_context_size); + } else { + model_compile_options.SetInputModelPath(model_name.c_str()); + model_compile_options.SetOutputModelPath(model_name_ctx.c_str()); + } + // AOT time + ASSERT_TRUE(Ort::CompileModel(*ort_env, model_compile_options).IsOK()); + + // JIT time + Ort::Session session_object{nullptr}; + if (test_param.bytestream_io) { + session_object = Ort::Session(*ort_env, output_context, output_context_size, session_options); + } else { + session_object = Ort::Session(*ort_env, model_name_ctx.c_str(), session_options); + } + auto io_binding = generate_io_binding(session_object); + Ort::RunOptions run_options; + session_object.Run(run_options, io_binding); +} + +TEST_P(CompileApiTest, SmallModel) { + const auto& test_param = GetCompileParam(); + SmallModelTest(test_param, true); +} + +TEST_P(CompileApiTest, SmallSplitModel) { + const auto& test_param = GetCompileParam(); + SmallModelTest(test_param, false); +} + +TEST_P(CompileApiTest, LargeModel) { + const auto& test_param = GetCompileParam(); + // with embed mode == 1 the resulting file will be over the 2GB proto limit + if (test_param.embed_mode == 1) { + GTEST_SKIP(); + } + std::string test_name = test_param.to_string(); + PathString model_name = path_utils::MakePathString("nv_execution_provider_compile_large_" + test_name + ".onnx"); + PathString external_data_name = path_utils::MakePathString("nv_execution_provider_compile_large_" + test_name + ".onnx_data"); + PathString model_name_ctx = path_utils::MakePathString("nv_execution_provider_compile_large_" + test_name + "_ctx.onnx"); + PathString model_name_ctx_data = path_utils::MakePathString("nv_execution_provider_compile_large_" + test_name + "_ctx.onnx_data"); + clearFileIfExists(model_name_ctx); + clearFileIfExists(model_name_ctx_data); + // This accelerates test iterations if the large model was already generated + if (!std::filesystem::exists(model_name) || !std::filesystem::exists(external_data_name)) { + CreateLargeLLMModel(model_name, external_data_name); + } + + Ort::SessionOptions session_options; + std::unordered_map option_map{ + {onnxruntime::nv::provider_option_names::kUseExternalDataInitializer, + std::to_string(test_param.bytestream_io || test_param.external_initialzier_for_parser)}}; + auto ep = AppendTrtEtxEP(session_options, option_map); + + Ort::ModelCompilationOptions model_compile_options(*ort_env, session_options); + model_compile_options.SetEpContextEmbedMode(test_param.embed_mode); + + void* output_context = nullptr; + size_t output_context_size = 0; + std::vector input_onnx, input_data; + std::vector file_names; + std::vector file_buffers; + std::vector lengths; + if (test_param.bytestream_io) { + input_onnx = readBinaryFile(model_name); + input_data = readBinaryFile(external_data_name); + file_names = {external_data_name}; + file_buffers = {input_data.data()}; + lengths = {input_data.size()}; + session_options.AddExternalInitializersFromFilesInMemory(file_names, file_buffers, lengths); + + model_compile_options.SetInputModelFromBuffer(input_onnx.data(), input_onnx.size()); + model_compile_options.SetOutputModelBuffer(Ort::AllocatorWithDefaultOptions(), &output_context, &output_context_size); + } else { + model_compile_options.SetInputModelPath(model_name.c_str()); + model_compile_options.SetOutputModelPath(model_name_ctx.c_str()); + model_compile_options.SetOutputModelExternalInitializersFile(model_name_ctx_data.c_str(), 1024); + } + + // AOT time + ASSERT_TRUE(Ort::CompileModel(*ort_env, model_compile_options).IsOK()); + + // JIT time + std::unique_ptr session; + if (test_param.bytestream_io) { + session = std::make_unique(*ort_env, output_context, output_context_size, session_options); + } else { + session = std::make_unique(*ort_env, model_name_ctx.c_str(), session_options); + } + + auto io_binding = generate_io_binding(*session); + Ort::RunOptions run_options; + session->Run(run_options, io_binding); +} + +INSTANTIATE_TEST_SUITE_P( + NvExecutionProviderTest, CompileApiTest, + ::testing::Values( + CompileParam{true, false}, + CompileParam{false, false}, + CompileParam{true, true}, + CompileParam{false, true}, + // test with external initializers for parser + CompileParam{true, true, true}, + CompileParam{true, false, true}), + [](const testing::TestParamInfo& info) { + return info.param.to_string(); + }); + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_options_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_options_test.cc new file mode 100644 index 0000000000000..d415548876153 --- /dev/null +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_options_test.cc @@ -0,0 +1,82 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. +#include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" +#include "test/providers/provider_test_utils.h" +#include "test/framework/test_utils.h" + +#include "test/util/include/scoped_env_vars.h" +#include "test/common/trt_op_test_utils.h" +#include "test/common/random_generator.h" +#include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" + +#include +#include + +using namespace std; +using namespace ONNX_NAMESPACE; +using namespace ::onnxruntime::logging; +extern std::unique_ptr ort_env; +namespace onnxruntime { + +namespace test { +size_t countFilesInDirectory(const std::string& dir_path) { + return std::distance(std::filesystem::directory_iterator(dir_path), std::filesystem::directory_iterator{}); +} + +TEST(NvExecutionProviderTest, RuntimeCaching) { + PathString model_name = ORT_TSTR("nv_execution_provider_runtime_caching.onnx"); + PathString model_name_ctx = ORT_TSTR("nv_execution_provider_runtime_caching_ctx.onnx"); + auto model_name_ctx_str = PathToUTF8(model_name_ctx); + clearFileIfExists(model_name_ctx); + std::string graph_name = "test"; + std::vector dims = {1, 3, 2}; + std::string runtime_cache_name = "./runtime_cache/"; + if (std::filesystem::exists(runtime_cache_name)) { + std::filesystem::remove_all(runtime_cache_name); + } + CreateBaseModel(model_name, graph_name, dims); + // AOT time + { + Ort::SessionOptions so; + Ort::RunOptions run_options; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, model_name_ctx_str.c_str()); + so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {{"nv_runtime_cache_path", runtime_cache_name.c_str()}}); + Ort::Session session_object(*ort_env, model_name.c_str(), so); + + auto io_binding = generate_io_binding(session_object); + session_object.Run(run_options, io_binding); + } + // the cache will be dumped to disk upon session destruction + ASSERT_TRUE(std::filesystem::exists(runtime_cache_name.c_str())); + ASSERT_TRUE(1 == countFilesInDirectory(runtime_cache_name)); + + // use existing cache + { + Ort::SessionOptions so; + Ort::RunOptions run_options; + so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {{"nv_runtime_cache_path", runtime_cache_name.c_str()}}); + Ort::Session session_object(*ort_env, model_name_ctx.c_str(), so); + } + ASSERT_TRUE(1 == countFilesInDirectory(runtime_cache_name)); + + // create new cache + { + Ort::SessionOptions so; + Ort::RunOptions run_options; + std::string new_cache_name = "/tmp/runtime_cache_new/"; + if (std::filesystem::exists(new_cache_name)) { + std::filesystem::remove_all(new_cache_name); + } + so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {{"nv_runtime_cache_path", new_cache_name.c_str()}}); + { + Ort::Session session_object(*ort_env, model_name_ctx.c_str(), so); + } + // the cache will be dumped to disk upon session destruction + ASSERT_TRUE(std::filesystem::exists(new_cache_name.c_str())); + ASSERT_TRUE(1 == countFilesInDirectory(new_cache_name)); + } +} +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc index f0ce5c0b296ca..7f7894abdf3d5 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc @@ -3,18 +3,26 @@ // Licensed under the MIT License. // registration/selection is only supported on windows as there's no device discovery on other platforms -#ifdef _WIN32 #include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" #include +#include #include #include "core/session/onnxruntime_cxx_api.h" #include "test/util/include/api_asserts.h" +#include "core/graph/basic_types.h" +#include "core/graph/onnx_protobuf.h" +#include "core/graph/model_saving_options.h" +#include "test/util/include/scoped_env_vars.h" +#include "test/common/trt_op_test_utils.h" +#include "test/providers/provider_test_utils.h" +#include "test/framework/test_utils.h" namespace onnxruntime { namespace test { +#ifdef _WIN32 Utils::NvTensorRtRtxEpInfo Utils::nv_tensorrt_rtx_ep_info; @@ -51,8 +59,410 @@ void Utils::RegisterAndGetNvTensorRtRtxEp(Ort::Env& env, RegisteredEpDeviceUniqu c_api.UnregisterExecutionProviderLibrary(env, nv_tensorrt_rtx_ep_info.registration_name.c_str()); }); } +#endif // _WIN32 + +void CreateBaseModel(const PathString& model_name, + std::string graph_name, + std::vector dims, + bool add_fast_gelu, + ONNX_NAMESPACE::TensorProto_DataType dtype, + const PathString& external_initializer_file) { + onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + std::vector inputs; + std::vector outputs; + + // FLOAT tensor + ONNX_NAMESPACE::TypeProto float_tensor; + float_tensor.mutable_tensor_type()->set_elem_type(dtype); + + for (auto dim : dims) { + float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); + } + ONNX_NAMESPACE::TypeProto dyn_float_tensor; + dyn_float_tensor.mutable_tensor_type()->set_elem_type(dtype); + + auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor); + auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor); + inputs.push_back(&input_arg_1); + inputs.push_back(&input_arg_2); + auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor); + outputs.push_back(&output_arg); + graph.AddNode("node_1", "Add", "node 1.", inputs, outputs); + + auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor); + inputs.clear(); + inputs.push_back(&output_arg); + inputs.push_back(&input_arg_3); + + auto& output_arg_2 = graph.GetOrCreateNodeArg("node_2_out_1", &float_tensor); + outputs.clear(); + outputs.push_back(&output_arg_2); + graph.AddNode("node_2", "Add", "node 2.", inputs, outputs); + + inputs.clear(); + inputs.push_back(&output_arg_2); + + if (add_fast_gelu) { + auto& output_arg_3 = graph.GetOrCreateNodeArg("node_3_out_1", &dyn_float_tensor); + outputs.clear(); + outputs.push_back(&output_arg_3); + + graph.AddNode("node_3", "FastGelu", "node 3.", inputs, outputs, + /* attributes */ nullptr, kMSDomain); + + inputs.clear(); + inputs.push_back(&output_arg_3); + } + + ONNX_NAMESPACE::TypeProto float_scalar; + float_scalar.mutable_tensor_type()->set_elem_type(dtype); + float_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + auto& input_scalar = graph.GetOrCreateNodeArg("S", &float_scalar); + inputs.push_back(&input_scalar); + + auto& output_arg_4 = graph.GetOrCreateNodeArg("O", &dyn_float_tensor); + + outputs.clear(); + outputs.push_back(&output_arg_4); + graph.AddNode("node_5", "Add", "node 5.", inputs, outputs); + + auto status = graph.Resolve(); + ASSERT_TRUE(status.IsOK()); + if (!external_initializer_file.empty()) { + ModelSavingOptions save_options(128); + status = Model::SaveWithExternalInitializers(model, model_name, external_initializer_file, save_options); + } else { + status = Model::Save(model, model_name); + } + ASSERT_TRUE(status.IsOK()); +} + +// Helper to create large initializers +ONNX_NAMESPACE::TensorProto CreateLargeWeight( + const std::string& name, + ONNX_NAMESPACE::TensorProto_DataType dtype, + const std::vector& shape, + float scale = 0.02f) { + ONNX_NAMESPACE::TensorProto tensor; + tensor.set_name(name); + tensor.set_data_type(dtype); + for (auto d : shape) tensor.add_dims(d); + // Here we fill with random floats, but for real data, use your trained weights. + size_t total_size = 1; + for (int64_t d : shape) total_size *= d; + std::random_device rd; + std::default_random_engine rng(rd()); + if (dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + std::vector data(total_size); + std::normal_distribution dist(0.0f, scale); + for (auto& v : data) v = dist(rng); + tensor.set_raw_data(data.data(), total_size * sizeof(float)); + } else if (dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + std::vector data(total_size); + std::normal_distribution dist(0.0f, scale); + for (auto& v : data) v = MLFloat16(dist(rng)); + tensor.set_raw_data(data.data(), total_size * sizeof(MLFloat16)); + } else { + throw std::runtime_error("Unsupported data type for large weight"); + } + return tensor; +} + +// Helper to add a GroupQueryAttention node +onnxruntime::NodeArg& AddGroupQueryAttention( + onnxruntime::Graph& graph, + onnxruntime::NodeArg& query, + onnxruntime::NodeArg& key, + onnxruntime::NodeArg& value, + int batch_size, + int head_dim, + int seq_len, + int num_heads, + int kv_num_heads, + float scale, + ONNX_NAMESPACE::TensorProto_DataType dtype, + const std::string& node_name) { + // KV cache + ONNX_NAMESPACE::TypeProto key_type; + key_type.mutable_tensor_type()->set_elem_type(dtype); + key_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(batch_size); + key_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(kv_num_heads); + key_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(seq_len); + key_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(head_dim); + auto& past_key = graph.GetOrCreateNodeArg(node_name + "_past_key", &key_type); + + ONNX_NAMESPACE::TypeProto value_type; + value_type.mutable_tensor_type()->set_elem_type(dtype); + value_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(batch_size); + value_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(kv_num_heads); + value_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(seq_len); + value_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(head_dim); + auto& past_value = graph.GetOrCreateNodeArg(node_name + "_past_value", &value_type); + + // Output + auto& output = graph.GetOrCreateNodeArg(node_name + "_output", nullptr); + + // Create required initializers for GroupQueryAttention + ONNX_NAMESPACE::TensorProto seqlens_k_tensor; + seqlens_k_tensor.set_name(node_name + "_seqlens_k"); + seqlens_k_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + seqlens_k_tensor.add_dims(2); + seqlens_k_tensor.set_dims(0, batch_size); + seqlens_k_tensor.set_dims(0, 1); + seqlens_k_tensor.add_int32_data(seq_len - 1); // seqlens_k = total_sequence_length - 1 + graph.AddInitializedTensor(seqlens_k_tensor); + + ONNX_NAMESPACE::TensorProto total_seq_len_tensor; + total_seq_len_tensor.set_name(node_name + "_total_sequence_length"); + total_seq_len_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + total_seq_len_tensor.add_int32_data(seq_len); + graph.AddInitializedTensor(total_seq_len_tensor); + + // Get the initializers that were created for this node + auto* seqlens_k = graph.GetNodeArg(node_name + "_seqlens_k"); + auto* total_sequence_length = graph.GetNodeArg(node_name + "_total_sequence_length"); + + auto& present_value = graph.GetOrCreateNodeArg(node_name + "_present_value", nullptr); + auto& present_key = graph.GetOrCreateNodeArg(node_name + "_present_key", nullptr); + + // Inputs - GroupQueryAttention requires at least 7 inputs (query, key, value, past_key, past_value, seqlens_k, total_sequence_length) + std::vector inputs = { + &query, // 0: query + &key, // 1: key + &value, // 2: value + &past_key, // 3: past_key (optional) + &past_value, // 4: past_value (optional) + seqlens_k, // 5: seqlens_k (required) + total_sequence_length, // 6: total_sequence_length (required) + // nullptr, // 7: cos_cache (optional) + // nullptr, // 8: sin_cache (optional) + // nullptr, // 9: position_ids (optional) + // nullptr, // 10: attention_bias (optional) + // nullptr // 11: head_sink (optional) + }; + + // Attributes + NodeAttributes attrs; + ONNX_NAMESPACE::AttributeProto attr_heads; + attr_heads.set_name("num_heads"); + attr_heads.set_type(onnx::AttributeProto_AttributeType_INT); + attr_heads.set_i(num_heads); + attrs["num_heads"] = attr_heads; + ONNX_NAMESPACE::AttributeProto attr_kv_num_heads; + attr_kv_num_heads.set_name("kv_num_heads"); + attr_kv_num_heads.set_type(onnx::AttributeProto_AttributeType_INT); + attr_kv_num_heads.set_i(kv_num_heads); + attrs["kv_num_heads"] = attr_kv_num_heads; + ONNX_NAMESPACE::AttributeProto attr_scale; + attr_scale.set_name("scale"); + attr_scale.set_type(onnx::AttributeProto_AttributeType_FLOAT); + attr_scale.set_f(scale); + attrs["scale"] = attr_scale; + + // Register node + graph.AddNode( + node_name, + "GroupQueryAttention", + "GroupQueryAttention Node", + inputs, + {&output, &present_key, &present_value}, + &attrs, + "com.microsoft"); + + return output; +} + +void CreateLargeLLMModel(const PathString& model_path, const PathString& external_data_path) { + // Model parameters (example: 24 layers, 4096 hidden dim, 32 attention heads, 8 kv heads => GQA) + int batch_size = 1; + int num_layers = 32; + int hidden_dim = 2048; + int q_num_heads = 8; + int kv_num_heads = 1; // GQA: q_num_heads > kv_num_heads, and divisible. + int seq_length = 128; // Short, for demonstration. + int vocab_size = 32000; + auto dtype = ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; + + // Set up model/graph + onnxruntime::Model model("LLM_With_GQA", false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + // Input + ONNX_NAMESPACE::TypeProto input_type; + input_type.mutable_tensor_type()->set_elem_type(dtype); + input_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(batch_size); + input_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(seq_length); + input_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(hidden_dim); + auto& input = graph.GetOrCreateNodeArg("input", &input_type); + + auto* current_arg = &input; + + // Repeated layers: [Attention + MLP] + for (int l = 0; l < num_layers; ++l) { + // KV cache - initialize with zeros for the first forward pass + int head_dim = hidden_dim / q_num_heads; + + // Split Q, K, V + auto& q_split = graph.GetOrCreateNodeArg("q_split_" + std::to_string(l), nullptr); + auto& k_split = graph.GetOrCreateNodeArg("k_split_" + std::to_string(l), nullptr); + auto& v_split = graph.GetOrCreateNodeArg("v_split_" + std::to_string(l), nullptr); + constexpr bool split = false; + if constexpr (split) { + // Attention weights (Q, K, V projections) + auto wqkv = CreateLargeWeight("wqkv_" + std::to_string(l), + dtype, {hidden_dim, hidden_dim * 3}); + graph.AddInitializedTensor(wqkv); + + // Q = input @ wq, K = input @ wk, V = input @ wv + auto& qkv_arg = graph.GetOrCreateNodeArg("qkv_" + std::to_string(l), nullptr); + graph.AddNode("QKV_Linear_" + std::to_string(l), "MatMul", "", {current_arg, graph.GetNodeArg(wqkv.name())}, {&qkv_arg}); + + NodeAttributes attrs_split; + ONNX_NAMESPACE::AttributeProto attr_split_axis; + attr_split_axis.set_name("axis"); + attr_split_axis.set_type(onnx::AttributeProto_AttributeType_INT); + attr_split_axis.set_i(-1); + attrs_split["axis"] = attr_split_axis; + ONNX_NAMESPACE::AttributeProto attr_split_num_outputs; + attr_split_num_outputs.set_name("num_outputs"); + attr_split_num_outputs.set_type(onnx::AttributeProto_AttributeType_INT); + attr_split_num_outputs.set_i(3); + attrs_split["num_outputs"] = attr_split_num_outputs; + graph.AddNode("Q_Split_" + std::to_string(l), "Split", "", {&qkv_arg}, {&q_split, &k_split, &v_split}, &attrs_split); + } else { + // Attention weights (Q, K, V projections) + auto wq = CreateLargeWeight("wq_" + std::to_string(l), + dtype, {hidden_dim, hidden_dim}); + graph.AddInitializedTensor(wq); + auto wk = CreateLargeWeight("wk_" + std::to_string(l), + dtype, {hidden_dim, head_dim * kv_num_heads}); + graph.AddInitializedTensor(wk); + auto wv = CreateLargeWeight("wv_" + std::to_string(l), + dtype, {hidden_dim, head_dim * kv_num_heads}); + graph.AddInitializedTensor(wv); + + // Q = input @ wq, K = input @ wk, V = input @ wv + graph.AddNode("Q_Linear_" + std::to_string(l), "MatMul", "", {current_arg, graph.GetNodeArg(wq.name())}, {&q_split}); + graph.AddNode("K_Linear_" + std::to_string(l), "MatMul", "", {current_arg, graph.GetNodeArg(wk.name())}, {&k_split}); + graph.AddNode("V_Linear_" + std::to_string(l), "MatMul", "", {current_arg, graph.GetNodeArg(wv.name())}, {&v_split}); + } + // Reshape Q, K, V + auto& q_reshaped = graph.GetOrCreateNodeArg("q_reshaped_" + std::to_string(l), nullptr); + auto& k_reshaped = graph.GetOrCreateNodeArg("k_reshaped_" + std::to_string(l), nullptr); + auto& v_reshaped = graph.GetOrCreateNodeArg("v_reshaped_" + std::to_string(l), nullptr); + + ONNX_NAMESPACE::TensorProto q_shape_tensor; + q_shape_tensor.set_name("q_shape_" + std::to_string(l)); + q_shape_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + q_shape_tensor.add_dims(3); + q_shape_tensor.add_int64_data(batch_size); + q_shape_tensor.add_int64_data(seq_length); + q_shape_tensor.add_int64_data(head_dim * q_num_heads); + graph.AddInitializedTensor(q_shape_tensor); + + ONNX_NAMESPACE::TensorProto k_shape_tensor; + k_shape_tensor.set_name("k_shape_" + std::to_string(l)); + k_shape_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + k_shape_tensor.add_dims(3); + k_shape_tensor.add_int64_data(batch_size); + k_shape_tensor.add_int64_data(seq_length); + k_shape_tensor.add_int64_data(head_dim * kv_num_heads); + graph.AddInitializedTensor(k_shape_tensor); + + ONNX_NAMESPACE::TensorProto v_shape_tensor; + v_shape_tensor.set_name("v_shape_" + std::to_string(l)); + v_shape_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + v_shape_tensor.add_dims(3); + v_shape_tensor.add_int64_data(batch_size); + v_shape_tensor.add_int64_data(seq_length); + v_shape_tensor.add_int64_data(head_dim * kv_num_heads); + graph.AddInitializedTensor(v_shape_tensor); + + graph.AddNode("Q_Reshape_" + std::to_string(l), "Reshape", "", {&q_split, graph.GetNodeArg(q_shape_tensor.name())}, {&q_reshaped}); + graph.AddNode("K_Reshape_" + std::to_string(l), "Reshape", "", {&k_split, graph.GetNodeArg(k_shape_tensor.name())}, {&k_reshaped}); + graph.AddNode("V_Reshape_" + std::to_string(l), "Reshape", "", {&v_split, graph.GetNodeArg(v_shape_tensor.name())}, {&v_reshaped}); + + // Replace standard attention with GQA + auto& attn_out = AddGroupQueryAttention( + graph, q_reshaped, k_reshaped, v_reshaped, + batch_size, head_dim, seq_length, q_num_heads, kv_num_heads, + 1.0f, dtype, + "GQA_" + std::to_string(l)); + + // Add an MLP block: (Linear + Activation + Linear) + auto w1 = CreateLargeWeight("mlp_w1_" + std::to_string(l), dtype, {hidden_dim, hidden_dim * 4}); + auto w2 = CreateLargeWeight("mlp_w2_" + std::to_string(l), dtype, {hidden_dim * 4, hidden_dim}); + graph.AddInitializedTensor(w1); + graph.AddInitializedTensor(w2); + + auto& mlp_hidden = graph.GetOrCreateNodeArg("mlp_hidden_" + std::to_string(l), nullptr); + graph.AddNode("MLP_1_" + std::to_string(l), "MatMul", "", {&attn_out, graph.GetNodeArg(w1.name())}, {&mlp_hidden}); + auto& relu_out = graph.GetOrCreateNodeArg("relu_" + std::to_string(l), nullptr); + graph.AddNode("Relu_" + std::to_string(l), "Relu", "", {&mlp_hidden}, {&relu_out}); + auto& mlp_out = graph.GetOrCreateNodeArg("mlp_out_" + std::to_string(l), nullptr); + graph.AddNode("MLP_2_" + std::to_string(l), "MatMul", "", {&relu_out, graph.GetNodeArg(w2.name())}, {&mlp_out}); + current_arg = &mlp_out; // For next layer. + } + + // Final projection to vocab + auto w_logits = CreateLargeWeight("w_logits", + dtype, {hidden_dim, vocab_size}); + graph.AddInitializedTensor(w_logits); + auto& output = graph.GetOrCreateNodeArg("logits", nullptr); + graph.AddNode("Output_Linear", "MatMul", "", {current_arg, graph.GetNodeArg(w_logits.name())}, {&output}); + + // Validate, Write as large model with external data + auto status = graph.Resolve(); + if (!status.IsOK()) throw std::runtime_error(status.ErrorMessage()); + + onnxruntime::ModelSavingOptions save_options(128); + status = onnxruntime::Model::SaveWithExternalInitializers( + model, model_path, external_data_path, save_options); + if (!status.IsOK()) throw std::runtime_error(status.ErrorMessage()); +} + +Ort::IoBinding generate_io_binding( + Ort::Session& session, + std::map> shape_overwrites, + OrtAllocator* allocator) { + Ort::IoBinding binding(session); + auto default_allocator = Ort::AllocatorWithDefaultOptions(); + if (allocator == nullptr) { + allocator = default_allocator; + } + const OrtMemoryInfo* info; + Ort::ThrowOnError(Ort::GetApi().AllocatorGetInfo(allocator, &info)); + Ort::MemoryInfo mem_info(info->name.c_str(), info->alloc_type, static_cast(info->device.Id()), info->mem_type); + + for (int input_idx = 0; input_idx < int(session.GetInputCount()); ++input_idx) { + auto input_name = session.GetInputNameAllocated(input_idx, Ort::AllocatorWithDefaultOptions()); + auto full_tensor_info = session.GetInputTypeInfo(input_idx); + auto tensor_info = full_tensor_info.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + auto type = tensor_info.GetElementType(); + if (shape_overwrites.find(input_name.get()) == shape_overwrites.end()) { + for (auto& v : shape) { + if (v == -1) { + v = 1; + } + } + } else { + shape = shape_overwrites[input_name.get()]; + } + auto input_value = Ort::Value::CreateTensor(allocator, + shape.data(), + shape.size(), + type); + binding.BindInput(input_name.get(), input_value); + } + + for (int output_idx = 0; output_idx < int(session.GetOutputCount()); ++output_idx) { + auto output_name = session.GetOutputNameAllocated(output_idx, Ort::AllocatorWithDefaultOptions()); + binding.BindOutput(output_name.get(), mem_info); + } + return binding; +} } // namespace test } // namespace onnxruntime - -#endif // _WIN32 diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h index ef14d3cb382c0..0f011af8211ca 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h @@ -5,9 +5,21 @@ #include #include +#include +#include + +#include +#include +#include +#include +#include +#include -#include "core/session/onnxruntime_cxx_api.h" #include "core/graph/constants.h" +#include "core/common/path_string.h" +#include "core/framework/tensor.h" +#include "core/framework/ort_value.h" +#include "test/util/include/api_asserts.h" namespace onnxruntime { namespace test { @@ -17,7 +29,7 @@ using RegisteredEpDeviceUniquePtr = std::unique_ptr> converter; + return converter.to_bytes(path); +#else + return path.c_str(); +#endif +} + +[[maybe_unused]] static void clearFileIfExists(PathString path) { + if (std::filesystem::exists(path)) { + std::filesystem::remove(path); + } +} + +template +static void VerifyOutputs(const std::vector& fetches, const std::vector& expected_dims, + const std::vector& expected_values) { + ASSERT_EQ(1, fetches.size()); + auto& rtensor = fetches.front().Get(); + TensorShape expected_shape(expected_dims); + ASSERT_EQ(expected_shape, rtensor.Shape()); + const std::vector found(rtensor.Data(), rtensor.Data() + expected_values.size()); + ASSERT_EQ(expected_values, found); +} + +/** + * Create a simple model with dynamic or non-dynamic input shape. + * \param model_name - model name + * \param graph_name - graph name + * \param dims - input dimensions + * \param add_fast_gelu - add FastGelu node which makes the whole model partition into TRT EP and CUDA EP subgraphs. + * \param external_initializer_file - file name to save external initializers to + * + * input: "X", "Y" and "Z" + * you can specify input dimensions, for example (1, 3, 2), (1, 2) or (1, -1, -1)). Note: -1 means the dimension is dynamic. + * All three inputs have the same dimensions. + * output: "M" + * + * "X" "Y" + * \ / + * "Z" Add + * \ / + * Add + * / + * Add (+ float scalar "S") + * / + * "O" + * + * or + * + * "X" "Y" + * \ / + * "Z" Add + * \ / + * Add + * / + * FastGelu (This node will be placed on CUDA EP) + * / + * * Add (+ float scalar "S") + * / + * "O" + */ +void CreateBaseModel(const PathString& model_name, + std::string graph_name, + std::vector dims, + bool add_fast_gelu = false, + ONNX_NAMESPACE::TensorProto_DataType dtype = ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + const PathString& external_initializer_file = {}); + +void CreateLargeLLMModel(const PathString& model_path, const PathString& external_data_path); + +Ort::IoBinding generate_io_binding( + Ort::Session& session, + std::map> shape_overwrites = {}, + OrtAllocator* allocator = nullptr); + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc b/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc new file mode 100644 index 0000000000000..fc90563a61bb1 --- /dev/null +++ b/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "core/framework/float16.h" + +#include "test/util/include/test/test_environment.h" +#include "test/optimizer/qdq_test_utils.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::logging; + +extern std::unique_ptr ort_env; + +class OVEP_BF16_Tests : public ::testing::TestWithParam {}; + +namespace detail { +auto ConstructModel() { + using namespace onnxruntime; + using namespace test; + + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 19; + Model model("Bfloat16Tester", true, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, {}, DefaultLoggingManager().DefaultLogger()); + + Graph& graph = model.MainGraph(); + ModelTestBuilder builder(graph); + auto dim = 4; + std::vector input_data(dim, 1.0f); + auto* input = builder.MakeInput({dim}, input_data); + builder.graph_.SetInputs({input}); + + auto* cast_to_bf16 = builder.MakeIntermediate(); + Node& cast_node = builder.AddNode("Cast", {input}, {cast_to_bf16}, ""); + cast_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)); + + std::vector weight_data(dim * dim); + for (std::size_t i = 0; i < weight_data.size(); ++i) + weight_data[i] = onnxruntime::BFloat16(static_cast(i % dim) / dim); + auto* weights = builder.MakeInitializer({dim, dim}, weight_data); + + auto* matmul_out = builder.MakeIntermediate(); + builder.AddNode("MatMul", {cast_to_bf16, weights}, {matmul_out}); + + std::vector weight_data_2(dim * dim); + for (std::size_t i = 0; i < weight_data_2.size(); ++i) + weight_data_2[i] = onnxruntime::BFloat16(static_cast(i % dim) / dim); + auto* weights_2 = builder.MakeInitializer({dim, dim}, weight_data_2); + + auto* matmul_out_2 = builder.MakeIntermediate(); + builder.AddNode("MatMul", {matmul_out, weights_2}, {matmul_out_2}); + + auto* output = builder.MakeOutput(); + Node& cast2_node = builder.AddNode("Cast", {matmul_out_2}, {output}); + cast2_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + + builder.SetGraphOutputs(); + auto st = model.MainGraph().Resolve(); + if (st != Status::OK()) + throw std::runtime_error(st.ErrorMessage()); + return model; +} + +auto ProbeDevice(const std::string& device) { + static std::map is_present; + if (is_present.find(device) == is_present.end()) { + Ort::SessionOptions sessionOptions; + std::unordered_map ov_options; + ov_options["device_type"] = device; + try { + sessionOptions.AppendExecutionProvider_OpenVINO_V2(ov_options); + is_present[device] = true; + } catch (...) { + is_present[device] = false; + } + } + return is_present[device]; +} +} // namespace detail + +namespace onnxruntime { +namespace test { + +TEST_P(OVEP_BF16_Tests, TestModelConversion) { + Ort::SessionOptions sessionOptions; + std::unordered_map ov_options; + const auto& device = GetParam(); + if (!::detail::ProbeDevice(device)) + GTEST_SKIP() << device + " is not available on this machine"; + + ov_options["device_type"] = device; + auto model = ::detail::ConstructModel(); + sessionOptions.AppendExecutionProvider_OpenVINO_V2(ov_options); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + try { + Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), sessionOptions); + } catch (...) { + FAIL(); + } +} +INSTANTIATE_TEST_SUITE_P(OVEP_Tests, + OVEP_BF16_Tests, + ::testing::Values("CPU", "GPU", "NPU")); +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index a206644bc945e..74b37867b0060 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -232,7 +232,7 @@ TEST(QnnEP, TestDisableCPUFallback_ConflictingConfig) { so.AppendExecutionProvider("QNN", options); // Invalid! Adds CPU EP to session, but also disables CPU fallback. - Ort::Status status(OrtSessionOptionsAppendExecutionProvider_CPU(so, 1)); + so.AppendExecutionProvider_CPU(1); const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "constant_floats.onnx"; @@ -285,7 +285,7 @@ TEST_F(QnnHTPBackendTests, TestConvWithExternalData) { so.AppendExecutionProvider("QNN", options); - Ort::Status status(OrtSessionOptionsAppendExecutionProvider_CPU(so, 1)); + so.AppendExecutionProvider_CPU(1); const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "conv_qdq_external_ini.onnx"; diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 739e39a6975e2..1c8cc6f78fe63 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -317,6 +317,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_DisableEpCompile_ThenCompileExplicitly) { Ort::ModelCompilationOptions compile_options(*ort_env, so); compile_options.SetInputModelPath(input_model_file); compile_options.SetOutputModelPath(output_model_file); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); @@ -355,6 +356,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputModelFromPath) { Ort::ModelCompilationOptions compile_options(*ort_env, so); compile_options.SetInputModelPath(input_model_file); compile_options.SetOutputModelPath(output_model_file); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); @@ -393,6 +395,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputModelAsBuffer_Embe compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); compile_options.SetOutputModelPath(output_model_file); compile_options.SetEpContextEmbedMode(true); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); @@ -427,6 +430,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer) { // Create model compilation options from the session options. Output model is stored in a buffer. Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); compile_options.SetInputModelPath(input_model_file); Ort::AllocatorWithDefaultOptions allocator; @@ -482,6 +486,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size); compile_options.SetEpContextEmbedMode(true); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); @@ -515,6 +520,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB std::string bin_file_name = model_name.substr(0, pos) + "_qnn.bin"; compile_options.SetEpContextBinaryInformation(ToWideString(target_dir).c_str(), ToWideString(model_name).c_str()); compile_options.SetEpContextEmbedMode(false); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); @@ -573,6 +579,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer_Outpu compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size); compile_options.SetOutputModelExternalInitializersFile(output_initializers_file, 0); compile_options.SetEpContextEmbedMode(true); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); @@ -2070,6 +2077,278 @@ TEST_F(QnnHTPBackendTests, QnnEpDynamicOptions) { EXPECT_STREQ("Unsupported EP Dynamic Option", e.what()); } } + +// Implementation of OrtOutStreamWriteFunc that writes the compiled model to a file. +static OrtStatus* ORT_API_CALL TestWriteToStream(void* stream_state, const void* buffer, size_t buffer_num_bytes) { + std::ofstream* outfile = reinterpret_cast(stream_state); + outfile->write(reinterpret_cast(buffer), buffer_num_bytes); + return nullptr; // No error +} + +// Implementation of OrtOutStreamWriteFunc that directly returns an OrtStatus indicating an error. +static OrtStatus* ORT_API_CALL ReturnStatusFromStream(void* stream_state, const void* buffer, size_t buffer_num_bytes) { + ORT_UNUSED_PARAMETER(stream_state); + ORT_UNUSED_PARAMETER(buffer); + ORT_UNUSED_PARAMETER(buffer_num_bytes); + return Ort::GetApi().CreateStatus(ORT_FAIL, "Error from OrtOutStreamWriteFunc callback"); +} + +// Test using the CompileModel() API with settings: +// - input model comes from a file +// - write output model to custom write stream +TEST_F(QnnHTPBackendTests, CompileApi_InputFile_WriteOutputModelBytes) { + const ORTCHAR_T* input_model_file = ORT_TSTR("./compileapi_inputfile_writeoutputmodelbytes.onnx"); + std::filesystem::remove(input_model_file); + + // Create a test model and save it to a file. + TestModel test_model; + CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model); + ASSERT_STATUS_OK(test_model.Save(input_model_file)); + + // Initialize session options with QNN EP + Ort::SessionOptions so; + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + so.AppendExecutionProvider("QNN", provider_options); + + const ORTCHAR_T* output_model_file = ORT_TSTR("compileapi_inputfile_writeoutputmodelbytes_ctx.onnx"); + std::filesystem::remove(output_model_file); + + // Open an output file. Test will incrementally write the output model to file + // via calls to our OrtOutStreamWriteFunc callback. + ASSERT_FALSE(std::filesystem::exists(output_model_file)); + std::ofstream outfile(output_model_file, std::ios::binary); + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelWriteFunc(TestWriteToStream, reinterpret_cast(&outfile)); + compile_options.SetEpContextEmbedMode(true); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); + + // Compile the model. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + outfile.flush(); + outfile.close(); + + // Check that the compiled model has the expected number of EPContext nodes. + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + CheckEpContextNodeCounts(output_model_file, 2, 2); +} + +// Tests using an OrtOutStreamFunc function that returns an error. +TEST_F(QnnHTPBackendTests, CompileApi_OutputStream_ReturnStatus) { + // Create a test model (in memory). + TestModel test_model; + CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model); + std::string model_data = test_model.Serialize(); + + // Initialize session options with QNN EP + Ort::SessionOptions so; + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + so.AppendExecutionProvider("QNN", provider_options); + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); + compile_options.SetOutputModelWriteFunc(ReturnStatusFromStream, nullptr); // Set output stream that returns error + compile_options.SetEpContextEmbedMode(true); + + // Compile the model. Expect a specific error status. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_FALSE(status.IsOK()); + EXPECT_EQ(status.GetErrorCode(), ORT_FAIL); + EXPECT_EQ(status.GetErrorMessage(), "Error from OrtOutStreamWriteFunc callback"); +} + +struct CustomInitializerHandlerState { + const ORTCHAR_T* external_file_path = nullptr; + std::ofstream* outfile = nullptr; +}; + +static OrtStatus* ORT_API_CALL TestHandleInitializerDataFunc(void* state, + const char* initializer_name, + const OrtValue* c_initializer_value, + const OrtExternalInitializerInfo* /*c_external_info*/, + OrtExternalInitializerInfo** c_new_external_info) { + Ort::Status final_status{nullptr}; + + ORT_TRY { + CustomInitializerHandlerState* custom_state = reinterpret_cast(state); + + if (std::string("constant") == initializer_name) { + // Keep a specific initializer in the model just to test both scenarios. + // A real implementation may check the byte size and keep small initializers in the model. + *c_new_external_info = nullptr; + return nullptr; + } + + // + // Store other initializers in an external file. + // + Ort::ConstValue value{c_initializer_value}; + size_t byte_size = value.GetTensorSizeInBytes(); + int64_t offset = custom_state->outfile->tellp(); + const ORTCHAR_T* location = custom_state->external_file_path; + + custom_state->outfile->write(static_cast(value.GetTensorRawData()), byte_size); + custom_state->outfile->flush(); + + // Provide caller (ORT) with the new external info. + Ort::ExternalInitializerInfo new_external_info{nullptr}; + if (Ort::Status status = Ort::ExternalInitializerInfo::Create(location, offset, byte_size, new_external_info); + !status.IsOK()) { + return status.release(); + } + + *c_new_external_info = new_external_info.release(); + } + ORT_CATCH(const Ort::Exception& ex) { + ORT_HANDLE_EXCEPTION(([&ex, &final_status]() { + final_status = Ort::Status{ex}; + })); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION(([&ex, &final_status]() { + final_status = Ort::Status(ex.what(), ORT_FAIL); + })); + } + + return final_status.release(); +} + +// Test using the CompileModel() API with settings: +// - input model comes from a file +// - write output model to a file +// - Use callback to specify where each initializer is stored (i.e., external file or within model). +TEST_F(QnnHTPBackendTests, CompileApi_InputFile_OutputFile_InitializerHandler) { + const ORTCHAR_T* input_model_file = ORT_TSTR("./compileapi_inputfile_outputfile_initializerhandler.onnx"); + const ORTCHAR_T* output_model_file = ORT_TSTR("./compileapi_inputfile_outputfile_initializerhandler_ctx.onnx"); + const ORTCHAR_T* initializer_file = ORT_TSTR("./compileapi_inputfile_outputfile_initializerhandler.bin"); + std::filesystem::remove(input_model_file); + std::filesystem::remove(output_model_file); + std::filesystem::remove(initializer_file); + + // Create a test model and save it to a file. + TestModel test_model; + CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model); + ASSERT_STATUS_OK(test_model.Save(input_model_file)); + + // Initialize session options with QNN EP + Ort::SessionOptions so; + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + so.AppendExecutionProvider("QNN", provider_options); + + // Open a file to store external initializers. ORT will call our handler function for every initializer. + ASSERT_FALSE(std::filesystem::exists(initializer_file)); + std::ofstream outfile(initializer_file, std::ios::binary); + CustomInitializerHandlerState custom_state = {initializer_file, &outfile}; + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + compile_options.SetOutputModelGetInitializerLocationFunc(TestHandleInitializerDataFunc, + reinterpret_cast(&custom_state)); + compile_options.SetEpContextEmbedMode(true); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); + + // Compile the model. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + outfile.flush(); + outfile.close(); + + ASSERT_TRUE(std::filesystem::exists(initializer_file)); + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + CheckEpContextNodeCounts(output_model_file, 2, 2); +} + +static OrtStatus* ORT_API_CALL ReuseExternalInitializers(void* state, + const char* /*initializer_name*/, + const OrtValue* /*initializer_value*/, + const OrtExternalInitializerInfo* external_info, + OrtExternalInitializerInfo** new_external_info) { + Ort::Status final_status{nullptr}; + + ORT_TRY { + // If the original initializer was stored in an external file, keep it there (just for testing). + if (external_info != nullptr) { + Ort::ConstExternalInitializerInfo info(external_info); + auto location = info.GetFilePath(); + int64_t offset = info.GetFileOffset(); + size_t byte_size = info.GetByteSize(); + + Ort::ExternalInitializerInfo new_info(nullptr); + Ort::Status status = Ort::ExternalInitializerInfo::Create(location.c_str(), offset, byte_size, new_info); + if (!status.IsOK()) { + return status.release(); + } + + *new_external_info = new_info.release(); + + // Keep track of number of reused external initializers so that we can assert + // that we reused the expected number of initializers. + // THIS IS TEST CODE. An application would not do this. + size_t* num_reused_ext_initializers = reinterpret_cast(state); + *num_reused_ext_initializers += 1; + + return nullptr; + } + + // If not originally external, save it within the generated compiled model + *new_external_info = nullptr; + } + ORT_CATCH(const Ort::Exception& ex) { + ORT_HANDLE_EXCEPTION(([&ex, &final_status]() { + final_status = Ort::Status{ex}; + })); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION(([&ex, &final_status]() { + final_status = Ort::Status(ex.what(), ORT_FAIL); + })); + } + + return final_status.release(); +} + +// Test using the CompileModel() API with settings: +// - input model comes from a file +// - write output model to a file +// - Use callback to specify where each initializer is stored. We'll reuse external initializers +// from original model! +TEST_F(QnnHTPBackendTests, CompileApi_InitializerHandler_ReuseExternalInitializers) { + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/conv_qdq_external_ini.onnx"); + const ORTCHAR_T* output_model_file = ORT_TSTR("testdata/conv_qdq_external_ini_reuse_ctx.onnx"); + std::filesystem::remove(output_model_file); + + size_t num_reused_ext_initializers = 0; + + // Create model compilation options from the session options. + Ort::SessionOptions so; + Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + compile_options.SetOutputModelGetInitializerLocationFunc(ReuseExternalInitializers, + reinterpret_cast(&num_reused_ext_initializers)); + compile_options.SetEpContextEmbedMode(true); + + // Compile the model. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + std::filesystem::remove(output_model_file); + + ASSERT_EQ(num_reused_ext_initializers, 2); // Reused external conv weight and bias. +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/lpbqgemm_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqgemm_fusion_test.cc new file mode 100644 index 0000000000000..b349e0c40882f --- /dev/null +++ b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqgemm_fusion_test.cc @@ -0,0 +1,131 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" +#include "test/optimizer/qdq_test_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +#if defined(__aarch64__) || defined(_M_ARM64) + +namespace { + +GetQDQTestCaseFn BuildLPBQGemmTestCase() { + return [](ModelTestBuilder& builder) -> void { + // Define the test case for LPBQGemm fusion here + const int64_t input_channels = 16; + const int64_t output_channels = 16; + const int64_t blocks_per_axis = 4; + const std::vector input_shape{1, input_channels}; + auto input_def = TestInputDef(input_shape, false, -0.5f, 0.5f); + NodeArg* input = MakeTestInput(builder, input_def); + + // QuantizeLinear for Activation + NodeArg* act_ql_output = builder.MakeIntermediate(); + NodeArg* act_ql_scale = builder.MakeScalarInitializer(0.00005509183756657876f); + NodeArg* act_ql_zero_point = builder.MakeScalarInitializer(23715); + builder.AddNode("QuantizeLinear", {input, act_ql_scale, act_ql_zero_point}, {act_ql_output}); + + // DequantizeLinear for Activation + NodeArg* act_dql_output = builder.MakeIntermediate(); + NodeArg* act_dql_scale = builder.MakeScalarInitializer(0.00005509183756657876f); + NodeArg* act_dql_zero_point = builder.MakeScalarInitializer(23715); + builder.AddNode("DequantizeLinear", {act_ql_output, act_dql_scale, act_dql_zero_point}, {act_dql_output}); + + // DequantizeLinear for Scale + NodeArg* scale_dql_input = builder.MakeInitializer({blocks_per_axis, output_channels}, 1, 15); + NodeArg* scale_dql_scale = builder.MakeInitializer({output_channels}, 0.01f, 0.02f); + std::vector dql_zero_points_data = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + NodeArg* scale_dql_zero_point = builder.Make1DInitializer(dql_zero_points_data); + NodeArg* scale_dql_output = builder.MakeIntermediate(); + Node& scale_dql = builder.AddNode("DequantizeLinear", {scale_dql_input, scale_dql_scale, scale_dql_zero_point}, {scale_dql_output}); + scale_dql.AddAttribute("axis", static_cast(1)); + + // QuantizeLinear for Weight + NodeArg* w_ql_input = builder.MakeInitializer({input_channels, output_channels}, -1.0f, 1.0f); + std::vector zero_points_data; + size_t num_storage_elems = blocks_per_axis * output_channels; + zero_points_data.resize(Int4x2::CalcNumInt4Pairs(num_storage_elems)); + for (size_t i = 0; i < num_storage_elems; ++i) { + size_t r = i >> 1; + size_t c = i & 0x1; + zero_points_data[r].SetElem(c, 0); + } + NodeArg* w_ql_zero_point = builder.MakeInitializer({blocks_per_axis, output_channels}, zero_points_data); + NodeArg* w_ql_output = builder.MakeIntermediate(); + Node& w_ql = builder.AddNode("QuantizeLinear", {w_ql_input, scale_dql_output, w_ql_zero_point}, {w_ql_output}); + w_ql.AddAttribute("axis", static_cast(0)); + w_ql.AddAttribute("block_size", static_cast(4)); + + // DequantizeLinear for Weight + NodeArg* w_dql_zero_point = builder.MakeInitializer({blocks_per_axis, output_channels}, zero_points_data); + NodeArg* w_dql_output = builder.MakeIntermediate(); + Node& w_dql = builder.AddNode("DequantizeLinear", {w_ql_output, scale_dql_output, w_dql_zero_point}, {w_dql_output}); + w_dql.AddAttribute("axis", static_cast(0)); + w_dql.AddAttribute("block_size", static_cast(4)); + + // Gemm + NodeArg* gemm_bias = builder.MakeInitializer({output_channels}, -1.0f, 1.0f); + NodeArg* gemm_output = builder.MakeIntermediate(); + builder.AddNode("Gemm", {act_dql_output, w_dql_output, gemm_bias}, {gemm_output}); + + // QuantizeLinear for Output + NodeArg* output_ql_scale = builder.MakeScalarInitializer(0.00019595865160226822f); + NodeArg* output_ql_zero_point = builder.MakeScalarInitializer(31693); + NodeArg* output_ql_output = builder.MakeIntermediate(); + builder.AddNode("QuantizeLinear", {gemm_output, output_ql_scale, output_ql_zero_point}, {output_ql_output}); + + // DequantizeLinear for Output + NodeArg* output_dql_scale = builder.MakeScalarInitializer(0.00019595865160226822f); + NodeArg* output_dql_zero_point = builder.MakeScalarInitializer(31693); + NodeArg* output_dql_output = builder.MakeOutput(); + builder.AddNode("DequantizeLinear", {output_ql_output, output_dql_scale, output_dql_zero_point}, {output_dql_output}); + }; +} + +ProviderOptions GetProviderOptions() { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + return provider_options; +} + +} // namespace + +#if defined(_WIN32) +// Graph fails to compose on ARM64 Windows since QNN 2.37.0 +TEST_F(QnnHTPBackendTests, DISABLED_LPBQGemmFusion) { +#else +TEST_F(QnnHTPBackendTests, LPBQGemmFusion) { +#endif + ProviderOptions provider_options = GetProviderOptions(); + RunQnnModelTest(BuildLPBQGemmTestCase(), + provider_options, + /*opset_version=*/21, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::Some, + /*fp32_abs_err=*/1e-2f, + /*log_severity =*/logging::Severity::kERROR, + /*verify_outputs=*/false); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/lpbqmatmul_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqmatmul_fusion_test.cc new file mode 100644 index 0000000000000..8f63ccd5f2cd1 --- /dev/null +++ b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqmatmul_fusion_test.cc @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" +#include "test/optimizer/qdq_test_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +#if defined(__aarch64__) || defined(_M_ARM64) + +namespace { + +GetQDQTestCaseFn BuildLPBQMatMulTestCase() { + return [](ModelTestBuilder& builder) -> void { + // Define the test case for LPBQGemm fusion here + const int64_t input_channels = 16; + const int64_t output_channels = 16; + const int64_t blocks_per_axis = 4; + const std::vector input_shape{1, input_channels}; + auto input_def = TestInputDef(input_shape, false, -0.5f, 0.5f); + NodeArg* input = MakeTestInput(builder, input_def); + + // QuantizeLinear for Activation + NodeArg* act_ql_output = builder.MakeIntermediate(); + NodeArg* act_ql_scale = builder.MakeScalarInitializer(0.00005509183756657876f); + NodeArg* act_ql_zero_point = builder.MakeScalarInitializer(23715); + builder.AddNode("QuantizeLinear", {input, act_ql_scale, act_ql_zero_point}, {act_ql_output}); + + // DequantizeLinear for Activation + NodeArg* act_dql_output = builder.MakeIntermediate(); + NodeArg* act_dql_scale = builder.MakeScalarInitializer(0.00005509183756657876f); + NodeArg* act_dql_zero_point = builder.MakeScalarInitializer(23715); + builder.AddNode("DequantizeLinear", {act_ql_output, act_dql_scale, act_dql_zero_point}, {act_dql_output}); + + // DequantizeLinear for Scale + NodeArg* scale_dql_input = builder.MakeInitializer({blocks_per_axis, output_channels}, 1, 16); + NodeArg* scale_dql_scale = builder.MakeInitializer({output_channels}, 0.01f, 0.02f); + std::vector dql_zero_points_data = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + NodeArg* scale_dql_zero_point = builder.Make1DInitializer(dql_zero_points_data); + NodeArg* scale_dql_output = builder.MakeIntermediate(); + Node& scale_dql = builder.AddNode("DequantizeLinear", {scale_dql_input, scale_dql_scale, scale_dql_zero_point}, {scale_dql_output}); + scale_dql.AddAttribute("axis", static_cast(1)); + + // QuantizeLinear for Weight + NodeArg* w_ql_input = builder.MakeInitializer({input_channels, output_channels}, -2.0f, 2.0f); + std::vector zero_points_data; + size_t num_storage_elems = blocks_per_axis * output_channels; + zero_points_data.resize(Int4x2::CalcNumInt4Pairs(num_storage_elems)); + for (size_t i = 0; i < num_storage_elems; ++i) { + size_t r = i >> 1; + size_t c = i & 0x1; + zero_points_data[r].SetElem(c, 0); + } + NodeArg* w_ql_zero_point = builder.MakeInitializer({blocks_per_axis, output_channels}, zero_points_data); + NodeArg* w_ql_output = builder.MakeIntermediate(); + Node& w_ql = builder.AddNode("QuantizeLinear", {w_ql_input, scale_dql_output, w_ql_zero_point}, {w_ql_output}); + w_ql.AddAttribute("axis", static_cast(0)); + w_ql.AddAttribute("block_size", static_cast(4)); + + // DequantizeLinear for Weight + NodeArg* w_dql_zero_point = builder.MakeInitializer({blocks_per_axis, output_channels}, zero_points_data); + NodeArg* w_dql_output = builder.MakeIntermediate(); + Node& w_dql = builder.AddNode("DequantizeLinear", {w_ql_output, scale_dql_output, w_dql_zero_point}, {w_dql_output}); + w_dql.AddAttribute("axis", static_cast(0)); + w_dql.AddAttribute("block_size", static_cast(4)); + + // MatMul + NodeArg* matmul_output = builder.MakeIntermediate(); + builder.AddNode("MatMul", {act_dql_output, w_dql_output}, {matmul_output}); + + // QuantizeLinear for Output + NodeArg* output_ql_scale = builder.MakeScalarInitializer(0.00019595865160226822f); + NodeArg* output_ql_zero_point = builder.MakeScalarInitializer(31693); + NodeArg* output_ql_output = builder.MakeIntermediate(); + builder.AddNode("QuantizeLinear", {matmul_output, output_ql_scale, output_ql_zero_point}, {output_ql_output}); + + // DequantizeLinear for Output + NodeArg* output_dql_scale = builder.MakeScalarInitializer(0.00019595865160226822f); + NodeArg* output_dql_zero_point = builder.MakeScalarInitializer(31693); + NodeArg* output_dql_output = builder.MakeOutput(); + builder.AddNode("DequantizeLinear", {output_ql_output, output_dql_scale, output_dql_zero_point}, {output_dql_output}); + }; +} + +ProviderOptions GetProviderOptions() { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + return provider_options; +} + +} // namespace + +#if defined(_WIN32) +// Graph fails to compose on ARM64 Windows since QNN 2.37.0 +TEST_F(QnnHTPBackendTests, DISABLED_LPBQMatMulFusion) { +#else +TEST_F(QnnHTPBackendTests, LPBQMatMulFusion) { +#endif + ProviderOptions provider_options = GetProviderOptions(); + RunQnnModelTest(BuildLPBQMatMulTestCase(), + provider_options, + /*opset_version=*/21, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::Some, + /*fp32_abs_err=*/1e-2f, + /*log_severity =*/logging::Severity::kERROR, + /*verify_outputs=*/false); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/where_htp_test.cc b/onnxruntime/test/providers/qnn/where_htp_test.cc index bb3e229bbc9f8..95a9f3dac9cb7 100644 --- a/onnxruntime/test/providers/qnn/where_htp_test.cc +++ b/onnxruntime/test/providers/qnn/where_htp_test.cc @@ -86,7 +86,8 @@ static void RunWhereQDQTest(const TestInputDef& condition_def, } // Check that QNN compiles DQ -> Where -> Q as a single unit. -TEST_F(QnnHTPBackendTests, WhereQDQU8) { +// Fails since QNN 2.37.1: Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_WhereQDQU8) { RunWhereQDQTest(TestInputDef({4, 3, 2}, false, {true, false, true, false, true, false, true, false, true, false, true, false, @@ -99,7 +100,8 @@ TEST_F(QnnHTPBackendTests, WhereQDQU8) { // Check that QNN compiles DQ -> Where -> Q as a single unit. // Check QNN Where works with broadcast -TEST_F(QnnHTPBackendTests, WhereBroadcastU8) { +// Fails since QNN 2.37.1: Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_WhereBroadcastU8) { RunWhereQDQTest(TestInputDef({2}, false, {true, false}), TestInputDef({4, 3, 2}, true, -2.0f, 2.0f), TestInputDef({1}, true, {3.0f}), diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py index 0c52740398b7a..cb31627a87c48 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_autoep.py +++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py @@ -183,7 +183,7 @@ def test_example_plugin_ep_devices(self): Test registration of an example EP plugin and retrieval of its OrtEpDevice. """ if sys.platform != "win32": - self.skipTest("Skipping test because it device discovery is only supported on Windows") + self.skipTest("Skipping test because device discovery is only supported on Windows") ep_lib_path = "example_plugin_ep.dll" try: @@ -244,6 +244,44 @@ def test_example_plugin_ep_devices(self): del sess # Delete session before unregistering library self.unregister_execution_provider_library(ep_name) + def test_example_plugin_ep_data_transfer(self): + """ + Test usage of shared data transfer and allocator from plugin EP. + """ + if sys.platform != "win32": + self.skipTest("Skipping test because device discovery is only supported on Windows") + + if "DmlExecutionProvider" in onnxrt.get_available_providers(): + self.skipTest("Skipping because DML EP data transfer is broken if we haven't created an inference session") + + ep_lib_path = "example_plugin_ep.dll" + try: + ep_lib_path = get_name("example_plugin_ep.dll") + except FileNotFoundError: + self.skipTest(f"Skipping test because EP library '{ep_lib_path}' cannot be found") + + ep_name = "example_ep" + self.register_execution_provider_library(ep_name, os.path.realpath(ep_lib_path)) + + data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + data2 = data + 1 + + # the example EP pretends to use GPU memory so we can test data transfer. + # by matching its OrtDevice info we will hit its allocator and data transfer implementations. + # copy data from CPU to the fake GPU memory + gpu_value = onnxrt.OrtValue.ortvalue_from_numpy(data, "gpu", 0, 0xBE57) + # copy back to CPU + cpu_data = gpu_value.numpy() + np.testing.assert_equal(data, cpu_data) + + gpu_value.update_inplace(data2) # update the fake GPU data + cpu_data_2 = gpu_value.numpy() # copy back to CPU + np.testing.assert_equal(data2, cpu_data_2) + + gpu_value = None # Delete OrtValue before unregistering library as the allocator will be destroyed. + + self.unregister_execution_provider_library(ep_name) + if __name__ == "__main__": unittest.main(verbosity=1) diff --git a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py index b102676860444..e46cdb4f98850 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py +++ b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py @@ -225,6 +225,199 @@ def test_compile_from_buffer_to_buffer(self): self.assertTrue(isinstance(output_model_bytes, bytes)) self.assertGreater(len(output_model_bytes), 0) + def test_compile_graph_optimization_level(self): + """ + Tests compiling a model with no optimizations (default) vs all optimizations. + """ + input_model_path = get_name("test_cast_back_to_back_non_const_mixed_types_origin.onnx") + output_model_path_0 = os.path.join(self._tmp_dir_path, "cast.disable_all.compiled.onnx") + output_model_path_1 = os.path.join(self._tmp_dir_path, "cast.enable_all.compiled.onnx") + + # Local function that compiles a model with a given graph optimization level and returns + # the count of operator types in the compiled model. + def compile_and_get_op_counts( + output_model_path: str, + graph_opt_level: onnxrt.GraphOptimizationLevel | None, + ) -> dict[str, int]: + session_options = onnxrt.SessionOptions() + if graph_opt_level is not None: + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + graph_optimization_level=graph_opt_level, + ) + else: + # graph optimization level defaults to ORT_DISABLE_ALL if not provided. + model_compiler = onnxrt.ModelCompiler(session_options, input_model_path) + + model_compiler.compile_to_file(output_model_path) + self.assertTrue(os.path.exists(output_model_path)) + + model: onnx.ModelProto = onnx.load(get_name(output_model_path)) + op_counts = {} + for node in model.graph.node: + if node.op_type not in op_counts: + op_counts[node.op_type] = 1 + else: + op_counts[node.op_type] += 1 + + return op_counts + + # Compile model on CPU with no graph optimizations (default). + # Model should have 9 Casts + op_counts_0 = compile_and_get_op_counts(output_model_path_0, graph_opt_level=None) + self.assertEqual(op_counts_0["Cast"], 9) + + # Compile model on CPU with ALL graph optimizations. + # Model should have less casts (optimized out) + op_counts_1 = compile_and_get_op_counts( + output_model_path_1, graph_opt_level=onnxrt.GraphOptimizationLevel.ORT_ENABLE_BASIC + ) + self.assertEqual(op_counts_1["Cast"], 8) + + def test_compile_from_file_to_stream(self): + """ + Tests compiling a model (from files) to an output stream using a custom write functor. + """ + provider = None + provider_options = dict() + if "QNNExecutionProvider" in available_providers: + provider = "QNNExecutionProvider" + provider_options["backend_type"] = "htp" + + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "model.compiled.stream.onnx") + + with open(output_model_path, "wb") as output_fd: + # User's custom write functor. Writes the model to a file. + def my_write_func(buffer: bytes): + self.assertGreater(len(buffer), 0) + output_fd.write(buffer) + + session_options = onnxrt.SessionOptions() + if provider: + session_options.add_provider(provider, provider_options) + + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + model_compiler.compile_to_stream(my_write_func) + + self.assertTrue(os.path.exists(output_model_path)) + + def test_compile_to_stream_that_raises_exception(self): + """ + Tests compiling a model to an output stream that always raises an exception. + """ + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + + # User's custom write functor that raises an exception. + test_py_error_message = "My Python Error" + + def my_write_func(buffer: bytes): + self.assertGreater(len(buffer), 0) + raise ValueError(test_py_error_message) + + session_options = onnxrt.SessionOptions() + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + + # Try to compile and expect ORT to raise a Fail exception that contains our message. + with self.assertRaises(Fail) as context: + model_compiler.compile_to_stream(my_write_func) + self.assertIn(test_py_error_message, str(context.exception)) + + def test_compile_with_basic_initializer_location_func(self): + """ + Tests compiling a model using a custom initializer handler that stores initializers + in an external file. + """ + input_model_path = get_name("conv_qdq_external_ini.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "conv_qdq.init_handler.onnx") + initializer_file_path = os.path.join(self._tmp_dir_path, "conv_qdq.init_handler.bin") + + if os.path.exists(output_model_path): + os.remove(output_model_path) + + if os.path.exists(initializer_file_path): + os.remove(initializer_file_path) + + with open(initializer_file_path, "wb") as ext_init_file: + + def store_large_initializer_externally( + initializer_name: str, + initializer_value: onnxrt.OrtValue, + external_info: onnxrt.OrtExternalInitializerInfo | None, + ) -> onnxrt.OrtExternalInitializerInfo | None: + self.assertTrue(initializer_name) # Should have valid name + byte_size = initializer_value.tensor_size_in_bytes() + + if byte_size < 64: + return None # Store small initializer within compiled model. + + # Else, write initializer to new external file. + value_np = initializer_value.numpy() + file_offset = ext_init_file.tell() + ext_init_file.write(value_np.tobytes()) + return onnxrt.OrtExternalInitializerInfo(initializer_file_path, file_offset, byte_size) + + session_options = onnxrt.SessionOptions() + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + get_initializer_location_func=store_large_initializer_externally, + ) + model_compiler.compile_to_file(output_model_path) + + self.assertTrue(os.path.exists(output_model_path)) + self.assertTrue(os.path.exists(initializer_file_path)) + + def test_compile_with_initializer_func_that_reuses(self): + """ + Tests compiling a model using a custom initializer handler that reuses external initializer files. + """ + input_model_path = get_name("conv_qdq_external_ini.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "conv_qdq.init_handler_reuse.onnx") + + if os.path.exists(output_model_path): + os.remove(output_model_path) + + # Function that reuses external initializer files for the compiled model. + def reuse_external_initializers( + initializer_name: str, + initializer_value: onnxrt.OrtValue, + external_info: onnxrt.OrtExternalInitializerInfo | None, + ) -> onnxrt.OrtExternalInitializerInfo | None: + self.assertTrue(initializer_name) # Should have valid name + self.assertNotEqual(initializer_value.data_ptr(), 0) + self.assertGreater(initializer_value.tensor_size_in_bytes(), 0) + if external_info is not None: + # Original initializer is stored externally. + # Make the initializer in the compiled model use the same external file + return external_info + + return None # Otherwise, make a copy of the initializer and store it within compiled model. + + session_options = onnxrt.SessionOptions() + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + get_initializer_location_func=reuse_external_initializers, + ) + model_compiler.compile_to_file(output_model_path) + self.assertTrue(os.path.exists(output_model_path)) + def test_fail_load_uncompiled_model_and_then_compile(self): """ Tests compiling scenario: diff --git a/onnxruntime/test/python/onnxruntime_test_python_ep_compatibility.py b/onnxruntime/test/python/onnxruntime_test_python_ep_compatibility.py new file mode 100644 index 0000000000000..8e69fdf088103 --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_python_ep_compatibility.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import platform +import sys +import unittest + +from onnxruntime.capi.onnxruntime_pybind11_state import ( + OrtCompiledModelCompatibility, + get_ep_devices, + get_model_compatibility_for_ep_devices, +) + +# handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed. +if platform.system() == "Windows" and sys.version_info.major >= 3 and sys.version_info.minor >= 8: # noqa: YTT204 + os.add_dll_directory(os.getcwd()) + + +class TestEpCompatibility(unittest.TestCase): + def test_invalid_args(self): + # empty devices + with self.assertRaises(RuntimeError): + get_model_compatibility_for_ep_devices([], "info") + # None compatibility info should raise TypeError before native call + with self.assertRaises(TypeError): + get_model_compatibility_for_ep_devices(get_ep_devices(), None) # type: ignore[arg-type] + + def test_basic_smoke(self): + devices = list(get_ep_devices()) + if not devices: + self.skipTest("No EP devices available in this build") + + # Always select CPUExecutionProvider; skip if not present. + cpu_devices = [d for d in devices if getattr(d, "ep_name", None) == "CPUExecutionProvider"] + if not cpu_devices: + self.skipTest("CPUExecutionProvider not available in this build") + selected = [cpu_devices[0]] + + # API requires all devices belong to the same EP; we pass only one. + status = get_model_compatibility_for_ep_devices(selected, "arbitrary-compat-string") + self.assertEqual(status, OrtCompiledModelCompatibility.EP_NOT_APPLICABLE) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/onnxruntime_test_python_nv_tensorrt_rtx_ep_tests.py b/onnxruntime/test/python/onnxruntime_test_python_nv_tensorrt_rtx_ep_tests.py new file mode 100644 index 0000000000000..d5c80a4a1f4ba --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_python_nv_tensorrt_rtx_ep_tests.py @@ -0,0 +1,468 @@ +# Copyright (c) NVIDIA Corporation. All rights reserved. +# Licensed under the MIT License. +from __future__ import annotations + +import sys +import unittest +from collections.abc import Sequence + +import numpy as np +import torch +from autoep_helper import AutoEpTestCase +from helper import get_name +from numpy.testing import assert_almost_equal +from onnx import TensorProto, helper +from onnx.defs import onnx_opset_version + +import onnxruntime as onnxrt +from onnxruntime.capi._pybind_state import OrtDevice as C_OrtDevice +from onnxruntime.capi._pybind_state import OrtValue as C_OrtValue +from onnxruntime.capi._pybind_state import OrtValueVector, SessionIOBinding + + +class TestNvTensorRTRTXAutoEP(AutoEpTestCase): + """ + Test suite for the NvTensorRTRTX Execution Provider. + + This class contains tests for registering the NvTensorRTRTX EP, + selecting it using different policies, and running inference with various + I/O binding configurations. + """ + + ep_lib_path = "onnxruntime_providers_nv_tensorrt_rtx.dll" + ep_name = "NvTensorRTRTXExecutionProvider" + + def setUp(self): + if sys.platform != "win32": + self.skipTest("Skipping test because device discovery is only supported on Windows") + self.register_execution_provider_library(self.ep_name, self.ep_lib_path) + + def tearDown(self): + self.unregister_execution_provider_library(self.ep_name) + + def _create_ortvalue_input_on_gpu(self, device): + return onnxrt.OrtValue.ortvalue_from_numpy( + np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32), device, 0 + ) + + def _create_ortvalue_alternate_input_on_gpu(self, device): + return onnxrt.OrtValue.ortvalue_from_numpy( + np.array([[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]], dtype=np.float32), + device, + 0, + ) + + def _create_uninitialized_ortvalue_input_on_gpu(self, device): + return onnxrt.OrtValue.ortvalue_from_shape_and_type([3, 2], np.float32, device, 0) + + def _create_numpy_input(self): + return np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + + def _create_expected_output(self): + return np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + + def _create_expected_output_alternate(self): + return np.array([[2.0, 8.0], [18.0, 32.0], [50.0, 72.0]], dtype=np.float32) + + def torch_to_onnx_type(self, torch_dtype): + if torch_dtype == torch.float32: + return TensorProto.FLOAT + elif torch_dtype == torch.float16: + return TensorProto.FLOAT16 + elif torch_dtype == torch.bfloat16: + return TensorProto.BFLOAT16 + elif torch_dtype == torch.int8: + return TensorProto.int8 + elif torch_dtype == torch.int32: + return TensorProto.INT32 + elif torch_dtype == torch.int64: + return TensorProto.INT64 + else: + raise TypeError(f"Unsupported dtype: {torch_dtype}") + + def test_nv_tensorrt_rtx_ep_register_and_inference(self): + """ + Test registration of NvTensorRTRTX EP, adding its OrtDevice to the SessionOptions, and running inference. + """ + ep_devices = onnxrt.get_ep_devices() + nv_tensorrt_rtx_ep_device = next((d for d in ep_devices if d.ep_name == self.ep_name), None) + self.assertIsNotNone(nv_tensorrt_rtx_ep_device) + self.assertEqual(nv_tensorrt_rtx_ep_device.ep_vendor, "NVIDIA") + + hw_device = nv_tensorrt_rtx_ep_device.device + self.assertEqual(hw_device.type, onnxrt.OrtHardwareDeviceType.GPU) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx")) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + def test_nv_tensorrt_rtx_ep_prefer_gpu_and_inference(self): + """ + Test selecting NvTensorRTRTX EP via the PREFER_GPU policy and running inference. + """ + # Set a policy to prefer GPU. NvTensorRTRTX should be selected. + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + def test_nv_tensorrt_rtx_ep_selection_delegate_and_inference(self): + """ + Test selecting NvTensorRTRTX EP via the custom EP selection delegate function and then run inference. + """ + + # User's custom EP selection function. + def my_delegate( + ep_devices: Sequence[onnxrt.OrtEpDevice], + model_metadata: dict[str, str], + runtime_metadata: dict[str, str], + max_selections: int, + ) -> Sequence[onnxrt.OrtEpDevice]: + self.assertGreater(len(model_metadata), 0) + self.assertGreaterEqual(len(ep_devices), 1) + self.assertGreaterEqual(max_selections, 2) + + nv_tensorrt_rtx_ep_device = next((d for d in ep_devices if d.ep_name == self.ep_name), None) + self.assertIsNotNone(nv_tensorrt_rtx_ep_device) + + # Select the NvTensorRTRTX device + return [nv_tensorrt_rtx_ep_device] + + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy_delegate(my_delegate) + self.assertTrue(sess_options.has_providers()) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + def test_bind_input_only(self): + """ + Test I/O binding with input data only. + """ + # Set a policy to prefer GPU. NvTensorRTRTX should be selected. + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + input = self._create_ortvalue_input_on_gpu("cuda") + + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + io_binding = session.io_binding() + + # Bind input to the GPU + io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Bind output to CPU + io_binding.bind_output("Y") + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host + # here) + ort_output = io_binding.copy_outputs_to_cpu()[0] + + # Validate results + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output)) + + def test_bind_input_and_bind_output_with_ortvalues(self): + """ + Test I/O binding with OrtValues for both input and output. + """ + # Set a policy to prefer GPU. NvTensorRTRTX EP should be selected. + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + io_binding = session.io_binding() + + # Bind ortvalue as input + input_ortvalue = self._create_ortvalue_input_on_gpu("cuda") + io_binding.bind_ortvalue_input("X", input_ortvalue) + + # Bind ortvalue as output + output_ortvalue = self._create_uninitialized_ortvalue_input_on_gpu("cuda") + io_binding.bind_ortvalue_output("Y", output_ortvalue) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # Inspect contents of output_ortvalue and make sure that it has the right contents + self.assertTrue(np.array_equal(self._create_expected_output(), output_ortvalue.numpy())) + + # Bind another ortvalue as input + input_ortvalue_2 = self._create_ortvalue_alternate_input_on_gpu("cuda") + io_binding.bind_ortvalue_input("X", input_ortvalue_2) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # Inspect contents of output_ortvalue and make sure that it has the right contents + self.assertTrue(np.array_equal(self._create_expected_output_alternate(), output_ortvalue.numpy())) + + def test_bind_input_and_non_preallocated_output(self): + """ + Test I/O binding with non-preallocated output. + """ + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + io_binding = session.io_binding() + + input = self._create_ortvalue_input_on_gpu("cuda") + + # Bind input to the GPU + io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) + + # Bind output to the GPU + io_binding.bind_output("Y", "cuda") + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # This call returns an OrtValue which has data allocated by ORT on the GPU + ort_outputs = io_binding.get_outputs() + self.assertEqual(len(ort_outputs), 1) + self.assertEqual(ort_outputs[0].device_name(), "cuda") + # Validate results (by copying results to CPU by creating a Numpy object) + self.assertTrue(np.array_equal(self._create_expected_output(), ort_outputs[0].numpy())) + + # We should be able to repeat the above process as many times as we want - try once more + ort_outputs = io_binding.get_outputs() + self.assertEqual(len(ort_outputs), 1) + self.assertEqual(ort_outputs[0].device_name(), "cuda") + # Validate results (by copying results to CPU by creating a Numpy object) + self.assertTrue(np.array_equal(self._create_expected_output(), ort_outputs[0].numpy())) + + input = self._create_ortvalue_alternate_input_on_gpu("cuda") + + # Change the bound input and validate the results in the same bound OrtValue + # Bind alternate input to the GPU + io_binding.bind_input( + "X", + "cuda", + 0, + np.float32, + [3, 2], + input.data_ptr(), + ) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # This call returns an OrtValue which has data allocated by ORT on the GPU + ort_outputs = io_binding.get_outputs() + self.assertEqual(len(ort_outputs), 1) + self.assertEqual(ort_outputs[0].device_name(), "cuda") + # Validate results (by copying results to CPU by creating a Numpy object) + self.assertTrue(np.array_equal(self._create_expected_output_alternate(), ort_outputs[0].numpy())) + + def test_bind_input_and_preallocated_output(self): + """ + Test I/O binding with preallocated output. + """ + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + input = self._create_ortvalue_input_on_gpu("cuda") + + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + io_binding = session.io_binding() + + # Bind input to the GPU + io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) + + # Bind output to the GPU + output = self._create_uninitialized_ortvalue_input_on_gpu("cuda") + io_binding.bind_output("Y", "cuda", 0, np.float32, [3, 2], output.data_ptr()) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host + # here) + ort_output_vals = io_binding.copy_outputs_to_cpu()[0] + # Validate results + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output_vals)) + + # Validate if ORT actually wrote to pre-allocated buffer by copying the allocated buffer + # to the host and validating its contents + ort_output_vals_in_cpu = output.numpy() + # Validate results + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output_vals_in_cpu)) + + def test_bind_input_types(self): + """ + Test I/O binding with various input data types. + """ + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + opset = onnx_opset_version() + device = C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) + + for dtype in [ + np.float32, + # np.float64, + np.int32, + # np.uint32, + np.int64, + # np.uint64, + # np.int16, + # np.uint16, + # np.int8, + np.uint8, + np.float16, + np.bool_, + ]: + with self.subTest(dtype=dtype, inner_device=str(device)): + x = np.arange(8).reshape((-1, 2)).astype(dtype) + proto_dtype = helper.np_dtype_to_tensor_dtype(x.dtype) + + X = helper.make_tensor_value_info("X", proto_dtype, [None, x.shape[1]]) # noqa: N806 + Y = helper.make_tensor_value_info("Y", proto_dtype, [None, x.shape[1]]) # noqa: N806 + + # inference + node_add = helper.make_node("Identity", ["X"], ["Y"]) + + # graph + graph_def = helper.make_graph([node_add], "lr", [X], [Y], []) + model_def = helper.make_model( + graph_def, + producer_name="dummy", + ir_version=7, + producer_version="0", + opset_imports=[helper.make_operatorsetid("", opset)], + ) + + sess = onnxrt.InferenceSession(model_def.SerializeToString(), sess_options=sess_options) + + bind = SessionIOBinding(sess._sess) + ort_value = C_OrtValue.ortvalue_from_numpy(x, device) + bind.bind_ortvalue_input("X", ort_value) + bind.bind_output("Y", device) + sess._sess.run_with_iobinding(bind, None) + ortvaluevector = bind.get_outputs() + self.assertIsInstance(ortvaluevector, OrtValueVector) + ortvalue = bind.get_outputs()[0] + y = ortvalue.numpy() + assert_almost_equal(x, y) + + bind = SessionIOBinding(sess._sess) + bind.bind_input("X", device, dtype, x.shape, ort_value.data_ptr()) + bind.bind_output("Y", device) + sess._sess.run_with_iobinding(bind, None) + ortvalue = bind.get_outputs()[0] + y = ortvalue.numpy() + assert_almost_equal(x, y) + + def test_bind_onnx_types_from_torch(self): + """ + Test I/O binding with various input data types. + """ + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + opset = onnx_opset_version() + + for dtype in [ + torch.float32, + torch.float16, + torch.bfloat16, + torch.int32, + torch.int64, + ]: + with self.subTest(dtype=dtype): + proto_dtype = self.torch_to_onnx_type(dtype) + + x_ = helper.make_tensor_value_info("X", proto_dtype, [None]) + y_ = helper.make_tensor_value_info("Y", proto_dtype, [None]) + node_add = helper.make_node("Identity", ["X"], ["Y"]) + graph_def = helper.make_graph([node_add], "lr", [x_], [y_], []) + model_def = helper.make_model( + graph_def, + producer_name="dummy", + ir_version=10, + producer_version="0", + opset_imports=[helper.make_operatorsetid("", opset)], + ) + sess = onnxrt.InferenceSession(model_def.SerializeToString(), sess_options=sess_options) + + dev = "cuda" if torch.cuda.is_available() else "cpu" + device = ( + C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) + if dev == "cuda" + else C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0) + ) + + x = torch.arange(8, dtype=dtype, device=dev) + y = torch.empty(8, dtype=dtype, device=dev) + + bind = SessionIOBinding(sess._sess) + bind.bind_input("X", device, proto_dtype, x.shape, x.data_ptr()) + bind.bind_output("Y", device, proto_dtype, y.shape, y.data_ptr()) + sess._sess.run_with_iobinding(bind, None) + self.assertTrue(torch.equal(x, y)) + + +if __name__ == "__main__": + unittest.main(verbosity=1) diff --git a/onnxruntime/test/python/transformers/test_parity_moe.py b/onnxruntime/test/python/transformers/test_moe_cuda.py similarity index 53% rename from onnxruntime/test/python/transformers/test_parity_moe.py rename to onnxruntime/test/python/transformers/test_moe_cuda.py index 252d89a2257fc..c09d8bacf1fa2 100644 --- a/onnxruntime/test/python/transformers/test_parity_moe.py +++ b/onnxruntime/test/python/transformers/test_moe_cuda.py @@ -9,6 +9,8 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +import itertools +import os import unittest from collections import OrderedDict @@ -21,38 +23,54 @@ import onnxruntime +# Reduces number of tests to run for faster pipeline checks +pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" + +onnxruntime.preload_dlls() + +# Determine the execution provider and device based on CUDA availability. +use_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers() and torch.cuda.is_available() +device = torch.device("cuda:0" if use_cuda else "cpu") +ort_provider = ["CUDAExecutionProvider"] if use_cuda else ["CPUExecutionProvider"] + torch.manual_seed(42) numpy.random.seed(42) -USE_QUANT = False -ORT_DTYPE = TensorProto.FLOAT16 if USE_QUANT else TensorProto.FLOAT -NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 -THRESHOLD = 5e-1 if USE_QUANT else 1e-2 +onnx_to_torch_type_map = { + TensorProto.FLOAT16: torch.float16, + TensorProto.FLOAT: torch.float, + TensorProto.BFLOAT16: torch.bfloat16, + TensorProto.UINT8: torch.uint8, +} +ort_to_numpy_type_map = { + TensorProto.FLOAT16: numpy.float16, + TensorProto.FLOAT: numpy.float32, + TensorProto.UINT8: numpy.uint8, +} -def value_string_of(numpy_array): - arr = numpy_array.flatten() - lines = ["f, ".join([str(v) for v in arr[i : min(i + 8, arr.size)]]) for i in range(0, arr.size, 8)] - return "{\n " + "f,\n ".join(lines) + "f}" +ort_dtype_name_map = { + TensorProto.FLOAT16: "FP16", + TensorProto.FLOAT: "FP32", + TensorProto.BFLOAT16: "BF16", +} -def print_tensor(name, numpy_array): - print(f"const std::vector {name} = {value_string_of(numpy_array)};") +def quant_dequant(weights, is_4_bit_quantization: bool = True): + type = torch.quint4x2 if is_4_bit_quantization else torch.int8 + import tensorrt_llm # noqa: PLC0415 -def quant_dequant(weights, quant_mode: bool = True): - # use the test version `_symmetric_...` to get the non-interleaved weights - type = torch.quint4x2 if quant_mode else torch.int8 - # This import is needed to use torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix() - # Comment out this line for passing the lintrunner check in the CI. - # import tensorrt_llm + # Avoid lint false alert that the package is not used. Note that this function will not be called in pipeline. + if pipeline_mode: + print("Tensorrt LLM version", tensorrt_llm.__version__) quant_weights, processed_q_weight, torch_weight_scales = ( torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix(weights.T.cpu().contiguous(), type) ) # Unpack the int4s int int8s - if quant_mode: + if is_4_bit_quantization: upper = quant_weights >> 4 lower = (quant_weights << 4) >> 4 # Arithmetic right shift sign extends quant_weights = torch.stack((lower, upper), dim=2).view(weights.T.shape) @@ -71,6 +89,7 @@ def create_moe_onnx_graph( fc1_experts_bias, fc2_experts_weights, fc2_experts_bias, + onnx_dtype, ): nodes = [ helper.make_node( @@ -94,21 +113,21 @@ def create_moe_onnx_graph( fc1_shape = [num_experts, hidden_size, inter_size] fc2_shape = [num_experts, inter_size, hidden_size] - torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 + torch_dtype = onnx_to_torch_type_map[onnx_dtype] initializers = [ helper.make_tensor( "fc1_experts_weights", - ORT_DTYPE, + onnx_dtype, fc1_shape, - fc1_experts_weights.to(torch_type).flatten().tolist(), + fc1_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_experts_weights", - ORT_DTYPE, + onnx_dtype, fc2_shape, - fc2_experts_weights.to(torch_type).flatten().tolist(), + fc2_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), ] @@ -119,35 +138,35 @@ def create_moe_onnx_graph( [ helper.make_tensor( "fc1_experts_bias", - ORT_DTYPE, + onnx_dtype, fc1_bias_shape, - fc1_experts_bias.to(torch_type).flatten().tolist(), + fc1_experts_bias.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_experts_bias", - ORT_DTYPE, + onnx_dtype, fc2_bias_shape, - fc2_experts_bias.to(torch_type).flatten().tolist(), + fc2_experts_bias.to(torch_dtype).flatten().tolist(), raw=False, ), ] ) graph_inputs = [ - helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ORT_DTYPE, + onnx_dtype, [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("output", onnx_dtype, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -171,6 +190,7 @@ def create_mixtral_moe_onnx_graph( fc2_experts_weights, fc3_experts_weights, topk, + onnx_dtype, ): nodes = [ helper.make_node( @@ -197,46 +217,46 @@ def create_mixtral_moe_onnx_graph( fc2_shape = [num_experts, inter_size, hidden_size] fc3_shape = [num_experts, hidden_size, inter_size] - torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 + torch_dtype = onnx_to_torch_type_map[onnx_dtype] initializers = [ helper.make_tensor( "fc1_experts_weights", - ORT_DTYPE, + onnx_dtype, fc1_shape, - fc1_experts_weights.to(torch_type).flatten().tolist(), + fc1_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_experts_weights", - ORT_DTYPE, + onnx_dtype, fc2_shape, - fc2_experts_weights.to(torch_type).flatten().tolist(), + fc2_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc3_experts_weights", - ORT_DTYPE, + onnx_dtype, fc3_shape, - fc3_experts_weights.to(torch_type).flatten().tolist(), + fc3_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), ] graph_inputs = [ - helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ORT_DTYPE, + onnx_dtype, [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("output", onnx_dtype, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -259,12 +279,14 @@ def create_phi_moe_onnx_graph( fc1_experts_weights, fc2_experts_weights, fc3_experts_weights, - fc1_scales, - fc2_scales, - fc3_scales, topk, + onnx_dtype, + quant_bits=0, + fc1_scales=None, + fc2_scales=None, + fc3_scales=None, ): - use_quant = USE_QUANT + use_quant = quant_bits > 0 if use_quant: assert fc1_experts_weights.dtype == torch.int8 assert fc2_experts_weights.dtype == torch.int8 @@ -276,34 +298,37 @@ def create_phi_moe_onnx_graph( assert fc2_scales.dtype == torch.float16 assert fc3_scales.dtype == torch.float16 + op_name = "QMoE" if use_quant else "MoE" + inputs = ( + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_scales", + "", + "fc2_experts_weights", + "fc2_scales", + "", + "fc3_experts_weights", + "fc3_scales", + "", + ] + if use_quant + else [ + "input", + "router_probs", + "fc1_experts_weights", + "", + "fc2_experts_weights", + "", + "fc3_experts_weights", + ] + ) + nodes = [ helper.make_node( - "MoE" if not use_quant else "QMoE", - ( - [ - "input", - "router_probs", - "fc1_experts_weights", - "", - "fc2_experts_weights", - "", - "fc3_experts_weights", - ] - if not use_quant - else [ - "input", - "router_probs", - "fc1_experts_weights", - "fc1_scales", - "", - "fc2_experts_weights", - "fc2_scales", - "", - "fc3_experts_weights", - "fc3_scales", - "", - ] - ), + op_name, + inputs, ["output"], "MoE_0", k=topk, @@ -315,37 +340,38 @@ def create_phi_moe_onnx_graph( ] if use_quant: - nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", 8)]) + nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) - fc1_shape = [num_experts, hidden_size, inter_size] - fc2_shape = [num_experts, inter_size, hidden_size] - fc3_shape = [num_experts, hidden_size, inter_size] + components = 2 if quant_bits == 4 else 1 + fc1_shape = [num_experts, hidden_size, inter_size // components] + fc2_shape = [num_experts, inter_size, hidden_size // components] + fc3_shape = [num_experts, hidden_size, inter_size // components] - torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 - numpy_type = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 - if use_quant: - numpy_type = numpy.uint8 + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + + weight_numpy_type = numpy.uint8 if use_quant else ort_to_numpy_type_map[onnx_dtype] + weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype initializers = [ helper.make_tensor( "fc1_experts_weights", - ORT_DTYPE if not use_quant else TensorProto.UINT8, + weight_onnx_type, fc1_shape, - fc1_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(), + fc1_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), raw=False, ), helper.make_tensor( "fc2_experts_weights", - ORT_DTYPE if not use_quant else TensorProto.UINT8, + weight_onnx_type, fc2_shape, - fc2_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(), + fc2_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), raw=False, ), helper.make_tensor( "fc3_experts_weights", - ORT_DTYPE if not use_quant else TensorProto.UINT8, + weight_onnx_type, fc3_shape, - fc3_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(), + fc3_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), raw=False, ), ] @@ -358,42 +384,42 @@ def create_phi_moe_onnx_graph( [ helper.make_tensor( "fc1_scales", - ORT_DTYPE, + onnx_dtype, fc1_scale_shape, - fc1_scales.to(torch_type).flatten().tolist(), + fc1_scales.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_scales", - ORT_DTYPE, + onnx_dtype, fc2_scale_shape, - fc2_scales.to(torch_type).flatten().tolist(), + fc2_scales.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc3_scales", - ORT_DTYPE, + onnx_dtype, fc3_scale_shape, - fc3_scales.to(torch_type).flatten().tolist(), + fc3_scales.to(torch_dtype).flatten().tolist(), raw=False, ), ] ) graph_inputs = [ - helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ORT_DTYPE, + onnx_dtype, [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("output", onnx_dtype, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -546,126 +572,127 @@ def __init__(self, config: PhiMoEConfig): class SparseMoeBlockORTHelper(nn.Module): - def __init__(self): + def __init__(self, quant_bits=0, onnx_dtype=None): super().__init__() + self.quant_bits = quant_bits + if onnx_dtype is None: + self.onnx_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT + else: + self.onnx_dtype = onnx_dtype + self.np_type = numpy.float16 if self.onnx_dtype == TensorProto.FLOAT16 else numpy.float32 def create_ort_session(self, moe_onnx_graph): from onnxruntime import InferenceSession, SessionOptions # noqa: PLC0415 sess_options = SessionOptions() + sess_options.log_severity_level = 2 - cuda_providers = ["CUDAExecutionProvider"] - if cuda_providers[0] not in onnxruntime.get_available_providers(): + try: + ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=ort_provider) + except Exception as e: + print(f"Failed to create ONNX Runtime session with provider {ort_provider}: {e}") + print("Skipping ONNX Runtime execution for this test case.") return None - sess_options.log_severity_level = 2 - ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"]) - return ort_session def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: pass - def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor: + def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False) -> torch.Tensor: + if self.ort_sess is None: + return None + batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) + hidden_states_flat = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - ort_inputs = { - "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)), - "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)), - } + router_logits = self.gate(hidden_states_flat) - ort_output = None - if self.ort_sess is not None: - if not iobinding: - ort_output = self.ort_sess.run(None, ort_inputs) - return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits - else: - self.ort_run_with_iobinding(ort_inputs) - return None + # Determine the correct torch dtype from the onnx_dtype + torch_dtype = onnx_to_torch_type_map[self.onnx_dtype] - # print_tensor("input", ort_inputs["input"]) - # print_tensor("router_probs", ort_inputs["router_probs"]) - # print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy()) - # print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy()) - # print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy()) - # print_tensor("output", ort_output[0]) - - return None + # Prepare tensors on the correct device for ORT inference with the CORRECT dtype + tensors = { + "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype), + "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype), + "output": torch.zeros_like(hidden_states_flat, device=device, dtype=torch_dtype), + } - def ort_run_with_iobinding(self, ort_inputs, repeat=1000): + # Bind inputs and outputs to torch tensors directly. iobinding = self.ort_sess.io_binding() - device_id = torch.cuda.current_device() - - iobinding.bind_input( - name="input", - device_type="cuda", - device_id=device_id, - element_type=NP_TYPE, - shape=ort_inputs["input"].shape, - buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(ort_inputs["input"], "cuda", device_id).data_ptr(), - ) - iobinding.bind_input( - name="router_probs", - device_type="cuda", - device_id=device_id, - element_type=NP_TYPE, - shape=ort_inputs["router_probs"].shape, - buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( - ort_inputs["router_probs"], "cuda", device_id - ).data_ptr(), - ) - - iobinding.bind_output( - name="output", - device_type="cuda", - device_id=device_id, - element_type=NP_TYPE, - shape=ort_inputs["input"].shape, - buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( - numpy.zeros(ort_inputs["input"].shape), "cuda", device_id - ).data_ptr(), - ) - - # warm up - for _ in range(5): - iobinding.synchronize_inputs() - self.ort_sess.run_with_iobinding(iobinding) - iobinding.synchronize_outputs() - - import time # noqa: PLC0415 - - s = time.time() - for _ in range(repeat): - iobinding.synchronize_inputs() - self.ort_sess.run_with_iobinding(iobinding) - iobinding.synchronize_outputs() - e = time.time() - print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms") + for name, tensor in tensors.items(): + # Ensure tensor is on the globally defined device + if name == "output": + iobinding.bind_output( + name=name, + device_type=tensor.device.type, + device_id=tensor.device.index or 0, + element_type=self.onnx_dtype, + shape=tensor.shape, + buffer_ptr=tensor.data_ptr(), + ) + else: + iobinding.bind_input( + name=name, + device_type=tensor.device.type, + device_id=tensor.device.index or 0, + element_type=self.onnx_dtype, + shape=tensor.shape, + buffer_ptr=tensor.data_ptr(), + ) + + iobinding.synchronize_inputs() + self.ort_sess.run_with_iobinding(iobinding) + iobinding.synchronize_outputs() + + if enable_performance_test: + import time # noqa: PLC0415 + + repeat = 1000 + s = time.time() + for _ in range(repeat): + iobinding.synchronize_inputs() + self.ort_sess.run_with_iobinding(iobinding) + iobinding.synchronize_outputs() + e = time.time() + print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms") + + # The output tensor is on `device`. Reshape and return it. + return tensors["output"].reshape(batch_size, sequence_length, hidden_dim) def parity_check(self): - hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) torch_output = self.forward(hidden_state) ort_output = self.ort_forward(hidden_state) + + dtype_str = ort_dtype_name_map[self.onnx_dtype] + + # Maps "ort_type:quant_bits" to (atol, rtol) + ort_dtype_quant_bits_tolerance_map = { + "FP32:0": (5e-3, 1e-3), + "FP16:0": (5e-2, 1e-3), + "FP16:4": (3.0, 1e-2), + "FP16:8": (2.0, 1e-2), + "BF16:0": (1.0, 1e-2), + "BF16:4": (30.0, 1e-1), + "BF16:8": (20.0, 1e-1), + } + + atol, rtol = ort_dtype_quant_bits_tolerance_map[f"{dtype_str}:{self.quant_bits}"] if ort_output is not None: print( - "name:", - self.__class__.__name__, - " batch_size:", - self.batch_size, - " sequence_length:", - self.sequence_length, - " max_diff:", - (torch_output - ort_output).abs().max(), + f"name: {self.__class__.__name__}, quant_bits: {self.quant_bits}, dtype: {dtype_str}," + f" batch: {self.batch_size}, seq_len: {self.sequence_length}," + f" max_diff: {(torch_output.cpu() - ort_output.cpu()).abs().max()}" + ) + torch.testing.assert_close( + ort_output.cpu().to(torch.float32), torch_output.cpu().to(torch.float32), rtol=rtol, atol=atol ) - torch.testing.assert_close(ort_output.to(torch.float32), torch_output, rtol=THRESHOLD, atol=THRESHOLD) def benchmark_ort(self): - hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) - self.ort_forward(hidden_state, iobinding=True) + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) + self.ort_forward(hidden_state, enable_performance_test=True) class SwitchMoE(SparseMoeBlockORTHelper): @@ -680,7 +707,7 @@ def __init__( eval_capacity=-1, activation="gelu", ): - super().__init__() + super().__init__(quant_bits=0) # SwitchMoE is not quantized self.batch_size = batch_size self.sequence_length = sequence_length self.num_experts = num_experts @@ -709,6 +736,7 @@ def __init__( self.moe_experts.bias1, self.moe_experts.weight2.transpose(1, 2), self.moe_experts.bias2, + self.onnx_dtype, ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) @@ -744,7 +772,7 @@ class MixtralSparseMoeBlock(SparseMoeBlockORTHelper): """ def __init__(self, config, batch_size, sequence_length): - super().__init__() + super().__init__(quant_bits=0) # Mixtral test is not quantized self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts @@ -778,6 +806,7 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight2, self.moe_experts_weight3, self.top_k, + self.onnx_dtype, ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) @@ -874,40 +903,41 @@ class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper): and memory on padding. """ - def __init__(self, config, batch_size, sequence_length): - super().__init__() + def __init__(self, config, batch_size, sequence_length, quant_bits=0, onnx_dtype=None): + super().__init__(quant_bits, onnx_dtype) self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_tok self.router_jitter_noise = config.router_jitter_noise + use_quant = self.quant_bits > 0 # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) self.experts = nn.ModuleList([PhiMoEBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) - w1_list = [] - w2_list = [] - w3_list = [] - w1_scale_list = [] - w2_scale_list = [] - w3_scale_list = [] - if not USE_QUANT: + w1_list, w2_list, w3_list = [], [], [] + w1_scale_list, w2_scale_list, w3_scale_list = [], [], [] + + if not use_quant: for i in range(self.num_experts): w1_list.append(self.experts[i].w1.weight) w2_list.append(self.experts[i].w2.weight) w3_list.append(self.experts[i].w3.weight) else: + is_4_bit = self.quant_bits == 4 for i in range(self.num_experts): - w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, False) - w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, False) - w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, False) + # Corrected quantization logic for per-output-channel quantization + w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, is_4_bit) + w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit) + w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit) self.experts[i].w1.weight.data = w1_qdq self.experts[i].w2.weight.data = w2_qdq self.experts[i].w3.weight.data = w3_qdq + # Transpose quantized weights to match the expected ONNX layout w1_list.append(pre_qweight1) w2_list.append(pre_qweight2) w3_list.append(pre_qweight3) @@ -919,9 +949,9 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight2 = torch.stack(w2_list, dim=0) self.moe_experts_weight3 = torch.stack(w3_list, dim=0) - moe_experts_weight_scale1 = torch.stack(w1_scale_list, dim=0) if USE_QUANT else None - moe_experts_weight_scale2 = torch.stack(w2_scale_list, dim=0) if USE_QUANT else None - moe_experts_weight_scale3 = torch.stack(w3_scale_list, dim=0) if USE_QUANT else None + moe_experts_weight_scale1 = torch.stack(w1_scale_list, dim=0) if use_quant else None + moe_experts_weight_scale2 = torch.stack(w2_scale_list, dim=0) if use_quant else None + moe_experts_weight_scale3 = torch.stack(w3_scale_list, dim=0) if use_quant else None self.batch_size = batch_size self.sequence_length = sequence_length @@ -933,10 +963,12 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight1, self.moe_experts_weight2, self.moe_experts_weight3, + self.top_k, + self.onnx_dtype, + self.quant_bits, moe_experts_weight_scale1, moe_experts_weight_scale2, moe_experts_weight_scale3, - self.top_k, ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) @@ -995,18 +1027,10 @@ def small_test_cases(): yield batch_size, sequence_length -def phi3_test_cases(): - # TODO: phi3 moe failed in long sequence lengths (max diff 0.22 > threshold 0.01), need investigation. - for batch_size in [1, 4, 16]: - for sequence_length in [128]: - yield batch_size, sequence_length - - +@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.") class TestSwitchMoE(unittest.TestCase): @parameterized.expand(small_test_cases()) def test_switch_moe_parity(self, batch_size, sequence_length): - # if platform.system() == "Windows": - # pytest.skip("Skip on Windows") switch_moe = SwitchMoE( batch_size=batch_size, sequence_length=sequence_length, @@ -1015,26 +1039,412 @@ def test_switch_moe_parity(self, batch_size, sequence_length): hidden_features=1024, out_features=256, ) + switch_moe.to(device) switch_moe.parity_check() - # switch_moe.benchmark_ort() +# quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16) +# since qMoE test requires tensorrt_llm for quant_dequant. We disable it in CI pipeline to avoid extra dependency. +quant_bits_list = [0] if pipeline_mode else [0, 8, 4] + + +@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.") class TestMixtralMoE(unittest.TestCase): @parameterized.expand(small_test_cases()) def test_mixtral_moe_parity(self, batch_size, sequence_length): config = MixtralConfig(hidden_size=256, intermediate_size=1024) mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length) + mixtral_moe.to(device) mixtral_moe.parity_check() - # mixtral_moe.benchmark_ort() +phi3_test_cases = list( + itertools.product( + [1, 4], # batch_size + [1, 32], # sequence_length + quant_bits_list, + ) +) + + +@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.") class TestPhiMoE(unittest.TestCase): - @parameterized.expand(phi3_test_cases()) - def test_phi3_moe_parity(self, batch_size, sequence_length): + @parameterized.expand(phi3_test_cases) + def test_phi3_moe_parity(self, batch_size, sequence_length, quant_bits): config = PhiMoEConfig(hidden_size=256, intermediate_size=1024) - phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length) + phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length, quant_bits) + phi3_moe.to(device) phi3_moe.parity_check() - # phi3_moe.benchmark_ort() + + +# --------------------------------------------- +# The following test are for swiglu activation +# --------------------------------------------- +class SwigluMoeConfig: + def __init__( + self, + hidden_size=2048, + intermediate_size=2048, + num_experts_per_token=2, + num_local_experts=8, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_experts_per_token = num_experts_per_token + self.num_local_experts = num_local_experts + + +def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): + dim = x.shape[-1] + x = x.view(-1, dim // 2, 2) + x_glu, x_linear = x[..., 0], x[..., 1] + + if limit is not None: + x_glu = x_glu.clamp(max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + + y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) + return y + + +class SwigluMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_size = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True) + + def forward(self, x): + x1 = self.w1(x) + y = swiglu(x1) + y = self.w2(y) + return y + + +# Note that the weight shape might not match the tensor shape in legacy operator spec. +def make_onnx_intializer(name: str, tensor: torch.Tensor, shape, onnx_dtype): + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + if torch_dtype == torch.bfloat16: + numpy_vals_uint16 = tensor.to(torch.bfloat16).cpu().view(torch.uint16).numpy() + initializer = helper.make_tensor( + name=name, + data_type=TensorProto.BFLOAT16, + dims=shape, + vals=numpy_vals_uint16.tobytes(), + raw=True, + ) + else: + initializer = helper.make_tensor( + name=name, + data_type=onnx_dtype, + dims=shape, + vals=tensor.flatten().detach().cpu().numpy().astype(numpy.uint8).tolist() + if onnx_dtype == TensorProto.UINT8 + else tensor.detach().to(torch_dtype).flatten().tolist(), + raw=False, + ) + return initializer + + +def create_swiglu_moe_onnx_graph( + num_tokens: int, + num_experts: int, + hidden_size: int, + inter_size: int, + topk: int, + onnx_dtype: int, + quant_bits: int, + fc1_experts_weights: torch.Tensor, + fc1_experts_bias: torch.Tensor, + fc2_experts_weights: torch.Tensor, + fc2_experts_bias: torch.Tensor, + fc1_experts_weight_scale: torch.Tensor = None, + fc2_experts_weight_scale: torch.Tensor = None, +): + use_quant = quant_bits > 0 + op_name = "QMoE" if use_quant else "MoE" + + inputs = ( + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_weight_scale", + "fc1_experts_bias", + "fc2_experts_weights", + "fc2_experts_weight_scale", + "fc2_experts_bias", + ] + if use_quant + else [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_bias", + "fc2_experts_weights", + "fc2_experts_bias", + ] + ) + + nodes = [ + helper.make_node( + op_name, + inputs, + ["output"], + "MoE_0", + k=topk, + normalize_routing_weights=1, + activation_type="swiglu", + domain="com.microsoft", + ), + ] + + if use_quant: + nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) + + components = 2 if quant_bits == 4 else 1 + fc1_weight_shape = [num_experts, 2 * inter_size, hidden_size // components] + fc1_bias_shape = [num_experts, 2 * inter_size] + fc1_experts_weight_scale_shape = [num_experts, 2 * inter_size] + + fc2_weight_shape = [num_experts, hidden_size, inter_size // components] + fc2_bias_shape = [num_experts, hidden_size] + fc2_experts_weight_scale_shape = [num_experts, hidden_size] + + weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype + + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + weight_torch_dtype = onnx_to_torch_type_map[weight_onnx_type] + + initializers = [ + make_onnx_intializer( + "fc1_experts_weights", fc1_experts_weights.to(weight_torch_dtype), fc1_weight_shape, weight_onnx_type + ), + make_onnx_intializer("fc1_experts_bias", fc1_experts_bias.to(torch_dtype), fc1_bias_shape, onnx_dtype), + make_onnx_intializer( + "fc2_experts_weights", fc2_experts_weights.to(weight_torch_dtype), fc2_weight_shape, weight_onnx_type + ), + make_onnx_intializer("fc2_experts_bias", fc2_experts_bias.to(torch_dtype), fc2_bias_shape, onnx_dtype), + ] + + if use_quant: + initializers.extend( + [ + make_onnx_intializer( + "fc1_experts_weight_scale", + fc1_experts_weight_scale.to(torch_dtype), + fc1_experts_weight_scale_shape, + onnx_dtype, + ), + make_onnx_intializer( + "fc2_experts_weight_scale", + fc2_experts_weight_scale.to(torch_dtype), + fc2_experts_weight_scale_shape, + onnx_dtype, + ), + ] + ) + + graph_inputs = [ + helper.make_tensor_value_info("input", onnx_dtype, [num_tokens, hidden_size]), + ] + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + onnx_dtype, + [num_tokens, num_experts], + ) + ) + + graph_outputs = [ + helper.make_tensor_value_info("output", onnx_dtype, [num_tokens, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +class SwigluMoEBlock(SparseMoeBlockORTHelper): + def __init__( + self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + ): + super().__init__(quant_bits, onnx_dtype=onnx_dtype) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_token + use_quant = self.quant_bits > 0 + + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) + + self.experts = nn.ModuleList([SwigluMlp(config) for _ in range(self.num_experts)]) + + # For the ONNX MoE operator, weights must be transposed to [In, Out] format. + # Biases do not require transposition. + fc1_w_list, fc2_w_list = [], [] + fc1_b_list, fc2_b_list = [], [] + scale_1_list, scale_2_list = [], [] + + for expert in self.experts: + fc1_b_list.append(expert.w1.bias) + fc2_b_list.append(expert.w2.bias) + if not use_quant: + fc1_w_list.append(expert.w1.weight) + fc2_w_list.append(expert.w2.weight) + else: + is_4_bit = self.quant_bits == 4 + + # quant_dequant expects [Out, In] format, matching nn.Linear.weight + scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) + + # Update the expert's weight with the dequantized version for the PyTorch reference. + expert.w1.weight.data = w1_qdq + expert.w2.weight.data = w2_qdq + + fc1_w_list.append(pre_qweight1) + fc2_w_list.append(pre_qweight2) + scale_1_list.append(scale1) + scale_2_list.append(scale2) + + # Stack the prepared tensors for the graph builder + fc1_experts_weights = torch.stack(fc1_w_list, dim=0) + fc2_experts_weights = torch.stack(fc2_w_list, dim=0) + fc1_experts_bias = torch.stack(fc1_b_list, dim=0) + fc2_experts_bias = torch.stack(fc2_b_list, dim=0) + + moe_experts_weight_scale1 = torch.stack(scale_1_list, dim=0) if use_quant else None + moe_experts_weight_scale2 = torch.stack(scale_2_list, dim=0) if use_quant else None + + self.batch_size = batch_size + self.sequence_length = sequence_length + + # Build the ONNX graph with the correctly shaped tensors + self.moe_onnx_graph = create_swiglu_moe_onnx_graph( + num_tokens=self.batch_size * self.sequence_length, + num_experts=self.num_experts, + hidden_size=self.hidden_dim, + inter_size=self.ffn_dim, + topk=self.top_k, + onnx_dtype=self.onnx_dtype, + quant_bits=self.quant_bits, + fc1_experts_weights=fc1_experts_weights, + fc1_experts_bias=fc1_experts_bias, + fc2_experts_weights=fc2_experts_weights, + fc2_experts_bias=fc2_experts_bias, + fc1_experts_weight_scale=moe_experts_weight_scale1, + fc2_experts_weight_scale=moe_experts_weight_scale2, + ) + + self.ort_sess = self.create_ort_session(self.moe_onnx_graph) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + This is the robust PyTorch reference implementation. It directly uses the + nn.Module experts, which is cleaner and less error-prone than manual matmul. + """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + routing_weights, selected_experts = torch.topk(router_logits, self.top_k, dim=-1) + routing_weights = F.softmax(routing_weights, dim=1, dtype=torch.float) + + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + + +swiglu_test_cases = list( + itertools.product( + [1, 2], # batch_size + [1, 3], # sequence_length + quant_bits_list, # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16) + ) +) + + +@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.") +class TestSwigluMoE(unittest.TestCase): + @parameterized.expand(swiglu_test_cases) + def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): + config = SwigluMoeConfig(hidden_size=64, intermediate_size=256, num_experts_per_token=2, num_local_experts=4) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits) + moe.to(device) + moe.parity_check() + + +def has_bf16_moe(): + if "CUDAExecutionProvider" not in onnxruntime.get_available_providers() or not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 8 + + +@unittest.skipIf(not has_bf16_moe(), "skipping bf16 moe tests.") +class TestSwigluMoeBf16(unittest.TestCase): + @parameterized.expand(swiglu_test_cases) + def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): + config = SwigluMoeConfig(hidden_size=64, intermediate_size=128, num_experts_per_token=2, num_local_experts=4) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits, onnx_dtype=TensorProto.BFLOAT16) + moe.to(device) + moe.parity_check() + + +perf_test_cases = list( + itertools.product( + [1], # batch_size + [128, 512, 1024, 2048, 4096], # sequence_length + [0, 8, 4], # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16) + ) +) + + +@unittest.skipIf(pipeline_mode or not use_cuda, "skipping performance test in CI pipeline.") +class TestSwigluMoEPerf(unittest.TestCase): + @parameterized.expand(perf_test_cases) + def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): + hidden_size = 2880 + intermediate_size = 2880 + num_experts_per_token = 8 + num_local_experts = 128 + config = SwigluMoeConfig( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_experts_per_token=num_experts_per_token, + num_local_experts=num_local_experts, + ) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits) + moe.to(device) + moe.benchmark_ort() if __name__ == "__main__": diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py new file mode 100644 index 0000000000000..efaaca29a01b6 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -0,0 +1,1118 @@ +# -------------------------------------------------------------------------- +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# +# QMoE quantization implementation notes: +# +# Both CPU and CUDA implementations use symmetric quantization centered around 0: +# - 4-bit: range [-8, 7] with no zero-point (symmetric around 0) +# - 8-bit: range [-128, 127] with no zero-point (symmetric around 0) +# +# This follows the _symmetric_quantize_last_axis_of_batched_matrix pattern. +# Tolerance values account for numerical differences between implementations. +# +# Routing Logic: CPU implementation uses top-k selection first, then softmax +# normalization on the selected experts. This provides proper weight distribution +# while maintaining computational efficiency. +# -------------------------------------------------------------------------- +import time +import unittest +from collections import OrderedDict + +import numpy +import torch +import torch.nn.functional as F +from onnx import helper +from parameterized import parameterized +from torch import nn + +import onnxruntime + +try: + from onnx import TensorProto + + has_onnx = True +except ImportError: + has_onnx = False + + class TensorProtoPlaceholder: + FLOAT16 = 10 + FLOAT = 1 + + +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + "silu": nn.SiLU, + "gelu": nn.GELU, +} +ACT2FN = ClassInstantier(ACT2CLS) + +if not has_onnx: + + class TensorProtoPlaceholder: + FLOAT16 = 10 + FLOAT = 1 + UINT8 = 2 + + TensorProto = TensorProtoPlaceholder + +onnxruntime.preload_dlls() + +device = torch.device("cpu") + +ort_provider = ["CPUExecutionProvider"] + +torch.manual_seed(42) +numpy.random.seed(42) + +onnx_to_torch_type_map = { + TensorProto.FLOAT16: torch.float16, + TensorProto.FLOAT: torch.float, + TensorProto.UINT8: torch.uint8, +} + +ort_to_numpy_type_map = { + TensorProto.FLOAT16: numpy.float16, + TensorProto.FLOAT: numpy.float32, + TensorProto.UINT8: numpy.uint8, +} + +ort_dtype_name_map = { + TensorProto.FLOAT16: "FP16", + TensorProto.FLOAT: "FP32", +} + + +def quant_dequant(weights, is_4_bit_quantization: bool = True): + """ + Quantize and dequantize weights for testing purposes. + This function uses symmetric quantization centered around 0 (no zero-point). + + This uses symmetric quantization similar to _symmetric_quantize_last_axis_of_batched_matrix: + - 4-bit: range = [-8, 7], no zero-point (symmetric around 0) + - 8-bit: range = [-128, 127], no zero-point (symmetric around 0) + """ + # Handle edge case of all-zero weights tensor + if torch.all(weights == 0): + if is_4_bit_quantization: + packed_size = (weights.shape[-1] + 1) // 2 + return ( + torch.zeros_like(weights[..., 0:1]), + torch.zeros( + (weights.shape[0], weights.shape[1], packed_size), + dtype=torch.uint8, + device=weights.device, + ), + torch.zeros_like(weights), + ) + else: + return ( + torch.zeros_like(weights[..., 0:1]), + torch.zeros_like(weights, dtype=torch.uint8), + torch.zeros_like(weights), + ) + + # Calculate scale like C++ implementation + abs_max = weights.abs().max(dim=-1, keepdim=True)[0] + abs_max = torch.clamp(abs_max, min=1e-8) # More conservative clamping for better precision + + if is_4_bit_quantization: + # 4-bit: scale = abs_max / 7.0 (using 7.0 as max positive value for symmetric range) + # Use higher precision computation for better accuracy + scale = (abs_max.double() / 7.0).float() + 1e-12 + + # Handle potential edge cases for zero or very small weights + if torch.max(abs_max) < 1e-8: + packed_size = (weights.shape[-1] + 1) // 2 + return ( + torch.ones_like(weights[..., 0:1]) * 1e-8, + torch.zeros( + (weights.shape[0], weights.shape[1], packed_size), + dtype=torch.uint8, + device=weights.device, + ), + torch.zeros_like(weights), + ) + + # Quantize: round(weight / scale) then clamp to [-8, 7] + # Use higher precision for the division to reduce accumulated errors + scaled_weights = weights.double() / scale.double() + quantized_weights = torch.round(scaled_weights).clamp(-8, 7).float() + + # For symmetric quantization, we use signed int4 representation + # Convert to uint8 storage for packing: shift [-8,7] -> [0,15] for storage only + storage_weights = (quantized_weights + 8).to(torch.uint8) + + # Pack 4-bit values into uint8 (every two elements) + even_indices = torch.arange(0, weights.shape[-1], 2) + odd_indices = torch.arange(1, weights.shape[-1], 2) + + # Handle odd length by padding with zero (which is 8 in storage representation) + if odd_indices.shape[0] < even_indices.shape[0]: + padding = torch.full( + (storage_weights.shape[0], storage_weights.shape[1], 1), + fill_value=8, # 0 in symmetric quantization, stored as 8 + dtype=torch.uint8, + device=storage_weights.device, + ) + storage_weights = torch.cat([storage_weights, padding], dim=-1) + odd_indices = torch.arange(1, storage_weights.shape[-1], 2) + + even_weights = storage_weights[..., even_indices] + odd_weights = storage_weights[..., odd_indices] + + # Pack: low nibble = even, high nibble = odd + packed_weights = (even_weights & 0xF) | ((odd_weights & 0xF) << 4) + + # Dequantize: scale * quantized_value (no zero-point subtraction) + # Unpack for dequantization + lower = packed_weights & 0xF + upper = (packed_weights >> 4) & 0xF + + # Restore original shape and convert back to signed representation + unpacked_weights = torch.zeros_like(weights, dtype=torch.uint8) + unpacked_weights[..., even_indices] = lower + + valid_odd_length = min(odd_indices.shape[0], weights.shape[-1] - even_indices.shape[0]) + if valid_odd_length > 0: + valid_odd_indices = odd_indices[:valid_odd_length] + unpacked_weights[..., valid_odd_indices] = upper[..., :valid_odd_length] + + # Convert back to signed values: [0,15] -> [-8,7] and apply scale + signed_weights = unpacked_weights.float() - 8.0 # Convert storage back to signed + dequant_scale = scale.float() # Ensure FP32 precision for computation + result = dequant_scale * signed_weights # No zero-point in symmetric quantization + + return scale.to(torch.float16), packed_weights, result.to(weights.dtype) + else: + # 8-bit: scale = abs_max / 127.0 (using 127.0 as max positive value for symmetric range) + # Use higher precision computation for better accuracy + scale = (abs_max.double() / 127.0).float() + 1e-12 + + # Handle potential edge cases for zero or very small weights + if torch.max(abs_max) < 1e-8: + return ( + torch.ones_like(weights[..., 0:1]) * 1e-8, + torch.zeros_like(weights, dtype=torch.uint8), + torch.zeros_like(weights), + ) + + # Quantize: round(weight / scale) then clamp to [-128, 127] + # Use higher precision for the division to reduce accumulated errors + scaled_weights = weights.double() / scale.double() + quantized_weights = torch.round(scaled_weights).clamp(-128, 127).float() + + # For symmetric quantization, we use signed int8 representation + # Convert to uint8 storage: shift [-128,127] -> [0,255] for storage only + storage_weights = (quantized_weights + 128).to(torch.uint8) + + # Dequantize: scale * quantized_value (no zero-point subtraction) + # Convert back to signed values: [0,255] -> [-128,127] and apply scale + signed_weights = storage_weights.float() - 128.0 # Convert storage back to signed + dequant_scale = scale.float() # Ensure FP32 precision for computation + result = dequant_scale * signed_weights # No zero-point in symmetric quantization + + return scale.to(torch.float16), storage_weights, result.to(weights.dtype) + + +def create_cpu_moe_onnx_graph( + hidden_size, + sequence_length, + num_experts, + top_k, + intermediate_size, + torch_dtype, + onnx_dtype, + fc1_experts_weights, + fc2_experts_weights, + fc1_bias=None, + fc2_bias=None, + fc1_scales=None, + fc2_scales=None, + use_swiglu=False, + use_quant=False, + quant_bits=4, + swiglu_interleaved=False, +): + if not has_onnx: + return None + + inter_size = intermediate_size + topk = top_k + + use_quant = True + + if fc1_scales is None and use_quant: + return None + if fc2_scales is None and use_quant: + return None + if not has_onnx: + return None + + assert fc1_experts_weights.dtype == torch.uint8, "FC1 weights must be uint8 for QMoE" + assert fc2_experts_weights.dtype == torch.uint8, "FC2 weights must be uint8 for QMoE" + assert fc1_scales is not None, "FC1 scales must be provided for QMoE" + assert fc2_scales is not None, "FC2 scales must be provided for QMoE" + assert fc1_scales.dtype == torch.float16, "FC1 scales must be float16 for QMoE" + assert fc2_scales.dtype == torch.float16, "FC2 scales must be float16 for QMoE" + + if not has_onnx: + return None + + op_name = "QMoE" + inputs = [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_scales", + "", + "fc2_experts_weights", + "fc2_scales", + "", + ] + + activation = "swiglu" if use_swiglu else "silu" + + nodes = [ + helper.make_node( + op_name, + inputs, + ["output"], + "MoE_0", + k=topk, + normalize_routing_weights=1, # Use proper routing normalization to match PyTorch behavior + activation_type=activation, + # Add new attributes with backwards-compatible default values + swiglu_fusion=1 if (use_swiglu and swiglu_interleaved) else 0, # 1 = fused and interleaved + swiglu_limit=7.0, + activation_alpha=1.702, + activation_beta=1.0, + domain="com.microsoft", + ), + ] + + if use_quant: + nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) + + # Weights are store in column major order. Need pack 2 int4 values into uint8. + # Use the actual tensor shapes instead of calculating them to avoid size mismatches + fc1_shape = list(fc1_experts_weights.shape) + fc2_shape = list(fc2_experts_weights.shape) + + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + + weight_numpy_type = numpy.uint8 if use_quant else ort_to_numpy_type_map[onnx_dtype] + weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype + + initializers = [ + helper.make_tensor( + "fc1_experts_weights", + weight_onnx_type, + fc1_shape, + fc1_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), + ), + helper.make_tensor( + "fc2_experts_weights", + weight_onnx_type, + fc2_shape, + fc2_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), + ), + ] + + fc1_scale_shape = [num_experts, 2 * inter_size if use_swiglu else inter_size] + fc2_scale_shape = [num_experts, hidden_size] + + fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size) + fc2_scale_size = num_experts * hidden_size + + # Handle scale tensors - fc1_scales and fc2_scales are guaranteed to be not None due to earlier assertions + # Handle different possible scale tensor structures for fc1_scales + if len(fc1_scales.shape) == 4: + # 4D case: [num_experts, inter_size, hidden_size, 1] - extract first scale per expert per output + if use_swiglu: + fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, : 2 * inter_size, 0, 0].flatten().detach().cpu().numpy() + else: + fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, :inter_size, 0, 0].flatten().detach().cpu().numpy() + elif len(fc1_scales.shape) == 2: + # 2D case: already flattened, just ensure correct size + fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if use_swiglu and fc1_scale_tensor.size == num_experts * inter_size: + # For SwiGLU, duplicate the scales to cover both gate and value components + fc1_scale_tensor = numpy.tile(fc1_scale_tensor.reshape(num_experts, inter_size), (1, 2)).flatten() + elif fc1_scale_tensor.size > fc1_scale_size: + # Truncate to expected size + fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] + else: + # Other cases: flatten and truncate/pad as needed + fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc1_scale_tensor.size > fc1_scale_size: + fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] + elif fc1_scale_tensor.size < fc1_scale_size: + # Pad with ones if too small + pad_size = fc1_scale_size - fc1_scale_tensor.size + fc1_scale_tensor = numpy.concatenate([fc1_scale_tensor, numpy.ones(pad_size, dtype=fc1_scale_tensor.dtype)]) + + # Process scale tensor for proper shape + fc1_scale_data_list = fc1_scale_tensor.tolist() + fc1_scale_data = fc1_scale_data_list + + # Handle different possible scale tensor structures for fc2_scales + if len(fc2_scales.shape) == 4: + # 4D case: [num_experts, hidden_size, inter_size, 1] - extract first scale per expert per output + fc2_scale_tensor = fc2_scales.to(torch_dtype)[:, :hidden_size, 0, 0].flatten().detach().cpu().numpy() + elif len(fc2_scales.shape) == 2: + # 2D case: already flattened, just ensure correct size + fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc2_scale_tensor.size > fc2_scale_size: + # Truncate to expected size + fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] + else: + # Other cases: flatten and truncate/pad as needed + fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc2_scale_tensor.size > fc2_scale_size: + fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] + elif fc2_scale_tensor.size < fc2_scale_size: + # Pad with ones if too small + pad_size = fc2_scale_size - fc2_scale_tensor.size + fc2_scale_tensor = numpy.concatenate([fc2_scale_tensor, numpy.ones(pad_size, dtype=fc2_scale_tensor.dtype)]) + + # Process scale tensor for proper shape + fc2_scale_data_list = fc2_scale_tensor.tolist() + fc2_scale_data = fc2_scale_data_list + + initializers.extend( + [ + helper.make_tensor( + "fc1_scales", + onnx_dtype, + fc1_scale_shape, + fc1_scale_data, + raw=False, + ), + helper.make_tensor( + "fc2_scales", + onnx_dtype, + fc2_scale_shape, + fc2_scale_data, + raw=False, + ), + ] + ) + + graph_inputs = [ + helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), + ] + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + onnx_dtype, + [sequence_length, num_experts], + ) + ) + + graph_outputs = [ + helper.make_tensor_value_info("output", onnx_dtype, [sequence_length, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +class PhiMoEConfig: + def __init__( + self, + hidden_size=4096, + intermediate_size=14336, + hidden_act="silu", + num_experts_per_tok=2, + num_local_experts=8, + router_jitter_noise=0.01, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.router_jitter_noise = router_jitter_noise + + +class SwigluMoeConfig: + def __init__( + self, + hidden_size=4096, + intermediate_size=14336, + num_local_experts=8, + num_experts_per_token=2, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_local_experts = num_local_experts + self.num_experts_per_token = num_experts_per_token + + +def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): + dim = x.shape[-1] + x = x.view(-1, dim // 2, 2) + x_glu, x_linear = x[..., 0], x[..., 1] + + if limit is not None: + x_glu = x_glu.clamp(max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + + y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) + return y + + +class MoEBlockSparseTop2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class PhiMoEBlockSparseTop2MLP(MoEBlockSparseTop2MLP): + def __init__(self, config: PhiMoEConfig): + super().__init__(config) + + +class PhiMoESwiGLUMLP(nn.Module): + """ + Phi3 MoE expert converted to 2-weight SwiGLU structure for CPU compatibility. + This converts the traditional 3-weight Phi3 structure to SwiGLU format. + """ + + def __init__(self, config: PhiMoEConfig): + super().__init__() + self.intermediate_size = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True) + + def forward(self, x): + x1 = self.w1(x) + y = swiglu(x1) + y = self.w2(y) + return y + + +class SwigluMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_size = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True) + + def forward(self, x): + x1 = self.w1(x) + y = swiglu(x1) + y = self.w2(y) + return y + + +def masked_sampling_omp_inference(scores, top_k, jitter_eps, training): + """ + Updated to match the CUDA implementation's routing logic for fair comparison. + This now uses the same complex jitter-based masking approach as the CUDA tests. + """ + assert top_k == 2 + assert not training + + mask_logits_threshold, selected_experts = torch.topk(scores, 2) + + mask_logits_threshold_1 = mask_logits_threshold[:, 0].unsqueeze(-1) + + factor = scores.abs().clamp(min=mask_logits_threshold_1) + logits_mask = ((mask_logits_threshold_1 - scores) / factor) > (2 * jitter_eps) + + multiplier_1 = torch.softmax(scores.masked_fill(logits_mask, float("-inf")), dim=-1).gather( + dim=-1, index=selected_experts[:, 0].unsqueeze(-1) + ) + + mask_logits_threshold_2 = mask_logits_threshold[:, 1].unsqueeze(-1) + + factor = scores.abs().clamp(min=mask_logits_threshold_2) + logits_mask = ((mask_logits_threshold_2 - scores) / factor) > (2 * jitter_eps) + + multiplier_2 = torch.softmax( + torch.scatter(scores, -1, selected_experts[:, 0].unsqueeze(-1), float("-inf")).masked_fill( + logits_mask, float("-inf") + ), + dim=-1, + ).gather(dim=-1, index=selected_experts[:, 1].unsqueeze(-1)) + + multiplier = torch.concat((multiplier_1, multiplier_2), dim=-1) + + return ( + multiplier, + selected_experts, + ) + + +class SparseMoeBlockORTHelper(nn.Module): + def __init__(self, quant_bits=0, onnx_dtype=None): + super().__init__() + self.quant_bits = quant_bits + if onnx_dtype is None: + self.onnx_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT + else: + self.onnx_dtype = onnx_dtype + self.np_type = numpy.float16 if self.onnx_dtype == TensorProto.FLOAT16 else numpy.float32 + + def create_ort_session(self, moe_onnx_graph): + if moe_onnx_graph is None: + return None + + sess_options = onnxruntime.SessionOptions() + sess_options.log_severity_level = 2 + + try: + ort_session = onnxruntime.InferenceSession(moe_onnx_graph, sess_options, providers=ort_provider) + except Exception: + return None + + return ort_session + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + pass + + def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False) -> torch.Tensor: + if self.ort_sess is None: + return None + + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states_flat) + + torch_dtype = onnx_to_torch_type_map[self.onnx_dtype] + + tensors = { + "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype), + "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype), + "output": torch.zeros_like(hidden_states_flat, device=device, dtype=torch_dtype), + } + + try: + iobinding = self.ort_sess.io_binding() + + for name, tensor in tensors.items(): + if name == "output": + iobinding.bind_output( + name=name, + device_type=tensor.device.type, + device_id=tensor.device.index or 0, + element_type=self.onnx_dtype, + shape=tensor.shape, + buffer_ptr=tensor.data_ptr(), + ) + else: + iobinding.bind_input( + name=name, + device_type=tensor.device.type, + device_id=tensor.device.index or 0, + element_type=self.onnx_dtype, + shape=tensor.shape, + buffer_ptr=tensor.data_ptr(), + ) + + iobinding.synchronize_inputs() + self.ort_sess.run_with_iobinding(iobinding) + iobinding.synchronize_outputs() + + if enable_performance_test: + repeat = 100 + s = time.time() + for _ in range(repeat): + iobinding.synchronize_inputs() + self.ort_sess.run_with_iobinding(iobinding) + iobinding.synchronize_outputs() + e = time.time() + time_ms = (e - s) / repeat * 1000 + is_swiglu = hasattr(self, "use_swiglu") and self.use_swiglu + is_interleaved = hasattr(self, "swiglu_interleaved") and self.swiglu_interleaved + act_type = f"SwiGLU(interleaved={is_interleaved})" if is_swiglu else "SiLU" + print(f"ORT Performance - {act_type} {self.quant_bits}-bit: {time_ms:.3f} ms/inference") + + return tensors["output"].reshape(batch_size, sequence_length, hidden_dim) + + except Exception as e: + raise + + def recreate_onnx_model(self): + """Recreate the ONNX model with the current weights to reflect any changes to the quantization code.""" + + w1_list, w2_list = [], [] + w1_scale_list, w2_scale_list = [], [] + + is_4_bit = self.quant_bits == 4 + for i in range(self.num_experts): + w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, is_4_bit) + w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit) + + if self.use_swiglu: + if self.swiglu_interleaved: + pass + else: + w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit) + + gate_weights = pre_qweight1 + value_weights = pre_qweight3 + gate_scales = w1_scale + value_scales = w3_scale + + pre_qweight1 = torch.cat([gate_weights, value_weights], dim=0) + w1_scale = torch.cat([gate_scales, value_scales], dim=0) + + if self.swiglu_interleaved: + self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) + + else: + intermediate_size = self.experts[i].w1.weight.shape[0] + gate_dequant = w1_qdq[:intermediate_size].contiguous().clone() + value_dequant = w1_qdq[intermediate_size:].contiguous().clone() + self.experts[i].w1.weight.data = gate_dequant + self.experts[i].w3.weight.data = value_dequant + else: + self.experts[i].w1.weight.data = w1_qdq.contiguous().clone() + + self.experts[i].w2.weight.data = w2_qdq.contiguous().clone() + + w1_list.append(pre_qweight1) + w2_list.append(pre_qweight2) + w1_scale_list.append(w1_scale) + w2_scale_list.append(w2_scale) + + self.moe_experts_weight1 = torch.stack(w1_list, dim=0) + self.moe_experts_weight2 = torch.stack(w2_list, dim=0) + + moe_experts_weight_scale1 = torch.stack(w1_scale_list, dim=0) + moe_experts_weight_scale2 = torch.stack(w2_scale_list, dim=0) + + if moe_experts_weight_scale1.dim() == 3: + moe_experts_weight_scale1 = moe_experts_weight_scale1.squeeze(-1) + if moe_experts_weight_scale2.dim() == 3: + moe_experts_weight_scale2 = moe_experts_weight_scale2.squeeze(-1) + + try: + self.moe_onnx_graph = create_cpu_moe_onnx_graph( + hidden_size=self.hidden_dim, + sequence_length=self.batch_size * self.sequence_length, + num_experts=self.num_experts, + top_k=self.top_k, + intermediate_size=self.ffn_dim, + torch_dtype=torch.float32, + onnx_dtype=self.onnx_dtype, + fc1_experts_weights=self.moe_experts_weight1, + fc2_experts_weights=self.moe_experts_weight2, + # Biases are not used in QMoE + fc1_bias=None, + fc2_bias=None, + # Scales are used for dequantization + fc1_scales=moe_experts_weight_scale1, + fc2_scales=moe_experts_weight_scale2, + use_swiglu=self.use_swiglu, + use_quant=True, # Always use QMoE + quant_bits=self.quant_bits, + swiglu_interleaved=self.swiglu_interleaved if hasattr(self, "swiglu_interleaved") else False, + ) + except Exception: + self.moe_onnx_graph = None + return False + + self.ort_sess = self.create_ort_session(self.moe_onnx_graph) if self.moe_onnx_graph else None + return self.ort_sess is not None + + def parity_check(self): + model_updated = self.recreate_onnx_model() + if not model_updated: + return + + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) + torch_output = self.forward(hidden_state) + ort_output = self.ort_forward(hidden_state) + + if ort_output is None: + return + + torch_has_nan = torch.isnan(torch_output).any() + ort_has_nan = torch.isnan(ort_output).any() + torch_has_inf = torch.isinf(torch_output).any() + ort_has_inf = torch.isinf(ort_output).any() + + if torch_has_nan or ort_has_nan or torch_has_inf or ort_has_inf: + torch_output_clean = torch.where( + torch.isnan(torch_output) | torch.isinf(torch_output), torch.zeros_like(torch_output), torch_output + ) + ort_output_clean = torch.where( + torch.isnan(ort_output) | torch.isinf(ort_output), torch.zeros_like(ort_output), ort_output + ) + max_diff = (torch_output_clean.cpu() - ort_output_clean.cpu()).abs().max() + + if (torch_has_nan and ort_has_nan) or (torch_has_inf and ort_has_inf): + problematic_torch = torch.isnan(torch_output) | torch.isinf(torch_output) + problematic_ort = torch.isnan(ort_output) | torch.isinf(ort_output) + if torch.equal(problematic_torch, problematic_ort): + max_diff = 0.0 + else: + max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max() + + is_swiglu = hasattr(self, "use_swiglu") and self.use_swiglu + is_interleaved = hasattr(self, "swiglu_interleaved") and self.swiglu_interleaved + act_type = f"SwiGLU(interleaved={is_interleaved})" if is_swiglu else "SiLU" + + print(f"Parity check - {act_type} {self.quant_bits}-bit: max_diff = {max_diff:.6f}") + + ort_dtype_quant_bits_tolerance_map = { + "FP32:0": (5e-3, 1e-3), + "FP16:0": (5e-2, 1e-3), + "FP16:4": (0.05, 0.01), + "FP16:8": (0.02, 0.01), + "FP32:4": (0.11, 0.01), + "FP32:8": (0.11, 0.01), + } + + dtype_str = ort_dtype_name_map[self.onnx_dtype] + tolerance_key = f"{dtype_str}:{self.quant_bits}" + if tolerance_key in ort_dtype_quant_bits_tolerance_map: + base_atol, rtol = ort_dtype_quant_bits_tolerance_map[tolerance_key] + + if max_diff > base_atol: + raise AssertionError( + f"QMoE parity check failed: max difference {max_diff:.6f} exceeds " + f"tolerance {base_atol:.6f} for {tolerance_key}" + ) + else: + fallback_atol = 0.1 + if max_diff > fallback_atol: + raise AssertionError( + f"QMoE parity check failed: max difference {max_diff:.6f} exceeds " + f"fallback tolerance {fallback_atol:.6f} for unknown config {tolerance_key}" + ) + + def benchmark_ort(self): + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) + self.ort_forward(hidden_state, enable_performance_test=True) + + +def small_test_cases(): + for batch_size in [1, 4]: + for sequence_length in [32, 128]: + yield batch_size, sequence_length + + +class SwigluMoEBlock(SparseMoeBlockORTHelper): + def __init__( + self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + ): + super().__init__(quant_bits, onnx_dtype=onnx_dtype) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_token + self.use_swiglu = True + self.swiglu_interleaved = True + use_quant = self.quant_bits > 0 + + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) + + self.experts = nn.ModuleList([SwigluMlp(config) for _ in range(self.num_experts)]) + + fc1_w_list, fc2_w_list = [], [] + fc1_b_list, fc2_b_list = [], [] + scale_1_list, scale_2_list = [], [] + + for expert in self.experts: + fc1_b_list.append(expert.w1.bias) + fc2_b_list.append(expert.w2.bias) + if not use_quant: + fc1_w_list.append(expert.w1.weight) + fc2_w_list.append(expert.w2.weight) + else: + is_4_bit = self.quant_bits == 4 + + scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) + + expert.w1.weight.data = w1_qdq + expert.w2.weight.data = w2_qdq + + fc1_w_list.append(pre_qweight1) + fc2_w_list.append(pre_qweight2) + scale_1_list.append(scale1) + scale_2_list.append(scale2) + + self.batch_size = batch_size + self.sequence_length = sequence_length + + self.moe_onnx_graph = None + self.ort_sess = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + routing_weights, selected_experts = torch.topk(router_logits, self.top_k, dim=-1) + routing_weights = F.softmax(routing_weights, dim=1, dtype=torch.float) + + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + + +class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper): + def __init__( + self, config: PhiMoEConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + ): + super().__init__(quant_bits, onnx_dtype=onnx_dtype) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + self.router_jitter_noise = config.router_jitter_noise + self.use_swiglu = True + self.swiglu_interleaved = True + use_quant = self.quant_bits > 0 + + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) + + self.experts = nn.ModuleList([PhiMoESwiGLUMLP(config) for _ in range(self.num_experts)]) + + fc1_w_list, fc2_w_list = [], [] + fc1_b_list, fc2_b_list = [], [] + scale_1_list, scale_2_list = [], [] + + for expert in self.experts: + fc1_b_list.append(expert.w1.bias) + fc2_b_list.append(expert.w2.bias) + if not use_quant: + fc1_w_list.append(expert.w1.weight) + fc2_w_list.append(expert.w2.weight) + else: + is_4_bit = self.quant_bits == 4 + + scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) + + expert.w1.weight.data = w1_qdq + expert.w2.weight.data = w2_qdq + + fc1_w_list.append(pre_qweight1) + fc2_w_list.append(pre_qweight2) + scale_1_list.append(scale1) + scale_2_list.append(scale2) + + fc1_experts_weights = torch.stack(fc1_w_list, dim=0) + fc2_experts_weights = torch.stack(fc2_w_list, dim=0) + fc1_experts_bias = torch.stack(fc1_b_list, dim=0) + fc2_experts_bias = torch.stack(fc2_b_list, dim=0) + + moe_experts_weight_scale1 = torch.stack(scale_1_list, dim=0) if use_quant else None + moe_experts_weight_scale2 = torch.stack(scale_2_list, dim=0) if use_quant else None + + self.batch_size = batch_size + self.sequence_length = sequence_length + + self.moe_onnx_graph = create_cpu_moe_onnx_graph( + hidden_size=self.hidden_dim, + sequence_length=self.batch_size * self.sequence_length, + num_experts=self.num_experts, + top_k=self.top_k, + intermediate_size=self.ffn_dim, + torch_dtype=torch.float32, + onnx_dtype=self.onnx_dtype, + fc1_experts_weights=fc1_experts_weights, + fc2_experts_weights=fc2_experts_weights, + fc1_bias=fc1_experts_bias, + fc2_bias=fc2_experts_bias, + fc1_scales=moe_experts_weight_scale1, + fc2_scales=moe_experts_weight_scale2, + use_swiglu=self.use_swiglu, + use_quant=use_quant, + quant_bits=self.quant_bits, + swiglu_interleaved=self.swiglu_interleaved, + ) + + self.ort_sess = self.create_ort_session(self.moe_onnx_graph) if self.moe_onnx_graph else None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """PyTorch reference forward pass using SwiGLU-style routing""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + + +disable_cpu_qmoe_tests = False + +# Define test cases for different MoE types +phi3_test_cases = [ + (1, 32, 4), + (1, 32, 8), + (2, 16, 4), + (2, 16, 8), +] + + +@unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") +class TestPhiQMoECPU(unittest.TestCase): + @parameterized.expand(phi3_test_cases) + def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): + torch.manual_seed(42) + numpy.random.seed(42) + + test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}" + print(f"Running Phi3 QMoE test: {test_config}") + + config = PhiMoEConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_tok=2) + + phi3_moe = PhiMoESparseMoeBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + ) + + hidden_states = torch.randn(batch_size, sequence_length, config.hidden_size).to(torch.float32) + + torch_result = phi3_moe.forward(hidden_states) + + # Verify output shape and basic properties + expected_shape = (batch_size, sequence_length, config.hidden_size) + self.assertEqual(torch_result.shape, expected_shape) + self.assertFalse(torch.isnan(torch_result).any()) + self.assertFalse(torch.isinf(torch_result).any()) + + phi3_moe.parity_check() + + +disable_cpu_qmoe_tests = False + +swiglu_test_cases = [ + (1, 32, 4), + (1, 32, 8), + (2, 16, 4), + (2, 16, 8), +] + + +@unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") +class TestSwigluQMoECPU(unittest.TestCase): + @parameterized.expand(swiglu_test_cases) + def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): + torch.manual_seed(42) + numpy.random.seed(42) + + test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}" + print(f"Running SwiGLU test: {test_config}") + + config = SwigluMoeConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_token=2) + + swiglu_moe = SwigluMoEBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + ) + + hidden_states = torch.randn(batch_size, sequence_length, config.hidden_size).to(torch.float32) + + torch_result = swiglu_moe.forward(hidden_states) + + expected_shape = (batch_size, sequence_length, config.hidden_size) + self.assertEqual(torch_result.shape, expected_shape) + self.assertFalse(torch.isnan(torch_result).any()) + self.assertFalse(torch.isinf(torch_result).any()) + + swiglu_moe.parity_check() + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/shared_lib/test_allocator.cc b/onnxruntime/test/shared_lib/test_allocator.cc index 29f3dfad0f11d..bf9e54e8b3c7b 100644 --- a/onnxruntime/test/shared_lib/test_allocator.cc +++ b/onnxruntime/test/shared_lib/test_allocator.cc @@ -45,12 +45,10 @@ TEST(CApiTest, DefaultAllocator) { TEST(CApiTest, CustomAllocator) { constexpr PATH_TYPE model_path = TSTR("testdata/mul_1.onnx"); - const auto& api = Ort::GetApi(); - // Case 1: Register a custom allocator. { MockedOrtAllocator mocked_allocator; - ASSERT_TRUE(api.RegisterAllocator(*ort_env, &mocked_allocator) == nullptr); + ort_env->RegisterAllocator(&mocked_allocator); Ort::SessionOptions session_options; session_options.AddConfigEntry("session.use_env_allocators", "1"); @@ -62,14 +60,14 @@ TEST(CApiTest, CustomAllocator) { ASSERT_EQ(mocked_allocator.NumAllocations(), std::stoll(stats.GetValue("NumAllocs"))); ASSERT_EQ(mocked_allocator.NumReserveAllocations(), std::stoll(stats.GetValue("NumReserves"))); - ASSERT_TRUE(api.UnregisterAllocator(*ort_env, mocked_allocator.Info()) == nullptr); + ort_env->UnregisterAllocator(mocked_allocator.Info()); } // Case 2: Register a custom allocator with an older API version which does not support GetStats. { MockedOrtAllocator mocked_allocator; mocked_allocator.version = 22; - ASSERT_TRUE(api.RegisterAllocator(*ort_env, &mocked_allocator) == nullptr); + ort_env->RegisterAllocator(&mocked_allocator); Ort::SessionOptions session_options; session_options.AddConfigEntry("session.use_env_allocators", "1"); @@ -81,7 +79,7 @@ TEST(CApiTest, CustomAllocator) { auto stats = allocator.GetStats(); ASSERT_EQ(0, stats.GetKeyValuePairs().size()); - ASSERT_TRUE(api.UnregisterAllocator(*ort_env, mocked_allocator.Info()) == nullptr); + ort_env->UnregisterAllocator(mocked_allocator.Info()); } } #endif diff --git a/onnxruntime/test/shared_lib/test_data_copy.cc b/onnxruntime/test/shared_lib/test_data_copy.cc index 2294bb8d6fdff..e7d9d7715092b 100644 --- a/onnxruntime/test/shared_lib/test_data_copy.cc +++ b/onnxruntime/test/shared_lib/test_data_copy.cc @@ -31,9 +31,6 @@ using StreamUniquePtr = std::unique_ptrGetApi(ORT_API_VERSION); - #ifdef _WIN32 std::string cuda_lib = "onnxruntime_providers_cuda.dll"; #else @@ -47,10 +44,10 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { // register the provider bridge based CUDA EP so allocator and data transfer is available // not all the CIs have the provider library in the expected place so we allow for that const char* ep_registration_name = "ORT CUDA"; - ASSERT_ORTSTATUS_OK(api->RegisterExecutionProviderLibrary(env, ep_registration_name, - ORT_TSTR("onnxruntime_providers_cuda"))); + ort_env->RegisterExecutionProviderLibrary(ep_registration_name, + ORT_TSTR("onnxruntime_providers_cuda")); - const OrtEpDevice* cuda_device = nullptr; + Ort::ConstEpDevice cuda_device{nullptr}; for (const auto& ep_device : ort_env->GetEpDevices()) { std::string vendor{ep_device.EpVendor()}; std::string name = {ep_device.EpName()}; @@ -70,13 +67,11 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { // we pass in the CUDA cudaStream_t from the OrtSyncStream via provider options so need to create it upfront. // in the future the stream should be an input to the Session Run. - OrtSyncStream* stream = nullptr; - StreamUniquePtr stream_ptr; + Ort::SyncStream stream{nullptr}; if (use_streams) { - ASSERT_ORTSTATUS_OK(api->CreateSyncStreamForEpDevice(cuda_device, /*options*/ nullptr, &stream)); - stream_ptr = StreamUniquePtr(stream, [api](OrtSyncStream* stream) { api->ReleaseSyncStream(stream); }); + stream = cuda_device.CreateSyncStream(); - size_t stream_addr = reinterpret_cast(api->SyncStream_GetHandle(stream)); + size_t stream_addr = reinterpret_cast(stream.GetHandle()); options.AddConfigEntry("ep.cudaexecutionprovider.user_compute_stream", std::to_string(stream_addr).c_str()); // we explicitly specify user_compute_stream, so why do we also need to set has_user_compute_stream? options.AddConfigEntry("ep.cudaexecutionprovider.has_user_compute_stream", "1"); @@ -87,24 +82,27 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { size_t num_inputs = session.GetInputCount(); // find the input location so we know which inputs can be provided on device. - std::vector input_locations; - input_locations.resize(num_inputs, nullptr); - ASSERT_ORTSTATUS_OK(api->SessionGetMemoryInfoForInputs(session, input_locations.data(), num_inputs)); + auto input_locations = session.GetMemoryInfoForInputs(); + ASSERT_EQ(session.GetInputCount(), input_locations.size()); + + // Testing coverage + auto input_ep_devices = session.GetEpDeviceForInputs(); + ASSERT_EQ(session.GetInputCount(), input_ep_devices.size()); + + // This is for testing + auto output_locations = session.GetMemoryInfoForOutputs(); + ASSERT_EQ(session.GetOutputCount(), output_locations.size()); std::vector cpu_tensors; // info for device copy - std::vector src_tensor_ptrs; - std::vector dst_tensor_ptrs; - - // values we'll call Run with - std::vector input_tensors; + std::vector device_tensors; ASSERT_EQ(num_inputs, 1); // create cpu based input data. Ort::AllocatorWithDefaultOptions cpu_allocator; - std::vector shape{1, 1, 28, 28}; + constexpr const std::array shape{1, 1, 28, 28}; std::vector input_data(28 * 28, 0.5f); Ort::Value input_value = Ort::Value::CreateTensor(cpu_allocator.GetInfo(), input_data.data(), input_data.size(), @@ -112,15 +110,13 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { cpu_tensors.push_back(std::move(input_value)); for (size_t idx = 0; idx < num_inputs; ++idx) { - const OrtMemoryInfo* mem_info = input_locations[idx]; - OrtDeviceMemoryType mem_type = api->MemoryInfoGetDeviceMemType(mem_info); - OrtMemoryInfoDeviceType device_type; - api->MemoryInfoGetDeviceType(mem_info, &device_type); + auto mem_info = input_locations[idx]; + OrtDeviceMemoryType mem_type = mem_info.GetDeviceMemoryType(); + OrtMemoryInfoDeviceType device_type = mem_info.GetDeviceType(); if (device_type == OrtMemoryInfoDeviceType_GPU && mem_type == OrtDeviceMemoryType_DEFAULT) { // copy to device - OrtAllocator* allocator = nullptr; - ASSERT_ORTSTATUS_OK(api->GetSharedAllocator(env, mem_info, &allocator)); + auto allocator = ort_env->GetSharedAllocator(mem_info); // allocate new on-device memory auto src_shape = cpu_tensors[idx].GetTensorTypeAndShapeInfo().GetShape(); @@ -137,18 +133,12 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &value); */ - src_tensor_ptrs.push_back(cpu_tensors[idx]); - dst_tensor_ptrs.push_back(device_value); - input_tensors.push_back(std::move(device_value)); - } else { - // input is on CPU accessible memory. move to input_tensors - input_tensors.push_back(std::move(cpu_tensors[idx])); + device_tensors.push_back(std::move(device_value)); } } - if (!src_tensor_ptrs.empty()) { - ASSERT_ORTSTATUS_OK(api->CopyTensors(env, src_tensor_ptrs.data(), dst_tensor_ptrs.data(), stream, - src_tensor_ptrs.size())); + if (!device_tensors.empty()) { + ASSERT_CXX_ORTSTATUS_OK(ort_env->CopyTensors(cpu_tensors, device_tensors, stream)); // Stream support is still a work in progress. // @@ -160,18 +150,19 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { // iobinding.SynchronizeInputs(); // this doesn't actually require any bound inputs } - std::vector input_names = {"Input3"}; - std::vector output_names = {"Plus214_Output_0"}; + const auto& input_tensors = (!device_tensors.empty()) ? device_tensors : cpu_tensors; + + constexpr const std::array input_names = {"Input3"}; + constexpr const std::array output_names = {"Plus214_Output_0"}; Ort::Value output; session.Run(Ort::RunOptions{}, input_names.data(), input_tensors.data(), input_tensors.size(), output_names.data(), &output, 1); - const float* results = nullptr; - ASSERT_ORTSTATUS_OK(api->GetTensorData(output, reinterpret_cast(&results))); + const float* results = output.GetTensorData(); // expected results from the CPU EP. can check/re-create by running with PREFER_CPU. - std::vector expected = { + constexpr const std::array expected = { -0.701670527f, -0.583666623f, 0.0480501056f, @@ -192,7 +183,7 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { run_test(/*use_streams*/ true); run_test(/*use_streams*/ false); - ASSERT_ORTSTATUS_OK(api->UnregisterExecutionProviderLibrary(env, ep_registration_name)); + ort_env->UnregisterExecutionProviderLibrary(ep_registration_name); } #endif // USE_CUDA diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 56cc234a63832..b7a9da8e1b658 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1901,14 +1901,6 @@ TEST(CApiTest, test_pyop_kwarg) { } #endif -#ifdef ORT_RUN_EXTERNAL_ONNX_TESTS -TEST(CApiTest, create_session_without_session_option) { - constexpr PATH_TYPE model_uri = TSTR("../models/opset8/test_squeezenet/model.onnx"); - Ort::Session ret(*ort_env, model_uri, Ort::SessionOptions{nullptr}); - ASSERT_NE(nullptr, ret); -} -#endif - #ifdef REDUCED_OPS_BUILD TEST(ReducedOpsBuildTest, test_excluded_ops) { // In reduced ops build, test a model containing ops not included in required_ops.config cannot be loaded. @@ -1973,7 +1965,7 @@ static bool CreateSessionWithQnnEpAndQnnHtpSharedMemoryAllocator(PATH_TYPE model TEST(CApiTest, get_allocator_cpu) { Ort::SessionOptions session_options; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); + session_options.AppendExecutionProvider_CPU(1); Ort::Session session(*ort_env, NAMED_AND_ANON_DIM_PARAM_URI, session_options); Ort::MemoryInfo info_cpu = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault); Ort::Allocator cpu_allocator(session, info_cpu); @@ -2018,8 +2010,7 @@ TEST(CApiTest, get_allocator_cpu) { #ifdef USE_CUDA TEST(CApiTest, get_allocator_cuda) { Ort::SessionOptions session_options; - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + Ort::CUDAProviderOptions options; session_options.AppendExecutionProvider_CUDA_V2(*options); Ort::Session session(*ort_env, NAMED_AND_ANON_DIM_PARAM_URI, session_options); @@ -2101,7 +2092,7 @@ TEST(CApiTest, get_allocator_qnn_htp_shared) { TEST(CApiTest, io_binding) { Ort::SessionOptions session_options; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); + session_options.AppendExecutionProvider_CPU(1); Ort::Session session(*ort_env, MODEL_URI, session_options); Ort::MemoryInfo info_cpu = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault); @@ -2175,11 +2166,10 @@ TEST(CApiTest, io_binding) { TEST(CApiTest, io_binding_cuda) { Ort::SessionOptions session_options; #ifdef USE_TENSORRT - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Tensorrt(session_options, 0)); + session_options.AppendExecutionProvider_TensorRT({}); #else - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); - session_options.AppendExecutionProvider_CUDA_V2(*options); + Ort::CUDAProviderOptions cuda_options; + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); #endif Ort::Session session(*ort_env, MODEL_URI, session_options); @@ -2376,35 +2366,25 @@ TEST(CApiTest, io_binding_qnn_htp_shared) { #if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) || defined(USE_DML) TEST(CApiTest, basic_cuda_graph) { - const auto& api = Ort::GetApi(); + [[maybe_unused]] const auto& api = Ort::GetApi(); Ort::SessionOptions session_options; #if defined(USE_TENSORRT) // Enable cuda graph in TRT provider option. - OrtTensorRTProviderOptionsV2* trt_options; - ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); - std::unique_ptr - rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); - std::vector keys{"trt_cuda_graph_enable"}; - std::vector values{"1"}; - ASSERT_TRUE(api.UpdateTensorRTProviderOptions(rel_trt_options.get(), keys.data(), values.data(), keys.size()) == nullptr); + Ort::TensorRTProviderOptions trt_options; + std::unordered_map trt_options_map = {{"trt_cuda_graph_enable", + "1"}}; + trt_options.Update(trt_options_map); + session_options.AppendExecutionProvider_TensorRT_V2(*trt_options); - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( - static_cast(session_options), - rel_trt_options.get()) == nullptr); #elif defined(USE_CUDA) // Enable cuda graph in cuda provider option. - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); - std::unique_ptr - rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions); - std::vector keys{"enable_cuda_graph"}; - std::vector values{"1"}; - ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 1) == nullptr); + Ort::CUDAProviderOptions cuda_options; + std::unordered_map options_map = {{"enable_cuda_graph", + "1"}}; + cuda_options.Update(options_map); + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2( - static_cast(session_options), - rel_cuda_options.get()) == nullptr); #elif defined(USE_ROCM) // Enable hip graph in rocm provider option. OrtROCMProviderOptions* rocm_options = nullptr; @@ -2699,7 +2679,7 @@ static void RunWithCudaGraphAnnotation(T& cg_data, } TEST(CApiTest, basic_cuda_graph_with_annotation) { - const auto& api = Ort::GetApi(); + [[maybe_unused]] const auto& api = Ort::GetApi(); Ort::SessionOptions session_options; #ifdef USE_DML @@ -2712,17 +2692,11 @@ TEST(CApiTest, basic_cuda_graph_with_annotation) { Ort::MemoryInfo info_mem("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeDefault); #elif defined(USE_CUDA) // Enable cuda graph in cuda provider option. - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); - std::unique_ptr - rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions); - std::vector keys{"enable_cuda_graph"}; - std::vector values{"1"}; - ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 1) == nullptr); + Ort::CUDAProviderOptions cuda_options; + std::unordered_map options_map = {{"enable_cuda_graph", "1"}}; + cuda_options.Update(options_map); + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2( - static_cast(session_options), - rel_cuda_options.get()) == nullptr); Ort::MemoryInfo info_mem("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); #elif defined(USE_ROCM) // Enable hip graph in rocm provider option. @@ -2771,21 +2745,15 @@ TEST(CApiTest, basic_cuda_graph_with_annotation) { #ifndef REDUCED_OPS_BUILD #if defined(USE_CUDA) || defined(USE_TENSORRT) TEST(CApiTest, cuda_graph_with_shape_nodes) { - const auto& api = Ort::GetApi(); + [[maybe_unused]] const auto& api = Ort::GetApi(); // Enable cuda graph in cuda provider option. - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); - std::unique_ptr - rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions); - std::vector keys{"enable_cuda_graph"}; - std::vector values{"1"}; - ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 1) == nullptr); + Ort::CUDAProviderOptions cuda_options; + const std::unordered_map options_map = {{"enable_cuda_graph", "1"}}; + cuda_options.Update(options_map); Ort::SessionOptions session_options; - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2( - static_cast(session_options), - rel_cuda_options.get()) == nullptr); + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); // Successful loading of the ONNX model with shape nodes with cuda graph feature enabled Ort::Session session(*ort_env, TSTR("testdata/cuda_graph_with_shape_nodes.onnx"), session_options); @@ -3316,13 +3284,9 @@ TEST(CApiTest, model_metadata) { } TEST(CApiTest, get_available_providers) { - const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); - int len = 0; - char** providers; - ASSERT_EQ(g_ort->GetAvailableProviders(&providers, &len), nullptr); - ASSERT_GT(len, 0); - ASSERT_STREQ(providers[len - 1], "CPUExecutionProvider"); - ASSERT_EQ(g_ort->ReleaseAvailableProviders(providers, len), nullptr); + std::vector providers = Ort::GetAvailableProviders(); + ASSERT_GT(providers.size(), 0); + ASSERT_STREQ(providers.back().c_str(), "CPUExecutionProvider"); } TEST(CApiTest, get_available_providers_cpp) { @@ -3348,8 +3312,6 @@ TEST(CApiTest, get_build_info_string) { } TEST(CApiTest, TestSharedAllocators) { - OrtEnv* env_ptr = (OrtEnv*)(*ort_env); - // prepare inputs std::vector> inputs(1); auto& input = inputs.back(); @@ -3367,28 +3329,17 @@ TEST(CApiTest, TestSharedAllocators) { // Turn on sharing of the allocator between sessions session_options.AddConfigEntry(kOrtSessionOptionsConfigUseEnvAllocators, "1"); - const auto& api = Ort::GetApi(); - // CASE 1: We test creating and registering an ORT-internal allocator implementation instance // for sharing between sessions { - OrtMemoryInfo* mem_info = nullptr; - ASSERT_TRUE(api.CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &mem_info) == nullptr); - std::unique_ptr rel_info(mem_info, api.ReleaseMemoryInfo); - - OrtArenaCfg* arena_cfg = nullptr; - ASSERT_TRUE(api.CreateArenaCfg(0, -1, -1, -1, &arena_cfg) == nullptr); - std::unique_ptr rel_arena_cfg(arena_cfg, api.ReleaseArenaCfg); + auto mem_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + Ort::ArenaCfg arena_cfg(0, -1, -1, -1); // This creates an ORT-internal allocator instance and registers it in the environment for sharing // NOTE: On x86 builds arenas are not supported and will default to using non-arena based allocator - ASSERT_TRUE(api.CreateAndRegisterAllocator(env_ptr, mem_info, arena_cfg) == nullptr); - + ort_env->CreateAndRegisterAllocator(mem_info, arena_cfg); // Registration is always a replace operation - std::unique_ptr status_releaser( - api.CreateAndRegisterAllocator(env_ptr, mem_info, arena_cfg), - api.ReleaseStatus); - ASSERT_TRUE(status_releaser.get() == nullptr); + ort_env->CreateAndRegisterAllocator(mem_info, arena_cfg); { // create session 1 @@ -3414,7 +3365,7 @@ TEST(CApiTest, TestSharedAllocators) { // Remove the registered shared allocator for part 2 of this test // where-in we will register a custom allocator for the same device. - ASSERT_TRUE(api.UnregisterAllocator(env_ptr, mem_info) == nullptr); + ort_env->UnregisterAllocator(mem_info); } // CASE 2: We test registering a custom allocator implementation @@ -3425,15 +3376,9 @@ TEST(CApiTest, TestSharedAllocators) { // need to be aligned for certain devices/build configurations/math libraries. // See docs/C_API.md for details. MockedOrtAllocator custom_allocator; - ASSERT_TRUE(api.RegisterAllocator(env_ptr, &custom_allocator) == nullptr); - + ort_env->RegisterAllocator(&custom_allocator); // Registration is always a replace operation - std::unique_ptr - status_releaser( - api.RegisterAllocator(env_ptr, &custom_allocator), - api.ReleaseStatus); - ASSERT_TRUE(status_releaser.get() == nullptr); - + ort_env->RegisterAllocator(&custom_allocator); { // Keep this scoped to destroy the underlying sessions after use // This should trigger frees in our custom allocator @@ -3472,7 +3417,7 @@ TEST(CApiTest, TestSharedAllocators) { // Remove the registered shared allocator from the global environment // (common to all tests) to prevent its accidental usage elsewhere - ASSERT_TRUE(api.UnregisterAllocator(env_ptr, custom_allocator.Info()) == nullptr); + ort_env->UnregisterAllocator(custom_allocator.Info()); // Ensure that the registered custom allocator was indeed used for both sessions // We should have seen 2 allocations per session (one for the sole initializer @@ -3488,22 +3433,15 @@ TEST(CApiTest, TestSharedAllocators) { } #ifdef USE_CUDA { - OrtMemoryInfo* cuda_meminfo = nullptr; - ASSERT_TRUE(api.CreateMemoryInfo("Cuda", OrtArenaAllocator, 0, OrtMemTypeDefault, &cuda_meminfo) == nullptr); - std::unique_ptr rel_info(cuda_meminfo, api.ReleaseMemoryInfo); - - OrtArenaCfg* arena_cfg = nullptr; - ASSERT_TRUE(api.CreateArenaCfg(0, -1, -1, -1, &arena_cfg) == nullptr); - std::unique_ptr rel_arena_cfg(arena_cfg, api.ReleaseArenaCfg); + auto cuda_meminfo = Ort::MemoryInfo("Cuda", OrtArenaAllocator, 0, OrtMemTypeDefault); - std::vector keys, values; - ASSERT_TRUE(api.CreateAndRegisterAllocatorV2(env_ptr, onnxruntime::kCudaExecutionProvider, cuda_meminfo, arena_cfg, keys.data(), values.data(), 0) == nullptr); + Ort::ArenaCfg arena_cfg(0, -1, -1, -1); + ort_env->CreateAndRegisterAllocatorV2(onnxruntime::kCudaExecutionProvider, + cuda_meminfo, {}, arena_cfg); // Registration is always a replace operation - std::unique_ptr status_releaser( - api.CreateAndRegisterAllocatorV2(env_ptr, onnxruntime::kCudaExecutionProvider, cuda_meminfo, arena_cfg, keys.data(), values.data(), 0), - api.ReleaseStatus); - ASSERT_TRUE(status_releaser.get() == nullptr); + ort_env->CreateAndRegisterAllocatorV2(onnxruntime::kCudaExecutionProvider, + cuda_meminfo, {}, arena_cfg); { // create session 1 @@ -3530,7 +3468,7 @@ TEST(CApiTest, TestSharedAllocators) { nullptr); } - ASSERT_TRUE(api.UnregisterAllocator(env_ptr, cuda_meminfo) == nullptr); + ort_env->UnregisterAllocator(cuda_meminfo); } #endif } @@ -3558,16 +3496,10 @@ TEST(CApiTest, TestSharingOfInitializerAndItsPrepackedVersion) { Ort::Value val = Ort::Value::CreateTensor(mem_info, data, data_len, shape, shape_len); session_options.AddInitializer("W", val); - const auto& api = Ort::GetApi(); - - OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr; - ASSERT_TRUE(api.CreatePrepackedWeightsContainer(&prepacked_weights_container) == nullptr); - std::unique_ptr - rel_prepacked_weights_container(prepacked_weights_container, api.ReleasePrepackedWeightsContainer); - auto default_allocator = std::make_unique(); // create session 1 (using model path) + Ort::PrepackedWeightsContainer prepacked_weights_container; Ort::Session session1(*ort_env, MATMUL_MODEL_URI, session_options, prepacked_weights_container); RunSession(default_allocator.get(), session1, @@ -3664,12 +3596,11 @@ TEST(CApiTest, AllocateInitializersFromNonArenaMemory) { Ort::SessionOptions session_options; #ifdef USE_CUDA - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + Ort::CUDAProviderOptions options; session_options.AppendExecutionProvider_CUDA_V2(*options); #else // arena is enabled but the sole initializer will still be allocated from non-arena memory - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); + session_options.AppendExecutionProvider_CPU(1); #endif // disable using arena for the sole initializer in the model @@ -3685,16 +3616,18 @@ TEST(CApiTest, AllocateInitializersFromNonArenaMemory) { // Usage example showing how to use CreateArenaCfgV2() API to configure the default memory CUDA arena allocator TEST(CApiTest, ConfigureCudaArenaAndDemonstrateMemoryArenaShrinkage) { - const auto& api = Ort::GetApi(); - Ort::SessionOptions session_options; - const char* keys[] = {"max_mem", "arena_extend_strategy", "initial_chunk_size_bytes", "max_dead_bytes_per_chunk", "initial_growth_chunk_size_bytes", "max_power_of_two_extend_bytes"}; - const size_t values[] = {0 /*let ort pick default max memory*/, 0, 1024, 0, 256, 1L << 24}; + const std::unordered_map config_map = { + {"max_mem", 0}, // let ort pick default max memory + {"arena_extend_strategy", 0}, // use default extend strategy + {"initial_chunk_size_bytes", 1024}, // initial chunk size in bytes + {"max_dead_bytes_per_chunk", 0}, // no dead bytes per chunk + {"initial_growth_chunk_size_bytes", 256}, // initial growth chunk size in bytes + {"max_power_of_two_extend_bytes", 1L << 24} // max power of two extend bytes + }; - OrtArenaCfg* arena_cfg = nullptr; - ASSERT_TRUE(api.CreateArenaCfgV2(keys, values, 5, &arena_cfg) == nullptr); - std::unique_ptr rel_arena_cfg(arena_cfg, api.ReleaseArenaCfg); + Ort::ArenaCfg arena_cfg(config_map); OrtCUDAProviderOptions cuda_provider_options = CreateDefaultOrtCudaProviderOptionsWithCustomStream(nullptr); cuda_provider_options.default_memory_arena_cfg = arena_cfg; @@ -3718,24 +3651,16 @@ TEST(CApiTest, ConfigureCudaArenaAndDemonstrateMemoryArenaShrinkage) { #ifdef USE_TENSORRT TEST(TensorrtExecutionProviderTest, ShapeTensorTest) { - const auto& api = Ort::GetApi(); - // Test input tensor which is shape tensor with explicit trt profile shapes Ort::SessionOptions session_options; - OrtTensorRTProviderOptionsV2* trt_options; - ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); - std::unique_ptr - rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); - - const char* trt_profile_min_shapes = "data:2x2,shape:4x1"; - const char* trt_profile_max_shapes = "data:2x2,shape:4x1"; - const char* trt_profile_opt_shapes = "data:2x2,shape:4x1"; - std::vector keys{"trt_profile_min_shapes", "trt_profile_max_shapes", "trt_profile_opt_shapes"}; - std::vector values{trt_profile_min_shapes, trt_profile_max_shapes, trt_profile_opt_shapes}; - ASSERT_TRUE(api.UpdateTensorRTProviderOptions(rel_trt_options.get(), keys.data(), values.data(), keys.size()) == nullptr); - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( - static_cast(session_options), - rel_trt_options.get()) == nullptr); + Ort::TensorRTProviderOptions trt_options; + + std::unordered_map trt_options_map = { + {"trt_profile_min_shapes", "data:2x2,shape:4x1"}, + {"trt_profile_max_shapes", "data:2x2,shape:4x1"}, + {"trt_profile_opt_shapes", "data:2x2,shape:4x1"}}; + trt_options.Update(trt_options_map); + session_options.AppendExecutionProvider_TensorRT_V2(*trt_options); auto model_path = ORT_TSTR("testdata/trt_reshape.onnx"); @@ -3758,37 +3683,24 @@ TEST(TensorrtExecutionProviderTest, ShapeTensorTest) { // Test input tensor which is shape tensor with implicit trt profile shapes Ort::SessionOptions session_options_2; - OrtTensorRTProviderOptionsV2* trt_options_2; - ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options_2) == nullptr); - std::unique_ptr - rel_trt_options_2(trt_options_2, api.ReleaseTensorRTProviderOptions); - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( - static_cast(session_options_2), - rel_trt_options_2.get()) == nullptr); + Ort::TensorRTProviderOptions trt_options_2; + session_options_2.AppendExecutionProvider_TensorRT_V2(*trt_options_2); Ort::Session session_2(*ort_env, model_path, session_options_2); - session_2.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, countof(output_names)); + session_2.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, std::size(output_names)); } TEST(CApiTest, TestExternalCUDAStreamWithIOBinding) { - const auto& api = Ort::GetApi(); Ort::SessionOptions session_options; - - OrtTensorRTProviderOptionsV2* trt_options; - ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); - std::unique_ptr - rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); + Ort::TensorRTProviderOptions trt_options; // updating provider option with user provided compute stream cudaStream_t compute_stream = nullptr; - void* user_compute_stream = nullptr; cudaStreamCreate(&compute_stream); - ASSERT_TRUE(api.UpdateTensorRTProviderOptionsWithValue(rel_trt_options.get(), "user_compute_stream", compute_stream) == nullptr); - ASSERT_TRUE(api.GetTensorRTProviderOptionsByName(rel_trt_options.get(), "user_compute_stream", &user_compute_stream) == nullptr); + trt_options.UpdateWithValue("user_compute_stream", compute_stream); + void* user_compute_stream = trt_options.GetOptionByName("user_compute_stream"); ASSERT_TRUE(user_compute_stream == (void*)compute_stream); - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( - static_cast(session_options), - rel_trt_options.get()) == nullptr); + session_options.AppendExecutionProvider_TensorRT_V2(*trt_options); Ort::Session session(*ort_env, MODEL_URI, session_options); Ort::MemoryInfo info_cuda("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); @@ -3901,36 +3813,30 @@ class CApiTensorRTTest : public testing::Test, public ::testing::WithParamInterf TEST_P(CApiTensorRTTest, TestConfigureTensorRTProviderOptions) { std::string param = GetParam(); size_t pos = param.find("="); - std::string option_name = param.substr(0, pos); - std::string option_value = param.substr(pos + 1); + const std::string option_name = param.substr(0, pos); + const std::string option_value = param.substr(pos + 1); ASSERT_NE(pos, std::string::npos); - const auto& api = Ort::GetApi(); - OrtTensorRTProviderOptionsV2* trt_options; - OrtAllocator* allocator; - char* trt_options_str; - ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); - std::unique_ptr rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); - const char* engine_cache_path = "./trt_engine_folder"; - std::vector keys{"device_id", "has_user_compute_stream", "trt_fp16_enable", "trt_int8_enable", "trt_engine_cache_enable", - "trt_engine_cache_path", option_name.c_str()}; - - std::vector values{"0", "0", "1", "0", "1", - engine_cache_path, option_value.c_str()}; + Ort::TensorRTProviderOptions trt_options; + std::unordered_map trt_options_map = { + {"device_id", "0"}, + {"has_user_compute_stream", "0"}, + {"trt_fp16_enable", "1"}, + {"trt_int8_enable", "0"}, + {"trt_engine_cache_enable", "1"}, + {"trt_engine_cache_path", engine_cache_path}, + {option_name, option_value}}; - ASSERT_TRUE(api.UpdateTensorRTProviderOptions(rel_trt_options.get(), keys.data(), values.data(), keys.size()) == nullptr); + trt_options.Update(trt_options_map); - ASSERT_TRUE(api.GetAllocatorWithDefaultOptions(&allocator) == nullptr); - ASSERT_TRUE(api.GetTensorRTProviderOptionsAsString(rel_trt_options.get(), allocator, &trt_options_str) == nullptr); - std::string s(trt_options_str); - ASSERT_TRUE(s.find(engine_cache_path) != std::string::npos); - ASSERT_TRUE(s.find(param.c_str()) != std::string::npos); - ASSERT_TRUE(api.AllocatorFree(allocator, (void*)trt_options_str) == nullptr); + std::string trt_options_str = trt_options.GetTensorRTProviderOptionsAsString(); + ASSERT_NE(trt_options_str.find(engine_cache_path), std::string::npos); + ASSERT_NE(trt_options_str.find(param), std::string::npos); Ort::SessionOptions session_options; - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2(static_cast(session_options), rel_trt_options.get()) == nullptr); + session_options.AppendExecutionProvider_TensorRT_V2(*trt_options); // simple inference test // prepare inputs @@ -3973,40 +3879,30 @@ INSTANTIATE_TEST_SUITE_P(CApiTensorRTTest, CApiTensorRTTest, // This test uses CreateCUDAProviderOptions/UpdateCUDAProviderOptions/UpdateCUDAProviderOptionsWithValue APIs to configure and create a CUDA Execution Provider instance TEST(CApiTest, TestConfigureCUDAProviderOptions) { - const auto& api = Ort::GetApi(); - - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); - std::unique_ptr rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions); + Ort::CUDAProviderOptions cuda_options; // Only test updating OrtCUDAProviderOptionsV2 instance with user provided compute stream not running the inference cudaStream_t compute_stream = nullptr; void* user_compute_stream = nullptr; cudaStreamCreateWithFlags(&compute_stream, cudaStreamNonBlocking); - ASSERT_TRUE(api.UpdateCUDAProviderOptionsWithValue(rel_cuda_options.get(), "user_compute_stream", compute_stream) == nullptr); - ASSERT_TRUE(api.GetCUDAProviderOptionsByName(rel_cuda_options.get(), "user_compute_stream", &user_compute_stream) == nullptr); + cuda_options.UpdateWithValue("user_compute_stream", compute_stream); + user_compute_stream = cuda_options.GetOptionByName("user_compute_stream"); ASSERT_TRUE(user_compute_stream == (void*)compute_stream); cudaStreamDestroy(compute_stream); - std::vector keys{ - "device_id", "has_user_compute_stream", "gpu_mem_limit", "arena_extend_strategy", - "cudnn_conv_algo_search", "do_copy_in_default_stream", "cudnn_conv_use_max_workspace", "cudnn_conv1d_pad_to_nc1d"}; - - std::vector values{ - "0", "0", "1024", "kSameAsRequested", - "DEFAULT", "1", "1"}; - - ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 6) == nullptr); + std::unordered_map cuda_options_map = { + {"device_id", "0"}, + {"has_user_compute_stream", "0"}, + {"gpu_mem_limit", "1024"}, + {"arena_extend_strategy", "kSameAsRequested"}, + {"cudnn_conv_algo_search", "DEFAULT"}, + {"do_copy_in_default_stream", "1"}, + {"cudnn_conv_use_max_workspace", "1"}, + {"cudnn_conv1d_pad_to_nc1d", "1"}}; - OrtAllocator* allocator; - ASSERT_TRUE(api.GetAllocatorWithDefaultOptions(&allocator) == nullptr); + cuda_options.Update(cuda_options_map); - char* cuda_options_str = nullptr; - ASSERT_TRUE(api.GetCUDAProviderOptionsAsString(rel_cuda_options.get(), allocator, &cuda_options_str) == nullptr); - std::string s; - if (cuda_options_str != nullptr) { - s = std::string(cuda_options_str, strnlen(cuda_options_str, 2048)); - } + std::string s = cuda_options.GetCUDAProviderOptionsAsString(); ASSERT_TRUE(s.find("device_id=0") != std::string::npos); ASSERT_TRUE(s.find("gpu_mem_limit=1024") != std::string::npos); ASSERT_TRUE(s.find("arena_extend_strategy=kSameAsRequested") != std::string::npos); @@ -4015,10 +3911,8 @@ TEST(CApiTest, TestConfigureCUDAProviderOptions) { ASSERT_TRUE(s.find("cudnn_conv_use_max_workspace=1") != std::string::npos); ASSERT_TRUE(s.find("cudnn_conv1d_pad_to_nc1d") != std::string::npos); - ASSERT_TRUE(api.AllocatorFree(allocator, (void*)cuda_options_str) == nullptr); - Ort::SessionOptions session_options; - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2(static_cast(session_options), rel_cuda_options.get()) == nullptr); + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); // if session creation passes, model loads fine std::basic_string model_uri = MODEL_URI; @@ -4117,9 +4011,8 @@ TEST(CApiTest, GitHubIssue10179) { auto load_model_thread_fn = []() { try { const auto* model_path = MODEL_URI; - Ort::SessionOptions session_options{}; - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + Ort::SessionOptions session_options; + Ort::CUDAProviderOptions options; session_options.AppendExecutionProvider_CUDA_V2(*options); Ort::Session session{*ort_env, model_path, session_options}; } catch (const std::exception& e) { @@ -4150,8 +4043,7 @@ TEST(CApiTest, GitHubIssue10179) { TEST(CApiTest, TestCudaMemcpyToHostWithSequenceTensors) { const auto* model_path = SEQUENCE_MODEL_URI_2; Ort::SessionOptions session_options{}; - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + Ort::CUDAProviderOptions options; session_options.AppendExecutionProvider_CUDA_V2(*options); Ort::Session session{*ort_env, model_path, session_options}; diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc index 0fe747cdd84e5..cffa0efc39d45 100644 --- a/onnxruntime/test/shared_lib/test_model_builder_api.cc +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -420,7 +420,7 @@ TEST(ModelEditorAPITest, BasicModelEdit_CxxApi) { // typically this isn't needed. we replace this input but need to read info from it later on in the test // validation so we save the info locally to keep it accessible. - auto orig_input_name = graph_inputs[0].Name(); + auto orig_input_name = graph_inputs[0].GetName(); auto input_shape = graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetShape(); const std::string new_input_name = "Int64Input"; @@ -589,7 +589,7 @@ TEST(ModelEditorAPITest, InvalidModelEdit) { Node node("Cast", domain, "NewInputNode", {new_input_name}, // the existing node will now consume the output from the Cast instead of a graph input - {graph_inputs[0].Name()}, + {graph_inputs[0].GetName()}, attributes); graph.AddNode(node); diff --git a/onnxruntime/test/shared_lib/test_model_loading.cc b/onnxruntime/test/shared_lib/test_model_loading.cc index 89b12ec61649e..7268c351877f3 100644 --- a/onnxruntime/test/shared_lib/test_model_loading.cc +++ b/onnxruntime/test/shared_lib/test_model_loading.cc @@ -60,8 +60,7 @@ TEST(CApiTest, model_from_array) { create_session(so); #ifdef USE_CUDA - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + Ort::CUDAProviderOptions options; so.AppendExecutionProvider_CUDA_V2(*options); create_session(so); #endif diff --git a/onnxruntime/test/shared_lib/test_session_options.cc b/onnxruntime/test/shared_lib/test_session_options.cc index 3fbb294e1af49..d12a586f662ac 100644 --- a/onnxruntime/test/shared_lib/test_session_options.cc +++ b/onnxruntime/test/shared_lib/test_session_options.cc @@ -54,20 +54,15 @@ TEST(CApiTest, session_options_provider_interface_fail_add_openvino) { #if defined(USE_CUDA_PROVIDER_INTERFACE) // Test that loading CUDA EP when only the interface is built (but not the full EP) fails. TEST(CApiTest, session_options_provider_interface_fail_add_cuda) { - const OrtApi& api = Ort::GetApi(); Ort::SessionOptions session_options; - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - Ort::Status status1 = Ort::Status{api.CreateCUDAProviderOptions(&cuda_options)}; - ASSERT_TRUE(status1.IsOK()); - - Ort::Status status2 = Ort::Status{api.SessionOptionsAppendExecutionProvider_CUDA_V2(session_options, - cuda_options)}; - ASSERT_FALSE(status2.IsOK()); - EXPECT_EQ(status2.GetErrorCode(), ORT_FAIL); - EXPECT_THAT(status2.GetErrorMessage(), testing::HasSubstr("Failed to load")); - - api.ReleaseCUDAProviderOptions(cuda_options); + Ort::CUDAProviderOptions cuda_options; + try { + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); + FAIL() << "Appending CUDA options have thrown exception"; + } catch (const Ort::Exception& ex) { + ASSERT_THAT(ex.what(), testing::HasSubstr("Failed to load")); + } } #endif // defined(USE_CUDA_PROVIDER_INTERFACE) diff --git a/onnxruntime/test/testdata/conv_default_attrs.onnx b/onnxruntime/test/testdata/conv_default_attrs.onnx new file mode 100644 index 0000000000000..fc7ee58dee15e Binary files /dev/null and b/onnxruntime/test/testdata/conv_default_attrs.onnx differ diff --git a/onnxruntime/test/testdata/make_conv_default_attrs.py b/onnxruntime/test/testdata/make_conv_default_attrs.py new file mode 100644 index 0000000000000..fc092bf8b25fb --- /dev/null +++ b/onnxruntime/test/testdata/make_conv_default_attrs.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import numpy as np +import onnx + + +def main(): + inp_shape = (1, 2, 8, 8) + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, inp_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, None) + + weight_data = [ + [[[-1.5, 0.0], [0.2, 1.5]], [[-1.5, 0.0], [0.2, 1.5]]], + [[[-1.0, 0.0], [0.1333, 1.0]], [[-1.0, 0.0], [0.1333, 1.0]]], + ] + weight = onnx.numpy_helper.from_array(np.array(weight_data, dtype=np.float32), "weight") + bias = onnx.numpy_helper.from_array(np.array([0.0, 0.0], dtype=np.float32), "bias") + conv_node = onnx.helper.make_node("Conv", ["input_0", "weight", "bias"], ["output_0"], name="Conv0") + graph = onnx.helper.make_graph( + [conv_node], + "Convf32", + [input_0], + [output_0], + initializer=[weight, bias], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + + onnx.checker.check_model(model, True) + onnx.save_model(model, "conv_default_attrs.onnx") + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 23c3a922326cb..43f6e480672ba 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -776,7 +776,8 @@ //TODO: Resolve as a graph implementation that returns a constant inf tensor with appropriate strides "^test_reduce_max_empty_set_cpu", // DNNL result in "(shapes (2, 1, 4), (1, 0, 1) mismatch)". this is the same for test_reduce_min_empty_set which is already in the list "^test_reduce_min_empty_set_cpu", - "^test_resize_upsample_sizes_nearest_not_smaller_cpu" + "^test_resize_upsample_sizes_nearest_not_smaller_cpu", + "^test_clip_min_greater_than_max_cpu" ], // ORT first supported opset 7, so models with nodes that require versions prior to opset 7 are not supported "tests_with_pre_opset7_dependencies": [ diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 2e4aa3923b649..bae7a14908916 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -80,21 +80,7 @@ std::unique_ptr TensorrtExecutionProviderWithOptions(const O std::unique_ptr DefaultMIGraphXExecutionProvider() { #ifdef USE_MIGRAPHX - OrtMIGraphXProviderOptions params{ - 0, - 0, - 0, - 0, - 0, - nullptr, - 1, - "./compiled_model.mxr", - 1, - "./compiled_model.mxr", - 1, - SIZE_MAX, - 0}; - return MIGraphXProviderFactoryCreator::Create(¶ms)->CreateProvider(); + return MIGraphXProviderFactoryCreator::Create(ProviderOptions{})->CreateProvider(); #else return nullptr; #endif @@ -102,7 +88,7 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { std::unique_ptr MIGraphXExecutionProviderWithOptions(const OrtMIGraphXProviderOptions* params) { #ifdef USE_MIGRAPHX - if (auto factory = MIGraphXProviderFactoryCreator::Create(params)) + if (const auto factory = MIGraphXProviderFactoryCreator::Create(params); factory != nullptr) return factory->CreateProvider(); #else ORT_UNUSED_PARAMETER(params); diff --git a/onnxruntime/test/util/include/api_asserts.h b/onnxruntime/test/util/include/api_asserts.h index 423135f96fbcd..0be3b8bbb0764 100644 --- a/onnxruntime/test/util/include/api_asserts.h +++ b/onnxruntime/test/util/include/api_asserts.h @@ -10,36 +10,29 @@ #include "core/session/onnxruntime_cxx_api.h" // asserts for the public API -#define ASSERT_ORTSTATUS_OK(function) \ - do { \ - OrtStatusPtr _tmp_status = (function); \ - ASSERT_EQ(_tmp_status, nullptr) << Ort::GetApi().GetErrorMessage(_tmp_status); \ - if (_tmp_status) Ort::GetApi().ReleaseStatus(_tmp_status); \ +#define ASSERT_ORTSTATUS_OK(function) \ + do { \ + Ort::Status _tmp_status{(function)}; \ + ASSERT_TRUE(_tmp_status.IsOK()) << _tmp_status.GetErrorMessage(); \ } while (false) -#define EXPECT_ORTSTATUS_OK(api, function) \ - do { \ - OrtStatusPtr _tmp_status = (api->function); \ - EXPECT_EQ(_tmp_status, nullptr) << Ort::GetApi().GetErrorMessage(_tmp_status); \ - if (_tmp_status) Ort::GetApi().ReleaseStatus(_tmp_status); \ +#define EXPECT_ORTSTATUS_OK(api, function) \ + do { \ + Ort::Status _tmp_status{(api->function)}; \ + EXPECT_TRUE(_tmp_status.IsOK()) << _tmp_status.GetErrorMessage(); \ } while (false) -#define ASSERT_ORTSTATUS_NOT_OK(api, function) \ - do { \ - OrtStatusPtr _tmp_status = (api->function); \ - ASSERT_NE(_tmp_status, nullptr); \ - if (_tmp_status) Ort::GetApi().ReleaseStatus(_tmp_status); \ +#define ASSERT_ORTSTATUS_NOT_OK(api, function) \ + do { \ + Ort::Status _tmp_status{(api->function)}; \ + ASSERT_TRUE(_tmp_status.IsOK()); \ } while (false) -#define EXPECT_ORTSTATUS_NOT_OK(api, function) \ - do { \ - OrtStatusPtr _tmp_status = (api->function); \ - EXPECT_NE(_tmp_status, nullptr); \ - if (_tmp_status) Ort::GetApi().ReleaseStatus(_tmp_status); \ +#define EXPECT_ORTSTATUS_NOT_OK(api, function) \ + do { \ + Ort::Status _tmp_status{(api->function)}; \ + EXPECT_FALSE(_tmp_status.IsOK()); \ } while (false) -#define ASSERT_CXX_ORTSTATUS_OK(function) \ - do { \ - Ort::Status _tmp_status = (function); \ - ASSERT_TRUE(_tmp_status.IsOK()) << _tmp_status.GetErrorMessage(); \ - } while (false) +#define ASSERT_CXX_ORTSTATUS_OK(function) \ + ASSERT_ORTSTATUS_OK(function) diff --git a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc index 90be9e24d3dd4..ff220fcb067b8 100644 --- a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc +++ b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc @@ -121,7 +121,7 @@ void Conv1dToMatmul(Graph& graph, Node& conv, const std::string transformer_name initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); InlinedVector initializer_proto_value{weight_squeeze_axis}; initializer_proto.set_raw_data(initializer_proto_value.data(), initializer_proto_value.size() * sizeof(int64_t)); - auto& axes_input = graph_utils::AddInitializerWithExternalData(graph, initializer_proto); + auto& axes_input = graph_utils::AddInitializer(graph, initializer_proto); // Squeeze node doesn't have opschema here, so we need to set input args count manually weight_squeeze.MutableInputArgsCount().resize(2); graph_utils::AddNodeInput(weight_squeeze, 1, axes_input); diff --git a/orttraining/orttraining/core/optimizer/megatron_transformer.cc b/orttraining/orttraining/core/optimizer/megatron_transformer.cc index 55286379fd273..7c429ae5cb643 100644 --- a/orttraining/orttraining/core/optimizer/megatron_transformer.cc +++ b/orttraining/orttraining/core/optimizer/megatron_transformer.cc @@ -453,15 +453,15 @@ Status MegatronTransformer::TransformGPT2MLP(Graph& graph, bool& modified, return skip_status; } - NodeArg& a_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, a_weight_initializer_partition); + NodeArg& a_weight_partition_arg = graph_utils::AddInitializer(graph, a_weight_initializer_partition); graph_utils::ReplaceNodeInput(node, 1, a_weight_partition_arg); updated_weight_names_.insert({a_weight_arg->Name(), a_weight_partition_arg.Name()}); - NodeArg& a_bias_partition_arg = graph_utils::AddInitializerWithExternalData(graph, a_bias_initializer_partition); + NodeArg& a_bias_partition_arg = graph_utils::AddInitializer(graph, a_bias_initializer_partition); graph_utils::ReplaceNodeInput(add_node, 1, a_bias_partition_arg); updated_weight_names_.insert({b_weight_arg->Name(), a_bias_partition_arg.Name()}); - NodeArg& b_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, b_weight_initializer_partition); + NodeArg& b_weight_partition_arg = graph_utils::AddInitializer(graph, b_weight_initializer_partition); graph_utils::ReplaceNodeInput(matmul2_node, 1, b_weight_partition_arg); updated_weight_names_.insert({a_bias_arg->Name(), b_weight_partition_arg.Name()}); @@ -600,15 +600,15 @@ Status MegatronTransformer::TransformBARTMLP(Graph& graph, bool& modified, return skip_status; } - NodeArg& dense_wi_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, dense_wi_weight_initializer_partition); + NodeArg& dense_wi_weight_partition_arg = graph_utils::AddInitializer(graph, dense_wi_weight_initializer_partition); graph_utils::ReplaceNodeInput(*second_op, 0, dense_wi_weight_partition_arg); updated_weight_names_.insert({dense_wi_weight_arg->Name(), dense_wi_weight_partition_arg.Name()}); - NodeArg& dense_wi_bias_partition_arg = graph_utils::AddInitializerWithExternalData(graph, dense_wi_bias_initializer_partition); + NodeArg& dense_wi_bias_partition_arg = graph_utils::AddInitializer(graph, dense_wi_bias_initializer_partition); graph_utils::ReplaceNodeInput(biasgelu_node, 1, dense_wi_bias_partition_arg); updated_weight_names_.insert({dense_wi_bias_arg->Name(), dense_wi_bias_partition_arg.Name()}); - NodeArg& dense_wo_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, dense_wo_weight_initializer_partition); + NodeArg& dense_wo_weight_partition_arg = graph_utils::AddInitializer(graph, dense_wo_weight_initializer_partition); graph_utils::ReplaceNodeInput(*transpose_op_ptr, 0, dense_wo_weight_partition_arg); updated_weight_names_.insert({dense_wo_weight_arg->Name(), dense_wo_weight_partition_arg.Name()}); @@ -814,15 +814,15 @@ Status MegatronTransformer::TransformGPT2Attention(Graph& graph, bool& modified, [](Node* node_ptr) { return node_ptr != nullptr; }); // Replace by the partition weights. - NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, qkv_weight_initializer_partition); + NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializer(graph, qkv_weight_initializer_partition); graph_utils::ReplaceNodeInput(node, 1, qkv_weight_partition_arg); updated_weight_names_.insert({qkv_weight_arg->Name(), qkv_weight_partition_arg.Name()}); - NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializerWithExternalData(graph, qkv_bias_initializer_partition); + NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializer(graph, qkv_bias_initializer_partition); graph_utils::ReplaceNodeInput(add_node, 1, qkv_bias_partition_arg); updated_weight_names_.insert({qkv_bias_arg->Name(), qkv_bias_partition_arg.Name()}); - NodeArg& dense_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, dense_weight_initializer_partition); + NodeArg& dense_weight_partition_arg = graph_utils::AddInitializer(graph, dense_weight_initializer_partition); graph_utils::ReplaceNodeInput(matmul_node, 1, dense_weight_partition_arg); updated_weight_names_.insert({dense_weight_arg->Name(), dense_weight_partition_arg.Name()}); @@ -849,7 +849,7 @@ Status MegatronTransformer::TransformGPT2Attention(Graph& graph, bool& modified, val_partition.insert(val_partition.end(), val, val + size); val_partition[2] /= horizontal_parallel_size_; tensor_partition.set_raw_data(val_partition.data(), size * sizeof(int64_t)); - NodeArg& node_arg_partition = graph_utils::AddInitializerWithExternalData(graph, tensor_partition); + NodeArg& node_arg_partition = graph_utils::AddInitializer(graph, tensor_partition); graph_utils::ReplaceNodeInput(*node_ptr, 1, node_arg_partition); graph.RemoveInitializedTensor(shape_arg->Name()); } @@ -1130,7 +1130,7 @@ Status MegatronTransformer::TransformBARTAttention(Graph& graph, bool& modified, size_t i = 0; for (auto trans_ptr : weight_transpose_node_ptrs) { auto weight_name = trans_ptr->MutableInputDefs()[0]->Name(); - NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, qkv_weight_initializer_partitions[i]); + NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializer(graph, qkv_weight_initializer_partitions[i]); graph_utils::ReplaceNodeInput(*trans_ptr, 0, qkv_weight_partition_arg); graph.RemoveInitializedTensor(weight_name); updated_weight_names_.insert({weight_name, qkv_weight_partition_arg.Name()}); @@ -1139,14 +1139,14 @@ Status MegatronTransformer::TransformBARTAttention(Graph& graph, bool& modified, i = 0; for (auto add_ptr : bias_add_node_ptrs) { auto bias_name = add_ptr->MutableInputDefs()[1]->Name(); - NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializerWithExternalData(graph, qkv_bias_initializer_partitions[i]); + NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializer(graph, qkv_bias_initializer_partitions[i]); graph_utils::ReplaceNodeInput(*add_ptr, 1, qkv_bias_partition_arg); graph.RemoveInitializedTensor(bias_name); updated_weight_names_.insert({bias_name, qkv_bias_partition_arg.Name()}); i++; } - NodeArg& dense_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, dense_weight_initializer_partition); + NodeArg& dense_weight_partition_arg = graph_utils::AddInitializer(graph, dense_weight_initializer_partition); graph_utils::ReplaceNodeInput(*last_transpose, 0, dense_weight_partition_arg); graph.RemoveInitializedTensor(dense_weight_arg->Name()); updated_weight_names_.insert({dense_weight_arg->Name(), dense_weight_partition_arg.Name()}); @@ -1178,7 +1178,7 @@ Status MegatronTransformer::TransformBARTAttention(Graph& graph, bool& modified, val_partition.insert(val_partition.end(), val, val + size); val_partition[idx] /= horizontal_parallel_size_; tensor_partition.set_raw_data(val_partition.data(), size * sizeof(int64_t)); - NodeArg& node_arg_partition = graph_utils::AddInitializerWithExternalData(graph, tensor_partition); + NodeArg& node_arg_partition = graph_utils::AddInitializer(graph, tensor_partition); graph_utils::ReplaceNodeInput(*node_ptr, 1, node_arg_partition); graph.RemoveInitializedTensor(shape_arg->Name()); } diff --git a/orttraining/orttraining/core/optimizer/qdq_fusion.cc b/orttraining/orttraining/core/optimizer/qdq_fusion.cc index 4a5bdc1f8fcd2..42720dbbb11e5 100644 --- a/orttraining/orttraining/core/optimizer/qdq_fusion.cc +++ b/orttraining/orttraining/core/optimizer/qdq_fusion.cc @@ -45,7 +45,7 @@ int ReplaceOrCreateZeroPointInitializer(Graph& graph, Node& quantize_node) { // Since the quantize node has the zero point initializer input, replace it graph_utils::ReplaceNodeInput(quantize_node, 2, - graph_utils::AddInitializerWithExternalData(graph, zero_point_tensor_float)); + graph_utils::AddInitializer(graph, zero_point_tensor_float)); } else { // The quantize node does not have the zero point optional input. // Create the zero point initializer to be 0. @@ -55,7 +55,7 @@ int ReplaceOrCreateZeroPointInitializer(Graph& graph, Node& quantize_node) { // Since the input did not exist, add the newly created initializer as an input graph_utils::AddNodeInput(quantize_node, 2, - graph_utils::AddInitializerWithExternalData(graph, zero_point_tensor_float)); + graph_utils::AddInitializer(graph, zero_point_tensor_float)); } return zero_point_type; diff --git a/orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.cc b/orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.cc index 8c9c12ceb4497..84bf715c7c85a 100644 --- a/orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.cc +++ b/orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.cc @@ -83,7 +83,7 @@ Status SceLossGradBiasFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ ignore_index_initializer_proto.set_name(graph.GenerateNodeArgName("sce_grad_ignore_index")); ignore_index_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); ignore_index_initializer_proto.add_int64_data(static_cast(-1)); - new_scegrad_node_inputs.emplace_back(&graph_utils::AddInitializerWithExternalData(graph, ignore_index_initializer_proto)); + new_scegrad_node_inputs.emplace_back(&graph_utils::AddInitializer(graph, ignore_index_initializer_proto)); } new_scegrad_node_inputs.emplace_back(bias_def); if (!p_reshape) { diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 2ec77c96dc2d5..6f5d6e1389443 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -1235,7 +1235,7 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_TakeEffect) { auto out_channel = 64; auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); - auto* weight_arg = builder.MakeInitializer({out_channel, in_channel / group, 1}, {-1.0f, 1.0f}); + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel / group, 1}, -1.0f, 1.0f); auto* conv_output = builder.MakeOutput(); auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); @@ -1280,8 +1280,8 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect1) { auto out_channel = 64; auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); - auto* weight_arg = builder.MakeInitializer({out_channel, in_channel, 1}, {-1.0f, 1.0f}); - auto* bias_arg = builder.MakeInitializer({out_channel}, {-1.0f, 1.0f}); + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel, 1}, -1.0f, 1.0f); + auto* bias_arg = builder.MakeInitializer({out_channel}, -1.0f, 1.0f); auto* conv_output = builder.MakeOutput(); auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg, bias_arg}, {conv_output}); @@ -1314,7 +1314,7 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect2) { auto out_channel = 64; auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); - auto* weight_arg = builder.MakeInitializer({out_channel, in_channel, 1}, {-1.0f, 1.0f}); + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel, 1}, -1.0f, 1.0f); auto* conv_output = builder.MakeOutput(); auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); @@ -1347,7 +1347,7 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect3) { auto out_channel = 64; auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); - auto* weight_arg = builder.MakeInitializer({out_channel, in_channel, 1}, {-1.0f, 1.0f}); + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel, 1}, -1.0f, 1.0f); auto* conv_output = builder.MakeOutput(); auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt index 99c36aee85df7..83c29239151c8 100644 --- a/requirements-lintrunner.txt +++ b/requirements-lintrunner.txt @@ -2,5 +2,5 @@ # When any package below is changed, you shall run "lintrunner init" again. lintrunner==0.12.7 lintrunner-adapters==0.12.5 -ruff==0.12.4 +ruff==0.12.9 clang-format==20.1.8 diff --git a/setup.py b/setup.py index 5ab1ac5b840d4..6bfb53329f319 100644 --- a/setup.py +++ b/setup.py @@ -68,6 +68,7 @@ def parse_arg_remove_string(argv, arg_name_equal): is_cuda_version_12 = cuda_version.startswith("12.") elif parse_arg_remove_boolean(sys.argv, "--use_migraphx"): is_migraphx = True + package_name = "onnxruntime-migraphx" elif parse_arg_remove_boolean(sys.argv, "--use_openvino"): is_openvino = True package_name = "onnxruntime-openvino" @@ -90,8 +91,6 @@ def parse_arg_remove_string(argv, arg_name_equal): is_qnn = True package_name = "onnxruntime-qnn" qnn_version = parse_arg_remove_string(sys.argv, "--qnn_version=") -elif is_migraphx: - package_name = "onnxruntime-migraphx" if not nightly_build else "ort-migraphx-nightly" # PEP 513 defined manylinux1_x86_64 and manylinux1_i686 # PEP 571 defined manylinux2010_x86_64 and manylinux2010_i686 @@ -283,7 +282,6 @@ def run(self): self._rewrite_ld_preload_tensorrt(to_preload_tensorrt) self._rewrite_ld_preload_tensorrt(to_preload_nv_tensorrt_rtx) self._rewrite_ld_preload(to_preload_cann) - else: pass @@ -412,6 +410,7 @@ def finalize_options(self): libs.extend(["onnxruntime_providers_nv_tensorrt_rtx.dll"]) libs.extend(["onnxruntime_providers_openvino.dll"]) libs.extend(["onnxruntime_providers_cuda.dll"]) + libs.extend(["onnxruntime_providers_migraphx.dll"]) libs.extend(["onnxruntime_providers_vitisai.dll"]) libs.extend(["onnxruntime_providers_qnn.dll"]) # DirectML Libs @@ -435,6 +434,26 @@ def finalize_options(self): libs.extend(qnn_deps) if nightly_build: libs.extend(["onnxruntime_pywrapper.dll"]) + migraphx_deps = [ + "amd_comgr0602.dll", + "amd_comgr0604.dll", + "amd_comgr0700.dll", + "hiprtc0602.dll", + "hiprtc0604.dll", + "hiprtc0700.dll", + "hiprtc-builtins0602.dll", + "hiprtc-builtins0604.dll", + "hiprtc-builtins0700.dll", + "migraphx-hiprtc-driver.exe", + "migraphx.dll", + "migraphx_c.dll", + "migraphx_cpu.dll", + "migraphx_device.dll", + "migraphx_gpu.dll", + "migraphx_onnx.dll", + "migraphx_tf.dll", + ] + libs.extend(migraphx_deps) if is_manylinux: if is_openvino: diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 561a76be5fa89..d22c8587a82b5 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -723,8 +723,6 @@ def generate_build_tree( cmake_args += ["-Donnxruntime_ENABLE_WEBASSEMBLY_RELAXED_SIMD=ON"] if args.use_migraphx: cmake_args.append("-Donnxruntime_MIGRAPHX_HOME=" + migraphx_home) - cmake_args.append("-Donnxruntime_ROCM_HOME=" + rocm_home) - cmake_args.append("-Donnxruntime_ROCM_VERSION=" + args.rocm_version) if args.use_tensorrt: cmake_args.append("-Donnxruntime_TENSORRT_HOME=" + tensorrt_home) @@ -840,8 +838,6 @@ def generate_build_tree( if is_macOS() and not args.android: add_default_definition(cmake_extra_defines, "CMAKE_OSX_ARCHITECTURES", args.osx_arch) - if args.apple_deploy_target: - cmake_args += ["-DCMAKE_OSX_DEPLOYMENT_TARGET=" + args.apple_deploy_target] # Code sign the binaries, if the code signing development identity and/or team id are provided if args.xcode_code_signing_identity: cmake_args += ["-DCMAKE_XCODE_ATTRIBUTE_CODE_SIGN_IDENTITY=" + args.xcode_code_signing_identity] @@ -932,7 +928,6 @@ def generate_build_tree( cmake_args += [ "-Donnxruntime_BUILD_SHARED_LIB=ON", "-DCMAKE_OSX_SYSROOT=" + args.apple_sysroot, - "-DCMAKE_OSX_DEPLOYMENT_TARGET=" + args.apple_deploy_target, # we do not need protoc binary for ios cross build "-Dprotobuf_BUILD_PROTOC_BINARIES=OFF", "-DPLATFORM_NAME=" + platform_name, @@ -948,16 +943,15 @@ def generate_build_tree( if args.macos == "Catalyst": macabi_target = f"{args.osx_arch}-apple-ios{args.apple_deploy_target}-macabi" cmake_args += [ - "-DCMAKE_CXX_COMPILER_TARGET=" + macabi_target, - "-DCMAKE_C_COMPILER_TARGET=" + macabi_target, - "-DCMAKE_CC_COMPILER_TARGET=" + macabi_target, f"-DCMAKE_CXX_FLAGS=--target={macabi_target}", - f"-DCMAKE_CXX_FLAGS_RELEASE=-O3 -DNDEBUG --target={macabi_target}", f"-DCMAKE_C_FLAGS=--target={macabi_target}", - f"-DCMAKE_C_FLAGS_RELEASE=-O3 -DNDEBUG --target={macabi_target}", - f"-DCMAKE_CC_FLAGS=--target={macabi_target}", - f"-DCMAKE_CC_FLAGS_RELEASE=-O3 -DNDEBUG --target={macabi_target}", + f"-DCMAKE_ASM_FLAGS=--target={macabi_target}", ] + else: + cmake_args += [ + "-DCMAKE_OSX_DEPLOYMENT_TARGET=" + args.apple_deploy_target, + ] + if args.visionos: cmake_args += [ "-DCMAKE_SYSTEM_NAME=visionOS", @@ -1513,8 +1507,8 @@ def adb_push(src, dest, **kwargs): def adb_shell(*args, **kwargs): return run_subprocess([sdk_tool_paths.adb, "shell", *args], **kwargs) - def adb_install(*args, **kwargs): - return run_subprocess([sdk_tool_paths.adb, "install", *args], **kwargs) + def adb_logcat(*args, **kwargs): + return run_subprocess([sdk_tool_paths.adb, "logcat", *args], **kwargs) def run_adb_shell(cmd): # GCOV_PREFIX_STRIP specifies the depth of the directory hierarchy to strip and @@ -1540,6 +1534,17 @@ def run_adb_shell(cmd): ) context_stack.callback(android.stop_emulator, emulator_proc) + all_android_tests_passed = False + + def dump_logs_on_failure(): + if not all_android_tests_passed: + log.warning("Android test failed. Dumping logs.") + adb_logcat("-d") # dump logs + + context_stack.callback(dump_logs_on_failure) + + adb_logcat("-c") # clear logs + adb_push("testdata", device_dir, cwd=cwd) if is_linux() and os.path.exists("/data/onnx"): adb_push("/data/onnx", device_dir + "/test", cwd=cwd) @@ -1591,6 +1596,8 @@ def run_adb_shell(cmd): f"LD_LIBRARY_PATH=$LD_LIBRARY_PATH:{device_dir} {device_dir}/onnxruntime_customopregistration_test" ) + all_android_tests_passed = True + def run_ios_tests(args, source_dir, config, cwd): is_targeting_iphone_simulator = "iphonesimulator" in args.apple_sysroot.lower() @@ -1697,8 +1704,10 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): run_ios_tests(args, source_dir, config, cwd) continue dll_path_list = [] - if args.use_tensorrt or args.use_nv_tensorrt_rtx: + if args.use_tensorrt: dll_path_list.append(os.path.join(args.tensorrt_home, "lib")) + if args.use_nv_tensorrt_rtx: + dll_path_list.append(os.path.join(args.tensorrt_rtx_home, "lib")) dll_path = None if len(dll_path_list) > 0: @@ -1994,6 +2003,7 @@ def build_nuget_package( use_winml, use_qnn, use_dml, + use_migraphx, enable_training_apis, msbuild_extra_options, ): @@ -2031,6 +2041,9 @@ def build_nuget_package( elif use_tensorrt: execution_provider = "/p:ExecutionProvider=tensorrt" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.TensorRT" + elif use_migraphx: + execution_provider = "/p:ExecutionProvider=migraphx" + package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.MIGraphX" elif use_dnnl: execution_provider = "/p:ExecutionProvider=dnnl" package_name = "/p:OrtPackageId=Microsoft.ML.OnnxRuntime.DNNL" @@ -2622,6 +2635,7 @@ def main(): getattr(args, "use_winml", False), args.use_qnn, getattr(args, "use_dml", False), + args.use_migraphx, args.enable_training_apis, normalize_arg_list(args.msbuild_extra_options), ) diff --git a/tools/ci_build/get_docker_image.py b/tools/ci_build/get_docker_image.py index e656cedae5916..90947e534918d 100755 --- a/tools/ci_build/get_docker_image.py +++ b/tools/ci_build/get_docker_image.py @@ -71,11 +71,14 @@ def main(): log.info(f"Image: {full_image_name}") - dst_deps_file = Path(args.context) / "scripts" / "deps.txt" + dst_scripts_dir = Path(args.context) / "scripts" + dst_deps_file = dst_scripts_dir / "deps.txt" # The docker file may provide a special deps.txt in its docker context dir and uses that one. # Otherwise, copy a generic one from this repo's cmake dir. if not dst_deps_file.exists(): log.info(f"Copy deps.txt to : {dst_deps_file}") + if not dst_scripts_dir.exists(): + dst_scripts_dir.mkdir(parents=True, exist_ok=True) shutil.copyfile(Path(REPO_DIR) / "cmake" / "deps.txt", str(dst_deps_file)) if "manylinux" in args.dockerfile and args.multiple_repos: diff --git a/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json b/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json index d99191e4f45d8..eec9a753f6dcf 100644 --- a/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json +++ b/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json @@ -24,7 +24,7 @@ ], "macosx": [ "--macos=MacOSX", - "--apple_deploy_target=13.3" + "--apple_deploy_target=13.4" ], "iphoneos": [ "--ios", diff --git a/tools/ci_build/github/apple/default_training_ios_framework_build_settings.json b/tools/ci_build/github/apple/default_training_ios_framework_build_settings.json index e35ddb93a173d..bcc23296b7d3a 100644 --- a/tools/ci_build/github/apple/default_training_ios_framework_build_settings.json +++ b/tools/ci_build/github/apple/default_training_ios_framework_build_settings.json @@ -33,7 +33,7 @@ ], "macosx": [ "--macos=MacOSX", - "--apple_deploy_target=13.3" + "--apple_deploy_target=13.4" ] } } diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index 91f35d2b54033..b062a3b64f6f3 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml index 5cf5cd8c936fa..53b62762319ba 100644 --- a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml @@ -7,7 +7,6 @@ parameters: default: true stages: - # build binaries for Android - ${{ if parameters.BuildAndroidBinaries }}: - stage: BuildAndroidBinaries diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 028777756352d..e5319b068a1fc 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -55,7 +55,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 resources: repositories: @@ -122,12 +122,12 @@ extends: PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} - - template: stages/download-java-tools-stage.yml - - template: templates/c-api-cpu.yml parameters: RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} ${{ if eq(parameters.NugetPackageSuffix, 'NONE') }}: OrtNugetPackageId: 'Microsoft.ML.OnnxRuntime' ${{ else }}: @@ -135,16 +135,10 @@ extends: AdditionalBuildFlags: '' AdditionalWinBuildFlags: '--enable_onnx_tests ${{parameters.AdditionalBuildFlag}}' BuildVariant: 'default' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} QnnSDKVersion: ${{ parameters.QnnSdk }} is1ES: true - template: stages/java-cuda-packaging-stage.yml - parameters: - CudaVersion: 12.2 - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - template: stages/nuget-combine-cuda-stage.yml parameters: @@ -159,6 +153,8 @@ extends: buildNodejs: true SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} - template: stages/nodejs-win-packaging-stage.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml index 3772b5e9c4c20..b846cc8bb9e80 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml @@ -83,7 +83,7 @@ stages: artifactName: 'onnxruntime-android-qnn-aar' packageName: 'onnxruntime-android-qnn' #TODO: get this information from the setup stage - QnnSDKVersion: '2.36.1.250708' + QnnSDKVersion: '2.37.1.250807' - template: nuget/templates/test_win.yml parameters: @@ -121,109 +121,6 @@ stages: parameters: StageSuffix: 'macOS_CPU_x64' -- template: templates/final-jar-testing.yml - parameters: - OS: Windows - PoolName: 'onnxruntime-Win-CPU-2022' - -- template: templates/final-jar-testing.yml - parameters: - OS: Linux - PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' - -- template: templates/final-jar-testing.yml - parameters: - OS: MacOS - PoolName: 'macOS-14' - - -- stage: GPU_JAR_Testing - dependsOn: Setup - jobs: - - job: Final_Jar_Testing_Windows_GPU - workspace: - clean: all - pool: 'onnxruntime-Win2022-GPU-A10' - timeoutInMinutes: 60 - variables: - - name: runCodesignValidationInjection - value: false - - steps: - - template: templates/set-version-number-variables-step.yml - - - template: templates/jobs/download_win_gpu_library.yml - parameters: - CudaVersion: 12.2 - DownloadCUDA: true - DownloadTRT: true - - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Final Jar' - ArtifactName: onnxruntime-java-gpu - TargetPath: '$(Build.BinariesDirectory)\final-jar' - - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Jar Tools' - ArtifactName: onnxruntime-java-tools - TargetPath: '$(Build.BinariesDirectory)\final-jar' - - - task: CmdLine@2 - inputs: - script: | - mkdir test - pushd test - jar xf $(Build.BinariesDirectory)\final-jar\testing.jar - popd - java -DUSE_CUDA=1 -jar junit-platform-console-standalone-1.6.2.jar -cp .;.\test;protobuf-java-3.25.5.jar;onnxruntime_gpu-$(OnnxRuntimeVersion).jar --scan-class-path --fail-if-no-tests --disable-banner - workingDirectory: '$(Build.BinariesDirectory)\final-jar' - - - job: Final_Jar_Testing_Linux_GPU - workspace: - clean: all - pool: - name: 'Onnxruntime-Linux-GPU-A10' - os: linux - variables: - - name: runCodesignValidationInjection - value: false - - name: docker_base_image - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250724.1 - timeoutInMinutes: 60 - steps: - - checkout: self - submodules: false - - template: templates/set-version-number-variables-step.yml - - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Final Jar' - ArtifactName: onnxruntime-java-gpu - TargetPath: '$(Build.BinariesDirectory)/final-jar' - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 - Context: tools/ci_build/github/linux/docker/ - DockerBuildArgs: " - --build-arg BUILD_UID=$( id -u ) - --build-arg BASEIMAGE=${{ variables.docker_base_image }} - --build-arg TRT_VERSION=${{ variables.linux_trt_version }} - " - Repository: onnxruntimeubi8packagestest - - - bash: | - docker run -e SYSTEM_COLLECTIONURI --rm \ - --gpus all \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume /data/models:/build/models:ro \ - onnxruntimeubi8packagestest \ - /bin/bash /onnxruntime_src/tools/ci_build/github/linux/java_linux_final_test.sh -r /build -v $(OnnxRuntimeVersion) - displayName: 'Test' - - template: nuget/templates/test_win.yml parameters: AgentPool: 'onnxruntime-Win2022-GPU-A10' @@ -274,13 +171,18 @@ stages: clean: true submodules: none - - template: templates/flex-downloadPipelineArtifact.yml + - download: build + artifact: 'Windows_Packaging_cuda_build_artifacts' + displayName: 'Download Windows GPU Packages Build' + + - template: templates/setup-build-tools.yml parameters: - ArtifactName: "Windows_Packaging_cuda_build_artifacts" - StepName: 'Download Pipeline Artifact - Windows GPU Packages Build' - TargetPath: '$(Build.BinariesDirectory)/RelWithDebInfo/' + host_cpu_arch: 'x64' - - template: templates/telemetry-steps.yml + - task: CmdLine@2 + inputs: + script: | + move $(Pipeline.Workspace)/build/Windows_Packaging_cuda_build_artifacts $(Build.BinariesDirectory)/RelWithDebInfo - template: templates/set-version-number-variables-step.yml @@ -290,17 +192,6 @@ stages: jdkArchitectureOption: x64 jdkSourceOption: 'PreInstalled' - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - architecture: x64 - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' - - task: PythonScript@0 displayName: 'Update CTest Path References' inputs: @@ -309,10 +200,6 @@ stages: "$(Build.BinariesDirectory)/RelWithDebInfo/CTestTestfile.cmake" "$(Build.BinariesDirectory)/RelWithDebInfo" - - task: NodeTool@0 - inputs: - versionSpec: '22.x' - - template: templates/jobs/download_win_gpu_library.yml parameters: CudaVersion: 12.2 @@ -323,14 +210,8 @@ stages: displayName: 'test' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --enable_onnx_tests $(TelemetryOption) ' + arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --enable_onnx_tests' workingDirectory: '$(Build.BinariesDirectory)' - # Previous stage only assembles the java binaries, testing will be done in this stage with GPU machine - - template: templates/make_java_win_binaries.yml - parameters: - msbuildPlatform: x64 - java_artifact_id: onnxruntime_gpu - buildOnly: false - stage: Windows_Packaging_Tensorrt_Testing dependsOn: Setup @@ -344,14 +225,19 @@ stages: - checkout: self clean: true submodules: none + + - download: build + artifact: 'Windows_Packaging_tensorrt_build_artifacts' + displayName: 'Download Windows GPU Packages Build' - - template: templates/flex-downloadPipelineArtifact.yml + - template: templates/setup-build-tools.yml parameters: - ArtifactName: "Windows_Packaging_tensorrt_build_artifacts" - StepName: 'Download Pipeline Artifact - Windows GPU Packages Build' - TargetPath: '$(Build.BinariesDirectory)/RelWithDebInfo/' + host_cpu_arch: 'x64' - - template: templates/telemetry-steps.yml + - task: CmdLine@2 + inputs: + script: | + move $(Pipeline.Workspace)/build/Windows_Packaging_tensorrt_build_artifacts $(Build.BinariesDirectory)/RelWithDebInfo - template: templates/set-version-number-variables-step.yml @@ -360,18 +246,7 @@ stages: versionSpec: "17" jdkArchitectureOption: x64 jdkSourceOption: 'PreInstalled' - - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - architecture: x64 - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' - + - task: PythonScript@0 displayName: 'Update CTest Path References' inputs: @@ -380,10 +255,6 @@ stages: "$(Build.BinariesDirectory)/RelWithDebInfo/CTestTestfile.cmake" "$(Build.BinariesDirectory)/RelWithDebInfo" - - task: NodeTool@0 - inputs: - versionSpec: '22.x' - - template: templates/jobs/download_win_gpu_library.yml parameters: CudaVersion: 12.2 @@ -394,11 +265,5 @@ stages: displayName: 'test' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --enable_onnx_tests $(TelemetryOption) ' - workingDirectory: '$(Build.BinariesDirectory)' - # Previous stage only assembles the java binaries, testing will be done in this stage with GPU machine - - template: templates/make_java_win_binaries.yml - parameters: - msbuildPlatform: x64 - java_artifact_id: onnxruntime_gpu - buildOnly: false + arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --enable_onnx_tests' + workingDirectory: '$(Build.BinariesDirectory)' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml index 46695403fd854..95f55f52f9a68 100644 --- a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml @@ -123,11 +123,9 @@ extends: buildNodejs: false SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} - template: stages/download-java-tools-stage.yml - template: stages/java-cuda-packaging-stage.yml - parameters: - CudaVersion: ${{ parameters.CudaVersion }} - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml index 0e807b7beb6e9..257f554dd200e 100644 --- a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -6,13 +6,28 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 - name: IsReleaseBuild displayName: Is a release build? Set it to true if you are doing an Onnx Runtime release. type: boolean default: false +- name: PreReleaseVersionSuffixString + displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. + type: string + values: + - alpha + - beta + - rc + - none + default: none + +- name: PreReleaseVersionSuffixNumber + displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. + type: number + default: 0 + - name: PackageName displayName: What is the package name? Override using an environment variable CustomPackageName. type: string @@ -69,8 +84,16 @@ extends: exclusionsFile: '$(Build.SourcesDirectory)\tools\ci_build\policheck_exclusions.xml' stages: + - template: stages/set_packaging_variables_stage.yml + parameters: + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} + - template: templates/win-ci.yml parameters: + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} ort_build_pool_name: 'onnxruntime-Win2022-GPU-A10' DoCompliance: false DoEsrp: true @@ -103,7 +126,6 @@ extends: - template: templates/mac-cpu-packaging-pipeline.yml parameters: AllowReleasedOpsetOnly: 1 - BuildForAllArchs: true AdditionalBuildFlags: '--use_webgpu --skip_tests' DoEsrp: true @@ -157,7 +179,7 @@ extends: targetType: 'inline' script: | mkdir -p $(Build.BinariesDirectory)/osx-x64 - Move-Item -Path $(Build.BinariesDirectory)/osx/onnxruntime-osx-x86_64* -Destination $(Build.BinariesDirectory)/osx-x64 + Move-Item -Path $(Build.BinariesDirectory)/osx/onnxruntime-osx-x86_64* -Destination $(Build.BinariesDirectory)/osx-x64 mkdir -p $(Build.BinariesDirectory)/osx-arm64 Move-Item -Path $(Build.BinariesDirectory)/osx/onnxruntime-osx-arm64* -Destination $(Build.BinariesDirectory)/osx-arm64 diff --git a/tools/ci_build/github/azure-pipelines/jar_package_testing.yml b/tools/ci_build/github/azure-pipelines/jar_package_testing.yml new file mode 100644 index 0000000000000..19b40cb7c549a --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/jar_package_testing.yml @@ -0,0 +1,175 @@ +resources: + pipelines: + - pipeline: build + source: 'Zip-Nuget-Java-Nodejs Packaging Pipeline' + trigger: true + branch: main + +variables: + mavenVersion: '3.9.8' + +stages: +- template: templates/final-jar-testing-win.yml + parameters: + PoolName: 'onnxruntime-Win-CPU-2022' + +- template: templates/final-jar-testing-linux.yml + parameters: + OS: Linux + PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' + +- template: templates/final-jar-testing-linux.yml + parameters: + OS: MacOS + PoolName: 'macOS-14' + +- stage: GPU_JAR_Testing + dependsOn: [] + jobs: + - job: Final_Jar_Testing_Windows_GPU + workspace: + clean: all + pool: 'onnxruntime-Win2022-GPU-A10' + timeoutInMinutes: 60 + variables: + - name: runCodesignValidationInjection + value: false + + steps: + - template: templates/set-version-number-variables-step.yml + + - template: templates/jobs/download_win_gpu_library.yml + parameters: + CudaVersion: 12.2 + DownloadCUDA: true + DownloadTRT: true + + - template: templates/setup-maven.yml + + - task: Maven@4 + displayName: 'Download Java Dependencies' + inputs: + mavenPomFile: '$(Build.SourcesDirectory)/tools/ci_build/java/pom.xml' + goals: 'dependency:copy-dependencies' + options: '-DoutputDirectory=$(Pipeline.Workspace)/build/onnxruntime-java' + publishJUnitTestResults: false + javaHomeOption: 'JDKVersion' + jdkVersionOption: '1.17' + mavenVersionOption: 'Default' + - download: build + artifact: 'onnxruntime-java-gpu' + displayName: 'Download Final Jar' + - script: | + move $(Pipeline.Workspace)\build\onnxruntime-java-gpu\*.jar $(Pipeline.Workspace)\build\onnxruntime-java\ + + - task: PowerShell@2 + displayName: 'Run Java Tests with PowerShell' + inputs: + targetType: 'inline' + script: | + # Exit script on any error + $ErrorActionPreference = "Stop" + + cd $(Pipeline.Workspace)/build/onnxruntime-java + del *.asc + del *.sha256 + del *.sha512 + del *.pom + del *.sha1 + del *.pom + cd .. + mkdir tests + cd tests + jar xf $(Pipeline.Workspace)/build/onnxruntime-java/testing.jar + del $(Pipeline.Workspace)/build/onnxruntime-java/testing.jar + dir $(Pipeline.Workspace)/build/tests + Write-Host "Running JUnit Tests..." + & java -DUSE_CUDA=1 ` + -cp "$(Pipeline.Workspace)\build\tests;$(Pipeline.Workspace)\build\onnxruntime-java\*" org.junit.platform.console.ConsoleLauncher --scan-classpath=$(Pipeline.Workspace)\build\tests ` + --fail-if-no-tests --disable-banner --reports-dir "$($env:Build_ArtifactStagingDirectory)/TestResults" + + - task: PublishTestResults@2 + displayName: 'Publish Test Results' + inputs: + testResultsFormat: 'JUnit' + testResultsFiles: '$(Build.ArtifactStagingDirectory)/TestResults/TEST-junit-jupiter.xml' + failTaskOnFailedTests: true + + + - job: Final_Jar_Testing_Linux_GPU + workspace: + clean: all + pool: + name: 'Onnxruntime-Linux-GPU-A10' + variables: + - name: runCodesignValidationInjection + value: false + - name: docker_base_image + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250724.1 + timeoutInMinutes: 60 + steps: + - checkout: self + submodules: false + + - template: templates/set-version-number-variables-step.yml + + - bash: | + sudo apt-get install -y msopenjdk-17 + dpkg -l msopenjdk-17 + + - bash: | + echo "Downloading and installing Maven $(mavenVersion) for Linux..." + MAVEN_DIR="$(Agent.TempDirectory)/apache-maven-$(mavenVersion)" + # Download Maven binary + wget https://archive.apache.org/dist/maven/maven-3/$(mavenVersion)/binaries/apache-maven-$(mavenVersion)-bin.tar.gz -O $(Agent.TempDirectory)/maven.tar.gz + + # Extract to the temp directory + mkdir -p ${MAVEN_DIR} + tar -xzf $(Agent.TempDirectory)/maven.tar.gz -C $(Agent.TempDirectory) + + # Add Maven's bin directory to the PATH for subsequent tasks in the job + echo "##vso[task.prependpath]${MAVEN_DIR}/bin" + displayName: 'Install Maven (Linux)' + + - script: | + echo "Maven is now on the PATH." + mvn --version + + - download: build + artifact: 'onnxruntime-java-gpu' + displayName: 'Download Final Jar' + + # Rename the downloaded folder + - script: | + mv $(Pipeline.Workspace)/build/onnxruntime-java-gpu $(Pipeline.Workspace)/build/onnxruntime-java + + - task: Maven@4 + displayName: 'Download Dependencies' + inputs: + mavenPomFile: '$(Build.SourcesDirectory)/tools/ci_build/java/pom.xml' + goals: 'dependency:copy-dependencies' + options: '-DoutputDirectory=$(Pipeline.Workspace)/build/onnxruntime-java' + publishJUnitTestResults: false + javaHomeOption: 'Path' + jdkDirectory: '/usr/lib/jvm/msopenjdk-17-amd64' + jdkVersionOption: 'Default' + mavenVersionOption: 'Default' + + # Now all the jars are in the $(Pipeline.Workspace)/build folder + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 + Context: tools/ci_build/github/linux/docker/ + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=${{ variables.docker_base_image }} --build-arg TRT_VERSION=${{ variables.linux_trt_version }}" + Repository: onnxruntimeubi8packagestest + + - bash: | + docker run --network=none --rm \ + --gpus all \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Pipeline.Workspace)/build:/build \ + --volume /data/models:/build/models:ro \ + onnxruntimeubi8packagestest \ + /bin/bash /onnxruntime_src/tools/ci_build/github/linux/java_linux_final_test.sh -r /build -v $(OnnxRuntimeVersion) + displayName: 'Test' diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 526ed71df2006..ae2602c77d7a2 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml index 6c998f9c3da13..ae595bbf0c96b 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml @@ -4,8 +4,13 @@ steps: displayName: 'Download NPM_packages' artifact: 'NPM_packages' -- script: | - mv $(Pipeline.Workspace)/build/NPM_packages '$(Build.BinariesDirectory)/nodejs-artifact' + +- task: PowerShell@2 + displayName: 'Move Artifact Directory' + inputs: + targetType: 'inline' + script: | + Move-Item -Path "$(Pipeline.Workspace)/build/NPM_packages" -Destination "$(Build.BinariesDirectory)/nodejs-artifact" - script: mkdir e2e_test workingDirectory: '$(Build.BinariesDirectory)' diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml index bb4f600395ac9..4dd19ce2c250c 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml @@ -11,9 +11,7 @@ stages: clean: all timeoutInMinutes: 120 pool: - name: 'Azure Pipelines' - image: 'macOS-15' - os: 'macOS' + vmImage: 'macOS-15' variables: - name: OnnxRuntimeBuildDirectory diff --git a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml index 3615f9f7c0960..7fdddfa0d03f5 100644 --- a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml @@ -50,6 +50,11 @@ extends: sourceAnalysisPool: name: onnxruntime-Win-CPU-2022 os: windows + codeql: + compiled: + enabled: false + justificationForDisabling: 'CodeQL causes the React Native Android tests to fail when trying to load Linux x64 .so' + stages: - template: templates/web-ci.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index 757b8ac6e9a16..574302bb11fe3 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -49,21 +49,10 @@ stages: clean: true submodules: none - - template: ../../templates/telemetry-steps.yml - - - task: NodeTool@0 - inputs: - versionSpec: '22.x' - - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - architecture: x64 - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' + + - template: ../../templates/setup-build-tools.yml + parameters: + host_cpu_arch: 'x64' # need to set PROCESSOR_ARCHITECTURE so the x86 SDK is installed correctly - task: UseDotNet@2 diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml index 9c5c796d2983d..dcfad3c7721a0 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml @@ -27,8 +27,6 @@ stages: variables: - name: OnnxRuntimeBuildDirectory value: '$(Build.BinariesDirectory)' - - name: SKIPNONPACKAGETESTS - value: 'ON' - name: runCodesignValidationInjection value: false - name: CUDA_MODULE_LOADINGL @@ -69,6 +67,11 @@ stages: - download: build displayName: 'Download Nuget' artifact: 'drop-signed-nuget-${{ parameters.ArtifactSuffix }}' + + - download: build + displayName: 'Download CustomOp DLL' + artifact: 'onnxruntime-win-x64' + patterns: "testdata/custom_op_library.dll" - template: get-nuget-package-version-as-variable.yml @@ -76,7 +79,10 @@ stages: packageFolder: '$(Pipeline.Workspace)/build/drop-signed-nuget-${{ parameters.ArtifactSuffix }}' - script: | - mklink /D /J models C:\local\models + mklink /D /J models C:\local\models + mkdir $(Build.BinariesDirectory)\Debug\Debug + dir $(Pipeline.Workspace)\build\onnxruntime-win-x64 + copy $(Pipeline.Workspace)\build\onnxruntime-win-x64\testdata\custom_op_library.dll $(Build.BinariesDirectory)\Debug\Debug workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Create models link' diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index 379b20ce8a0c4..39d3577c3524f 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -25,14 +25,56 @@ stages: - stage: Packages_Somking_Test dependsOn: [] jobs: - - template: templates/py-package-smoking-test.yml + - template: templates/py-package-smoking-test-macos.yml parameters: - job_name: Test_MAC_Wheels - machine_pool: - vmImage: 'macOS-14' - itemPattern: '*/*mac*x86_64.whl' - arch: 'x86_64' - - template: templates/py-package-smoking-test.yml + python_version: 3.10 + os_version: 14 + - template: templates/py-package-smoking-test-macos.yml + parameters: + python_version: 3.11 + os_version: 14 + - template: templates/py-package-smoking-test-macos.yml + parameters: + python_version: 3.12 + os_version: 14 + - template: templates/py-package-smoking-test-macos.yml + parameters: + python_version: 3.13 + os_version: 14 + - template: templates/py-package-smoking-test-macos.yml + parameters: + python_version: 3.10 + os_version: 13 + - template: templates/py-package-smoking-test-macos.yml + parameters: + python_version: 3.11 + os_version: 13 + - template: templates/py-package-smoking-test-macos.yml + parameters: + python_version: 3.12 + os_version: 13 + - template: templates/py-package-smoking-test-macos.yml + parameters: + python_version: 3.13 + os_version: 13 + - template: templates/py-package-smoking-test-macos.yml + parameters: + python_version: 3.10 + os_version: 15 + - template: templates/py-package-smoking-test-macos.yml + parameters: + python_version: 3.11 + os_version: 15 + - template: templates/py-package-smoking-test-macos.yml + parameters: + python_version: 3.12 + os_version: 15 + - template: templates/py-package-smoking-test-macos.yml + parameters: + python_version: 3.13 + os_version: 15 + + - template: templates/py-package-smoking-test-linux.yml parameters: job_name: Test_LINUX_x86_64_Wheels itemPattern: '*/*manylinux*x86_64.whl' diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index b4edf78c3b7bd..02aead3b3d3c7 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.36.1.250708 + default: 2.37.1.250807 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 7af5334793c30..02fae0b10ac39 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 - name: build_config displayName: Build Configuration @@ -14,6 +14,22 @@ parameters: type: boolean default: false + +- name: PreReleaseVersionSuffixString + displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. + type: string + values: + - alpha + - beta + - rc + - none + default: none + +- name: PreReleaseVersionSuffixNumber + displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. + type: number + default: 0 + - name: DoEsrp displayName: Run code sign tasks? Must be true if you are doing an Onnx Runtime release. type: boolean @@ -68,6 +84,11 @@ extends: enabled: true exclusionsFile: '$(Build.SourcesDirectory)\tools\ci_build\policheck_exclusions.xml' stages: + - template: stages/set_packaging_variables_stage.yml + parameters: + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} - template: templates/qnn-ep-win.yml parameters: @@ -77,4 +98,4 @@ extends: DoEsrp: ${{ parameters.DoEsrp }} ArtifactName: 'drop-nuget-qnn-arm64x' StageName: 'OnnxRuntime_QNN_Nuget_Win_Arm64x' - build_config: ${{ parameters.build_config }} \ No newline at end of file + build_config: ${{ parameters.build_config }} diff --git a/tools/ci_build/github/azure-pipelines/stages/download-java-tools-stage.yml b/tools/ci_build/github/azure-pipelines/stages/download-java-tools-stage.yml deleted file mode 100644 index 949d29d27da9d..0000000000000 --- a/tools/ci_build/github/azure-pipelines/stages/download-java-tools-stage.yml +++ /dev/null @@ -1,26 +0,0 @@ -stages: -- stage: Download_Java_Tools - dependsOn: [] - jobs: - - job: Download_Java_Tools - pool: - name: 'onnxruntime-Ubuntu2404-AMD-CPU' - os: linux - steps: - - checkout: none - - task: CmdLine@2 - displayName: Download Java Tools - inputs: - script: | - mkdir -p java-tools - pushd java-tools - wget --tries=3 https://oss.sonatype.org/service/local/repositories/releases/content/org/junit/platform/junit-platform-console-standalone/1.6.2/junit-platform-console-standalone-1.6.2.jar -P ./ - wget --tries=3 https://oss.sonatype.org/service/local/repositories/releases/content/com/google/protobuf/protobuf-java/3.25.5/protobuf-java-3.25.5.jar -P ./ - popd - workingDirectory: '$(Agent.TempDirectory)' - - - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Pipeline Java Tools Artifact' - inputs: - targetPath: '$(Agent.TempDirectory)/java-tools' - artifact: 'onnxruntime-java-tools' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml index 63aaf328e1426..a58d74bf80a86 100644 --- a/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml @@ -1,80 +1,31 @@ -parameters: -- name: CudaVersion - type: string -- name: SpecificArtifact - type: string -- name: BuildId - type: string - stages: - stage: Jar_Packaging_GPU dependsOn: - Linux_C_API_Packaging_GPU - Windows_Packaging_CUDA - Windows_Packaging_TensorRT - - Download_Java_Tools jobs: - job: Jar_Packaging_GPU workspace: clean: all + templateContext: + inputs: + - input: pipelineArtifact + artifactName: drop-onnxruntime-java-win-x64-tensorrt + targetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' + + - input: pipelineArtifact + artifactName: drop-onnxruntime-java-linux-x64-tensorrt + targetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-linux-x64' + + outputs: + - output: pipelineArtifact + targetPath: $(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64 + artifactName: onnxruntime-java-gpu pool: 'onnxruntime-Win-CPU-2022' dependsOn: [] condition: succeeded() steps: - - checkout: self - submodules: false - - template: ../templates/set-version-number-variables-step.yml - - - template: ../templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - Win x64' - ArtifactName: 'drop-onnxruntime-java-win-x64-tensorrt' - TargetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: ../templates/flex-downloadPipelineArtifact.yml - parameters: - stepName: 'Download Pipeline Artifact - Linux x64' - artifactName: 'drop-onnxruntime-java-linux-x64-cuda' - targetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-linux-x64' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: ../templates/flex-downloadPipelineArtifact.yml + - template: ../templates/jar-packaging.yml parameters: - StepName: 'Download Pipeline Artifact - Linux x64' - ArtifactName: 'drop-onnxruntime-java-linux-x64-tensorrt' - targetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-linux-x64-tensorrt' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - task: PowerShell@2 - displayName: 'PowerShell Script' - inputs: - targetType: filePath - filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\jar_gpu_packaging.ps1 - failOnStderr: true - showWarnings: true - workingDirectory: '$(Build.BinariesDirectory)\java-artifact' - - - template: ../templates/jar-esrp-dll.yml - parameters: - JarFileDirectory: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' - JarFileName: 'onnxruntime_gpu-$(OnnxRuntimeVersion).jar' - - - template: ../templates/jar-maven-signing-win.yml - parameters: - JarFileDirectory: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' - - - task: CopyFiles@2 - displayName: 'Copy Java Files to Artifact Staging Directory' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Pipeline Artifact' - inputs: - path: '$(Build.ArtifactStagingDirectory)' - artifact: 'onnxruntime-java-gpu' + package_type: gpu \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml index 76eb5f150ad44..3187a7fb759c2 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml @@ -71,23 +71,9 @@ stages: clean: true submodules: none - - template: ../templates/telemetry-steps.yml - - - task: NodeTool@0 - inputs: - versionSpec: '22.x' - - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - architecture: ${{ parameters.BuildArch }} - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' - + - template: ../templates/setup-build-tools.yml + parameters: + host_cpu_arch: ${{ parameters.BuildArch }} # need to set PROCESSOR_ARCHITECTURE so the x86 SDK is installed correctly - task: UseDotNet@2 diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml index e33d3dbf9e107..168432283fa51 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml @@ -32,6 +32,19 @@ parameters: - name: BuildId type: string +- name: PreReleaseVersionSuffixString + displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. + type: string + values: + - alpha + - beta + - rc + - none + +- name: PreReleaseVersionSuffixNumber + displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. + type: number + stages: - template: nuget-linux-cuda-packaging-stage.yml parameters: @@ -52,6 +65,8 @@ stages: win_trt_home: ${{ parameters.win_trt_home }} win_cuda_home: ${{ parameters.win_cuda_home }} buildJava: ${{ parameters.buildJava }} + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} - template: nuget-cuda-packaging-stage.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml index 4175a339535e4..121e80fca1021 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -28,6 +28,10 @@ stages: value: ${{ parameters.CudaVersion }} steps: - template: ../templates/set-version-number-variables-step.yml + - task: UsePythonVersion@0 + displayName: Use Python 3.12 + inputs: + versionSpec: 3.12 - template: ../templates/get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/default/cuda${{ variables.CUDA_VERSION_MAJOR }}/Dockerfile @@ -45,10 +49,8 @@ stages: arch: 'linux-x64' buildConfig: 'Release' artifactName: 'onnxruntime-java-linux-x64-cuda' - version: '$(OnnxRuntimeVersion)' libraryName: 'libonnxruntime.so' nativeLibraryName: 'libonnxruntime4j_jni.so' - is1ES: true - template: ../templates/c-api-artifacts-package-and-publish-steps-posix.yml parameters: @@ -85,6 +87,10 @@ stages: - checkout: self clean: true submodules: recursive + - task: UsePythonVersion@0 + displayName: Use Python 3.12 + inputs: + versionSpec: 3.12 - template: ../templates/get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/default/cuda${{ variables.CUDA_VERSION_MAJOR }}/Dockerfile @@ -106,10 +112,8 @@ stages: arch: 'linux-x64' buildConfig: 'Release' artifactName: 'onnxruntime-java-linux-x64-tensorrt' - version: '$(OnnxRuntimeVersion)' libraryName: 'libonnxruntime.so' nativeLibraryName: 'libonnxruntime4j_jni.so' - is1ES: true - template: ../templates/c-api-artifacts-package-and-publish-steps-posix.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml index ed6c4c799c26d..61afeba2d302b 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml @@ -34,10 +34,25 @@ parameters: - name: buildJava type: boolean +- name: PreReleaseVersionSuffixString + displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. + type: string + values: + - alpha + - beta + - rc + - none + +- name: PreReleaseVersionSuffixNumber + displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. + type: number + stages: # Windows CUDA without TensorRT Packaging - template: ../templates/win-ci.yml parameters: + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} ort_build_pool_name: 'onnxruntime-Win2022-GPU-A10' DoEsrp: ${{ parameters.DoEsrp }} stage_name_suffix: CUDA @@ -45,16 +60,19 @@ stages: msbuildPlatform: x64 packageName: x64-cuda CudaVersion: ${{ parameters.CudaVersion }} - buildparameter: --use_cuda --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52-real;61-real;75-real;86-real;89-real;90a-virtual" + buildparameter: --use_cuda --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=75-real;86-real;89-real;90a-virtual" runTests: ${{ parameters.RunOnnxRuntimeTests }} buildJava: ${{ parameters.buildJava }} java_artifact_id: onnxruntime_gpu UseIncreasedTimeoutForTests: ${{ parameters.UseIncreasedTimeoutForTests }} SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} + # Windows CUDA with TensorRT Packaging - template: ../templates/win-ci.yml parameters: + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} ort_build_pool_name: 'onnxruntime-Win2022-GPU-A10' DoEsrp: ${{ parameters.DoEsrp }} stage_name_suffix: TensorRT @@ -62,7 +80,7 @@ stages: msbuildPlatform: x64 CudaVersion: ${{ parameters.CudaVersion }} packageName: x64-tensorrt - buildparameter: --use_tensorrt --tensorrt_home=${{ parameters.win_trt_home }} --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52-real;61-real;75-real;86-real;89-real;90a-virtual" + buildparameter: --use_tensorrt --tensorrt_home=${{ parameters.win_trt_home }} --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=75-real;86-real;89-real;90a-virtual" runTests: ${{ parameters.RunOnnxRuntimeTests }} buildJava: ${{ parameters.buildJava }} java_artifact_id: onnxruntime_gpu diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index c1b83c5e579dc..999ace8d3e345 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.36.1.250708 + default: 2.37.1.250807 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: @@ -90,20 +90,12 @@ stages: matrix: Python310_x64: PythonVersion: '3.10' - MsbuildPlatform: x64 - buildArch: x64 Python311_x64: PythonVersion: '3.11' - MsbuildPlatform: x64 - buildArch: x64 Python312_x64: PythonVersion: '3.12' - MsbuildPlatform: x64 - buildArch: x64 Python313_x64: PythonVersion: '3.13' - MsbuildPlatform: x64 - buildArch: x64 variables: OnnxRuntimeBuildDirectory: '$(Build.BinariesDirectory)' ExtraParam: ${{ parameters.build_py_parameters }} @@ -116,17 +108,10 @@ stages: clean: true submodules: recursive - - template: ../templates/telemetry-steps.yml - - - task: UsePythonVersion@0 - inputs: - versionSpec: $(PythonVersion) - addToPath: true - architecture: $(buildArch) - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' + - template: ../templates/setup-build-tools.yml + parameters: + host_cpu_arch: 'x64' + python_version: $(PythonVersion) - template: ../templates/set-nightly-build-option-variable-step.yml @@ -208,80 +193,61 @@ stages: - stage: Python_Packaging_MacOS dependsOn: [] jobs: - - job: MacOS_py_Wheels - timeoutInMinutes: 360 - workspace: - clean: all - pool: - name: "Azure Pipelines" - image: "macOS-14" - os: macOS - templateContext: - outputs: - - output: pipelineArtifact - targetPath: $(Build.SourcesDirectory)/build/Release/dist/fixed_wheels - artifactName: onnxruntime-macos-$(PythonVersion) - variables: - MACOSX_DEPLOYMENT_TARGET: '13.3' - strategy: - matrix: - Python310: - PythonVersion: '3.10' - Python311: - PythonVersion: '3.11' - Python312: - PythonVersion: '3.12' - Python313: - PythonVersion: '3.13' - steps: - - checkout: self - clean: true - submodules: recursive - - - task: UsePythonVersion@0 - displayName: 'Use Python' - inputs: - versionSpec: $(PythonVersion) - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' - - - template: ../templates/use-xcode-version.yml + - template: ../templates/py-macos.yml + parameters: + arch: 'arm64' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + python_version: '3.10' + + - template: ../templates/py-macos.yml + parameters: + arch: 'arm64' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + python_version: '3.11' - - script: | - set -e -x - export _PYTHON_HOST_PLATFORM=macosx-${{variables.MACOSX_DEPLOYMENT_TARGET}}-universal2 - python3 -m pip install -r '$(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/requirements.txt' - # Note: There is a build error when we set CMAKE_OSX_ARCHITECTURES="arm64;x86_64" and KleidiAI is enabled. - # Disable KleidiAI as a workaround with --no_kleidiai. - # TODO Re-enable KleidiAI once https://github.com/microsoft/onnxruntime/issues/24152 is fixed. - python3 $(Build.SourcesDirectory)/tools/ci_build/build.py \ - --build_dir $(Build.SourcesDirectory)/build \ - --use_vcpkg --use_vcpkg_ms_internal_asset_cache \ - --use_binskim_compliant_compile_flags \ - --config Release \ - --build_wheel \ - --use_coreml \ - --no_kleidiai \ - ${{ parameters.build_py_parameters }} \ - --cmake_extra_defines CMAKE_OSX_ARCHITECTURES="arm64;x86_64" \ - --update --skip_submodule_sync --build --parallel - displayName: 'Command Line Script' + - template: ../templates/py-macos.yml + parameters: + arch: 'arm64' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + python_version: '3.12' + + - template: ../templates/py-macos.yml + parameters: + arch: 'arm64' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + python_version: '3.13' - - script: | - set -ex - python -m pip install --upgrade delocate - cd '$(Build.SourcesDirectory)/build/Release/dist' - ls - for file in *.whl - do - delocate-listdeps "$file" - delocate-wheel --require-archs=x86_64,arm64 -w fixed_wheels -v "$file" - done - displayName: 'delocate wheel' + - template: ../templates/py-macos.yml + parameters: + arch: 'x86_64' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + python_version: '3.10' + + - template: ../templates/py-macos.yml + parameters: + arch: 'x86_64' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + python_version: '3.11' + - template: ../templates/py-macos.yml + parameters: + arch: 'x86_64' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + python_version: '3.12' + + - template: ../templates/py-macos.yml + parameters: + arch: 'x86_64' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + python_version: '3.13' - ${{ if eq(parameters.enable_linux_arm, true) }}: - stage: Python_Packaging_Linux_ARM @@ -316,7 +282,21 @@ stages: MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' QNN_SDK: ${{ parameters.qnn_sdk_version }} BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} - is1ES: true + PYTHON_VERSION: '3.11' + + - template: ../templates/py-win-arm64-qnn.yml + parameters: + MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' + QNN_SDK: ${{ parameters.qnn_sdk_version }} + BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} + PYTHON_VERSION: '3.12' + + - template: ../templates/py-win-arm64-qnn.yml + parameters: + MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' + QNN_SDK: ${{ parameters.qnn_sdk_version }} + BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} + PYTHON_VERSION: '3.13' - ${{ if eq(parameters.enable_windows_arm64ec_qnn, true) }}: - stage: Python_Packaging_Windows_arm64ec_QNN @@ -327,7 +307,6 @@ stages: MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' QNN_SDK: ${{ parameters.qnn_sdk_version }} BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} - is1ES: true - ${{ if eq(parameters.enable_windows_x64_qnn, true) }}: - stage: Python_Packaging_Windows_x64_QNN diff --git a/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml index 9c063f561eefc..e2683c04f21f2 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml @@ -3,6 +3,7 @@ parameters: type: string default: 'onnxruntime-Win2022-GPU-A10' +# Package name suffix - name: EP_NAME type: string @@ -92,18 +93,10 @@ stages: clean: true submodules: none - - template: ../templates/telemetry-steps.yml - - - task: UsePythonVersion@0 - inputs: - versionSpec: ${{ parameters.PYTHON_VERSION }} - addToPath: true - architecture: 'x64' - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' + - template: ../templates/setup-build-tools.yml + parameters: + host_cpu_arch: 'x64' + python_version: ${{ parameters.PYTHON_VERSION }} - template: ../templates/jobs/download_win_gpu_library.yml parameters: @@ -220,9 +213,10 @@ stages: TMPDIR: "$(Agent.TempDirectory)" - powershell: | - - python -m pip uninstall -y onnxruntime onnxruntime-gpu -qq - Get-ChildItem -Path $(Build.ArtifactStagingDirectory)/*cp${{ replace(parameters.PYTHON_VERSION,'.','') }}*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname tabulate} + $ErrorActionPreference = "Stop" + python -m pip uninstall -y onnxruntime onnxruntime-${{ parameters.EP_NAME }} -qq + dir $(Build.ArtifactStagingDirectory) + python -m pip --disable-pip-version-check install --no-index --find-links $(Build.ArtifactStagingDirectory) onnxruntime-${{ parameters.EP_NAME }} mkdir -p $(Agent.TempDirectory)\ort_test_data Copy-Item -Path $(Build.sourcesDirectory)/onnxruntime/test/python/onnx_backend_test_series.py -Destination $(Agent.TempDirectory)\ort_test_data Copy-Item -Recurse -Path $(Build.sourcesDirectory)/onnxruntime/test/testdata -Destination $(Agent.TempDirectory)\ort_test_data diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index 6e6fb98e6e68c..be61f652f7fc5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -19,7 +19,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.36.1.250708' + default: '2.37.1.250807' - name: enableWebGpu displayName: Enable WebGPU test @@ -58,7 +58,7 @@ jobs: mkdir -p android_test/android/app/libs cd android_test/android cp -av $(Build.SourcesDirectory)/java/src/test/android/* ./ - cp $(Pipeline.Workspace)/build/onnxruntime-android-full-aar/${{parameters.packageName}}-$(OnnxRuntimeVersion)${{parameters.ReleaseVersionSuffix}}.aar app/libs/${{parameters.packageName}}.aar + cp $(Pipeline.Workspace)/build/${{parameters.artifactName}}/${{parameters.packageName}}-$(OnnxRuntimeVersion)${{parameters.ReleaseVersionSuffix}}.aar app/libs/${{parameters.packageName}}.aar displayName: Copy Android test files and AAR to android_test directory workingDirectory: $(Build.BinariesDirectory) diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index e4bfe20238770..1f402160dc4d5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -53,7 +53,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.36.1.250708' + default: '2.37.1.250807' - name: is1ES displayName: Is 1ES pipeline @@ -74,10 +74,11 @@ jobs: ${{ if contains(parameters.pool_name, 'mac')}}: os: macOS - - variables: - artifacts_directory: $(Build.BinariesDirectory)/.artifacts - + templateContext: + outputs: + - output: pipelineArtifact + targetPath: $(Build.BinariesDirectory)/.artifacts + artifactName: ${{parameters.artifactName}} steps: - checkout: self clean: true @@ -88,7 +89,7 @@ jobs: inputs: script: | # Create a folder for artifacts - mkdir -p $(artifacts_directory) + mkdir -p $(Build.BinariesDirectory)/.artifacts workingDirectory: $(Build.BinariesDirectory) - template: get-docker-image-steps.yml @@ -131,7 +132,7 @@ jobs: --volume $(Build.BinariesDirectory):/build \ --volume $ANDROID_HOME:/android_home \ --volume $NDK_HOME:/ndk_home \ - --volume $(artifacts_directory):/home/onnxruntimedev/.artifacts \ + --volume $(Build.BinariesDirectory)/.artifacts:/home/onnxruntimedev/.artifacts \ --volume $(Build.BinariesDirectory)/.build_settings:/home/onnxruntimedev/.build_settings \ $QNN_VOLUME \ -e NIGHTLY_BUILD \ @@ -145,18 +146,6 @@ jobs: /bin/bash /onnxruntime_src/tools/ci_build/github/android/build_aar_and_copy_artifacts.sh $USE_QNN workingDirectory: $(Build.SourcesDirectory) - - - ${{ if eq(parameters['enable_code_sign'], 'true') }}: - - template: jar-maven-signing-linux.yml - parameters: - JarFileDirectory: '$(artifacts_directory)' - - ${{ if eq(parameters.is1ES, false) }}: - - task: PublishPipelineArtifact@1 - inputs: - targetPath: '$(artifacts_directory)' - artifactName: '${{parameters.artifactName}}' - - ${{ if eq(parameters.is1ES, true) }}: - - task: 1ES.PublishPipelineArtifact@1 - inputs: - targetPath: '$(artifacts_directory)' - artifactName: '${{parameters.artifactName}}' + - template: jar-maven-signing-linux.yml + parameters: + JarFileDirectory: $(Build.BinariesDirectory)/.artifacts \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index bf65b0c54cf27..9509a40cda4e9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -9,6 +9,19 @@ parameters: type: boolean default: false +- name: PreReleaseVersionSuffixString + displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. + type: string + values: + - alpha + - beta + - rc + - none + +- name: PreReleaseVersionSuffixNumber + displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. + type: number + - name: AdditionalBuildFlags displayName: Additional build flags for build.py type: string @@ -28,22 +41,12 @@ parameters: type: string default: 'default' -- name: SpecificArtifact - displayName: Use Specific Artifact - type: boolean - default: false - -- name: BuildId - displayName: Specific Artifact's BuildId - type: string - default: '0' - # Do not update this to a version that does not exist for the qnn-runtime Maven package: # https://mvnrepository.com/artifact/com.qualcomm.qti/qnn-runtime - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 - name: is1ES displayName: Is 1ES pipeline @@ -58,9 +61,6 @@ stages: - template: mac-cpu-packaging-pipeline.yml parameters: AllowReleasedOpsetOnly: 1 - BuildForAllArchs: true - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} DoEsrp: true - stage: Android_Java_API_AAR_Packaging_Full @@ -108,14 +108,26 @@ stages: clean: all pool: name: 'Azure Pipelines' - image: 'macOS-14' + image: 'macOS-15' os: 'macOS' timeoutInMinutes: 300 steps: - template: set-version-number-variables-step.yml + - task: JavaToolInstaller@0 + inputs: + versionSpec: "17" + jdkArchitectureOption: "x64" + jdkSourceOption: 'PreInstalled' + - template: use-xcode-version.yml + parameters: + xcodeVersion: 16.4 + + - template: setup-build-tools.yml + parameters: + host_cpu_arch: arm64 - script: | set -e -x @@ -155,6 +167,8 @@ stages: runTests: false buildJava: false buildNodejs: false + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} - template: win-ci.yml parameters: @@ -167,13 +181,14 @@ stages: runTests: ${{ parameters.RunOnnxRuntimeTests }} buildJava: true buildNodejs: false + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} - stage: Jar_Packaging dependsOn: - Linux_C_API_Packaging_CPU - - MacOS_C_API_Package_Publish + - MacOS_C_API_Packaging_CPU - Windows_Packaging_CPU_x64_${{ parameters.BuildVariant }} - - Download_Java_Tools condition: succeeded() jobs: - job: Jar_Packaging @@ -204,38 +219,13 @@ stages: targetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-osx-arm64' outputs: - output: pipelineArtifact - targetPath: $(Build.ArtifactStagingDirectory) + targetPath: $(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64 artifactName: onnxruntime-java steps: - - checkout: self - submodules: false - - template: set-version-number-variables-step.yml - - - task: PowerShell@2 - displayName: 'PowerShell Script' - inputs: - targetType: filePath - filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\jar_packaging.ps1 - failOnStderr: true - showWarnings: true - workingDirectory: '$(Build.BinariesDirectory)\java-artifact' - - - template: jar-esrp-dll.yml + - template: jar-packaging.yml parameters: - JarFileDirectory: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' - JarFileName: 'onnxruntime-$(OnnxRuntimeVersion).jar' - - - template: jar-maven-signing-win.yml - parameters: - JarFileDirectory: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' - - - task: CopyFiles@2 - displayName: 'Copy Java Files to Artifact Staging Directory' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - + package_type: cpu - stage: NuGet_Packaging_CPU dependsOn: @@ -262,6 +252,28 @@ stages: binskim: enabled: true scanOutputDirectoryOnly: true + inputs: + - input: pipelineArtifact + artifactName: onnxruntime-win-x64 + targetPath: $(Build.BinariesDirectory)/nuget-artifact + - input: pipelineArtifact + artifactName: onnxruntime-win-arm64 + targetPath: $(Build.BinariesDirectory)/nuget-artifact + - input: pipelineArtifact + artifactName: onnxruntime-osx + targetPath: $(Build.BinariesDirectory)/nuget-artifact + - input: pipelineArtifact + artifactName: onnxruntime-linux-x64 + targetPath: $(Build.BinariesDirectory)/nuget-artifact + - input: pipelineArtifact + artifactName: onnxruntime-linux-aarch64 + targetPath: $(Build.BinariesDirectory)/nuget-artifact + - input: pipelineArtifact + artifactName: onnxruntime-ios-full-xcframework + targetPath: $(Build.BinariesDirectory)/nuget-artifact + - input: pipelineArtifact + artifactName: onnxruntime-android-full-aar + targetPath: $(Build.BinariesDirectory)/nuget-artifact outputs: - output: pipelineArtifact targetPath: $(Build.ArtifactStagingDirectory) @@ -277,62 +289,6 @@ stages: - checkout: self submodules: true - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - Win x64' - ArtifactName: 'onnxruntime-win-x64' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download win-arm64 Pipeline Artifact' - ArtifactName: 'onnxruntime-win-arm64' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download osx-x64 Pipeline Artifact' - ArtifactName: 'onnxruntime-osx' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download linux-x64 Pipeline Artifact' - ArtifactName: 'onnxruntime-linux-x64' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download linux-aarch64 Pipeline Artifact' - ArtifactName: 'onnxruntime-linux-aarch64' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download iOS Pipeline Artifact' - ArtifactName: 'onnxruntime-ios-full-xcframework' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Android-full-aar Pipeline Artifact' - ArtifactName: 'onnxruntime-android-full-aar' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - script: | dir workingDirectory: '$(Build.BinariesDirectory)/nuget-artifact' @@ -443,7 +399,7 @@ stages: - Windows_Nodejs_Packaging_arm64 - Linux_Nodejs_Packaging_x64 - Linux_C_API_Packaging_CPU - - MacOS_C_API_Package_Publish + - MacOS_C_API_Packaging_CPU condition: succeeded() jobs: - job: Nodejs_Packaging diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml index aa1e38f8b0159..f1599b6843fb5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml @@ -45,6 +45,14 @@ jobs: - checkout: self clean: true submodules: none + + - task: UsePythonVersion@0 + displayName: Use Python 3.12 + inputs: + versionSpec: 3.12 + ${{ if eq(parameters.OnnxruntimeArch, 'aarch64') }}: + architecture: arm64 + - template: set-version-number-variables-step.yml - ${{ if eq(parameters.OnnxruntimeArch, 'x64') }}: - template: get-docker-image-steps.yml @@ -82,10 +90,8 @@ jobs: arch: 'linux-${{parameters.OnnxruntimeArch}}' buildConfig: 'Release' artifactName: 'onnxruntime-java-linux-${{parameters.OnnxruntimeArch}}' - version: '$(OnnxRuntimeVersion)' libraryName: 'libonnxruntime.so' nativeLibraryName: 'libonnxruntime4j_jni.so' - is1ES: true - template: c-api-artifacts-package-and-publish-steps-posix.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/templates/download_maven_for_tests.yml b/tools/ci_build/github/azure-pipelines/templates/download_maven_for_tests.yml new file mode 100644 index 0000000000000..7d4cc9550ce54 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/download_maven_for_tests.yml @@ -0,0 +1,18 @@ +steps: +- pwsh: | + echo "Downloading and installing Maven $(mavenVersion) for Windows..." + $MAVEN_DIR = "$(Agent.TempDirectory)\apache-maven-$(mavenVersion)" + # Download Maven binary + Invoke-WebRequest -Uri "https://archive.apache.org/dist/maven/maven-3/$(mavenVersion)/binaries/apache-maven-$(mavenVersion)-bin.zip" -OutFile "$(Agent.TempDirectory)\maven.zip" + + # Extract to the temp directory + Expand-Archive -Path "$(Agent.TempDirectory)\maven.zip" -DestinationPath "$(Agent.TempDirectory)" + + # Add Maven's bin directory to the PATH for subsequent tasks in the job + echo "##vso[task.prependpath]$MAVEN_DIR\bin" + + +- script: | + echo "Maven is now on the PATH." + mvn --version + diff --git a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-linux.yml b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-linux.yml new file mode 100644 index 0000000000000..5a25232a90c39 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-linux.yml @@ -0,0 +1,109 @@ +# Run Java tests on CPU machines with the CPU java package + +parameters: +- name: OS + displayName: Operating System + type: string + +- name: PoolName + type: string + +stages: +- stage: Final_Jar_Testing_${{parameters.OS}} + dependsOn: [] + jobs: + - job: Final_Jar_Testing_${{parameters.OS}} + workspace: + clean: all + ${{ if eq(parameters.OS, 'MacOS') }}: + pool: + vmImage: 'macOS-15' + ${{ if eq(parameters.OS, 'Linux') }}: + pool: + name: ${{ parameters.PoolName }} + variables: + - name: runCodesignValidationInjection + value: false + timeoutInMinutes: 60 + steps: + - template: set-version-number-variables-step.yml + + - bash: | + echo "Downloading and installing Maven $(mavenVersion) for Linux..." + MAVEN_DIR="$(Agent.TempDirectory)/apache-maven-$(mavenVersion)" + # Download Maven binary + wget https://archive.apache.org/dist/maven/maven-3/$(mavenVersion)/binaries/apache-maven-$(mavenVersion)-bin.tar.gz -O $(Agent.TempDirectory)/maven.tar.gz + + # Extract to the temp directory + mkdir -p ${MAVEN_DIR} + tar -xzf $(Agent.TempDirectory)/maven.tar.gz -C $(Agent.TempDirectory) + + # Add Maven's bin directory to the PATH for subsequent tasks in the job + echo "##vso[task.prependpath]${MAVEN_DIR}/bin" + displayName: 'Install Maven (Linux)' + condition: and(succeeded(), eq(variables['Agent.OS'], 'Linux')) + + - script: | + echo "Maven is now on the PATH." + mvn --version + + - download: build + artifact: 'onnxruntime-java' + displayName: 'Download Final Jar' + + - task: Maven@4 + displayName: 'Download Dependencies into App Folder' + inputs: + mavenPomFile: '$(Build.SourcesDirectory)/tools/ci_build/java/pom.xml' + goals: 'dependency:copy-dependencies' + options: '-DoutputDirectory=$(Pipeline.Workspace)/build/onnxruntime-java' + publishJUnitTestResults: false + javaHomeOption: 'JDKVersion' + jdkVersionOption: '1.17' + mavenVersionOption: 'Default' + + - task: Bash@3 + displayName: 'Run Java Tests on Linux/macOS' + condition: and(succeeded(), in(variables['Agent.OS'], 'Linux', 'Darwin')) + inputs: + targetType: 'inline' + script: | + set -e -x + cd $(Pipeline.Workspace)/build/onnxruntime-java + rm -f *.asc + rm -f *.sha256 + rm -f *.sha512 + rm -f *.sha1 + rm -f *.md5 + rm -f *.pom + ls + cd .. + mkdir tests + cd tests + jar xf $(Pipeline.Workspace)/build/onnxruntime-java/testing.jar + rm -f $(Pipeline.Workspace)/build/onnxruntime-java/testing.jar + ls $(Pipeline.Workspace)/build/tests + echo "Java Version" + java -version + + # Set the correct library path based on the OS + os_name=$(uname) + if [[ "$os_name" == "Linux" ]]; then + echo "Platform: Linux. Setting LD_LIBRARY_PATH." + export LD_LIBRARY_PATH="$(pwd):$LD_LIBRARY_PATH" + java -cp '$(Pipeline.Workspace)/build/tests:$(Pipeline.Workspace)/build/onnxruntime-java/*' org.junit.platform.console.ConsoleLauncher --scan-classpath=$(Pipeline.Workspace)/build/tests \ + --fail-if-no-tests --disable-banner --reports-dir "$(Build.ArtifactStagingDirectory)/TestResults" + elif [[ "$os_name" == "Darwin" ]]; then + echo "Platform: macOS. Setting DYLD_LIBRARY_PATH." + export DYLD_LIBRARY_PATH="$(pwd):$DYLD_LIBRARY_PATH" + java -DUSE_WEBGPU=1 -DUSE_COREML=1 -cp '$(Pipeline.Workspace)/build/tests:$(Pipeline.Workspace)/build/onnxruntime-java/*' org.junit.platform.console.ConsoleLauncher --scan-classpath=$(Pipeline.Workspace)/build/tests \ + --fail-if-no-tests --disable-banner --reports-dir "$(Build.ArtifactStagingDirectory)/TestResults" + fi + + + - task: PublishTestResults@2 + displayName: 'Publish Test Results' + inputs: + testResultsFormat: 'JUnit' + testResultsFiles: '$(Build.ArtifactStagingDirectory)/TestResults/TEST-junit-jupiter.xml' + failTaskOnFailedTests: true diff --git a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-win.yml b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-win.yml new file mode 100644 index 0000000000000..de07e9e89dc81 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-win.yml @@ -0,0 +1,84 @@ +parameters: +- name: PoolName + type: string + +stages: +- stage: Final_Jar_Testing_Windows + dependsOn: [] + jobs: + - job: Final_Jar_Testing_Windows + workspace: + clean: all + pool: + name: ${{ parameters.PoolName }} + variables: + - name: runCodesignValidationInjection + value: false + timeoutInMinutes: 60 + steps: + - template: set-version-number-variables-step.yml + + - pwsh: | + echo "Downloading and installing Maven $(mavenVersion) for Windows..." + $MAVEN_DIR = "$(Agent.TempDirectory)\apache-maven-$(mavenVersion)" + # Download Maven binary + Invoke-WebRequest -Uri "https://archive.apache.org/dist/maven/maven-3/$(mavenVersion)/binaries/apache-maven-$(mavenVersion)-bin.zip" -OutFile "$(Agent.TempDirectory)\maven.zip" + + # Extract to the temp directory + Expand-Archive -Path "$(Agent.TempDirectory)\maven.zip" -DestinationPath "$(Agent.TempDirectory)" + + # Add Maven's bin directory to the PATH for subsequent tasks in the job + echo "##vso[task.prependpath]$MAVEN_DIR\bin" + displayName: 'Install Maven (Windows)' + + + - script: | + echo "Maven is now on the PATH." + mvn --version + + - download: build + artifact: 'onnxruntime-java' + displayName: 'Download Final Jar' + + - task: Maven@4 + displayName: 'Download Dependencies into App Folder' + inputs: + mavenPomFile: '$(Build.SourcesDirectory)/tools/ci_build/java/pom.xml' + goals: 'dependency:copy-dependencies' + options: '-DoutputDirectory=$(Pipeline.Workspace)/build/onnxruntime-java' + publishJUnitTestResults: false + javaHomeOption: 'JDKVersion' + jdkVersionOption: '1.17' + mavenVersionOption: 'Default' + + - task: PowerShell@2 + displayName: 'Run Java Tests on Windows' + condition: and(succeeded(), eq(variables['Agent.OS'], 'Windows_NT')) + inputs: + targetType: 'inline' + script: | + $ErrorActionPreference = "Stop" + cd $(Pipeline.Workspace)/build/onnxruntime-java + del *.asc + del *.sha256 + del *.sha512 + del *.md5 + del *.sha1 + del *.pom + cd .. + mkdir tests + cd tests + jar xf $(Pipeline.Workspace)/build/onnxruntime-java/testing.jar + del $(Pipeline.Workspace)/build/onnxruntime-java/testing.jar + dir $(Pipeline.Workspace)/build/tests + Write-Host "Running JUnit Tests..." + & java ` + -cp "$(Pipeline.Workspace)\build\tests;$(Pipeline.Workspace)\build\onnxruntime-java\*" org.junit.platform.console.ConsoleLauncher --scan-classpath=$(Pipeline.Workspace)\build\tests ` + --fail-if-no-tests --disable-banner --reports-dir "$($env:Build_ArtifactStagingDirectory)/TestResults" + + - task: PublishTestResults@2 + displayName: 'Publish Test Results' + inputs: + testResultsFormat: 'JUnit' + testResultsFiles: '$(Build.ArtifactStagingDirectory)/TestResults/TEST-junit-jupiter.xml' + failTaskOnFailedTests: true diff --git a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing.yml b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing.yml deleted file mode 100644 index bc40de130740a..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing.yml +++ /dev/null @@ -1,80 +0,0 @@ -parameters: -- name: OS - displayName: Opserating System - type: string - -- name: PoolName - type: string - -stages: -- stage: Final_Jar_Testing_${{parameters.OS}} - jobs: - - job: Final_Jar_Testing_${{parameters.OS}} - workspace: - clean: all - ${{ if eq(parameters.OS, 'MacOS') }}: - pool: - name: 'Azure Pipelines' - image: macOS-14 - os: macOS - ${{ if eq(parameters.OS, 'Linux') }}: - pool: - name: ${{ parameters.PoolName }} - os: linux - ${{ if eq(parameters.OS, 'Windows') }}: - pool: - name: ${{ parameters.PoolName }} - os: windows - variables: - - name: runCodesignValidationInjection - value: false - timeoutInMinutes: 60 - - steps: - - template: set-version-number-variables-step.yml - - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Final Jar' - ArtifactName: onnxruntime-java - TargetPath: '$(Build.BinariesDirectory)/final-jar' - - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Jar Tools' - ArtifactName: onnxruntime-java-tools - TargetPath: '$(Build.BinariesDirectory)/final-jar' - - - ${{ if eq(parameters.OS, 'Windows') }}: - - task: CmdLine@2 - inputs: - script: | - mkdir test - pushd test - jar xf $(Build.BinariesDirectory)\final-jar\testing.jar - popd - java -jar junit-platform-console-standalone-1.6.2.jar -cp .;.\test;protobuf-java-3.25.5.jar;onnxruntime-$(OnnxRuntimeVersion).jar --scan-class-path --fail-if-no-tests --disable-banner - workingDirectory: '$(Build.BinariesDirectory)\final-jar' - - ${{ else }}: - - task: Bash@3 - inputs: - targetType: 'inline' - script: | - set -e -x - echo "Java Version" - java -version - mkdir test - pushd test - jar xf '$(Build.BinariesDirectory)/final-jar/testing.jar' - popd - # if you want to run the tests in the power shell, you need to replace ':' to ';', that is, "-cp .;.\test;protobuf-java-3.25.5.jar;onnxruntime-$(OnnxRuntimeVersion).jar" - java -jar ./junit-platform-console-standalone-1.6.2.jar -cp .:./test:./protobuf-java-3.25.5.jar:./onnxruntime-$(OnnxRuntimeVersion).jar --scan-class-path --fail-if-no-tests --disable-banner - workingDirectory: '$(Build.BinariesDirectory)/final-jar' - env: - ${{ if eq(parameters.OS, 'MacOS') }}: - DYLD_LIBRARY_PATH: '$(Build.BinariesDirectory)/final-jar/test:$(DYLD_LIBRARY_PATH)' - ${{ if eq(parameters.OS, 'Linux') }}: - LD_LIBRARY_PATH: '$(Build.BinariesDirectory)/final-jar/test:$(LD_LIBRARY_PATH)' - - - ${{ if eq(parameters['OS'], 'MacOS') }}: - - template: use-xcode-version.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/jar-esrp-dll.yml b/tools/ci_build/github/azure-pipelines/templates/jar-esrp-dll.yml index b59ba551c222f..dd0e0898ecc3b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jar-esrp-dll.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jar-esrp-dll.yml @@ -3,28 +3,25 @@ parameters: type: string default: '' -- name: JarFileName - type: string - default: '' - steps: - - task: PowerShell@2 - displayName: 'ESRP Jar - Extract Jar File' - inputs: - targetType: filePath - filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\jar_esrp_dll.ps1 - arguments: extract '${{ parameters.JarFileDirectory }}' '${{ parameters.JarFileName }}' - workingDirectory: '$(Build.BinariesDirectory)' +- task: PowerShell@2 + displayName: 'ESRP Jar - Extract Jar File' + inputs: + targetType: filePath + filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\jar_esrp_dll.ps1 + arguments: extract '${{ parameters.JarFileDirectory }}' + workingDirectory: '$(Build.BinariesDirectory)' - - template: win-esrp-dll.yml - parameters: - FolderPath: '${{ parameters.JarFileDirectory }}\jar_extracted_full_files' - DisplayName: 'ESRP Jar - Sign Dlls' +- template: win-esrp-dll.yml + parameters: + FolderPath: '${{ parameters.JarFileDirectory }}\jar_extracted_full_files' + DisplayName: 'ESRP Jar - Sign Dlls' + DoEsrp: true # Assuming ESRP should always run when this template is called - - task: PowerShell@2 - displayName: 'ESRP Jar - Repack Jar File' - inputs: - targetType: filePath - filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\jar_esrp_dll.ps1 - arguments: repack '${{ parameters.JarFileDirectory }}' '${{ parameters.JarFileName }}' - workingDirectory: '$(Build.BinariesDirectory)' +- task: PowerShell@2 + displayName: 'ESRP Jar - Repack Jar File' + inputs: + targetType: filePath + filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\jar_esrp_dll.ps1 + arguments: repack '${{ parameters.JarFileDirectory }}' + workingDirectory: '$(Build.BinariesDirectory)' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml index df2aff0634819..98a52b08f32f2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml @@ -4,54 +4,25 @@ parameters: steps: - task: AzureKeyVault@2 - displayName: 'Get GnuPG signing keys' + displayName: "Get GnuPG signing keys" inputs: #The value below is the name of an ADO service connection. - azureSubscription: 'AIInfraBuildOnnxRuntimeOSS' - KeyVaultName: 'ort-release' - SecretsFilter: 'java-pgp-pwd,java-pgp-key' + azureSubscription: "AIInfraBuildOnnxRuntimeOSS" + KeyVaultName: "ort-release" + SecretsFilter: "java-pgp-pwd,java-pgp-key" RunAsPreJob: false - - task: CmdLine@2 - displayName: 'Sign jar files: GnuPG and sha256' + - task: UsePythonVersion@0 + displayName: "Use Python 3.12" inputs: - workingDirectory: '$(Build.SourcesDirectory)' - script: | - #!/bin/bash - set -e + versionSpec: "3.12" - jar_file_directory='${{ parameters.JarFileDirectory }}' - working_directory='$(Build.SourcesDirectory)' - original_private_key='$(java-pgp-key)' - original_passphrase='$(java-pgp-pwd)' - - private_key_file=$working_directory/private_key.txt - passphrase_file=$working_directory/passphrase.txt - - echo "Generating GnuPG key files." - printf "%s" "$original_private_key" >$private_key_file - printf "%s" "$original_passphrase" >$passphrase_file - echo "Generated GnuPG key files." - - echo "Importing GnuPG private key file." - gpg --batch --import $private_key_file - echo "Imported GnuPG private key file." - - for file in $(find $jar_file_directory -type f); do - echo "GnuPG signing to file: $file" - gpg --pinentry-mode loopback --passphrase-file $passphrase_file -ab $file - echo "GnuPG signed to file: $file" - done - - for file in $(find $jar_file_directory -type f); do - echo "Adding checksum of sha256 to file: $file" - sha256_value=$(sha256sum $file | awk '{print $1}') - echo $sha256_value" *"$(basename "$file") >$file.sha256 - echo "Added checksum of sha256 to file: $file" - done - - echo "GnuPG and sha256 signing to files completed." - echo "Deleting GnuPG key files." - rm -f $private_key_file - rm -f $passphrase_file - echo "Deleted GnuPG key files." + - task: PythonScript@0 + displayName: "Sign files: GnuPG, sha1, and md5" + env: + JAVA_PGP_PWD: $(java-pgp-pwd) + JAVA_PGP_KEY: $(java-pgp-key) + inputs: + scriptPath: "$(Build.SourcesDirectory)/tools/ci_build/github/windows/sign_java_artifacts.py" + arguments: "${{ parameters.JarFileDirectory }}" + workingDirectory: "$(Build.SourcesDirectory)" \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml deleted file mode 100644 index ef845dc3bf243..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml +++ /dev/null @@ -1,78 +0,0 @@ -parameters: - - name: JarFileDirectory - type: string - -steps: - - task: AzureKeyVault@2 - displayName: 'Get GnuPG signing keys' - inputs: - azureSubscription: 'AIInfraBuildOnnxRuntimeOSS' - KeyVaultName: 'ort-release' - SecretsFilter: 'java-pgp-pwd,java-pgp-key' - RunAsPreJob: false - - - task: PowerShell@2 - displayName: 'Sign jar files: GnuPG and sha256' - inputs: - targetType: 'inline' - pwsh: true - workingDirectory: '$(Build.SourcesDirectory)' - script: | - $jar_file_directory = '${{ parameters.JarFileDirectory }}' - $working_directory = '$(Build.SourcesDirectory)' - - $original_passphrase='$(java-pgp-pwd)' - $original_private_key='$(java-pgp-key)' - - $gpg_exe_path = "C:\Program Files (x86)\gnupg\bin\gpg.exe" - - $passphrase_file = Join-Path -Path $working_directory -ChildPath "passphrase.txt" - $private_key_file = Join-Path -Path $working_directory -ChildPath "private_key.txt" - - Write-Host "Generating GnuPG key files." - Out-File -FilePath $passphrase_file -InputObject $original_passphrase -NoNewline -Encoding ascii - Out-File -FilePath $private_key_file -InputObject $original_private_key -NoNewline -Encoding ascii - Write-Host "Generated GnuPG key files." - - Write-Host "Importing GnuPG private key file." - & $gpg_exe_path --batch --import $private_key_file - if ($lastExitCode -ne 0) { - Write-Host -Object "GnuPG importing private key command failed. Exitcode: $exitCode" - exit $lastExitCode - } - Write-Host "Imported GnuPG private key file." - - $targeting_original_files = Get-ChildItem $jar_file_directory -Recurse -Force -File -Name - foreach ($file in $targeting_original_files) { - $file_path = Join-Path $jar_file_directory -ChildPath $file - Write-Host "GnuPG signing to file: "$file_path - & $gpg_exe_path --pinentry-mode loopback --passphrase-file $passphrase_file -ab $file_path - if ($lastExitCode -ne 0) { - Write-Host -Object "GnuPG signing file command failed. Exitcode: $exitCode" - exit $lastExitCode - } - Write-Host "GnuPG signed to file: "$file_path - } - - $PSDefaultParameterValues['Out-File:Encoding'] = 'utf8NoBOM' - $sha256sum_exe_path = "C:\Program Files\Git\usr\bin\sha256sum.exe" - $targeting_asc_files = Get-ChildItem $jar_file_directory -Recurse -Force -File -Name - $original_location = Get-Location - Set-Location $jar_file_directory - foreach ($file in $targeting_asc_files) { - Write-Host "Adding checksum of sha256 to file: "$file - $file_path_sha256 = $file + ".sha256" - & $sha256sum_exe_path $file 1>$file_path_sha256 - if ($lastExitCode -ne 0) { - Write-Host -Object "sha256sum command failed. Exitcode: $exitCode" - exit $lastExitCode - } - Write-Host "Added checksum of sha256 to file: "$file - } - Set-Location $original_location - - Write-Host "GnuPG and sha256 signing to files completed." - Write-Host "Deleting GnuPG key files." - Remove-Item -Path $passphrase_file - Remove-Item -Path $private_key_file - Write-Host "Deleted GnuPG key files." diff --git a/tools/ci_build/github/azure-pipelines/templates/jar-packaging.yml b/tools/ci_build/github/azure-pipelines/templates/jar-packaging.yml new file mode 100644 index 0000000000000..098d7e3162d1f --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/jar-packaging.yml @@ -0,0 +1,61 @@ +# This template packages the Java artifacts for either CPU or GPU. +# It calls the PowerShell script with the correct package type and ensures +# that the correct final JAR file is signed and published. +# Currently this file only runs on Windows x64. + +parameters: + - name: package_type + type: string + default: 'cpu' + values: + - 'cpu' + - 'gpu' + +steps: +- checkout: self + submodules: false + +- task: UsePythonVersion@0 + inputs: + versionSpec: '3.13' + addToPath: true + +- task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + +- template: set-version-number-variables-step.yml + +- script: python -m pip install -r $(Build.SourcesDirectory)\tools\ci_build\github\windows\python\requirements.txt + +- task: PythonScript@0 + displayName: 'Package Java Artifacts' + inputs: + scriptPath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\jar_packaging.py + arguments: '--package_type ${{ parameters.package_type }} --build_dir $(Build.BinariesDirectory)' + workingDirectory: '$(Build.BinariesDirectory)\java-artifact' + +- script: dir $(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64 + +- template: jar-esrp-dll.yml + parameters: + JarFileDirectory: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' + +- task: AzureKeyVault@2 + displayName: 'Get GnuPG signing keys' + inputs: + azureSubscription: 'AIInfraBuildOnnxRuntimeOSS' + KeyVaultName: 'ort-release' + SecretsFilter: 'java-pgp-pwd,java-pgp-key' + RunAsPreJob: false + +- task: PythonScript@0 + displayName: 'Sign files: GnuPG, sha1, and md5' + env: + JAVA_PGP_PWD: $(java-pgp-pwd) + JAVA_PGP_KEY: $(java-pgp-key) + inputs: + scriptPath: '$(Build.SourcesDirectory)/tools/ci_build/github/windows/sign_java_artifacts.py' + arguments: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' + workingDirectory: '$(Build.SourcesDirectory)' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/java-api-artifacts-package-and-publish-steps-posix.yml b/tools/ci_build/github/azure-pipelines/templates/java-api-artifacts-package-and-publish-steps-posix.yml index 1c4b0ae5f4137..166b03f6b55e1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/java-api-artifacts-package-and-publish-steps-posix.yml +++ b/tools/ci_build/github/azure-pipelines/templates/java-api-artifacts-package-and-publish-steps-posix.yml @@ -1,28 +1,50 @@ # sets up common build tools for the windows build machines before build parameters: - arch: 'linux-x64' - buildConfig: 'RelWithDebInfo' - artifactName: 'onnxruntime-java-linux-x64' - libraryName: 'libonnxruntime.so' - nativeLibraryName: 'libonnxruntime4j_jni.so' - version: '' - is1ES: false +- name: buildConfig + displayName: Build Configuration + type: string + values: + - 'Release' + - 'Debug' + - 'RelWithDebInfo' + +- name: artifactName + displayName: Artifact Name + type: string + #default: 'onnxruntime-java' + +- name: libraryName + displayName: Main Library Name + type: string + #default: 'libonnxruntime.so' + +- name: nativeLibraryName + displayName: JNI Library Name + type: string + #default: 'libonnxruntime4j_jni.so' + +- name: arch + displayName: Architecture + type: string + #default: 'linux-x64' + steps: -- task: ShellScript@2 - displayName: 'Copy build artifacts for zipping' +- task: PythonScript@0 + inputs: + scriptSource: 'filePath' + scriptPath: 'tools/ci_build/linux_java_copy_strip_binary.py' + arguments: >- + --binary-dir $(Build.BinariesDirectory) + --build-config ${{parameters.buildConfig}} + --artifact-name ${{parameters.artifactName}} + --lib-name ${{parameters.libraryName}} + --native-lib-name ${{parameters.nativeLibraryName}} + --arch ${{parameters.arch}} + displayName: 'Package ONNX Runtime Java Native Libs' + +- task: 1ES.PublishPipelineArtifact@1 inputs: - scriptPath: 'tools/ci_build/github/linux/java_copy_strip_binary.sh' - args: '-r $(Build.BinariesDirectory) -c ${{parameters.buildConfig}} -a ${{parameters.artifactName}} -l ${{parameters.libraryName}} -n ${{parameters.nativeLibraryName}} -v ${{parameters.version}} -h ${{parameters.arch}}' - workingDirectory: '$(Build.BinariesDirectory)/${{parameters.buildConfig}}' + targetPath: '$(Build.BinariesDirectory)/${{parameters.artifactName}}' + artifactName: 'drop-${{parameters.artifactName}}' -- ${{ if eq(parameters.is1ES, true) }}: - - task: 1ES.PublishPipelineArtifact@1 - inputs: - targetPath: '$(Build.BinariesDirectory)/${{parameters.artifactName}}' - artifactName: 'drop-${{parameters.artifactName}}' -- ${{ if eq(parameters.is1ES, false) }}: - - task: PublishBuildArtifacts@1 - inputs: - pathtoPublish: '$(Build.BinariesDirectory)/${{parameters.artifactName}}' - artifactName: 'drop-${{parameters.artifactName}}' diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index 57703239fc594..73c774b9a45e9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.36.1.250708' + default: '2.37.1.250807' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_training_test_data.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_training_test_data.yml deleted file mode 100644 index 8f6434f7ac40d..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_training_test_data.yml +++ /dev/null @@ -1,8 +0,0 @@ -steps: - - script: | - azcopy cp --recursive https://lotusscus.blob.core.windows.net/orttrainingtestdatascus/mnist/ $(Agent.TempDirectory) - displayName: 'Download Training Test Data MNIST' - - - script: | - ls -al $(Agent.TempDirectory)/mnist - displayName: 'Print contents of Training Test Data MNIST' diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml index 674f16d8e9332..681138a5ab3d1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml @@ -20,9 +20,16 @@ parameters: steps: - ${{ if eq(parameters.DownloadCUDA, true) }}: - - powershell: | - azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.CudaVersion }} $(Agent.TempDirectory) + - task: AzureCLI@2 displayName: 'Download CUDA SDK v${{ parameters.CudaVersion }}' + inputs: + azureSubscription: AIInfraBuildOnnxRuntimeOSS + scriptType: 'batch' + scriptLocation: 'inlineScript' + inlineScript: | + set AZCOPY_AUTO_LOGIN_TYPE=AZCLI + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.CudaVersion }} $(Agent.TempDirectory) + - powershell: | Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}\bin;$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}\extras\CUPTI\lib64" displayName: 'Append CUDA SDK Directory to PATH' @@ -31,6 +38,7 @@ steps: inputs: script: | echo %PATH% + dir $(Agent.TempDirectory) displayName: 'Print PATH after download CUDA SDK' - ${{ if eq(parameters.DownloadTRT, true) }}: @@ -51,9 +59,16 @@ steps: echo $(trtCudaVersion) && echo TensorRT-${{ parameters.TrtVersion }}.Windows10.x86_64.cuda-$(trtCudaVersion) displayName: Get trtCudaVersion and Directory Name - - powershell: | - azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/local/TensorRT-${{ parameters.TrtVersion }}.Windows10.x86_64.cuda-$(trtCudaVersion) $(Agent.TempDirectory) + + - task: AzureCLI@2 displayName: 'Download TensorRT-${{ parameters.TrtVersion }}.Windows10.x86_64.cuda-$(trtCudaVersion)' + inputs: + azureSubscription: AIInfraBuildOnnxRuntimeOSS + scriptType: 'batch' + scriptLocation: 'inlineScript' + inlineScript: | + set AZCOPY_AUTO_LOGIN_TYPE=AZCLI + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/local/TensorRT-${{ parameters.TrtVersion }}.Windows10.x86_64.cuda-$(trtCudaVersion) $(Agent.TempDirectory) - powershell: | Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\TensorRT-${{ parameters.TrtVersion }}.Windows10.x86_64.cuda-$(trtCudaVersion)\lib" diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index d2e401f3f6ab4..8c15fe111593f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.36.1.250708' + default: '2.37.1.250807' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/init_linux_qnn_sdk_x64.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/init_linux_qnn_sdk_x64.yml index b7fb8a51f28be..0ce6f3ec50a06 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/init_linux_qnn_sdk_x64.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/init_linux_qnn_sdk_x64.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.36.1.250708' + default: '2.37.1.250807' steps: - bash: | diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index ef0f4c6e0883c..7b184308bfc66 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -83,6 +83,9 @@ jobs: versionSpec: '3.12' addToPath: true architecture: $(buildArch) + - task: NodeTool@0 + inputs: + versionSpec: '22.x' - ${{if eq(parameters.WithCache, true)}}: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml index 7547b841c7480..56cc84a90dc68 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml @@ -1,3 +1,5 @@ +# This stage fetch built macOS binaries from other stages, sign the binaries, then repack them + parameters: - name: AdditionalBuildFlags displayName: Additional build flags for build.py @@ -13,31 +15,11 @@ parameters: - 1 - 0 -- name: BuildForAllArchs - displayName: Build for all CPU ARCHs - type: boolean - -- name: WithCache - displayName: Build with Cache - type: boolean - default: false - - name: DoESRP displayName: Do ESRP type: boolean default: false -# these 2 parameters are used for debugging. -- name: SpecificArtifact - displayName: Use Specific Artifact (Debugging only) - type: boolean - default: false - -- name: BuildId - displayName: Pipeline BuildId, you could find it in the URL - type: string - default: '0' - stages: - stage: MacOS_C_API_Packaging_CPU dependsOn: [] @@ -47,21 +29,12 @@ stages: MacosArch: 'x86_64' AllowReleasedOpsetOnly: ${{ parameters.AllowReleasedOpsetOnly }} AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} - WithCache: ${{ parameters.WithCache }} - - ${{ if eq(parameters.BuildForAllArchs, true) }}: - - template: mac-cpu-packing-jobs.yml - parameters: - MacosArch: 'arm64' - AllowReleasedOpsetOnly: ${{ parameters.AllowReleasedOpsetOnly }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} - WithCache: ${{ parameters.WithCache }} - - template: mac-cpu-packing-jobs.yml - parameters: - MacosArch: 'universal2' - AllowReleasedOpsetOnly: ${{ parameters.AllowReleasedOpsetOnly }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} - WithCache: ${{ parameters.WithCache }} + - template: mac-cpu-packing-jobs.yml + parameters: + MacosArch: 'arm64' + AllowReleasedOpsetOnly: ${{ parameters.AllowReleasedOpsetOnly }} + AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} - stage: MacOS_C_API_Package_Publish dependsOn: MacOS_C_API_Packaging_CPU @@ -71,68 +44,56 @@ stages: name: 'Azure Pipelines' image: 'macOS-14' os: 'macOS' + templateContext: + inputs: + - input: pipelineArtifact + artifactName: onnxruntime-osx-x86_64 # The files in this artifact are not signed + targetPath: $(Build.ArtifactStagingDirectory) + - input: pipelineArtifact + artifactName: onnxruntime-osx-arm64 # The files in this artifact are not signed + targetPath: $(Build.ArtifactStagingDirectory) + outputs: + - output: pipelineArtifact + targetPath: $(Build.ArtifactStagingDirectory) + artifactName: 'onnxruntime-osx' # The files in this artifact are signed steps: - - checkout: none - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline onnxruntime-osx-x86_64' - ArtifactName: 'onnxruntime-osx-x86_64' - TargetPath: '$(Build.ArtifactStagingDirectory)' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} + - checkout: self - - ${{ if eq(parameters.BuildForAllArchs, true) }}: - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline onnxruntime-osx-arm64' - ArtifactName: 'onnxruntime-osx-arm64' - TargetPath: '$(Build.ArtifactStagingDirectory)' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline onnxruntime-osx-universal2' - ArtifactName: 'onnxruntime-osx-universal2' - TargetPath: '$(Build.ArtifactStagingDirectory)' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.13' + addToPath: true - - ${{ if eq(parameters.DoESRP, true)}}: - - script: | - pushd '$(Build.ArtifactStagingDirectory)' - find . '*.tgz' -exec tar -zxvf {} \; - rm -f *.tgz; - find . -type d -name 'onnxruntime-osx-*' -exec zip -FSr --symlinks {}.zip {} \; - find . -type d -name 'onnxruntime-osx-*' -exec rm -rf {} \; - ls -l - popd - displayName: tgz to zip + - task: PythonScript@0 + displayName: 'Prepare, Create Universal Binary, and Zip with Python' + inputs: + scriptSource: 'filePath' + scriptPath: 'tools/ci_build/prepare_macos_package.py' + arguments: '--staging_dir $(Build.ArtifactStagingDirectory)' - - template: mac-esrp-dylib.yml - parameters: - FolderPath: '$(Build.ArtifactStagingDirectory)' - Pattern: '*.zip' + - template: mac-esrp-dylib.yml + parameters: + FolderPath: '$(Build.ArtifactStagingDirectory)' + Pattern: '*.zip' - - script: | - pushd '$(Build.ArtifactStagingDirectory)' - find . '*.zip' -exec unzip {} \; - rm -f *.zip; - find . -type d -name 'onnxruntime-osx-*' -exec tar -czf {}.tgz {} \; - find . -type d -name 'onnxruntime-osx-*' -exec rm -rf {} \; - ls -l - popd - displayName: zip to tgz - - bash: | - set -ex - mkdir -p $(Agent.TempDirectory)/macpackage - find $(Build.ArtifactStagingDirectory) -name "*.tgz" -exec tar -zxvf {} -C $(Agent.TempDirectory)/macpackage \; - find $(Agent.TempDirectory)/macpackage -name "*.dylib" -exec codesign -dvvv {} \; - find $(Agent.TempDirectory)/macpackage -name "*.dylib" -exec ls -l {} \; - rm -rf $(Agent.TempDirectory)/macpackage - displayName: 'Verify code signing' + - script: | + set -ex + mkdir temp + cd temp + find $(Build.ArtifactStagingDirectory) -name '*.zip' -exec unzip {} \; + rm -rf $(Build.ArtifactStagingDirectory)/*; + find . -type d -name 'onnxruntime-osx-*' -exec tar -czf {}.tgz {} \; + ls -l + mv *.tgz $(Build.ArtifactStagingDirectory) + displayName: 'Unzip Signed Files and Repackage to TGZ' + workingDirectory: $(Agent.TempDirectory) - - task: 1ES.PublishPipelineArtifact@1 - inputs: - targetPath: '$(Build.ArtifactStagingDirectory)' - artifactName: 'onnxruntime-osx' - condition: 'succeededOrFailed()' + - bash: | + set -ex + mkdir -p macpackage + find $(Build.ArtifactStagingDirectory) -name "*.tgz" -exec tar -zxvf {} -C macpackage \; + find macpackage -name "*.dylib" -exec codesign -dvvv {} \; + find macpackage -name "*.dylib" -exec ls -l {} \; + rm -rf macpackage + displayName: 'Verify Code Signing' + workingDirectory: $(Agent.TempDirectory) diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-steps.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-steps.yml index 9a8264a288582..c43bfe2886f22 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-steps.yml @@ -4,56 +4,22 @@ parameters: values: - 'x86_64' - 'arm64' - - 'universal2' - default: 'x86_64' - name: AdditionalBuildFlags displayName: Additional build flags for build.py type: string default: '' -- name: BuildJava - displayName: Build with Java - type: boolean - default: true -- name: BuildNodejs - displayName: Build with Nodejs - type: boolean - default: false - -- name: WithCache - displayName: Build with Cache - type: boolean - default: false - -- name: CacheDir - displayName: Cache Directory - type: string - default: '' - -- name: Today - type: string - default: "" steps: -- template: mac-build-step-with-cache.yml - parameters: - WithCache: ${{ parameters.WithCache }} - Today: ${{ parameters.Today }} - AdditionalKey: onnxruntime_${{ parameters.MacosArch }} - CacheDir: ${{ parameters.CacheDir }} - ChangeEveryCommit: true - BuildStep: - - script: | - set -e -x - rm -rf $(Build.BinariesDirectory)/Release - python3 $(Build.SourcesDirectory)/tools/ci_build/build.py --update --build ${{ parameters.AdditionalBuildFlags }} --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags --build_shared_lib --config Release --use_vcpkg --use_vcpkg_ms_internal_asset_cache - cd $(Build.BinariesDirectory)/Release - make install DESTDIR=$(Build.BinariesDirectory)/installed - displayName: 'Build ${{ parameters.MacosArch }}' - env: - CCACHE_DIR: ${{ parameters.CacheDir }} +- script: | + set -e -x + rm -rf $(Build.BinariesDirectory)/Release + python3 $(Build.SourcesDirectory)/tools/ci_build/build.py --update --build ${{ parameters.AdditionalBuildFlags }} --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --parallel 3 --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags --build_shared_lib --config Release --use_vcpkg --use_vcpkg_ms_internal_asset_cache + cd $(Build.BinariesDirectory)/Release + make install DESTDIR=$(Build.BinariesDirectory)/installed + displayName: 'Build ${{ parameters.MacosArch }}' - ${{ if eq(parameters.MacosArch, 'x86_64') }}: - script: | @@ -77,9 +43,9 @@ steps: replaceExistingArchive: true - script: | - set -e -x - mkdir -p $(Build.ArtifactStagingDirectory)/testdata - cp $(Build.BinariesDirectory)/Release/libcustom_op_library.dylib $(Build.ArtifactStagingDirectory)/testdata + set -e -x + mkdir -p $(Build.ArtifactStagingDirectory)/testdata + cp $(Build.BinariesDirectory)/Release/libcustom_op_library.dylib $(Build.ArtifactStagingDirectory)/testdata displayName: 'Copy libcustom_op_library.dylib to ArtifactStagingDirectory' condition: and(succeeded(), eq('${{ parameters.MacosArch }}', 'x86_64')) @@ -88,23 +54,19 @@ steps: targetPath: '$(Build.ArtifactStagingDirectory)' artifactName: 'onnxruntime-osx-${{ parameters.MacosArch }}' -- ${{ if eq(parameters.BuildJava, true) }}: - - template: java-api-artifacts-package-and-publish-steps-posix.yml - parameters: - arch: 'osx-${{ parameters.MacosArch }}' - buildConfig: 'Release' - artifactName: 'onnxruntime-java-osx-${{ parameters.MacosArch }}' - version: '$(OnnxRuntimeVersion)' - libraryName: 'libonnxruntime.dylib' - nativeLibraryName: 'libonnxruntime4j_jni.dylib' - is1ES: true +- template: java-api-artifacts-package-and-publish-steps-posix.yml + parameters: + arch: 'osx-${{ parameters.MacosArch }}' + buildConfig: 'Release' + artifactName: 'onnxruntime-java-osx-${{ parameters.MacosArch }}' + libraryName: 'libonnxruntime.dylib' + nativeLibraryName: 'libonnxruntime4j_jni.dylib' -- ${{ if eq(parameters.BuildNodejs, true) }}: - - template: nodejs-artifacts-package-and-publish-steps-posix.yml - parameters: - ${{ if eq(parameters.MacosArch, 'x86_64') }}: - arch: x64 - ${{ if eq(parameters.MacosArch, 'arm64') }}: - arch: arm64 - os: 'darwin' - artifactName: 'drop-onnxruntime-nodejs-osx-${{ parameters.MacosArch }}' +- template: nodejs-artifacts-package-and-publish-steps-posix.yml + parameters: + ${{ if eq(parameters.MacosArch, 'x86_64') }}: + arch: x64 + ${{ if eq(parameters.MacosArch, 'arm64') }}: + arch: arm64 + os: 'darwin' + artifactName: 'drop-onnxruntime-nodejs-osx-${{ parameters.MacosArch }}' diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml index 2c660e23d8648..c63c74fb997fe 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml @@ -4,13 +4,6 @@ parameters: values: - 'x86_64' - 'arm64' - - 'universal2' - default: 'x86_64' - -- name: WithCache - displayName: Build with Cache - type: boolean - default: false - name: AdditionalBuildFlags displayName: Additional build flags for build.py @@ -31,11 +24,8 @@ jobs: workspace: clean: all variables: - MACOSX_DEPLOYMENT_TARGET: '13.3' + MACOSX_DEPLOYMENT_TARGET: '13.4' ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} - TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - PROTO_CACHE_DIR: $(Pipeline.Workspace)/ccache_proto - ORT_CACHE_DIR: $(Pipeline.Workspace)/ccache_ort pool: name: "Azure Pipelines" image: 'macOS-14' @@ -46,65 +36,34 @@ jobs: clean: true submodules: none - - task: UsePythonVersion@0 - displayName: Use Python 3.10 - inputs: - versionSpec: 3.10 - - - task: NodeTool@0 - inputs: - versionSpec: '22.x' - - task: JavaToolInstaller@0 inputs: versionSpec: "17" jdkArchitectureOption: "x64" jdkSourceOption: 'PreInstalled' - - template: set-version-number-variables-step.yml - - template: use-xcode-version.yml + - template: setup-build-tools.yml + parameters: + host_cpu_arch: ${{ parameters.MacosArch }} + + - template: set-version-number-variables-step.yml + - script: | set -e -x - export PATH=$(Build.BinariesDirectory)/installed/bin:$PATH export ONNX_ML=1 export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=ON -DONNX_WERROR=OFF" - python3 -m pip install -r '$(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/requirements.txt' - - - - ${{ if eq(parameters.MacosArch, 'universal2') }}: - - template: mac-cpu-packaging-steps.yml - parameters: - MacosArch: ${{ parameters.MacosArch }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --use_coreml --use_webgpu --no_kleidiai --cmake_extra_defines CMAKE_OSX_ARCHITECTURES="arm64;x86_64" - BuildJava: false - BuildNodejs: false - WithCache: ${{ parameters.WithCache }} - ${{ if eq(parameters.WithCache, true) }}: - Today: $(TODAY) - CacheDir: $(ORT_CACHE_DIR) + python3 -m pip install -r '$(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/requirements.txt' - ${{ if eq(parameters.MacosArch, 'arm64') }}: - template: mac-cpu-packaging-steps.yml parameters: MacosArch: ${{ parameters.MacosArch }} AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml --use_webgpu --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=arm64 - BuildJava: true - BuildNodejs: true - WithCache: ${{ parameters.WithCache }} - ${{ if eq(parameters.WithCache, true) }}: - Today: $(TODAY) - CacheDir: $(ORT_CACHE_DIR) - ${{ if eq(parameters.MacosArch, 'x86_64') }}: - template: mac-cpu-packaging-steps.yml parameters: MacosArch: ${{ parameters.MacosArch }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml --use_webgpu - BuildJava: true - BuildNodejs: true - WithCache: ${{ parameters.WithCache }} - ${{ if eq(parameters.WithCache, true) }}: - Today: $(TODAY) - CacheDir: $(ORT_CACHE_DIR) + AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml --use_webgpu --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=x86_64 diff --git a/tools/ci_build/github/azure-pipelines/templates/make_java_win_binaries.yml b/tools/ci_build/github/azure-pipelines/templates/make_java_win_binaries.yml index 0d62ed7907a67..d1ea61ada90c3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/make_java_win_binaries.yml +++ b/tools/ci_build/github/azure-pipelines/templates/make_java_win_binaries.yml @@ -1,59 +1,50 @@ parameters: - - name: msbuildPlatform - type: string - - name: java_artifact_id - type: string - - name: buildOnly - type: boolean +- name: msbuildPlatform + type: string +- name: java_artifact_id + type: string +- name: buildOnly + type: boolean + default: false +- name: PreReleaseVersionSuffixString + displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. + type: string + values: + - alpha + - beta + - rc + - none + +- name: PreReleaseVersionSuffixNumber + displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. + type: number steps: - - task: CmdLine@2 - displayName: 'Gradle cmakeCheck' - inputs: - ${{ if eq(parameters.buildOnly, true) }}: - script: | - call gradlew.bat testClasses -DcmakeBuildDir=$(Build.BinariesDirectory)\RelWithDebInfo - call gradlew.bat cmakeCheck -x test -DcmakeBuildDir=$(Build.BinariesDirectory)\RelWithDebInfo --warning-mode all - workingDirectory: $(Build.SourcesDirectory)\java - ${{ else }}: - script: | - call gradlew.bat cmakeCheck -DcmakeBuildDir=$(Build.BinariesDirectory)\RelWithDebInfo --warning-mode all - workingDirectory: $(Build.SourcesDirectory)\java +- task: PowerShell@2 + displayName: 'Build and Package Java Artifacts' + inputs: + targetType: 'inline' + script: | + # Define arguments for the Python script + $scriptArgs = @( + "--sources-dir", "$(Build.SourcesDirectory)", + "--binaries-dir", "$(Build.BinariesDirectory)", + "--platform", "${{ parameters.msbuildPlatform }}", + "--build-config", "RelWithDebInfo", + "--java-artifact-id", "${{ parameters.java_artifact_id }}", + "--pre-release-version-suffix-string", "${{ parameters.PreReleaseVersionSuffixString }}", + "--pre-release-version-suffix-number", "${{ parameters.PreReleaseVersionSuffixNumber }}", + "--commit-hash", "$(OnnxRuntimeGitCommitHash)" + ) + + # Conditionally add the --build-only flag if the parameter is true + if ('${{ parameters.buildOnly }}' -eq 'True') { + $scriptArgs += "--build-only" + } + + # Define the path to the python script within your repository + $scriptPath = "$(Build.SourcesDirectory)/tools/ci_build/manage_java_artifacts.py" - - task: CmdLine@2 - displayName: 'Add symbols and notices to Java' - inputs: - script: | - @echo on - cd $(Build.BinariesDirectory)\RelWithDebInfo - set NATIVE_FOLDER=$(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\stage\ai\onnxruntime\native\win-x64 - mkdir %NATIVE_FOLDER% - echo "Directories created" - copy .\java\build\libs\*.jar $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }} - pushd $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }} - set artifact_id=${{ parameters.java_artifact_id }} - jar xf onnxruntime-$(OnnxRuntimeVersion).jar META-INF\maven\com.microsoft.onnxruntime\%artifact_id%\pom.xml - move META-INF\maven\com.microsoft.onnxruntime\%artifact_id%\pom.xml onnxruntime-$(OnnxRuntimeVersion).pom - rd /s /q META-INF - popd - copy .\RelWithDebInfo\onnxruntime.pdb %NATIVE_FOLDER% - copy .\RelWithDebInfo\onnxruntime4j_jni.pdb %NATIVE_FOLDER% - copy $(Build.SourcesDirectory)\docs\Privacy.md $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\stage\Privacy.md - copy $(Build.SourcesDirectory)\ThirdPartyNotices.txt $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\stage\ThirdPartyNotices.txt - @echo $(OnnxRuntimeGitCommitHash) > $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\stage\GIT_COMMIT_ID - pushd $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\stage - jar uf $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\onnxruntime-$(OnnxRuntimeVersion).jar ai\onnxruntime\native\win-x64\onnxruntime.pdb - jar uf $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\onnxruntime-$(OnnxRuntimeVersion).jar ai\onnxruntime\native\win-x64\onnxruntime4j_jni.pdb - jar uf $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\onnxruntime-$(OnnxRuntimeVersion).jar Privacy.md ThirdPartyNotices.txt GIT_COMMIT_ID - popd - pushd $(Build.SourcesDirectory)\java\build\classes\java\test - if %errorlevel% neq 0 exit /b %errorlevel% - jar cvf $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\testing.jar . - if %errorlevel% neq 0 exit /b %errorlevel% - popd - pushd $(Build.SourcesDirectory)\java\build\resources\test - rd /s /q ai\onnxruntime\native - jar uvf $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\testing.jar . - popd - rd /s /q $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\stage - dir /s /b $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }} + # Execute the Python script, passing all arguments + Write-Host "Executing Python script: $scriptPath with arguments: $($scriptArgs -join ' ')" + python $scriptPath $scriptArgs \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index 2168214527c91..3d662ffbb18dd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -26,7 +26,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 - name: is1ES displayName: 'Whether the pipeline is running in 1ES' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-macos.yml b/tools/ci_build/github/azure-pipelines/templates/py-macos.yml new file mode 100644 index 0000000000000..c8a26481d6205 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-macos.yml @@ -0,0 +1,75 @@ +parameters: +- name: arch + type: string + +- name: python_version + type: string + +- name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +- name: extra_build_arg + type: string + default: '' + +jobs: +- job: Mac_${{ parameters.arch }}_${{ replace(parameters.python_version,'.','_') }} + timeoutInMinutes: 240 + workspace: + clean: all + pool: + name: "Azure Pipelines" + image: "macOS-15" + os: macOS + templateContext: + outputs: + - output: pipelineArtifact + targetPath: $(Build.SourcesDirectory)/build/Release/dist/fixed_wheels + artifactName: onnxruntime-macos-${{ parameters.arch }}_${{ replace(parameters.python_version,'.','_') }} + + variables: + - name: MACOSX_DEPLOYMENT_TARGET + value: '13.4' + + steps: + - checkout: self + clean: true + submodules: none + + - template: use-xcode-version.yml + parameters: + xcodeVersion: '16.4.0' + + + - template: setup-build-tools.yml + parameters: + host_cpu_arch: ${{ parameters.arch }} + python_version: ${{ parameters.python_version }} + + - script: | + set -e -x + export _PYTHON_HOST_PLATFORM=macosx-${{variables.MACOSX_DEPLOYMENT_TARGET}}-${{ parameters.arch }} + python3 -m pip install -r '$(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/requirements.txt' + python3 $(Build.SourcesDirectory)/tools/ci_build/build.py \ + --build_dir $(Build.SourcesDirectory)/build \ + --use_vcpkg --use_vcpkg_ms_internal_asset_cache \ + --use_binskim_compliant_compile_flags \ + --config Release \ + --build_wheel \ + --use_coreml ${{ parameters.extra_build_arg }} \ + --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=${{ parameters.arch }} \ + --update --skip_submodule_sync --build --parallel + python -m pip install --upgrade delocate + cd '$(Build.SourcesDirectory)/build/Release/dist' + ls + for file in *.whl + do + delocate-listdeps "$file" + delocate-wheel --require-archs=${{ parameters.arch }} -w fixed_wheels -v "$file" + done \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test-linux.yml similarity index 89% rename from tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml rename to tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test-linux.yml index be9707e8f3f65..ad3116b45d52a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test-linux.yml @@ -54,7 +54,8 @@ jobs: FILE_NAME="${files[0]}" FILE_NAME=$(basename $FILE_NAME) PYTHON_PACKAGE_NAME=$(echo "$FILE_NAME" | cut -f 1 -d '-') - python3 -m pip install --find-links "$(Pipeline.Workspace)/build/onnxruntime-${{ parameters.arch }}-${{ parameters.ep }}" $PYTHON_PACKAGE_NAME + python3 -m pip install coloredlogs flatbuffers numpy packaging protobuf sympy + python3 -m pip install --no-index --find-links . $PYTHON_PACKAGE_NAME python3 -m pip show $PYTHON_PACKAGE_NAME python3 -c "import onnxruntime as ort; print(ort.__version__)" workingDirectory: $(Pipeline.Workspace)/build/onnxruntime-${{ parameters.arch }}-${{ parameters.ep }} diff --git a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test-macos.yml b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test-macos.yml new file mode 100644 index 0000000000000..1e369eb74f3ec --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test-macos.yml @@ -0,0 +1,41 @@ +parameters: +- name: python_version + type: string + +- name: os_version + type: string + +jobs: +- job: Test_MAC_Wheels_${{ parameters.os_version }}_${{ replace(parameters.python_version,'.','_') }} + timeoutInMinutes: 30 + workspace: + clean: all + pool: + vmImage: 'macOS-${{ parameters.os_version }}' + steps: + - checkout: none + + - task: UsePythonVersion@0 + displayName: 'Use Python' + inputs: + versionSpec: ${{ parameters.python_version }} + + - download: build # pipeline resource identifier. + artifact: onnxruntime-macos-x86_64_${{ replace(parameters.python_version,'.','_') }} + + - task: Bash@3 + inputs: + targetType: 'inline' + script: | + set -ex + files=(*.whl) + FILE_NAME="${files[0]}" + FILE_NAME=$(basename $FILE_NAME) + PYTHON_PACKAGE_NAME=$(echo "$FILE_NAME" | cut -f 1 -d '-') + python3 -m pip install coloredlogs flatbuffers numpy packaging protobuf sympy + python3 -m pip install --no-index --find-links . $PYTHON_PACKAGE_NAME + python3 -m pip show $PYTHON_PACKAGE_NAME + python3 -c "import onnxruntime as ort; print(ort.__version__)" + workingDirectory: $(Pipeline.Workspace)/build/onnxruntime-macos-x86_64_${{ replace(parameters.python_version,'.','_') }} + displayName: Test Package Installation + diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 761c551e9f4d9..9ad59ba90402d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -4,10 +4,14 @@ parameters: type: string default: 'onnxruntime-qnn-windows-vs-2022-arm64' +- name: PYTHON_VERSION + type: string + default: '3.11' + - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 - name: ENV_SETUP_SCRIPT type: string @@ -19,13 +23,8 @@ parameters: type: string default: '' -- name: is1ES - displayName: 'Whether the pipeline is running in 1ES' - type: boolean - default: false - jobs: -- job: Win_py_arm64_qnn_Wheels +- job: Win_py_arm64_qnn_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }} timeoutInMinutes: 210 workspace: clean: all @@ -48,48 +47,20 @@ jobs: outputs: - output: pipelineArtifact targetPath: $(Build.ArtifactStagingDirectory) - artifactName: onnxruntime_qnn_arm64_$(PythonVersion) + artifactName: onnxruntime_qnn_arm64_${{ parameters.PYTHON_VERSION }} - strategy: - matrix: - Python311_arm64: - PythonVersion: '3.11.0' - LocalPythonDir: 'C:\Python\Python311' - Python312_arm64: - PythonVersion: '3.12.6' - LocalPythonDir: 'C:\Python\Python312' - Python313_arm64: - PythonVersion: '3.13.2' - LocalPythonDir: 'C:\Python\Python313' variables: GRADLE_OPTS: '-Dorg.gradle.daemon=false' VSGenerator: 'Visual Studio 17 2022' steps: - checkout: self clean: true - submodules: recursive - - - template: telemetry-steps.yml - - - script: | - MKDIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 - XCOPY /s /y /h /e /c /q "$(LocalPythonDir)\*.*" $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64\ - COPY NUL $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64.complete - DIR $(Agent.ToolsDirectory)\Python - DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion) - DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 - displayName: Copy python $(PythonVersion) version to agent tools directory + submodules: none - - task: UsePythonVersion@0 - inputs: - versionSpec: $(PythonVersion) - addToPath: true - architecture: 'arm64' - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' + - template: setup-build-tools.yml + parameters: + host_cpu_arch: 'arm64' + python_version: ${{ parameters.PYTHON_VERSION }} - task: PythonScript@0 inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 74cae38393ea6..aad24661b868c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 - name: ENV_SETUP_SCRIPT type: string @@ -19,11 +19,6 @@ parameters: type: string default: '' -- name: is1ES - displayName: 'Whether the pipeline is running in 1ES' - type: boolean - default: false - jobs: - job: Win_py_x64_qnn_Wheels timeoutInMinutes: 210 @@ -50,18 +45,10 @@ jobs: clean: true submodules: recursive - - template: telemetry-steps.yml - - - task: UsePythonVersion@0 - inputs: - versionSpec: $(PythonVersion) - addToPath: true - architecture: 'x64' - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' + - template: setup-build-tools.yml + parameters: + host_cpu_arch: 'x64' + python_version: $(PythonVersion) - script: python -m pip install -r $(Build.SourcesDirectory)\tools\ci_build\github\windows\python\requirements.txt diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index e631e9d391a67..49f6fc662aa75 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 - name: ENV_SETUP_SCRIPT type: string @@ -49,19 +49,11 @@ jobs: clean: true submodules: recursive - - template: telemetry-steps.yml - - - task: UsePythonVersion@0 - inputs: - versionSpec: $(PythonVersion) - addToPath: true - architecture: 'x64' - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' - + - template: setup-build-tools.yml + parameters: + host_cpu_arch: 'x64' + python_version: $(PythonVersion) + - script: python -m pip install -r $(Build.SourcesDirectory)\tools\ci_build\github\windows\python\requirements.txt - template: set-nightly-build-option-variable-step.yml @@ -137,4 +129,4 @@ jobs: - script: | 7z x *.whl workingDirectory: '$(Build.ArtifactStagingDirectory)' - displayName: 'unzip the package' \ No newline at end of file + displayName: 'unzip the package' diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 52d9eb139fab7..3836db5ee7ba0 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.36.1.250708' + QnnSdk: '2.37.1.250807' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false @@ -11,7 +11,8 @@ parameters: stages: - stage: ${{ parameters.StageName }} - dependsOn: [] + dependsOn: + - Setup jobs: - job: ${{ parameters.StageName }} timeoutInMinutes: 300 @@ -45,15 +46,15 @@ stages: artifactName: "drop-signed-nuget-qnn" variables: OrtPackageId: ${{ parameters.OrtNugetPackageId }} + ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] commonBuildArgs: '--skip_submodule_sync --build_shared_lib --client_package_build --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags ' steps: - template: set-version-number-variables-step.yml - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true + - template: setup-build-tools.yml + parameters: + host_cpu_arch: 'x64' - template: jobs/download_win_qnn_sdk.yml parameters: @@ -107,7 +108,12 @@ stages: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.DesktopOnly.CSharp.sln' platform: 'Any CPU' configuration: ${{ parameters.build_config }} - msbuildArguments: '-p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }}' + msbuildArguments: > + -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" + -p:OrtPackageId=$(OrtPackageId) + -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} + -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) + -p:PackageVersion=$(OnnxRuntimeVersion) workingDirectory: '$(Build.SourcesDirectory)\csharp' - ${{ if eq(parameters.DoEsrp, true) }}: @@ -123,7 +129,7 @@ stages: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' platform: 'Any CPU' configuration: ${{ parameters.build_config }} - msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:TargetArchitecture=arm64' + msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) -p:TargetArchitecture=arm64' workingDirectory: '$(Build.SourcesDirectory)\csharp' - task: CopyFiles@2 diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index fe26dc20106f7..2e7f5122cbdc6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -80,12 +80,9 @@ stages: steps: - template: use-xcode-version.yml - - task: UsePythonVersion@0 - displayName: Use python 3.12 - inputs: - versionSpec: "3.12" - addToPath: true - architecture: "x64" + - template: setup-build-tools.yml + parameters: + host_cpu_arch: x64 # All ADO agents are currently x64 - script: | pip install -r tools/ci_build/github/apple/ios_packaging/requirements.txt @@ -113,10 +110,7 @@ stages: CCACHE_DEPEND: 1 CCACHE_SLOPPINESS: modules CCACHE_DIR: $(ORT_CACHE_DIR) - # Test the iOS package - - task: NodeTool@0 - inputs: - versionSpec: '22.x' + - script: brew install coreutils ninja npm displayName: Install coreutils, ninja, npm diff --git a/tools/ci_build/github/azure-pipelines/templates/setup-build-tools.yml b/tools/ci_build/github/azure-pipelines/templates/setup-build-tools.yml new file mode 100644 index 0000000000000..df7fea537ce6f --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/setup-build-tools.yml @@ -0,0 +1,65 @@ +# Setup python/nodejs/cmake/vcpkg tools. Also, setup telemetry header file if the current OS is Windows. + +parameters: +# for selecting python binary +- name: host_cpu_arch + type: string + +- name: python_version + type: string + default: '3.12' + +- name: action_version + type: string + default: 'v0.0.9' + +steps: +- template: telemetry-steps.yml + +# Currently all ADO macOS machines are x64 machines +- task: UsePythonVersion@0 + displayName: 'Use Python ${{ parameters.host_cpu_arch }} (macOS)' + condition: and(succeeded(), eq(variables['Agent.OS'], 'Darwin')) + inputs: + versionSpec: ${{ parameters.python_version }} + architecture: 'x64' + +- task: UsePythonVersion@0 + displayName: 'Use Python ${{ parameters.host_cpu_arch }} (non-macOS)' + condition: and(succeeded(), ne(variables['Agent.OS'], 'Darwin')) + inputs: + versionSpec: ${{ parameters.python_version }} + architecture: ${{ parameters.host_cpu_arch }} + +- task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + +# The following task does not support different arches. +- task: UseNode@1 + condition: and(succeeded(), ne(variables['Agent.OS'], 'Windows_NT')) + inputs: + version: '22.x' + +- task: PowerShell@2 + displayName: 'Setup Latest Node.js v20 (Win)' + condition: and(succeeded(), eq(variables['Agent.OS'], 'Windows_NT')) + inputs: + filePath: '$(System.DefaultWorkingDirectory)\tools\ci_build\github\windows\setup_nodejs.ps1' + arguments: '-MajorVersion 22' + +- script: | + node -v + npm -v + + condition: and(succeeded(), eq(variables['Agent.OS'], 'Windows_NT')) + displayName: 'Verify Node.js Version' + +- script: python3 -m pip install requests + +- task: PythonScript@0 + displayName: 'Run GitHub Action via Python Wrapper' + inputs: + scriptPath: 'tools/ci_build/run_gh_action.py' + arguments: '${{ parameters.action_version }}' diff --git a/tools/ci_build/github/azure-pipelines/templates/setup-maven.yml b/tools/ci_build/github/azure-pipelines/templates/setup-maven.yml new file mode 100644 index 0000000000000..7ad755c50e541 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/setup-maven.yml @@ -0,0 +1,47 @@ +steps: +- task: AzureCLI@2 + displayName: 'Download and Extract Maven using Azure CLI' + inputs: + azureSubscription: 'AIInfraBuildOnnxRuntimeOSS' + scriptType: 'pscore' # Use PowerShell Core + scriptLocation: 'inlineScript' + inlineScript: | + # Define the scope for the access token + $authScope = "https://mspmecloud.onmicrosoft.com/RebuildManager.Web/.default" + + Write-Host "Requesting access token for scope: $authScope" + $tokenInfo = az account get-access-token --scope $authScope | ConvertFrom-Json + + # Set the token as an environment variable for the next tool to use + $env:TRT_UPLOAD_AUTH_TOKEN = $tokenInfo.accessToken + Write-Host "Successfully configured TRT_UPLOAD_AUTH_TOKEN environment variable." + + # Execute the Terrapin Retrieval Tool to download Maven + Write-Host "Downloading Maven..." + & C:\local\Terrapin\TerrapinRetrievalTool.exe -b https://vcpkg.storage.devpackages.microsoft.io/artifacts/ -a true -u Environment -p https://dlcdn.apache.org/maven/maven-3/3.9.11/binaries/apache-maven-3.9.11-bin.zip -s 03e2d65d4483a3396980629f260e25cac0d8b6f7f2791e4dc20bc83f9514db8d0f05b0479e699a5f34679250c49c8e52e961262ded468a20de0be254d8207076 -d $(Agent.TempDirectory)\maven.zip + + # Check if the download was successful + if ($LASTEXITCODE -ne 0) { + throw "Error downloading maven. Exit code: $LASTEXITCODE" + } + Write-Host "Maven downloaded successfully." + + # Extract the downloaded maven zip file + $arguments = "x", "$(Agent.TempDirectory)\maven.zip", "-y", "-o$(Agent.TempDirectory)" + Write-Output "Executing: 7z.exe $arguments" + & 7z.exe $arguments + + # Check if the extraction was successful + if ($LASTEXITCODE -ne 0) { + throw "Error extracting maven.zip. Exit code: $LASTEXITCODE" + } + Write-Host "Maven extracted successfully." + + # Prepend the Maven bin directory to the PATH for subsequent steps in the job + Write-Host "Adding Maven to the pipeline PATH." + Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\apache-maven-3.9.11\bin" + +- script: | + echo "Verifying Maven installation..." + mvn --version + displayName: 'Verify Maven Version' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml index 5eef1ae8e8e93..f377ad863cbe0 100644 --- a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml @@ -87,11 +87,9 @@ stages: removeProfile: true displayName: 'Install ORT Mobile Test Provisioning Profile' - - task: UsePythonVersion@0 - inputs: - versionSpec: "3.12" - addToPath: true - architecture: "x64" + - template: ../setup-build-tools.yml + parameters: + host_cpu_arch: arm64 - template: ../use-xcode-version.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/templates/telemetry-steps.yml b/tools/ci_build/github/azure-pipelines/templates/telemetry-steps.yml index a8bc789e1cffe..8db4a8f8c8658 100644 --- a/tools/ci_build/github/azure-pipelines/templates/telemetry-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/telemetry-steps.yml @@ -5,6 +5,7 @@ steps: # TELEMETRYGUID is a runtime variable that is stored on the pipeline in an old-fashioned way. So it cannot be used in # template expressions. We access it through env variables. - task: PowerShell@2 + condition: and(succeeded(), eq(variables['Agent.OS'], 'Windows_NT')) displayName: 'Set TelemetryOption variable and optionally create TraceLoggingConfigPrivate.h for WinML Telemetry' inputs: targetType: filePath diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index eec0f273581a2..c54b13b8dec6a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -40,6 +40,19 @@ parameters: type: string default: '' +- name: PreReleaseVersionSuffixString + displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. + type: string + values: + - alpha + - beta + - rc + - none + +- name: PreReleaseVersionSuffixNumber + displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. + type: number + # for inference packages '', for training packages '-training' # used for drop-extra and c api artifacts (onnxruntime-win-* or onnxrutime-training-win-*) - name: artifact_name_suffix @@ -110,6 +123,11 @@ stages: - output: pipelineArtifact targetPath: $(Build.ArtifactStagingDirectory) artifactName: 'onnxruntime${{ parameters.artifact_name_suffix }}-win-${{ parameters.packageName }}' + + - ${{ if eq(parameters.buildJava, 'true') }}: + - output: pipelineArtifact + targetPath: $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }} + artifactName: 'drop-onnxruntime-java-win-${{ parameters.packageName }}${{parameters.artifact_name_suffix}}' # GPU build has two jobs. This is the first one. - ${{ if contains(parameters.ort_build_pool_name, 'GPU') }}: - output: pipelineArtifact @@ -134,18 +152,9 @@ stages: clean: true submodules: none - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - architecture: ${{ parameters.buildArch }} - - - template: telemetry-steps.yml - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' + - template: setup-build-tools.yml + parameters: + host_cpu_arch: 'x64' - ${{ if eq(parameters['buildJava'], 'true') }}: - task: JavaToolInstaller@0 @@ -154,12 +163,6 @@ stages: jdkArchitectureOption: ${{ parameters.buildArch }} jdkSourceOption: 'PreInstalled' - - - task: NodeTool@0 - condition: and(succeeded(), eq('${{ parameters.buildNodejs}}', true)) - inputs: - versionSpec: '22.x' - - ${{ if ne(parameters.CudaVersion, '') }}: - template: jobs/download_win_gpu_library.yml parameters: @@ -183,21 +186,15 @@ stages: # For CPU job, tests are run in the same machine as building - ${{ if eq(parameters.buildJava, 'true') }}: + - template: setup-maven.yml - template: make_java_win_binaries.yml parameters: msbuildPlatform: ${{ parameters.msbuildPlatform }} java_artifact_id: ${{ parameters.java_artifact_id }} - ${{ if or(contains(parameters.buildparameter, 'use_cuda'), contains(parameters.buildparameter, 'use_tensorrt')) }}: - # When it is a GPU build, we only assemble the java binaries, testing will be done in the later stage with GPU machine - buildOnly: true - ${{ else }}: - buildOnly: false - - - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Java temp binaries' - inputs: - targetPath: '$(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}' - artifactName: 'drop-onnxruntime-java-win-${{ parameters.packageName }}${{parameters.artifact_name_suffix}}' + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} + buildOnly: true + # All GPU builds will be tested in the next stage with GPU machine - ${{ if contains(parameters.ort_build_pool_name, 'CPU') }}: - task: PythonScript@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index 01f73a63075e3..8b2504d61def1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -76,9 +76,11 @@ jobs: git checkout -- .gitattributes workingDirectory: '$(Build.SourcesDirectory)' displayName: 'Testing: force EOL to lf on windows for /js/**' - - task: NodeTool@0 - inputs: - versionSpec: '22.x' + + - template: setup-build-tools.yml + parameters: + host_cpu_arch: 'x64' + - task: DownloadPipelineArtifact@2 inputs: patterns: '${{ parameters.BuildConfig }}_wasm/**/*' diff --git a/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml b/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml index a084d28e84c1e..915ff517742fd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml @@ -23,19 +23,9 @@ jobs: inputs: version: '6.x' - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - ${{ if eq(parameters.BuildArch, 'x86') }}: - architecture: 'x86' - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' - - - template: telemetry-steps.yml + - template: setup-build-tools.yml + parameters: + host_cpu_arch: ${{ parameters.BuildArch }} - task: NuGetCommand@2 displayName: 'NuGet restore' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 7ebf5394e4530..a01e2bc921aea 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 jobs: - job: 'BUILD_QNN_EP' @@ -61,16 +61,6 @@ jobs: # because the python bindings also use the USE__PROVIDER_INTERFACE preprocessor macros. ExtraQnnBuildArgs: '--enable_generic_interface --build_wheel' steps: - - - script: | - MKDIR $(Agent.ToolsDirectory)\Python\3.11.0\arm64 - XCOPY /s /y /h /e /c /q "C:\Python\Python311\*.*" $(Agent.ToolsDirectory)\Python\3.11.0\arm64\ - COPY NUL $(Agent.ToolsDirectory)\Python\3.11.0\arm64.complete - DIR $(Agent.ToolsDirectory)\Python - DIR $(Agent.ToolsDirectory)\Python\3.11.0 - DIR $(Agent.ToolsDirectory)\Python\3.11.0\arm64 - displayName: Copy python 3.11.0 version to agent tools directory - - task: UsePythonVersion@0 inputs: versionSpec: '3.x' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml deleted file mode 100644 index ffeb577547f69..0000000000000 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ /dev/null @@ -1,112 +0,0 @@ -##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### -### please do rerun set-trigger-rules.py ### -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -#### end trigger #### - -parameters: - -- name: QnnSdk - displayName: QNN SDK version - type: string - default: 2.36.1.250708 - -jobs: -- job: 'BUILD_QNN_EP' - pool: 'Onnxruntime-QNNEP-Windows-2022-CPU' - variables: - MsbuildArguments: '-detailedsummary -maxcpucount -consoleloggerparameters:PerformanceSummary' - OnnxRuntimeBuildDirectory: '$(Build.BinariesDirectory)' - DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true - buildArch: x64 - setVcvars: true - BuildConfig: 'RelWithDebInfo' - ALLOW_RELEASED_ONNX_OPSET_ONLY: '1' - TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - timeoutInMinutes: 120 - workspace: - clean: all - strategy: - matrix: - SHARED_LIB: - QnnLibKind: 'shared_lib' - STATIC_LIB: - QnnLibKind: 'static_lib' - steps: - - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - architecture: $(buildArch) - - - template: templates/jobs/download_win_qnn_sdk.yml - parameters: - QnnSDKVersion: ${{ parameters.QnnSdk }} - - - template: templates/jobs/win-ci-build-steps.yml - parameters: - WithCache: True - Today: $(TODAY) - AdditionalKey: "win-qnn | $(BuildConfig)" - BuildPyArguments: >- - --config $(BuildConfig) - --build_dir $(Build.BinariesDirectory) - --cmake_generator "Visual Studio 17 2022" - --build_java - --build_shared_lib - --use_qnn $(QnnLibKind) - --qnn_home $(QnnSDKRootDir) - --use_binskim_compliant_compile_flags - --update --parallel - MsbuildArguments: $(MsbuildArguments) - BuildArch: $(buildArch) - Platform: 'x64' - BuildConfig: $(BuildConfig) - - - script: | - python $(Build.SourcesDirectory)\tools\ci_build\build.py ^ - --config $(BuildConfig) ^ - --build_dir $(Build.BinariesDirectory) ^ - --cmake_generator "Visual Studio 17 2022" ^ - --build_java ^ - --build_shared_lib ^ - --use_qnn $(QnnLibKind) ^ - --qnn_home $(QnnSDKRootDir) ^ - --use_binskim_compliant_compile_flags ^ - --test --enable_onnx_tests - displayName: 'Run unit tests' - - - script: | - .\$(BuildConfig)\onnx_test_runner -j 1 -e qnn -i "backend_path|$(QnnSDKRootDir)\lib\x86_64-windows-msvc\QnnCpu.dll" $(Build.SourcesDirectory)\cmake\external\onnx\onnx\backend\test\data\node - workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)' - displayName: 'Run ONNX Tests' - - - script: | - .\$(BuildConfig)\onnx_test_runner -j 1 -e qnn -i "backend_path|$(QnnSDKRootDir)\lib\x86_64-windows-msvc\QnnCpu.dll" C:\data\float32_models - workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)' - displayName: 'Run float32 model tests' diff --git a/tools/ci_build/github/linux/build_cuda_c_api_package.sh b/tools/ci_build/github/linux/build_cuda_c_api_package.sh index fe417db7f2559..9cc140f41cf91 100755 --- a/tools/ci_build/github/linux/build_cuda_c_api_package.sh +++ b/tools/ci_build/github/linux/build_cuda_c_api_package.sh @@ -2,4 +2,4 @@ set -e -x docker run -e SYSTEM_COLLECTIONURI --rm --volume \ $BUILD_SOURCESDIRECTORY:/onnxruntime_src --volume $BUILD_BINARIESDIRECTORY:/build -e NIGHTLY_BUILD onnxruntimecuda${CUDA_VERSION_MAJOR}build \ -/bin/bash -c "/usr/bin/python3 /onnxruntime_src/tools/ci_build/build.py --enable_lto --build_java --build_nodejs --build_dir /build --config Release --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --use_cuda --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr/local/cuda-$CUDA_VERSION --skip_tests --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;90a-real;90a-virtual' && cd /build/Release && make install DESTDIR=/build/installed" +/bin/bash -c "/usr/bin/python3 /onnxruntime_src/tools/ci_build/build.py --enable_lto --build_java --build_nodejs --build_dir /build --config Release --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --use_cuda --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr/local/cuda-$CUDA_VERSION --skip_tests --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;90a-real;90a-virtual' 'onnxruntime_USE_FPA_INTB_GEMM=OFF' && cd /build/Release && make install DESTDIR=/build/installed" diff --git a/tools/ci_build/github/linux/build_linux_python_package.sh b/tools/ci_build/github/linux/build_linux_python_package.sh index fd23cd9bc37f1..5684029c72049 100755 --- a/tools/ci_build/github/linux/build_linux_python_package.sh +++ b/tools/ci_build/github/linux/build_linux_python_package.sh @@ -70,7 +70,7 @@ fi if [ "$BUILD_DEVICE" == "GPU" ]; then SHORT_CUDA_VERSION=$(echo $CUDA_VERSION | sed 's/\([[:digit:]]\+\.[[:digit:]]\+\)\.[[:digit:]]\+/\1/') #Enable CUDA and TRT EPs. - BUILD_ARGS+=("--use_cuda" "--use_tensorrt" "--cuda_version=$SHORT_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--nvcc_threads=1" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;86-real;90a-real;90a-virtual") + BUILD_ARGS+=("--use_cuda" "--use_tensorrt" "--cuda_version=$SHORT_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--nvcc_threads=1" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;86-real;90a-real;90a-virtual" "onnxruntime_USE_FPA_INTB_GEMM=OFF") fi if [ "$BUILD_DEVICE" == "NPU" ]; then diff --git a/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh b/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh index 54e671a8196be..b8d968c82d002 100755 --- a/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh +++ b/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh @@ -3,4 +3,4 @@ set -e -x mkdir -p $HOME/.onnx docker run -e SYSTEM_COLLECTIONURI --rm --volume /data/onnx:/data/onnx:ro --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src --volume $BUILD_BINARIESDIRECTORY:/build \ --volume /data/models:/build/models:ro --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecuda${CUDA_VERSION_MAJOR}xtrt86build \ -/bin/bash -c "/usr/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --skip_tests --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --build_java --build_nodejs --use_tensorrt --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;90a-real;90a-virtual' --use_vcpkg --use_vcpkg_ms_internal_asset_cache && cd /build/Release && make install DESTDIR=/build/installed" +/bin/bash -c "/usr/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --skip_tests --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --build_java --build_nodejs --use_tensorrt --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;90a-real;90a-virtual' 'onnxruntime_USE_FPA_INTB_GEMM=OFF' --use_vcpkg --use_vcpkg_ms_internal_asset_cache && cd /build/Release && make install DESTDIR=/build/installed" diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu index 177df14d6eaee..2a65e7c26b20b 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu @@ -1,4 +1,5 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 +FROM $BASEIMAGE ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm index 957eef8046eaf..3337af3be6074 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm @@ -1,4 +1,5 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 +FROM $BASEIMAGE ARG ROCM_VERSION=6.2.3 #Add our own dependencies diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu index 56d67599f0bce..0007a4e06f7c0 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu @@ -1,4 +1,5 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 +FROM $BASEIMAGE ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile index c8e164282a2f0..8b2083c2ccfc1 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile @@ -2,7 +2,8 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14_dotnet:20250724.1 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14_dotnet:20250724.1 +FROM $BASEIMAGE ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile index 31bd41226263f..f5143d5ac9ab9 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile @@ -1,4 +1,5 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14:20250724.1 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14:20250724.1 +FROM $BASEIMAGE ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile index 461464093688a..cfc2ce7079148 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -2,7 +2,8 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14_dotnet:20250724.1 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14_dotnet:20250724.1 +FROM $BASEIMAGE ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile index 043291065736d..8401393a661b1 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile @@ -2,7 +2,8 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12_dotnet:20250724.1 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12_dotnet:20250724.1 +FROM $BASEIMAGE ARG TRT_VERSION #Install TensorRT only if TRT_VERSION is not empty diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile index 43da13df2fe8b..b923febc1227f 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile @@ -1,4 +1,5 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 +FROM $BASEIMAGE ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/scripts/requirements.txt b/tools/ci_build/github/linux/docker/scripts/requirements.txt index b52d0cbcf3fea..a52e57138117a 100644 --- a/tools/ci_build/github/linux/docker/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/requirements.txt @@ -3,7 +3,7 @@ numpy==2.2.6 mypy pytest setuptools==78.1.1 -wheel==0.42.0 +wheel==0.45.1 onnx==1.18.0 argparse sympy==1.14 diff --git a/tools/ci_build/github/linux/java_copy_strip_binary.sh b/tools/ci_build/github/linux/java_copy_strip_binary.sh deleted file mode 100755 index 329c1b0ab9b9e..0000000000000 --- a/tools/ci_build/github/linux/java_copy_strip_binary.sh +++ /dev/null @@ -1,71 +0,0 @@ -#!/bin/bash -set -e -o -x - -while getopts r:a:l:n:c:h:v: parameter_Option -do case "${parameter_Option}" -in -r) BINARY_DIR=${OPTARG};; -a) ARTIFACT_NAME=${OPTARG};; -c) BUILD_CONFIG=${OPTARG};; -l) LIB_NAME=${OPTARG};; -n) NATIVE_LIB_NAME=${OPTARG};; -h) ARCH=${OPTARG};; #must match the JAVA_OS_ARCH variable in onnxruntime_java.cmake -v) VERSION_NUMBER=${OPTARG};; -esac -done - -EXIT_CODE=1 - -uname -a - -echo "Version: $VERSION_NUMBER" -if [[ $LIB_NAME == *.dylib ]] && [[ $ARCH == 'osx-x86_64' ]]; then - ARCH='osx-x64' -elif [[ $LIB_NAME == *.dylib ]] && [[ $ARCH == 'osx-arm64' ]]; then - ARCH='osx-aarch64' -fi -NATIVE_FOLDER=ai/onnxruntime/native/$ARCH - -mkdir -p $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER - -echo "Directories created" - -echo "Copy debug symbols in a separate file and strip the original binary." - -if [[ $LIB_NAME == *.dylib ]] -then - # ORT LIB - dsymutil $BINARY_DIR/$BUILD_CONFIG/$LIB_NAME -o $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/$LIB_NAME.dSYM - cp $BINARY_DIR/$BUILD_CONFIG/$LIB_NAME $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime.dylib - strip -S $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime.dylib - # JNI Lib - dsymutil $BINARY_DIR/$BUILD_CONFIG/$NATIVE_LIB_NAME -o $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/$NATIVE_LIB_NAME.dSYM - cp $BINARY_DIR/$BUILD_CONFIG/$NATIVE_LIB_NAME $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime4j_jni.dylib - strip -S $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime4j_jni.dylib - # Add custom lib for testing. This should be added to testing.jar - cp $BINARY_DIR/$BUILD_CONFIG/libcustom_op_library.dylib $BINARY_DIR/$ARTIFACT_NAME -elif [[ $LIB_NAME == *.so ]] -then - cp $BINARY_DIR/$BUILD_CONFIG/$LIB_NAME $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime.so - cp $BINARY_DIR/$BUILD_CONFIG/$NATIVE_LIB_NAME $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime4j_jni.so - # Add custom lib - cp $BINARY_DIR/$BUILD_CONFIG/libcustom_op_library.so $BINARY_DIR/$ARTIFACT_NAME - # Add cuda provider if it exists - if [[ -f "$BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_cuda.so" ]]; then - cp $BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_shared.so $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime_providers_shared.so - cp $BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_cuda.so $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime_providers_cuda.so - fi - # Add tensorrt provider if it exists - if [[ -f "$BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_tensorrt.so" ]]; then - cp $BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_shared.so $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime_providers_shared.so - cp $BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_tensorrt.so $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime_providers_tensorrt.so - fi -fi - -find $BINARY_DIR/$ARTIFACT_NAME -ls -rm -fr $BINARY_DIR/$ARTIFACT_NAME/jar - -EXIT_CODE=$? - -set -e -exit $EXIT_CODE diff --git a/tools/ci_build/github/linux/java_linux_final_test.sh b/tools/ci_build/github/linux/java_linux_final_test.sh index 2699d488acbb8..cdbfd2bad10a8 100755 --- a/tools/ci_build/github/linux/java_linux_final_test.sh +++ b/tools/ci_build/github/linux/java_linux_final_test.sh @@ -20,23 +20,24 @@ EXIT_CODE=1 uname -a -cd "$BINARY_DIR/final-jar" - -mkdir test +cd "$BINARY_DIR/onnxruntime-java" +rm -f *.asc +rm -f *.sha256 +rm -f *.sha1 +rm -f *.md5 +rm -f *.sha512 +rm -f *.pom +ls +cd .. +mkdir tests +cd tests +jar xf ../onnxruntime-java/testing.jar +rm -f ../onnxruntime-java/testing.jar +echo "Java Version" +java -version echo "Directories created" echo "Library path:" "$LD_LIBRARY_PATH" -pushd test -jar xf "$BINARY_DIR/final-jar/testing.jar" -popd - -curl -O -sSL https://oss.sonatype.org/service/local/repositories/releases/content/org/junit/platform/junit-platform-console-standalone/1.6.2/junit-platform-console-standalone-1.6.2.jar -curl -O -sSL https://oss.sonatype.org/service/local/repositories/releases/content/com/google/protobuf/protobuf-java/3.25.5/protobuf-java-3.25.5.jar -java -DUSE_CUDA=1 -jar ./junit-platform-console-standalone-1.6.2.jar -cp .:./test:./protobuf-java-3.25.5.jar:./onnxruntime_gpu-"${VERSION_NUMBER}".jar --scan-class-path --fail-if-no-tests --disable-banner - - -EXIT_CODE=$? - -set -e -exit $EXIT_CODE +java -DUSE_CUDA=1 -cp "$BINARY_DIR/tests:$BINARY_DIR/onnxruntime-java/*" org.junit.platform.console.ConsoleLauncher --scan-classpath=$BINARY_DIR/tests \ + --fail-if-no-tests --disable-banner diff --git a/tools/ci_build/github/windows/extract_nuget_files.ps1 b/tools/ci_build/github/windows/extract_nuget_files.ps1 index ff8f63a85b97a..20d6c1f2b63a5 100644 --- a/tools/ci_build/github/windows/extract_nuget_files.ps1 +++ b/tools/ci_build/github/windows/extract_nuget_files.ps1 @@ -1,105 +1,119 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -# This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline +# This file is used by Zip-Nuget-Java Packaging Pipeline -# Re-construct a build directory that contains binaries from all the different platforms we're including -# in the native ORT nuget package +# Define the directory for NuGet artifacts. $nuget_artifacts_dir = "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" -New-Item -Path $nuget_artifacts_dir -ItemType directory +# Create the directory if it doesn't exist. +New-Item -Path $nuget_artifacts_dir -ItemType directory -ErrorAction SilentlyContinue ## .zip files -# unzip directly -# exclude the iOS xcframework as we need to leave that zipped up to preserve symlinks -Get-ChildItem -Path $Env:BUILD_BINARIESDIRECTORY\nuget-artifact\* -Include *.zip -Exclude onnxruntime_ios_xcframework.*.zip | +# Unzip files directly, excluding the iOS xcframework to preserve its symlinks. +Get-ChildItem -Path "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact\*" -Include *.zip -Exclude onnxruntime_ios_xcframework.*.zip | Foreach-Object { - $cmd = "7z.exe x $($_.FullName) -y -o$nuget_artifacts_dir" - Write-Output $cmd - Invoke-Expression -Command $cmd + # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). + $arguments = "x", "$($_.FullName)", "-y", "-o$nuget_artifacts_dir", "-snld20" + Write-Output "Executing: 7z.exe $arguments" + # Directly call 7z.exe using the call operator '&' + & 7z.exe $arguments + # Check the exit code of the last command. A non-zero code indicates an error. + if ($LASTEXITCODE -ne 0) { + throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" + } } ## .tgz files -# first extract the tar file from the tgz -Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.tgz | +# First, extract the .tar file from the .tgz archive. +Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.tgz | Foreach-Object { - $cmd = "7z.exe x $($_.FullName) -y -o$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" - Write-Output $cmd - Invoke-Expression -Command $cmd + # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). + $arguments = "x", "$($_.FullName)", "-y", "-o$Env:BUILD_BINARIESDIRECTORY\nuget-artifact", "-snld20" + Write-Output "Executing: 7z.exe $arguments" + & 7z.exe $arguments + if ($LASTEXITCODE -ne 0) { + throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" + } } -# now extract the actual folder structure from the tar file to the build dir -Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.tar | +# Now, extract the contents from the .tar file into the final directory. +Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.tar | Foreach-Object { - $cmd = "7z.exe x $($_.FullName) -y -o$nuget_artifacts_dir" - Write-Output $cmd - Invoke-Expression -Command $cmd + # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). + $arguments = "x", "$($_.FullName)", "-y", "-o$nuget_artifacts_dir", "-snld20" + Write-Output "Executing: 7z.exe $arguments" + & 7z.exe $arguments + if ($LASTEXITCODE -ne 0) { + throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" + } } -# process iOS xcframework -$xcframeworks = Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter onnxruntime_ios_xcframework.*.zip +# Process iOS xcframework +$xcframeworks = Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter onnxruntime_ios_xcframework.*.zip if ($xcframeworks.Count -eq 1) { - $xcframework = $xcframeworks[0] - $target_dir = "$nuget_artifacts_dir\onnxruntime-ios-xcframework" - # remove version info from filename and use required filename format - $target_file = "$target_dir\onnxruntime.xcframework.zip" - New-Item -Path $target_dir -ItemType directory + $xcframework = $xcframeworks[0] + $target_dir = "$nuget_artifacts_dir\onnxruntime-ios-xcframework" + # Use the required filename format, removing version info. + $target_file = "$target_dir\onnxruntime.xcframework.zip" + New-Item -Path $target_dir -ItemType directory -ErrorAction SilentlyContinue - Write-Output "Copy-Item $($xcframework.FullName) $target_file" - Copy-Item $xcframework.FullName $target_file + Write-Output "Copying $($xcframework.FullName) to $target_file" + Copy-Item $xcframework.FullName $target_file } elseif ($xcframeworks.Count -gt 1) { - Write-Error "Expected at most one onnxruntime_ios_xcframework*.zip file but got: [$xcframeworks]" + Write-Error "Expected at most one onnxruntime_ios_xcframework*.zip file but got: [$xcframeworks]" } - -# copy android AAR. -# for full build of onnxruntime Android AAR, there should only be one .aar file -# called onnxruntime-android-x.y.z.aar or onnxruntime-training-android-x.y.z.aar but sanity check that -$aars = Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.aar +# Copy Android AAR file. +# There should only be one .aar file for a full build. +$aars = Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.aar if ($aars.Count -eq 1) { - $aar = $aars[0] - $aar_prefix = "onnxruntime" - if ($aar -like "onnxruntime-training*") { - $aar_prefix = "onnxruntime-training" - } - $target_dir = "$nuget_artifacts_dir\$aar_prefix-android-aar" - $target_file = "$target_dir\onnxruntime.aar" # remove '-mobile' and version info from filename - New-Item -Path $target_dir -ItemType directory + $aar = $aars[0] + $aar_prefix = "onnxruntime" + if ($aar.Name -like "onnxruntime-training*") { + $aar_prefix = "onnxruntime-training" + } + $target_dir = "$nuget_artifacts_dir\$aar_prefix-android-aar" + # Remove version info from the filename for consistency. + $target_file = "$target_dir\onnxruntime.aar" + New-Item -Path $target_dir -ItemType directory -ErrorAction SilentlyContinue - Write-Output "Copy-Item $($aar.FullName) $target_file" - Copy-Item $aar.FullName $target_file + Write-Output "Copying $($aar.FullName) to $target_file" + Copy-Item $aar.FullName $target_file } elseif ($aars.Count -gt 1) { - Write-Error "Expected at most one Android .aar file but got: [$aars]" + Write-Error "Expected at most one Android .aar file but got: [$aars]" } -# Check whether this is a training pipeline -$is_training_pipeline = $false -if (Test-Path -Path $nuget_artifacts_dir\onnxruntime-training-win-x64-*) { - $is_training_pipeline = $true - Write-Output "onnxruntime-training-win-x64-* dir exists. This is a training pipeline." +# Check if this is a training pipeline by looking for a specific directory. +$is_training_pipeline = Test-Path -Path "$nuget_artifacts_dir\onnxruntime-training-win-x64-*" +if ($is_training_pipeline) { + Write-Output "onnxruntime-training-win-x64-* dir exists. This is a training pipeline." } -# Copy onnxruntime and protoc binaries to the binaries dir as these are required -# by Microsoft.ML.OnnxRuntime.Tests.NetCoreApp +# Copy onnxruntime and protoc binaries required by tests. +$destinationDir = "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo" if ($is_training_pipeline) { - Copy-Item -Path $nuget_artifacts_dir\onnxruntime-training-win-x64-*\lib\* -Destination $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo + Copy-Item -Path "$nuget_artifacts_dir\onnxruntime-training-win-x64-*\lib\*" -Destination $destinationDir -Recurse } else { - Copy-Item -Path $nuget_artifacts_dir\onnxruntime-win-x64-*\lib\* -Destination $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo + Copy-Item -Path "$nuget_artifacts_dir\onnxruntime-win-x64-*\lib\*" -Destination $destinationDir -Recurse } -"Get-ChildItem -Directory -Path $nuget_artifacts_dir\onnxruntime-*" -$ort_dirs = Get-ChildItem -Directory -Path $nuget_artifacts_dir\onnxruntime-* -foreach ($ort_dir in $ort_dirs) -{ - # remove the last '-xxx' segment from the dir name. typically that's the architecture. - $dirname = Split-Path -Path $ort_dir -Leaf - $dirname = $dirname.SubString(0,$dirname.LastIndexOf('-')) - Write-Output "Renaming $ort_dir to $dirname" - Rename-Item -Path $ort_dir -NewName $nuget_artifacts_dir\$dirname +# Rename directories to remove the architecture-specific suffix. +Write-Output "Renaming onnxruntime directories..." +Get-ChildItem -Directory -Path "$nuget_artifacts_dir\onnxruntime-*" | ForEach-Object { + $dirname = $_.Name + # Find the last hyphen and remove the suffix. + $lastHyphenIndex = $dirname.LastIndexOf('-') + if ($lastHyphenIndex -gt -1) { + $newName = $dirname.Substring(0, $lastHyphenIndex) + $newPath = Join-Path -Path $_.Parent.FullName -ChildPath $newName + Write-Output "Renaming '$($_.FullName)' to '$newPath'" + Rename-Item -Path $_.FullName -NewName $newName + } } -# List artifacts -"Post copy artifacts" -Get-ChildItem -Recurse $nuget_artifacts_dir\ +# List the final artifacts. +Write-Output "Post-copy artifacts:" +Get-ChildItem -Recurse $nuget_artifacts_dir \ No newline at end of file diff --git a/tools/ci_build/github/windows/extract_nuget_files_gpu.ps1 b/tools/ci_build/github/windows/extract_nuget_files_gpu.ps1 index 01a8eebe75df2..29946dcb73f8a 100644 --- a/tools/ci_build/github/windows/extract_nuget_files_gpu.ps1 +++ b/tools/ci_build/github/windows/extract_nuget_files_gpu.ps1 @@ -2,47 +2,81 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget-Java Packaging Pipeline -New-Item -Path $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts -ItemType directory +# Define the directory for NuGet artifacts. +$nuget_artifacts_dir = "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" +# Create the directory if it doesn't exist. +New-Item -Path $nuget_artifacts_dir -ItemType directory -ErrorAction SilentlyContinue -Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.zip | +## .zip files +Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.zip | Foreach-Object { - $cmd = "7z.exe x $($_.FullName) -y -o$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" - Write-Output $cmd - Invoke-Expression -Command $cmd + # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). + $arguments = "x", "$($_.FullName)", "-y", "-o$nuget_artifacts_dir", "-snld20" + Write-Output "Executing: 7z.exe $arguments" + & 7z.exe $arguments + if ($LASTEXITCODE -ne 0) { + throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" + } } -Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.tgz | +## .tgz files +Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.tgz | Foreach-Object { - $cmd = "7z.exe x $($_.FullName) -y -o$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" # *.tar will be created after *.tgz is extracted - Write-Output $cmd - Invoke-Expression -Command $cmd + # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). + # *.tar will be created after *.tgz is extracted + $arguments = "x", "$($_.FullName)", "-y", "-o$Env:BUILD_BINARIESDIRECTORY\nuget-artifact", "-snld20" + Write-Output "Executing: 7z.exe $arguments" + & 7z.exe $arguments + if ($LASTEXITCODE -ne 0) { + throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" + } } -Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.tar | +## .tar files +Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.tar | Foreach-Object { - $cmd = "7z.exe x $($_.FullName) -y -o$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" - Write-Output $cmd - Invoke-Expression -Command $cmd + # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). + $arguments = "x", "$($_.FullName)", "-y", "-o$nuget_artifacts_dir", "-snld20" + Write-Output "Executing: 7z.exe $arguments" + & 7z.exe $arguments + if ($LASTEXITCODE -ne 0) { + throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" + } } +# Create directory for protobuf build dependencies. +New-Item -Path "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build\RelWithDebInfo" -ItemType directory -ErrorAction SilentlyContinue -New-Item -Path $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build\RelWithDebInfo -ItemType directory - -Copy-Item -Path $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-win-x64-cuda-*\lib\* -Destination $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo +# Copy CUDA libraries. +Copy-Item -Path "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-win-x64-cuda-*\lib\*" -Destination "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo" +# Install protoc via dotnet. $protocInstallDir = "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build" dotnet new console dotnet add package Google.Protobuf.Tools --version 3.21.12 --package-directory $protocInstallDir +if ($LASTEXITCODE -ne 0) { + throw "Error adding Google.Protobuf.Tools package. Exit code: $LASTEXITCODE" +} + +# Find and copy the protoc executable. $protocDir = Get-ChildItem -Path $protocInstallDir -Recurse -Filter "protoc.exe" | Select-Object -ExpandProperty DirectoryName -First 1 -Write-Output $protocDir -Copy-Item -Path $protocDir -Destination $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build\RelWithDebInfo - -$ort_dirs = Get-ChildItem -Path $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-* -Directory -foreach ($ort_dir in $ort_dirs) -{ - $dirname = Split-Path -Path $ort_dir -Leaf - $dirname = $dirname.SubString(0,$dirname.LastIndexOf('-')) - Write-Output "Renaming $ort_dir to $dirname" - Rename-Item -Path $ort_dir -NewName $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\$dirname +if ($protocDir) { + Write-Output "Found protoc directory: $protocDir" + Copy-Item -Path $protocDir -Destination "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build\RelWithDebInfo" +} +else { + Write-Error "Could not find protoc.exe in $protocInstallDir" } +# Rename onnxruntime directories to a generic format. +$ort_dirs = Get-ChildItem -Path "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-*" -Directory +foreach ($ort_dir in $ort_dirs) { + $dirname = Split-Path -Path $ort_dir -Leaf + $lastHyphenIndex = $dirname.LastIndexOf('-') + if ($lastHyphenIndex -gt -1) { + $newName = $dirname.Substring(0, $lastHyphenIndex) + $newPath = Join-Path -Path $ort_dir.Parent.FullName -ChildPath $newName + Write-Output "Renaming '$($ort_dir.FullName)' to '$newPath'" + Rename-Item -Path $ort_dir.FullName -NewName $newName + } +} diff --git a/tools/ci_build/github/windows/jar_esrp_dll.ps1 b/tools/ci_build/github/windows/jar_esrp_dll.ps1 index 8492d7591271b..2a53374d845a0 100644 --- a/tools/ci_build/github/windows/jar_esrp_dll.ps1 +++ b/tools/ci_build/github/windows/jar_esrp_dll.ps1 @@ -1,41 +1,70 @@ -$instruction = $args[0] # extract or repack -$original_jar_file_directory = $args[1] # The directory where the original jar file is located -$original_jar_file_name = $args[2] # The name of the original jar file +param( + [string]$instruction, # Should be 'extract' or 'repack' + [string]$jar_file_directory # The directory where the original jar file is located +) -$original_jar_file_full_path = "$original_jar_file_directory\$original_jar_file_name" -$extracted_file_directory = "$original_jar_file_directory\jar_extracted_full_files" +$extracted_file_directory = Join-Path $jar_file_directory "jar_extracted_full_files" +$state_file = Join-Path $jar_file_directory "repack_list.txt" if ($instruction -eq "extract") { - Write-Host "Extracting the jar file $original_jar_file_full_path..." - & 7z x $original_jar_file_full_path -o"$extracted_file_directory" - if ($lastExitCode -ne 0) { - Write-Host -Object "7z extracting the jar file command failed. Exitcode: $exitCode" - exit $lastExitCode + # Find the main jar file(s) by looking for names that start with 'onnxruntime' + # and excluding common suffixes for sources and javadocs. + $main_jar_files = Get-ChildItem -Path $jar_file_directory -Filter onnxruntime*.jar | Where-Object { $_.Name -notlike '*-sources.jar' -and $_.Name -notlike '*-javadoc.jar' } + + if ($main_jar_files.Count -eq 0) { + Write-Error "No main ONNX Runtime JAR file found in directory: $jar_file_directory" + exit 1 } - Write-Host "Extracted files directory: $extracted_file_directory" - Write-Host "Removing the original jar file..." - Remove-Item -Path "$original_jar_file_full_path" -Force - Write-Host "Removed the original jar file." -} -elseif ($instruction -eq "repack") { + # Clear any previous state file + if (Test-Path $state_file) { + Remove-Item $state_file + } + + foreach ($jar_file in $main_jar_files) { + Write-Host "Extracting the jar file $($jar_file.FullName)..." + & 7z x $jar_file.FullName -o"$extracted_file_directory" + if ($LASTEXITCODE -ne 0) { + Write-Error "7z failed to extract the jar file. Exitcode: $LASTEXITCODE" + exit $LASTEXITCODE + } + + # Save the original name for repacking, then remove the file + $jar_file.Name | Out-File -FilePath $state_file -Append + Write-Host "Removing the original jar file: $($jar_file.FullName)" + Remove-Item -Path $jar_file.FullName -Force + } + Write-Host "Extracted files to directory: $extracted_file_directory" + +} elseif ($instruction -eq "repack") { + if (-not (Test-Path $state_file)) { + Write-Error "State file '$state_file' not found. Cannot repack." + exit 1 + } + Write-Host "Removing ESRP's CodeSignSummary file..." - # It is the summary generated by ESRP tool. It is not needed in the jar file. - Remove-Item -Path "$extracted_file_directory/CodeSignSummary*.*" -Force + Remove-Item -Path "$extracted_file_directory/CodeSignSummary*.*" -Force -ErrorAction SilentlyContinue Write-Host "Removed ESRP's CodeSignSummary file." - Write-Host "Repacking the jar file from directory $extracted_file_directory..." - & 7z a "$original_jar_file_full_path" "$extracted_file_directory\*" - if ($lastExitCode -ne 0) { - Write-Host -Object "7z repacking the jar file command failed. Exitcode: $exitCode" - exit $lastExitCode + $jar_files_to_repack = Get-Content $state_file + + foreach ($jar_file_name in $jar_files_to_repack) { + $repacked_jar_file_path = Join-Path $jar_file_directory $jar_file_name + Write-Host "Repacking to $repacked_jar_file_path from directory $extracted_file_directory..." + & 7z a "$repacked_jar_file_path" "$extracted_file_directory\*" + if ($LASTEXITCODE -ne 0) { + Write-Error "7z failed to repack the jar file. Exitcode: $LASTEXITCODE" + exit $LASTEXITCODE + } + Write-Host "Repacked the jar file $repacked_jar_file_path." } - Write-Host "Repacked the jar file $original_jar_file_full_path." - Write-Host "Removing the extracted files..." + Write-Host "Removing the extracted files and state file..." Remove-Item -Path "$extracted_file_directory" -Recurse -Force - Write-Host "Removed the extracted files." -} -else { - Write-Host "Invalid instruction: $instruction" + Remove-Item -Path $state_file -Force + Write-Host "Cleaned up temporary files." + +} else { + Write-Error "Invalid instruction: '$instruction'. Must be 'extract' or 'repack'." + exit 1 } diff --git a/tools/ci_build/github/windows/jar_gpu_packaging.ps1 b/tools/ci_build/github/windows/jar_gpu_packaging.ps1 deleted file mode 100644 index 1c94f4678f988..0000000000000 --- a/tools/ci_build/github/windows/jar_gpu_packaging.ps1 +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -$ErrorActionPreference = "Stop" -Write-Output "Start" -dir -Copy-Item -Path $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-linux-x64\ai\onnxruntime\native\linux-x64\libonnxruntime_providers_cuda.so -Destination $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-linux-x64-tensorrt\ai\onnxruntime\native\linux-x64 -pushd onnxruntime-java-linux-x64-tensorrt -Write-Output "Run 7z" -7z a $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-win-x64\testing.jar libcustom_op_library.so -Remove-Item -Path libcustom_op_library.so -7z a $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-win-x64\onnxruntime-$Env:ONNXRUNTIMEVERSION.jar . -popd -pushd onnxruntime-java-win-x64 -ren onnxruntime-$Env:ONNXRUNTIMEVERSION.jar onnxruntime_gpu-$Env:ONNXRUNTIMEVERSION.jar -ren onnxruntime-$Env:ONNXRUNTIMEVERSION-javadoc.jar onnxruntime_gpu-$Env:ONNXRUNTIMEVERSION-javadoc.jar -ren onnxruntime-$Env:ONNXRUNTIMEVERSION-sources.jar onnxruntime_gpu-$Env:ONNXRUNTIMEVERSION-sources.jar -ren onnxruntime-$Env:ONNXRUNTIMEVERSION.pom onnxruntime_gpu-$Env:ONNXRUNTIMEVERSION.pom -popd diff --git a/tools/ci_build/github/windows/jar_packaging.ps1 b/tools/ci_build/github/windows/jar_packaging.ps1 deleted file mode 100644 index a132ba6b26e2a..0000000000000 --- a/tools/ci_build/github/windows/jar_packaging.ps1 +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -$ErrorActionPreference = "Stop" -Write-Output "Start" -dir -pushd onnxruntime-java-linux-x64 -Write-Output "Run 7z" -7z a $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-win-x64\testing.jar libcustom_op_library.so -Remove-Item -Path libcustom_op_library.so -7z a $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-win-x64\onnxruntime-$Env:ONNXRUNTIMEVERSION.jar . -popd -pushd onnxruntime-java-osx-x86_64 -7z a $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-win-x64\testing.jar libcustom_op_library.dylib -Remove-Item -Path libcustom_op_library.dylib -7z a $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-win-x64\onnxruntime-$Env:ONNXRUNTIMEVERSION.jar . -popd -pushd onnxruntime-java-linux-aarch64 -Remove-Item -Path libcustom_op_library.so -7z a $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-win-x64\onnxruntime-$Env:ONNXRUNTIMEVERSION.jar . -popd -pushd onnxruntime-java-osx-arm64 -Remove-Item -Path libcustom_op_library.dylib -7z a $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-win-x64\onnxruntime-$Env:ONNXRUNTIMEVERSION.jar . -popd diff --git a/tools/ci_build/github/windows/jar_packaging.py b/tools/ci_build/github/windows/jar_packaging.py new file mode 100644 index 0000000000000..2354363610251 --- /dev/null +++ b/tools/ci_build/github/windows/jar_packaging.py @@ -0,0 +1,312 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +""" +Packages ONNX Runtime Java artifacts by combining native libraries from +various platform builds into final Java archive (JAR) files using 7z. +""" + +import argparse +import glob +import os +import re +import shutil +import subprocess +import sys +from pathlib import Path +from typing import Any + +# Add semver as a dependency +try: + import semver +except ImportError: + print("Error: The 'semver' package is not installed. Please add it to your requirements.txt.", file=sys.stderr) + sys.exit(1) + +# --- Helper Functions for Archiving --- + + +def find_7z_executable(): + """Finds the 7z executable, checking the system PATH and default installation locations.""" + # 1. Check if '7z' is in the PATH + seven_zip_exe = shutil.which("7z") + if seven_zip_exe: + return seven_zip_exe + + # 2. Check the default installation directory under Program Files + program_files = os.environ.get("ProgramFiles") # noqa: SIM112 + if program_files: + default_path = Path(program_files) / "7-Zip" / "7z.exe" + if default_path.is_file(): + return str(default_path) + + return None + + +SEVEN_ZIP_EXE = find_7z_executable() + + +def add_file_to_archive(archive_path: Path, file_to_add: Path, description: str): + """Appends a single file to a zip archive (JAR file) using 7z.""" + print(f" -> {description}...") + try: + if not SEVEN_ZIP_EXE: + raise FileNotFoundError + # Run 7z from the file's parent directory to ensure a clean archive path. + subprocess.run( + [SEVEN_ZIP_EXE, "a", str(archive_path), file_to_add.name], + check=True, + cwd=file_to_add.parent, + capture_output=True, + text=True, + ) + except FileNotFoundError: + print( + "Error: '7z' command not found. Please ensure 7-Zip is installed and in your PATH, or in the default location 'C:\\Program Files\\7-Zip'.", + file=sys.stderr, + ) + raise + except subprocess.CalledProcessError as e: + print(f"Error: 7z failed to archive '{file_to_add.name}' to '{archive_path.name}'.", file=sys.stderr) + print(f"Reason: {e.stderr}", file=sys.stderr) + raise + + +def archive_directory_contents(archive_path: Path, source_dir: Path, description: str): + """Archives a directory into a zip file (JAR file) using 7z, preserving its top-level name.""" + print(f" -> {description}...") + try: + if not SEVEN_ZIP_EXE: + raise FileNotFoundError + # Run 7z from the parent of the source directory to ensure the source directory + # itself is added to the archive, preserving the path structure (e.g., 'ai/...'). + subprocess.run( + [SEVEN_ZIP_EXE, "a", str(archive_path), source_dir.name], + check=True, + cwd=source_dir.parent, + capture_output=True, + text=True, + ) + except FileNotFoundError: + print( + "Error: '7z' command not found. Please ensure 7-Zip is installed and in your PATH, or in the default location 'C:\\Program Files\\7-Zip'.", + file=sys.stderr, + ) + raise + except subprocess.CalledProcessError as e: + print(f"Error: 7z failed to archive directory '{source_dir.name}' to '{archive_path.name}'.", file=sys.stderr) + print(f"Reason: {e.stderr}", file=sys.stderr) + raise + + +# --- Validation Helpers --- + + +def validate_version(version_string: str): + """Validates if the version string conforms to the project's format.""" + print(f"Validating version string: {version_string}...") + try: + version_info = semver.Version.parse(version_string) + if version_info.prerelease: + prerelease_tag = version_info.prerelease + allowed_tags_pattern = r"^(alpha|beta|rc)\d+$" + if not re.match(allowed_tags_pattern, str(prerelease_tag)): + raise ValueError(f"Pre-release tag '{prerelease_tag}' is not an allowed type.") + except ValueError as e: + print(f"Error: Version '{version_string}' is not valid. Reason: {e}", file=sys.stderr) + print("Expected format is 'X.Y.Z' or 'X.Y.Z-(alpha|beta|rc)N'.", file=sys.stderr) + sys.exit(1) + print("Version format is valid.") + + +def validate_companion_jars(base_jar_path: Path): + """Ensures that -sources.jar and -javadoc.jar files exist.""" + print("Validating presence of companion -sources.jar and -javadoc.jar...") + base_stem = base_jar_path.stem + directory = base_jar_path.parent + sources_jar_path = directory / f"{base_stem}-sources.jar" + + if not sources_jar_path.is_file(): + print(f"Error: Missing companion sources JAR. Expected: {sources_jar_path.name}", file=sys.stderr) + sys.exit(1) + + if not list(directory.glob(f"{base_stem}-javadoc*.jar")): + print(f"Error: Missing companion javadoc JAR. Expected file like: {base_stem}-javadoc.jar", file=sys.stderr) + sys.exit(1) + print("Companion JARs are present.") + + +# --- Core Logic Function --- + + +def process_platform_archive( + platform_path: Path, + main_archive_file: Path, + test_archive_file: Path, + custom_lib_file: str, + archive_custom_lib: bool, +): + """Processes a single platform directory, adding only the 'ai' subdirectory to the main JAR.""" + print(f"Processing platform: {platform_path}...") + + # 1. Handle the custom op library. + custom_lib_full_path = platform_path / custom_lib_file + if custom_lib_file and custom_lib_full_path.is_file(): + if archive_custom_lib: + add_file_to_archive(test_archive_file, custom_lib_full_path, f"Archiving '{custom_lib_file}' to test JAR") + # Always remove the lib after processing to prevent it from being in the main JAR. + print(f" -> Removing '{custom_lib_file}' from source directory...") + custom_lib_full_path.unlink() + elif archive_custom_lib: + # If we expected to archive the file but it wasn't there, it's a fatal error. + print(f"Error: Expected custom op library '{custom_lib_file}' not found in {platform_path}", file=sys.stderr) + sys.exit(1) + + # 2. Archive only the native library directory ('ai/...') to the main JAR. + # This explicitly excludes other files or folders like '_manifest'. + native_lib_root = platform_path / "ai" + if native_lib_root.is_dir(): + archive_directory_contents( + main_archive_file, native_lib_root, f"Archiving native libs from '{native_lib_root.name}' to main JAR" + ) + else: + print(f"Warning: Native library path 'ai/' not found in {platform_path}. Skipping main archive step.") + + print(f"Finished platform: {platform_path}") + print("--------------------------------") + + +def run_packaging(package_type: str, build_dir: str): + """The main logic for the packaging process, refactored to be callable.""" + artifacts_base_dir = Path(build_dir) / "java-artifact" + primary_package_dir = artifacts_base_dir / "onnxruntime-java-win-x64" + if not primary_package_dir.is_dir(): + print(f"Error: Primary package directory not found at '{primary_package_dir}'", file=sys.stderr) + sys.exit(1) + + # --- Version Discovery --- + print(f"Discovering version from JAR files in '{primary_package_dir}'...") + jar_pattern = str(primary_package_dir / "onnxruntime*-*.jar") + jar_files = [Path(f) for f in glob.glob(jar_pattern) if "-sources" not in f and "-javadoc" not in f] + if not jar_files: + print( + f"Error: Could not find a main JAR file in '{primary_package_dir}' to determine the version.", + file=sys.stderr, + ) + sys.exit(1) + + main_jar_file = jar_files[0] + validate_companion_jars(main_jar_file) + + version = "" + stem = main_jar_file.stem + try: + # Per user feedback, the version is everything after the first dash. + _, version = stem.split("-", 1) + except ValueError: + # This will happen if there is no dash in the filename, which is unexpected. + print( + f"Error: Could not parse version from JAR file '{main_jar_file.name}'. Expected format -.jar", + file=sys.stderr, + ) + sys.exit(1) + + if not version: + print( + f"Error: Could not parse version from JAR file '{main_jar_file.name}'. Version part is empty.", + file=sys.stderr, + ) + sys.exit(1) + + print(f"Version discovered: {version}") + validate_version(version) + + # --- Package Definitions --- + package_definitions: dict[str, dict[str, Any]] = { + "cpu": { + "platforms": [ + {"path": "onnxruntime-java-linux-x64", "lib": "libcustom_op_library.so", "archive_lib": True}, + {"path": "onnxruntime-java-osx-x86_64", "lib": "libcustom_op_library.dylib", "archive_lib": True}, + {"path": "onnxruntime-java-linux-aarch64", "lib": "libcustom_op_library.so", "archive_lib": False}, + {"path": "onnxruntime-java-osx-arm64", "lib": "libcustom_op_library.dylib", "archive_lib": False}, + ] + }, + "gpu": { + "platforms": [ + {"path": "onnxruntime-java-linux-x64", "lib": "libcustom_op_library.so", "archive_lib": False} + ] + }, + } + + # --- Processing Loop --- + print(f"\n## Configuring for {package_type.upper()} package build...") + + final_main_archive = main_jar_file + final_test_archive = primary_package_dir / "testing.jar" + + print(f"Using '{final_main_archive.name}' as the base for in-place packaging.") + + if not final_test_archive.is_file(): + print(f"Error: Base 'testing.jar' not found at '{final_test_archive}'.", file=sys.stderr) + sys.exit(1) + + platforms_to_process = package_definitions[package_type]["platforms"] + + for platform in platforms_to_process: + platform_full_path = artifacts_base_dir / platform["path"] + if not platform_full_path.is_dir(): + print(f"Error: Required platform artifact directory not found: {platform_full_path}", file=sys.stderr) + sys.exit(1) + + process_platform_archive( + platform_path=platform_full_path, + main_archive_file=final_main_archive, + test_archive_file=final_test_archive, + custom_lib_file=platform["lib"], + archive_custom_lib=platform["archive_lib"], + ) + + print("\nScript completed successfully.") + + +def main(): + """Main script entry point for command-line execution.""" + if sys.platform != "win32": + print("Error: This script is intended to be run on Windows.", file=sys.stderr) + sys.exit(1) + + parser = argparse.ArgumentParser(description="Package ONNX Runtime Java artifacts.") + parser.add_argument( + "--package_type", + type=str, + choices=["cpu", "gpu"], + default="cpu", + help="The type of package to build ('cpu' or 'gpu').", + ) + parser.add_argument( + "--build_dir", + type=str, + help="The build directory containing the java-artifact folder.", + ) + args = parser.parse_args() + + build_dir = args.build_dir + if not build_dir: + try: + build_dir = os.environ["BUILD_BINARIESDIRECTORY"] + except KeyError: + print( + "Error: Environment variable BUILD_BINARIESDIRECTORY is not set and --build_dir is not provided.", + file=sys.stderr, + ) + sys.exit(1) + + run_packaging(args.package_type, build_dir) + + +if __name__ == "__main__": + try: + main() + except Exception as e: + print(f"\nAn unhandled error occurred: {e}", file=sys.stderr) + sys.exit(1) diff --git a/tools/ci_build/github/windows/jar_packaging_test.py b/tools/ci_build/github/windows/jar_packaging_test.py new file mode 100644 index 0000000000000..91b68728dad15 --- /dev/null +++ b/tools/ci_build/github/windows/jar_packaging_test.py @@ -0,0 +1,157 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import zipfile +from pathlib import Path + +import jar_packaging # The refactored script +import pytest + + +# Helper to create an empty file +def create_empty_file(path): + Path(path).touch() + + +# Helper to create a dummy JAR file +def create_dummy_jar(path): + with zipfile.ZipFile(path, "w") as zf: + zf.writestr("META-INF/MANIFEST.MF", "Manifest-Version: 1.0\n") + + +@pytest.fixture +def directory_setup_factory(tmp_path): + """ + A factory fixture that returns a function to set up a test directory + for a given package type and version. + """ + + def _setup_test_directory(package_type: str, version_string: str): + """Sets up a temporary directory structure mimicking the build artifacts.""" + java_artifact_dir = tmp_path / "java-artifact" + win_dir = java_artifact_dir / "onnxruntime-java-win-x64" + linux_dir = java_artifact_dir / "onnxruntime-java-linux-x64" + osx_dir = java_artifact_dir / "onnxruntime-java-osx-x86_64" + + # --- Main artifact directory (Windows) --- + win_dir.mkdir(parents=True, exist_ok=True) + artifact_name = f"onnxruntime_{package_type}" if package_type == "gpu" else "onnxruntime" + create_dummy_jar(win_dir / f"{artifact_name}-{version_string}.jar") + create_dummy_jar(win_dir / f"{artifact_name}-{version_string}-sources.jar") + create_dummy_jar(win_dir / f"{artifact_name}-{version_string}-javadoc.jar") + create_empty_file(win_dir / f"{artifact_name}-{version_string}.pom") + create_dummy_jar(win_dir / "testing.jar") + (win_dir / "_manifest" / "spdx_2.2").mkdir(parents=True, exist_ok=True) + + # --- Linux platform --- + linux_native_dir = linux_dir / "ai" / "onnxruntime" / "native" / "linux-x64" + linux_native_dir.mkdir(parents=True, exist_ok=True) + create_empty_file(linux_dir / "libcustom_op_library.so") + create_empty_file(linux_native_dir / "libonnxruntime.so") + create_empty_file(linux_native_dir / "libonnxruntime4j_jni.so") + if package_type == "gpu": + create_empty_file(linux_native_dir / "libonnxruntime_providers_cuda.so") + (linux_dir / "_manifest" / "spdx_2.2").mkdir(parents=True, exist_ok=True) + + # --- macOS and other platforms (for CPU test) --- + if package_type == "cpu": + osx_native_dir = osx_dir / "ai" / "onnxruntime" / "native" / "osx-x86_64" + osx_native_dir.mkdir(parents=True, exist_ok=True) + create_empty_file(osx_dir / "libcustom_op_library.dylib") + create_empty_file(osx_native_dir / "libonnxruntime.dylib") + create_empty_file(osx_native_dir / "libonnxruntime4j_jni.dylib") + (osx_dir / "_manifest" / "spdx_2.2").mkdir(parents=True, exist_ok=True) + + # Add linux-aarch64 and osx-arm64 for CPU test + linux_aarch64_dir = java_artifact_dir / "onnxruntime-java-linux-aarch64" + linux_aarch64_native_dir = linux_aarch64_dir / "ai" / "onnxruntime" / "native" / "linux-aarch64" + linux_aarch64_native_dir.mkdir(parents=True, exist_ok=True) + create_empty_file(linux_aarch64_dir / "libcustom_op_library.so") + + osx_arm64_dir = java_artifact_dir / "onnxruntime-java-osx-arm64" + osx_arm64_native_dir = osx_arm64_dir / "ai" / "onnxruntime" / "native" / "osx-arm64" + osx_arm64_native_dir.mkdir(parents=True, exist_ok=True) + create_empty_file(osx_arm64_dir / "libcustom_op_library.dylib") + + return tmp_path + + return _setup_test_directory + + +@pytest.mark.parametrize("version_string", ["1.23.0", "1.23.0-rc1"]) +def test_gpu_packaging(directory_setup_factory, version_string): + """ + Tests the GPU packaging logic for both release and pre-release versions + to ensure correct files are added to the JARs. + """ + temp_build_dir = directory_setup_factory("gpu", version_string) + + # Run the packaging script logic + jar_packaging.run_packaging("gpu", str(temp_build_dir)) + + # --- Verification --- + win_dir = temp_build_dir / "java-artifact" / "onnxruntime-java-win-x64" + main_jar_path = win_dir / f"onnxruntime_gpu-{version_string}.jar" + testing_jar_path = win_dir / "testing.jar" + + # 1. Verify the main JAR contains the Linux native libraries + with zipfile.ZipFile(main_jar_path, "r") as zf: + jar_contents = zf.namelist() + assert "ai/onnxruntime/native/linux-x64/libonnxruntime.so" in jar_contents + assert "ai/onnxruntime/native/linux-x64/libonnxruntime4j_jni.so" in jar_contents + assert "ai/onnxruntime/native/linux-x64/libonnxruntime_providers_cuda.so" in jar_contents + + # 2. Verify the testing JAR does not contain the custom op library for GPU builds + with zipfile.ZipFile(testing_jar_path, "r") as zf: + jar_contents = zf.namelist() + # The custom op lib for linux is not archived for GPU builds. + # This checks that it's NOT in the test jar. + assert "libcustom_op_library.so" not in jar_contents + + # 3. Verify the custom op library was removed from the source linux directory + linux_dir = temp_build_dir / "java-artifact" / "onnxruntime-java-linux-x64" + assert not (linux_dir / "libcustom_op_library.so").exists() + + +@pytest.mark.parametrize("version_string", ["1.23.0", "1.23.0-rc1"]) +def test_cpu_packaging(directory_setup_factory, version_string): + """ + Tests the CPU packaging logic to ensure correct files are added to the JARs. + """ + temp_build_dir = directory_setup_factory("cpu", version_string) + + # Run the packaging script logic + jar_packaging.run_packaging("cpu", str(temp_build_dir)) + + # --- Verification --- + win_dir = temp_build_dir / "java-artifact" / "onnxruntime-java-win-x64" + main_jar_path = win_dir / f"onnxruntime-{version_string}.jar" + testing_jar_path = win_dir / "testing.jar" + + # 1. Verify the main JAR contains native libraries from all relevant platforms + with zipfile.ZipFile(main_jar_path, "r") as zf: + jar_contents = zf.namelist() + # Linux libs + assert "ai/onnxruntime/native/linux-x64/libonnxruntime.so" in jar_contents + assert "ai/onnxruntime/native/linux-x64/libonnxruntime4j_jni.so" in jar_contents + # macOS libs + assert "ai/onnxruntime/native/osx-x86_64/libonnxruntime.dylib" in jar_contents + assert "ai/onnxruntime/native/osx-x86_64/libonnxruntime4j_jni.dylib" in jar_contents + # GPU libs should NOT be present + assert "ai/onnxruntime/native/linux-x64/libonnxruntime_providers_cuda.so" not in jar_contents + + # 2. Verify the testing JAR contains the custom op libraries that should be archived + with zipfile.ZipFile(testing_jar_path, "r") as zf: + jar_contents = zf.namelist() + assert "libcustom_op_library.so" in jar_contents + assert "libcustom_op_library.dylib" in jar_contents + + # 3. Verify the custom op libraries were removed from the source directories + linux_dir = temp_build_dir / "java-artifact" / "onnxruntime-java-linux-x64" + osx_dir = temp_build_dir / "java-artifact" / "onnxruntime-java-osx-x86_64" + linux_aarch64_dir = temp_build_dir / "java-artifact" / "onnxruntime-java-linux-aarch64" + osx_arm64_dir = temp_build_dir / "java-artifact" / "onnxruntime-java-osx-arm64" + assert not (linux_dir / "libcustom_op_library.so").exists() + assert not (osx_dir / "libcustom_op_library.dylib").exists() + assert not (linux_aarch64_dir / "libcustom_op_library.so").exists() + assert not (osx_arm64_dir / "libcustom_op_library.dylib").exists() diff --git a/tools/ci_build/github/windows/python/requirements.txt b/tools/ci_build/github/windows/python/requirements.txt index b36f6045a5962..91c3a88aca464 100644 --- a/tools/ci_build/github/windows/python/requirements.txt +++ b/tools/ci_build/github/windows/python/requirements.txt @@ -11,3 +11,6 @@ psutil onnxscript==0.3.2 jinja2 markupsafe +semver +packaging +coloredlogs diff --git a/tools/ci_build/github/windows/setup_nodejs.ps1 b/tools/ci_build/github/windows/setup_nodejs.ps1 new file mode 100644 index 0000000000000..478bb35f010f8 --- /dev/null +++ b/tools/ci_build/github/windows/setup_nodejs.ps1 @@ -0,0 +1,59 @@ +[CmdletBinding()] +param ( + # The major version of Node.js to use. Example: '20' + [Parameter(Mandatory = $true)] + [string]$MajorVersion +) + +try { + # Get the processor architecture ID using CIM + # 9 = x64, 12 = arm64 + $architectureId = (Get-CimInstance -ClassName Win32_Processor).Architecture + + # Map the architecture ID to the string used in the tool path + $archString = switch ($architectureId) { + 9 { "x64" } + 12 { "arm64" } + default { throw "Unsupported CPU architecture: $architectureId. This script only supports x64 and arm64." } + } + + Write-Host "Detected Architecture: $archString" + + # --- New Logic to find the latest version --- + $nodeVersionsPath = Join-Path $env:AGENT_TOOLSDIRECTORY "node" + if (-not (Test-Path -Path $nodeVersionsPath)) { + throw "Node.js tool directory not found at '$nodeVersionsPath'." + } + + # Find all directory names matching the major version (e.g., "20.*") + $matchingVersions = Get-ChildItem -Path $nodeVersionsPath | + Where-Object { $_.PSIsContainer -and $_.Name -like "$MajorVersion.*" } | + Select-Object -ExpandProperty Name + + if ($null -eq $matchingVersions) { + throw "No installed Node.js versions found for major version '$MajorVersion' at '$nodeVersionsPath'." + } + + # Sort the versions to find the highest one and select it + $latestVersion = $matchingVersions | Sort-Object -Descending {[version]$_} | Select-Object -First 1 + Write-Host "Found latest matching version: $latestVersion" + # --- End of New Logic --- + + # Construct the full path using the discovered latest version + $nodeToolPath = Join-Path $nodeVersionsPath "$latestVersion\$archString" + + # Verify that the final directory exists + if (-not (Test-Path -Path $nodeToolPath -PathType Container)) { + throw "Node.js tool path not found. Please ensure version '$latestVersion' for '$archString' exists at: $nodeToolPath" + } + + # Use the Azure DevOps logging command to prepend the directory to the PATH + Write-Host "##vso[task.prependpath]$nodeToolPath" + Write-Host "Successfully added Node.js $latestVersion ($archString) to the PATH." + +} +catch { + # If any error occurs, log it as an error in the pipeline and fail the task + Write-Host "##vso[task.logissue type=error]$($_.Exception.Message)" + exit 1 +} \ No newline at end of file diff --git a/tools/ci_build/github/windows/sign_java_artifacts.py b/tools/ci_build/github/windows/sign_java_artifacts.py new file mode 100644 index 0000000000000..19d1a4af98799 --- /dev/null +++ b/tools/ci_build/github/windows/sign_java_artifacts.py @@ -0,0 +1,139 @@ +import argparse +import hashlib +import os +import platform +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path + + +def get_gpg_path() -> Path: + """Finds the path to the GPG executable.""" + if platform.system() == "Windows": + program_files_x86 = os.environ.get("ProgramFiles(x86)") # noqa: SIM112 + if not program_files_x86: + raise OSError("ProgramFiles(x86) environment variable not found.") + return Path(program_files_x86) / "gnupg/bin/gpg.exe" + + gpg_path_str = shutil.which("gpg") + if gpg_path_str is None: + raise FileNotFoundError("gpg executable not found in system PATH.") + return Path(gpg_path_str) + + +def run_command(command: list[str], check: bool = True) -> subprocess.CompletedProcess: + """Executes a command and raises an exception if it fails.""" + print(f"Running command: {' '.join(command)}") + result = subprocess.run(command, capture_output=True, text=True, check=False) + if check and result.returncode != 0: + print(f"Command failed with exit code {result.returncode}") + print(f"Stdout:\n{result.stdout}") + print(f"Stderr:\n{result.stderr}") + raise subprocess.CalledProcessError(result.returncode, command, result.stdout, result.stderr) + return result + + +def create_hash_file(file_path: Path, algorithm: str) -> None: + """Creates a checksum file for the given file using the specified algorithm.""" + print(f" - Generating {algorithm.upper()} checksum...") + try: + hasher = hashlib.new(algorithm) + with file_path.open("rb") as f: + # Read in chunks to handle large files efficiently + while chunk := f.read(8192): + hasher.update(chunk) + + hash_value = hasher.hexdigest() + # Create checksum file in 'sha1sum'/'md5sum' format. + # The '*' indicates to read the file in binary mode for verification tools. + Path(f"{file_path}.{algorithm}").write_text(hash_value.lower(), encoding="utf-8") + except Exception as e: + print(f"Error generating {algorithm} hash for {file_path}: {e}") + raise + + +def main() -> None: + """ + Signs files with GPG and generates checksums. + """ + parser = argparse.ArgumentParser(description="Signs files with GPG and generates checksums.") + parser.add_argument("jar_file_directory", help="The directory containing files to sign.") + args = parser.parse_args() + + jar_file_directory = Path(args.jar_file_directory) + if not jar_file_directory.is_dir(): + print(f"Error: Directory not found at '{jar_file_directory}'", file=sys.stderr) + sys.exit(1) + + print(f"\nListing files to be processed in '{jar_file_directory}':") + files_to_process = [p for p in jar_file_directory.rglob("*") if p.is_file()] + for file_path in files_to_process: + print(f" - {file_path}") + print(f"Found {len(files_to_process)} files.") + + print("\nGetting GnuPG signing keys from environment variables.") + gpg_passphrase = os.environ.get("JAVA_PGP_PWD") + gpg_private_key = os.environ.get("JAVA_PGP_KEY") + + if not gpg_passphrase or not gpg_private_key: + print( + "Error: GPG passphrase or private key not found in environment variables ('JAVA_PGP_PWD', 'JAVA_PGP_KEY').", + file=sys.stderr, + ) + sys.exit(1) + + gpg_exe_path = get_gpg_path() + if not gpg_exe_path.is_file(): + print(f"Error: GPG executable not found at '{gpg_exe_path}'.", file=sys.stderr) + sys.exit(1) + + agent_temp_dir = os.environ.get("AGENT_TEMPDIRECTORY") + + # Use a single temporary directory to manage all temporary files + with tempfile.TemporaryDirectory(dir=agent_temp_dir) as temp_dir: + temp_dir_path = Path(temp_dir) + print(f"Created temporary directory: {temp_dir_path}") + + private_key_file = temp_dir_path / "private.key" + passphrase_file = temp_dir_path / "passphrase.txt" + + print("Writing GnuPG key and passphrase to temporary files.") + private_key_file.write_text(gpg_private_key, encoding="utf-8") + passphrase_file.write_text(gpg_passphrase, encoding="utf-8") + + print("Importing GnuPG private key.") + run_command([str(gpg_exe_path), "--batch", "--import", str(private_key_file)]) + print("Successfully imported GnuPG private key.") + + print(f"\nProcessing {len(files_to_process)} files in '{jar_file_directory}'.") + + for file_path in files_to_process: + print(f"Processing file: {file_path}") + + # GPG Signing (.asc) + print(" - GnuPG signing...") + run_command( + [ + str(gpg_exe_path), + "--pinentry-mode", + "loopback", + "--passphrase-file", + str(passphrase_file), + "--detach-sign", + "--armor", + str(file_path), + ] + ) + + # SHA-1 and MD5 Checksums + create_hash_file(file_path, "sha1") + create_hash_file(file_path, "md5") + + print("\nFile signing and checksum generation completed.") + print("Temporary directory and its contents have been deleted.") + + +if __name__ == "__main__": + main() diff --git a/tools/ci_build/java/pom.xml b/tools/ci_build/java/pom.xml new file mode 100644 index 0000000000000..d0856db7cee70 --- /dev/null +++ b/tools/ci_build/java/pom.xml @@ -0,0 +1,32 @@ + + 4.0.0 + + com.example + dependency-downloader + 1.0.0 + pom + + + UTF-8 + 3.25.5 + 1.10.2 + + + + + com.google.protobuf + protobuf-java + ${protobuf.version} + + + + org.junit.platform + junit-platform-console-standalone + ${junit.platform.version} + test + + + + \ No newline at end of file diff --git a/tools/ci_build/linux_java_copy_strip_binary.py b/tools/ci_build/linux_java_copy_strip_binary.py new file mode 100644 index 0000000000000..b9ca856d1c514 --- /dev/null +++ b/tools/ci_build/linux_java_copy_strip_binary.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +""" +Prepares native shared libraries for the ONNX Runtime Java package. + +This script is a build utility that run as part of a packaging pipeline and takes compiled C/C++ shared libraries +(.so, .dylib) and stages them for packaging into a Java JAR file. + +It expected the following inputs: +/ +└── / + ├── libonnxruntime.so (File from --lib-name) + ├── libonnxruntime4j_jni.so (File from --native-lib-name) + ├── libcustom_op_library.so + │ + ├── (Optional) libonnxruntime_providers_shared.so + ├── (Optional) libonnxruntime_providers_cuda.so + └── (Optional) libonnxruntime_providers_tensorrt.so + +It performs the following key operations: + +1. Validates the existence of all required source directories and libraries. +2. Creates the specific Java Native Interface (JNI) directory structure + (ai/onnxruntime/native/). +3. Copies the main, JNI, and custom op libraries to their destinations. +4. For macOS, extracts debug symbols into .dSYM files using `dsymutil`. +5. Strips all release binaries of their debug symbols to reduce file size. +6. Copies optional provider libraries (e.g., CUDA, TensorRT) for Linux builds. + +It is intended to be called from a CI/CD pipeline as part of the overall +build process for the onnxruntime-java package. +""" + +import argparse +import logging +import platform +import shutil +import subprocess +import sys +from pathlib import Path + +# --- Configuration --- +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", +) + + +# --- Helper Functions --- +def run_command(command: list[str | Path]): + """Runs an external command and exits the script if the command fails.""" + str_command = " ".join(map(str, command)) + logging.info(f"Running command: '{str_command}'") + try: + proc = subprocess.run(command, check=True, text=True, capture_output=True) + logging.info(f"Successfully executed: {Path(command[0]).name}") + if proc.stdout: + logging.debug(f"STDOUT: {proc.stdout.strip()}") + except FileNotFoundError: + logging.error(f"Command not found: '{command[0]}'. Please ensure it is installed and in your PATH.") + raise + except subprocess.CalledProcessError as e: + logging.error(f"Command '{Path(e.cmd[0]).name}' failed with exit code {e.returncode}.") + if e.stdout: + logging.error(f"STDOUT: {e.stdout.strip()}") + if e.stderr: + logging.error(f"STDERR: {e.stderr.strip()}") + raise + + +# --- Main Execution --- +def main(): + """Main function to parse arguments and package the native libraries.""" + parser = argparse.ArgumentParser( + description="Packages ONNX Runtime native libraries for Java.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Arguments + parser.add_argument("--binary-dir", required=True, type=Path, help="Path to the build binaries directory.") + parser.add_argument("--artifact-name", required=True, help="Name of the final artifact directory.") + parser.add_argument("--build-config", required=True, help="CMake build configuration (e.g., Release).") + parser.add_argument("--lib-name", required=True, help="Filename of the main ONNX Runtime shared library.") + parser.add_argument("--native-lib-name", required=True, help="Filename of the JNI shared library.") + parser.add_argument("--arch", required=True, help="Architecture string (e.g., osx-x86_64).") + args = parser.parse_args() + + # --- Path Setup and Validation --- + logging.info(f"System Info: {' '.join(platform.uname())}") + + source_build_dir = args.binary_dir / args.build_config + target_artifact_dir = args.binary_dir / args.artifact_name + + # Validate that the source build directory exists. + if not source_build_dir.is_dir(): + logging.error(f"Source build directory not found: {source_build_dir}") + sys.exit(1) + + # Map architecture names for macOS to align with Java conventions + arch = args.arch + if args.lib_name.endswith(".dylib"): + if arch == "osx-x86_64": + arch = "osx-x64" + elif arch == "osx-arm64": + arch = "osx-aarch64" + + # --- Library Processing --- + native_folder = target_artifact_dir / "ai" / "onnxruntime" / "native" / arch + native_folder.mkdir(parents=True, exist_ok=True) + logging.info(f"Staging native libraries in: {native_folder}") + + # Validate that all required library files exist before processing. + main_lib_src = source_build_dir / args.lib_name + jni_lib_src = source_build_dir / args.native_lib_name + + required_files = [main_lib_src, jni_lib_src] + lib_suffix = ".dylib" if args.lib_name.endswith(".dylib") else ".so" + custom_op_lib_src = source_build_dir / f"libcustom_op_library{lib_suffix}" + required_files.append(custom_op_lib_src) + + for f in required_files: + if not f.is_file(): + logging.error(f"Required library file not found: {f}") + sys.exit(1) + logging.info("All required source library files found.") + + # Start processing now that checks have passed + if lib_suffix == ".dylib": # macOS + logging.info("Processing macOS libraries (.dylib)...") + run_command(["dsymutil", main_lib_src, "-o", native_folder / f"{args.lib_name}.dSYM"]) + shutil.copy2(main_lib_src, native_folder / "libonnxruntime.dylib") + run_command(["strip", "-S", native_folder / "libonnxruntime.dylib"]) + + run_command(["dsymutil", jni_lib_src, "-o", native_folder / f"{args.native_lib_name}.dSYM"]) + shutil.copy2(jni_lib_src, native_folder / "libonnxruntime4j_jni.dylib") + run_command(["strip", "-S", native_folder / "libonnxruntime4j_jni.dylib"]) + + shutil.copy2(custom_op_lib_src, target_artifact_dir) + + elif lib_suffix == ".so": # Linux + logging.info("Processing Linux libraries (.so)...") + + # Main library + main_lib_dest = native_folder / "libonnxruntime.so" + shutil.copy2(main_lib_src, main_lib_dest) + run_command(["strip", "-S", main_lib_dest]) + + # JNI library + jni_lib_dest = native_folder / "libonnxruntime4j_jni.so" + shutil.copy2(jni_lib_src, jni_lib_dest) + run_command(["strip", "-S", jni_lib_dest]) + + # Custom op library (not stripped as it's for testing) + shutil.copy2(custom_op_lib_src, target_artifact_dir) + + # Provider checks are optional, so we check for their existence here. + for provider in ["cuda", "tensorrt"]: + provider_lib_src = source_build_dir / f"libonnxruntime_providers_{provider}.so" + if provider_lib_src.exists(): + logging.info(f"Found optional {provider} provider library. Copying and stripping...") + + # Shared provider library + shared_provider_lib_src = source_build_dir / "libonnxruntime_providers_shared.so" + if shared_provider_lib_src.exists(): + shared_provider_dest = native_folder / shared_provider_lib_src.name + shutil.copy2(shared_provider_lib_src, shared_provider_dest) + run_command(["strip", "-S", shared_provider_dest]) + + # Specific provider library + provider_lib_dest = native_folder / provider_lib_src.name + shutil.copy2(provider_lib_src, provider_lib_dest) + run_command(["strip", "-S", provider_lib_dest]) + else: + logging.warning(f"Unsupported library type for '{args.lib_name}'. No special processing will occur.") + + # --- Finalization --- + logging.info(f"--- Final contents of '{target_artifact_dir}' ---") + for path in sorted(target_artifact_dir.rglob("*")): + logging.info(f" - {path.relative_to(target_artifact_dir)}") + logging.info("--- End of contents ---") + + jar_dir_to_remove = target_artifact_dir / "jar" + if jar_dir_to_remove.is_dir(): + logging.info(f"Removing temporary directory: {jar_dir_to_remove}") + shutil.rmtree(jar_dir_to_remove) + + logging.info("Script completed successfully.") + + +if __name__ == "__main__": + try: + main() + except Exception as e: + logging.error(f"Script failed due to an unhandled error: {e}") + sys.exit(1) diff --git a/tools/ci_build/manage_java_artifacts.py b/tools/ci_build/manage_java_artifacts.py new file mode 100644 index 0000000000000..51521f651adec --- /dev/null +++ b/tools/ci_build/manage_java_artifacts.py @@ -0,0 +1,312 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +# This script runs after ORT jars are built. It picks up the jars from ORT's build dir then repack them a bit. + +import argparse +import logging +import re +import shutil +import subprocess +import sys +import zipfile +from pathlib import Path + +# --- Configuration --- +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") + + +# --- Helper Functions --- +def run_command(command: list, working_dir: Path): + """Runs a command in a specified directory and checks for errors.""" + logging.info(f"Running command: '{' '.join(map(str, command))}' in '{working_dir}'") + try: + # On Windows, shell=True is required to correctly locate and execute .bat or .cmd files + # like gradlew.bat and mvn.cmd that may be in the system's PATH. + use_shell = sys.platform == "win32" + subprocess.run(command, cwd=working_dir, check=True, shell=use_shell) + logging.info("Command successful.") + except subprocess.CalledProcessError as e: + # Output will have been streamed, so we just need to log the failure. + logging.error(f"Command failed with exit code {e.returncode}") + raise + except FileNotFoundError: + logging.error( + f"Command failed: The executable '{command[0]}' was not found. " + "Please ensure it is installed and that its location is in the system's PATH environment variable." + ) + raise + + +def log_directory_contents(dir_path: Path, description: str): + """Logs the contents of a directory for debugging.""" + logging.info(f"--- Listing contents of {description} at '{dir_path}' ---") + if not dir_path.is_dir(): + logging.warning(f"Directory does not exist: {dir_path}") + return + contents = list(dir_path.rglob("*")) + if not contents: + logging.warning(f"Directory is empty: {dir_path}") + else: + for item in contents: + logging.info(f" - {item.relative_to(dir_path)}") + logging.info("--- End of directory listing ---") + + +def create_zip_from_directory(zip_file_path: Path, source_dir: Path): + """Creates a zip file from the contents of a source directory.""" + logging.info(f"Creating archive '{zip_file_path}' from directory '{source_dir}'...") + with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zipf: + for root, _, files in source_dir.walk(): + for file in files: + file_path = root / file + archive_name = file_path.relative_to(source_dir) + zipf.write(file_path, archive_name) + logging.info("Archive created successfully.") + + +# --- New function for validation --- +def validate_artifacts( + platform_dir: Path, main_jar: Path, main_pom: Path, testing_jar: Path, version: str, artifact_id: str +): + """Uses Maven to validate the generated JAR and POM files.""" + logging.info("--- Starting Maven Artifact Validation ---") + maven_executable = "mvn.cmd" if sys.platform == "win32" else "mvn" + group_id = "com.microsoft.onnxruntime" # Assuming this is constant + + # 1. Validate the main ONNX Runtime JAR and its POM + logging.info(f"Validating main artifact: {main_jar.name}") + install_main_cmd = [ + maven_executable, + "install:install-file", + f"-Dfile={main_jar.resolve()}", + f"-DpomFile={main_pom.resolve()}", + # Adding these makes the command more robust and less prone to errors + f"-DgroupId={group_id}", + f"-DartifactId={artifact_id}", + f"-Dversion={version}", + "-Dpackaging=jar", + ] + run_command(install_main_cmd, working_dir=platform_dir) + logging.info("Main artifact validated successfully.") + + # 2. Validate the testing JAR (it has no POM, so we supply all info) + logging.info(f"Validating testing artifact: {testing_jar.name}") + install_testing_cmd = [ + maven_executable, + "install:install-file", + f"-Dfile={testing_jar.resolve()}", + f"-DgroupId={group_id}", + f"-DartifactId={artifact_id}-testing", + f"-Dversion={version}", + "-Dpackaging=jar", + ] + run_command(install_testing_cmd, working_dir=platform_dir) + logging.info("Testing artifact validated successfully.") + logging.info("--- Maven Artifact Validation Complete ---") + + +def main(): + """Main script execution.""" + parser = argparse.ArgumentParser(description="Builds and packages Java artifacts, PDBs, and notice files.") + parser.add_argument("--sources-dir", required=True, type=Path, help="Path to the build sources directory.") + parser.add_argument("--binaries-dir", required=True, type=Path, help="Path to the build binaries directory.") + parser.add_argument("--platform", required=True, help="Platform string (e.g., x64).") + parser.add_argument( + "--java-artifact-id", required=True, help="The Java artifact ID (e.g., onnxruntime or onnxruntime_gpu)." + ) + parser.add_argument( + "--build-config", + choices=["Debug", "Release", "RelWithDebInfo", "MinSizeRel"], + default="RelWithDebInfo", + help="The CMake build configuration type.", + ) + parser.add_argument( + "--pre-release-version-suffix-string", + choices=["alpha", "beta", "rc", "none"], + default="none", + help="The pre-release version suffix string.", + ) + parser.add_argument( + "--pre-release-version-suffix-number", type=int, default=0, help="The pre-release version suffix number." + ) + parser.add_argument("--commit-hash", required=True, help="The git commit hash.") + parser.add_argument("--build-only", action="store_true", help="Flag to indicate if this is a build-only run.") + args = parser.parse_args() + + # --- 1. Version and Build Logic --- + # Determine the repository root from the script's location + repo_root = Path(__file__).resolve().parent.parent.parent + version_file_path = repo_root / "VERSION_NUMBER" + + logging.info(f"Reading base version from {version_file_path}") + if not version_file_path.is_file(): + raise FileNotFoundError(f"Version file not found at {version_file_path}") + + base_version = version_file_path.read_text(encoding="utf-8").strip() + + # Validate the version format + if not re.match(r"^\d+\.\d+\.\d+$", base_version): + raise ValueError(f"Version '{base_version}' from {version_file_path} is not in the required x.y.z format.") + + logging.info(f"Successfully read and validated base version: {base_version}") + + # Start with the base version and conditionally append the pre-release suffix. + full_version = base_version + if args.pre_release_version_suffix_string != "none": + if args.pre_release_version_suffix_number <= 0: + raise ValueError( + "Pre-release version suffix number must be a positive integer if a suffix string is provided." + ) + # Append the suffix, conforming to Maven standards (e.g., 1.2.3-rc1) + full_version += f"-{args.pre_release_version_suffix_string}{args.pre_release_version_suffix_number}" + + logging.info(f"Using full version: {full_version}") + + # Use the java subdirectory of the repository root as the working directory for Gradle + java_working_dir = repo_root / "java" + + build_config_dir = args.binaries_dir / args.build_config + cmake_build_dir_arg = f"-DcmakeBuildDir={build_config_dir}" + version_property_arg = f"-Dorg.gradle.project.version={full_version}" + + # Construct the absolute path to the Gradle wrapper + gradle_executable_name = "gradlew.bat" if sys.platform == "win32" else "gradlew" + gradle_executable_path = java_working_dir / gradle_executable_name + + # Rebuild the jar so that we can change the version + gradle_args = [cmake_build_dir_arg, version_property_arg] + if args.java_artifact_id == "onnxruntime_gpu": + gradle_args.append("-DUSE_CUDA") + gradle_args.append("-DUSE_TENSORRT") + run_command([str(gradle_executable_path), "cmakeBuild", *gradle_args], working_dir=java_working_dir) + if args.build_only: + run_command( + [ + str(gradle_executable_path), + "testClasses", + "--warning-mode", + "all", + *gradle_args, + ], + working_dir=java_working_dir, + ) + else: + run_command( + [ + str(gradle_executable_path), + "cmakeCheck", + "--warning-mode", + "all", + *gradle_args, + ], + working_dir=java_working_dir, + ) + + # --- 2. Path Definitions --- + platform_dir = args.binaries_dir / f"onnxruntime-java-win-{args.platform}" + stage_dir = platform_dir / "stage" + native_folder = stage_dir / "ai" / "onnxruntime" / "native" / f"win-{args.platform}" + main_jar_name = f"{args.java_artifact_id}-{full_version}.jar" + main_jar_path = platform_dir / main_jar_name + final_pom_path = platform_dir / f"{args.java_artifact_id}-{full_version}.pom" + testing_jar_path = platform_dir / "testing.jar" + + # --- 3. Packaging Logic --- + try: + stage_dir.mkdir(parents=True, exist_ok=True) + native_folder.mkdir(parents=True, exist_ok=True) + + gradle_libs_dir = java_working_dir / "build" / "libs" + log_directory_contents(gradle_libs_dir, "Gradle build output libs") + + # FIX: Filter glob results to find the main artifact JAR, excluding sources and javadoc. + main_jars = [ + p + for p in gradle_libs_dir.glob("*.jar") + if not p.name.endswith("-sources.jar") and not p.name.endswith("-javadoc.jar") + ] + + if not main_jars: + raise FileNotFoundError(f"Gradle build finished, but no main artifact JAR was found in {gradle_libs_dir}") + if len(main_jars) > 1: + logging.warning(f"Found multiple potential main JARs: {[p.name for p in main_jars]}. Using the first one.") + + source_jar_path = main_jars[0] + logging.info(f"Found source JAR to copy: {source_jar_path.name}") + + # The main JAR file is copied to its final name directly. + shutil.copy2(source_jar_path, main_jar_path) + + # Now, find and copy the associated sources and javadoc JARs, renaming them to match. + source_basename = source_jar_path.stem # e.g., 'onnxruntime-1.23.0' + dest_basename = main_jar_path.stem # e.g., 'onnxruntime_gpu-1.23.0' + + for classifier in ["sources", "javadoc"]: + source_classified_jar = gradle_libs_dir / f"{source_basename}-{classifier}.jar" + if source_classified_jar.is_file(): + dest_classified_jar = platform_dir / f"{dest_basename}-{classifier}.jar" + logging.info(f"Copying classified artifact: {source_classified_jar.name} -> {dest_classified_jar.name}") + shutil.copy2(source_classified_jar, dest_classified_jar) + else: + logging.warning(f"Optional artifact '{source_classified_jar.name}' not found, skipping.") + + log_directory_contents(platform_dir, "final platform directory before JAR processing") + + pom_archive_path = f"META-INF/maven/com.microsoft.onnxruntime/{args.java_artifact_id}/pom.xml" + with zipfile.ZipFile(main_jar_path, "r") as jar: + jar.extract(pom_archive_path, path=platform_dir) + + shutil.move(str(platform_dir / pom_archive_path), str(final_pom_path)) + shutil.rmtree(platform_dir / "META-INF") + + shutil.copy2(args.sources_dir / "docs" / "Privacy.md", stage_dir) + shutil.copy2(args.sources_dir / "ThirdPartyNotices.txt", stage_dir) + (stage_dir / "GIT_COMMIT_ID").write_text(args.commit_hash, encoding="utf-8") + + with zipfile.ZipFile(main_jar_path, "a") as jar: + for root, _, files in stage_dir.walk(): + for file in files: + file_path = root / file + jar.write(file_path, file_path.relative_to(stage_dir)) + + test_classes_dir = args.sources_dir / "java" / "build" / "classes" / "java" / "test" + test_resources_dir = args.sources_dir / "java" / "build" / "resources" / "test" + + create_zip_from_directory(testing_jar_path, test_classes_dir) + + native_resource_path = test_resources_dir / "ai" / "onnxruntime" / "native" + if native_resource_path.exists(): + shutil.rmtree(native_resource_path) + + with zipfile.ZipFile(testing_jar_path, "a") as jar: + for root, _, files in test_resources_dir.walk(): + for file in files: + file_path = root / file + jar.write(file_path, file_path.relative_to(test_resources_dir)) + + logging.info("Java artifact packaging complete.") + + # --- 4. Validation Step --- + validate_artifacts( + platform_dir=platform_dir, + main_jar=main_jar_path, + main_pom=final_pom_path, + testing_jar=testing_jar_path, + version=full_version, + artifact_id=args.java_artifact_id, + ) + + finally: + # 5. Clean up stage directory + if stage_dir.exists(): + logging.info(f"Cleaning up stage directory: {stage_dir}") + shutil.rmtree(stage_dir) + + logging.info(f"\nFinal contents of '{platform_dir}':") + for item in platform_dir.iterdir(): + print(item) + + +if __name__ == "__main__": + main() diff --git a/tools/ci_build/prepare_macos_package.py b/tools/ci_build/prepare_macos_package.py new file mode 100644 index 0000000000000..b92e81663c776 --- /dev/null +++ b/tools/ci_build/prepare_macos_package.py @@ -0,0 +1,185 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import os +import pathlib +import shutil +import stat as stat_module +import subprocess +import sys +import tarfile +from datetime import datetime + + +def run_command(command: list[str | pathlib.Path], check: bool = True) -> subprocess.CompletedProcess: + """Helper to run a command, stream its output, and check for errors.""" + print(f"Executing: {' '.join(map(str, command))}", flush=True) + try: + return subprocess.run(command, check=check, text=True, capture_output=True) + except subprocess.CalledProcessError as e: + print(f"ERROR: Command failed with exit code {e.returncode}", file=sys.stderr) + print(f"--- STDOUT ---\n{e.stdout}", file=sys.stderr) + print(f"--- STDERR ---\n{e.stderr}", file=sys.stderr) + raise + + +def get_relative_file_paths(root_dir: pathlib.Path) -> set[pathlib.Path]: + """ + Returns a set of all relative file paths within a directory, + ignoring any files inside .dSYM directories. + """ + paths = set() + for p in root_dir.rglob("*"): + # Check if any part of the path is a .dSYM directory. + if any(part.endswith(".dSYM") for part in p.relative_to(root_dir).parts): + continue + if p.is_file(): + paths.add(p.relative_to(root_dir)) + return paths + + +def is_macho_binary(file_path: pathlib.Path) -> bool: + """Checks if a file is a Mach-O binary using the 'file' command.""" + if not file_path.is_file(): + return False + try: + result = run_command(["file", file_path]) + return "Mach-O" in result.stdout + except (subprocess.CalledProcessError, FileNotFoundError): + return False + + +def main(): + """Main function to prepare macOS packages for signing.""" + # 1. Setup paths and parse arguments + parser = argparse.ArgumentParser(description="Prepares macOS packages for signing.") + parser.add_argument( + "--staging_dir", + type=pathlib.Path, + required=True, + help="The directory where artifacts are staged and processed.", + ) + args = parser.parse_args() + staging_dir = args.staging_dir.resolve() + + if not staging_dir.is_dir(): + raise FileNotFoundError(f"Staging directory not found: {staging_dir}") + + os.chdir(staging_dir) + print(f"##[group]Working in directory: {staging_dir}") + print(f"Initial contents: {[p.name for p in staging_dir.iterdir()]}") + print("##[endgroup]") + + # 2. Unpack all .tgz archives + print("##[group]Unpacking downloaded archives...") + tgz_files = list(staging_dir.glob("*.tgz")) + if not tgz_files: + raise FileNotFoundError("Build Error: No .tgz files found to process.") + + for tgz in tgz_files: + print(f"Extracting {tgz.name}...") + with tarfile.open(tgz) as tar: + tar.extractall(path=".") + tgz.unlink() # Delete the archive + print("##[endgroup]") + + # 3. Locate architecture-specific directories + print("##[group]Locating architecture directories...") + arm64_dirs = list(staging_dir.glob("onnxruntime-osx-arm64*")) + x64_dirs = list(staging_dir.glob("onnxruntime-osx-x86_64*")) + + if len(arm64_dirs) != 1 or len(x64_dirs) != 1: + raise FileNotFoundError( + f"Build Error: Expected 1 arm64 and 1 x64 directory, but found: arm64={len(arm64_dirs)}, x64={len(x64_dirs)}" + ) + + arm64_dir, x64_dir = arm64_dirs[0], x64_dirs[0] + print(f"Found ARM64 source: {arm64_dir.name}") + print(f"Found x86_64 source: {x64_dir.name}") + print("##[endgroup]") + + # **NEW**: Remove _manifest directories before comparison or processing. + print("##[group]Removing _manifest directories...") + for package_dir in (arm64_dir, x64_dir): + manifest_path = package_dir / "_manifest" + if manifest_path.is_dir(): + print(f"Removing manifest directory: {manifest_path.relative_to(staging_dir)}") + shutil.rmtree(manifest_path) + print("##[endgroup]") + + # 4. Error Check: Verify file tree structures are identical + print("##[group]Verifying file tree structures...") + arm64_files = get_relative_file_paths(arm64_dir) + x64_files = get_relative_file_paths(x64_dir) + + if arm64_files != x64_files: + difference = arm64_files.symmetric_difference(x64_files) + print(f"ERROR: File tree structures do not match. Found {len(difference)} differing files:", file=sys.stderr) + for f in sorted(difference): + print(f"- {f}", file=sys.stderr) + sys.exit(1) + + print("✅ File tree structures match.") + print("##[endgroup]") + + # 5. Create the universal binary package + print("##[group]Creating universal2 package with lipo...") + universal_dir = staging_dir / arm64_dir.name.replace("arm64", "universal2") + + print(f"Copying {arm64_dir.name} to {universal_dir.name} as a template.") + shutil.copytree(arm64_dir, universal_dir, symlinks=True, ignore=shutil.ignore_patterns("*.dSYM")) + + for relative_path in arm64_files: + arm64_file = arm64_dir / relative_path + x64_file = x64_dir / relative_path + universal_file = universal_dir / relative_path + + if is_macho_binary(arm64_file) and is_macho_binary(x64_file): + print(f"Combining {relative_path}...") + run_command(["lipo", "-create", arm64_file, x64_file, "-output", universal_file]) + run_command(["lipo", "-info", universal_file]) + print("##[endgroup]") + + # Remove .dSYM folders from source packages before zipping. + print("##[group]Removing .dSYM folders from source packages...") + for package_dir in (arm64_dir, x64_dir): + for dsym_dir in package_dir.rglob("*.dSYM"): + if dsym_dir.is_dir(): + print(f"Removing {dsym_dir.relative_to(staging_dir)}") + shutil.rmtree(dsym_dir) + print("##[endgroup]") + + # 6. Zip all packages for signing and clean up + print("##[group]Zipping all packages for signing...") + for dir_path in (arm64_dir, x64_dir, universal_dir): + # Create a zip file in the staging directory. + zip_file_path = staging_dir / f"{dir_path.name}.zip" + print(f"Zipping {dir_path.name} to {zip_file_path}") + # The source directory path (dir_path.name) is relative to the current working directory (staging_dir). + run_command(["zip", "-FSr", "--symlinks", zip_file_path, dir_path.name]) + + print(f"Removing directory {dir_path.name}") + shutil.rmtree(dir_path) + + print("Final contents of staging directory:") + for item in sorted(staging_dir.iterdir()): + try: + stat = item.stat() + size = stat.st_size + mode_str = stat_module.filemode(stat.st_mode) + mtime = datetime.fromtimestamp(stat.st_mtime).strftime("%b %d %H:%M") + print(f"{mode_str} {size:>10} {mtime} {item.name}") + except FileNotFoundError: + # Handle cases where a file might be a broken symlink + print(f"l????????? {'?':>10} ? ? {item.name} (broken link)") + + print("##[endgroup]") + + +if __name__ == "__main__": + try: + main() + except Exception as e: + print(f"##[error]A critical error occurred: {e}", file=sys.stderr) + sys.exit(1) diff --git a/tools/ci_build/run_gh_action.py b/tools/ci_build/run_gh_action.py new file mode 100644 index 0000000000000..3980324cdd7a1 --- /dev/null +++ b/tools/ci_build/run_gh_action.py @@ -0,0 +1,162 @@ +import os +import platform +import shutil +import sys +import tempfile +import zipfile +from pathlib import Path + +import requests + +SCRIPT_DIR = Path(__file__).resolve().parent +REPO_DIR = (SCRIPT_DIR / ".." / "..").resolve() + +sys.path.insert(0, str(REPO_DIR / "tools" / "python")) + +from util import run # noqa: E402 + +# Hash structure for platform-specific binaries +CMAKE_HASHES = { + "windows": { + "x64": "807b774fcb12defff8ce869e602fc5b6279d5b7bf7229ebcf3f7490da3f887d516b9c49a00d50f9179e552ed8737d19835a19ef8f366d1ffda1ad6f3352a90c2", + "arm64": "86937dc89deabe0ff2a08fe198fcfc70764476b865cca4c6dc3bfc7fb9f7d44d4929af919e26e84aaedef17ad01ffb9683e42c39cb38b409100f723bc5ef1cc0", + }, + "linux": { + "x64": "7939260931098c3f00d2b36de3bee6a0ee3bcae2dba001598c492ed5c82d295c9aa9969654f1ff937fec4d71679541238baaa648c5246f36e14f28f0a62337a0", + "arm64": "8eeb07e966a5340c122979dd2e371708a78adccc85200b22bc7e66028e65513bce5ced6c37fe65aedb94000d970186c5c7562d1ab3dbda911061de46b75345d9", + }, + "macos": "99cc9c63ae49f21253efb5921de2ba84ce136018abf08632c92c060ba91d552e0f6acc214e9ba8123dee0cf6d1cf089ca389e321879fd9d719a60d975bcffcc8", +} + + +def get_platform_keys() -> tuple[str | None, str | None]: + """Detects the OS and CPU architecture and returns normalized keys.""" + os_key: str | None = None + match sys.platform: + case "win32": + os_key = "windows" + case "linux": + os_key = "linux" + case "darwin": + os_key = "macos" + + arch_key: str | None = None + match platform.machine().lower(): + case "amd64" | "x86_64": + arch_key = "x64" + case "arm64" | "aarch64": + arch_key = "arm64" + + return os_key, arch_key + + +def main() -> None: + if len(sys.argv) < 2: + print("::error::Action version argument was not provided.") + sys.exit(1) + + action_version = sys.argv[1] + + # --- Platform Detection and Variable Setup --- + os_key, arch_key = get_platform_keys() + if not os_key or not arch_key: + print( + f"::error::Could not determine a supported platform from OS '{sys.platform}' and Arch '{platform.machine()}'." + ) + sys.exit(1) + + print(f"Detected Platform: OS='{os_key}', Architecture='{arch_key}'") + + try: + if os_key == "macos": + cmake_hash = CMAKE_HASHES[os_key] + else: + cmake_hash = CMAKE_HASHES[os_key][arch_key] + + print(f"Selected CMake hash for '{os_key}'.") + except KeyError: + print(f"::error::Unsupported platform or missing hash for OS='{os_key}' and Arch='{arch_key}'.") + sys.exit(1) + + # --- Conditionally set Terrapin and define action inputs --- + disable_terrapin_value = "true" + terrapin_tool_path_str = "C:\\local\\Terrapin\\TerrapinRetrievalTool.exe" + + action_inputs = { + "INPUT_CMAKE-VERSION": "3.31.8", + "INPUT_CMAKE-HASH": cmake_hash, + "INPUT_VCPKG-VERSION": "2025.06.13", + "INPUT_VCPKG-HASH": "735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc", + "INPUT_ADD-CMAKE-TO-PATH": "true", + } + + if os_key == "windows" and Path(terrapin_tool_path_str).exists(): + disable_terrapin_value = "false" + action_inputs["INPUT_TERRAPIN-TOOL-PATH"] = terrapin_tool_path_str + print("Terrapin tool found. Setting INPUT_DISABLE-TERRAPIN to 'false' and providing tool path.") + + action_inputs["INPUT_DISABLE-TERRAPIN"] = disable_terrapin_value + + # --- Download and Extract the Action to a Temporary Directory --- + zip_url = f"https://github.com/microsoft/onnxruntime-github-actions/archive/refs/tags/{action_version}.zip" + + # Use AGENT_TEMPDIRECTORY, with a fallback to the system's default temp directory. + temp_dir = Path(os.environ.get("AGENT_TEMPDIRECTORY", tempfile.gettempdir())).resolve() + zip_path = temp_dir / "action.zip" + extract_dir = temp_dir / "action-unzipped" + + print(f"Using temporary directory: {temp_dir}") + + # --- Locate, Run, and Cleanup the Action Script --- + try: + print(f"Downloading action source from: {zip_url}") + response = requests.get(zip_url, stream=True) + response.raise_for_status() + with open(zip_path, "wb") as f: + shutil.copyfileobj(response.raw, f) + + print(f"Extracting {zip_path} to {extract_dir}") + if extract_dir.exists(): + shutil.rmtree(extract_dir) + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(extract_dir) + + try: + action_base_path = next(extract_dir.glob("onnxruntime-github-actions-*")) + print(f"Found action base path: {action_base_path}") + except StopIteration as e: + raise FileNotFoundError(f"Could not find extracted action directory in '{extract_dir}'") from e + + action_script_path = action_base_path / "setup-build-tools" / "dist" / "index.js" + if not action_script_path.exists(): + raise FileNotFoundError(f"Action script not found at expected path: {action_script_path}") + + env = os.environ.copy() + env.update(action_inputs) + + if "AGENT_TOOLSDIRECTORY" in env: + env["RUNNER_TOOL_CACHE"] = env["AGENT_TOOLSDIRECTORY"] + print(f"Mapped RUNNER_TOOL_CACHE to AGENT_TOOLSDIRECTORY: {env['RUNNER_TOOL_CACHE']}") + if "AGENT_TEMPDIRECTORY" in env: + env["RUNNER_TEMP"] = env["AGENT_TEMPDIRECTORY"] + print(f"Mapped RUNNER_TEMP to AGENT_TEMPDIRECTORY: {env['RUNNER_TEMP']}") + + run("node", str(action_script_path), env=env) + + finally: + # --- Cleanup --- + # This block ensures the zip file and extracted directory are always removed. + print("\nStarting cleanup...") + if zip_path.exists(): + print(f"Removing temporary zip file: {zip_path}") + zip_path.unlink() + + if extract_dir.exists(): + print(f"Removing extracted action directory: {extract_dir}") + shutil.rmtree(extract_dir) + + print("Cleanup complete.") + + +if __name__ == "__main__": + main() diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 211cb7a2a8a75..ead240a7cef1b 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -22,6 +22,8 @@ def get_package_name(os, cpu_arch, ep, is_training_package): pkg_name += "-tensorrt" elif ep == "rocm": pkg_name += "-rocm" + elif ep == "migraphx": + pkg_name += "-migraphx" elif os == "linux": pkg_name += "-linux-" pkg_name += cpu_arch @@ -31,6 +33,8 @@ def get_package_name(os, cpu_arch, ep, is_training_package): pkg_name += "-tensorrt" elif ep == "rocm": pkg_name += "-rocm" + elif ep == "migraphx": + pkg_name += "-migraphx" elif os == "osx": pkg_name = "onnxruntime-osx-" + cpu_arch return pkg_name @@ -44,7 +48,11 @@ def get_package_name(os, cpu_arch, ep, is_training_package): def is_this_file_needed(ep, filename, package_name): if package_name == "Microsoft.ML.OnnxRuntime.Gpu": return False - return (ep != "cuda" or "cuda" in filename) and (ep != "tensorrt" or "cuda" not in filename) + return ( + (ep != "cuda" or "cuda" in filename) + and (ep != "tensorrt" or "cuda" not in filename) + and (ep != "migraphx" or "migraphx" not in filename) + ) # nuget_artifacts_dir: the directory with uncompressed C API tarball/zip files @@ -138,7 +146,7 @@ def parse_arguments(): required=False, default="None", type=str, - choices=["cuda", "dnnl", "openvino", "tensorrt", "snpe", "qnn", "None"], + choices=["cuda", "dnnl", "openvino", "migraphx", "tensorrt", "snpe", "qnn", "None"], help="The selected execution provider for this build.", ) parser.add_argument("--sdk_info", required=False, default="", type=str, help="dependency SDK information.") @@ -182,6 +190,8 @@ def generate_description(line_list, package_name): description = "This package contains Linux native shared library artifacts for ONNX Runtime with CUDA." elif "Microsoft.ML.OnnxRuntime.Gpu.Windows" in package_name: description = "This package contains Windows native shared library artifacts for ONNX Runtime with CUDA." + elif "Microsoft.ML.OnnxRuntime.MIGraphX" in package_name: + description = "This package contains native shared library artifacts for ONNX Runtime with MIGraphX." elif "Intel.ML.OnnxRuntime" in package_name: description = "This package contains native shared library artifacts for ONNX Runtime with OpenVINO." elif "Microsoft.ML.OnnxRuntime" in package_name: # This is a Microsoft.ML.OnnxRuntime.* package @@ -359,6 +369,7 @@ def generate_files(line_list, args): is_windowsai_package = args.package_name == "Microsoft.AI.MachineLearning" is_snpe_package = args.package_name == "Microsoft.ML.OnnxRuntime.Snpe" is_qnn_package = args.package_name == "Microsoft.ML.OnnxRuntime.QNN" + is_migraphx_package = args.package_name == "Microsoft.ML.OnnxRuntime.MIGraphX" is_training_package = args.package_name in [ "Microsoft.ML.OnnxRuntime.Training", "Microsoft.ML.OnnxRuntime.Training.Gpu", @@ -384,6 +395,24 @@ def generate_files(line_list, args): "openvino_ep_shared_lib": "onnxruntime_providers_openvino.dll", "cuda_ep_shared_lib": "onnxruntime_providers_cuda.dll", "qnn_ep_shared_lib": "onnxruntime_providers_qnn.dll", + "migraphx_ep_shared_lib": "onnxruntime_providers_migraphx.dll", + "amd_comgr0602": "amd_comgr0602.dll", + "amd_comgr0604": "amd_comgr0604.dll", + "amd_comgr0700": "amd_comgr0700.dll", + "hiprtc0602": "hiprtc0602.dll", + "hiprtc0604": "hiprtc0604.dll", + "hiprtc0700": "hiprtc0700.dll", + "hiprtc-builtins0602": "hiprtc-builtins0602.dll", + "hiprtc-builtins0604": "hiprtc-builtins0604.dll", + "hiprtc-builtins0700": "hiprtc-builtins0700.dll", + "migraphx-hiprtc-driver": "migraphx-hiprtc-driver.exe", + "migraphx": "migraphx.dll", + "migraphx_c": "migraphx_c.dll", + "migraphx_cpu": "migraphx_cpu.dll", + "migraphx_device": "migraphx_device.dll", + "migraphx_gpu": "migraphx_gpu.dll", + "migraphx_onnx": "migraphx_onnx.dll", + "migraphx_tf": "migraphx_tf", "onnxruntime_perf_test": "onnxruntime_perf_test.exe", "onnx_test_runner": "onnx_test_runner.exe", } @@ -402,6 +431,7 @@ def generate_files(line_list, args): "openvino_ep_shared_lib": "libonnxruntime_providers_openvino.so", "cuda_ep_shared_lib": "libonnxruntime_providers_cuda.so", "rocm_ep_shared_lib": "libonnxruntime_providers_rocm.so", + "migraphx_ep_shared_lib": "libonnxruntime_providers_migraphx.so", "onnxruntime_perf_test": "onnxruntime_perf_test", "onnx_test_runner": "onnx_test_runner", } @@ -421,7 +451,7 @@ def generate_files(line_list, args): include_dir = f"{build_dir}\\native\\include" # Sub.Gpu packages do not include the onnxruntime headers - if args.package_name != "Microsoft.ML.OnnxRuntime.Gpu": + if args.package_name != "Microsoft.ML.OnnxRuntime.Gpu" and args.package_name != "Microsoft.ML.OnnxRuntime.MIGraphX": files_list.append( "' ) + if args.execution_provider == "migraphx": + files_list.append( + "' + ) + files_list.append( + "' + ) + + if is_windows_build: + native_build_path = Path(args.native_build_path) + + def _files_list_append(key: str): + path = native_build_path / nuget_dependencies[key] + if path.exists(): + files_list.append( + "' + ) + + _files_list_append("amd_comgr0602") + _files_list_append("amd_comgr0604") + _files_list_append("amd_comgr0700") + _files_list_append("hiprtc0602") + _files_list_append("hiprtc0604") + _files_list_append("hiprtc0700") + _files_list_append("hiprtc-builtins0602") + _files_list_append("hiprtc-builtins0604") + _files_list_append("hiprtc-builtins0700") + _files_list_append("migraphx-hiprtc-driver") + _files_list_append("migraphx") + _files_list_append("migraphx_c") + _files_list_append("migraphx_cpu") + _files_list_append("migraphx_device") + _files_list_append("migraphx_gpu") + _files_list_append("migraphx_onnx") + _files_list_append("migraphx_tf") + if is_dml_package: files_list.append( " int | None: """Triggers a pipeline and returns the new build ID.""" print(f"\n--- Triggering Pipeline ID: {pipeline_id} on branch '{branch}' in project '{project}' ---") @@ -313,13 +335,26 @@ def trigger_pipeline( run_url = f"https://dev.azure.com/{ADO_ORGANIZATION}/{project}/_apis/pipelines/{pipeline_id}/runs?api-version=7.1-preview.1" payload: dict[str, any] = {"resources": {"repositories": {"self": {"refName": branch}}}} + template_params: dict[str, any] = {} if nightly_override is not None and packaging_type == "nightly": print(f"Overriding NIGHTLY_BUILD variable to '{nightly_override}'.") payload["variables"] = {"NIGHTLY_BUILD": {"value": nightly_override}} elif release_override is not None and packaging_type == "release": print(f"Overriding IsReleaseBuild parameter to '{release_override}'.") - payload["templateParameters"] = {"IsReleaseBuild": release_override} + template_params["IsReleaseBuild"] = release_override + + # Add pre-release parameters if the pipeline supports them and user provided them + if has_pre_release_params: + if pre_release_suffix_string is not None: + print(f"Setting PreReleaseVersionSuffixString parameter to '{pre_release_suffix_string}'.") + template_params["PreReleaseVersionSuffixString"] = pre_release_suffix_string + if pre_release_suffix_number is not None: + print(f"Setting PreReleaseVersionSuffixNumber parameter to {pre_release_suffix_number}.") + template_params["PreReleaseVersionSuffixNumber"] = pre_release_suffix_number + + if template_params: + payload["templateParameters"] = template_params try: response = requests.post(run_url, headers=headers, data=json.dumps(payload)) @@ -327,12 +362,12 @@ def trigger_pipeline( build_info = response.json() build_id = build_info.get("id") print(f"Successfully triggered build. Build ID: {build_id}") - print(f" Web URL: {build_info.get('_links', {}).get('web', {}).get('href')}") + print(f" Web URL: {build_info.get('_links', {}).get('web', {}).get('href')}") return build_id except requests.exceptions.RequestException as e: print(f"ERROR triggering pipeline: {e}") if e.response is not None: - print(f" Response: {e.response.text}") + print(f" Response: {e.response.text}") return None @@ -351,6 +386,24 @@ def main(): choices=["nightly", "release"], help="Specify the build mode for packaging pipelines (nightly or release). This sets NIGHTLY_BUILD and IsReleaseBuild accordingly.", ) + parser.add_argument( + "--no-cancel", + action="store_true", + dest="no_cancel_builds", + help="Do not cancel existing running builds for the pipeline before triggering a new one.", + ) + # New arguments for pre-release versioning + parser.add_argument( + "--pre-release-suffix-string", + choices=["alpha", "beta", "rc", "none"], + help="Suffix for pre-release versions (e.g., 'rc'). Requires the pipeline to have the parameter.", + ) + parser.add_argument( + "--pre-release-suffix-number", + type=int, + help="Number for pre-release versions (e.g., '1'). Requires the pipeline to have the parameter.", + ) + args = parser.parse_args() project = "Lotus" @@ -358,6 +411,11 @@ def main(): branch_for_yaml_fetch = args.branch is_pr_build = False + if (args.pre_release_suffix_string and not args.pre_release_suffix_number) or ( + not args.pre_release_suffix_string and args.pre_release_suffix_number + ): + parser.error("--pre-release-suffix-string and --pre-release-suffix-number must be used together.") + if args.pr: print(f"--- Pull Request Mode Activated for PR #{args.pr} ---") project = "PublicPackages" @@ -388,21 +446,58 @@ def main(): print(f" - {result['pipeline']['name']} (ID: {result['pipeline']['id']})") else: print(f"\n--- Triggering {len(pipelines_to_trigger)} Pipelines on branch '{branch_for_trigger}' ---") - nightly_override = None - release_override = None - if args.build_mode == "nightly": - nightly_override = "1" - release_override = "false" - elif args.build_mode == "release": - nightly_override = "0" - release_override = "true" + + # If pre-release flags are used, it implies a release build. + if args.pre_release_suffix_string: + print("Pre-release suffix provided. Forcing 'release' build mode.") + if args.build_mode and args.build_mode != "release": + print(f"Warning: --build-mode={args.build_mode} is overridden by pre-release flags.") + + # If pre-release flags are used, it implies a release build. + if args.pre_release_suffix_string: + print("Pre-release suffix provided. Forcing 'release' build mode.") + if args.build_mode and args.build_mode != "release": + print(f"Warning: --build-mode={args.build_mode} is overridden by pre-release flags.") for result in pipelines_to_trigger: pipeline = result["pipeline"] packaging_type = result["packaging_type"] - cancel_running_builds(pipeline["id"], branch_for_trigger, token, project) + has_pre_release_params = result["has_pre_release_params"] + + # Determine build mode based on flags + nightly_override = None + release_override = None + if args.build_mode == "nightly": + nightly_override = "1" + release_override = "false" + elif args.build_mode == "release": + nightly_override = "0" + release_override = "true" + + # If pre-release flags are used AND the pipeline supports them, it implies a release build. + if args.pre_release_suffix_string and has_pre_release_params: + print(f"Pre-release flags used and supported by '{pipeline['name']}'. Forcing 'release' mode.") + if args.build_mode and args.build_mode != "release": + print(f" - Warning: --build-mode={args.build_mode} is overridden for this pipeline.") + nightly_override = "0" + release_override = "true" + + if not args.no_cancel_builds: + cancel_running_builds(pipeline["id"], branch_for_trigger, token, project) + else: + print(f"\nSkipping cancellation for Pipeline ID: {pipeline['id']} as per --no-cancel flag.") + trigger_pipeline( - pipeline["id"], token, branch_for_trigger, project, nightly_override, release_override, packaging_type + pipeline_id=pipeline["id"], + token=token, + branch=branch_for_trigger, + project=project, + nightly_override=nightly_override, + release_override=release_override, + packaging_type=packaging_type, + has_pre_release_params=has_pre_release_params, + pre_release_suffix_string=args.pre_release_suffix_string, + pre_release_suffix_number=args.pre_release_suffix_number, )