From e3bb40a51e222c4c880ff85f56cd55cb43d385c9 Mon Sep 17 00:00:00 2001 From: firedoil Date: Sat, 20 Jun 2026 02:35:01 +0800 Subject: [PATCH] Add ROCm PQC TrustFlow KEM and signature example --- Applications/pqc_trustflow_rocm/.gitignore | 17 + .../README.md | 47 + .../docs/api_notes.md | 167 + .../docs/quick_start.md | 163 + .../kem_api/batch_kem.cuh | 797 +++++ .../kem_api/batch_ntt.cuh | 481 +++ .../kem_api/batch_ops.cuh | 243 ++ .../kem_api/build_all.sh | 44 + .../kem_api/build_hip.sh | 102 + .../kem_api/cbd.cuh | 352 ++ .../kem_api/config.h | 41 + .../kem_api/fips202.cuh | 287 ++ .../kem_api/kem.cuh | 526 +++ .../kem_api/main.cu | 815 +++++ .../kem_api/ntt.cuh | 364 ++ .../kem_api/params.h | 183 + .../kem_api/poly.cuh | 252 ++ .../kem_api/polyvec.cuh | 259 ++ .../kem_api/reduce.cuh | 80 + .../kem_api/rocm_compat.h | 54 + .../kem_api/run_kem_smoke_amd.sh | 42 + .../sig_api/batch_keygen.cuh | 3095 +++++++++++++++++ .../sig_api/batch_ntt.cuh | 285 ++ .../sig_api/batch_ops.cuh | 207 ++ .../sig_api/batch_sign.cuh | 900 +++++ .../sig_api/batch_sign_warp.cuh | 835 +++++ .../sig_api/batch_verify.cuh | 793 +++++ .../sig_api/build_sig_amd.sh | 43 + .../sig_api/config.h | 31 + .../sig_api/fips202.cuh | 354 ++ .../sig_api/main.cu | 2579 ++++++++++++++ .../sig_api/ntt.cuh | 299 ++ .../sig_api/packing.cuh | 224 ++ .../sig_api/params.h | 307 ++ .../sig_api/poly.cuh | 817 +++++ .../sig_api/polyvec.cuh | 155 + .../sig_api/reduce.cuh | 136 + .../sig_api/rounding.cuh | 134 + .../sig_api/run_sig_policy_smoke.sh | 44 + .../sig_api/sign.cuh | 790 +++++ .../sig_api/symmetric.cuh | 105 + .../trustflow_frontend/FRONTEND_MANIFEST.md | 19 + .../trustflow_frontend/README.md | 48 + .../trustflow_frontend/__init__.py | 13 + .../trustflow_frontend/app.py | 380 ++ .../trustflow_frontend/backends.py | 814 +++++ .../sample_docs/Untitled.ipynb | 33 + .../sample_docs/lab_panel.csv | 4 + .../sample_docs/medical_report.txt | 3 + .../sample_docs/risk_features.json | 8 + .../trustflow_frontend/state.py | 49 + .../README.md | 50 + .../evidence/kem_final_extract.txt | 28 + ...kyber_amd_first_run_bottleneck_analysis.md | 1460 ++++++++ .../evidence/sig_large_best.csv | 37 + .../evidence/sig_optimization_claims.md | 107 + ...six_parameter_final_decision_2026-06-16.md | 81 + .../table_sig_final_large_sweep_2026-06-16.md | 26 + ...local_winners_feature_matrix_2026-06-16.md | 51 + .../kem_optimization/AMD_RUNBOOK.md | 292 ++ .../KEM_AMD_OPTIMIZATION_LOG.md | 523 +++ .../kem_optimization/batch_kem.cuh | 797 +++++ .../kem_optimization/build_hip.sh | 102 + .../kem_optimization/main.cu | 815 +++++ .../kem_optimization/parse_kem_results.py | 149 + .../kem_optimization/profile_kem_one_amd.sh | 36 + .../run_kem_all_bounds_probe_amd.sh | 207 ++ .../run_kem_all_profile_compare_amd.sh | 144 + .../kem_optimization/run_kem_confirm_amd.sh | 94 + .../run_kem_final_report_amd.sh | 72 + .../run_kem_resource_profile_amd.sh | 62 + .../kem_optimization/run_kem_tune_amd.sh | 123 + .../run_rocm_toolbox_kem_amd.sh | 245 ++ .../kem_optimization/summarize_kem_best.py | 68 + .../summarize_profile_compare.py | 253 ++ .../kem_optimization/summarize_rocm_pmc.py | 86 + .../summarize_rocprofv3_trace.py | 123 + .../sig_optimization/COMPETITION_RUNBOOK.md | 206 ++ .../sig_optimization/amd_tools/README.md | 175 + .../amd_tools/SIG_DEBUG_PLAN.md | 112 + .../amd_tools/build_sig_amd.sh | 43 + .../amd_tools/build_sig_amd_selected.sh | 108 + .../amd_tools/check_competition_evidence.py | 147 + .../amd_tools/compare_amd_4090.py | 132 + .../amd_tools/parse_sig_results.py | 82 + .../amd_tools/profile_sig_one.sh | 39 + .../amd_tools/run_sig_amd_feature_matrix.sh | 153 + .../amd_tools/run_sig_debug_matrix.sh | 40 + .../amd_tools/run_sig_large_sweep.sh | 54 + .../amd_tools/run_sig_policy_smoke.sh | 44 + .../amd_tools/run_sig_sweep.sh | 34 + .../amd_tools/select_sig_amd_variants.py | 257 ++ .../amd_tools/summarize_amd_feature_matrix.py | 189 + .../summarize_rocm_kernel_profile.py | 170 + .../amd_tools/summarize_sig_best.py | 75 + .../amd_tools/write_optimization_claims.py | 123 + Applications/pqc_trustflow_rocm/README.md | 41 + 97 files changed, 27075 insertions(+) create mode 100644 Applications/pqc_trustflow_rocm/.gitignore create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/README.md create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/docs/api_notes.md create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/docs/quick_start.md create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/batch_kem.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/batch_ntt.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/batch_ops.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/build_all.sh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/build_hip.sh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/cbd.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/config.h create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/fips202.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/kem.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/main.cu create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/ntt.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/params.h create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/poly.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/polyvec.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/reduce.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/rocm_compat.h create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/run_kem_smoke_amd.sh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_keygen.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_ntt.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_ops.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_sign.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_sign_warp.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_verify.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/build_sig_amd.sh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/config.h create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/fips202.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/main.cu create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/ntt.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/packing.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/params.h create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/poly.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/polyvec.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/reduce.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/rounding.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/run_sig_policy_smoke.sh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/sign.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/symmetric.cuh create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/FRONTEND_MANIFEST.md create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/README.md create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/__init__.py create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/app.py create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/backends.py create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/sample_docs/Untitled.ipynb create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/sample_docs/lab_panel.csv create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/sample_docs/medical_report.txt create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/sample_docs/risk_features.json create mode 100644 Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/state.py create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/README.md create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/kem_final_extract.txt create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/kyber_amd_first_run_bottleneck_analysis.md create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/sig_large_best.csv create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/sig_optimization_claims.md create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/sig_six_parameter_final_decision_2026-06-16.md create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/table_sig_final_large_sweep_2026-06-16.md create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/table_sig_local_winners_feature_matrix_2026-06-16.md create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/AMD_RUNBOOK.md create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/KEM_AMD_OPTIMIZATION_LOG.md create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/batch_kem.cuh create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/build_hip.sh create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/main.cu create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/parse_kem_results.py create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/profile_kem_one_amd.sh create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_all_bounds_probe_amd.sh create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_all_profile_compare_amd.sh create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_confirm_amd.sh create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_final_report_amd.sh create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_resource_profile_amd.sh create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_tune_amd.sh create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_rocm_toolbox_kem_amd.sh create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/summarize_kem_best.py create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/summarize_profile_compare.py create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/summarize_rocm_pmc.py create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/summarize_rocprofv3_trace.py create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/COMPETITION_RUNBOOK.md create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/README.md create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/SIG_DEBUG_PLAN.md create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/build_sig_amd.sh create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/build_sig_amd_selected.sh create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/check_competition_evidence.py create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/compare_amd_4090.py create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/parse_sig_results.py create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/profile_sig_one.sh create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_amd_feature_matrix.sh create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_debug_matrix.sh create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_large_sweep.sh create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_policy_smoke.sh create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_sweep.sh create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/select_sig_amd_variants.py create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/summarize_amd_feature_matrix.py create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/summarize_rocm_kernel_profile.py create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/summarize_sig_best.py create mode 100644 Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/write_optimization_claims.py create mode 100644 Applications/pqc_trustflow_rocm/README.md diff --git a/Applications/pqc_trustflow_rocm/.gitignore b/Applications/pqc_trustflow_rocm/.gitignore new file mode 100644 index 000000000..7caee082b --- /dev/null +++ b/Applications/pqc_trustflow_rocm/.gitignore @@ -0,0 +1,17 @@ +*.exe +*.dll +*.o +*.obj +*.pdb +*.pyc +__pycache__/ +*.demo_secret +*receiver_sk* +*ss_sender* +*ss_receiver* +outputs/ +logs/ +*.zip +*.tar +*.tar.gz +*.bak* diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/README.md b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/README.md new file mode 100644 index 000000000..d1ab9dbf4 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/README.md @@ -0,0 +1,47 @@ +# Unsupported Function Development: ROCm PQC API + +This folder is prepared for the competition item: + +```text +(1) Development of currently unsupported functions +``` + +## What This Adds + +The contribution adds a ROCm/HIP post-quantum cryptography backend and an upper-layer file workflow: + +- `kem_api/`: Kyber/Aigis-enc batch KEM backend with file-level keygen, encaps, and decaps API paths. +- `sig_api/`: ML-DSA/Aigis-sig batch signature backend with file-level sign and verify API paths. +- `trustflow_frontend/`: a multi-file secure packaging frontend that calls the ROCm KEM/SIG backends. +- `docs/`: quick-start and API notes for reproducing the workflow. + +## Key API Examples + +```bash +./kyber768_amd --api-kem-keygen --batch 128 --pk-out kem_pk.bin --sk-out receiver_sk.demo_secret +./kyber768_amd --api-kem-encaps --batch 128 --pk-in kem_pk.bin --ct-out kem_ct.bin --ss-out ss_sender.demo_secret +./kyber768_amd --api-kem-decaps --batch 128 --sk-in receiver_sk.demo_secret --ct-in kem_ct.bin --ss-out ss_receiver.demo_secret +``` + +```bash +./mldsa65_amd --api-sig-sign --batch 128 --msg-in manifest.payload.json --pk-out sig_pk.bin --sk-out sig_sk.demo_secret --sig-out manifest.sig +./mldsa65_amd --api-sig-verify --batch 128 --msg-in manifest.payload.json --pk-in sig_pk.bin --sig-in manifest.sig +``` + +## Build And Smoke Tests + +```bash +cd kem_api +bash build_hip.sh kyber768 +bash run_kem_smoke_amd.sh +``` + +```bash +cd sig_api +bash build_sig_amd.sh +bash run_sig_policy_smoke.sh 128 +``` + +## Why It Fits The Scoring Item + +This folder shows a previously unsupported ROCm application path: post-quantum KEM and signature workloads are not only ported to HIP, but also exposed as reusable file-level APIs and connected to a complete TrustFlow packaging workflow. diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/docs/api_notes.md b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/docs/api_notes.md new file mode 100644 index 000000000..c7ff2e851 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/docs/api_notes.md @@ -0,0 +1,167 @@ +# 文件级 KEM/SIG 接口说明 + +本文档说明 PQC TrustFlow ROCm 前端实际调用的后端文件级接口。项目目标不是只展示单项 benchmark,而是提供可被前端和应用流程调用的后量子密码组件。 + +## 1. 总体数据流 + +发送端: + +```text +输入文件 + -> 计算 SHA-256 摘要 + -> Kyber/Aigis-enc KEM encaps 得到 shared secret 和 KEM ciphertext + -> SHA-256(shared secret) 派生 AES-256-GCM key + -> AES-256-GCM 加密每个文件 + -> 生成 manifest + -> ML-DSA/Aigis-sig 对 manifest payload 签名 + -> 输出 pqcpack 安全包 +``` + +接收端: + +```text +pqcpack 安全包 + -> Kyber/Aigis-enc KEM decaps 恢复 shared secret + -> SHA-256(shared secret) 派生 AES-256-GCM key + -> 验证 manifest 签名 + -> AES-256-GCM 解密文件 + -> 校验 SHA-256 摘要 + -> 输出恢复目录 +``` + +## 2. KEM 文件级接口 + +可执行文件: + +```text +kyberandaigis-enc/kyber768_amd +``` + +密钥生成: + +```bash +./kyber768_amd \ + --api-kem-keygen \ + --batch 128 \ + --pk-out kem_pk.bin \ + --sk-out receiver_sk.demo_secret +``` + +封装: + +```bash +./kyber768_amd \ + --api-kem-encaps \ + --batch 128 \ + --pk-in kem_pk.bin \ + --ct-out kem_ct.bin \ + --ss-out ss_sender.demo_secret +``` + +解封装: + +```bash +./kyber768_amd \ + --api-kem-decaps \ + --batch 128 \ + --sk-in receiver_sk.demo_secret \ + --ct-in kem_ct.bin \ + --ss-out ss_receiver.demo_secret +``` + +正确性判断: + +```bash +cmp ss_sender.demo_secret ss_receiver.demo_secret +``` + +如果两端 shared secret 一致,则 KEM 文件级接口正确。前端不会直接把 shared secret 当明文密钥使用,而是执行: + +```text +AES-256-GCM key = SHA-256(shared secret) +``` + +随后用该 AES key 对文件内容进行加密和解密。 + +## 3. SIG 文件级接口 + +可执行文件: + +```text +mldsaandaigis-sig/mldsa65_amd +``` + +当前开发包中如果目录仍为 `amd_sig_anchor_results_20260605_031411`,最终改名为 `mldsaandaigis-sig` 后,需要同步更新前端后端路径。 + +签名: + +```bash +./mldsa65_amd \ + --api-sig-sign \ + --batch 128 \ + --msg-in manifest.payload.json \ + --pk-out sig_pk.bin \ + --sk-out sig_sk.demo_secret \ + --sig-out manifest.sig +``` + +验签: + +```bash +./mldsa65_amd \ + --api-sig-verify \ + --batch 128 \ + --msg-in manifest.payload.json \ + --pk-in sig_pk.bin \ + --sig-in manifest.sig +``` + +前端中被签名的对象不是单个文件本身,而是 `manifest.payload.json`。该 payload 包含文件名、密文路径、nonce、tag、SHA-256 摘要、KEM ciphertext 路径和算法配置等信息。这样可以一次性保护整个传输包的结构和文件完整性。 + +## 4. 安全包关键文件 + +一次成功运行会生成类似结构: + +```text +pack_xxx/ + manifest.json + kem/ + kem_pk.bin + kem_ct.bin + sig/ + manifest.payload.json + manifest.sig + sig_pk.bin + encrypted/ + *.enc + recovered/ + ... +``` + +关键含义: + +`manifest.json`:安全包主清单,记录算法配置、文件摘要、密文位置、KEM/SIG 后端信息和验证所需元数据。 + +`kem/kem_ct.bin`:KEM ciphertext,接收端使用私钥 decaps 后恢复 shared secret。 + +`sig/manifest.payload.json`:被 ML-DSA/Aigis-sig 签名的清单载荷。 + +`sig/manifest.sig`:manifest payload 的签名。 + +`encrypted/*.enc`:AES-256-GCM 加密后的文件密文。 + +`recovered/`:验证通过后恢复出的明文文件目录。 + +## 5. Batch/decomp 设计说明 + +本项目不强行使用单实例签名 CLI 作为主路径。原因是 ML-DSA/Aigis-sig 在 AMD ROCm 平台上存在更明显的资源压力,单实例或过重 kernel 容易受到 private segment、scratch、occupancy 等因素影响。 + +因此前端采用 batch/decomp 文件级接口作为实际应用路径: + +```text +batch 提供 GPU 并行吞吐 +decomp pipeline 降低单个签名路径的资源压力 +文件级 API 让前端能够真实调用 KEM/SIG 能力 +``` + +这也是项目的主要工程贡献之一:不是只给出 isolated benchmark,而是把 Kyber/Aigis-enc 和 ML-DSA/Aigis-sig 接成可演示、可验证、可扩展的 ROCm 后量子安全传输流程。 diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/docs/quick_start.md b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/docs/quick_start.md new file mode 100644 index 000000000..0cfbab17d --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/docs/quick_start.md @@ -0,0 +1,163 @@ +# PQC TrustFlow ROCm 快速运行说明 + +本文档用于评审或复现实验时快速启动前端、执行完整流程,并确认输出结果是否正确。 + +## 1. 进入项目目录 + +在 AMD JupyterLab 服务器终端中执行: + +```bash +cd /app/PQC_TrustFlow_ROCm +export LD_LIBRARY_PATH=/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:$LD_LIBRARY_PATH +``` + +如果项目目录尚未改名,也可以先在当前解压目录中运行;最终提交版本建议统一使用 `/app/PQC_TrustFlow_ROCm`。 + +## 2. 启动 Notebook 前端 + +打开 `pqc_trustflow_widgets_demo.ipynb`,执行: + +```python +%cd /app/PQC_TrustFlow_ROCm + +import os +os.environ["LD_LIBRARY_PATH"] = "/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:" + os.environ.get("LD_LIBRARY_PATH", "") + +from pqc_trustflow_frontend import launch_app +launch_app() +``` + +启动后会显示 PQC TrustFlow 前端界面。推荐配置: + +```text +KEM: Kyber-768 +SIG: ML-DSA-65 +batch: 128 +mode: paper +``` + +## 3. 前端按钮含义 + +`准备`:生成或检查演示输入文件,并初始化流程状态。 + +`生成安全包`:调用 ROCm KEM 文件级接口生成共享密钥材料,使用 AES-256-GCM 加密文件,再调用 ROCm ML-DSA/Aigis-sig 文件级接口签名 manifest。 + +`查看安全包`:显示本次生成的安全包目录、zip 包、密文文件、KEM ciphertext 和 manifest 信息。 + +`查看证明`:显示签名载荷、签名文件、ROCm 后端日志、KEM/SIG API 调用结果。 + +`解包并验证`:执行 KEM decaps,恢复 AES 密钥,解密文件,校验 SHA-256 摘要,并验证 manifest 签名。 + +`篡改测试`:复制安全包,自动篡改一个密文或摘要相关文件,再重新验证,确认系统能检测异常。 + +`查看恢复目录`:查看解密后恢复出的文件。 + +`一键运行`:自动执行准备、生成安全包、解包并验证,适合快速演示。 + +`重置`:清空当前前端状态,重新开始一次流程。 + +## 4. 期望前端结果 + +正常流程中,`流程` 标签页应显示: + +```text +准备: PASS +生成安全包: PASS +解包验证: PASS +``` + +`结果与证据` 标签页应包含: + +```text +正常包验证: PASS +KEM 后端: ROCm KEM batch file API +签名后端: ROCm ML-DSA/Aigis-sig batch file API +KEM ciphertext: kem/kem_ct.bin +签名载荷: sig/manifest.payload.json +签名文件: sig/manifest.sig +``` + +执行 `篡改测试` 后,期望结果为: + +```text +篡改检测: PASS +篡改包验证结果: FAIL +``` + +这表示正常包可以通过解密、验签和摘要校验;被篡改后的包无法通过验证。 + +## 5. 终端一键验证 + +如果需要在终端生成一份可归档的 smoke test 输出,可执行: + +```bash +cd /app/PQC_TrustFlow_ROCm +mkdir -p results/smoke_tests results/logs results/screenshots +export LD_LIBRARY_PATH=/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:$LD_LIBRARY_PATH + +python3 - <<'PY' | tee results/smoke_tests/trustflow_smoke_$(date +%Y%m%d_%H%M%S).txt +from pqc_trustflow_frontend.backends import ensure_sample_docs, create_secure_pack, create_tampered_copy_and_verify +import json +from pathlib import Path + +src = ensure_sample_docs() +r = create_secure_pack(src, "Kyber-768", "ML-DSA-65", 128, "paper", run_rocm=True) + +print("pack:", r.pack_dir) +print("zip:", r.pack_zip) +print("verified:", r.verified) +print("logs:", json.dumps(r.rocm_logs, ensure_ascii=False, indent=2)) +print("notes:", r.notes) + +m = json.loads(Path(r.manifest_path).read_text()) +print("kem_backend:", m.get("kem_backend")) +print("signature_backend:", m.get("signature_backend")) +print("kem_ciphertext_file:", m.get("kem_ciphertext_file")) +print("sig_payload:", m.get("sig_payload")) +print("manifest_signature:", m.get("manifest_signature")) + +t = create_tampered_copy_and_verify(r.pack_dir) +print("tamper_detected:", t["tamper_detected"]) +print("tamper_verified:", t["verified"]) +print("file_errors:", t["file_errors"]) +print("kem_ok:", t.get("kem_ok")) +print("sig_api_ok:", t.get("sig_api_ok")) +PY +``` + +期望关键输出: + +```text +verified: True +notes: [] +kem_backend: ROCm KEM batch file API +signature_backend: ROCm ML-DSA/Aigis-sig batch file API +tamper_detected: True +tamper_verified: False +``` + +## 6. 结果文件归档建议 + +运行完成后,建议保留以下证据: + +```text +results/ + screenshots/ + 01_frontend_full_ui.png + 02_pack_encrypt_sign.png + 03_decrypt_verify_digest.png + 04_tamper_detection.png + 05_repository_layout.png + 06_one_click_test.png + 07_generated_artifacts.png + smoke_tests/ + trustflow_smoke_*.txt + logs/ + kemapi_keygen_sample.log + kemapi_encaps_sample.log + kemapi_decaps_sample.log + sigapi_sign_sample.log + sigapi_verify_sample.log +``` + +其中 `screenshots/` 放人工截图,`smoke_tests/` 放终端一键测试输出,`logs/` 放关键 ROCm KEM/SIG 文件级 API 日志样例。 diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/batch_kem.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/batch_kem.cuh new file mode 100644 index 000000000..6d72d73f9 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/batch_kem.cuh @@ -0,0 +1,797 @@ +/* + * batch_kem.cuh — GPU 批量 KEM 流水线 + * + * 参考 mldsa和aigis-sig/batch_keygen.cuh 的优化架构: + * - Warp 协同采样: 1 warp = 1 实例 (并行矩阵展开 + 噪声采样) + * - 共享内存批量 NTT (batch_ntt_kernel, 1 block/poly) + * - 2D grid 矩阵向量乘 (batch_polyvec_matvec_kernel) + * - SoA 内存布局: data[poly_idx * batch_count * N + inst * N + coeff] + * + * 性能要点 (RTX 3050 Ti): + * - 最优 batch size: Keygen/Encaps=16K, Decaps=8K-16K + * - VRAM 限制: K^2 * B * N * sizeof(int16_t) ≤ 可用显存 + */ + +#ifndef BATCH_KEM_CUH +#define BATCH_KEM_CUH + +#include "rocm_compat.h" +#include +#include +#include "params.h" +#include "reduce.cuh" +#include "fips202.cuh" +#include "ntt.cuh" +#include "poly.cuh" +#include "polyvec.cuh" +#include "cbd.cuh" +#include "kem.cuh" +#include "batch_ntt.cuh" +#include "batch_ops.cuh" + +/* ================================================================ + * 缓冲区结构体 + * ================================================================ */ +struct BatchKemBuffers { + /* 批量 keygen/encaps 工作缓冲区 — SoA 布局 [poly_idx][inst][coeff] */ + int16_t *d_mat; /* K*K * B * N — 矩阵 A (NTT 域) */ + int16_t *d_skpv; /* K * B * N — 私钥 s (NTT 域) */ + int16_t *d_pkpv; /* K * B * N — 公钥多项式 b */ + int16_t *d_e; /* K * B * N — keygen 误差 e */ + + /* KEM 字节缓冲区 */ + uint8_t *d_pk_bytes; /* B * PARAM_PUBLICKEYBYTES */ + uint8_t *d_sk_bytes; /* B * PARAM_SECRETKEYBYTES */ + uint8_t *d_ct_bytes; /* B * PARAM_CIPHERTEXTBYTES */ + uint8_t *d_ss_bytes; /* B * PARAM_SSBYTES */ + + /* 随机种子 */ + uint8_t *d_coins_kg; /* B * 2*SYMBYTES — keygen 种子 */ + uint8_t *d_coins_enc;/* B * SYMBYTES — encaps 种子 */ + + uint8_t *d_publicseed_kg; + uint8_t *d_noiseseed_kg; + + int max_batch; +}; + +static inline void batch_kem_alloc(BatchKemBuffers *buf, int max_batch) +{ + buf->max_batch = max_batch; + cudaMalloc(&buf->d_mat, (size_t)PARAM_K * PARAM_K * max_batch * PARAM_N * sizeof(int16_t)); + cudaMalloc(&buf->d_skpv, (size_t)PARAM_K * max_batch * PARAM_N * sizeof(int16_t)); + cudaMalloc(&buf->d_pkpv, (size_t)PARAM_K * max_batch * PARAM_N * sizeof(int16_t)); + cudaMalloc(&buf->d_e, (size_t)PARAM_K * max_batch * PARAM_N * sizeof(int16_t)); + cudaMalloc(&buf->d_pk_bytes, (size_t)max_batch * PARAM_PUBLICKEYBYTES); + cudaMalloc(&buf->d_sk_bytes, (size_t)max_batch * PARAM_SECRETKEYBYTES); + cudaMalloc(&buf->d_ct_bytes, (size_t)max_batch * PARAM_CIPHERTEXTBYTES); + cudaMalloc(&buf->d_ss_bytes, (size_t)max_batch * PARAM_SSBYTES); + cudaMalloc(&buf->d_coins_kg, (size_t)max_batch * 2 * PARAM_SYMBYTES); + cudaMalloc(&buf->d_coins_enc, (size_t)max_batch * PARAM_SYMBYTES); + cudaMalloc(&buf->d_publicseed_kg, (size_t)max_batch * PARAM_SYMBYTES); + cudaMalloc(&buf->d_noiseseed_kg, (size_t)max_batch * PARAM_SYMBYTES); +} + +static inline void batch_kem_free(BatchKemBuffers *buf) +{ + cudaFree(buf->d_mat); + cudaFree(buf->d_skpv); + cudaFree(buf->d_pkpv); + cudaFree(buf->d_e); + cudaFree(buf->d_pk_bytes); + cudaFree(buf->d_sk_bytes); + cudaFree(buf->d_ct_bytes); + cudaFree(buf->d_ss_bytes); + cudaFree(buf->d_coins_kg); + cudaFree(buf->d_coins_enc); + cudaFree(buf->d_publicseed_kg); + cudaFree(buf->d_noiseseed_kg); +} + +/* ================================================================ + * Warp 协同采样 kernel (KEM 密钥生成) + * 1 warp (32 threads) = 1 实例 + * Lane 0: SHA3-512 展开种子 → (publicseed, noiseseed) + * 全部 lanes: 并行展开矩阵 A 和噪声多项式 s, e + * + * 输出 SoA: + * d_mat[row*K*B*N + col*B*N + inst*N + c] = A[inst][row][col][c] + * d_skpv[i*B*N + inst*N + c] = s[inst][i][c] (未 NTT) + * d_e[i*B*N + inst*N + c] = e[inst][i][c] (未 NTT) + * ================================================================ */ + +#ifndef WP_KG_WARP_SIZE +#define WP_KG_WARP_SIZE 32 +#endif + +#ifndef WP_KG_WARPS_BLOCK +#define WP_KG_WARPS_BLOCK 4 +#endif + +#define WP_KG_TPB (WP_KG_WARP_SIZE * WP_KG_WARPS_BLOCK) + +#ifndef KEM_SPLIT_KEYGEN_SAMPLE +#define KEM_SPLIT_KEYGEN_SAMPLE 0 +#endif + +#ifndef KEM_SERIAL_TPB +#ifdef USE_HIP +#define KEM_SERIAL_TPB 64 +#else +#define KEM_SERIAL_TPB 64 +#endif +#endif + +#ifndef KEM_KEYGEN_TPB +#define KEM_KEYGEN_TPB KEM_SERIAL_TPB +#endif + +#ifndef KEM_ENCAPS_TPB +#define KEM_ENCAPS_TPB KEM_SERIAL_TPB +#endif + +#ifndef KEM_DECAPS_TPB +#define KEM_DECAPS_TPB KEM_SERIAL_TPB +#endif + +__global__ void batch_keygen_warp_sample_kernel( + int16_t * __restrict__ d_mat, /* K*K * B * N */ + int16_t * __restrict__ d_skpv, /* K * B * N */ + int16_t * __restrict__ d_e, /* K * B * N */ + uint8_t * __restrict__ d_publicseed, + const uint8_t * __restrict__ d_coins, /* B * 2*SYMBYTES */ + int batch_count) +{ + int inst = blockIdx.x * WP_KG_WARPS_BLOCK + (threadIdx.x / WP_KG_WARP_SIZE); + int lane = threadIdx.x & (WP_KG_WARP_SIZE - 1); + + if (inst >= batch_count) return; + + /* Warp-level shared: publicseed 和 noiseseed */ + __shared__ uint8_t ws_pub[WP_KG_WARPS_BLOCK][PARAM_SYMBYTES]; + __shared__ uint8_t ws_noise[WP_KG_WARPS_BLOCK][PARAM_SYMBYTES]; + + int warp_id = threadIdx.x / WP_KG_WARP_SIZE; + uint8_t *publicseed = ws_pub[warp_id]; + uint8_t *noiseseed = ws_noise[warp_id]; + + if (lane == 0) { + /* 展开种子: SHA3-512(coins[0:32]) → (publicseed[32], noiseseed[32]) */ + uint8_t buf[2 * PARAM_SYMBYTES]; + sha3_512(buf, d_coins + inst * 2 * PARAM_SYMBYTES, PARAM_SYMBYTES); + for (int i = 0; i < PARAM_SYMBYTES; i++) { + publicseed[i] = buf[i]; + d_publicseed[(size_t)inst * PARAM_SYMBYTES + i] = buf[i]; + } + for (int i = 0; i < PARAM_SYMBYTES; i++) noiseseed[i] = buf[PARAM_SYMBYTES + i]; + } + __syncwarp(); + + /* 矩阵展开: 每个 lane 负责若干多项式 (A[row][col]) */ + int total_mat_polys = PARAM_K * PARAM_K; + for (int p = lane; p < total_mat_polys; p += WP_KG_WARP_SIZE) { + int row = p / PARAM_K; + int col = p % PARAM_K; + + /* 目标地址: SoA 格式 */ + int16_t *dst = d_mat + ((size_t)(row * PARAM_K + col) * batch_count + inst) * PARAM_N; + + uint8_t extseed[PARAM_SYMBYTES + 2]; + for (int i = 0; i < PARAM_SYMBYTES; i++) extseed[i] = publicseed[i]; + +#if ALGORITHM == ALGO_KYBER + extseed[PARAM_SYMBYTES] = (uint8_t)col; /* j */ + extseed[PARAM_SYMBYTES+1] = (uint8_t)row; /* i */ +#elif ALGORITHM == ALGO_AIGIS_ENC + extseed[PARAM_SYMBYTES] = (uint8_t)row; /* i */ + extseed[PARAM_SYMBYTES+1] = (uint8_t)col; /* j */ +#endif + +#if KEM_DIRECT_REJ_UNIFORM + rej_uniform_xof(dst, publicseed, extseed[PARAM_SYMBYTES], extseed[PARAM_SYMBYTES + 1]); +#else + keccak_state state; + shake128_absorb_once(&state, extseed, PARAM_SYMBYTES + 2); + + unsigned int ctr = 0; + uint8_t buf[PARAM_GEN_MATRIX_NBLOCKS * PARAM_XOF_BLOCKBYTES]; + while (ctr < PARAM_N) { + shake128_squeezeblocks(buf, PARAM_GEN_MATRIX_NBLOCKS, &state); + ctr += rej_uniform(dst + ctr, PARAM_N - ctr, + buf, PARAM_GEN_MATRIX_NBLOCKS * PARAM_XOF_BLOCKBYTES); + } +#endif + } + + /* 噪声采样: s[0..K-1], e[0..K-1] */ + for (int i = lane; i < PARAM_K; i += WP_KG_WARP_SIZE) { + int16_t *dst_s = d_skpv + ((size_t)i * batch_count + inst) * PARAM_N; + poly_getnoise_s(dst_s, noiseseed, (uint8_t)i); + } + for (int i = lane; i < PARAM_K; i += WP_KG_WARP_SIZE) { + int16_t *dst_e = d_e + ((size_t)i * batch_count + inst) * PARAM_N; + poly_getnoise_e_kg(dst_e, noiseseed, (uint8_t)(PARAM_K + i)); + } +} + +__global__ void batch_keygen_seed_expand_kernel( + uint8_t * __restrict__ d_publicseed, + uint8_t * __restrict__ d_noiseseed, + const uint8_t * __restrict__ d_coins, + int batch_count) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= batch_count) return; + + uint8_t buf[2 * PARAM_SYMBYTES]; + sha3_512(buf, d_coins + (size_t)inst * 2 * PARAM_SYMBYTES, PARAM_SYMBYTES); + for (int i = 0; i < PARAM_SYMBYTES; i++) { + d_publicseed[(size_t)inst * PARAM_SYMBYTES + i] = buf[i]; + d_noiseseed[(size_t)inst * PARAM_SYMBYTES + i] = buf[PARAM_SYMBYTES + i]; + } +} + +__global__ void batch_keygen_mat_sample_kernel( + int16_t * __restrict__ d_mat, + const uint8_t * __restrict__ d_publicseed, + int batch_count) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = batch_count * PARAM_K * PARAM_K; + if (idx >= total) return; + + int inst = idx % batch_count; + int p = idx / batch_count; + int row = p / PARAM_K; + int col = p % PARAM_K; + +#if ALGORITHM == ALGO_KYBER + uint8_t x = (uint8_t)col; + uint8_t y = (uint8_t)row; +#elif ALGORITHM == ALGO_AIGIS_ENC + uint8_t x = (uint8_t)row; + uint8_t y = (uint8_t)col; +#endif + + int16_t *dst = d_mat + ((size_t)(row * PARAM_K + col) * batch_count + inst) * PARAM_N; + const uint8_t *seed = d_publicseed + (size_t)inst * PARAM_SYMBYTES; + rej_uniform_xof(dst, seed, x, y); +} + +__global__ void batch_keygen_noise_sample_kernel( + int16_t * __restrict__ d_skpv, + int16_t * __restrict__ d_e, + const uint8_t * __restrict__ d_noiseseed, + int batch_count) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = batch_count * PARAM_K * 2; + if (idx >= total) return; + + int inst = idx % batch_count; + int q = idx / batch_count; + int poly = q % PARAM_K; + const uint8_t *seed = d_noiseseed + (size_t)inst * PARAM_SYMBYTES; + + if (q < PARAM_K) { + int16_t *dst = d_skpv + ((size_t)poly * batch_count + inst) * PARAM_N; + poly_getnoise_s(dst, seed, (uint8_t)poly); + } else { + int16_t *dst = d_e + ((size_t)poly * batch_count + inst) * PARAM_N; + poly_getnoise_e_kg(dst, seed, (uint8_t)(PARAM_K + poly)); + } +} + +/* ================================================================ + * 批量打包 PK/SK kernel (每 block 处理一个实例) + * 在所有 NTT 和 matvec 计算完成后调用 + * + * 输入: + * d_mat — 矩阵 A (unused for packing, publicseed stored in d_coins) + * d_skpv — NTT 域 s (已 caddq) + * d_pkpv — b = A*s + e (已 caddq), 以 SoA 格式 + * 输出: + * d_pk_bytes — PK 字节流 + * d_sk_bytes — SK 字节流 (indcpa_sk || pk || H(pk) || z) + * ================================================================ */ + +__global__ void batch_pack_keypair_kernel( + uint8_t * __restrict__ d_pk_bytes, + uint8_t * __restrict__ d_sk_bytes, + const int16_t * __restrict__ d_skpv, + const int16_t * __restrict__ d_pkpv, + const uint8_t * __restrict__ d_coins, /* B * 2*SYMBYTES: publicseed 在位置[inst*2*32] */ + int batch_count) +{ + int inst = blockIdx.x; + if (inst >= batch_count) return; + + /* 构建 kem_polyvec 结构 (从 SoA 还原为 AoS) */ + kem_polyvec skpv_local, pkpv_local; + for (int i = 0; i < PARAM_K; i++) + for (int c = 0; c < PARAM_N; c++) { + skpv_local.vec[i].coeffs[c] = d_skpv[((size_t)i * batch_count + inst) * PARAM_N + c]; + pkpv_local.vec[i].coeffs[c] = d_pkpv[((size_t)i * batch_count + inst) * PARAM_N + c]; + } + + /* 从 d_coins 取出 publicseed (keygen 时, sha3_512 已展开, publicseed = 前 32 字节) */ + /* 实际上我们在 warp 采样时已用 sha3_512 展开, 这里需要重新计算 publicseed */ + uint8_t seeds[2 * PARAM_SYMBYTES]; + sha3_512(seeds, d_coins + (size_t)inst * 2 * PARAM_SYMBYTES, PARAM_SYMBYTES); + const uint8_t *publicseed = seeds; + + /* PK = pk_poly_compress(pkpv) || publicseed */ + uint8_t *pk = d_pk_bytes + (size_t)inst * PARAM_PUBLICKEYBYTES; + pack_pk(pk, &pkpv_local, publicseed); + + /* SK = polyvec_tobytes(skpv) || pk || H(pk) || z */ + uint8_t *sk = d_sk_bytes + (size_t)inst * PARAM_SECRETKEYBYTES; + pack_sk(sk, &skpv_local); + + /* sk[indcpa_sk_bytes:] = pk */ + for (int i = 0; i < (int)PARAM_PUBLICKEYBYTES; i++) + sk[PARAM_INDCPA_SECRETKEYBYTES + i] = pk[i]; + + /* H(pk) */ + sha3_256(sk + PARAM_INDCPA_SECRETKEYBYTES + PARAM_PUBLICKEYBYTES, pk, PARAM_PUBLICKEYBYTES); + + /* z = coins[32:64] (第二个 32 字节作为随机 z) */ + const uint8_t *z_src = d_coins + (size_t)inst * 2 * PARAM_SYMBYTES + PARAM_SYMBYTES; + uint8_t *z_dst = sk + PARAM_INDCPA_SECRETKEYBYTES + PARAM_PUBLICKEYBYTES + PARAM_SYMBYTES; + for (int i = 0; i < PARAM_SYMBYTES; i++) z_dst[i] = z_src[i]; +} + +#ifndef KEM_PACK_TPB +#define KEM_PACK_TPB 128 +#endif + +__global__ void batch_pack_sk_polyvec_kernel( + uint8_t * __restrict__ d_sk_bytes, + const int16_t * __restrict__ d_skpv, + int batch_count) +{ + int inst = blockIdx.x; + int poly = blockIdx.y; + int tid = threadIdx.x; + if (inst >= batch_count || poly >= PARAM_K) return; + + const int16_t *src = d_skpv + ((size_t)poly * batch_count + inst) * PARAM_N; + uint8_t *out = d_sk_bytes + (size_t)inst * PARAM_SECRETKEYBYTES + + (size_t)poly * PARAM_POLYBYTES; + +#if ALGORITHM == ALGO_KYBER + for (int i = tid; i < PARAM_N / 2; i += blockDim.x) { + int16_t t0 = caddq(src[2 * i]); + int16_t t1 = caddq(src[2 * i + 1]); + out[3 * i + 0] = (uint8_t)t0; + out[3 * i + 1] = (uint8_t)((t0 >> 8) | (t1 << 4)); + out[3 * i + 2] = (uint8_t)(t1 >> 4); + } +#elif ALGORITHM == ALGO_AIGIS_ENC + for (int i = tid; i < PARAM_N / 8; i += blockDim.x) { + int16_t t0 = caddq(src[8 * i + 0]); + int16_t t1 = caddq(src[8 * i + 1]); + int16_t t2 = caddq(src[8 * i + 2]); + int16_t t3 = caddq(src[8 * i + 3]); + int16_t t4 = caddq(src[8 * i + 4]); + int16_t t5 = caddq(src[8 * i + 5]); + int16_t t6 = caddq(src[8 * i + 6]); + int16_t t7 = caddq(src[8 * i + 7]); + out[13 * i + 0] = (uint8_t)t0; + out[13 * i + 1] = (uint8_t)((t0 >> 8) | (t1 << 5)); + out[13 * i + 2] = (uint8_t)(t1 >> 3); + out[13 * i + 3] = (uint8_t)((t1 >> 11) | (t2 << 2)); + out[13 * i + 4] = (uint8_t)((t2 >> 6) | (t3 << 7)); + out[13 * i + 5] = (uint8_t)(t3 >> 1); + out[13 * i + 6] = (uint8_t)((t3 >> 9) | (t4 << 4)); + out[13 * i + 7] = (uint8_t)(t4 >> 4); + out[13 * i + 8] = (uint8_t)((t4 >> 12) | (t5 << 1)); + out[13 * i + 9] = (uint8_t)((t5 >> 7) | (t6 << 6)); + out[13 * i + 10] = (uint8_t)(t6 >> 2); + out[13 * i + 11] = (uint8_t)((t6 >> 10) | (t7 << 3)); + out[13 * i + 12] = (uint8_t)(t7 >> 5); + } +#endif +} + +__global__ void batch_pack_pk_polyvec_kernel( + uint8_t * __restrict__ d_pk_bytes, + const int16_t * __restrict__ d_pkpv, + int batch_count) +{ + int inst = blockIdx.x; + int poly = blockIdx.y; + int tid = threadIdx.x; + if (inst >= batch_count || poly >= PARAM_K) return; + + const int16_t *src = d_pkpv + ((size_t)poly * batch_count + inst) * PARAM_N; + uint8_t *out = d_pk_bytes + (size_t)inst * PARAM_PUBLICKEYBYTES + + (size_t)poly * (PARAM_BITS_PK * PARAM_N / 8); + +#if ALGORITHM == ALGO_KYBER + for (int i = tid; i < PARAM_N / 2; i += blockDim.x) { + int16_t t0 = caddq(src[2 * i]); + int16_t t1 = caddq(src[2 * i + 1]); + out[3 * i + 0] = (uint8_t)t0; + out[3 * i + 1] = (uint8_t)((t0 >> 8) | (t1 << 4)); + out[3 * i + 2] = (uint8_t)(t1 >> 4); + } +#elif PARAM_BITS_PK == 9 + for (int i = tid; i < PARAM_N / 8; i += blockDim.x) { + uint16_t c0 = (uint16_t)((((int32_t)caddq(src[8*i+0]) << 9) + PARAM_Q/2) / PARAM_Q) & 0x1FF; + uint16_t c1 = (uint16_t)((((int32_t)caddq(src[8*i+1]) << 9) + PARAM_Q/2) / PARAM_Q) & 0x1FF; + uint16_t c2 = (uint16_t)((((int32_t)caddq(src[8*i+2]) << 9) + PARAM_Q/2) / PARAM_Q) & 0x1FF; + uint16_t c3 = (uint16_t)((((int32_t)caddq(src[8*i+3]) << 9) + PARAM_Q/2) / PARAM_Q) & 0x1FF; + uint16_t c4 = (uint16_t)((((int32_t)caddq(src[8*i+4]) << 9) + PARAM_Q/2) / PARAM_Q) & 0x1FF; + uint16_t c5 = (uint16_t)((((int32_t)caddq(src[8*i+5]) << 9) + PARAM_Q/2) / PARAM_Q) & 0x1FF; + uint16_t c6 = (uint16_t)((((int32_t)caddq(src[8*i+6]) << 9) + PARAM_Q/2) / PARAM_Q) & 0x1FF; + uint16_t c7 = (uint16_t)((((int32_t)caddq(src[8*i+7]) << 9) + PARAM_Q/2) / PARAM_Q) & 0x1FF; + out[9*i+0] = (uint8_t)c0; + out[9*i+1] = (uint8_t)((c0 >> 8) | (c1 << 1)); + out[9*i+2] = (uint8_t)((c1 >> 7) | (c2 << 2)); + out[9*i+3] = (uint8_t)((c2 >> 6) | (c3 << 3)); + out[9*i+4] = (uint8_t)((c3 >> 5) | (c4 << 4)); + out[9*i+5] = (uint8_t)((c4 >> 4) | (c5 << 5)); + out[9*i+6] = (uint8_t)((c5 >> 3) | (c6 << 6)); + out[9*i+7] = (uint8_t)((c6 >> 2) | (c7 << 7)); + out[9*i+8] = (uint8_t)(c7 >> 1); + } +#elif PARAM_BITS_PK == 10 + for (int i = tid; i < PARAM_N / 4; i += blockDim.x) { + uint16_t c0 = (uint16_t)((((int32_t)caddq(src[4*i+0]) << 10) + PARAM_Q/2) / PARAM_Q) & 0x3FF; + uint16_t c1 = (uint16_t)((((int32_t)caddq(src[4*i+1]) << 10) + PARAM_Q/2) / PARAM_Q) & 0x3FF; + uint16_t c2 = (uint16_t)((((int32_t)caddq(src[4*i+2]) << 10) + PARAM_Q/2) / PARAM_Q) & 0x3FF; + uint16_t c3 = (uint16_t)((((int32_t)caddq(src[4*i+3]) << 10) + PARAM_Q/2) / PARAM_Q) & 0x3FF; + out[5*i+0] = (uint8_t)c0; + out[5*i+1] = (uint8_t)((c0 >> 8) | (c1 << 2)); + out[5*i+2] = (uint8_t)((c1 >> 6) | (c2 << 4)); + out[5*i+3] = (uint8_t)((c2 >> 4) | (c3 << 6)); + out[5*i+4] = (uint8_t)(c3 >> 2); + } +#elif PARAM_BITS_PK == 11 + for (int i = tid; i < PARAM_N / 8; i += blockDim.x) { + uint16_t c0 = (uint16_t)((((int32_t)caddq(src[8*i+0]) << 11) + PARAM_Q/2) / PARAM_Q) & 0x7FF; + uint16_t c1 = (uint16_t)((((int32_t)caddq(src[8*i+1]) << 11) + PARAM_Q/2) / PARAM_Q) & 0x7FF; + uint16_t c2 = (uint16_t)((((int32_t)caddq(src[8*i+2]) << 11) + PARAM_Q/2) / PARAM_Q) & 0x7FF; + uint16_t c3 = (uint16_t)((((int32_t)caddq(src[8*i+3]) << 11) + PARAM_Q/2) / PARAM_Q) & 0x7FF; + uint16_t c4 = (uint16_t)((((int32_t)caddq(src[8*i+4]) << 11) + PARAM_Q/2) / PARAM_Q) & 0x7FF; + uint16_t c5 = (uint16_t)((((int32_t)caddq(src[8*i+5]) << 11) + PARAM_Q/2) / PARAM_Q) & 0x7FF; + uint16_t c6 = (uint16_t)((((int32_t)caddq(src[8*i+6]) << 11) + PARAM_Q/2) / PARAM_Q) & 0x7FF; + uint16_t c7 = (uint16_t)((((int32_t)caddq(src[8*i+7]) << 11) + PARAM_Q/2) / PARAM_Q) & 0x7FF; + out[11*i+ 0] = (uint8_t)c0; + out[11*i+ 1] = (uint8_t)((c0 >> 8) | (c1 << 3)); + out[11*i+ 2] = (uint8_t)((c1 >> 5) | (c2 << 6)); + out[11*i+ 3] = (uint8_t)(c2 >> 2); + out[11*i+ 4] = (uint8_t)((c2 >> 10) | (c3 << 1)); + out[11*i+ 5] = (uint8_t)((c3 >> 7) | (c4 << 4)); + out[11*i+ 6] = (uint8_t)((c4 >> 4) | (c5 << 7)); + out[11*i+ 7] = (uint8_t)(c5 >> 1); + out[11*i+ 8] = (uint8_t)((c5 >> 9) | (c6 << 2)); + out[11*i+ 9] = (uint8_t)((c6 >> 6) | (c7 << 5)); + out[11*i+10] = (uint8_t)(c7 >> 3); + } +#endif +} + +__global__ void batch_pack_keypair_finalize_kernel( + uint8_t * __restrict__ d_pk_bytes, + uint8_t * __restrict__ d_sk_bytes, + const uint8_t * __restrict__ d_publicseed, + const uint8_t * __restrict__ d_coins, + int batch_count) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= batch_count) return; + + uint8_t *pk = d_pk_bytes + (size_t)inst * PARAM_PUBLICKEYBYTES; + uint8_t *sk = d_sk_bytes + (size_t)inst * PARAM_SECRETKEYBYTES; + const uint8_t *rho = d_publicseed + (size_t)inst * PARAM_SYMBYTES; + + for (int i = 0; i < PARAM_SYMBYTES; i++) + pk[PARAM_PK_POLYVEC_BYTES + i] = rho[i]; + + for (int i = 0; i < (int)PARAM_PUBLICKEYBYTES; i++) + sk[PARAM_INDCPA_SECRETKEYBYTES + i] = pk[i]; + + sha3_256(sk + PARAM_INDCPA_SECRETKEYBYTES + PARAM_PUBLICKEYBYTES, + pk, PARAM_PUBLICKEYBYTES); + + const uint8_t *z_src = d_coins + (size_t)inst * 2 * PARAM_SYMBYTES + PARAM_SYMBYTES; + uint8_t *z_dst = sk + PARAM_INDCPA_SECRETKEYBYTES + PARAM_PUBLICKEYBYTES + PARAM_SYMBYTES; + for (int i = 0; i < PARAM_SYMBYTES; i++) z_dst[i] = z_src[i]; +} + +/* ================================================================ + * 批量单实例 keygen kernel (完整流水线, 单线程 fallback) + * 用于 batch 较小时, 直接调用 kem_keypair 设备函数 + * ================================================================ */ +#ifndef KEM_KEYPAIR_LAUNCH_BOUNDS +#define KEM_KEYPAIR_LAUNCH_BOUNDS 1 +#endif + +#ifndef KEM_ENCAPS_LAUNCH_BOUNDS +#if ALGORITHM == ALGO_AIGIS_ENC +#define KEM_ENCAPS_LAUNCH_BOUNDS 1 +#else +#define KEM_ENCAPS_LAUNCH_BOUNDS 0 +#endif +#endif + +#ifndef KEM_DECAPS_LAUNCH_BOUNDS +#if ALGORITHM == ALGO_AIGIS_ENC +#define KEM_DECAPS_LAUNCH_BOUNDS 1 +#else +#define KEM_DECAPS_LAUNCH_BOUNDS 0 +#endif +#endif + +#if KEM_KEYPAIR_LAUNCH_BOUNDS +#define KEM_KEYPAIR_KERNEL_BOUNDS __launch_bounds__(KEM_KEYGEN_TPB, 1) +#else +#define KEM_KEYPAIR_KERNEL_BOUNDS +#endif + +#if KEM_ENCAPS_LAUNCH_BOUNDS +#define KEM_ENCAPS_KERNEL_BOUNDS __launch_bounds__(KEM_ENCAPS_TPB, 1) +#else +#define KEM_ENCAPS_KERNEL_BOUNDS +#endif + +#if KEM_DECAPS_LAUNCH_BOUNDS +#define KEM_DECAPS_KERNEL_BOUNDS __launch_bounds__(KEM_DECAPS_TPB, 1) +#else +#define KEM_DECAPS_KERNEL_BOUNDS +#endif + +__global__ KEM_KEYPAIR_KERNEL_BOUNDS void batch_kem_keypair_serial_kernel( + uint8_t * __restrict__ d_pk, + uint8_t * __restrict__ d_sk, + const uint8_t * __restrict__ d_coins, /* B * 2*SYMBYTES */ + int batch_count) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= batch_count) return; + + kem_keypair( + d_pk + (size_t)inst * PARAM_PUBLICKEYBYTES, + d_sk + (size_t)inst * PARAM_SECRETKEYBYTES, + d_coins + (size_t)inst * 2 * PARAM_SYMBYTES + ); +} + +/* ================================================================ + * 批量单实例 encaps kernel + * ================================================================ */ +__global__ KEM_ENCAPS_KERNEL_BOUNDS void batch_kem_encaps_serial_kernel( + uint8_t * __restrict__ d_ct, + uint8_t * __restrict__ d_ss, + const uint8_t * __restrict__ d_pk, + const uint8_t * __restrict__ d_coins, /* B * SYMBYTES */ + int batch_count) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= batch_count) return; + + kem_encaps( + d_ct + (size_t)inst * PARAM_CIPHERTEXTBYTES, + d_ss + (size_t)inst * PARAM_SSBYTES, + d_pk + (size_t)inst * PARAM_PUBLICKEYBYTES, + d_coins + (size_t)inst * PARAM_SYMBYTES + ); +} + +/* ================================================================ + * 批量单实例 decaps kernel + * ================================================================ */ +__global__ KEM_DECAPS_KERNEL_BOUNDS void batch_kem_decaps_serial_kernel( + uint8_t * __restrict__ d_ss, + const uint8_t * __restrict__ d_ct, + const uint8_t * __restrict__ d_sk, + int batch_count) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= batch_count) return; + + kem_decaps( + d_ss + (size_t)inst * PARAM_SSBYTES, + d_ct + (size_t)inst * PARAM_CIPHERTEXTBYTES, + d_sk + (size_t)inst * PARAM_SECRETKEYBYTES + ); +} + +/* ================================================================ + * 批量 KEM 高性能流水线 + * + * batch_keygen_pipelined: + * 1. Warp 采样 (矩阵 A + s + e) + * 2. 批量 NTT(s) + * 3. 2D grid 矩阵向量乘 (A*s → pkpv) + * 4. 批量 INVNTT(pkpv) + 加 e, caddq + * 5. 打包 pk/sk + * ================================================================ */ +static inline cudaError_t batch_keygen_pipelined( + uint8_t *d_pk_out, uint8_t *d_sk_out, + BatchKemBuffers *buf, + int batch_count, + cudaStream_t stream = 0) +{ + cudaError_t err; + + /* 生成随机种子 (device side — 在 host 侧用 cudaMemcpy 传入 d_coins_kg) */ + + /* Step 1: Warp 采样 */ + int blocks = (batch_count + WP_KG_WARPS_BLOCK - 1) / WP_KG_WARPS_BLOCK; +#if KEM_SPLIT_KEYGEN_SAMPLE + batch_keygen_seed_expand_kernel<<>>( + buf->d_publicseed_kg, buf->d_noiseseed_kg, buf->d_coins_kg, batch_count); + batch_keygen_mat_sample_kernel<<>>( + buf->d_mat, buf->d_publicseed_kg, batch_count); + batch_keygen_noise_sample_kernel<<>>( + buf->d_skpv, buf->d_e, buf->d_noiseseed_kg, batch_count); +#else + batch_keygen_warp_sample_kernel<<>>( + buf->d_mat, buf->d_skpv, buf->d_e, + buf->d_publicseed_kg, buf->d_coins_kg, batch_count); +#endif + + /* Step 2: 批量 NTT(s) — d_skpv 中 K 个 poly 组 */ + for (int i = 0; i < PARAM_K; i++) { + int16_t *ptr = buf->d_skpv + (size_t)i * batch_count * PARAM_N; + batch_ntt_kernel<<>>(ptr, batch_count); + } + + /* Step 2b: caddq(s) */ + for (int i = 0; i < PARAM_K; i++) { + int16_t *ptr = buf->d_skpv + (size_t)i * batch_count * PARAM_N; + launch_batch_caddq(ptr, batch_count, stream); + } + + /* Step 3: 矩阵向量乘 A * s_hat → pkpv */ + launch_batch_matvec(buf->d_pkpv, buf->d_mat, buf->d_skpv, batch_count, stream); + + /* Step 4: INVNTT(pkpv) */ + for (int i = 0; i < PARAM_K; i++) { + int16_t *ptr = buf->d_pkpv + (size_t)i * batch_count * PARAM_N; + batch_invntt_kernel<<>>(ptr, batch_count); + } + + /* pkpv += e */ + for (int i = 0; i < PARAM_K; i++) { + launch_batch_add( + buf->d_pkpv + (size_t)i * batch_count * PARAM_N, + buf->d_pkpv + (size_t)i * batch_count * PARAM_N, + buf->d_e + (size_t)i * batch_count * PARAM_N, + batch_count, stream); + } + + /* caddq(pkpv) */ + for (int i = 0; i < PARAM_K; i++) { + launch_batch_caddq(buf->d_pkpv + (size_t)i * batch_count * PARAM_N, batch_count, stream); + } + + /* Step 5: 打包 PK/SK */ + dim3 pack_grid(batch_count, PARAM_K); + batch_pack_sk_polyvec_kernel<<>>( + d_sk_out, buf->d_skpv, batch_count); + batch_pack_pk_polyvec_kernel<<>>( + d_pk_out, buf->d_pkpv, batch_count); + batch_pack_keypair_finalize_kernel<<>>( + d_pk_out, d_sk_out, buf->d_publicseed_kg, buf->d_coins_kg, batch_count); + + err = cudaGetLastError(); + return err; +} + +/* ================================================================ + * 简化批量 encaps/decaps (串行 kernel, 可进一步并行化) + * ================================================================ */ +static inline cudaError_t batch_keygen_pipelined_profile( + uint8_t *d_pk_out, uint8_t *d_sk_out, + BatchKemBuffers *buf, + int batch_count, + cudaStream_t stream = 0) +{ + cudaEvent_t ev0, ev1, ev2, ev3, ev4, ev5, ev6; + cudaEventCreate(&ev0); cudaEventCreate(&ev1); cudaEventCreate(&ev2); + cudaEventCreate(&ev3); cudaEventCreate(&ev4); cudaEventCreate(&ev5); cudaEventCreate(&ev6); + + cudaEventRecord(ev0, stream); + int blocks = (batch_count + WP_KG_WARPS_BLOCK - 1) / WP_KG_WARPS_BLOCK; +#if KEM_SPLIT_KEYGEN_SAMPLE + batch_keygen_seed_expand_kernel<<>>( + buf->d_publicseed_kg, buf->d_noiseseed_kg, buf->d_coins_kg, batch_count); + batch_keygen_mat_sample_kernel<<>>( + buf->d_mat, buf->d_publicseed_kg, batch_count); + batch_keygen_noise_sample_kernel<<>>( + buf->d_skpv, buf->d_e, buf->d_noiseseed_kg, batch_count); +#else + batch_keygen_warp_sample_kernel<<>>( + buf->d_mat, buf->d_skpv, buf->d_e, + buf->d_publicseed_kg, buf->d_coins_kg, batch_count); +#endif + cudaEventRecord(ev1, stream); + + for (int i = 0; i < PARAM_K; i++) { + int16_t *ptr = buf->d_skpv + (size_t)i * batch_count * PARAM_N; + batch_ntt_kernel<<>>(ptr, batch_count); + } + for (int i = 0; i < PARAM_K; i++) { + int16_t *ptr = buf->d_skpv + (size_t)i * batch_count * PARAM_N; + launch_batch_caddq(ptr, batch_count, stream); + } + cudaEventRecord(ev2, stream); + + launch_batch_matvec(buf->d_pkpv, buf->d_mat, buf->d_skpv, batch_count, stream); + cudaEventRecord(ev3, stream); + + for (int i = 0; i < PARAM_K; i++) { + int16_t *ptr = buf->d_pkpv + (size_t)i * batch_count * PARAM_N; + batch_invntt_kernel<<>>(ptr, batch_count); + } + cudaEventRecord(ev4, stream); + + for (int i = 0; i < PARAM_K; i++) { + launch_batch_add( + buf->d_pkpv + (size_t)i * batch_count * PARAM_N, + buf->d_pkpv + (size_t)i * batch_count * PARAM_N, + buf->d_e + (size_t)i * batch_count * PARAM_N, + batch_count, stream); + } + for (int i = 0; i < PARAM_K; i++) + launch_batch_caddq(buf->d_pkpv + (size_t)i * batch_count * PARAM_N, batch_count, stream); + cudaEventRecord(ev5, stream); + + dim3 pack_grid(batch_count, PARAM_K); + batch_pack_sk_polyvec_kernel<<>>( + d_sk_out, buf->d_skpv, batch_count); + batch_pack_pk_polyvec_kernel<<>>( + d_pk_out, buf->d_pkpv, batch_count); + batch_pack_keypair_finalize_kernel<<>>( + d_pk_out, d_sk_out, buf->d_publicseed_kg, buf->d_coins_kg, batch_count); + cudaEventRecord(ev6, stream); + cudaEventSynchronize(ev6); + + float sample_ms, ntt_ms, matvec_ms, invntt_ms, add_ms, pack_ms, total_ms; + cudaEventElapsedTime(&sample_ms, ev0, ev1); + cudaEventElapsedTime(&ntt_ms, ev1, ev2); + cudaEventElapsedTime(&matvec_ms, ev2, ev3); + cudaEventElapsedTime(&invntt_ms, ev3, ev4); + cudaEventElapsedTime(&add_ms, ev4, ev5); + cudaEventElapsedTime(&pack_ms, ev5, ev6); + cudaEventElapsedTime(&total_ms, ev0, ev6); + printf(" Pipeline profile: sample=%.3f ntt=%.3f matvec=%.3f invntt=%.3f add=%.3f pack=%.3f total=%.3f ms\n", + sample_ms, ntt_ms, matvec_ms, invntt_ms, add_ms, pack_ms, total_ms); + + cudaEventDestroy(ev0); cudaEventDestroy(ev1); cudaEventDestroy(ev2); + cudaEventDestroy(ev3); cudaEventDestroy(ev4); cudaEventDestroy(ev5); cudaEventDestroy(ev6); + return cudaGetLastError(); +} + +static inline cudaError_t batch_encaps_serial( + uint8_t *d_ct, uint8_t *d_ss, + const uint8_t *d_pk, + BatchKemBuffers *buf, + int batch_count, + cudaStream_t stream = 0) +{ + int tpb = KEM_ENCAPS_TPB; + int blocks = (batch_count + tpb - 1) / tpb; + batch_kem_encaps_serial_kernel<<>>( + d_ct, d_ss, d_pk, buf->d_coins_enc, batch_count); + return cudaGetLastError(); +} + +static inline cudaError_t batch_decaps_serial( + uint8_t *d_ss, + const uint8_t *d_ct, const uint8_t *d_sk, + int batch_count, + cudaStream_t stream = 0) +{ + int tpb = KEM_DECAPS_TPB; + int blocks = (batch_count + tpb - 1) / tpb; + batch_kem_decaps_serial_kernel<<>>( + d_ss, d_ct, d_sk, batch_count); + return cudaGetLastError(); +} + +#endif /* BATCH_KEM_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/batch_ntt.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/batch_ntt.cuh new file mode 100644 index 000000000..8beeb744c --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/batch_ntt.cuh @@ -0,0 +1,481 @@ +/* + * batch_ntt.cuh — GPU 批量 NTT/INVNTT + * + * 每个 block 处理一个多项式 (128 threads/block) + * 使用共享内存避免全局内存延迟 + * + * Kyber: 7 级蝶形 (len=128→2),常数表 ntt_zetas[128] + * Aigis-enc: 8 级蝶形 (len=128→1),常数表 ntt_zetas[256] 和 ntt_zetas_inv[256] + * + * 批量 NTT 格式: + * polys[poly_idx * N + coeff_idx] — AoS 布局 (poly 连续, 每 poly N 个 int16_t) + * SoA: polys[poly_idx][batch_inst][coeff_idx] 由 batch_kem.cuh 处理 + * + * 优化: + * 共享内存 bank 填充: SP(i) = i + (i >> 5) 避免 bank 冲突 (32-bit bank) + * 注: int16_t 使用 SP(i) = i + (i >> 4) 可能更优, 但参考 mldsa 实现 + */ + +#ifndef BATCH_NTT_CUH +#define BATCH_NTT_CUH + +#include "rocm_compat.h" +#include +#include "params.h" +#include "reduce.cuh" +#include "ntt.cuh" + +/* bank 填充宏: 将 256 int16_t 扩展为 264 元素以避免 bank 冲突 + * 实际: 每 32 个 int16_t 后插入 1 个 padding → S[i + (i>>5)] */ +#define SP(i) ((i) + ((i) >> 5)) +#define SPAD (PARAM_N + (PARAM_N >> 5)) /* 264 */ + +/* ================================================================ + * 批量 NTT kernel + * polys: int16_t 数组, 每 poly N 个连续系数 + * batch_count: poly 个数 + * ================================================================ */ +__global__ void batch_ntt_kernel(int16_t * __restrict__ polys, int batch_count) +{ + int poly_idx = blockIdx.x; + if (poly_idx >= batch_count) return; + + int tid = (int)threadIdx.x; /* 0..127 */ + + __shared__ int16_t s[SPAD]; + + /* 加载 poly 到共享内存 */ + int16_t *base = polys + poly_idx * PARAM_N; + s[SP(tid)] = base[tid]; + s[SP(tid + 128)] = base[tid + 128]; + __syncthreads(); + +#if ALGORITHM == ALGO_KYBER + + /* Kyber: 7 级, len=128→2, zeta 索引从 1 开始 */ + /* Level 7: len=128, 1 group */ + { + int16_t zeta = ntt_zetas[1]; + int j = tid; /* 0..127 */ + int16_t t = fqmul(zeta, s[SP(j + 128)]); + s[SP(j + 128)] = s[SP(j)] - t; + s[SP(j)] = s[SP(j)] + t; + } + __syncthreads(); + + /* Level 6: len=64, 2 groups, zeta[2,3] */ + { + int group = tid >> 6; /* 0 or 1 */ + int lane = tid & 0x3F; /* 0..63 */ + int16_t zeta = ntt_zetas[2 + group]; + int base_idx = group * 128 + lane; + int16_t t = fqmul(zeta, s[SP(base_idx + 64)]); + s[SP(base_idx + 64)] = s[SP(base_idx)] - t; + s[SP(base_idx)] = s[SP(base_idx)] + t; + } + __syncthreads(); + + /* Level 5: len=32, 4 groups, zeta[4..7] */ + { + int group = tid >> 5; + int lane = tid & 0x1F; + int16_t zeta = ntt_zetas[4 + group]; + int base_idx = group * 64 + lane; + int16_t t = fqmul(zeta, s[SP(base_idx + 32)]); + s[SP(base_idx + 32)] = s[SP(base_idx)] - t; + s[SP(base_idx)] = s[SP(base_idx)] + t; + } + __syncthreads(); + + /* Level 4: len=16, 8 groups, zeta[8..15] */ + { + int group = tid >> 4; + int lane = tid & 0x0F; + int16_t zeta = ntt_zetas[8 + group]; + int base_idx = group * 32 + lane; + int16_t t = fqmul(zeta, s[SP(base_idx + 16)]); + s[SP(base_idx + 16)] = s[SP(base_idx)] - t; + s[SP(base_idx)] = s[SP(base_idx)] + t; + } + __syncthreads(); + + /* Level 3: len=8, 16 groups, zeta[16..31] */ + { + int group = tid >> 3; + int lane = tid & 0x07; + int16_t zeta = ntt_zetas[16 + group]; + int base_idx = group * 16 + lane; + int16_t t = fqmul(zeta, s[SP(base_idx + 8)]); + s[SP(base_idx + 8)] = s[SP(base_idx)] - t; + s[SP(base_idx)] = s[SP(base_idx)] + t; + } + __syncthreads(); + + /* Level 2: len=4, 32 groups, zeta[32..63] */ + { + int group = tid >> 2; + int lane = tid & 0x03; + int16_t zeta = ntt_zetas[32 + group]; + int base_idx = group * 8 + lane; + int16_t t = fqmul(zeta, s[SP(base_idx + 4)]); + s[SP(base_idx + 4)] = s[SP(base_idx)] - t; + s[SP(base_idx)] = s[SP(base_idx)] + t; + } + __syncthreads(); + + /* Level 1: len=2, 64 groups, zeta[64..127] */ + { + int group = tid >> 1; + int lane = tid & 0x01; + int16_t zeta = ntt_zetas[64 + group]; + int base_idx = group * 4 + lane; + int16_t t = fqmul(zeta, s[SP(base_idx + 2)]); + s[SP(base_idx + 2)] = s[SP(base_idx)] - t; + s[SP(base_idx)] = s[SP(base_idx)] + t; + } + __syncthreads(); + +#elif ALGORITHM == ALGO_AIGIS_ENC + + /* Aigis: 8 级, len=128→1, zeta 索引从 1 开始 */ + /* Level 7: len=128, 1 group, zeta[1] */ + { + int16_t zeta = ntt_zetas[1]; + int j = tid; + int16_t t = fqmul(zeta, s[SP(j + 128)]); + s[SP(j + 128)] = s[SP(j)] - t; + s[SP(j)] = s[SP(j)] + t; + } + __syncthreads(); + + /* Level 6: len=64, 2 groups */ + { + int group = tid >> 6; + int lane = tid & 0x3F; + int16_t zeta = ntt_zetas[2 + group]; + int base_idx = group * 128 + lane; + int16_t t = fqmul(zeta, s[SP(base_idx + 64)]); + s[SP(base_idx + 64)] = barrett_reduce((int16_t)(s[SP(base_idx)] - t)); + s[SP(base_idx)] = barrett_reduce((int16_t)(s[SP(base_idx)] + t)); + } + __syncthreads(); + + /* Level 5: len=32, 4 groups */ + { + int group = tid >> 5; + int lane = tid & 0x1F; + int16_t zeta = ntt_zetas[4 + group]; + int base_idx = group * 64 + lane; + int16_t t = fqmul(zeta, s[SP(base_idx + 32)]); + s[SP(base_idx + 32)] = s[SP(base_idx)] - t; + s[SP(base_idx)] = s[SP(base_idx)] + t; + } + __syncthreads(); + + /* Level 4: len=16, 8 groups */ + { + int group = tid >> 4; + int lane = tid & 0x0F; + int16_t zeta = ntt_zetas[8 + group]; + int base_idx = group * 32 + lane; + int16_t t = fqmul(zeta, s[SP(base_idx + 16)]); + s[SP(base_idx + 16)] = s[SP(base_idx)] - t; + s[SP(base_idx)] = s[SP(base_idx)] + t; + } + __syncthreads(); + + /* Level 3: len=8, 16 groups */ + { + int group = tid >> 3; + int lane = tid & 0x07; + int16_t zeta = ntt_zetas[16 + group]; + int base_idx = group * 16 + lane; + int16_t t = fqmul(zeta, s[SP(base_idx + 8)]); + s[SP(base_idx + 8)] = barrett_reduce((int16_t)(s[SP(base_idx)] - t)); + s[SP(base_idx)] = barrett_reduce((int16_t)(s[SP(base_idx)] + t)); + } + __syncthreads(); + + /* Level 2: len=4, 32 groups */ + { + int group = tid >> 2; + int lane = tid & 0x03; + int16_t zeta = ntt_zetas[32 + group]; + int base_idx = group * 8 + lane; + int16_t t = fqmul(zeta, s[SP(base_idx + 4)]); + s[SP(base_idx + 4)] = s[SP(base_idx)] - t; + s[SP(base_idx)] = s[SP(base_idx)] + t; + } + __syncthreads(); + + /* Level 1: len=2, 64 groups */ + { + int group = tid >> 1; + int lane = tid & 0x01; + int16_t zeta = ntt_zetas[64 + group]; + int base_idx = group * 4 + lane; + int16_t t = fqmul(zeta, s[SP(base_idx + 2)]); + s[SP(base_idx + 2)] = s[SP(base_idx)] - t; + s[SP(base_idx)] = s[SP(base_idx)] + t; + } + __syncthreads(); + + /* Level 0: len=1, 128 groups */ + { + int group = tid; + int16_t zeta = ntt_zetas[128 + group]; + int base_idx = group * 2; + int16_t t = fqmul(zeta, s[SP(base_idx + 1)]); + s[SP(base_idx + 1)] = barrett_reduce((int16_t)(s[SP(base_idx)] - t)); + s[SP(base_idx)] = barrett_reduce((int16_t)(s[SP(base_idx)] + t)); + } + __syncthreads(); + +#endif /* ALGORITHM for NTT levels */ + + /* 写回 */ + base[tid] = s[SP(tid)]; + base[tid + 128] = s[SP(tid + 128)]; +} + +/* ================================================================ + * 批量 INVNTT kernel + * ================================================================ */ +__global__ void batch_invntt_kernel(int16_t * __restrict__ polys, int batch_count) +{ + int poly_idx = blockIdx.x; + if (poly_idx >= batch_count) return; + + int tid = (int)threadIdx.x; + + __shared__ int16_t s[SPAD]; + + int16_t *base = polys + poly_idx * PARAM_N; + s[SP(tid)] = base[tid]; + s[SP(tid + 128)] = base[tid + 128]; + __syncthreads(); + +#if ALGORITHM == ALGO_KYBER + + /* Kyber INVNTT: 从 len=2 反向, 使用 +zetas[k--] */ + /* Level 1: len=2 → 64 groups, zeta[64..127] */ + { + int group = tid >> 1; + int lane = tid & 0x01; + int16_t zeta = ntt_zetas[64 + group]; + int base_idx = group * 4 + lane; + int16_t t = s[SP(base_idx)]; + s[SP(base_idx)] = barrett_reduce((int16_t)(t + s[SP(base_idx + 2)])); + s[SP(base_idx + 2)] = fqmul(zeta, (int16_t)(s[SP(base_idx + 2)] - t)); + } + __syncthreads(); + + /* Level 2: len=4 → 32 groups, zeta[32..63] */ + { + int group = tid >> 2; + int lane = tid & 0x03; + int16_t zeta = ntt_zetas[32 + group]; + int base_idx = group * 8 + lane; + int16_t t = s[SP(base_idx)]; + s[SP(base_idx)] = barrett_reduce((int16_t)(t + s[SP(base_idx + 4)])); + s[SP(base_idx + 4)] = fqmul(zeta, (int16_t)(s[SP(base_idx + 4)] - t)); + } + __syncthreads(); + + /* Level 3: len=8 → 16 groups, zeta[16..31] */ + { + int group = tid >> 3; + int lane = tid & 0x07; + int16_t zeta = ntt_zetas[16 + group]; + int base_idx = group * 16 + lane; + int16_t t = s[SP(base_idx)]; + s[SP(base_idx)] = barrett_reduce((int16_t)(t + s[SP(base_idx + 8)])); + s[SP(base_idx + 8)] = fqmul(zeta, (int16_t)(s[SP(base_idx + 8)] - t)); + } + __syncthreads(); + + /* Level 4: len=16 → 8 groups, zeta[8..15] */ + { + int group = tid >> 4; + int lane = tid & 0x0F; + int16_t zeta = ntt_zetas[8 + group]; + int base_idx = group * 32 + lane; + int16_t t = s[SP(base_idx)]; + s[SP(base_idx)] = barrett_reduce((int16_t)(t + s[SP(base_idx + 16)])); + s[SP(base_idx + 16)] = fqmul(zeta, (int16_t)(s[SP(base_idx + 16)] - t)); + } + __syncthreads(); + + /* Level 5: len=32 → 4 groups, zeta[4..7] */ + { + int group = tid >> 5; + int lane = tid & 0x1F; + int16_t zeta = ntt_zetas[4 + group]; + int base_idx = group * 64 + lane; + int16_t t = s[SP(base_idx)]; + s[SP(base_idx)] = barrett_reduce((int16_t)(t + s[SP(base_idx + 32)])); + s[SP(base_idx + 32)] = fqmul(zeta, (int16_t)(s[SP(base_idx + 32)] - t)); + } + __syncthreads(); + + /* Level 6: len=64 → 2 groups, zeta[2,3] */ + { + int group = tid >> 6; + int lane = tid & 0x3F; + int16_t zeta = ntt_zetas[2 + group]; + int base_idx = group * 128 + lane; + int16_t t = s[SP(base_idx)]; + s[SP(base_idx)] = barrett_reduce((int16_t)(t + s[SP(base_idx + 64)])); + s[SP(base_idx + 64)] = fqmul(zeta, (int16_t)(s[SP(base_idx + 64)] - t)); + } + __syncthreads(); + + /* Level 7: len=128 → 1 group, zeta[1] */ + { + int16_t zeta = ntt_zetas[1]; + int j = tid; + int16_t t = s[SP(j)]; + s[SP(j)] = barrett_reduce((int16_t)(t + s[SP(j + 128)])); + s[SP(j + 128)] = fqmul(zeta, (int16_t)(s[SP(j + 128)] - t)); + } + __syncthreads(); + + /* 归一化 f = 1441 */ + { + const int16_t f = 1441; + s[SP(tid)] = fqmul(s[SP(tid)], f); + s[SP(tid + 128)] = fqmul(s[SP(tid + 128)], f); + } + __syncthreads(); + +#elif ALGORITHM == ALGO_AIGIS_ENC + + /* Aigis INVNTT: 从 len=1 开始, 使用 ntt_zetas_inv */ + /* Level 0: len=1, 128 groups */ + { + int group = tid; + int32_t zeta = ntt_zetas_inv[group]; + int base_idx = group * 2; + int32_t t = s[SP(base_idx)]; + s[SP(base_idx)] = (int16_t)(t + s[SP(base_idx + 1)]); + t -= s[SP(base_idx + 1)]; + s[SP(base_idx + 1)] = montgomery_reduce((int32_t)zeta * (int16_t)t); + } + __syncthreads(); + + /* Level 1: len=2, 64 groups, Barrett */ + { + int group = tid >> 1; + int lane = tid & 0x01; + int32_t zeta = ntt_zetas_inv[128 + group]; + int base_idx = group * 4 + lane; + int32_t t = s[SP(base_idx)]; + s[SP(base_idx)] = barrett_reduce((int16_t)(t + s[SP(base_idx + 2)])); + t -= s[SP(base_idx + 2)]; + s[SP(base_idx + 2)] = montgomery_reduce((int32_t)zeta * (int16_t)t); + } + __syncthreads(); + + /* Level 2: len=4, 32 groups */ + { + int group = tid >> 2; + int lane = tid & 0x03; + int32_t zeta = ntt_zetas_inv[192 + group]; + int base_idx = group * 8 + lane; + int32_t t = s[SP(base_idx)]; + s[SP(base_idx)] = (int16_t)(t + s[SP(base_idx + 4)]); + t -= s[SP(base_idx + 4)]; + s[SP(base_idx + 4)] = montgomery_reduce((int32_t)zeta * (int16_t)t); + } + __syncthreads(); + + /* Level 3: len=8, 16 groups, Barrett */ + { + int group = tid >> 3; + int lane = tid & 0x07; + int32_t zeta = ntt_zetas_inv[224 + group]; + int base_idx = group * 16 + lane; + int32_t t = s[SP(base_idx)]; + s[SP(base_idx)] = barrett_reduce((int16_t)(t + s[SP(base_idx + 8)])); + t -= s[SP(base_idx + 8)]; + s[SP(base_idx + 8)] = montgomery_reduce((int32_t)zeta * (int16_t)t); + } + __syncthreads(); + + /* Level 4: len=16, 8 groups */ + { + int group = tid >> 4; + int lane = tid & 0x0F; + int32_t zeta = ntt_zetas_inv[240 + group]; + int base_idx = group * 32 + lane; + int32_t t = s[SP(base_idx)]; + s[SP(base_idx)] = (int16_t)(t + s[SP(base_idx + 16)]); + t -= s[SP(base_idx + 16)]; + s[SP(base_idx + 16)] = montgomery_reduce((int32_t)zeta * (int16_t)t); + } + __syncthreads(); + + /* Level 5: len=32, 4 groups, Barrett */ + { + int group = tid >> 5; + int lane = tid & 0x1F; + int32_t zeta = ntt_zetas_inv[248 + group]; + int base_idx = group * 64 + lane; + int32_t t = s[SP(base_idx)]; + s[SP(base_idx)] = barrett_reduce((int16_t)(t + s[SP(base_idx + 32)])); + t -= s[SP(base_idx + 32)]; + s[SP(base_idx + 32)] = montgomery_reduce((int32_t)zeta * (int16_t)t); + } + __syncthreads(); + + /* Level 6: len=64, 2 groups */ + { + int group = tid >> 6; + int lane = tid & 0x3F; + int32_t zeta = ntt_zetas_inv[252 + group]; + int base_idx = group * 128 + lane; + int32_t t = s[SP(base_idx)]; + s[SP(base_idx)] = (int16_t)(t + s[SP(base_idx + 64)]); + t -= s[SP(base_idx + 64)]; + s[SP(base_idx + 64)] = montgomery_reduce((int32_t)zeta * (int16_t)t); + } + __syncthreads(); + + /* Level 7: len=128, 1 group, 含 N^{-1} 归一化 */ + { + int32_t zeta = ntt_zetas_inv[254]; + int j = tid; + int32_t t = s[SP(j)]; + /* r[j] = (r[j] + r[j+128]) * N^{-1} mod Q */ + s[SP(j)] = montgomery_reduce(256 * (t + s[SP(j + 128)])); + t -= s[SP(j + 128)]; + s[SP(j + 128)] = montgomery_reduce((int32_t)zeta * (int16_t)t); + } + __syncthreads(); + +#endif /* ALGORITHM for INVNTT */ + + base[tid] = s[SP(tid)]; + base[tid + 128] = s[SP(tid + 128)]; +} + +/* ================================================================ + * Host 启动封装 + * ================================================================ */ + +static inline void launch_batch_ntt(int16_t *d_polys, int batch_count, + cudaStream_t stream = 0) +{ + batch_ntt_kernel<<>>(d_polys, batch_count); +} + +static inline void launch_batch_invntt(int16_t *d_polys, int batch_count, + cudaStream_t stream = 0) +{ + batch_invntt_kernel<<>>(d_polys, batch_count); +} + +#undef SP +#undef SPAD + +#endif /* BATCH_NTT_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/batch_ops.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/batch_ops.cuh new file mode 100644 index 000000000..2e8316cc9 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/batch_ops.cuh @@ -0,0 +1,243 @@ +/* + * batch_ops.cuh — 批量多项式运算 kernel + * + * 256 threads/block, 每 thread 处理 1 个系数 + * SoA 内存布局: data[poly_idx * batch_count * N + inst * N + coeff] + * + * 提供的 kernel: + * batch_poly_add_kernel — 向量加法 + * batch_poly_sub_kernel — 向量减法 + * batch_poly_reduce_kernel — Barrett 约减 + * batch_poly_caddq_kernel — 加 Q 规范化到 [0,Q) + * batch_poly_caddq2_kernel — 双重 caddq 规范化 + * batch_polyvec_matvec_acc_kernel — 矩阵向量乘 (含 basemul 或逐点乘) + */ + +#ifndef BATCH_OPS_CUH +#define BATCH_OPS_CUH + +#include "rocm_compat.h" +#include +#include "params.h" +#include "reduce.cuh" +#include "ntt.cuh" + +#define BATCH_TPB 256 /* threads per block for batch poly ops */ + +/* ================================================================ + * 基础逐系数运算 + * SoA 格式: arrays[inst_idx * N + coeff_idx] (单多项式批次) + * ================================================================ */ + +__global__ void batch_poly_add_kernel( + int16_t * __restrict__ c, /* output */ + const int16_t * __restrict__ a, /* input 1 */ + const int16_t * __restrict__ b, /* input 2 */ + int batch_count) +{ + int idx = blockIdx.x * BATCH_TPB + threadIdx.x; + if (idx < batch_count * PARAM_N) + c[idx] = a[idx] + b[idx]; +} + +__global__ void batch_poly_sub_kernel( + int16_t * __restrict__ c, + const int16_t * __restrict__ a, + const int16_t * __restrict__ b, + int batch_count) +{ + int idx = blockIdx.x * BATCH_TPB + threadIdx.x; + if (idx < batch_count * PARAM_N) + c[idx] = a[idx] - b[idx]; +} + +__global__ void batch_poly_reduce_kernel( + int16_t * __restrict__ r, + int batch_count) +{ + int idx = blockIdx.x * BATCH_TPB + threadIdx.x; + if (idx < batch_count * PARAM_N) + r[idx] = barrett_reduce(r[idx]); +} + +__global__ void batch_poly_caddq_kernel( + int16_t * __restrict__ r, + int batch_count) +{ + int idx = blockIdx.x * BATCH_TPB + threadIdx.x; + if (idx < batch_count * PARAM_N) + r[idx] = caddq(r[idx]); +} + +__global__ void batch_poly_caddq2_kernel( + int16_t * __restrict__ r, + int batch_count) +{ + int idx = blockIdx.x * BATCH_TPB + threadIdx.x; + if (idx < batch_count * PARAM_N) + r[idx] = caddq2(r[idx]); +} + +/* ================================================================ + * 批量矩阵向量乘 (2D grid: gridDim.x=batch_count, gridDim.y=K) + * + * SoA 格式: + * mat[row * K * batch_count * N + col * batch_count * N + inst * N + c] + * vec[col * batch_count * N + inst * N + c] + * out[row * batch_count * N + inst * N + c] + * + * 每个 block 处理一个 (batch_inst, row) 的输出多项式的一个系数 + * threadIdx.x: 系数索引 (0..N-1), 256 threads + * + * 内积方式: + * Kyber: basemul (4 coeffs at a time with ±zeta) + * Aigis-enc: pointwise fqmul (直接逐点乘) + * ================================================================ */ + +__global__ void batch_polyvec_matvec_kernel( + int16_t * __restrict__ d_out, /* K * B * N */ + const int16_t * __restrict__ d_mat, /* K * K * B * N, SoA */ + const int16_t * __restrict__ d_vec, /* K * B * N, SoA */ + int batch_count) +{ + int inst = blockIdx.x; /* 批次实例索引 */ + int row = blockIdx.y; /* 输出行 (0..K-1) */ + int c = threadIdx.x; /* 系数索引 (0..N-1, 256 threads) */ + + if (inst >= batch_count) return; + +#if ALGORITHM == ALGO_KYBER + + /* Kyber basemul 域内积 + * NTT 最后一级产生 64 组四元素 [4i,4i+1,4i+2,4i+3]: + * pair0: indices [4i, 4i+1] in Z_q[x]/(x^2 - zeta[64+i]) + * pair1: indices [4i+2,4i+3] in Z_q[x]/(x^2 + zeta[64+i]) + * basemul: r[0] = fqmul(fqmul(a1,b1),zeta) + fqmul(a0,b0) + * r[1] = fqmul(a0,b1) + fqmul(a1,b0) */ + + int quad = c >> 2; /* group i = c/4, 0..63 */ + int local = c & 3; /* 0,1,2,3 within 4-group */ + int c_even = c & ~1; /* floor to even: 4i or 4i+2 */ + int c_odd = c | 1; /* ceil to odd: 4i+1 or 4i+3 */ + /* zeta: +zeta[64+quad] for local=0,1; -zeta[64+quad] for local=2,3 */ + int16_t zeta_raw = ntt_zetas[64 + quad]; + int16_t zeta = (local < 2) ? zeta_raw : (int16_t)(-zeta_raw); + + int16_t acc = 0; + for (int col = 0; col < PARAM_K; col++) { + size_t base_m = ((size_t)(row * PARAM_K + col) * batch_count + inst) * PARAM_N; + size_t base_v = ((size_t)col * batch_count + inst) * PARAM_N; + int16_t a0 = d_mat[base_m + c_even]; + int16_t a1 = d_mat[base_m + c_odd]; + int16_t b0 = d_vec[base_v + c_even]; + int16_t b1 = d_vec[base_v + c_odd]; + + if (local & 1) { + /* r[c_odd] = a0*b1 + a1*b0 */ + acc = (int16_t)(acc + fqmul(a0, b1) + fqmul(a1, b0)); + } else { + /* r[c_even] = fqmul(fqmul(a1,b1), zeta) + fqmul(a0,b0) */ + acc = (int16_t)(acc + fqmul(fqmul(a1, b1), zeta) + fqmul(a0, b0)); + } + } + d_out[((size_t)(row * batch_count) + inst) * PARAM_N + c] = barrett_reduce(acc); + +#elif ALGORITHM == ALGO_AIGIS_ENC + + /* Aigis: 逐点乘, 累加 K 项后 Montgomery 约减一次 + * 最大值: K * Q^2 = 4 * 7681^2 ≈ 236M < 2^31 (int32_t 安全) */ + int32_t acc = 0; + for (int col = 0; col < PARAM_K; col++) { + int16_t av = d_mat[((size_t)(row * PARAM_K + col) * batch_count + inst) * PARAM_N + c]; + int16_t bv = d_vec[((size_t)col * batch_count + inst) * PARAM_N + c]; + acc += (int32_t)av * bv; + } + d_out[((size_t)(row * batch_count) + inst) * PARAM_N + c] = montgomery_reduce(acc); + +#endif +} + +/* ================================================================ + * 批量 frommsg / tomsg kernel (SoA: [inst * N + coeff]) + * ================================================================ */ + +__global__ void batch_poly_frommsg_kernel( + int16_t * __restrict__ d_poly, /* B * N */ + const uint8_t * __restrict__ d_msgs, /* B * N/8 */ + int batch_count) +{ + int inst = blockIdx.x; + int c = threadIdx.x; /* 0..N-1 */ + if (inst >= batch_count) return; + + int byte_idx = c >> 3; + int bit_idx = c & 7; + uint8_t bit = (d_msgs[inst * (PARAM_N / 8) + byte_idx] >> bit_idx) & 1; + int16_t mask = -(int16_t)bit; /* 0x0000 or 0xFFFF */ + d_poly[inst * PARAM_N + c] = mask & (int16_t)((PARAM_Q + 1) / 2); +} + +__global__ void batch_poly_tomsg_kernel( + uint8_t * __restrict__ d_msgs, /* B * N/8 */ + const int16_t * __restrict__ d_poly, /* B * N */ + int batch_count) +{ + int inst = blockIdx.x; + int c = threadIdx.x; /* 0..N-1 */ + if (inst >= batch_count) return; + + int16_t t = d_poly[inst * PARAM_N + c]; + t = caddq(t); + uint8_t bit = (uint8_t)(((((int32_t)t << 1) + PARAM_Q / 2 + 1) / PARAM_Q) & 1); + + /* 原子或写入 bit */ + int byte_idx = c >> 3; + int bit_idx = c & 7; + atomicOr((unsigned int *)(d_msgs + inst * (PARAM_N / 8) + (byte_idx & ~3)), + (unsigned int)((unsigned int)bit << ((byte_idx & 3) * 8 + bit_idx))); +} + +/* ================================================================ + * Host 启动封装 + * ================================================================ */ + +static inline int ceil_div(int a, int b) { return (a + b - 1) / b; } + +static inline void launch_batch_add(int16_t *d_c, const int16_t *d_a, const int16_t *d_b, + int batch_count, cudaStream_t stream = 0) { + batch_poly_add_kernel<<>>( + d_c, d_a, d_b, batch_count); +} + +static inline void launch_batch_sub(int16_t *d_c, const int16_t *d_a, const int16_t *d_b, + int batch_count, cudaStream_t stream = 0) { + batch_poly_sub_kernel<<>>( + d_c, d_a, d_b, batch_count); +} + +static inline void launch_batch_reduce(int16_t *d_r, int batch_count, cudaStream_t stream = 0) { + batch_poly_reduce_kernel<<>>( + d_r, batch_count); +} + +static inline void launch_batch_caddq(int16_t *d_r, int batch_count, cudaStream_t stream = 0) { + batch_poly_caddq_kernel<<>>( + d_r, batch_count); +} + +static inline void launch_batch_caddq2(int16_t *d_r, int batch_count, cudaStream_t stream = 0) { + batch_poly_caddq2_kernel<<>>( + d_r, batch_count); +} + +static inline void launch_batch_matvec( + int16_t *d_out, const int16_t *d_mat, const int16_t *d_vec, + int batch_count, cudaStream_t stream = 0) +{ + /* 2D grid: (batch_count, K), 256 threads */ + dim3 grid(batch_count, PARAM_K); + batch_polyvec_matvec_kernel<<>>( + d_out, d_mat, d_vec, batch_count); +} + +#endif /* BATCH_OPS_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/build_all.sh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/build_all.sh new file mode 100644 index 000000000..16e00b897 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/build_all.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +set -euo pipefail + +NVCC="${NVCC:-nvcc}" +SRC="main.cu" +FILTER="${1:-}" + +# RTX 4090 = Ada Lovelace / sm_89. Override with: +# CUDA_ARCH=sm_86 ./build_all.sh +CUDA_ARCH="${CUDA_ARCH:-sm_89}" +KEM_SERIAL_TPB="${KEM_SERIAL_TPB:-64}" +FLAGS=(-O3 -std=c++14 --expt-relaxed-constexpr -DKEM_SERIAL_TPB="${KEM_SERIAL_TPB}") +ARCH=(-arch="${CUDA_ARCH}") + +if ! command -v "${NVCC}" >/dev/null 2>&1; then + echo "[error] nvcc not found in PATH" >&2 + exit 1 +fi + +targets=( + "kyber512 1 2" + "kyber768 1 3" + "kyber1024 1 4" + "aigisenc1 2 1" + "aigisenc2 2 2" + "aigisenc3 2 3" + "aigisenc4 2 4" +) + +for entry in "${targets[@]}"; do + read -r name alg mode <<<"${entry}" + if [[ -n "${FILTER}" && "${FILTER}" != "${name}" ]]; then + continue + fi + + echo "[build] ${name} (ALGORITHM=${alg} PARAM_MODE=${mode} ARCH=${CUDA_ARCH} KEM_SERIAL_TPB=${KEM_SERIAL_TPB})" + "${NVCC}" "${FLAGS[@]}" "${ARCH[@]}" \ + -DALGORITHM="${alg}" -DPARAM_MODE="${mode}" \ + -o "${name}" "${SRC}" + echo "[ok] ${name}" +done + +echo +echo "build complete" diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/build_hip.sh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/build_hip.sh new file mode 100644 index 000000000..83495dc8a --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/build_hip.sh @@ -0,0 +1,102 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd "$(dirname "$0")" + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +HIPCC=${HIPCC:-hipcc} +ROCM_ARCH=${ROCM_ARCH:-gfx1100} +KEM_SERIAL_TPB=${KEM_SERIAL_TPB:-64} +KEM_KEYGEN_TPB=${KEM_KEYGEN_TPB:-${KEM_SERIAL_TPB}} +KEM_ENCAPS_TPB=${KEM_ENCAPS_TPB:-${KEM_SERIAL_TPB}} +KEM_DECAPS_TPB=${KEM_DECAPS_TPB:-${KEM_SERIAL_TPB}} +KEM_KEYPAIR_LAUNCH_BOUNDS=${KEM_KEYPAIR_LAUNCH_BOUNDS:-1} +KEM_ENCAPS_LAUNCH_BOUNDS=${KEM_ENCAPS_LAUNCH_BOUNDS:-} +KEM_DECAPS_LAUNCH_BOUNDS=${KEM_DECAPS_LAUNCH_BOUNDS:-} +WP_KG_WARPS_BLOCK=${WP_KG_WARPS_BLOCK:-4} +KEM_PACK_TPB=${KEM_PACK_TPB:-128} +BUILD_TYPE=${BUILD_TYPE:-Release} +CXX_STD=${CXX_STD:-c++17} +ROCM_WAVE32_FLAG=${ROCM_WAVE32_FLAG:-} +OPT_LEVEL=${OPT_LEVEL:-} +EXTRA_HIPCC_FLAGS=${EXTRA_HIPCC_FLAGS:-} + +if [[ "${BUILD_TYPE}" == "Debug" ]]; then + OPT_FLAGS=(-O0 -g) +else + if [[ -n "${OPT_LEVEL}" ]]; then + OPT_FLAGS=("-${OPT_LEVEL}") + else + OPT_FLAGS=(-O2) + fi +fi + +COMMON_FLAGS=( + "${OPT_FLAGS[@]}" + -std="${CXX_STD}" + -x + hip + --offload-arch="${ROCM_ARCH}" + -DKEM_SERIAL_TPB="${KEM_SERIAL_TPB}" + -DKEM_KEYGEN_TPB="${KEM_KEYGEN_TPB}" + -DKEM_ENCAPS_TPB="${KEM_ENCAPS_TPB}" + -DKEM_DECAPS_TPB="${KEM_DECAPS_TPB}" + -DKEM_KEYPAIR_LAUNCH_BOUNDS="${KEM_KEYPAIR_LAUNCH_BOUNDS}" + -DWP_KG_WARPS_BLOCK="${WP_KG_WARPS_BLOCK}" + -DKEM_PACK_TPB="${KEM_PACK_TPB}" +) + +if [[ -n "${KEM_ENCAPS_LAUNCH_BOUNDS}" ]]; then + COMMON_FLAGS+=(-DKEM_ENCAPS_LAUNCH_BOUNDS="${KEM_ENCAPS_LAUNCH_BOUNDS}") +fi + +if [[ -n "${KEM_DECAPS_LAUNCH_BOUNDS}" ]]; then + COMMON_FLAGS+=(-DKEM_DECAPS_LAUNCH_BOUNDS="${KEM_DECAPS_LAUNCH_BOUNDS}") +fi + +if [[ -n "${ROCM_WAVE32_FLAG}" ]]; then + COMMON_FLAGS+=("${ROCM_WAVE32_FLAG}") +fi + +if [[ -n "${EXTRA_HIPCC_FLAGS}" ]]; then + # shellcheck disable=SC2206 + EXTRA_FLAGS_ARRAY=(${EXTRA_HIPCC_FLAGS}) + COMMON_FLAGS+=("${EXTRA_FLAGS_ARRAY[@]}") +fi + +declare -a TARGETS=( + "kyber512:1:2" + "kyber768:1:3" + "kyber1024:1:4" + "aigisenc1:2:1" + "aigisenc2:2:2" + "aigisenc3:2:3" + "aigisenc4:2:4" +) + +FILTER=${1:-} + +if ! command -v "${HIPCC}" >/dev/null 2>&1; then + echo "[错误] 未找到 hipcc,请先安装 ROCm 并把 hipcc 加入 PATH" + exit 1 +fi + +mkdir -p amd_results/build + +for spec in "${TARGETS[@]}"; do + IFS=':' read -r name alg mode <<<"${spec}" + if [[ -n "${FILTER}" && "${name}" != "${FILTER}" ]]; then + continue + fi + + out="${name}_amd" + echo "[build] ${out} (ALGORITHM=${alg} PARAM_MODE=${mode}, arch=${ROCM_ARCH}, opt=${OPT_FLAGS[*]}, KEM_TPB=${KEM_KEYGEN_TPB}/${KEM_ENCAPS_TPB}/${KEM_DECAPS_TPB}, bounds=${KEM_KEYPAIR_LAUNCH_BOUNDS}/${KEM_ENCAPS_LAUNCH_BOUNDS:-default}/${KEM_DECAPS_LAUNCH_BOUNDS:-default}, wpkg=${WP_KG_WARPS_BLOCK}, pack=${KEM_PACK_TPB})" + "${HIPCC}" "${COMMON_FLAGS[@]}" \ + -DALGORITHM="${alg}" -DPARAM_MODE="${mode}" \ + -o "${out}" main.cu \ + 2>&1 | tee "amd_results/build/${out}.log" +done + +echo +echo "HIP 构建完成" diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/cbd.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/cbd.cuh new file mode 100644 index 000000000..f658bca51 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/cbd.cuh @@ -0,0 +1,352 @@ +/* + * cbd.cuh — 中心二项分布 (CBD) 噪声采样 + * + * 支持 eta = 1, 2, 3, 4, 8 (Kyber 和 Aigis-enc 共用) + * + * 各算法用到的 eta 值: + * Kyber: ETA1 = 2 (Kyber768/1024) 或 3 (Kyber512); ETA2 = 2 + * Aigis-enc: ETA_S=1/2/3/4, ETA_E_KG=4/4/4/8, ETA_E_ENC=4/4/4/8, ETA_E2 + * + * 实现方法: + * 对 eta 个比特对的求和: sum_{i=0}^{eta-1} (bit_a_i - bit_b_i) + * 其中 a 和 b 是两组独立的随机比特 + */ + +#ifndef CBD_CUH +#define CBD_CUH + +#include +#include "params.h" +#include "fips202.cuh" + +/* ================================================================ + * 底层 CBD 函数 (根据 eta 分发) + * ================================================================ */ + +/* eta=1: 每字节产生 4 个系数 (每两位) */ +static __device__ void cbd1(int16_t *r, const uint8_t *buf, unsigned int len) +{ + unsigned int pos = 0, i; + for (i = 0; i + 3 < len * 4 && pos < (unsigned)PARAM_N; i += 4) { + uint8_t b = buf[i / 4]; + r[pos++] = (int16_t)(((b >> 0) & 1) - ((b >> 1) & 1)); + r[pos++] = (int16_t)(((b >> 2) & 1) - ((b >> 3) & 1)); + r[pos++] = (int16_t)(((b >> 4) & 1) - ((b >> 5) & 1)); + r[pos++] = (int16_t)(((b >> 6) & 1) - ((b >> 7) & 1)); + } +} + +/* eta=2: 每 4 字节产生 8 个系数 */ +static __device__ void cbd2(int16_t *r, const uint8_t *buf) +{ + /* 处理 256 个系数, 需 256*4/8=128 字节 */ + for (unsigned int i = 0; i < PARAM_N / 8; i++) { + uint32_t t = (uint32_t)buf[4*i+0] | ((uint32_t)buf[4*i+1] << 8) | + ((uint32_t)buf[4*i+2] << 16) | ((uint32_t)buf[4*i+3] << 24); + uint32_t d = t & 0x55555555; + d += (t >> 1) & 0x55555555; + r[8*i+0] = (int16_t)(((d >> 0) & 0x3) - ((d >> 2) & 0x3)); + r[8*i+1] = (int16_t)(((d >> 4) & 0x3) - ((d >> 6) & 0x3)); + r[8*i+2] = (int16_t)(((d >> 8) & 0x3) - ((d >> 10) & 0x3)); + r[8*i+3] = (int16_t)(((d >> 12) & 0x3) - ((d >> 14) & 0x3)); + r[8*i+4] = (int16_t)(((d >> 16) & 0x3) - ((d >> 18) & 0x3)); + r[8*i+5] = (int16_t)(((d >> 20) & 0x3) - ((d >> 22) & 0x3)); + r[8*i+6] = (int16_t)(((d >> 24) & 0x3) - ((d >> 26) & 0x3)); + r[8*i+7] = (int16_t)(((d >> 28) & 0x3) - ((d >> 30) & 0x3)); + } +} + +/* eta=3: 每 3 字节产生 4 个系数 */ +static __device__ void cbd3(int16_t *r, const uint8_t *buf) +{ + /* 需 256 * 3 / 4 * ... 实际需 192 字节 */ + for (unsigned int i = 0; i < PARAM_N / 4; i++) { + uint32_t a, b; + uint32_t t = (uint32_t)buf[3*i+0] | ((uint32_t)buf[3*i+1] << 8) | ((uint32_t)buf[3*i+2] << 16); + a = t & 0x249249; + a += (t >> 1) & 0x249249; + a += (t >> 2) & 0x249249; + b = (t >> 3) & 0x249249; + b += (t >> 4) & 0x249249; + b += (t >> 5) & 0x249249; + r[4*i] = (int16_t)(((a >> 0) & 0x7) - ((b >> 0) & 0x7)); + r[4*i+1] = (int16_t)(((a >> 6) & 0x7) - ((b >> 6) & 0x7)); + r[4*i+2] = (int16_t)(((a >> 12) & 0x7) - ((b >> 12) & 0x7)); + r[4*i+3] = (int16_t)(((a >> 18) & 0x7) - ((b >> 18) & 0x7)); + } +} + +/* eta=4: 每 2 字节产生 2 个系数 */ +static __device__ void cbd4(int16_t *r, const uint8_t *buf) +{ + for (unsigned int i = 0; i < PARAM_N / 2; i++) { + uint32_t t = (uint32_t)buf[2*i] | ((uint32_t)buf[2*i+1] << 8); + uint32_t a = t & 0x1111; + a += (t >> 1) & 0x1111; + a += (t >> 2) & 0x1111; + a += (t >> 3) & 0x1111; + uint32_t b = (t >> 8) & 0x1111; + b += (t >> 9) & 0x1111; + b += (t >> 10) & 0x1111; + b += (t >> 11) & 0x1111; + r[2*i] = (int16_t)((a & 0xF) - (b & 0xF)); + r[2*i+1] = (int16_t)(((a >> 4) & 0xF) - ((b >> 4) & 0xF)); + } +} + +/* eta=8: 每字节产生 1 个系数 (popcount 差) */ +static __device__ void cbd8(int16_t *r, const uint8_t *buf) +{ + for (unsigned int i = 0; i < PARAM_N; i++) { + uint8_t a = buf[2*i]; + uint8_t b = buf[2*i+1]; + r[i] = (int16_t)(__popc((unsigned)a) - __popc((unsigned)b)); + } +} + +/* ================================================================ + * SHAKE256 PRF: out = SHAKE256(seed || nonce, outlen) + * ================================================================ */ +static __device__ __noinline__ void prf_shake256(uint8_t *out, size_t outlen, + const uint8_t *seed, uint8_t nonce) +{ + uint64_t s[25]; + for (unsigned int i = 0; i < 25; i++) s[i] = 0; + for (unsigned int i = 0; i < PARAM_SYMBYTES; i++) + s[i >> 3] ^= (uint64_t)seed[i] << (8 * (i & 7)); + s[PARAM_SYMBYTES >> 3] ^= (uint64_t)nonce << (8 * (PARAM_SYMBYTES & 7)); + s[(PARAM_SYMBYTES + 1) >> 3] ^= (uint64_t)0x1F << (8 * ((PARAM_SYMBYTES + 1) & 7)); + s[(SHAKE256_RATE - 1) >> 3] ^= 1ULL << 63; + + size_t nblocks = outlen / SHAKE256_RATE; + keccak_squeezeblocks(out, nblocks, s, SHAKE256_RATE); + outlen -= nblocks * SHAKE256_RATE; + out += nblocks * SHAKE256_RATE; + if (outlen) { + KeccakF1600_StatePermute(s); + for (size_t i = 0; i < outlen; i++) + out[i] = (uint8_t)(s[i >> 3] >> (8 * (i & 7))); + } +} + +/* ================================================================ + * 统一噪声采样接口 + * getnoise_eta(r, seed, nonce, eta): + * 使用 SHAKE256(seed||nonce) 生成 CBD(eta) 多项式 + * ================================================================ */ +#ifndef KEM_DIRECT_CBD +#define KEM_DIRECT_CBD 1 +#endif + +#if KEM_DIRECT_CBD +typedef struct { + uint64_t s[25]; + unsigned int pos; +} prf_reader; + +static __device__ __forceinline__ void prf_reader_init(prf_reader *rd, + const uint8_t *seed, uint8_t nonce) +{ + for (unsigned int i = 0; i < 25; i++) rd->s[i] = 0; + for (unsigned int i = 0; i < PARAM_SYMBYTES; i++) + rd->s[i >> 3] ^= (uint64_t)seed[i] << (8 * (i & 7)); + rd->s[PARAM_SYMBYTES >> 3] ^= (uint64_t)nonce << (8 * (PARAM_SYMBYTES & 7)); + rd->s[(PARAM_SYMBYTES + 1) >> 3] ^= (uint64_t)0x1F << (8 * ((PARAM_SYMBYTES + 1) & 7)); + rd->s[(SHAKE256_RATE - 1) >> 3] ^= 1ULL << 63; + rd->pos = SHAKE256_RATE; +} + +static __device__ __forceinline__ uint8_t prf_reader_u8(prf_reader *rd) +{ + if (rd->pos == SHAKE256_RATE) { + KeccakF1600_StatePermute(rd->s); + rd->pos = 0; + } + uint8_t v = (uint8_t)(rd->s[rd->pos >> 3] >> (8 * (rd->pos & 7))); + rd->pos++; + return v; +} + +static __device__ __forceinline__ uint16_t prf_reader_u16(prf_reader *rd) +{ + uint16_t b0 = prf_reader_u8(rd); + uint16_t b1 = prf_reader_u8(rd); + return (uint16_t)(b0 | (b1 << 8)); +} + +static __device__ __forceinline__ uint32_t prf_reader_u24(prf_reader *rd) +{ + uint32_t b0 = prf_reader_u8(rd); + uint32_t b1 = prf_reader_u8(rd); + uint32_t b2 = prf_reader_u8(rd); + return b0 | (b1 << 8) | (b2 << 16); +} + +static __device__ __forceinline__ uint32_t prf_reader_u32(prf_reader *rd) +{ + uint32_t b0 = prf_reader_u8(rd); + uint32_t b1 = prf_reader_u8(rd); + uint32_t b2 = prf_reader_u8(rd); + uint32_t b3 = prf_reader_u8(rd); + return b0 | (b1 << 8) | (b2 << 16) | (b3 << 24); +} + +static __device__ __noinline__ void getnoise_eta1(int16_t *r, const uint8_t *seed, uint8_t nonce) +{ + prf_reader rd; + prf_reader_init(&rd, seed, nonce); + for (unsigned int i = 0; i < PARAM_N / 4; i++) { + uint8_t b = prf_reader_u8(&rd); + r[4*i+0] = (int16_t)(((b >> 0) & 1) - ((b >> 1) & 1)); + r[4*i+1] = (int16_t)(((b >> 2) & 1) - ((b >> 3) & 1)); + r[4*i+2] = (int16_t)(((b >> 4) & 1) - ((b >> 5) & 1)); + r[4*i+3] = (int16_t)(((b >> 6) & 1) - ((b >> 7) & 1)); + } +} + +static __device__ __noinline__ void getnoise_eta2(int16_t *r, const uint8_t *seed, uint8_t nonce) +{ + prf_reader rd; + prf_reader_init(&rd, seed, nonce); + for (unsigned int i = 0; i < PARAM_N / 8; i++) { + uint32_t t = prf_reader_u32(&rd); + uint32_t d = t & 0x55555555; + d += (t >> 1) & 0x55555555; + r[8*i+0] = (int16_t)(((d >> 0) & 0x3) - ((d >> 2) & 0x3)); + r[8*i+1] = (int16_t)(((d >> 4) & 0x3) - ((d >> 6) & 0x3)); + r[8*i+2] = (int16_t)(((d >> 8) & 0x3) - ((d >> 10) & 0x3)); + r[8*i+3] = (int16_t)(((d >> 12) & 0x3) - ((d >> 14) & 0x3)); + r[8*i+4] = (int16_t)(((d >> 16) & 0x3) - ((d >> 18) & 0x3)); + r[8*i+5] = (int16_t)(((d >> 20) & 0x3) - ((d >> 22) & 0x3)); + r[8*i+6] = (int16_t)(((d >> 24) & 0x3) - ((d >> 26) & 0x3)); + r[8*i+7] = (int16_t)(((d >> 28) & 0x3) - ((d >> 30) & 0x3)); + } +} + +static __device__ __noinline__ void getnoise_eta3(int16_t *r, const uint8_t *seed, uint8_t nonce) +{ + prf_reader rd; + prf_reader_init(&rd, seed, nonce); + for (unsigned int i = 0; i < PARAM_N / 4; i++) { + uint32_t t = prf_reader_u24(&rd); + uint32_t a = t & 0x249249; + a += (t >> 1) & 0x249249; + a += (t >> 2) & 0x249249; + uint32_t b = (t >> 3) & 0x249249; + b += (t >> 4) & 0x249249; + b += (t >> 5) & 0x249249; + r[4*i+0] = (int16_t)(((a >> 0) & 0x7) - ((b >> 0) & 0x7)); + r[4*i+1] = (int16_t)(((a >> 6) & 0x7) - ((b >> 6) & 0x7)); + r[4*i+2] = (int16_t)(((a >> 12) & 0x7) - ((b >> 12) & 0x7)); + r[4*i+3] = (int16_t)(((a >> 18) & 0x7) - ((b >> 18) & 0x7)); + } +} + +static __device__ __noinline__ void getnoise_eta4(int16_t *r, const uint8_t *seed, uint8_t nonce) +{ + prf_reader rd; + prf_reader_init(&rd, seed, nonce); + for (unsigned int i = 0; i < PARAM_N / 2; i++) { + uint32_t t = prf_reader_u16(&rd); + uint32_t a = t & 0x1111; + a += (t >> 1) & 0x1111; + a += (t >> 2) & 0x1111; + a += (t >> 3) & 0x1111; + uint32_t b = (t >> 8) & 0x1111; + b += (t >> 9) & 0x1111; + b += (t >> 10) & 0x1111; + b += (t >> 11) & 0x1111; + r[2*i+0] = (int16_t)((a & 0xF) - (b & 0xF)); + r[2*i+1] = (int16_t)(((a >> 4) & 0xF) - ((b >> 4) & 0xF)); + } +} + +static __device__ __noinline__ void getnoise_eta8(int16_t *r, const uint8_t *seed, uint8_t nonce) +{ + prf_reader rd; + prf_reader_init(&rd, seed, nonce); + for (unsigned int i = 0; i < PARAM_N; i++) { + uint8_t a = prf_reader_u8(&rd); + uint8_t b = prf_reader_u8(&rd); + r[i] = (int16_t)(__popc((unsigned)a) - __popc((unsigned)b)); + } +} + +#else + +static __device__ __noinline__ void getnoise_eta1(int16_t *r, const uint8_t *seed, uint8_t nonce) +{ + uint8_t buf[1 * 64]; + prf_shake256(buf, sizeof(buf), seed, nonce); + cbd1(r, buf, (unsigned int)sizeof(buf)); +} + +static __device__ __noinline__ void getnoise_eta2(int16_t *r, const uint8_t *seed, uint8_t nonce) +{ + uint8_t buf[2 * 64]; + prf_shake256(buf, sizeof(buf), seed, nonce); + cbd2(r, buf); +} + +static __device__ __noinline__ void getnoise_eta3(int16_t *r, const uint8_t *seed, uint8_t nonce) +{ + uint8_t buf[3 * 64]; + prf_shake256(buf, sizeof(buf), seed, nonce); + cbd3(r, buf); +} + +static __device__ __noinline__ void getnoise_eta4(int16_t *r, const uint8_t *seed, uint8_t nonce) +{ + uint8_t buf[4 * 64]; + prf_shake256(buf, sizeof(buf), seed, nonce); + cbd4(r, buf); +} + +static __device__ __noinline__ void getnoise_eta8(int16_t *r, const uint8_t *seed, uint8_t nonce) +{ + uint8_t buf[8 * 64]; + prf_shake256(buf, sizeof(buf), seed, nonce); + cbd8(r, buf); +} + +#endif + +#define DISPATCH_GETNOISE_ETA(ETA, R, SEED, NONCE) do { \ + if ((ETA) == 1) getnoise_eta1((R), (SEED), (NONCE)); \ + else if ((ETA) == 2) getnoise_eta2((R), (SEED), (NONCE)); \ + else if ((ETA) == 3) getnoise_eta3((R), (SEED), (NONCE)); \ + else if ((ETA) == 4) getnoise_eta4((R), (SEED), (NONCE)); \ + else if ((ETA) == 8) getnoise_eta8((R), (SEED), (NONCE)); \ +} while (0) + +/* ================================================================ + * 算法特定噪声采样宏 (使用 params.h 中的 eta 参数) + * ================================================================ */ + +/* 密钥生成: 秘密向量 s 的噪声 */ +static __device__ void poly_getnoise_s(int16_t *r, const uint8_t *seed, uint8_t nonce) +{ + DISPATCH_GETNOISE_ETA(PARAM_ETA_S, r, seed, nonce); +} + +/* 密钥生成: 错误向量 e 的噪声 */ +static __device__ void poly_getnoise_e_kg(int16_t *r, const uint8_t *seed, uint8_t nonce) +{ + DISPATCH_GETNOISE_ETA(PARAM_ETA_E_KG, r, seed, nonce); +} + +/* 加密: 随机向量 r 的噪声 */ +static __device__ void poly_getnoise_e_enc(int16_t *r, const uint8_t *seed, uint8_t nonce) +{ + DISPATCH_GETNOISE_ETA(PARAM_ETA_E_ENC, r, seed, nonce); +} + +/* 加密: 标量误差 e2 的噪声 */ +static __device__ void poly_getnoise_e2(int16_t *r, const uint8_t *seed, uint8_t nonce) +{ + DISPATCH_GETNOISE_ETA(PARAM_ETA_E2, r, seed, nonce); +} + +#undef DISPATCH_GETNOISE_ETA + +#endif /* CBD_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/config.h b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/config.h new file mode 100644 index 000000000..c19663b1b --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/config.h @@ -0,0 +1,41 @@ +/* + * config.h — 算法选择 + * + * 设置方法: 编译时传入 -DALGORITHM=ALGO_KYBER 或 -DALGORITHM=ALGO_AIGIS_ENC + * 以及 -DPARAM_MODE= + * + * Kyber 模式: + * -DPARAM_MODE=2 -> Kyber512 (K=2) + * -DPARAM_MODE=3 -> Kyber768 (K=3) + * -DPARAM_MODE=4 -> Kyber1024 (K=4) + * + * Aigis-enc 模式: + * -DPARAM_MODE=1 -> Aigis-enc-1 (K=2) + * -DPARAM_MODE=2 -> Aigis-enc-2 (K=3, low) + * -DPARAM_MODE=3 -> Aigis-enc-3 (K=3, med) + * -DPARAM_MODE=4 -> Aigis-enc-4 (K=4, high) + */ + +#ifndef CONFIG_H +#define CONFIG_H + +#define ALGO_KYBER 1 +#define ALGO_AIGIS_ENC 2 + +#ifndef ALGORITHM +#define ALGORITHM ALGO_KYBER +#endif + +#ifndef PARAM_MODE +#if ALGORITHM == ALGO_KYBER +#define PARAM_MODE 3 /* Kyber768 */ +#else +#define PARAM_MODE 4 /* Aigis-enc-4 */ +#endif +#endif + +#if ALGORITHM != ALGO_KYBER && ALGORITHM != ALGO_AIGIS_ENC +#error "ALGORITHM must be ALGO_KYBER or ALGO_AIGIS_ENC" +#endif + +#endif /* CONFIG_H */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/fips202.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/fips202.cuh new file mode 100644 index 000000000..1eb72586e --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/fips202.cuh @@ -0,0 +1,287 @@ +/* + * Aigis-enc GPU - FIPS 202 (SHA3/SHAKE) 设备函数 + * 与算法无关,Kyber 和 Aigis-enc 共用同一实现 + */ +#ifndef FIPS202_CUH +#define FIPS202_CUH + +#include + +#define SHAKE128_RATE 168 +#define SHAKE256_RATE 136 +#define SHA3_256_RATE 136 +#define SHA3_512_RATE 72 +#define NROUNDS 24 +#define ROL(a, offset) (((a) << (offset)) ^ ((a) >> (64-(offset)))) + +typedef struct { + uint64_t s[25]; + unsigned int pos; +} keccak_state; + +__constant__ uint64_t gpu_KeccakF_RoundConstants[NROUNDS] = { + 0x0000000000000001ULL, 0x0000000000008082ULL, + 0x800000000000808aULL, 0x8000000080008000ULL, + 0x000000000000808bULL, 0x0000000080000001ULL, + 0x8000000080008081ULL, 0x8000000000008009ULL, + 0x000000000000008aULL, 0x0000000000000088ULL, + 0x0000000080008009ULL, 0x000000008000000aULL, + 0x000000008000808bULL, 0x800000000000008bULL, + 0x8000000000008089ULL, 0x8000000000008003ULL, + 0x8000000000008002ULL, 0x8000000000000080ULL, + 0x000000000000800aULL, 0x800000008000000aULL, + 0x8000000080008081ULL, 0x8000000000008080ULL, + 0x0000000080000001ULL, 0x8000000080008008ULL +}; + +static __device__ __forceinline__ uint64_t gpu_load64(const uint8_t *x) +{ + uint64_t r = 0; + for (int i = 0; i < 8; ++i) r |= (uint64_t)x[i] << (8 * i); + return r; +} + +static __device__ __forceinline__ void gpu_store64(uint8_t *x, uint64_t u) +{ + for (unsigned int i = 0; i < 8; ++i) { x[i] = (uint8_t)u; u >>= 8; } +} + +static __device__ __noinline__ void KeccakF1600_StatePermute(uint64_t *state) +{ + int round; + uint64_t Aba, Abe, Abi, Abo, Abu; + uint64_t Aga, Age, Agi, Ago, Agu; + uint64_t Aka, Ake, Aki, Ako, Aku; + uint64_t Ama, Ame, Ami, Amo, Amu; + uint64_t Asa, Ase, Asi, Aso, Asu; + uint64_t BCa, BCe, BCi, BCo, BCu; + uint64_t Da, De, Di, Do, Du; + uint64_t Eba, Ebe, Ebi, Ebo, Ebu; + uint64_t Ega, Ege, Egi, Ego, Egu; + uint64_t Eka, Eke, Eki, Eko, Eku; + uint64_t Ema, Eme, Emi, Emo, Emu; + uint64_t Esa, Ese, Esi, Eso, Esu; + + Aba = state[ 0]; Abe = state[ 1]; Abi = state[ 2]; Abo = state[ 3]; Abu = state[ 4]; + Aga = state[ 5]; Age = state[ 6]; Agi = state[ 7]; Ago = state[ 8]; Agu = state[ 9]; + Aka = state[10]; Ake = state[11]; Aki = state[12]; Ako = state[13]; Aku = state[14]; + Ama = state[15]; Ame = state[16]; Ami = state[17]; Amo = state[18]; Amu = state[19]; + Asa = state[20]; Ase = state[21]; Asi = state[22]; Aso = state[23]; Asu = state[24]; + + for (round = 0; round < NROUNDS; round += 2) { + BCa = Aba^Aga^Aka^Ama^Asa; BCe = Abe^Age^Ake^Ame^Ase; + BCi = Abi^Agi^Aki^Ami^Asi; BCo = Abo^Ago^Ako^Amo^Aso; + BCu = Abu^Agu^Aku^Amu^Asu; + Da = BCu^ROL(BCe, 1); De = BCa^ROL(BCi, 1); + Di = BCe^ROL(BCo, 1); Do = BCi^ROL(BCu, 1); Du = BCo^ROL(BCa, 1); + + Aba ^= Da; BCa = Aba; + Age ^= De; BCe = ROL(Age, 44); Aki ^= Di; BCi = ROL(Aki, 43); + Amo ^= Do; BCo = ROL(Amo, 21); Asu ^= Du; BCu = ROL(Asu, 14); + Eba = BCa ^((~BCe)& BCi); Eba ^= gpu_KeccakF_RoundConstants[round]; + Ebe = BCe ^((~BCi)& BCo); Ebi = BCi ^((~BCo)& BCu); + Ebo = BCo ^((~BCu)& BCa); Ebu = BCu ^((~BCa)& BCe); + + Abo ^= Do; BCa = ROL(Abo, 28); Agu ^= Du; BCe = ROL(Agu, 20); + Aka ^= Da; BCi = ROL(Aka, 3); Ame ^= De; BCo = ROL(Ame, 45); + Asi ^= Di; BCu = ROL(Asi, 61); + Ega = BCa ^((~BCe)& BCi); Ege = BCe ^((~BCi)& BCo); + Egi = BCi ^((~BCo)& BCu); Ego = BCo ^((~BCu)& BCa); Egu = BCu ^((~BCa)& BCe); + + Abe ^= De; BCa = ROL(Abe, 1); Agi ^= Di; BCe = ROL(Agi, 6); + Ako ^= Do; BCi = ROL(Ako, 25); Amu ^= Du; BCo = ROL(Amu, 8); + Asa ^= Da; BCu = ROL(Asa, 18); + Eka = BCa ^((~BCe)& BCi); Eke = BCe ^((~BCi)& BCo); + Eki = BCi ^((~BCo)& BCu); Eko = BCo ^((~BCu)& BCa); Eku = BCu ^((~BCa)& BCe); + + Abu ^= Du; BCa = ROL(Abu, 27); Aga ^= Da; BCe = ROL(Aga, 36); + Ake ^= De; BCi = ROL(Ake, 10); Ami ^= Di; BCo = ROL(Ami, 15); + Aso ^= Do; BCu = ROL(Aso, 56); + Ema = BCa ^((~BCe)& BCi); Eme = BCe ^((~BCi)& BCo); + Emi = BCi ^((~BCo)& BCu); Emo = BCo ^((~BCu)& BCa); Emu = BCu ^((~BCa)& BCe); + + Abi ^= Di; BCa = ROL(Abi, 62); Ago ^= Do; BCe = ROL(Ago, 55); + Aku ^= Du; BCi = ROL(Aku, 39); Ama ^= Da; BCo = ROL(Ama, 41); + Ase ^= De; BCu = ROL(Ase, 2); + Esa = BCa ^((~BCe)& BCi); Ese = BCe ^((~BCi)& BCo); + Esi = BCi ^((~BCo)& BCu); Eso = BCo ^((~BCu)& BCa); Esu = BCu ^((~BCa)& BCe); + + /* Round 2 */ + BCa = Eba^Ega^Eka^Ema^Esa; BCe = Ebe^Ege^Eke^Eme^Ese; + BCi = Ebi^Egi^Eki^Emi^Esi; BCo = Ebo^Ego^Eko^Emo^Eso; + BCu = Ebu^Egu^Eku^Emu^Esu; + Da = BCu^ROL(BCe, 1); De = BCa^ROL(BCi, 1); + Di = BCe^ROL(BCo, 1); Do = BCi^ROL(BCu, 1); Du = BCo^ROL(BCa, 1); + + Eba ^= Da; BCa = Eba; + Ege ^= De; BCe = ROL(Ege, 44); Eki ^= Di; BCi = ROL(Eki, 43); + Emo ^= Do; BCo = ROL(Emo, 21); Esu ^= Du; BCu = ROL(Esu, 14); + Aba = BCa ^((~BCe)& BCi); Aba ^= gpu_KeccakF_RoundConstants[round+1]; + Abe = BCe ^((~BCi)& BCo); Abi = BCi ^((~BCo)& BCu); + Abo = BCo ^((~BCu)& BCa); Abu = BCu ^((~BCa)& BCe); + + Ebo ^= Do; BCa = ROL(Ebo, 28); Egu ^= Du; BCe = ROL(Egu, 20); + Eka ^= Da; BCi = ROL(Eka, 3); Eme ^= De; BCo = ROL(Eme, 45); + Esi ^= Di; BCu = ROL(Esi, 61); + Aga = BCa ^((~BCe)& BCi); Age = BCe ^((~BCi)& BCo); + Agi = BCi ^((~BCo)& BCu); Ago = BCo ^((~BCu)& BCa); Agu = BCu ^((~BCa)& BCe); + + Ebe ^= De; BCa = ROL(Ebe, 1); Egi ^= Di; BCe = ROL(Egi, 6); + Eko ^= Do; BCi = ROL(Eko, 25); Emu ^= Du; BCo = ROL(Emu, 8); + Esa ^= Da; BCu = ROL(Esa, 18); + Aka = BCa ^((~BCe)& BCi); Ake = BCe ^((~BCi)& BCo); + Aki = BCi ^((~BCo)& BCu); Ako = BCo ^((~BCu)& BCa); Aku = BCu ^((~BCa)& BCe); + + Ebu ^= Du; BCa = ROL(Ebu, 27); Ega ^= Da; BCe = ROL(Ega, 36); + Eke ^= De; BCi = ROL(Eke, 10); Emi ^= Di; BCo = ROL(Emi, 15); + Eso ^= Do; BCu = ROL(Eso, 56); + Ama = BCa ^((~BCe)& BCi); Ame = BCe ^((~BCi)& BCo); + Ami = BCi ^((~BCo)& BCu); Amo = BCo ^((~BCu)& BCa); Amu = BCu ^((~BCa)& BCe); + + Ebi ^= Di; BCa = ROL(Ebi, 62); Ego ^= Do; BCe = ROL(Ego, 55); + Eku ^= Du; BCi = ROL(Eku, 39); Ema ^= Da; BCo = ROL(Ema, 41); + Ese ^= De; BCu = ROL(Ese, 2); + Asa = BCa ^((~BCe)& BCi); Ase = BCe ^((~BCi)& BCo); + Asi = BCi ^((~BCo)& BCu); Aso = BCo ^((~BCu)& BCa); Asu = BCu ^((~BCa)& BCe); + } + + state[ 0] = Aba; state[ 1] = Abe; state[ 2] = Abi; state[ 3] = Abo; state[ 4] = Abu; + state[ 5] = Aga; state[ 6] = Age; state[ 7] = Agi; state[ 8] = Ago; state[ 9] = Agu; + state[10] = Aka; state[11] = Ake; state[12] = Aki; state[13] = Ako; state[14] = Aku; + state[15] = Ama; state[16] = Ame; state[17] = Ami; state[18] = Amo; state[19] = Amu; + state[20] = Asa; state[21] = Ase; state[22] = Asi; state[23] = Aso; state[24] = Asu; +} + +/* Keccak 核心 */ +static __device__ void keccak_init(uint64_t s[25]) +{ for (unsigned int i = 0; i < 25; i++) s[i] = 0; } + +static __device__ unsigned int keccak_absorb(uint64_t s[25], unsigned int pos, + unsigned int r, const uint8_t *in, size_t inlen) +{ + unsigned int i; + while (pos + inlen >= r) { + for (i = pos; i < r; i++) s[i/8] ^= (uint64_t)*in++ << 8*(i%8); + inlen -= r - pos; KeccakF1600_StatePermute(s); pos = 0; + } + for (i = pos; i < pos + (unsigned int)inlen; i++) s[i/8] ^= (uint64_t)*in++ << 8*(i%8); + return i; +} + +static __device__ void keccak_finalize(uint64_t s[25], unsigned int pos, unsigned int r, uint8_t p) +{ s[pos/8] ^= (uint64_t)p << 8*(pos%8); s[r/8-1] ^= 1ULL << 63; } + +static __device__ unsigned int keccak_squeeze(uint8_t *out, size_t outlen, + uint64_t s[25], unsigned int pos, unsigned int r) +{ + unsigned int i; + while (outlen) { + if (pos == r) { KeccakF1600_StatePermute(s); pos = 0; } + for (i = pos; i < r && i < pos + (unsigned int)outlen; i++) *out++ = (uint8_t)(s[i/8] >> 8*(i%8)); + outlen -= i - pos; pos = i; + } + return pos; +} + +static __device__ void keccak_absorb_once(uint64_t s[25], unsigned int r, + const uint8_t *in, size_t inlen, uint8_t p) +{ + unsigned int i; + for (i = 0; i < 25; ++i) s[i] = 0; + while (inlen >= r) { + for (i = 0; i < r / 8; ++i) s[i] ^= gpu_load64(in + 8 * i); + KeccakF1600_StatePermute(s); inlen -= r; in += r; + } + for (i = 0; i < (unsigned int)inlen; ++i) + s[i >> 3] ^= (uint64_t)in[i] << (8 * (i & 7)); + s[inlen >> 3] ^= (uint64_t)p << (8 * (inlen & 7)); + s[(r - 1) >> 3] ^= 1ULL << 63; +} + +static __device__ void keccak_squeezeblocks(uint8_t *out, size_t nblocks, uint64_t s[25], unsigned int r) +{ + unsigned int i; + while (nblocks > 0) { + KeccakF1600_StatePermute(s); + for (i = 0; i < (r >> 3); i++) gpu_store64(out + 8 * i, s[i]); + out += r; nblocks--; + } +} + +/* SHAKE128 */ +static __device__ void shake128_init(keccak_state *state) +{ keccak_init(state->s); state->pos = 0; } + +static __device__ void shake128_absorb(keccak_state *state, const uint8_t *in, size_t inlen) +{ state->pos = keccak_absorb(state->s, state->pos, SHAKE128_RATE, in, inlen); } + +static __device__ void shake128_finalize(keccak_state *state) +{ keccak_finalize(state->s, state->pos, SHAKE128_RATE, 0x1F); state->pos = SHAKE128_RATE; } + +static __device__ void shake128_squeeze(uint8_t *out, size_t outlen, keccak_state *state) +{ state->pos = keccak_squeeze(out, outlen, state->s, state->pos, SHAKE128_RATE); } + +static __device__ void shake128_absorb_once(keccak_state *state, const uint8_t *in, size_t inlen) +{ keccak_absorb_once(state->s, SHAKE128_RATE, in, inlen, 0x1F); state->pos = SHAKE128_RATE; } + +static __device__ void shake128_squeezeblocks(uint8_t *output, size_t nblocks, keccak_state *state) +{ keccak_squeezeblocks(output, nblocks, state->s, SHAKE128_RATE); } + +/* SHAKE256 */ +static __device__ void shake256_init(keccak_state *state) +{ keccak_init(state->s); state->pos = 0; } + +static __device__ void shake256_absorb(keccak_state *state, const uint8_t *in, size_t inlen) +{ state->pos = keccak_absorb(state->s, state->pos, SHAKE256_RATE, in, inlen); } + +static __device__ void shake256_finalize(keccak_state *state) +{ keccak_finalize(state->s, state->pos, SHAKE256_RATE, 0x1F); state->pos = SHAKE256_RATE; } + +static __device__ void shake256_squeeze(uint8_t *out, size_t outlen, keccak_state *state) +{ state->pos = keccak_squeeze(out, outlen, state->s, state->pos, SHAKE256_RATE); } + +static __device__ void shake256_absorb_once(keccak_state *state, const uint8_t *input, size_t inlen) +{ keccak_absorb_once(state->s, SHAKE256_RATE, input, inlen, 0x1F); state->pos = SHAKE256_RATE; } + +static __device__ void shake256_squeezeblocks(uint8_t *output, size_t nblocks, keccak_state *state) +{ keccak_squeezeblocks(output, nblocks, state->s, SHAKE256_RATE); } + +/* 一次性 SHAKE */ +static __device__ __noinline__ void shake128(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen) +{ + keccak_state state; + shake128_absorb_once(&state, in, inlen); + size_t nblocks = outlen / SHAKE128_RATE; + shake128_squeezeblocks(out, nblocks, &state); + outlen -= nblocks * SHAKE128_RATE; out += nblocks * SHAKE128_RATE; + shake128_squeeze(out, outlen, &state); +} + +static __device__ __noinline__ void shake256(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen) +{ + keccak_state state; + shake256_absorb_once(&state, in, inlen); + size_t nblocks = outlen / SHAKE256_RATE; + shake256_squeezeblocks(out, nblocks, &state); + outlen -= nblocks * SHAKE256_RATE; out += nblocks * SHAKE256_RATE; + shake256_squeeze(out, outlen, &state); +} + +/* SHA3 */ +static __device__ __noinline__ void sha3_256(uint8_t *output, const uint8_t *input, size_t inlen) +{ + uint64_t s[25]; + keccak_absorb_once(s, SHA3_256_RATE, input, inlen, 0x06); + KeccakF1600_StatePermute(s); + for (size_t i = 0; i < 4; i++) gpu_store64(output + 8 * i, s[i]); +} + +static __device__ __noinline__ void sha3_512(uint8_t *output, const uint8_t *input, size_t inlen) +{ + uint64_t s[25]; + keccak_absorb_once(s, SHA3_512_RATE, input, inlen, 0x06); + KeccakF1600_StatePermute(s); + for (size_t i = 0; i < 8; i++) gpu_store64(output + 8 * i, s[i]); +} + +#endif /* FIPS202_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/kem.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/kem.cuh new file mode 100644 index 000000000..b3eac6d06 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/kem.cuh @@ -0,0 +1,526 @@ +/* + * kem.cuh — 统一 INDCPA 和 KEM 设备函数 + * + * 支持 Kyber 和 Aigis-enc,通过 #if ALGORITHM 分支处理关键差异: + * + * 1. 矩阵展开: 拒绝采样位宽 + * Kyber: rej_uniform_12bit (val < 3329, 每3字节2个值) + * Aigis-enc: rej_uniform_13bit (val < 7681, 每13字节8个值) + * + * 2. 矩阵种子调用约定: + * Kyber: SHAKE128(seed || j || i) for a[i][j] (先转置种子) + * Aigis-enc: SHAKE128(seed || i || j) for a[i][j] + * + * 3. PK 打包: + * Kyber: poly_tobytes (12-bit 无损) + * Aigis-enc: polyvec_pk_compress (bits_pk 有损) + * + * 4. INDCPA 加密签号: + * Kyber: v = pk*r + e2 + msg (加法) + * Aigis-enc: v = pk*r + e2 - msg (减法) + * + * 5. INDCPA 解密: + * Kyber: mp = s*u - v → tomsg(mp) + * Aigis-enc: mp = s*u - v → tomsg(mp) (同 Kyber 解密端) + * (Aigis enc v-=msg, dec mp=s*u-v → 等价) + */ + +#ifndef KEM_CUH +#define KEM_CUH + +#include +#include +#include "params.h" +#include "reduce.cuh" +#include "fips202.cuh" +#include "ntt.cuh" +#include "poly.cuh" +#include "polyvec.cuh" +#include "cbd.cuh" + +#ifndef KEM_FORCE_INLINE_PACK +#define KEM_FORCE_INLINE_PACK 0 +#endif + +#if KEM_FORCE_INLINE_PACK +#define KEM_PACK_ATTR __forceinline__ +#else +#define KEM_PACK_ATTR __noinline__ +#endif + +#ifndef KEM_FORCE_INLINE_TOP +#define KEM_FORCE_INLINE_TOP 0 +#endif + +#if KEM_FORCE_INLINE_TOP +#define KEM_TOP_ATTR __forceinline__ +#else +#define KEM_TOP_ATTR __noinline__ +#endif + +#ifndef KEM_FAST_DECAP_NO_REENC +#define KEM_FAST_DECAP_NO_REENC 0 +#endif + +/* ================================================================ + * 拒绝采样: 从 XOF 输出中提取均匀分布的系数 + * ================================================================ */ + +#if ALGORITHM == ALGO_KYBER + +/* Kyber: 12-bit 拒绝采样 (每 3 字节抽 2 个值) */ +static __device__ unsigned int rej_uniform(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen) +{ + unsigned int ctr = 0, pos = 0; + while (ctr < len && pos + 2 < buflen) { + uint16_t val0 = ((buf[pos+0]) | ((uint16_t)buf[pos+1] << 8)) & 0x0FFF; + uint16_t val1 = ((buf[pos+1] >> 4) | ((uint16_t)buf[pos+2] << 4)) & 0x0FFF; + pos += 3; + if (val0 < PARAM_Q) r[ctr++] = (int16_t)val0; + if (ctr < len && val1 < PARAM_Q) r[ctr++] = (int16_t)val1; + } + return ctr; +} + +#elif ALGORITHM == ALGO_AIGIS_ENC + +/* Aigis-enc: 13-bit 拒绝采样 (每 13 字节抽 8 个值) */ +static __device__ unsigned int rej_uniform(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen) +{ + unsigned int ctr = 0, pos = 0; + while (ctr < len && pos + 12 < buflen) { + uint16_t v[8]; + /* 每 13 字节 = 104 bits = 8 × 13 bits */ + v[0] = ((uint16_t)buf[pos+0] | ((uint16_t)buf[pos+1] << 8)) & 0x1FFF; + v[1] = ((uint16_t)buf[pos+1] >> 5 | ((uint16_t)buf[pos+2] << 3) | ((uint16_t)buf[pos+3] << 11)) & 0x1FFF; + v[2] = ((uint16_t)buf[pos+3] >> 2 | ((uint16_t)buf[pos+4] << 6)) & 0x1FFF; + v[3] = ((uint16_t)buf[pos+4] >> 7 | ((uint16_t)buf[pos+5] << 1) | ((uint16_t)buf[pos+6] << 9)) & 0x1FFF; + v[4] = ((uint16_t)buf[pos+6] >> 4 | ((uint16_t)buf[pos+7] << 4) | ((uint16_t)buf[pos+8] << 12)) & 0x1FFF; + v[5] = ((uint16_t)buf[pos+8] >> 1 | ((uint16_t)buf[pos+9] << 7)) & 0x1FFF; + v[6] = ((uint16_t)buf[pos+9] >> 6 | ((uint16_t)buf[pos+10] << 2) | ((uint16_t)buf[pos+11] << 10)) & 0x1FFF; + v[7] = ((uint16_t)buf[pos+11] >> 3 | ((uint16_t)buf[pos+12] << 5)) & 0x1FFF; + pos += 13; + for (int i = 0; i < 8 && ctr < len; i++) { + if (v[i] < (uint16_t)PARAM_Q) r[ctr++] = (int16_t)v[i]; + } + } + return ctr; +} + +#endif /* ALGORITHM for rej_uniform */ + +/* ================================================================ + * 矩阵展开 (gen_matrix / gen_at) + * + * Kyber 种子约定: 生成 a[i][j] 使用 SHAKE128(seed || j || i) + * Aigis 种子约定: 生成 a[i][j] 使用 SHAKE128(seed || i || j) + * + * transposed=0: 正常矩阵 A + * transposed=1: 转置矩阵 A^T + * ================================================================ */ +#ifndef KEM_DIRECT_REJ_UNIFORM +#define KEM_DIRECT_REJ_UNIFORM 1 +#endif + +#if KEM_DIRECT_REJ_UNIFORM +typedef struct { + uint64_t s[25]; + unsigned int pos; +} xof_reader; + +static __device__ __forceinline__ void xof_reader_init(xof_reader *rd, + const uint8_t *seed, uint8_t x, uint8_t y) +{ + for (unsigned int i = 0; i < 25; i++) rd->s[i] = 0; + for (unsigned int i = 0; i < PARAM_SYMBYTES; i++) + rd->s[i >> 3] ^= (uint64_t)seed[i] << (8 * (i & 7)); + rd->s[PARAM_SYMBYTES >> 3] ^= (uint64_t)x << (8 * (PARAM_SYMBYTES & 7)); + rd->s[(PARAM_SYMBYTES + 1) >> 3] ^= (uint64_t)y << (8 * ((PARAM_SYMBYTES + 1) & 7)); + rd->s[(PARAM_SYMBYTES + 2) >> 3] ^= (uint64_t)0x1F << (8 * ((PARAM_SYMBYTES + 2) & 7)); + rd->s[(SHAKE128_RATE - 1) >> 3] ^= 1ULL << 63; + rd->pos = SHAKE128_RATE; +} + +static __device__ __forceinline__ uint8_t xof_reader_u8(xof_reader *rd) +{ + if (rd->pos == SHAKE128_RATE) { + KeccakF1600_StatePermute(rd->s); + rd->pos = 0; + } + uint8_t v = (uint8_t)(rd->s[rd->pos >> 3] >> (8 * (rd->pos & 7))); + rd->pos++; + return v; +} + +static __device__ __noinline__ void rej_uniform_xof(int16_t *r, const uint8_t *seed, + uint8_t x, uint8_t y) +{ + xof_reader rd; + xof_reader_init(&rd, seed, x, y); + unsigned int ctr = 0; + +#if ALGORITHM == ALGO_KYBER + while (ctr < PARAM_N) { + uint16_t b0 = xof_reader_u8(&rd); + uint16_t b1 = xof_reader_u8(&rd); + uint16_t b2 = xof_reader_u8(&rd); + uint16_t val0 = (uint16_t)((b0 | (b1 << 8)) & 0x0FFF); + uint16_t val1 = (uint16_t)(((b1 >> 4) | (b2 << 4)) & 0x0FFF); + if (val0 < PARAM_Q) r[ctr++] = (int16_t)val0; + if (ctr < PARAM_N && val1 < PARAM_Q) r[ctr++] = (int16_t)val1; + } +#elif ALGORITHM == ALGO_AIGIS_ENC + while (ctr < PARAM_N) { + uint8_t b0 = xof_reader_u8(&rd); + uint8_t b1 = xof_reader_u8(&rd); + uint8_t b2 = xof_reader_u8(&rd); + uint8_t b3 = xof_reader_u8(&rd); + uint8_t b4 = xof_reader_u8(&rd); + uint8_t b5 = xof_reader_u8(&rd); + uint8_t b6 = xof_reader_u8(&rd); + uint8_t b7 = xof_reader_u8(&rd); + uint8_t b8 = xof_reader_u8(&rd); + uint8_t b9 = xof_reader_u8(&rd); + uint8_t b10 = xof_reader_u8(&rd); + uint8_t b11 = xof_reader_u8(&rd); + uint8_t b12 = xof_reader_u8(&rd); + uint16_t v0 = ((uint16_t)b0 | ((uint16_t)b1 << 8)) & 0x1FFF; + uint16_t v1 = ((uint16_t)b1 >> 5 | ((uint16_t)b2 << 3) | ((uint16_t)b3 << 11)) & 0x1FFF; + uint16_t v2 = ((uint16_t)b3 >> 2 | ((uint16_t)b4 << 6)) & 0x1FFF; + uint16_t v3 = ((uint16_t)b4 >> 7 | ((uint16_t)b5 << 1) | ((uint16_t)b6 << 9)) & 0x1FFF; + uint16_t v4 = ((uint16_t)b6 >> 4 | ((uint16_t)b7 << 4) | ((uint16_t)b8 << 12)) & 0x1FFF; + uint16_t v5 = ((uint16_t)b8 >> 1 | ((uint16_t)b9 << 7)) & 0x1FFF; + uint16_t v6 = ((uint16_t)b9 >> 6 | ((uint16_t)b10 << 2) | ((uint16_t)b11 << 10)) & 0x1FFF; + uint16_t v7 = ((uint16_t)b11 >> 3 | ((uint16_t)b12 << 5)) & 0x1FFF; + if (v0 < PARAM_Q) r[ctr++] = (int16_t)v0; + if (ctr < PARAM_N && v1 < PARAM_Q) r[ctr++] = (int16_t)v1; + if (ctr < PARAM_N && v2 < PARAM_Q) r[ctr++] = (int16_t)v2; + if (ctr < PARAM_N && v3 < PARAM_Q) r[ctr++] = (int16_t)v3; + if (ctr < PARAM_N && v4 < PARAM_Q) r[ctr++] = (int16_t)v4; + if (ctr < PARAM_N && v5 < PARAM_Q) r[ctr++] = (int16_t)v5; + if (ctr < PARAM_N && v6 < PARAM_Q) r[ctr++] = (int16_t)v6; + if (ctr < PARAM_N && v7 < PARAM_Q) r[ctr++] = (int16_t)v7; + } +#endif +} +#endif + +static __device__ __noinline__ void gen_matrix(kem_polyvec *a, const uint8_t *seed, int transposed) +{ +#if KEM_DIRECT_REJ_UNIFORM + for (int i = 0; i < PARAM_K; i++) { + for (int j = 0; j < PARAM_K; j++) { + uint8_t x, y; +#if ALGORITHM == ALGO_KYBER + if (transposed) { x = (uint8_t)j; y = (uint8_t)i; } + else { x = (uint8_t)i; y = (uint8_t)j; } +#elif ALGORITHM == ALGO_AIGIS_ENC + if (transposed) { x = (uint8_t)j; y = (uint8_t)i; } + else { x = (uint8_t)i; y = (uint8_t)j; } +#endif + rej_uniform_xof(a[i].vec[j].coeffs, seed, x, y); + } + } +#else + keccak_state state; + uint8_t buf[PARAM_GEN_MATRIX_BUFLEN + 2]; + unsigned int ctr; + + for (int i = 0; i < PARAM_K; i++) { + for (int j = 0; j < PARAM_K; j++) { + /* 构建 SHAKE128 输入: seed || x || y */ + uint8_t extseed[PARAM_SYMBYTES + 2]; + for (int k = 0; k < PARAM_SYMBYTES; k++) extseed[k] = seed[k]; + +#if ALGORITHM == ALGO_KYBER + /* Kyber: a[i][j] ← SHAKE128(seed, j, i) */ + if (transposed) { extseed[PARAM_SYMBYTES] = (uint8_t)j; extseed[PARAM_SYMBYTES+1] = (uint8_t)i; } + else { extseed[PARAM_SYMBYTES] = (uint8_t)i; extseed[PARAM_SYMBYTES+1] = (uint8_t)j; } +#elif ALGORITHM == ALGO_AIGIS_ENC + /* Aigis: a[i][j] ← SHAKE128(seed, i, j) for non-transposed */ + if (transposed) { extseed[PARAM_SYMBYTES] = (uint8_t)j; extseed[PARAM_SYMBYTES+1] = (uint8_t)i; } + else { extseed[PARAM_SYMBYTES] = (uint8_t)i; extseed[PARAM_SYMBYTES+1] = (uint8_t)j; } +#endif + + shake128_absorb_once(&state, extseed, PARAM_SYMBYTES + 2); + + ctr = 0; + while (ctr < PARAM_N) { + shake128_squeezeblocks(buf, PARAM_GEN_MATRIX_NBLOCKS, &state); + ctr += rej_uniform(a[i].vec[j].coeffs + ctr, PARAM_N - ctr, + buf, PARAM_GEN_MATRIX_NBLOCKS * PARAM_XOF_BLOCKBYTES); + } + } + } +#endif +} + +static __device__ __noinline__ void gen_matrix_row(kem_polyvec *rowvec, + const uint8_t *seed, int row, int transposed) +{ +#if KEM_DIRECT_REJ_UNIFORM + for (int j = 0; j < PARAM_K; j++) { + uint8_t x, y; +#if ALGORITHM == ALGO_KYBER + if (transposed) { x = (uint8_t)j; y = (uint8_t)row; } + else { x = (uint8_t)row; y = (uint8_t)j; } +#elif ALGORITHM == ALGO_AIGIS_ENC + if (transposed) { x = (uint8_t)j; y = (uint8_t)row; } + else { x = (uint8_t)row; y = (uint8_t)j; } +#endif + rej_uniform_xof(rowvec->vec[j].coeffs, seed, x, y); + } +#else + kem_polyvec mat[PARAM_K]; + gen_matrix(mat, seed, transposed); + for (int j = 0; j < PARAM_K; j++) + for (int c = 0; c < PARAM_N; c++) + rowvec->vec[j].coeffs[c] = mat[row].vec[j].coeffs[c]; +#endif +} + +/* ================================================================ + * PK/SK/密文 打包/解包 + * ================================================================ */ + +/* pk = pk_vec_bytes || rho */ +static __device__ KEM_PACK_ATTR void pack_pk(uint8_t *pk, const kem_polyvec *pkpv, const uint8_t *rho) +{ + polyvec_pk_compress(pk, pkpv); + for (int i = 0; i < PARAM_SYMBYTES; i++) + pk[PARAM_PK_POLYVEC_BYTES + i] = rho[i]; +} + +static __device__ KEM_PACK_ATTR void unpack_pk(kem_polyvec *pkpv, uint8_t *rho, const uint8_t *pk) +{ + polyvec_pk_decompress(pkpv, pk); + for (int i = 0; i < PARAM_SYMBYTES; i++) + rho[i] = pk[PARAM_PK_POLYVEC_BYTES + i]; +} + +/* sk = polyvec_tobytes (NTT 域 s) */ +static __device__ KEM_PACK_ATTR void pack_sk(uint8_t *sk, const kem_polyvec *skpv) +{ + polyvec_tobytes(sk, skpv); +} + +static __device__ KEM_PACK_ATTR void unpack_sk(kem_polyvec *skpv, const uint8_t *sk) +{ + polyvec_frombytes(skpv, sk); +} + +/* ct = ct_vec_bytes || ct_poly_bytes */ +static __device__ KEM_PACK_ATTR void pack_ciphertext(uint8_t *c, const kem_polyvec *b, const kem_poly *v) +{ + polyvec_ct_compress(c, b); + poly_compress_c2(c + PARAM_CT_VEC_BYTES, v); +} + +static __device__ KEM_PACK_ATTR void unpack_ciphertext(kem_polyvec *b, kem_poly *v, const uint8_t *c) +{ + polyvec_ct_decompress(b, c); + poly_decompress_c2(v, c + PARAM_CT_VEC_BYTES); +} + +/* ================================================================ + * INDCPA 密钥生成 (串行, 单线程) + * 输入: coins[32] 随机种子 + * 输出: pk[PARAM_PUBLICKEYBYTES], sk[PARAM_INDCPA_SECRETKEYBYTES] + * ================================================================ */ +static __device__ KEM_TOP_ATTR void indcpa_keypair(uint8_t *pk, uint8_t *sk, const uint8_t *coins) +{ + kem_polyvec arow, skpv, e, pkpv; + uint8_t buf[2 * PARAM_SYMBYTES]; + const uint8_t *publicseed = buf; + const uint8_t *noiseseed = buf + PARAM_SYMBYTES; + uint8_t nonce = 0; + + sha3_512(buf, coins, PARAM_SYMBYTES); + + /* 采样秘密 s 和误差 e */ + for (int i = 0; i < PARAM_K; i++) poly_getnoise_s(skpv.vec[i].coeffs, noiseseed, nonce++); + for (int i = 0; i < PARAM_K; i++) poly_getnoise_e_kg(e.vec[i].coeffs, noiseseed, nonce++); + + /* NTT(s) */ + polyvec_ntt(&skpv); + polyvec_caddq(&skpv); /* 规范化到 [0, Q) */ + + /* pk = A * s → INVNTT + e, caddq */ + for (int i = 0; i < PARAM_K; i++) { + gen_matrix_row(&arow, publicseed, i, 0 /* not transposed */); + polyvec_basemul_acc(&pkpv.vec[i], &arow, &skpv); + } + polyvec_invntt(&pkpv); + polyvec_add(&pkpv, &pkpv, &e); + polyvec_caddq(&pkpv); + + pack_sk(sk, &skpv); + pack_pk(pk, &pkpv, publicseed); +} + +/* ================================================================ + * INDCPA 加密 (串行, 单线程) + * ================================================================ */ +static __device__ KEM_TOP_ATTR void indcpa_enc(uint8_t *c, + const uint8_t *m, const uint8_t *pk, const uint8_t *coins) +{ + kem_polyvec at[PARAM_K], sp, ep, pkpv, b; + kem_poly epp, v, k; + uint8_t rho[PARAM_SYMBYTES]; + uint8_t nonce = 0; + + unpack_pk(&pkpv, rho, pk); + poly_frommsg(&k, m); + gen_matrix(at, rho, 1 /* transposed: A^T */); + + /* 采样随机噪声 */ + for (int i = 0; i < PARAM_K; i++) poly_getnoise_s(sp.vec[i].coeffs, coins, nonce++); + for (int i = 0; i < PARAM_K; i++) poly_getnoise_e_enc(ep.vec[i].coeffs, coins, nonce++); + poly_getnoise_e2(epp.coeffs, coins, nonce++); + + /* NTT(r) 和 NTT(pk) */ + polyvec_ntt(&sp); + polyvec_ntt(&pkpv); + + /* u = A^T * r → INVNTT + e1 */ + for (int i = 0; i < PARAM_K; i++) { + polyvec_basemul_acc(&b.vec[i], &at[i], &sp); + } + polyvec_invntt(&b); + polyvec_add(&b, &b, &ep); + polyvec_caddq(&b); + + /* v = pk^T * r → INVNTT + e2 ± msg */ + polyvec_basemul_acc(&v, &pkpv, &sp); + poly_invntt(&v); + poly_add(&v, &v, &epp); + +#if ALGORITHM == ALGO_KYBER + /* Kyber: v += msg */ + poly_add(&v, &v, &k); + poly_caddq(&v); +#elif ALGORITHM == ALGO_AIGIS_ENC + /* Aigis: v -= msg */ + poly_sub(&v, &v, &k); + poly_caddq2(&v); /* 可能为负, 需要双重 caddq */ +#endif + + pack_ciphertext(c, &b, &v); +} + +/* ================================================================ + * INDCPA 解密 (串行, 单线程) + * ================================================================ */ +static __device__ KEM_TOP_ATTR void indcpa_dec(uint8_t *m, const uint8_t *c, const uint8_t *sk) +{ + kem_polyvec b, skpv; + kem_poly v, mp; + + unpack_ciphertext(&b, &v, c); + unpack_sk(&skpv, sk); + + polyvec_ntt(&b); + polyvec_basemul_acc(&mp, &skpv, &b); + poly_invntt(&mp); + + /* mp = s^T * u - v */ + poly_sub(&mp, &mp, &v); + poly_caddq2(&mp); + + poly_tomsg(m, &mp); +} + +/* ================================================================ + * FO 变换: KEM 密钥生成 / 封装 / 解封装 + * + * sk 布局: indcpa_sk || pk || H(pk) || z + * (与 Kyber 和 Aigis-enc 参考实现相同) + * ================================================================ */ + +static __device__ KEM_TOP_ATTR void kem_keypair(uint8_t *pk, uint8_t *sk, const uint8_t *coins) +{ + uint8_t coins_indcpa[PARAM_SYMBYTES]; + /* coins[0:32] → INDCPA 种子 */ + for (int i = 0; i < PARAM_SYMBYTES; i++) coins_indcpa[i] = coins[i]; + + indcpa_keypair(pk, sk, coins_indcpa); + + /* sk[INDCPA_SK] = indcpa_sk, sk[INDCPA_SK+PK] = pk */ + uint8_t *sk_pk = sk + PARAM_INDCPA_SECRETKEYBYTES; + for (int i = 0; i < (int)PARAM_PUBLICKEYBYTES; i++) sk_pk[i] = pk[i]; + + /* H(pk) 存入 sk */ + uint8_t *hpk = sk + PARAM_INDCPA_SECRETKEYBYTES + PARAM_PUBLICKEYBYTES; + sha3_256(hpk, pk, PARAM_PUBLICKEYBYTES); + + /* z = random (coins[32:64]) */ + uint8_t *z = hpk + PARAM_SYMBYTES; + for (int i = 0; i < PARAM_SYMBYTES; i++) z[i] = coins[PARAM_SYMBYTES + i]; +} + +static __device__ KEM_TOP_ATTR void kem_encaps(uint8_t *ct, uint8_t *ss, + const uint8_t *pk, const uint8_t *coins) +{ + uint8_t buf[2 * PARAM_SYMBYTES]; + uint8_t kr[2 * PARAM_SYMBYTES]; + + /* Hash 消息和公钥 */ + for (int i = 0; i < PARAM_SYMBYTES; i++) buf[i] = coins[i]; + sha3_256(buf + PARAM_SYMBYTES, pk, PARAM_PUBLICKEYBYTES); + sha3_512(kr, buf, 2 * PARAM_SYMBYTES); + + /* 加密 */ + indcpa_enc(ct, buf, pk, kr + PARAM_SYMBYTES); + + /* ss = SHAKE256(K' || H(ct)) */ + sha3_256(kr + PARAM_SYMBYTES, ct, PARAM_CIPHERTEXTBYTES); + shake256(ss, PARAM_SSBYTES, kr, 2 * PARAM_SYMBYTES); +} + +static __device__ KEM_TOP_ATTR void kem_decaps(uint8_t *ss, + const uint8_t *ct, const uint8_t *sk) +{ + const uint8_t *pk = sk + PARAM_INDCPA_SECRETKEYBYTES; + const uint8_t *hpk = pk + PARAM_PUBLICKEYBYTES; + const uint8_t *z = hpk + PARAM_SYMBYTES; + + uint8_t buf[2 * PARAM_SYMBYTES]; + uint8_t kr[2 * PARAM_SYMBYTES]; +#if !KEM_FAST_DECAP_NO_REENC + uint8_t ct_reenc[PARAM_CIPHERTEXTBYTES]; +#endif + + /* 解密 */ + indcpa_dec(buf, ct, sk); + + /* 重新加密并比较 */ + for (int i = 0; i < PARAM_SYMBYTES; i++) buf[PARAM_SYMBYTES + i] = hpk[i]; + sha3_512(kr, buf, 2 * PARAM_SYMBYTES); + +#if !KEM_FAST_DECAP_NO_REENC + indcpa_enc(ct_reenc, buf, pk, kr + PARAM_SYMBYTES); + + /* 比较: 使用常数时间比较 */ + int diff = 0; + for (int i = 0; i < (int)PARAM_CIPHERTEXTBYTES; i++) + diff |= (ct[i] ^ ct_reenc[i]); + /* diff=0 → 相同; diff≠0 → 不同. 构造字节掩码: 0x00 或 0xFF */ + uint8_t fail = (uint8_t)(0u - (unsigned)(diff != 0)); + + /* ss = SHAKE256(K' || H(ct)) (失败时用 z 替代 K') */ + sha3_256(kr + PARAM_SYMBYTES, ct, PARAM_CIPHERTEXTBYTES); + + /* 若失败: k' = z; 否则用正常 k' */ + for (int i = 0; i < PARAM_SYMBYTES; i++) + kr[i] = (uint8_t)((kr[i] & ~fail) | (z[i] & fail)); +#else + (void)pk; + (void)z; +#endif + + shake256(ss, PARAM_SSBYTES, kr, 2 * PARAM_SYMBYTES); +} + +#endif /* KEM_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/main.cu b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/main.cu new file mode 100644 index 000000000..2b1da7c51 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/main.cu @@ -0,0 +1,815 @@ +/* + * main.cu — 统一 KEM 测试驱动程序 (Kyber + Aigis-enc) + * + * 编译示例: + * nvcc -O2 -DALGORITHM=1 -DPARAM_MODE=3 -o kyber768.exe main.cu + * nvcc -O2 -DALGORITHM=2 -DPARAM_MODE=3 -o aigisenc3.exe main.cu + * + * 用法: + * kyber768.exe — 运行正确性测试 + 默认批量吞吐量测试 + * kyber768.exe --batch 8192 — 指定批量大小 + * kyber768.exe --sweep — 扫描不同 batch size + * kyber768.exe --serial-only — 仅运行串行设备函数 (不用流水线 kernel) + */ + +#include "rocm_compat.h" +#include +#include +#include +#include +#include + +#include "config.h" +#include "params.h" +#include "batch_kem.cuh" + +/* ================================================================ + * 工具宏 + * ================================================================ */ +#define CUDA_CHECK(call) do { \ + cudaError_t _e = (call); \ + if (_e != cudaSuccess) { \ + fprintf(stderr, "CUDA error %s:%d — %s\n", __FILE__, __LINE__, \ + cudaGetErrorString(_e)); \ + exit(1); \ + } \ +} while (0) + +static double get_time_ms(void) +{ + struct timespec ts; + timespec_get(&ts, TIME_UTC); + return ts.tv_sec * 1000.0 + ts.tv_nsec / 1e6; +} + +/* ================================================================ + * 算法名称 + * ================================================================ */ +static const char *algo_name(void) +{ +#if ALGORITHM == ALGO_KYBER + #if PARAM_MODE == 2 + return "Kyber-512"; + #elif PARAM_MODE == 3 + return "Kyber-768"; + #else + return "Kyber-1024"; + #endif +#elif ALGORITHM == ALGO_AIGIS_ENC + #if PARAM_MODE == 1 + return "Aigis-enc-1"; + #elif PARAM_MODE == 2 + return "Aigis-enc-2"; + #elif PARAM_MODE == 3 + return "Aigis-enc-3"; + #else + return "Aigis-enc-4"; + #endif +#endif +} + +/* ================================================================ + * 正确性测试: 单实例 CPU 调用 GPU kernel 验证 + * ================================================================ */ +static int test_correctness(void) +{ + printf("=== 正确性测试: %s ===\n", algo_name()); + printf(" PK=%u SK=%u CT=%u SS=%u 字节\n", + PARAM_PUBLICKEYBYTES, PARAM_SECRETKEYBYTES, + PARAM_CIPHERTEXTBYTES, PARAM_SSBYTES); + + /* Host 端分配 */ + uint8_t *h_pk = (uint8_t *)malloc(PARAM_PUBLICKEYBYTES); + uint8_t *h_sk = (uint8_t *)malloc(PARAM_SECRETKEYBYTES); + uint8_t *h_ct = (uint8_t *)malloc(PARAM_CIPHERTEXTBYTES); + uint8_t *h_ss1 = (uint8_t *)malloc(PARAM_SSBYTES); + uint8_t *h_ss2 = (uint8_t *)malloc(PARAM_SSBYTES); + uint8_t *h_coins_kg = (uint8_t *)malloc(2 * PARAM_SYMBYTES); + uint8_t *h_coins_enc = (uint8_t *)malloc(PARAM_SYMBYTES); + + if (!h_pk || !h_sk || !h_ct || !h_ss1 || !h_ss2 || !h_coins_kg || !h_coins_enc) { + fprintf(stderr, "malloc failed\n"); + return -1; + } + + /* 生成伪随机种子 (测试用,实际应用请使用安全随机源) */ + srand(42); + for (int i = 0; i < 2 * PARAM_SYMBYTES; i++) + h_coins_kg[i] = (uint8_t)(rand() & 0xFF); + for (int i = 0; i < PARAM_SYMBYTES; i++) + h_coins_enc[i] = (uint8_t)(rand() & 0xFF); + + /* Device 端分配 */ + uint8_t *d_pk, *d_sk, *d_ct, *d_ss1, *d_ss2; + uint8_t *d_coins_kg, *d_coins_enc; + CUDA_CHECK(cudaMalloc(&d_pk, PARAM_PUBLICKEYBYTES)); + CUDA_CHECK(cudaMalloc(&d_sk, PARAM_SECRETKEYBYTES)); + CUDA_CHECK(cudaMalloc(&d_ct, PARAM_CIPHERTEXTBYTES)); + CUDA_CHECK(cudaMalloc(&d_ss1, PARAM_SSBYTES)); + CUDA_CHECK(cudaMalloc(&d_ss2, PARAM_SSBYTES)); + CUDA_CHECK(cudaMalloc(&d_coins_kg, 2 * PARAM_SYMBYTES)); + CUDA_CHECK(cudaMalloc(&d_coins_enc, PARAM_SYMBYTES)); + + CUDA_CHECK(cudaMemcpy(d_coins_kg, h_coins_kg, 2 * PARAM_SYMBYTES, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_coins_enc, h_coins_enc, PARAM_SYMBYTES, cudaMemcpyHostToDevice)); + + /* 串行设备 kernel 验证 (batch_count=1) */ + batch_kem_keypair_serial_kernel<<<1, 1>>>(d_pk, d_sk, d_coins_kg, 1); + CUDA_CHECK(cudaGetLastError()); + batch_kem_encaps_serial_kernel<<<1, 1>>>(d_ct, d_ss1, d_pk, d_coins_enc, 1); + CUDA_CHECK(cudaGetLastError()); + batch_kem_decaps_serial_kernel<<<1, 1>>>(d_ss2, d_ct, d_sk, 1); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + + /* 取回结果 */ + CUDA_CHECK(cudaMemcpy(h_ss1, d_ss1, PARAM_SSBYTES, cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_ss2, d_ss2, PARAM_SSBYTES, cudaMemcpyDeviceToHost)); + + /* 验证 ss1 == ss2 */ + int ok = (memcmp(h_ss1, h_ss2, PARAM_SSBYTES) == 0); + printf(" KEM 正确性: %s\n", ok ? "PASS" : "FAIL"); + + if (!ok) { + printf(" [encaps ss] "); + for (int i = 0; i < 8; i++) printf("%02x", h_ss1[i]); + printf("...\n"); + printf(" [decaps ss] "); + for (int i = 0; i < 8; i++) printf("%02x", h_ss2[i]); + printf("...\n"); + } + + cudaFree(d_pk); cudaFree(d_sk); cudaFree(d_ct); + cudaFree(d_ss1); cudaFree(d_ss2); + cudaFree(d_coins_kg); cudaFree(d_coins_enc); + free(h_pk); free(h_sk); free(h_ct); + free(h_ss1); free(h_ss2); + free(h_coins_kg); free(h_coins_enc); + + return ok ? 0 : 1; +} + +/* ================================================================ + * 批量吞吐量测试 + * ================================================================ */ +static void bench_batch(int batch_count, int n_ops, int use_pipeline, int profile_pipeline = 0) +{ + printf("\n--- batch=%d n_ops=%d mode=%s ---\n", + batch_count, n_ops, use_pipeline ? "pipeline" : "serial"); + + /* 分配设备内存 */ + uint8_t *d_pk, *d_sk, *d_ct, *d_ss; + uint8_t *d_coins_kg, *d_coins_enc; + + CUDA_CHECK(cudaMalloc(&d_pk, (size_t)batch_count * PARAM_PUBLICKEYBYTES)); + CUDA_CHECK(cudaMalloc(&d_sk, (size_t)batch_count * PARAM_SECRETKEYBYTES)); + CUDA_CHECK(cudaMalloc(&d_ct, (size_t)batch_count * PARAM_CIPHERTEXTBYTES)); + CUDA_CHECK(cudaMalloc(&d_ss, (size_t)batch_count * PARAM_SSBYTES)); + CUDA_CHECK(cudaMalloc(&d_coins_kg, (size_t)batch_count * 2 * PARAM_SYMBYTES)); + CUDA_CHECK(cudaMalloc(&d_coins_enc, (size_t)batch_count * PARAM_SYMBYTES)); + + /* 生成随机种子 */ + uint8_t *h_coins_kg = (uint8_t *)malloc((size_t)batch_count * 2 * PARAM_SYMBYTES); + uint8_t *h_coins_enc = (uint8_t *)malloc((size_t)batch_count * PARAM_SYMBYTES); + if (!h_coins_kg || !h_coins_enc) { fprintf(stderr, "OOM\n"); exit(1); } + + srand(1234); + for (size_t i = 0; i < (size_t)batch_count * 2 * PARAM_SYMBYTES; i++) + h_coins_kg[i] = (uint8_t)(rand() & 0xFF); + for (size_t i = 0; i < (size_t)batch_count * PARAM_SYMBYTES; i++) + h_coins_enc[i] = (uint8_t)(rand() & 0xFF); + + CUDA_CHECK(cudaMemcpy(d_coins_kg, h_coins_kg, (size_t)batch_count * 2 * PARAM_SYMBYTES, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_coins_enc, h_coins_enc, (size_t)batch_count * PARAM_SYMBYTES, cudaMemcpyHostToDevice)); + + BatchKemBuffers buf = {}; + if (use_pipeline) { + /* 修复 batch_kem_alloc 中的双 cudaMalloc bug: 直接内联分配 */ + buf.max_batch = batch_count; + CUDA_CHECK(cudaMalloc(&buf.d_mat, (size_t)PARAM_K * PARAM_K * batch_count * PARAM_N * sizeof(int16_t))); + CUDA_CHECK(cudaMalloc(&buf.d_skpv, (size_t)PARAM_K * batch_count * PARAM_N * sizeof(int16_t))); + CUDA_CHECK(cudaMalloc(&buf.d_pkpv, (size_t)PARAM_K * batch_count * PARAM_N * sizeof(int16_t))); + CUDA_CHECK(cudaMalloc(&buf.d_e, (size_t)PARAM_K * batch_count * PARAM_N * sizeof(int16_t))); + CUDA_CHECK(cudaMalloc(&buf.d_publicseed_kg, (size_t)batch_count * PARAM_SYMBYTES)); + CUDA_CHECK(cudaMalloc(&buf.d_noiseseed_kg, (size_t)batch_count * PARAM_SYMBYTES)); + buf.d_pk_bytes = d_pk; + buf.d_sk_bytes = d_sk; + buf.d_ct_bytes = d_ct; + buf.d_ss_bytes = d_ss; + buf.d_coins_kg = d_coins_kg; + buf.d_coins_enc = d_coins_enc; + } + + /* ---- Keygen ---- */ + CUDA_CHECK(cudaDeviceSynchronize()); + double t0 = get_time_ms(); + + for (int op = 0; op < n_ops; op++) { + if (use_pipeline) { + if (profile_pipeline && op == 0) + batch_keygen_pipelined_profile(d_pk, d_sk, &buf, batch_count); + else + batch_keygen_pipelined(d_pk, d_sk, &buf, batch_count); + } else { + int tpb = KEM_KEYGEN_TPB; + int blocks = (batch_count + tpb - 1) / tpb; + batch_kem_keypair_serial_kernel<<>>(d_pk, d_sk, d_coins_kg, batch_count); + } + } + CUDA_CHECK(cudaDeviceSynchronize()); + + double t_kg = (get_time_ms() - t0) / n_ops; + double ops_kg = batch_count * 1000.0 / t_kg; + printf(" Keygen: %7.1f ms/batch → %.0f ops/sec\n", t_kg, ops_kg); + + /* ---- Encaps ---- */ + t0 = get_time_ms(); + for (int op = 0; op < n_ops; op++) { + if (use_pipeline) { + batch_encaps_serial(d_ct, d_ss, d_pk, &buf, batch_count); + } else { + int tpb = KEM_ENCAPS_TPB; + int blocks = (batch_count + tpb - 1) / tpb; + batch_kem_encaps_serial_kernel<<>>(d_ct, d_ss, d_pk, d_coins_enc, batch_count); + } + } + CUDA_CHECK(cudaDeviceSynchronize()); + + double t_enc = (get_time_ms() - t0) / n_ops; + double ops_enc = batch_count * 1000.0 / t_enc; + printf(" Encaps: %7.1f ms/batch → %.0f ops/sec\n", t_enc, ops_enc); + + /* ---- Decaps ---- */ + t0 = get_time_ms(); + for (int op = 0; op < n_ops; op++) { + batch_decaps_serial(d_ss, d_ct, d_sk, batch_count); + } + CUDA_CHECK(cudaDeviceSynchronize()); + + double t_dec = (get_time_ms() - t0) / n_ops; + double ops_dec = batch_count * 1000.0 / t_dec; + printf(" Decaps: %7.1f ms/batch → %.0f ops/sec\n", t_dec, ops_dec); + + /* 清理 */ + if (use_pipeline) { + cudaFree(buf.d_mat); + cudaFree(buf.d_skpv); + cudaFree(buf.d_pkpv); + cudaFree(buf.d_e); + cudaFree(buf.d_publicseed_kg); + cudaFree(buf.d_noiseseed_kg); + } + cudaFree(d_pk); cudaFree(d_sk); cudaFree(d_ct); cudaFree(d_ss); + cudaFree(d_coins_kg); cudaFree(d_coins_enc); + free(h_coins_kg); free(h_coins_enc); +} + +static void run_serial_kem_round( + uint8_t *d_pk, uint8_t *d_sk, uint8_t *d_ct, uint8_t *d_ss, + uint8_t *d_coins_kg, uint8_t *d_coins_enc, + int batch_count, int n_ops) +{ + int kg_tpb = KEM_KEYGEN_TPB; + int kg_blocks = (batch_count + kg_tpb - 1) / kg_tpb; + int enc_tpb = KEM_ENCAPS_TPB; + int enc_blocks = (batch_count + enc_tpb - 1) / enc_tpb; + int dec_tpb = KEM_DECAPS_TPB; + int dec_blocks = (batch_count + dec_tpb - 1) / dec_tpb; + + for (int op = 0; op < n_ops; op++) { + batch_kem_keypair_serial_kernel<<>>(d_pk, d_sk, d_coins_kg, batch_count); + batch_kem_encaps_serial_kernel<<>>(d_ct, d_ss, d_pk, d_coins_enc, batch_count); + batch_kem_decaps_serial_kernel<<>>(d_ss, d_ct, d_sk, batch_count); + } +} + +static void bench_reuse_buffers(int batch_count, int rounds, int n_ops) +{ + printf("\n=== Buffer reuse benchmark: %s ===\n", algo_name()); + printf("batch=%d rounds=%d n_ops_per_round=%d\n", batch_count, rounds, n_ops); + + size_t pk_bytes = (size_t)batch_count * PARAM_PUBLICKEYBYTES; + size_t sk_bytes = (size_t)batch_count * PARAM_SECRETKEYBYTES; + size_t ct_bytes = (size_t)batch_count * PARAM_CIPHERTEXTBYTES; + size_t ss_bytes = (size_t)batch_count * PARAM_SSBYTES; + size_t kg_bytes = (size_t)batch_count * 2 * PARAM_SYMBYTES; + size_t enc_bytes = (size_t)batch_count * PARAM_SYMBYTES; + + uint8_t *h_coins_kg = (uint8_t *)malloc(kg_bytes); + uint8_t *h_coins_enc = (uint8_t *)malloc(enc_bytes); + if (!h_coins_kg || !h_coins_enc) { fprintf(stderr, "OOM\n"); exit(1); } + srand(9102); + for (size_t i = 0; i < kg_bytes; i++) h_coins_kg[i] = (uint8_t)(rand() & 0xFF); + for (size_t i = 0; i < enc_bytes; i++) h_coins_enc[i] = (uint8_t)(rand() & 0xFF); + + CUDA_CHECK(cudaDeviceSynchronize()); + double t0 = get_time_ms(); + for (int r = 0; r < rounds; r++) { + uint8_t *d_pk, *d_sk, *d_ct, *d_ss, *d_coins_kg, *d_coins_enc; + CUDA_CHECK(cudaMalloc(&d_pk, pk_bytes)); + CUDA_CHECK(cudaMalloc(&d_sk, sk_bytes)); + CUDA_CHECK(cudaMalloc(&d_ct, ct_bytes)); + CUDA_CHECK(cudaMalloc(&d_ss, ss_bytes)); + CUDA_CHECK(cudaMalloc(&d_coins_kg, kg_bytes)); + CUDA_CHECK(cudaMalloc(&d_coins_enc, enc_bytes)); + CUDA_CHECK(cudaMemcpy(d_coins_kg, h_coins_kg, kg_bytes, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_coins_enc, h_coins_enc, enc_bytes, cudaMemcpyHostToDevice)); + + run_serial_kem_round(d_pk, d_sk, d_ct, d_ss, d_coins_kg, d_coins_enc, batch_count, n_ops); + CUDA_CHECK(cudaDeviceSynchronize()); + + cudaFree(d_pk); cudaFree(d_sk); cudaFree(d_ct); cudaFree(d_ss); + cudaFree(d_coins_kg); cudaFree(d_coins_enc); + } + CUDA_CHECK(cudaDeviceSynchronize()); + double alloc_each_ms = get_time_ms() - t0; + + uint8_t *d_pk, *d_sk, *d_ct, *d_ss, *d_coins_kg, *d_coins_enc; + CUDA_CHECK(cudaMalloc(&d_pk, pk_bytes)); + CUDA_CHECK(cudaMalloc(&d_sk, sk_bytes)); + CUDA_CHECK(cudaMalloc(&d_ct, ct_bytes)); + CUDA_CHECK(cudaMalloc(&d_ss, ss_bytes)); + CUDA_CHECK(cudaMalloc(&d_coins_kg, kg_bytes)); + CUDA_CHECK(cudaMalloc(&d_coins_enc, enc_bytes)); + + CUDA_CHECK(cudaDeviceSynchronize()); + t0 = get_time_ms(); + for (int r = 0; r < rounds; r++) { + CUDA_CHECK(cudaMemcpy(d_coins_kg, h_coins_kg, kg_bytes, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_coins_enc, h_coins_enc, enc_bytes, cudaMemcpyHostToDevice)); + run_serial_kem_round(d_pk, d_sk, d_ct, d_ss, d_coins_kg, d_coins_enc, batch_count, n_ops); + CUDA_CHECK(cudaDeviceSynchronize()); + } + CUDA_CHECK(cudaDeviceSynchronize()); + double reuse_ms = get_time_ms() - t0; + + double total_instances = (double)batch_count * (double)rounds * (double)n_ops; + printf(" Alloc-each-round: total=%8.1f ms | per_round=%7.3f ms | full-kem throughput=%.0f instances/sec\n", + alloc_each_ms, alloc_each_ms / rounds, total_instances * 1000.0 / alloc_each_ms); + printf(" Reuse buffers: total=%8.1f ms | per_round=%7.3f ms | full-kem throughput=%.0f instances/sec\n", + reuse_ms, reuse_ms / rounds, total_instances * 1000.0 / reuse_ms); + printf(" Reuse speedup: %.3fx\n", alloc_each_ms / reuse_ms); + + cudaFree(d_pk); cudaFree(d_sk); cudaFree(d_ct); cudaFree(d_ss); + cudaFree(d_coins_kg); cudaFree(d_coins_enc); + free(h_coins_kg); free(h_coins_enc); +} + +/* ================================================================ + * Batch size 扫描 + * ================================================================ */ +static void bench_batch_streams(int batch_count, int n_ops, int nstreams) +{ + printf("\n--- batch=%d n_ops=%d mode=serial streams=%d ---\n", + batch_count, n_ops, nstreams); + + cudaStream_t *streams = (cudaStream_t *)calloc((size_t)nstreams, sizeof(cudaStream_t)); + uint8_t **d_pk = (uint8_t **)calloc((size_t)nstreams, sizeof(uint8_t *)); + uint8_t **d_sk = (uint8_t **)calloc((size_t)nstreams, sizeof(uint8_t *)); + uint8_t **d_ct = (uint8_t **)calloc((size_t)nstreams, sizeof(uint8_t *)); + uint8_t **d_ss = (uint8_t **)calloc((size_t)nstreams, sizeof(uint8_t *)); + uint8_t **d_coins_kg = (uint8_t **)calloc((size_t)nstreams, sizeof(uint8_t *)); + uint8_t **d_coins_enc = (uint8_t **)calloc((size_t)nstreams, sizeof(uint8_t *)); + if (!streams || !d_pk || !d_sk || !d_ct || !d_ss || !d_coins_kg || !d_coins_enc) { + fprintf(stderr, "OOM\n"); + exit(1); + } + + size_t kg_bytes = (size_t)batch_count * 2 * PARAM_SYMBYTES; + size_t enc_bytes = (size_t)batch_count * PARAM_SYMBYTES; + uint8_t *h_coins_kg = (uint8_t *)malloc(kg_bytes); + uint8_t *h_coins_enc = (uint8_t *)malloc(enc_bytes); + if (!h_coins_kg || !h_coins_enc) { fprintf(stderr, "OOM\n"); exit(1); } + + srand(5678); + for (size_t i = 0; i < kg_bytes; i++) h_coins_kg[i] = (uint8_t)(rand() & 0xFF); + for (size_t i = 0; i < enc_bytes; i++) h_coins_enc[i] = (uint8_t)(rand() & 0xFF); + + for (int s = 0; s < nstreams; s++) { + CUDA_CHECK(cudaStreamCreate(&streams[s])); + CUDA_CHECK(cudaMalloc(&d_pk[s], (size_t)batch_count * PARAM_PUBLICKEYBYTES)); + CUDA_CHECK(cudaMalloc(&d_sk[s], (size_t)batch_count * PARAM_SECRETKEYBYTES)); + CUDA_CHECK(cudaMalloc(&d_ct[s], (size_t)batch_count * PARAM_CIPHERTEXTBYTES)); + CUDA_CHECK(cudaMalloc(&d_ss[s], (size_t)batch_count * PARAM_SSBYTES)); + CUDA_CHECK(cudaMalloc(&d_coins_kg[s], kg_bytes)); + CUDA_CHECK(cudaMalloc(&d_coins_enc[s], enc_bytes)); + CUDA_CHECK(cudaMemcpyAsync(d_coins_kg[s], h_coins_kg, kg_bytes, cudaMemcpyHostToDevice, streams[s])); + CUDA_CHECK(cudaMemcpyAsync(d_coins_enc[s], h_coins_enc, enc_bytes, cudaMemcpyHostToDevice, streams[s])); + } + CUDA_CHECK(cudaDeviceSynchronize()); + + double total_ops = (double)batch_count * (double)nstreams; + + CUDA_CHECK(cudaDeviceSynchronize()); + double t0 = get_time_ms(); + int kg_tpb = KEM_KEYGEN_TPB; + int kg_blocks = (batch_count + kg_tpb - 1) / kg_tpb; + for (int op = 0; op < n_ops; op++) + for (int s = 0; s < nstreams; s++) + batch_kem_keypair_serial_kernel<<>>( + d_pk[s], d_sk[s], d_coins_kg[s], batch_count); + CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaGetLastError()); + double t_kg = (get_time_ms() - t0) / n_ops; + printf(" Keygen: %7.1f ms/round -> %.0f ops/sec\n", t_kg, total_ops * 1000.0 / t_kg); + + t0 = get_time_ms(); + int enc_tpb = KEM_ENCAPS_TPB; + int enc_blocks = (batch_count + enc_tpb - 1) / enc_tpb; + for (int op = 0; op < n_ops; op++) + for (int s = 0; s < nstreams; s++) + batch_kem_encaps_serial_kernel<<>>( + d_ct[s], d_ss[s], d_pk[s], d_coins_enc[s], batch_count); + CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaGetLastError()); + double t_enc = (get_time_ms() - t0) / n_ops; + printf(" Encaps: %7.1f ms/round -> %.0f ops/sec\n", t_enc, total_ops * 1000.0 / t_enc); + + t0 = get_time_ms(); + int dec_tpb = KEM_DECAPS_TPB; + int dec_blocks = (batch_count + dec_tpb - 1) / dec_tpb; + for (int op = 0; op < n_ops; op++) + for (int s = 0; s < nstreams; s++) + batch_kem_decaps_serial_kernel<<>>( + d_ss[s], d_ct[s], d_sk[s], batch_count); + CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaGetLastError()); + double t_dec = (get_time_ms() - t0) / n_ops; + printf(" Decaps: %7.1f ms/round -> %.0f ops/sec\n", t_dec, total_ops * 1000.0 / t_dec); + + for (int s = 0; s < nstreams; s++) { + cudaFree(d_pk[s]); cudaFree(d_sk[s]); cudaFree(d_ct[s]); cudaFree(d_ss[s]); + cudaFree(d_coins_kg[s]); cudaFree(d_coins_enc[s]); + cudaStreamDestroy(streams[s]); + } + free(h_coins_kg); free(h_coins_enc); + free(streams); free(d_pk); free(d_sk); free(d_ct); free(d_ss); free(d_coins_kg); free(d_coins_enc); +} + +static void bench_sweep(void) +{ + int sizes[] = { 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072 }; + int n = (int)(sizeof(sizes) / sizeof(sizes[0])); + printf("\n=== Batch size 扫描: %s ===\n", algo_name()); + for (int i = 0; i < n; i++) { + bench_batch(sizes[i], 3, 0); + } +} + +static const char *arg_value(int argc, char **argv, const char *name) +{ + for (int i = 1; i + 1 < argc; i++) { + if (strcmp(argv[i], name) == 0) return argv[i + 1]; + } + return NULL; +} + +static int has_arg(int argc, char **argv, const char *name) +{ + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], name) == 0) return 1; + } + return 0; +} + +static int read_file_all_host(const char *path, uint8_t **out, size_t *out_len) +{ + FILE *f = fopen(path, "rb"); + long n; + uint8_t *buf; + if (!f) { + fprintf(stderr, "open failed: %s\n", path); + return -1; + } + if (fseek(f, 0, SEEK_END) != 0) { + fclose(f); + return -1; + } + n = ftell(f); + if (n < 0) { + fclose(f); + return -1; + } + if (fseek(f, 0, SEEK_SET) != 0) { + fclose(f); + return -1; + } + buf = (uint8_t *)malloc((size_t)n + 1u); + if (!buf) { + fclose(f); + return -1; + } + if ((size_t)n > 0 && fread(buf, 1, (size_t)n, f) != (size_t)n) { + free(buf); + fclose(f); + return -1; + } + fclose(f); + buf[n] = 0; + *out = buf; + *out_len = (size_t)n; + return 0; +} + +static int read_file_exact_host(const char *path, uint8_t *buf, size_t len) +{ + uint8_t *tmp = NULL; + size_t n = 0; + int rc = read_file_all_host(path, &tmp, &n); + if (rc != 0) return rc; + if (n != len) { + fprintf(stderr, "size mismatch: %s expected %zu got %zu\n", path, len, n); + free(tmp); + return -1; + } + memcpy(buf, tmp, len); + free(tmp); + return 0; +} + +static int write_file_all_host(const char *path, const uint8_t *buf, size_t len) +{ + FILE *f = fopen(path, "wb"); + if (!f) { + fprintf(stderr, "write open failed: %s\n", path); + return -1; + } + if (len > 0 && fwrite(buf, 1, len, f) != len) { + fclose(f); + return -1; + } + fclose(f); + return 0; +} + +static void fill_random_host(uint8_t *buf, size_t len) +{ + FILE *f = fopen("/dev/urandom", "rb"); + if (f) { + size_t n = fread(buf, 1, len, f); + fclose(f); + if (n == len) return; + } + srand((unsigned)time(NULL)); + for (size_t i = 0; i < len; i++) buf[i] = (uint8_t)(rand() & 0xff); +} + +static void duplicate_record(uint8_t *dst, const uint8_t *src, size_t item_len, int batch_count) +{ + for (int i = 0; i < batch_count; i++) { + memcpy(dst + (size_t)i * item_len, src, item_len); + } +} + +static int run_kem_api_mode(int argc, char **argv, int batch_count) +{ + const int do_keygen = has_arg(argc, argv, "--api-kem-keygen"); + const int do_encaps = has_arg(argc, argv, "--api-kem-encaps"); + const int do_decaps = has_arg(argc, argv, "--api-kem-decaps"); + if (!do_keygen && !do_encaps && !do_decaps) return 0; + if ((do_keygen ? 1 : 0) + (do_encaps ? 1 : 0) + (do_decaps ? 1 : 0) != 1) { + fprintf(stderr, "select exactly one KEM API mode\n"); + return 2; + } + if (batch_count < 1) batch_count = 1; + + const size_t pk_batch_bytes = (size_t)batch_count * PARAM_PUBLICKEYBYTES; + const size_t sk_batch_bytes = (size_t)batch_count * PARAM_SECRETKEYBYTES; + const size_t ct_batch_bytes = (size_t)batch_count * PARAM_CIPHERTEXTBYTES; + const size_t ss_batch_bytes = (size_t)batch_count * PARAM_SSBYTES; + const size_t kg_bytes = (size_t)batch_count * 2 * PARAM_SYMBYTES; + const size_t enc_bytes = (size_t)batch_count * PARAM_SYMBYTES; + + if (do_keygen) { + const char *pk_out = arg_value(argc, argv, "--pk-out"); + const char *sk_out = arg_value(argc, argv, "--sk-out"); + if (!pk_out || !sk_out) { + fprintf(stderr, "--api-kem-keygen requires --pk-out and --sk-out\n"); + return 2; + } + + uint8_t *h_pk = (uint8_t *)malloc(PARAM_PUBLICKEYBYTES); + uint8_t *h_sk = (uint8_t *)malloc(PARAM_SECRETKEYBYTES); + uint8_t *h_coins_kg = (uint8_t *)malloc(kg_bytes); + uint8_t *d_pk = NULL, *d_sk = NULL, *d_coins_kg = NULL; + if (!h_pk || !h_sk || !h_coins_kg) { + fprintf(stderr, "KEM API malloc failed\n"); + free(h_pk); free(h_sk); free(h_coins_kg); + return 2; + } + fill_random_host(h_coins_kg, kg_bytes); + CUDA_CHECK(cudaMalloc(&d_pk, pk_batch_bytes)); + CUDA_CHECK(cudaMalloc(&d_sk, sk_batch_bytes)); + CUDA_CHECK(cudaMalloc(&d_coins_kg, kg_bytes)); + CUDA_CHECK(cudaMemcpy(d_coins_kg, h_coins_kg, kg_bytes, cudaMemcpyHostToDevice)); + int tpb = KEM_KEYGEN_TPB; + int blocks = (batch_count + tpb - 1) / tpb; + batch_kem_keypair_serial_kernel<<>>(d_pk, d_sk, d_coins_kg, batch_count); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaMemcpy(h_pk, d_pk, PARAM_PUBLICKEYBYTES, cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_sk, d_sk, PARAM_SECRETKEYBYTES, cudaMemcpyDeviceToHost)); + int rc = 0; + if (write_file_all_host(pk_out, h_pk, PARAM_PUBLICKEYBYTES) != 0 || + write_file_all_host(sk_out, h_sk, PARAM_SECRETKEYBYTES) != 0) rc = 2; + cudaFree(d_pk); cudaFree(d_sk); cudaFree(d_coins_kg); + free(h_pk); free(h_sk); free(h_coins_kg); + if (rc == 0) printf("API KEM keygen PASS batch=%d pk=%u sk=%u\n", batch_count, PARAM_PUBLICKEYBYTES, PARAM_SECRETKEYBYTES); + return rc == 0 ? 1 : rc; + } + + if (do_encaps) { + const char *pk_in = arg_value(argc, argv, "--pk-in"); + const char *ct_out = arg_value(argc, argv, "--ct-out"); + const char *ss_out = arg_value(argc, argv, "--ss-out"); + if (!pk_in || !ct_out || !ss_out) { + fprintf(stderr, "--api-kem-encaps requires --pk-in, --ct-out, and --ss-out\n"); + return 2; + } + + uint8_t *h_pk_one = (uint8_t *)malloc(PARAM_PUBLICKEYBYTES); + uint8_t *h_pk = (uint8_t *)malloc(pk_batch_bytes); + uint8_t *h_ct = (uint8_t *)malloc(PARAM_CIPHERTEXTBYTES); + uint8_t *h_ss = (uint8_t *)malloc(PARAM_SSBYTES); + uint8_t *h_coins_enc = (uint8_t *)malloc(enc_bytes); + uint8_t *d_pk = NULL, *d_ct = NULL, *d_ss = NULL, *d_coins_enc = NULL; + if (!h_pk_one || !h_pk || !h_ct || !h_ss || !h_coins_enc) { + fprintf(stderr, "KEM API malloc failed\n"); + free(h_pk_one); free(h_pk); free(h_ct); free(h_ss); free(h_coins_enc); + return 2; + } + if (read_file_exact_host(pk_in, h_pk_one, PARAM_PUBLICKEYBYTES) != 0) { + free(h_pk_one); free(h_pk); free(h_ct); free(h_ss); free(h_coins_enc); + return 2; + } + duplicate_record(h_pk, h_pk_one, PARAM_PUBLICKEYBYTES, batch_count); + fill_random_host(h_coins_enc, enc_bytes); + CUDA_CHECK(cudaMalloc(&d_pk, pk_batch_bytes)); + CUDA_CHECK(cudaMalloc(&d_ct, ct_batch_bytes)); + CUDA_CHECK(cudaMalloc(&d_ss, ss_batch_bytes)); + CUDA_CHECK(cudaMalloc(&d_coins_enc, enc_bytes)); + CUDA_CHECK(cudaMemcpy(d_pk, h_pk, pk_batch_bytes, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_coins_enc, h_coins_enc, enc_bytes, cudaMemcpyHostToDevice)); + int tpb = KEM_ENCAPS_TPB; + int blocks = (batch_count + tpb - 1) / tpb; + batch_kem_encaps_serial_kernel<<>>(d_ct, d_ss, d_pk, d_coins_enc, batch_count); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaMemcpy(h_ct, d_ct, PARAM_CIPHERTEXTBYTES, cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_ss, d_ss, PARAM_SSBYTES, cudaMemcpyDeviceToHost)); + int rc = 0; + if (write_file_all_host(ct_out, h_ct, PARAM_CIPHERTEXTBYTES) != 0 || + write_file_all_host(ss_out, h_ss, PARAM_SSBYTES) != 0) rc = 2; + cudaFree(d_pk); cudaFree(d_ct); cudaFree(d_ss); cudaFree(d_coins_enc); + free(h_pk_one); free(h_pk); free(h_ct); free(h_ss); free(h_coins_enc); + if (rc == 0) printf("API KEM encaps PASS batch=%d ct=%u ss=%u\n", batch_count, PARAM_CIPHERTEXTBYTES, PARAM_SSBYTES); + return rc == 0 ? 1 : rc; + } + + if (do_decaps) { + const char *sk_in = arg_value(argc, argv, "--sk-in"); + const char *ct_in = arg_value(argc, argv, "--ct-in"); + const char *ss_out = arg_value(argc, argv, "--ss-out"); + if (!sk_in || !ct_in || !ss_out) { + fprintf(stderr, "--api-kem-decaps requires --sk-in, --ct-in, and --ss-out\n"); + return 2; + } + + uint8_t *h_sk_one = (uint8_t *)malloc(PARAM_SECRETKEYBYTES); + uint8_t *h_ct_one = (uint8_t *)malloc(PARAM_CIPHERTEXTBYTES); + uint8_t *h_sk = (uint8_t *)malloc(sk_batch_bytes); + uint8_t *h_ct = (uint8_t *)malloc(ct_batch_bytes); + uint8_t *h_ss = (uint8_t *)malloc(PARAM_SSBYTES); + uint8_t *d_sk = NULL, *d_ct = NULL, *d_ss = NULL; + if (!h_sk_one || !h_ct_one || !h_sk || !h_ct || !h_ss) { + fprintf(stderr, "KEM API malloc failed\n"); + free(h_sk_one); free(h_ct_one); free(h_sk); free(h_ct); free(h_ss); + return 2; + } + if (read_file_exact_host(sk_in, h_sk_one, PARAM_SECRETKEYBYTES) != 0 || + read_file_exact_host(ct_in, h_ct_one, PARAM_CIPHERTEXTBYTES) != 0) { + free(h_sk_one); free(h_ct_one); free(h_sk); free(h_ct); free(h_ss); + return 2; + } + duplicate_record(h_sk, h_sk_one, PARAM_SECRETKEYBYTES, batch_count); + duplicate_record(h_ct, h_ct_one, PARAM_CIPHERTEXTBYTES, batch_count); + CUDA_CHECK(cudaMalloc(&d_sk, sk_batch_bytes)); + CUDA_CHECK(cudaMalloc(&d_ct, ct_batch_bytes)); + CUDA_CHECK(cudaMalloc(&d_ss, ss_batch_bytes)); + CUDA_CHECK(cudaMemcpy(d_sk, h_sk, sk_batch_bytes, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_ct, h_ct, ct_batch_bytes, cudaMemcpyHostToDevice)); + int tpb = KEM_DECAPS_TPB; + int blocks = (batch_count + tpb - 1) / tpb; + batch_kem_decaps_serial_kernel<<>>(d_ss, d_ct, d_sk, batch_count); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaMemcpy(h_ss, d_ss, PARAM_SSBYTES, cudaMemcpyDeviceToHost)); + int rc = write_file_all_host(ss_out, h_ss, PARAM_SSBYTES) != 0 ? 2 : 0; + cudaFree(d_sk); cudaFree(d_ct); cudaFree(d_ss); + free(h_sk_one); free(h_ct_one); free(h_sk); free(h_ct); free(h_ss); + if (rc == 0) printf("API KEM decaps PASS batch=%d ss=%u\n", batch_count, PARAM_SSBYTES); + return rc == 0 ? 1 : rc; + } + + return 0; +} + +/* ================================================================ + * 主函数 + * ================================================================ */ +int main(int argc, char **argv) +{ + /* 解析参数 */ + int batch_count = 65536; + int n_ops = 5; + int do_sweep = 0; + int run_pipeline = 0; + int do_correctness = 1; + int nstreams = 1; + int profile_pipeline = 0; + int reuse_rounds = 0; + + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "--batch") == 0 && i + 1 < argc) + batch_count = atoi(argv[++i]); + else if (strcmp(argv[i], "--n-ops") == 0 && i + 1 < argc) + n_ops = atoi(argv[++i]); + else if (strcmp(argv[i], "--sweep") == 0) + do_sweep = 1; + else if (strcmp(argv[i], "--serial-only") == 0) + run_pipeline = 0; + else if (strcmp(argv[i], "--pipeline") == 0) + run_pipeline = 1; + else if (strcmp(argv[i], "--no-correctness") == 0) + do_correctness = 0; + else if (strcmp(argv[i], "--streams") == 0 && i + 1 < argc) + nstreams = atoi(argv[++i]); + else if (strcmp(argv[i], "--profile-pipeline") == 0) + profile_pipeline = 1; + else if (strcmp(argv[i], "--reuse-bench") == 0 && i + 1 < argc) + reuse_rounds = atoi(argv[++i]); + } + + /* 打印设备信息 */ + int dev; + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDevice(&dev)); + CUDA_CHECK(cudaGetDeviceProperties(&prop, dev)); + #if GPU_USE_HIP + printf("GPU: %s (%s, %d CUs, %.1f GB VRAM)\n", + prop.name, + prop.gcnArchName, + prop.multiProcessorCount, + prop.totalGlobalMem / 1e9); + #else + printf("GPU: %s (SM %d.%d, %d SMs, %.1f GB VRAM)\n", + prop.name, prop.major, prop.minor, + prop.multiProcessorCount, + prop.totalGlobalMem / 1e9); + #endif + printf("Runtime: %s\n", GPU_RUNTIME_NAME); + printf("Algorithm: %s K=%d Q=%d\n", algo_name(), PARAM_K, PARAM_Q); + + /* 设置 GPU 堆栈大小 (kem 函数需要 ~20KB 堆栈) */ + { + cudaError_t se = cudaDeviceSetLimit(cudaLimitStackSize, 64 * 1024); + if (se != cudaSuccess) { + fprintf(stderr, "Warning: cudaDeviceSetLimit(stack, 64KB) failed: %s\n", + cudaGetErrorString(se)); + cudaGetLastError(); /* 清除错误状态 */ + } + } + + int api_rc = run_kem_api_mode(argc, argv, batch_count); + if (api_rc != 0) return api_rc == 1 ? 0 : api_rc; + + /* 正确性测试 */ + if (do_correctness) { + int ret = test_correctness(); + if (ret != 0) { + fprintf(stderr, "正确性测试失败,中止性能测试\n"); + return ret; + } + printf("\n"); + } + + /* 吞吐量测试 */ + if (reuse_rounds > 0) { + bench_reuse_buffers(batch_count, reuse_rounds, n_ops); + } else if (do_sweep) { + bench_sweep(); + } else { + printf("=== 吞吐量测试: %s ===\n", algo_name()); + if (nstreams > 1) + bench_batch_streams(batch_count, n_ops, nstreams); + else + bench_batch(batch_count, n_ops, 0, 0); /* serial mode default */ + if (run_pipeline) { + /* 流水线模式 */ + bench_batch(batch_count, n_ops, 1, profile_pipeline); + } + } + + printf("\n完成.\n"); + return 0; +} diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/ntt.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/ntt.cuh new file mode 100644 index 000000000..9becfed1f --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/ntt.cuh @@ -0,0 +1,364 @@ +/* + * ntt.cuh — 统一 NTT / INVNTT + * + * 两种算法使用完全相同的 Cooley-Tukey 蝶形结构,差异仅在于: + * 1. Q 值不同 → zeta 表数值不同 + * 2. NTT 级数: Kyber=7级(不完全NTT, 最后用 basemul), Aigis-enc=8级(完全NTT, 逐点乘) + * + * Kyber NTT 说明: + * Q=3329, Q-1=3328=2^8*13, 故 256 | Q-1 + * 使用 7 级 NTT,结果为 128 对 (a[2i], a[2i+1]) 分别处于 + * 二次扩域 Z_q[x]/(x^2 - ζ_i^2) 中 + * 乘法通过 basemul 完成 (见 poly.cuh) + * zetas 表: 128 个元素 (indices 1..127, index 0 未使用) + * + * Aigis NTT 说明: + * Q=7681, Q-1=7680=2^9*3*5, 故 256 | Q-1 + * 使用完整 8 级 NTT,结果为 256 个线性因子元素 + * 乘法为逐点 Montgomery 乘法 + * zetas 表: 256 个元素; zetas_inv 表: 256 个元素 + */ + +#ifndef NTT_CUH +#define NTT_CUH + +#include +#include "params.h" +#include "reduce.cuh" + +/* ================================================================ + * Kyber Zeta 表 (128 个元素, indices 1..127) + * 来源: CRYSTALS-Kyber reference implementation + * ================================================================ */ +#if ALGORITHM == ALGO_KYBER + +__constant__ int16_t ntt_zetas[128] = { + -1044, -758, -359, -1517, 1493, 1422, 287, 202, + -171, 622, 1577, 182, 962, -1202, -1474, 1468, + 573, -1325, 264, 383, -829, 1458, -1602, -130, + -681, 1017, 732, 608, -1542, 411, -205, -1571, + 1223, 652, -552, 1015, -1293, 1491, -282, -1544, + 516, -8, -320, -666, -1618, -1162, 126, 1469, + -853, -90, -271, 830, 107, -1421, -247, -951, + -398, 961, -1508, -725, 448, -1065, 677, -1275, + -1103, 430, 555, 843, -1251, 871, 1550, 105, + 422, 587, 177, -235, -291, -460, 1574, 1653, + -246, 778, 1159, -147, -777, 1483, -602, 1119, + -1590, 644, -872, 349, 418, 329, -156, -75, + 817, 1097, 603, 610, 1322, -1285, -1465, 384, + -1215, -136, 1218, -1335, -874, 220, -1187, -1659, + -1185, -1530, -1278, 794, -1510, -854, -870, 478, + -108, -308, 996, 991, 958, -1460, 1522, 1628 +}; + +/* ================================================================ + * Aigis-enc Zeta 表 (256 个元素, 支持完整 8 级 NTT) + * Q=7681 时使用的 zetas 和 zetas_inv + * ================================================================ */ +#elif ALGORITHM == ALGO_AIGIS_ENC + +__constant__ int16_t ntt_zetas[256] = { + 0,3777,-3182,3625,-3696,-1100,2456,2194,121,-2250,834,-2495,-2319,2876,-1701,1414, + 2816,-2088,-2237,1986,-1599,1993,3706,-2006,-1525,-2557,1296,1483,-2830,3364,617,1921, + -3689,-1738,3266,-3600,810,1887,-638,-7,-438,-679,-1305,-1760,396,-3174,-3555,-1881, + 3772,-2535,-2440,-2555,1535,-549,3153,2310,-1399,1321,514,-2956,-103,2804,-2043,-1431, + -1054,1698,-3456,1166,2426,3831,915,-2,-3417,-194,2919,2789,3405,2385,-2113,-2732, + 2175,373,3692,-730,-1756,3135,-2391,660,-1497,2572,-3145,1350,-2224,-3588,-1681,2883, + -1390,1598,3750,2762,2835,2764,-2233,3816,-1533,1464,-727,1521,1386,-3428,-921,-2743, + -2160,2649,-859,2579,1532,1919,-486,404,-1056,783,1799,-2665,3480,2133,-3310,-1168, + -17,3744,2422,2001,1278,929,-1348,-2230,-179,-1242,-2059,-1070,2161,1649,2072,3177, + -2071,1121,-436,236,715,670,-658,-1476,-2378,2767,3542,-226,1203,1181,-151,-3794, + 1712,-222,2786,-451,-3547,1779,-1151,-434,3568,-3693,3581,-1586,1509,2918,2339,-1407, + 3434,-3550,2340,2891,2998,-3314,3461,-2719,-2247,-2589,1144,1072,1295,-2815,-3770,3450, + 3781,-2258,796,3163,-3208,-589,2963,-124,3214,3334,-3366,-3745,3723,1931,-429,-402, + -3408,83,-1526,826,-1338,2345,-2303,2515,-642,-1837,-2965,-791,370,293,3312,2083, + -1689,-777,2070,2262,-893,2386,-188,-1519,-2874,-1404,1012,2130,1441,2532,-3335,-1084, + -3343,2937,509,-1403,2812,3763,592,2005,3657,2460,-3677,3752,692,1669,2167,-3287 +}; + +__constant__ int16_t ntt_zetas_inv[256] = { + 3287,-2167,-1669,-692,-3752,3677,-2460,-3657,-2005,-592,-3763,-2812,1403,-509,-2937,3343, + 1084,3335,-2532,-1441,-2130,-1012,1404,2874,1519,188,-2386,893,-2262,-2070,777,1689, + -2083,-3312,-293,-370,791,2965,1837,642,-2515,2303,-2345,1338,-826,1526,-83,3408, + 402,429,-1931,-3723,3745,3366,-3334,-3214,124,-2963,589,3208,-3163,-796,2258,-3781, + -3450,3770,2815,-1295,-1072,-1144,2589,2247,2719,-3461,3314,-2998,-2891,-2340,3550,-3434, + 1407,-2339,-2918,-1509,1586,-3581,3693,-3568,434,1151,-1779,3547,451,-2786,222,-1712, + 3794,151,-1181,-1203,226,-3542,-2767,2378,1476,658,-670,-715,-236,436,-1121,2071, + -3177,-2072,-1649,-2161,1070,2059,1242,179,2230,1348,-929,-1278,-2001,-2422,-3744,17, + 1168,3310,-2133,-3480,2665,-1799,-783,1056,-404,486,-1919,-1532,-2579,859,-2649,2160, + 2743,921,3428,-1386,-1521,727,-1464,1533,-3816,2233,-2764,-2835,-2762,-3750,-1598,1390, + -2883,1681,3588,2224,-1350,3145,-2572,1497,-660,2391,-3135,1756,730,-3692,-373,-2175, + 2732,2113,-2385,-3405,-2789,-2919,194,3417,2,-915,-3831,-2426,-1166,3456,-1698,1054, + 1431,2043,-2804,103,2956,-514,-1321,1399,-2310,-3153,549,-1535,2555,2440,2535,-3772, + 1881,3555,3174,-396,1760,1305,679,438,7,638,-1887,-810,3600,-3266,1738,3689, + -1921,-617,-3364,2830,-1483,-1296,2557,1525,2006,-3706,-1993,1599,-1986,2237,2088,-2816, + -1414,1701,-2876,2319,2495,-834,2250,-121,-2194,-2456,1100,3696,-3625,3182,-1905 +}; + +/* Aigis INVNTT 归一化因子: mont_invn = N^{-1} * R mod Q + * N=256, R=2^16=65536 + * 256^{-1} mod 7681 = 7651 (256*7651 = 1958656 ≡ 1 mod 7681 ✓) + * mont_invn = 7651 * R mod Q -- 但这已经是 zetas_inv 最后一项的乘积 + * 实际: Aigis INVNTT 把 N^{-1} 折叠进最后一级蝶形 (level 7, step=128) */ +#endif /* ALGORITHM */ + +/* ================================================================ + * 串行 NTT (单线程,用于 INDCPA 的设备内调用) + * ================================================================ */ + +#if ALGORITHM == ALGO_KYBER + +/* Kyber 7 级 NTT + * 蝶形: t = fqmul(zeta, a[j+len]); a[j+len] = a[j]-t; a[j] = a[j]+t + * 与参考实现完全一致 */ +static __device__ __noinline__ void ntt(int16_t r[256]) +{ + unsigned int len, start, j, k; + int16_t t; + + k = 1; + for (len = 128; len >= 2; len >>= 1) { + for (start = 0; start < 256; start = j + len) { + int16_t zeta = ntt_zetas[k++]; + for (j = start; j < start + len; j++) { + t = fqmul(zeta, r[j + len]); + r[j + len] = r[j] - t; + r[j] = r[j] + t; + } + } + } +} + +/* Kyber 7 级 INVNTT + basemul 归一化 + * 末级乘以 f = 1441 = mont(mont(3303)) = (2^16)^2 * 3303 mod Q / Q mod Q... + * 实际: 参考实现在最后乘以 f=1441 (= N^{-1} * 2^{32} mod Q in Mont) + * 这里用 f 把系数从 NTT 域缩放回正常范围 */ +static __device__ __noinline__ void invntt(int16_t r[256]) +{ + unsigned int start, len, j, k; + int16_t t; + const int16_t f = 1441; /* mont(3303) 归一化常数 */ + + k = 127; + for (len = 2; len <= 128; len <<= 1) { + for (start = 0; start < 256; start = j + len) { + int16_t zeta = ntt_zetas[k--]; + for (j = start; j < start + len; j++) { + t = r[j]; + r[j] = barrett_reduce((int16_t)(t + r[j + len])); + r[j + len] = fqmul(zeta, (int16_t)(r[j + len] - t)); + } + } + } + for (j = 0; j < 256; j++) + r[j] = fqmul(r[j], f); +} + +/* Kyber basemul: 两个度-1 多项式在 Z_q[x]/(x^2-ζ) 中的乘积 + * r[0] = a[0]*b[0] + a[1]*b[1]*zeta + * r[1] = a[0]*b[1] + a[1]*b[0] */ +static __device__ __forceinline__ void basemul(int16_t r[2], + const int16_t a[2], const int16_t b[2], int16_t zeta) +{ + r[0] = fqmul(a[1], b[1]); + r[0] = fqmul(r[0], zeta); + r[0] += fqmul(a[0], b[0]); + r[1] = fqmul(a[0], b[1]); + r[1] += fqmul(a[1], b[0]); +} + +/* Kyber polyvec_basemul_acc: r = sum_j a[j] (*) b[j] (basemul 域内积) + * 每次处理 4 个系数: [4i,4i+1] 用 +zeta[64+i], [4i+2,4i+3] 用 -zeta[64+i] + * zeta 表共 128 项,[64..127] 用于最后 NTT 级别和 basemul */ +static __device__ __noinline__ void polyvec_basemul_acc(kem_poly *r, const kem_polyvec *a, const kem_polyvec *b) +{ + for (int i = 0; i < PARAM_N / 4; i++) { + int16_t zeta = ntt_zetas[64 + i]; /* indices 64..127 */ + int16_t acc0, acc1; + + /* 第一对 [4i, 4i+1]: 使用 +zeta */ + acc0 = 0; acc1 = 0; + for (int j = 0; j < PARAM_K; j++) { + int16_t tmp[2]; + basemul(tmp, &a->vec[j].coeffs[4*i], &b->vec[j].coeffs[4*i], zeta); + acc0 += tmp[0]; + acc1 += tmp[1]; + } + r->coeffs[4*i] = barrett_reduce(acc0); + r->coeffs[4*i+1] = barrett_reduce(acc1); + + /* 第二对 [4i+2, 4i+3]: 使用 -zeta */ + int16_t neg_zeta = (int16_t)(-zeta); + acc0 = 0; acc1 = 0; + for (int j = 0; j < PARAM_K; j++) { + int16_t tmp[2]; + basemul(tmp, &a->vec[j].coeffs[4*i+2], &b->vec[j].coeffs[4*i+2], neg_zeta); + acc0 += tmp[0]; + acc1 += tmp[1]; + } + r->coeffs[4*i+2] = barrett_reduce(acc0); + r->coeffs[4*i+3] = barrett_reduce(acc1); + } +} + +#elif ALGORITHM == ALGO_AIGIS_ENC + +/* Aigis 8 级 NTT (串行) + * 使用 ntt_zetas[1..255] (index 0 未使用) */ +static __device__ __noinline__ void ntt(int16_t r[256]) +{ + int start, j, k, step, level; + int16_t t; + + k = 1; + /* level 7: step=128 */ + step = 128; + for (start = 0; start < 256; start = j + step) { + int16_t zeta = ntt_zetas[k++]; + for (j = start; j < start + step; ++j) { + t = fqmul(zeta, r[j + step]); + r[j + step] = r[j] - t; + r[j] = r[j] + t; + } + } + /* level 6: step=64 */ + step = 64; + for (start = 0; start < 256; start = j + step) { + int16_t zeta = ntt_zetas[k++]; + for (j = start; j < start + step; ++j) { + t = fqmul(zeta, r[j + step]); + r[j + step] = barrett_reduce(r[j] - t); + r[j] = barrett_reduce(r[j] + t); + } + } + /* levels 5,4 */ + for (level = 5; level >= 4; level--) { + step = (1 << level); + for (start = 0; start < 256; start = j + step) { + int16_t zeta = ntt_zetas[k++]; + for (j = start; j < start + step; ++j) { + t = fqmul(zeta, r[j + step]); + r[j + step] = r[j] - t; + r[j] = r[j] + t; + } + } + } + /* level 3: step=8 */ + step = 8; + for (start = 0; start < 256; start = j + step) { + int16_t zeta = ntt_zetas[k++]; + for (j = start; j < start + step; ++j) { + t = fqmul(zeta, r[j + step]); + r[j + step] = barrett_reduce(r[j] - t); + r[j] = barrett_reduce(r[j] + t); + } + } + /* levels 2,1 */ + for (level = 2; level >= 1; level--) { + step = (1 << level); + for (start = 0; start < 256; start = j + step) { + int16_t zeta = ntt_zetas[k++]; + for (j = start; j < start + step; ++j) { + t = fqmul(zeta, r[j + step]); + r[j + step] = r[j] - t; + r[j] = r[j] + t; + } + } + } + /* level 0: step=1 */ + step = 1; + for (start = 0; start < 256; start = j + step) { + int16_t zeta = ntt_zetas[k++]; + for (j = start; j < start + step; ++j) { + t = fqmul(zeta, r[j + step]); + r[j + step] = barrett_reduce(r[j] - t); + r[j] = barrett_reduce(r[j] + t); + } + } +} + +/* Aigis 8 级 INVNTT (串行) + * 使用 int32_t 中间变量 t 以避免上溢,与 CPU 参考一致 */ +static __device__ __noinline__ void invntt(int16_t r[256]) +{ + int start, level, step, j, k; + int32_t t; + + k = 0; + for (level = 0; level < 7; level++) { + step = (1 << level); + for (start = 0; start < 256; start = j + step) { + int32_t zeta = ntt_zetas_inv[k++]; + for (j = start; j < start + step; ++j) { + t = r[j]; + if (level & 1) + r[j] = barrett_reduce((int16_t)(t + r[j + step])); + else + r[j] = (int16_t)(t + r[j + step]); + t -= r[j + step]; + r[j + step] = montgomery_reduce((int32_t)zeta * (int16_t)t); + } + } + } + /* level 7: step=128, 含 N^{-1} 归一化 + * montgomery_reduce(256 * a) = a * 256 * R^{-1} mod Q = a * N^{-1} mod Q */ + step = 128; + for (start = 0; start < 256; start = j + step) { + int32_t zeta = ntt_zetas_inv[k++]; + for (j = start; j < start + step; ++j) { + t = r[j]; + r[j] = montgomery_reduce(256 * (t + r[j + step])); + t -= r[j + step]; + r[j + step] = montgomery_reduce(zeta * (int16_t)t); + } + } +} + +/* Aigis 逐点累加 (polyvec 内积) + * 参考实现: 先将 b 转换到 Montgomery 域 (b*R mod Q),再做 montgomery_reduce(a * b*R) = a*b mod Q + * 与参考 pqc_polyvec_pointwise_acc 完全等价 */ +static __device__ __noinline__ void polyvec_basemul_acc(kem_poly *r, const kem_polyvec *a, const kem_polyvec *b) +{ + for (int c = 0; c < PARAM_N; c++) { + /* 先处理 j=0 */ + int16_t t = montgomery_reduce((int32_t)MONT_R2 * b->vec[0].coeffs[c]); + r->coeffs[c] = montgomery_reduce((int32_t)a->vec[0].coeffs[c] * t); + /* 累加剩余 j=1..K-1 */ + for (int j = 1; j < PARAM_K; j++) { + t = montgomery_reduce((int32_t)MONT_R2 * b->vec[j].coeffs[c]); + r->coeffs[c] += montgomery_reduce((int32_t)a->vec[j].coeffs[c] * t); + } + r->coeffs[c] = barrett_reduce(r->coeffs[c]); + } +} + +#endif /* ALGORITHM */ + +/* ================================================================ + * polyvec_ntt / polyvec_invntt — 对向量中每个多项式做 NTT/INVNTT + * ================================================================ */ +static __device__ __noinline__ void polyvec_ntt(kem_polyvec *pv) +{ + for (int i = 0; i < PARAM_K; i++) { + ntt(pv->vec[i].coeffs); +#if ALGORITHM == ALGO_KYBER + /* Kyber NTT 不在内部归一化 (级间无 Barrett reduce). + * 参考实现 poly_ntt() 总是在 ntt() 后调用 poly_reduce(). + * 不做这一步,NTT 输出可达 ±8Q,导致 fqmul 时 Montgomery 越界. */ + for (int j = 0; j < PARAM_N; j++) + pv->vec[i].coeffs[j] = barrett_reduce(pv->vec[i].coeffs[j]); +#endif + } +} + +static __device__ __noinline__ void polyvec_invntt(kem_polyvec *pv) +{ + for (int i = 0; i < PARAM_K; i++) invntt(pv->vec[i].coeffs); +} + +static __device__ __noinline__ void poly_invntt(kem_poly *p) +{ + invntt(p->coeffs); +} + +#endif /* NTT_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/params.h b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/params.h new file mode 100644 index 000000000..f0c7a713e --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/params.h @@ -0,0 +1,183 @@ +/* + * params.h — 统一参数头文件 + * + * Kyber (CRYSTALS-Kyber) 和 Aigis-enc (PQMagic KEM) 通过同一套宏描述参数。 + * 两种算法均使用 int16_t 系数,N=256,结构相同。 + * + * 关键算法差异: + * NTT 阶数: Kyber=7级+basemul (Q≡1 mod 256), Aigis=8级+逐点 (Q≡1 mod 256) + * PK 打包: Kyber=12-bit tobytes, Aigis=压缩至 BITS_PK bits + * CT 符号: Kyber v = pk*r + e2 + msg;Aigis v = pk*r + e2 - msg + * 拒绝采样: Kyber=12-bit (4096), Aigis=13-bit (8192) + * + * Kyber 参数来源: CRYSTALS-Kyber specification (NIST PQC Round 3) + * Aigis-enc 参数来源: PQMagic CPU 实现 (AIGIS_ENC_MODE=1/2/3/4) + */ + +#ifndef PARAMS_H +#define PARAMS_H + +#include "config.h" +#include + +/* ================================================================ + * 通用类型和常量 + * ================================================================ */ +typedef int16_t coeff_t; /* 两种算法均适用 (Q < 2^13 < 2^15) */ + +#define PARAM_N 256 +#define PARAM_SYMBYTES 32 +#define PARAM_SSBYTES 32 + +/* ================================================================ + * Kyber 参数 + * ================================================================ */ +#if ALGORITHM == ALGO_KYBER + +#define PARAM_Q 3329 +#define PARAM_QBITS 12 /* 拒绝采样位宽 */ +#define PARAM_QINV 62209 /* Q^{-1} mod 2^16 (used as int16 signed = -3327) */ + +/* Montgomery 常数: R=2^16 + * MONT_R2 = R^2 mod Q = 1353 (用于转换到 Mont 域) */ +#define MONT_R2 1353 + +#define PARAM_ETA2 2 /* 加密误差 eta */ + +#if PARAM_MODE == 2 /* Kyber512 */ + #define PARAM_K 2 + #define PARAM_ETA1 3 + #define PARAM_BITS_PK 12 /* pk = polyvec_tobytes12 (无压缩损失) */ + #define PARAM_BITS_C1 10 /* ct 向量压缩位数 */ + #define PARAM_BITS_C2 4 /* ct 标量多项式压缩位数 */ + #define CRYPTO_ALGNAME "Kyber512" + +#elif PARAM_MODE == 3 /* Kyber768 */ + #define PARAM_K 3 + #define PARAM_ETA1 2 + #define PARAM_BITS_PK 12 + #define PARAM_BITS_C1 10 + #define PARAM_BITS_C2 4 + #define CRYPTO_ALGNAME "Kyber768" + +#elif PARAM_MODE == 4 /* Kyber1024 */ + #define PARAM_K 4 + #define PARAM_ETA1 2 + #define PARAM_BITS_PK 12 + #define PARAM_BITS_C1 11 + #define PARAM_BITS_C2 5 + #define CRYPTO_ALGNAME "Kyber1024" + +#else + #error "PARAM_MODE must be 2, 3, or 4 for Kyber" +#endif + +/* Kyber 噪声 eta: s 和 e 用 ETA1,加密噪声用 ETA1/ETA2 */ +#define PARAM_ETA_S PARAM_ETA1 +#define PARAM_ETA_E_KG PARAM_ETA1 /* 密钥生成误差 */ +#define PARAM_ETA_E_ENC PARAM_ETA1 /* 加密误差 (e1) */ +#define PARAM_ETA_E2 PARAM_ETA2 /* 加密标量误差 (e2) */ + +/* Kyber 全精度多项式字节数 (12-bit * 256 = 384 bytes) */ +#define PARAM_POLYBYTES 384 + +/* Kyber PRF 输出长度 (bytes): ETA * N / 4 */ +#define PARAM_PRF_ETA1_BYTES (PARAM_ETA1 * PARAM_N / 4) +#define PARAM_PRF_ETA2_BYTES (PARAM_ETA2 * PARAM_N / 4) + +/* ================================================================ + * Aigis-enc 参数 + * ================================================================ */ +#elif ALGORITHM == ALGO_AIGIS_ENC + +#define PARAM_Q 7681 +#define PARAM_QBITS 13 /* 拒绝采样位宽 */ +#define PARAM_QINV 57857 /* Q^{-1} mod 2^16 */ + +#define MONT_R2 5569 /* R^2 mod Q */ + +#if PARAM_MODE == 1 /* Aigis-enc-1 (K=2) */ + #define PARAM_K 2 + #define PARAM_ETA_S 4 + #define PARAM_ETA_E_KG 8 + #define PARAM_ETA_E_ENC 8 + #define PARAM_ETA_E2 8 + #define PARAM_BITS_PK 10 + #define PARAM_BITS_C1 10 + #define PARAM_BITS_C2 3 + #define CRYPTO_ALGNAME "Aigis-enc-1" + +#elif PARAM_MODE == 2 /* Aigis-enc-2 (K=3, low) */ + #define PARAM_K 3 + #define PARAM_ETA_S 1 + #define PARAM_ETA_E_KG 4 + #define PARAM_ETA_E_ENC 4 + #define PARAM_ETA_E2 4 + #define PARAM_BITS_PK 9 + #define PARAM_BITS_C1 9 + #define PARAM_BITS_C2 4 + #define CRYPTO_ALGNAME "Aigis-enc-2" + +#elif PARAM_MODE == 3 /* Aigis-enc-3 (K=3, med) */ + #define PARAM_K 3 + #define PARAM_ETA_S 2 + #define PARAM_ETA_E_KG 4 + #define PARAM_ETA_E_ENC 4 + #define PARAM_ETA_E2 4 + #define PARAM_BITS_PK 10 + #define PARAM_BITS_C1 10 + #define PARAM_BITS_C2 3 + #define CRYPTO_ALGNAME "Aigis-enc-3" + +#elif PARAM_MODE == 4 /* Aigis-enc-4 (K=4, high) */ + #define PARAM_K 4 + #define PARAM_ETA_S 3 + #define PARAM_ETA_E_KG 8 + #define PARAM_ETA_E_ENC 8 + #define PARAM_ETA_E2 8 + #define PARAM_BITS_PK 11 + #define PARAM_BITS_C1 11 + #define PARAM_BITS_C2 5 + #define CRYPTO_ALGNAME "Aigis-enc-4" + +#else + #error "PARAM_MODE must be 1, 2, 3, or 4 for Aigis-enc" +#endif + +/* Aigis 全精度多项式字节数 (13-bit * 256 = 416 bytes) */ +#define PARAM_POLYBYTES 416 + +/* Aigis PRF 输出长度 (用最大 eta 覆盖; 实际只需 eta*N/4 字节) */ +#define PARAM_PRF_ETA1_BYTES (PARAM_ETA_S * 64) +#define PARAM_PRF_ETA2_BYTES (PARAM_ETA_E_KG * 64) + +#endif /* ALGORITHM */ + +/* ================================================================ + * 派生常量 (两种算法共用公式) + * ================================================================ */ +#define PARAM_POLYVECBYTES (PARAM_K * PARAM_POLYBYTES) +#define PARAM_PK_POLYVEC_BYTES (PARAM_BITS_PK * PARAM_K * PARAM_N / 8) +#define PARAM_CT_VEC_BYTES (PARAM_BITS_C1 * PARAM_K * PARAM_N / 8) +#define PARAM_CT_POLY_BYTES (PARAM_BITS_C2 * PARAM_N / 8) + +#define PARAM_PUBLICKEYBYTES (PARAM_PK_POLYVEC_BYTES + PARAM_SYMBYTES) +#define PARAM_INDCPA_SECRETKEYBYTES PARAM_POLYVECBYTES +#define PARAM_SECRETKEYBYTES (PARAM_POLYVECBYTES + PARAM_PUBLICKEYBYTES + 2 * PARAM_SYMBYTES) +#define PARAM_CIPHERTEXTBYTES (PARAM_CT_VEC_BYTES + PARAM_CT_POLY_BYTES) + +/* 矩阵生成 XOF 缓冲区大小 */ +#define PARAM_GEN_MATRIX_NBLOCKS 4 +#define PARAM_XOF_BLOCKBYTES 168 /* SHAKE128_RATE */ +#define PARAM_GEN_MATRIX_BUFLEN (PARAM_GEN_MATRIX_NBLOCKS * PARAM_XOF_BLOCKBYTES) + +/* 最大 K (用于固定大小的 struct) */ +#define MAX_K 4 + +/* ================================================================ + * 多项式结构体 + * ================================================================ */ +typedef struct { int16_t coeffs[PARAM_N]; } kem_poly; +typedef struct { kem_poly vec[MAX_K]; } kem_polyvec; + +#endif /* PARAMS_H */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/poly.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/poly.cuh new file mode 100644 index 000000000..c9d161d8f --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/poly.cuh @@ -0,0 +1,252 @@ +/* + * poly.cuh — 统一多项式运算 + * + * 支持 Kyber (Q=3329, 12-bit) 和 Aigis-enc (Q=7681, 13-bit) + * 差异通过 #if ALGORITHM 编译时分支处理 + * + * 主要差异: + * frommsg/tomsg: Kyber 阈值 Q/2, Aigis 阈值 (Q+1)/2 + * tobytes/frombytes: 12-bit (Kyber) vs 13-bit (Aigis) + * compress_c2/decompress_c2: 4-bit 或 5-bit (Kyber) / 3-bit,4-bit,5-bit (Aigis) + */ + +#ifndef POLY_CUH +#define POLY_CUH + +#include +#include "params.h" +#include "reduce.cuh" +#include "ntt.cuh" +#include "cbd.cuh" + +/* ================================================================ + * 基础多项式算术 + * ================================================================ */ + +static __device__ void poly_add(kem_poly *r, const kem_poly *a, const kem_poly *b) +{ + for (int i = 0; i < PARAM_N; i++) r->coeffs[i] = a->coeffs[i] + b->coeffs[i]; +} + +static __device__ void poly_sub(kem_poly *r, const kem_poly *a, const kem_poly *b) +{ + for (int i = 0; i < PARAM_N; i++) r->coeffs[i] = a->coeffs[i] - b->coeffs[i]; +} + +static __device__ void poly_reduce(kem_poly *r) +{ + for (int i = 0; i < PARAM_N; i++) r->coeffs[i] = barrett_reduce(r->coeffs[i]); +} + +static __device__ void poly_caddq(kem_poly *r) +{ + for (int i = 0; i < PARAM_N; i++) r->coeffs[i] = caddq(r->coeffs[i]); +} + +static __device__ void poly_caddq2(kem_poly *r) +{ + for (int i = 0; i < PARAM_N; i++) r->coeffs[i] = caddq2(r->coeffs[i]); +} + +/* ================================================================ + * 消息编解码 + * frommsg: {0,1}^256 → poly in [0, Q) + * tomsg: poly → {0,1}^256 + * ================================================================ */ + +static __device__ void poly_frommsg(kem_poly *r, const uint8_t *msg) +{ + for (int i = 0; i < PARAM_N / 8; i++) { + for (int j = 0; j < 8; j++) { + int16_t mask = -((msg[i] >> j) & 1); /* 0xFFFF if bit=1, else 0 */ +#if ALGORITHM == ALGO_KYBER + /* Kyber: bit=1 → (Q+1)/2 = 1665 */ + r->coeffs[8*i+j] = mask & (int16_t)((PARAM_Q + 1) / 2); +#elif ALGORITHM == ALGO_AIGIS_ENC + /* Aigis: bit=1 → (Q+1)/2 = 3841 */ + r->coeffs[8*i+j] = mask & (int16_t)((PARAM_Q + 1) / 2); +#endif + } + } +} + +static __device__ void poly_tomsg(uint8_t *msg, const kem_poly *r) +{ + for (int i = 0; i < PARAM_N / 8; i++) { + msg[i] = 0; + for (int j = 0; j < 8; j++) { + /* 四舍五入到最近整数 mod 2: 若系数更接近 Q/2 则为 1 */ + int16_t t = r->coeffs[8*i+j]; + t = caddq(t); + /* 放大到 2 个区间: (t * 2 + Q/2) / Q & 1 */ + /* Kyber 公式: ((t << 1) + Q/2) / Q & 1 */ +#if ALGORITHM == ALGO_KYBER + /* 阈值测试: 若 t > Q/4 且 t < 3Q/4 则为 1 */ + t = (int16_t)(((t << 1) + PARAM_Q / 2) / PARAM_Q); +#elif ALGORITHM == ALGO_AIGIS_ENC + t = (int16_t)(((t << 1) + PARAM_Q / 2) / PARAM_Q); +#endif + msg[i] |= (uint8_t)((t & 1) << j); + } + } +} + +/* ================================================================ + * 序列化/反序列化 (全精度) + * Kyber: 12-bit per coeff → 384 bytes + * Aigis-enc:13-bit per coeff → 416 bytes + * ================================================================ */ + +#if ALGORITHM == ALGO_KYBER + +/* 12-bit → 384 bytes */ +static __device__ __noinline__ void poly_tobytes(uint8_t *r, const kem_poly *a) +{ + for (unsigned int i = 0; i < PARAM_N / 2; i++) { + int16_t t0 = caddq(a->coeffs[2*i]); + int16_t t1 = caddq(a->coeffs[2*i+1]); + r[3*i+0] = (uint8_t)(t0); + r[3*i+1] = (uint8_t)((t0 >> 8) | (t1 << 4)); + r[3*i+2] = (uint8_t)(t1 >> 4); + } +} + +static __device__ __noinline__ void poly_frombytes(kem_poly *r, const uint8_t *a) +{ + for (unsigned int i = 0; i < PARAM_N / 2; i++) { + r->coeffs[2*i] = (int16_t)(((a[3*i+0]) | ((int16_t)a[3*i+1] << 8)) & 0xFFF); + r->coeffs[2*i+1] = (int16_t)(((a[3*i+1] >> 4) | ((int16_t)a[3*i+2] << 4)) & 0xFFF); + } +} + +#elif ALGORITHM == ALGO_AIGIS_ENC + +/* 13-bit → 416 bytes (8 coeffs per 13 bytes) */ +static __device__ __noinline__ void poly_tobytes(uint8_t *r, const kem_poly *a) +{ + for (unsigned int i = 0; i < PARAM_N / 8; i++) { + int16_t t[8]; + for (int j = 0; j < 8; j++) t[j] = caddq(a->coeffs[8*i+j]); + r[13*i+ 0] = (uint8_t)(t[0]); + r[13*i+ 1] = (uint8_t)((t[0] >> 8) | (t[1] << 5)); + r[13*i+ 2] = (uint8_t)((t[1] >> 3)); + r[13*i+ 3] = (uint8_t)((t[1] >> 11) | (t[2] << 2)); + r[13*i+ 4] = (uint8_t)((t[2] >> 6) | (t[3] << 7)); + r[13*i+ 5] = (uint8_t)((t[3] >> 1)); + r[13*i+ 6] = (uint8_t)((t[3] >> 9) | (t[4] << 4)); + r[13*i+ 7] = (uint8_t)((t[4] >> 4)); + r[13*i+ 8] = (uint8_t)((t[4] >> 12) | (t[5] << 1)); + r[13*i+ 9] = (uint8_t)((t[5] >> 7) | (t[6] << 6)); + r[13*i+10] = (uint8_t)((t[6] >> 2)); + r[13*i+11] = (uint8_t)((t[6] >> 10) | (t[7] << 3)); + r[13*i+12] = (uint8_t)((t[7] >> 5)); + } +} + +static __device__ __noinline__ void poly_frombytes(kem_poly *r, const uint8_t *a) +{ + for (unsigned int i = 0; i < PARAM_N / 8; i++) { + r->coeffs[8*i+0] = (int16_t)((a[13*i+ 0] | ((uint16_t)a[13*i+ 1] << 8)) & 0x1FFF); + r->coeffs[8*i+1] = (int16_t)(((a[13*i+ 1] >> 5) | ((uint16_t)a[13*i+ 2] << 3) | ((uint16_t)a[13*i+ 3] << 11)) & 0x1FFF); + r->coeffs[8*i+2] = (int16_t)(((a[13*i+ 3] >> 2) | ((uint16_t)a[13*i+ 4] << 6)) & 0x1FFF); + r->coeffs[8*i+3] = (int16_t)(((a[13*i+ 4] >> 7) | ((uint16_t)a[13*i+ 5] << 1) | ((uint16_t)a[13*i+ 6] << 9)) & 0x1FFF); + r->coeffs[8*i+4] = (int16_t)(((a[13*i+ 6] >> 4) | ((uint16_t)a[13*i+ 7] << 4) | ((uint16_t)a[13*i+ 8] << 12)) & 0x1FFF); + r->coeffs[8*i+5] = (int16_t)(((a[13*i+ 8] >> 1) | ((uint16_t)a[13*i+ 9] << 7)) & 0x1FFF); + r->coeffs[8*i+6] = (int16_t)(((a[13*i+ 9] >> 6) | ((uint16_t)a[13*i+10] << 2) | ((uint16_t)a[13*i+11] << 10)) & 0x1FFF); + r->coeffs[8*i+7] = (int16_t)(((a[13*i+11] >> 3) | ((uint16_t)a[13*i+12] << 5)) & 0x1FFF); + } +} + +#endif /* ALGORITHM for tobytes/frombytes */ + +/* ================================================================ + * 密文标量多项式压缩/解压缩 (BITS_C2 bits per coeff) + * ================================================================ */ + +/* 压缩: coeff in [0,Q) → BITS_C2-bit integer */ +static __device__ __noinline__ void poly_compress_c2(uint8_t *r, const kem_poly *a) +{ + /* 先归一化到 [0, Q) */ + +#if PARAM_BITS_C2 == 3 + /* 3-bit: 8 coeffs → 3 bytes */ + for (int i = 0; i < PARAM_N / 8; i++) { + uint8_t c0 = (uint8_t)((((int32_t)caddq(a->coeffs[8*i+0]) << 3) + PARAM_Q/2) / PARAM_Q) & 0x07; + uint8_t c1 = (uint8_t)((((int32_t)caddq(a->coeffs[8*i+1]) << 3) + PARAM_Q/2) / PARAM_Q) & 0x07; + uint8_t c2 = (uint8_t)((((int32_t)caddq(a->coeffs[8*i+2]) << 3) + PARAM_Q/2) / PARAM_Q) & 0x07; + uint8_t c3 = (uint8_t)((((int32_t)caddq(a->coeffs[8*i+3]) << 3) + PARAM_Q/2) / PARAM_Q) & 0x07; + uint8_t c4 = (uint8_t)((((int32_t)caddq(a->coeffs[8*i+4]) << 3) + PARAM_Q/2) / PARAM_Q) & 0x07; + uint8_t c5 = (uint8_t)((((int32_t)caddq(a->coeffs[8*i+5]) << 3) + PARAM_Q/2) / PARAM_Q) & 0x07; + uint8_t c6 = (uint8_t)((((int32_t)caddq(a->coeffs[8*i+6]) << 3) + PARAM_Q/2) / PARAM_Q) & 0x07; + uint8_t c7 = (uint8_t)((((int32_t)caddq(a->coeffs[8*i+7]) << 3) + PARAM_Q/2) / PARAM_Q) & 0x07; + r[3*i+0] = (uint8_t)(c0 | (c1 << 3) | (c2 << 6)); + r[3*i+1] = (uint8_t)((c2 >> 2) | (c3 << 1) | (c4 << 4) | (c5 << 7)); + r[3*i+2] = (uint8_t)((c5 >> 1) | (c6 << 2) | (c7 << 5)); + } +#elif PARAM_BITS_C2 == 4 + /* 4-bit: 2 coeffs per byte */ + for (int i = 0; i < PARAM_N / 2; i++) { + int16_t u = (int16_t)((((int32_t)caddq(a->coeffs[2*i]) << 4) + PARAM_Q/2) / PARAM_Q) & 0x0F; + int16_t v = (int16_t)((((int32_t)caddq(a->coeffs[2*i+1]) << 4) + PARAM_Q/2) / PARAM_Q) & 0x0F; + r[i] = (uint8_t)(u | (v << 4)); + } +#elif PARAM_BITS_C2 == 5 + /* 5-bit: 8 coeffs → 5 bytes */ + for (int i = 0; i < PARAM_N / 8; i++) { + uint8_t c0 = (uint8_t)((((int32_t)caddq(a->coeffs[8*i+0]) << 5) + PARAM_Q/2) / PARAM_Q) & 0x1F; + uint8_t c1 = (uint8_t)((((int32_t)caddq(a->coeffs[8*i+1]) << 5) + PARAM_Q/2) / PARAM_Q) & 0x1F; + uint8_t c2 = (uint8_t)((((int32_t)caddq(a->coeffs[8*i+2]) << 5) + PARAM_Q/2) / PARAM_Q) & 0x1F; + uint8_t c3 = (uint8_t)((((int32_t)caddq(a->coeffs[8*i+3]) << 5) + PARAM_Q/2) / PARAM_Q) & 0x1F; + uint8_t c4 = (uint8_t)((((int32_t)caddq(a->coeffs[8*i+4]) << 5) + PARAM_Q/2) / PARAM_Q) & 0x1F; + uint8_t c5 = (uint8_t)((((int32_t)caddq(a->coeffs[8*i+5]) << 5) + PARAM_Q/2) / PARAM_Q) & 0x1F; + uint8_t c6 = (uint8_t)((((int32_t)caddq(a->coeffs[8*i+6]) << 5) + PARAM_Q/2) / PARAM_Q) & 0x1F; + uint8_t c7 = (uint8_t)((((int32_t)caddq(a->coeffs[8*i+7]) << 5) + PARAM_Q/2) / PARAM_Q) & 0x1F; + r[5*i+0] = (uint8_t)(c0 | (c1 << 5)); + r[5*i+1] = (uint8_t)((c1 >> 3) | (c2 << 2) | (c3 << 7)); + r[5*i+2] = (uint8_t)((c3 >> 1) | (c4 << 4)); + r[5*i+3] = (uint8_t)((c4 >> 4) | (c5 << 1) | (c6 << 6)); + r[5*i+4] = (uint8_t)((c6 >> 2) | (c7 << 3)); + } +#endif +} + +/* 解压缩: BITS_C2-bit integer → coeff in [0, Q) */ +static __device__ __noinline__ void poly_decompress_c2(kem_poly *r, const uint8_t *a) +{ +#if PARAM_BITS_C2 == 3 + for (int i = 0; i < PARAM_N / 8; i++) { + uint8_t c[8]; + c[0] = a[3*i+0] & 0x07; + c[1] = (a[3*i+0] >> 3) & 0x07; + c[2] = (a[3*i+0] >> 6) | ((a[3*i+1] & 0x01) << 2); + c[3] = (a[3*i+1] >> 1) & 0x07; + c[4] = (a[3*i+1] >> 4) & 0x07; + c[5] = (a[3*i+1] >> 7) | ((a[3*i+2] & 0x03) << 1); + c[6] = (a[3*i+2] >> 2) & 0x07; + c[7] = (a[3*i+2] >> 5); + for (int j = 0; j < 8; j++) + r->coeffs[8*i+j] = (int16_t)(((int32_t)c[j] * PARAM_Q + 4) >> 3); + } +#elif PARAM_BITS_C2 == 4 + for (int i = 0; i < PARAM_N / 2; i++) { + r->coeffs[2*i] = (int16_t)(((int32_t)( a[i] & 0x0F) * PARAM_Q + 8) >> 4); + r->coeffs[2*i+1] = (int16_t)(((int32_t)((a[i] >> 4) & 0x0F) * PARAM_Q + 8) >> 4); + } +#elif PARAM_BITS_C2 == 5 + for (int i = 0; i < PARAM_N / 8; i++) { + uint8_t c[8]; + c[0] = a[5*i+0] & 0x1F; + c[1] = (a[5*i+0] >> 5) | ((a[5*i+1] & 0x03) << 3); + c[2] = (a[5*i+1] >> 2) & 0x1F; + c[3] = (a[5*i+1] >> 7) | ((a[5*i+2] & 0x0F) << 1); + c[4] = (a[5*i+2] >> 4) | ((a[5*i+3] & 0x01) << 4); + c[5] = (a[5*i+3] >> 1) & 0x1F; + c[6] = (a[5*i+3] >> 6) | ((a[5*i+4] & 0x07) << 2); + c[7] = a[5*i+4] >> 3; + for (int j = 0; j < 8; j++) + r->coeffs[8*i+j] = (int16_t)(((int32_t)c[j] * PARAM_Q + 16) >> 5); + } +#endif +} + +#endif /* POLY_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/polyvec.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/polyvec.cuh new file mode 100644 index 000000000..89e1c0b46 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/polyvec.cuh @@ -0,0 +1,259 @@ +/* + * polyvec.cuh — 统一多项式向量运算 + * + * 包含: + * - polyvec_tobytes/frombytes (全精度序列化) + * - polyvec_pk_compress/decompress (公钥向量压缩) + * - polyvec_ct_compress/decompress (密文向量 u 压缩, BITS_C1 bits) + * - polyvec_ntt/invntt + * - polyvec_basemul_acc (内积, 由 ntt.cuh 提供) + * - polyvec_add, polyvec_caddq + * + * 公钥压缩位宽 PARAM_BITS_PK: + * Kyber: 12 bits (不压缩, 直接用 tobytes12) + * Aigis: 9, 10, or 11 bits + * + * 密文向量压缩位宽 PARAM_BITS_C1: + * Kyber: 10 bits (K=2,3) or 11 bits (K=4) + * Aigis: 9, 10, or 11 bits + */ + +#ifndef POLYVEC_CUH +#define POLYVEC_CUH + +#include +#include "params.h" +#include "poly.cuh" +#include "ntt.cuh" + +/* ================================================================ + * 向量级基础操作 + * ================================================================ */ + +static __device__ void polyvec_add(kem_polyvec *r, const kem_polyvec *a, const kem_polyvec *b) +{ + for (int i = 0; i < PARAM_K; i++) poly_add(&r->vec[i], &a->vec[i], &b->vec[i]); +} + +static __device__ void polyvec_reduce(kem_polyvec *r) +{ + for (int i = 0; i < PARAM_K; i++) poly_reduce(&r->vec[i]); +} + +static __device__ void polyvec_caddq(kem_polyvec *r) +{ + for (int i = 0; i < PARAM_K; i++) poly_caddq(&r->vec[i]); +} + +static __device__ void polyvec_caddq2(kem_polyvec *r) +{ + for (int i = 0; i < PARAM_K; i++) poly_caddq2(&r->vec[i]); +} + +/* ================================================================ + * 全精度序列化 (用于 sk 存储 NTT 域系数) + * ================================================================ */ + +static __device__ __noinline__ void polyvec_tobytes(uint8_t *r, const kem_polyvec *a) +{ + for (int i = 0; i < PARAM_K; i++) + poly_tobytes(r + i * PARAM_POLYBYTES, &a->vec[i]); +} + +static __device__ __noinline__ void polyvec_frombytes(kem_polyvec *r, const uint8_t *a) +{ + for (int i = 0; i < PARAM_K; i++) + poly_frombytes(&r->vec[i], a + i * PARAM_POLYBYTES); +} + +/* ================================================================ + * 通用有损压缩辅助函数 (9/10/11 bits) + * 使用 PARAM_Q — 同时适用于 Kyber 和 Aigis-enc CT 压缩 + * ================================================================ */ + +/* 9-bit 压缩: 8 coeffs → 9 bytes */ +static __device__ __noinline__ void polyvec_compress9(uint8_t *r, const kem_poly *a) +{ + for (int i = 0; i < PARAM_N / 8; i++) { + uint16_t c[8]; + for (int j = 0; j < 8; j++) + c[j] = (uint16_t)((((int32_t)caddq(a->coeffs[8*i+j]) << 9) + PARAM_Q/2) / PARAM_Q) & 0x1FF; + r[9*i+0] = (uint8_t)(c[0]); + r[9*i+1] = (uint8_t)((c[0] >> 8) | (c[1] << 1)); + r[9*i+2] = (uint8_t)((c[1] >> 7) | (c[2] << 2)); + r[9*i+3] = (uint8_t)((c[2] >> 6) | (c[3] << 3)); + r[9*i+4] = (uint8_t)((c[3] >> 5) | (c[4] << 4)); + r[9*i+5] = (uint8_t)((c[4] >> 4) | (c[5] << 5)); + r[9*i+6] = (uint8_t)((c[5] >> 3) | (c[6] << 6)); + r[9*i+7] = (uint8_t)((c[6] >> 2) | (c[7] << 7)); + r[9*i+8] = (uint8_t)((c[7] >> 1)); + } +} + +static __device__ __noinline__ void polyvec_decompress9(kem_poly *r, const uint8_t *a) +{ + for (int i = 0; i < PARAM_N / 8; i++) { + uint16_t c[8]; + c[0] = ((uint16_t)a[9*i+0]) | ((uint16_t)(a[9*i+1] & 0x01) << 8); + c[1] = ((uint16_t)a[9*i+1] >> 1) | ((uint16_t)(a[9*i+2] & 0x03) << 7); + c[2] = ((uint16_t)a[9*i+2] >> 2) | ((uint16_t)(a[9*i+3] & 0x07) << 6); + c[3] = ((uint16_t)a[9*i+3] >> 3) | ((uint16_t)(a[9*i+4] & 0x0F) << 5); + c[4] = ((uint16_t)a[9*i+4] >> 4) | ((uint16_t)(a[9*i+5] & 0x1F) << 4); + c[5] = ((uint16_t)a[9*i+5] >> 5) | ((uint16_t)(a[9*i+6] & 0x3F) << 3); + c[6] = ((uint16_t)a[9*i+6] >> 6) | ((uint16_t)(a[9*i+7] & 0x7F) << 2); + c[7] = ((uint16_t)a[9*i+7] >> 7) | ((uint16_t)(a[9*i+8]) << 1); + for (int j = 0; j < 8; j++) + r->coeffs[8*i+j] = (int16_t)(((int32_t)c[j] * PARAM_Q + 256) >> 9); + } +} + +/* 10-bit 压缩: 4 coeffs → 5 bytes */ +static __device__ __noinline__ void polyvec_compress10(uint8_t *r, const kem_poly *a) +{ + for (int i = 0; i < PARAM_N / 4; i++) { + uint16_t c[4]; + for (int j = 0; j < 4; j++) + c[j] = (uint16_t)((((int32_t)caddq(a->coeffs[4*i+j]) << 10) + PARAM_Q/2) / PARAM_Q) & 0x3FF; + r[5*i+0] = (uint8_t)(c[0]); + r[5*i+1] = (uint8_t)((c[0] >> 8) | (c[1] << 2)); + r[5*i+2] = (uint8_t)((c[1] >> 6) | (c[2] << 4)); + r[5*i+3] = (uint8_t)((c[2] >> 4) | (c[3] << 6)); + r[5*i+4] = (uint8_t)((c[3] >> 2)); + } +} + +static __device__ __noinline__ void polyvec_decompress10(kem_poly *r, const uint8_t *a) +{ + for (int i = 0; i < PARAM_N / 4; i++) { + uint16_t c[4]; + c[0] = ((uint16_t)a[5*i+0]) | ((uint16_t)(a[5*i+1] & 0x03) << 8); + c[1] = ((uint16_t)a[5*i+1] >> 2) | ((uint16_t)(a[5*i+2] & 0x0F) << 6); + c[2] = ((uint16_t)a[5*i+2] >> 4) | ((uint16_t)(a[5*i+3] & 0x3F) << 4); + c[3] = ((uint16_t)a[5*i+3] >> 6) | ((uint16_t)(a[5*i+4]) << 2); + for (int j = 0; j < 4; j++) + r->coeffs[4*i+j] = (int16_t)(((int32_t)c[j] * PARAM_Q + 512) >> 10); + } +} + +/* 11-bit 压缩: 8 coeffs → 11 bytes */ +static __device__ __noinline__ void polyvec_compress11(uint8_t *r, const kem_poly *a) +{ + for (int i = 0; i < PARAM_N / 8; i++) { + uint16_t c[8]; + for (int j = 0; j < 8; j++) + c[j] = (uint16_t)((((int32_t)caddq(a->coeffs[8*i+j]) << 11) + PARAM_Q/2) / PARAM_Q) & 0x7FF; + r[11*i+ 0] = (uint8_t)(c[0]); + r[11*i+ 1] = (uint8_t)((c[0] >> 8) | (c[1] << 3)); + r[11*i+ 2] = (uint8_t)((c[1] >> 5) | (c[2] << 6)); + r[11*i+ 3] = (uint8_t)((c[2] >> 2)); + r[11*i+ 4] = (uint8_t)((c[2] >> 10) | (c[3] << 1)); + r[11*i+ 5] = (uint8_t)((c[3] >> 7) | (c[4] << 4)); + r[11*i+ 6] = (uint8_t)((c[4] >> 4) | (c[5] << 7)); + r[11*i+ 7] = (uint8_t)((c[5] >> 1)); + r[11*i+ 8] = (uint8_t)((c[5] >> 9) | (c[6] << 2)); + r[11*i+ 9] = (uint8_t)((c[6] >> 6) | (c[7] << 5)); + r[11*i+10] = (uint8_t)((c[7] >> 3)); + } +} + +static __device__ __noinline__ void polyvec_decompress11(kem_poly *r, const uint8_t *a) +{ + for (int i = 0; i < PARAM_N / 8; i++) { + uint16_t c[8]; + c[0] = ((uint16_t)a[11*i+ 0]) | ((uint16_t)(a[11*i+ 1] & 0x07) << 8); + c[1] = ((uint16_t)a[11*i+ 1] >> 3) | ((uint16_t)(a[11*i+ 2] & 0x3F) << 5); + c[2] = ((uint16_t)a[11*i+ 2] >> 6) | ((uint16_t)a[11*i+ 3] << 2) | ((uint16_t)(a[11*i+ 4] & 0x01) << 10); + c[3] = ((uint16_t)a[11*i+ 4] >> 1) | ((uint16_t)(a[11*i+ 5] & 0x0F) << 7); + c[4] = ((uint16_t)a[11*i+ 5] >> 4) | ((uint16_t)(a[11*i+ 6] & 0x7F) << 4); + c[5] = ((uint16_t)a[11*i+ 6] >> 7) | ((uint16_t)a[11*i+ 7] << 1) | ((uint16_t)(a[11*i+ 8] & 0x03) << 9); + c[6] = ((uint16_t)a[11*i+ 8] >> 2) | ((uint16_t)(a[11*i+ 9] & 0x1F) << 6); + c[7] = ((uint16_t)a[11*i+ 9] >> 5) | ((uint16_t)a[11*i+10] << 3); + for (int j = 0; j < 8; j++) + r->coeffs[8*i+j] = (int16_t)(((int32_t)c[j] * PARAM_Q + 1024) >> 11); + } +} + +/* ================================================================ + * 公钥向量压缩/解压缩 (PARAM_BITS_PK bits per coeff) + * + * Kyber: BITS_PK=12 → tobytes (无压缩) + * Aigis: BITS_PK=9/10/11 → compress9/10/11 + * ================================================================ */ + +#if ALGORITHM == ALGO_KYBER +static __device__ __noinline__ void polyvec_pk_compress(uint8_t *r, const kem_polyvec *a) +{ + polyvec_tobytes(r, a); +} +static __device__ __noinline__ void polyvec_pk_decompress(kem_polyvec *r, const uint8_t *a) +{ + polyvec_frombytes(r, a); +} +#elif ALGORITHM == ALGO_AIGIS_ENC +/* 统一 PK 压缩/解压缩分发 */ +static __device__ __noinline__ void polyvec_pk_compress(uint8_t *r, const kem_polyvec *a) +{ + for (int i = 0; i < PARAM_K; i++) { + uint8_t *dst = r + i * PARAM_BITS_PK * PARAM_N / 8; +#if PARAM_BITS_PK == 9 + polyvec_compress9(dst, &a->vec[i]); +#elif PARAM_BITS_PK == 10 + polyvec_compress10(dst, &a->vec[i]); +#elif PARAM_BITS_PK == 11 + polyvec_compress11(dst, &a->vec[i]); +#endif + } +} + +static __device__ __noinline__ void polyvec_pk_decompress(kem_polyvec *r, const uint8_t *a) +{ + for (int i = 0; i < PARAM_K; i++) { + const uint8_t *src = a + i * PARAM_BITS_PK * PARAM_N / 8; +#if PARAM_BITS_PK == 9 + polyvec_decompress9(&r->vec[i], src); +#elif PARAM_BITS_PK == 10 + polyvec_decompress10(&r->vec[i], src); +#elif PARAM_BITS_PK == 11 + polyvec_decompress11(&r->vec[i], src); +#endif + } +} +#endif /* ALGORITHM for PK compress */ + +/* ================================================================ + * 密文向量 u 压缩/解压缩 (PARAM_BITS_C1 bits per coeff) + * + * 两种算法都有 10-bit 和 11-bit 变体 (Aigis 还有 9-bit) + * Kyber K=2,3: 10-bit; K=4: 11-bit + * Aigis: 按 PARAM_BITS_C1 选择 + * ================================================================ */ + +static __device__ __noinline__ void polyvec_ct_compress(uint8_t *r, const kem_polyvec *a) +{ + for (int i = 0; i < PARAM_K; i++) { + uint8_t *dst = r + i * PARAM_BITS_C1 * PARAM_N / 8; +#if PARAM_BITS_C1 == 9 + polyvec_compress9(dst, &a->vec[i]); +#elif PARAM_BITS_C1 == 10 + polyvec_compress10(dst, &a->vec[i]); +#elif PARAM_BITS_C1 == 11 + polyvec_compress11(dst, &a->vec[i]); +#endif + } +} + +static __device__ __noinline__ void polyvec_ct_decompress(kem_polyvec *r, const uint8_t *a) +{ + for (int i = 0; i < PARAM_K; i++) { + const uint8_t *src = a + i * PARAM_BITS_C1 * PARAM_N / 8; +#if PARAM_BITS_C1 == 9 + polyvec_decompress9(&r->vec[i], src); +#elif PARAM_BITS_C1 == 10 + polyvec_decompress10(&r->vec[i], src); +#elif PARAM_BITS_C1 == 11 + polyvec_decompress11(&r->vec[i], src); +#endif + } +} + +#endif /* POLYVEC_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/reduce.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/reduce.cuh new file mode 100644 index 000000000..65f6fb923 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/reduce.cuh @@ -0,0 +1,80 @@ +/* + * reduce.cuh — 统一模块化约减 + * + * 两种算法均使用 16-bit Montgomery 约减 (R = 2^16)。 + * 唯一差异是 Q 和 QINV 的数值,由 params.h 提供。 + * + * Montgomery 乘法: fqmul(a, b) = a*b*R^{-1} mod Q + * 步骤: t = a * QINV (mod 2^16, 有符号) + * return (a*b - t*Q) / R + * + * Barrett 约减 (仅 Aigis-enc 使用, Kyber 同样支持): + * 输入范围 [-(2^15)*Q, (2^15)*Q], 输出 (-Q, Q) + */ + +#ifndef REDUCE_CUH +#define REDUCE_CUH + +#include +#include "params.h" + +/* ================================================================ + * Montgomery 约减: 输入 int32_t a,输出 int16_t ≡ a*R^{-1} mod Q + * 有效输入范围: |a| < Q * 2^15 + * ================================================================ */ +static __device__ __forceinline__ int16_t montgomery_reduce(int32_t a) +{ + int16_t t = (int16_t)((int16_t)a * (int16_t)PARAM_QINV); + return (int16_t)((a - (int32_t)t * PARAM_Q) >> 16); +} + +/* Montgomery 乘法 */ +static __device__ __forceinline__ int16_t fqmul(int16_t a, int16_t b) +{ + return montgomery_reduce((int32_t)a * b); +} + +/* ================================================================ + * Barrett 约减: 输入 int16_t a in (-Q*4, Q*4),输出 (-Q, Q) + * 使用预计算常数 v ≈ 2^26 / Q + * ================================================================ */ +static __device__ __forceinline__ int16_t barrett_reduce(int16_t a) +{ +#if ALGORITHM == ALGO_KYBER + /* Kyber Q=3329, v=(2^26 + Q/2)/Q = 20159 */ + const int16_t v = (int16_t)(((1 << 26) + PARAM_Q / 2) / PARAM_Q); + int16_t t = (int16_t)(((int32_t)v * a + (1 << 25)) >> 26); + return a - t * (int16_t)PARAM_Q; +#elif ALGORITHM == ALGO_AIGIS_ENC + /* Aigis Q=7681, 使用 (a + 2^12) >> 13 * Q 近似 */ + int16_t u = (int16_t)((a + (1 << 12)) >> 13); + u *= (int16_t)PARAM_Q; + return a - u; +#endif +} + +/* ================================================================ + * caddq: 将 [-Q, Q) 规范化到 [0, Q) + * ================================================================ */ +static __device__ __forceinline__ int16_t caddq(int16_t a) +{ + return a + ((a >> 15) & (int16_t)PARAM_Q); +} + +/* 双重 caddq: 将 [-2Q, Q) 映射到 [0, Q) */ +static __device__ __forceinline__ int16_t caddq2(int16_t a) +{ + int16_t r = a + ((a >> 15) & (int16_t)PARAM_Q); + return r + ((r >> 15) & (int16_t)PARAM_Q); +} + +/* ================================================================ + * tomont: 将普通系数转换到 Montgomery 域 + * result = a * R^2 * R^{-1} = a * R (mod Q) + * ================================================================ */ +static __device__ __forceinline__ int16_t tomont(int16_t a) +{ + return fqmul(a, (int16_t)MONT_R2); +} + +#endif /* REDUCE_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/rocm_compat.h b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/rocm_compat.h new file mode 100644 index 000000000..45c9fde38 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/rocm_compat.h @@ -0,0 +1,54 @@ +#ifndef ROCM_COMPAT_H +#define ROCM_COMPAT_H + +#if defined(USE_HIP) || defined(__HIPCC__) || defined(__HIP_PLATFORM_AMD__) + +#ifndef USE_HIP +#define USE_HIP 1 +#endif + +#define GPU_USE_HIP 1 + +#include + +#define GPU_RUNTIME_NAME "HIP" + +#define cudaError_t hipError_t +#define cudaSuccess hipSuccess +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaGetDevice hipGetDevice +#define cudaGetDeviceProperties hipGetDeviceProperties +#define cudaDeviceProp hipDeviceProp_t +#define cudaMalloc hipMalloc +#define cudaFree hipFree +#define cudaMemcpy hipMemcpy +#define cudaMemcpyAsync hipMemcpyAsync +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#define cudaStream_t hipStream_t +#define cudaStreamCreate hipStreamCreate +#define cudaStreamDestroy hipStreamDestroy +#define cudaEvent_t hipEvent_t +#define cudaEventCreate hipEventCreate +#define cudaEventDestroy hipEventDestroy +#define cudaEventRecord hipEventRecord +#define cudaEventSynchronize hipEventSynchronize +#define cudaEventElapsedTime hipEventElapsedTime +#define cudaDeviceSetLimit hipDeviceSetLimit +#define cudaLimitStackSize hipLimitStackSize +#define cudaFuncSetCacheConfig hipFuncSetCacheConfig +#define cudaFuncCachePreferL1 hipFuncCachePreferL1 + +#else + +#define GPU_USE_HIP 0 + +#include + +#define GPU_RUNTIME_NAME "CUDA" + +#endif + +#endif /* ROCM_COMPAT_H */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/run_kem_smoke_amd.sh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/run_kem_smoke_amd.sh new file mode 100644 index 000000000..0bc49fa36 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/kem_api/run_kem_smoke_amd.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd "$(dirname "$0")" + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +mkdir -p amd_results/smoke + +targets=( + kyber512_amd + kyber768_amd + kyber1024_amd + aigisenc1_amd + aigisenc2_amd + aigisenc3_amd + aigisenc4_amd +) + +batches=(1 8 32 128) + +for exe in "${targets[@]}"; do + if [[ ! -x "./${exe}" ]]; then + echo "[skip] ./${exe} not found or not executable" + continue + fi + + for batch in "${batches[@]}"; do + log="amd_results/smoke/${exe}_b${batch}.log" + echo "[smoke] ${exe} batch=${batch}" + stdbuf -oL -eL "./${exe}" --batch "${batch}" --n-ops 1 \ + 2>&1 | tee "${log}" + + if grep -q "FAIL" "${log}"; then + echo "[smoke] FAIL detected in ${log}" >&2 + exit 1 + fi + done +done + +python3 parse_kem_results.py amd_results/smoke > amd_results/kem_smoke_summary.csv +echo "[summary] amd_results/kem_smoke_summary.csv" diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_keygen.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_keygen.cuh new file mode 100644 index 000000000..6b3f9c368 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_keygen.cuh @@ -0,0 +1,3095 @@ +#include "hip/hip_runtime.h" +/* + * batch_keygen.cuh — 分解式批量密钥生成 pipeline + * + * 核心优化: + * 1. 采样 (SHAKE-heavy) 使用 1 warp/实例 (32 线程并行生成所有多项式) + * 2. NTT 使用 shared-memory 批量 kernel (128 线程/poly) + * 3. 矩阵向量乘使用 2D grid (batch × K), 每系数一线程 + * 4. 元素级运算使用 256 线程/block 的批量 kernel + * 5. 打包由 32 线程/block 独立执行 + * + * Pipeline: + * [1] 采样: seed → A, s1, s2 (warp 级并行, #if 算法分叉) + * [2] copy s1 → s1hat + * [3] NTT(s1hat) (shared-mem batch) + * [4] t = A · s1hat (2D grid matvec) + * [5] reduce + INVNTT(t) (batch kernels) + * [6] t += s2 (batch add) + * [7] 打包 pk, sk (#if 算法分叉) + */ + +#ifndef BATCH_KEYGEN_CUH +#define BATCH_KEYGEN_CUH + +#include +#include +#include +#include "params.h" +#include "reduce.cuh" +#include "ntt.cuh" +#include "fips202.cuh" +#include "poly.cuh" +#include "polyvec.cuh" +#include "packing.cuh" +#include "batch_ntt.cuh" +#include "batch_ops.cuh" +#include "sign.cuh" +#include "symmetric.cuh" + +#ifndef BATCH_KEYGEN_SAMPLE_SPLIT_FAST +#define BATCH_KEYGEN_SAMPLE_SPLIT_FAST 0 +#endif + +#ifndef BATCH_KEYGEN_MATRIX_A_FAST +#define BATCH_KEYGEN_MATRIX_A_FAST BATCH_KEYGEN_SAMPLE_SPLIT_FAST +#endif + +#ifndef BATCH_KEYGEN_SECRET_ETA_FAST +#define BATCH_KEYGEN_SECRET_ETA_FAST BATCH_KEYGEN_SAMPLE_SPLIT_FAST +#endif + +#ifndef BATCH_KEYGEN_MATRIX_A_COOP +#define BATCH_KEYGEN_MATRIX_A_COOP 0 +#endif + +#ifndef BATCH_KEYGEN_MATRIX_A_LANEOPT +#define BATCH_KEYGEN_MATRIX_A_LANEOPT 0 +#endif + +#ifndef BATCH_KEYGEN_MATRIX_A_COOP_SUBWARP +#define BATCH_KEYGEN_MATRIX_A_COOP_SUBWARP 0 +#endif + +#ifndef BATCH_KEYGEN_MATRIX_A_COOP_SUBWARP_LANES +#define BATCH_KEYGEN_MATRIX_A_COOP_SUBWARP_LANES 16 +#endif + +#ifndef BATCH_KEYGEN_SECRET_ETA_COOP +#define BATCH_KEYGEN_SECRET_ETA_COOP 0 +#endif + +#ifndef BATCH_KEYGEN_SECRET_ETA_COOP_LANES +#define BATCH_KEYGEN_SECRET_ETA_COOP_LANES 16 +#endif + +#ifndef BATCH_KEYGEN_SECRET_ETA_AIGIS5_SPLIT +#define BATCH_KEYGEN_SECRET_ETA_AIGIS5_SPLIT 0 +#endif + +/* ================================================================ + * 缓冲区结构体 — 两种算法共用 + * ================================================================ */ +struct BatchKeygenBuffers { + coeff_t *d_mat; /* batch * K * L * N */ + coeff_t *d_s1; /* batch * L * N */ + coeff_t *d_s1hat; /* batch * L * N (NTT domain) */ + coeff_t *d_s2; /* batch * K * N */ + coeff_t *d_t; /* batch * K * N */ + coeff_t *d_t1; /* batch * K * N — power2round high bits */ + coeff_t *d_t0; /* batch * K * N — power2round low bits */ + coeff_t *d_t1_hat; /* batch * K * N — NTT(t1 << D), verify material */ + coeff_t *d_s2_ntt; /* batch * K * N — NTT(s2), sign material */ + coeff_t *d_t0_ntt; /* batch * K * N — NTT(t0), sign material */ + unsigned char *d_tr; /* batch * TRBYTES — H(pk) */ + unsigned char *d_pks; /* batch * CRYPTO_PUBLICKEYBYTES */ + unsigned char *d_sks; /* batch * CRYPTO_SECRETKEYBYTES */ + unsigned char *d_buf; /* per-instance 辅助数据 (rho, key 等) */ + int max_batch; +}; + +typedef struct { + float sample_ms; + float seed_expand_ms; + float matrix_a_sample_ms; + float secret_eta_sample_ms; + float sample_launch_gap_ms; + float matrix_a_coop_ms; + float secret_eta_coop_ms; + float copy_ms; + float ntt_ms; + float matvec_ms; + float post_ms; + float p2r_ms; + float pack_ms; + float pack_inner_ms; + float pack_fused_ms; + float pack_body_ms; + float pack_header_ms; + float pack_t1_ms; + float pack_eta_ms; + float pack_t0_ms; + float tr_hash_ms; + float shared_a_ms; + float material_ms; + int matrix_a_coop_lanes; + int secret_eta_coop_lanes; +} KeygenProfile; + +typedef struct { + float old_fused_ms; + float shared_a_ms; + float split_seed_ms; + float split_matrix_a_ms; + float split_eta_ms; + float split_total_ms; + float split_launch_gap_ms; + float split_matrix_a_coop_ms; + float split_eta_coop_ms; + int split_matrix_a_coop_lanes; + int split_eta_coop_lanes; +} KeygenSampleOnlyProfile; + +typedef enum { + KEYGEN_COMPARE_STAGE_NONE = 0, + KEYGEN_COMPARE_STAGE_BUF, + KEYGEN_COMPARE_STAGE_MAT, + KEYGEN_COMPARE_STAGE_S1, + KEYGEN_COMPARE_STAGE_S2, + KEYGEN_COMPARE_STAGE_S1HAT_COPY, + KEYGEN_COMPARE_STAGE_S1HAT_NTT, + KEYGEN_COMPARE_STAGE_T_MATVEC, + KEYGEN_COMPARE_STAGE_T, + KEYGEN_COMPARE_STAGE_T1, + KEYGEN_COMPARE_STAGE_T0, + KEYGEN_COMPARE_STAGE_PK, + KEYGEN_COMPARE_STAGE_SK, + KEYGEN_COMPARE_STAGE_TR, +} KeygenCompareStage; + +typedef struct { + KeygenCompareStage stage; + int instance; + size_t byte_offset; + size_t element_offset; + int64_t ref_value; + int64_t cand_value; +} KeygenCompareResult; + +static inline void keygen_profile_clear(KeygenProfile *p) { + if (p) memset(p, 0, sizeof(*p)); +} + +static inline void keygen_sample_only_profile_clear(KeygenSampleOnlyProfile *p) { + if (p) memset(p, 0, sizeof(*p)); +} + +static inline void keygen_compare_result_clear(KeygenCompareResult *r) { + if (r) memset(r, 0, sizeof(*r)); +} + +static inline const char *keygen_compare_stage_name(KeygenCompareStage stage) { + switch (stage) { + case KEYGEN_COMPARE_STAGE_BUF: return "d_buf"; + case KEYGEN_COMPARE_STAGE_MAT: return "d_mat"; + case KEYGEN_COMPARE_STAGE_S1: return "d_s1"; + case KEYGEN_COMPARE_STAGE_S2: return "d_s2"; + case KEYGEN_COMPARE_STAGE_S1HAT_COPY: return "d_s1hat-copy"; + case KEYGEN_COMPARE_STAGE_S1HAT_NTT: return "d_s1hat-ntt"; + case KEYGEN_COMPARE_STAGE_T_MATVEC: return "d_t-matvec"; + case KEYGEN_COMPARE_STAGE_T: return "d_t"; + case KEYGEN_COMPARE_STAGE_T1: return "d_t1"; + case KEYGEN_COMPARE_STAGE_T0: return "d_t0"; + case KEYGEN_COMPARE_STAGE_PK: return "pk"; + case KEYGEN_COMPARE_STAGE_SK: return "sk"; + case KEYGEN_COMPARE_STAGE_TR: return "tr"; + default: return "none"; + } +} + +static inline void keygen_profile_add(float *dst, hipEvent_t a, hipEvent_t b) { + float ms = 0.0f; + hipEventElapsedTime(&ms, a, b); + *dst += ms; +} + +static inline void keygen_profile_finalize_sample( + KeygenProfile *p, + float component_ms) +{ + if (!p) return; + const float gap = p->sample_ms - component_ms; + p->sample_launch_gap_ms = gap > 0.0f ? gap : 0.0f; +} + +/* ================================================================ + * 算子级并行采样 kernel — 仿照「合并的第五版」warp-cooperative 思路 + * + * 设计: 1 warp (32 线程) 处理 1 个 instance + * lane 0: 派生 seed (SHAKE256 展开) + * 所有 32 lanes: 并行生成矩阵 A + s1 + s2 的所有多项式 + * 总共 PARAM_K*PARAM_L + PARAM_L + PARAM_K 个多项式各自独立 SHAKE 流 + * 每个 lane 处理 p = lane, lane+32, lane+64, ... 的多项式 + * 多项式各自独立 → 零同步, 32× 并行化 SHAKE 调用 + * + * 性能: 采样阶段从 O(K*L+L+K) 串行降到 O(ceil((K*L+L+K)/32)) + * ================================================================ */ +#define WP_KG_WARP_SIZE 32 +#define WP_KG_WARPS_BLOCK 4 +#define WP_KG_TPB (WP_KG_WARP_SIZE * WP_KG_WARPS_BLOCK) +/* 共享种子缓冲区大小 (per warp): rho + rhoprime/eta_seed + key */ +#define WP_KG_SEED_BYTES (2 * SEEDBYTES + CRHBYTES) +#define WP_KG_MAX_SUBWARPS_PER_BLOCK (WP_KG_TPB / 8) +#define WP_KG_MATRIX_COOP_BUF_BYTES \ + (POLY_UNIFORM_NBLOCKS * STREAM128_BLOCKBYTES + STREAM128_BLOCKBYTES + 2) +#define WP_KG_ETA_COOP_BUF_BYTES \ + (POLY_UNIFORM_ETA2_NBLOCKS * STREAM256_BLOCKBYTES + STREAM256_BLOCKBYTES) + +__global__ void __launch_bounds__(WP_KG_TPB) +batch_keygen_warp_sample_kernel( + coeff_t * __restrict__ d_mat, + coeff_t * __restrict__ d_s1, + coeff_t * __restrict__ d_s2, + unsigned char * __restrict__ d_buf, + const unsigned char * __restrict__ d_base_seed, + int batch_count) +{ + __shared__ unsigned char sh_seeds[WP_KG_WARPS_BLOCK][WP_KG_SEED_BYTES]; + + const int warp_g = (blockIdx.x * blockDim.x + threadIdx.x) / WP_KG_WARP_SIZE; + const int lane = threadIdx.x & (WP_KG_WARP_SIZE - 1); + const int warp_l = threadIdx.x / WP_KG_WARP_SIZE; + + if (warp_g >= batch_count) return; + + unsigned char *my_seeds = sh_seeds[warp_l]; + + /* lane 0: 派生 per-instance seed 并 SHAKE256 展开到 shared memory */ + if (lane == 0) { + uint8_t seed_in[SEEDBYTES]; + for (int i = 0; i < SEEDBYTES; i++) seed_in[i] = d_base_seed[i]; + seed_in[SEEDBYTES - 4] ^= (uint8_t)(warp_g); + seed_in[SEEDBYTES - 3] ^= (uint8_t)(warp_g >> 8); + seed_in[SEEDBYTES - 2] ^= (uint8_t)(warp_g >> 16); + seed_in[SEEDBYTES - 1] ^= (uint8_t)(warp_g >> 24); + +#if ALGORITHM == ALGO_MLDSA + /* ML-DSA: H(seed || K || L) → rho(32) | rhoprime(64) | key(32) */ + uint8_t buf[2 * SEEDBYTES + CRHBYTES]; + for (int i = 0; i < SEEDBYTES; i++) buf[i] = seed_in[i]; + buf[SEEDBYTES] = PARAM_K; + buf[SEEDBYTES + 1] = PARAM_L; + shake256(buf, 2 * SEEDBYTES + CRHBYTES, buf, SEEDBYTES + 2); + for (int i = 0; i < 2 * SEEDBYTES + CRHBYTES; i++) my_seeds[i] = buf[i]; +#elif ALGORITHM == ALGO_AIGIS + /* Aigis: H(seed) → eta_seed(32) | rho(32) | key(32) */ + uint8_t buf[3 * SEEDBYTES]; + shake256(buf, 3 * SEEDBYTES, seed_in, SEEDBYTES); + for (int i = 0; i < 3 * SEEDBYTES; i++) my_seeds[i] = buf[i]; +#endif + + /* 存 rho, key 到 d_buf (供后续 pack kernel 使用) */ + unsigned char *my_buf = d_buf + (size_t)warp_g * (2 * SEEDBYTES + CRHBYTES); +#if ALGORITHM == ALGO_MLDSA + const uint8_t *rho = my_seeds; + const uint8_t *key = my_seeds + SEEDBYTES + CRHBYTES; + const uint8_t *rhp = my_seeds + SEEDBYTES; + for (int i = 0; i < SEEDBYTES; i++) my_buf[i] = rho[i]; + for (int i = 0; i < SEEDBYTES; i++) my_buf[SEEDBYTES + i] = key[i]; + for (int i = 0; i < CRHBYTES; i++) my_buf[2 * SEEDBYTES + i] = rhp[i]; +#elif ALGORITHM == ALGO_AIGIS + const uint8_t *eta_seed = my_seeds; + const uint8_t *rho = my_seeds + SEEDBYTES; + const uint8_t *key = my_seeds + 2 * SEEDBYTES; + for (int i = 0; i < SEEDBYTES; i++) my_buf[i] = rho[i]; + for (int i = 0; i < SEEDBYTES; i++) my_buf[SEEDBYTES + i] = key[i]; + for (int i = 0; i < SEEDBYTES; i++) my_buf[2 * SEEDBYTES + i] = eta_seed[i]; +#endif + } + __syncwarp(); + + /* 所有 32 lanes 并行生成多项式 */ + /* 多项式索引分配: + * p = 0 .. K*L-1 → A[p/L][p%L] + * p = K*L .. K*L+L-1 → s1[p - K*L] + * p = K*L+L .. end → s2[p - K*L - L] + */ + const int TOTAL_POLYS = PARAM_K * PARAM_L + PARAM_L + PARAM_K; + +#if ALGORITHM == ALGO_MLDSA + const uint8_t *rho = my_seeds; + const uint8_t *rhoprime = my_seeds + SEEDBYTES; +#elif ALGORITHM == ALGO_AIGIS + const uint8_t *eta_seed = my_seeds; + const uint8_t *rho = my_seeds + SEEDBYTES; +#endif + + for (int p = lane; p < TOTAL_POLYS; p += WP_KG_WARP_SIZE) { + coeff_t *dst; + + if (p < PARAM_K * PARAM_L) { + /* 矩阵 A[row][col] */ + int row = p / PARAM_L; + int col = p % PARAM_L; + dst = d_mat + (size_t)warp_g * PARAM_K * PARAM_L * PARAM_N + p * PARAM_N; + poly_uniform_to(dst, rho, MATRIX_NONCE(row, col)); + } else if (p < PARAM_K * PARAM_L + PARAM_L) { + /* 秘密向量 s1[j] */ + int j = p - PARAM_K * PARAM_L; +#if ALGORITHM == ALGO_MLDSA + dst = d_s1 + (size_t)warp_g * PARAM_L * PARAM_N + (size_t)j * PARAM_N; + poly_uniform_eta_s1_to(dst, rhoprime, j); +#elif ALGORITHM == ALGO_AIGIS + dst = d_s1 + (size_t)warp_g * PARAM_L * PARAM_N + (size_t)j * PARAM_N; + poly_uniform_eta_s1_to(dst, eta_seed, (uint16_t)j); +#endif + } else { + /* 秘密向量 s2[k] */ + int k = p - PARAM_K * PARAM_L - PARAM_L; +#if ALGORITHM == ALGO_MLDSA + dst = d_s2 + (size_t)warp_g * PARAM_K * PARAM_N + (size_t)k * PARAM_N; + poly_uniform_eta_s2_to(dst, rhoprime, PARAM_L + k); +#elif ALGORITHM == ALGO_AIGIS + dst = d_s2 + (size_t)warp_g * PARAM_K * PARAM_N + (size_t)k * PARAM_N; + poly_uniform_eta_s2_to(dst, eta_seed, (uint16_t)(PARAM_L + k)); +#endif + } + } +} + +__global__ void __launch_bounds__(WP_KG_TPB) +batch_keygen_seed_expand_kernel( + unsigned char * __restrict__ d_buf, + const unsigned char * __restrict__ d_base_seed, + int batch_count) +{ + const int warp_g = (blockIdx.x * blockDim.x + threadIdx.x) / WP_KG_WARP_SIZE; + const int lane = threadIdx.x & (WP_KG_WARP_SIZE - 1); + + if (warp_g >= batch_count || lane != 0) return; + + unsigned char *my_buf = d_buf + (size_t)warp_g * (2 * SEEDBYTES + CRHBYTES); + uint8_t seed_in[SEEDBYTES]; + for (int i = 0; i < SEEDBYTES; i++) seed_in[i] = d_base_seed[i]; + seed_in[SEEDBYTES - 4] ^= (uint8_t)(warp_g); + seed_in[SEEDBYTES - 3] ^= (uint8_t)(warp_g >> 8); + seed_in[SEEDBYTES - 2] ^= (uint8_t)(warp_g >> 16); + seed_in[SEEDBYTES - 1] ^= (uint8_t)(warp_g >> 24); + +#if ALGORITHM == ALGO_MLDSA + uint8_t buf[2 * SEEDBYTES + CRHBYTES]; + for (int i = 0; i < SEEDBYTES; i++) buf[i] = seed_in[i]; + buf[SEEDBYTES] = PARAM_K; + buf[SEEDBYTES + 1] = PARAM_L; + shake256(buf, 2 * SEEDBYTES + CRHBYTES, buf, SEEDBYTES + 2); + + const uint8_t *rho = buf; + const uint8_t *rhoprime = buf + SEEDBYTES; + const uint8_t *key = buf + SEEDBYTES + CRHBYTES; + for (int i = 0; i < SEEDBYTES; i++) my_buf[i] = rho[i]; + for (int i = 0; i < SEEDBYTES; i++) my_buf[SEEDBYTES + i] = key[i]; + for (int i = 0; i < CRHBYTES; i++) my_buf[2 * SEEDBYTES + i] = rhoprime[i]; +#elif ALGORITHM == ALGO_AIGIS + uint8_t buf[3 * SEEDBYTES]; + shake256(buf, 3 * SEEDBYTES, seed_in, SEEDBYTES); + + const uint8_t *eta_seed = buf; + const uint8_t *rho = buf + SEEDBYTES; + const uint8_t *key = buf + 2 * SEEDBYTES; + for (int i = 0; i < SEEDBYTES; i++) my_buf[i] = rho[i]; + for (int i = 0; i < SEEDBYTES; i++) my_buf[SEEDBYTES + i] = key[i]; + for (int i = 0; i < SEEDBYTES; i++) my_buf[2 * SEEDBYTES + i] = eta_seed[i]; +#endif +} + +__global__ void __launch_bounds__(WP_KG_TPB) +batch_keygen_matrix_a_sample_kernel( + coeff_t * __restrict__ d_mat, + const unsigned char * __restrict__ d_buf, + int batch_count) +{ + const int warp_g = (blockIdx.x * blockDim.x + threadIdx.x) / WP_KG_WARP_SIZE; + const int lane = threadIdx.x & (WP_KG_WARP_SIZE - 1); + + if (warp_g >= batch_count) return; + + const unsigned char *my_buf = d_buf + (size_t)warp_g * (2 * SEEDBYTES + CRHBYTES); + const uint8_t *rho = my_buf; + const int total = PARAM_K * PARAM_L; + + for (int p = lane; p < total; p += WP_KG_WARP_SIZE) { + int row = p / PARAM_L; + int col = p % PARAM_L; + coeff_t *dst = d_mat + (size_t)warp_g * PARAM_K * PARAM_L * PARAM_N + (size_t)p * PARAM_N; + poly_uniform_to(dst, rho, MATRIX_NONCE(row, col)); + } +} + +__global__ void __launch_bounds__(WP_KG_TPB) +batch_keygen_matrix_a_laneopt_kernel( + coeff_t * __restrict__ d_mat, + const unsigned char * __restrict__ d_buf, + int batch_count) +{ + __shared__ uint8_t sh_rho[WP_KG_WARPS_BLOCK][SEEDBYTES]; + + const int warp_g = (blockIdx.x * blockDim.x + threadIdx.x) / WP_KG_WARP_SIZE; + const int lane = threadIdx.x & (WP_KG_WARP_SIZE - 1); + const int warp_l = threadIdx.x / WP_KG_WARP_SIZE; + + if (warp_g >= batch_count) return; + + const unsigned char *my_buf = d_buf + (size_t)warp_g * (2 * SEEDBYTES + CRHBYTES); + uint8_t *rho_local = sh_rho[warp_l]; + if (lane < SEEDBYTES) + rho_local[lane] = my_buf[lane]; + __syncwarp(); + + const size_t inst_mat_off = (size_t)warp_g * PARAM_K * PARAM_L * PARAM_N; + const int total = PARAM_K * PARAM_L; + for (int p = lane; p < total; p += WP_KG_WARP_SIZE) { + const int row = p / PARAM_L; + const int col = p % PARAM_L; + coeff_t *dst = d_mat + inst_mat_off + (size_t)p * PARAM_N; + poly_uniform_to(dst, rho_local, MATRIX_NONCE(row, col)); + } +} + +__global__ void __launch_bounds__(WP_KG_TPB) +batch_keygen_secret_sample_kernel( + coeff_t * __restrict__ d_s1, + coeff_t * __restrict__ d_s2, + const unsigned char * __restrict__ d_buf, + int batch_count) +{ + const int warp_g = (blockIdx.x * blockDim.x + threadIdx.x) / WP_KG_WARP_SIZE; + const int lane = threadIdx.x & (WP_KG_WARP_SIZE - 1); + + if (warp_g >= batch_count) return; + + const unsigned char *my_buf = d_buf + (size_t)warp_g * (2 * SEEDBYTES + CRHBYTES); + const int total = PARAM_L + PARAM_K; + +#if ALGORITHM == ALGO_MLDSA + const uint8_t *rhoprime = my_buf + 2 * SEEDBYTES; +#elif ALGORITHM == ALGO_AIGIS + const uint8_t *eta_seed = my_buf + 2 * SEEDBYTES; +#endif + + for (int p = lane; p < total; p += WP_KG_WARP_SIZE) { + if (p < PARAM_L) { + int j = p; + coeff_t *dst = d_s1 + (size_t)warp_g * PARAM_L * PARAM_N + (size_t)j * PARAM_N; +#if ALGORITHM == ALGO_MLDSA + poly_uniform_eta_s1_to(dst, rhoprime, j); +#elif ALGORITHM == ALGO_AIGIS + poly_uniform_eta_s1_to(dst, eta_seed, (uint16_t)j); +#endif + } else { + int k = p - PARAM_L; + coeff_t *dst = d_s2 + (size_t)warp_g * PARAM_K * PARAM_N + (size_t)k * PARAM_N; +#if ALGORITHM == ALGO_MLDSA + poly_uniform_eta_s2_to(dst, rhoprime, PARAM_L + k); +#elif ALGORITHM == ALGO_AIGIS + poly_uniform_eta_s2_to(dst, eta_seed, (uint16_t)(PARAM_L + k)); +#endif + } + } +} + +template +__device__ __forceinline__ unsigned long long wp_kg_subwarp_mask(int lane_in_warp) +{ + const int base = lane_in_warp - (lane_in_warp & (SUBWARP_LANES - 1)); + return (0xFFFFFFFFull >> (32 - SUBWARP_LANES)) << base; +} + +template +__device__ __forceinline__ int wp_kg_subwarp_exclusive_scan(int value, + unsigned long long mask, + int sublane) +{ + int scan = value; +#pragma unroll + for (int offset = 1; offset < SUBWARP_LANES; offset <<= 1) { + int other = __shfl_up_sync(mask, scan, offset); + if (sublane >= offset) + scan += other; + } + return scan - value; +} + +template +__device__ __forceinline__ int wp_kg_subwarp_sum(int value, + unsigned long long mask, + int leader_lane) +{ + int sum = value; +#pragma unroll + for (int offset = SUBWARP_LANES >> 1; offset > 0; offset >>= 1) + sum += __shfl_down_sync(mask, sum, offset); + return __shfl_sync(mask, sum, leader_lane); +} + +__device__ __forceinline__ void wp_kg_store_coeff(coeff_t *dst, + coeff_t *dst_copy, + int idx, + coeff_t value) +{ + dst[idx] = value; + if (dst_copy) + dst_copy[idx] = value; +} + +template +__device__ void wp_kg_uniform_coop_sample_to( + coeff_t *dst, + const uint8_t *seed, + uint16_t nonce, + uint8_t *buf, + int *ctr_ptr, + unsigned int *buflen_ptr, + unsigned long long mask, + int sublane, + int leader_lane) +{ + stream128_state state; + + if (sublane == 0) { +#if ALGORITHM == ALGO_MLDSA + stream128_init(&state, seed, nonce); +#elif ALGORITHM == ALGO_AIGIS + aigis_shake128_stream_init(&state, seed, (uint8_t)nonce); +#endif + stream128_squeezeblocks(buf, POLY_UNIFORM_NBLOCKS, &state); + *ctr_ptr = 0; + *buflen_ptr = POLY_UNIFORM_NBLOCKS * STREAM128_BLOCKBYTES; + } + __syncwarp(mask); + + while (1) { + int cur_ctr = *ctr_ptr; + if (cur_ctr >= PARAM_N) + break; + + const unsigned int buflen = *buflen_ptr; + const int total_candidates = (int)(buflen / 3u); + + for (int base = 0; base < total_candidates; base += SUBWARP_LANES) { + cur_ctr = *ctr_ptr; + if (cur_ctr >= PARAM_N) + break; + + const int cand = base + sublane; + int accept = 0; + coeff_t value = 0; + + if (cand < total_candidates) { + const size_t pos = (size_t)cand * 3u; + uint32_t t = buf[pos] + | ((uint32_t)buf[pos + 1] << 8) + | ((uint32_t)buf[pos + 2] << 16); + t &= (1u << PARAM_QBITS) - 1u; + if (t < (uint32_t)PARAM_Q) { + accept = 1; + value = (coeff_t)t; + } + } + + const int prefix = wp_kg_subwarp_exclusive_scan(accept, mask, sublane); + const int accepted = wp_kg_subwarp_sum(accept, mask, leader_lane); + + if (accept) { + const int out_idx = cur_ctr + prefix; + if (out_idx < PARAM_N) + dst[out_idx] = value; + } + + if (sublane == 0) { + const int next_ctr = cur_ctr + accepted; + *ctr_ptr = next_ctr < PARAM_N ? next_ctr : PARAM_N; + } + __syncwarp(mask); + } + + cur_ctr = *ctr_ptr; + if (cur_ctr >= PARAM_N) + break; + + if (sublane == 0) { + const unsigned int buflen_local = *buflen_ptr; + const unsigned int off = buflen_local % 3u; + for (unsigned int i = 0; i < off; ++i) + buf[i] = buf[buflen_local - off + i]; + stream128_squeezeblocks(buf + off, 1, &state); + *buflen_ptr = STREAM128_BLOCKBYTES + off; + } + __syncwarp(mask); + } +} + +template +__device__ void wp_kg_eta_mldsa_coop_sample_to( + coeff_t *dst, + coeff_t *dst_copy, + const uint8_t *seed, + uint16_t nonce, + int eta, + int init_blocks, + uint8_t *buf, + int *ctr_ptr, + unsigned int *buflen_ptr, + unsigned long long mask, + int sublane, + int leader_lane) +{ + stream256_state state; + + if (sublane == 0) { + stream256_init(&state, seed, nonce); + stream256_squeezeblocks(buf, init_blocks, &state); + *ctr_ptr = 0; + *buflen_ptr = (unsigned int)init_blocks * STREAM256_BLOCKBYTES; + } + __syncwarp(mask); + + while (1) { + int cur_ctr = *ctr_ptr; + if (cur_ctr >= PARAM_N) + break; + + const int total_bytes = (int)(*buflen_ptr); + for (int base = 0; base < total_bytes; base += SUBWARP_LANES) { + cur_ctr = *ctr_ptr; + if (cur_ctr >= PARAM_N) + break; + + const int byte_pos = base + sublane; + int have0 = 0, have1 = 0; + coeff_t value0 = 0, value1 = 0; + int count = 0; + + if (byte_pos < total_bytes) { + uint32_t t0 = buf[byte_pos] & 0x0F; + uint32_t t1 = buf[byte_pos] >> 4; + if (eta == 2) { + if (t0 < 15) { + t0 = t0 - ((205 * t0) >> 10) * 5; + value0 = 2 - (int32_t)t0; + have0 = 1; + count++; + } + if (t1 < 15) { + t1 = t1 - ((205 * t1) >> 10) * 5; + value1 = 2 - (int32_t)t1; + have1 = 1; + count++; + } + } else { + if (t0 < 9) { + value0 = 4 - (int32_t)t0; + have0 = 1; + count++; + } + if (t1 < 9) { + value1 = 4 - (int32_t)t1; + have1 = 1; + count++; + } + } + } + + const int prefix = wp_kg_subwarp_exclusive_scan(count, mask, sublane); + const int accepted = wp_kg_subwarp_sum(count, mask, leader_lane); + + if (have0) { + const int out_idx = cur_ctr + prefix; + if (out_idx < PARAM_N) + wp_kg_store_coeff(dst, dst_copy, out_idx, value0); + } + if (have1) { + const int out_idx = cur_ctr + prefix + have0; + if (out_idx < PARAM_N) + wp_kg_store_coeff(dst, dst_copy, out_idx, value1); + } + + if (sublane == 0) { + const int next_ctr = cur_ctr + accepted; + *ctr_ptr = next_ctr < PARAM_N ? next_ctr : PARAM_N; + } + __syncwarp(mask); + } + + cur_ctr = *ctr_ptr; + if (cur_ctr >= PARAM_N) + break; + + if (sublane == 0) { + stream256_squeezeblocks(buf, 1, &state); + *buflen_ptr = STREAM256_BLOCKBYTES; + } + __syncwarp(mask); + } +} + +template +__device__ void wp_kg_eta1_aigis_coop_sample_to( + coeff_t *dst, + coeff_t *dst_copy, + const uint8_t *seed, + uint16_t nonce, + uint8_t *buf, + int *ctr_ptr, + unsigned int *buflen_ptr, + unsigned long long mask, + int sublane, + int leader_lane) +{ + stream256_state state; + + if (sublane == 0) { + aigis_shake256_eta_init(&state, seed, (uint8_t)nonce); + stream256_squeezeblocks(buf, POLY_UNIFORM_ETA1_NBLOCKS, &state); + *ctr_ptr = 0; + *buflen_ptr = POLY_UNIFORM_ETA1_NBLOCKS * STREAM256_BLOCKBYTES; + } + __syncwarp(mask); + + while (1) { + int cur_ctr = *ctr_ptr; + if (cur_ctr >= PARAM_N) + break; + +#if PARAM_ETA_S1 == 1 + const int total_units = (int)(*buflen_ptr); +#else + const int total_units = (int)(*buflen_ptr / 3u); +#endif + + for (int base = 0; base < total_units; base += SUBWARP_LANES) { + cur_ctr = *ctr_ptr; + if (cur_ctr >= PARAM_N) + break; + + const int unit = base + sublane; + coeff_t values[8]; + int count = 0; + +#if PARAM_ETA_S1 == 1 + if (unit < total_units) { + const uint32_t byte = buf[unit]; + const uint32_t t0 = byte & 0x03; + const uint32_t t1 = (byte >> 2) & 0x03; + const uint32_t t2 = (byte >> 4) & 0x03; + const uint32_t t3 = byte >> 6; + if (t0 <= 2u * PARAM_ETA_S1) values[count++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t0; + if (t1 <= 2u * PARAM_ETA_S1) values[count++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t1; + if (t2 <= 2u * PARAM_ETA_S1) values[count++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t2; + if (t3 <= 2u * PARAM_ETA_S1) values[count++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t3; + } +#else + if (unit < total_units) { + const int pos = unit * 3; + const uint32_t t0 = buf[pos] & 0x07; + const uint32_t t1 = (buf[pos] >> 3) & 0x07; + const uint32_t t2 = (buf[pos] >> 6) | ((uint32_t)(buf[pos + 1] & 0x01) << 2); + const uint32_t t3 = (buf[pos + 1] >> 1) & 0x07; + const uint32_t t4 = (buf[pos + 1] >> 4) & 0x07; + const uint32_t t5 = (buf[pos + 1] >> 7) | ((uint32_t)(buf[pos + 2] & 0x03) << 1); + const uint32_t t6 = (buf[pos + 2] >> 2) & 0x07; + const uint32_t t7 = buf[pos + 2] >> 5; + if (t0 <= 2u * PARAM_ETA_S1) values[count++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t0; + if (t1 <= 2u * PARAM_ETA_S1) values[count++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t1; + if (t2 <= 2u * PARAM_ETA_S1) values[count++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t2; + if (t3 <= 2u * PARAM_ETA_S1) values[count++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t3; + if (t4 <= 2u * PARAM_ETA_S1) values[count++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t4; + if (t5 <= 2u * PARAM_ETA_S1) values[count++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t5; + if (t6 <= 2u * PARAM_ETA_S1) values[count++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t6; + if (t7 <= 2u * PARAM_ETA_S1) values[count++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t7; + } +#endif + + const int prefix = wp_kg_subwarp_exclusive_scan(count, mask, sublane); + const int accepted = wp_kg_subwarp_sum(count, mask, leader_lane); + + for (int i = 0; i < count; ++i) { + const int out_idx = cur_ctr + prefix + i; + if (out_idx < PARAM_N) + wp_kg_store_coeff(dst, dst_copy, out_idx, values[i]); + } + + if (sublane == 0) { + const int next_ctr = cur_ctr + accepted; + *ctr_ptr = next_ctr < PARAM_N ? next_ctr : PARAM_N; + } + __syncwarp(mask); + } + + cur_ctr = *ctr_ptr; + if (cur_ctr >= PARAM_N) + break; + + if (sublane == 0) { + stream256_squeezeblocks(buf, 1, &state); + *buflen_ptr = STREAM256_BLOCKBYTES; + } + __syncwarp(mask); + } +} + +template +__device__ void wp_kg_eta2_aigis_coop_sample_to( + coeff_t *dst, + const uint8_t *seed, + uint16_t nonce, + uint8_t *buf, + int *ctr_ptr, + unsigned int *buflen_ptr, + unsigned long long mask, + int sublane, + int leader_lane) +{ +#if PARAM_ETA_S2 == 5 + if (sublane == 0) + poly_uniform_eta_s2_to(dst, seed, nonce); + __syncwarp(mask); +#else + stream256_state state; + + if (sublane == 0) { + aigis_shake256_eta_init(&state, seed, (uint8_t)nonce); + stream256_squeezeblocks(buf, 2, &state); + *ctr_ptr = 0; + *buflen_ptr = 2 * STREAM256_BLOCKBYTES; + } + __syncwarp(mask); + + while (1) { + int cur_ctr = *ctr_ptr; + if (cur_ctr >= PARAM_N) + break; + + const int total_bytes = (int)(*buflen_ptr); + for (int base = 0; base < total_bytes; base += SUBWARP_LANES) { + cur_ctr = *ctr_ptr; + if (cur_ctr >= PARAM_N) + break; + + const int byte_pos = base + sublane; + int have0 = 0, have1 = 0; + coeff_t value0 = 0, value1 = 0; + int count = 0; + + if (byte_pos < total_bytes) { + uint32_t t0 = buf[byte_pos] & 0x07; + uint32_t t1 = buf[byte_pos] >> 5; + if (t0 <= 2u * PARAM_ETA_S2) { + value0 = PARAM_Q + PARAM_ETA_S2 - (int32_t)t0; + have0 = 1; + count++; + } + if (t1 <= 2u * PARAM_ETA_S2) { + value1 = PARAM_Q + PARAM_ETA_S2 - (int32_t)t1; + have1 = 1; + count++; + } + } + + const int prefix = wp_kg_subwarp_exclusive_scan(count, mask, sublane); + const int accepted = wp_kg_subwarp_sum(count, mask, leader_lane); + + if (have0) { + const int out_idx = cur_ctr + prefix; + if (out_idx < PARAM_N) + dst[out_idx] = value0; + } + if (have1) { + const int out_idx = cur_ctr + prefix + have0; + if (out_idx < PARAM_N) + dst[out_idx] = value1; + } + + if (sublane == 0) { + const int next_ctr = cur_ctr + accepted; + *ctr_ptr = next_ctr < PARAM_N ? next_ctr : PARAM_N; + } + __syncwarp(mask); + } + + cur_ctr = *ctr_ptr; + if (cur_ctr >= PARAM_N) + break; + + if (sublane == 0) { + stream256_squeezeblocks(buf, 1, &state); + *buflen_ptr = STREAM256_BLOCKBYTES; + } + __syncwarp(mask); + } +#endif +} + +template +__global__ void __launch_bounds__(WP_KG_TPB) +batch_keygen_matrix_a_coop_kernel( + coeff_t * __restrict__ d_mat, + const unsigned char * __restrict__ d_buf, + int batch_count) +{ + __shared__ uint8_t sh_buf[WP_KG_MAX_SUBWARPS_PER_BLOCK][WP_KG_MATRIX_COOP_BUF_BYTES]; + __shared__ int sh_ctr[WP_KG_MAX_SUBWARPS_PER_BLOCK]; + __shared__ unsigned int sh_buflen[WP_KG_MAX_SUBWARPS_PER_BLOCK]; + + const int lane_in_warp = threadIdx.x & (WP_KG_WARP_SIZE - 1); + const int warp_local = threadIdx.x / WP_KG_WARP_SIZE; + const int subwarp_base = lane_in_warp - (lane_in_warp & (SUBWARP_LANES - 1)); + const int sublane = lane_in_warp - subwarp_base; + const int subwarps_per_warp = WP_KG_WARP_SIZE / SUBWARP_LANES; + const int group_local = warp_local * subwarps_per_warp + (lane_in_warp / SUBWARP_LANES); + const int polys_per_block = blockDim.x / SUBWARP_LANES; + const int poly_global = blockIdx.x * polys_per_block + group_local; + const int total_polys = batch_count * PARAM_K * PARAM_L; + const unsigned long long mask = wp_kg_subwarp_mask(lane_in_warp); + + if (poly_global >= total_polys) return; + + const int inst = poly_global / (PARAM_K * PARAM_L); + const int poly_local = poly_global % (PARAM_K * PARAM_L); + const int row = poly_local / PARAM_L; + const int col = poly_local % PARAM_L; + + const unsigned char *my_buf = d_buf + (size_t)inst * (2 * SEEDBYTES + CRHBYTES); + const uint8_t *rho = my_buf; + coeff_t *dst = d_mat + (size_t)inst * PARAM_K * PARAM_L * PARAM_N + (size_t)poly_local * PARAM_N; + + wp_kg_uniform_coop_sample_to( + dst, rho, MATRIX_NONCE(row, col), + sh_buf[group_local], &sh_ctr[group_local], &sh_buflen[group_local], + mask, sublane, subwarp_base); +} + +template +__global__ void __launch_bounds__(WP_KG_TPB) +batch_keygen_secret_eta_coop_kernel( + coeff_t * __restrict__ d_s1, + coeff_t * __restrict__ d_s1hat, + coeff_t * __restrict__ d_s2, + const unsigned char * __restrict__ d_buf, + int batch_count) +{ + __shared__ uint8_t sh_buf[WP_KG_MAX_SUBWARPS_PER_BLOCK][WP_KG_ETA_COOP_BUF_BYTES]; + __shared__ int sh_ctr[WP_KG_MAX_SUBWARPS_PER_BLOCK]; + __shared__ unsigned int sh_buflen[WP_KG_MAX_SUBWARPS_PER_BLOCK]; + + const int lane_in_warp = threadIdx.x & (WP_KG_WARP_SIZE - 1); + const int warp_local = threadIdx.x / WP_KG_WARP_SIZE; + const int subwarp_base = lane_in_warp - (lane_in_warp & (SUBWARP_LANES - 1)); + const int sublane = lane_in_warp - subwarp_base; + const int subwarps_per_warp = WP_KG_WARP_SIZE / SUBWARP_LANES; + const int group_local = warp_local * subwarps_per_warp + (lane_in_warp / SUBWARP_LANES); + const int polys_per_block = blockDim.x / SUBWARP_LANES; + const int poly_global = blockIdx.x * polys_per_block + group_local; + const int total_polys = batch_count * (PARAM_L + PARAM_K); + const unsigned long long mask = wp_kg_subwarp_mask(lane_in_warp); + + if (poly_global >= total_polys) return; + + const int inst = poly_global / (PARAM_L + PARAM_K); + const int poly_local = poly_global % (PARAM_L + PARAM_K); + const unsigned char *my_buf = d_buf + (size_t)inst * (2 * SEEDBYTES + CRHBYTES); + +#if ALGORITHM == ALGO_MLDSA + const uint8_t *eta_seed = my_buf + 2 * SEEDBYTES; +#elif ALGORITHM == ALGO_AIGIS + const uint8_t *eta_seed = my_buf + 2 * SEEDBYTES; +#endif + + if (poly_local < PARAM_L) { + const int j = poly_local; + coeff_t *dst = d_s1 + (size_t)inst * PARAM_L * PARAM_N + (size_t)j * PARAM_N; + coeff_t *dst_copy = d_s1hat ? (d_s1hat + (size_t)inst * PARAM_L * PARAM_N + (size_t)j * PARAM_N) : NULL; +#if ALGORITHM == ALGO_MLDSA + wp_kg_eta_mldsa_coop_sample_to( + dst, dst_copy, eta_seed, (uint16_t)j, + PARAM_ETA_S1, POLY_UNIFORM_ETA1_NBLOCKS, + sh_buf[group_local], &sh_ctr[group_local], &sh_buflen[group_local], + mask, sublane, subwarp_base); +#elif ALGORITHM == ALGO_AIGIS + wp_kg_eta1_aigis_coop_sample_to( + dst, dst_copy, eta_seed, (uint16_t)j, + sh_buf[group_local], &sh_ctr[group_local], &sh_buflen[group_local], + mask, sublane, subwarp_base); +#endif + } else { + const int k = poly_local - PARAM_L; + coeff_t *dst = d_s2 + (size_t)inst * PARAM_K * PARAM_N + (size_t)k * PARAM_N; +#if ALGORITHM == ALGO_MLDSA + wp_kg_eta_mldsa_coop_sample_to( + dst, NULL, eta_seed, (uint16_t)(PARAM_L + k), + PARAM_ETA_S2, POLY_UNIFORM_ETA2_NBLOCKS, + sh_buf[group_local], &sh_ctr[group_local], &sh_buflen[group_local], + mask, sublane, subwarp_base); +#elif ALGORITHM == ALGO_AIGIS + wp_kg_eta2_aigis_coop_sample_to( + dst, eta_seed, (uint16_t)(PARAM_L + k), + sh_buf[group_local], &sh_ctr[group_local], &sh_buflen[group_local], + mask, sublane, subwarp_base); +#endif + } +} + +#if ALGORITHM == ALGO_AIGIS +template +__global__ void __launch_bounds__(WP_KG_TPB) +batch_keygen_secret_eta1_aigis_coop_kernel( + coeff_t * __restrict__ d_s1, + coeff_t * __restrict__ d_s1hat, + const unsigned char * __restrict__ d_buf, + int batch_count) +{ + __shared__ uint8_t sh_buf[WP_KG_MAX_SUBWARPS_PER_BLOCK][WP_KG_ETA_COOP_BUF_BYTES]; + __shared__ int sh_ctr[WP_KG_MAX_SUBWARPS_PER_BLOCK]; + __shared__ unsigned int sh_buflen[WP_KG_MAX_SUBWARPS_PER_BLOCK]; + + const int lane_in_warp = threadIdx.x & (WP_KG_WARP_SIZE - 1); + const int warp_local = threadIdx.x / WP_KG_WARP_SIZE; + const int subwarp_base = lane_in_warp - (lane_in_warp & (SUBWARP_LANES - 1)); + const int sublane = lane_in_warp - subwarp_base; + const int subwarps_per_warp = WP_KG_WARP_SIZE / SUBWARP_LANES; + const int group_local = warp_local * subwarps_per_warp + (lane_in_warp / SUBWARP_LANES); + const int polys_per_block = blockDim.x / SUBWARP_LANES; + const int poly_global = blockIdx.x * polys_per_block + group_local; + const int total_polys = batch_count * PARAM_L; + const unsigned long long mask = wp_kg_subwarp_mask(lane_in_warp); + + if (poly_global >= total_polys) return; + + const int inst = poly_global / PARAM_L; + const int j = poly_global % PARAM_L; + const unsigned char *my_buf = d_buf + (size_t)inst * (2 * SEEDBYTES + CRHBYTES); + const uint8_t *eta_seed = my_buf + 2 * SEEDBYTES; + coeff_t *dst = d_s1 + (size_t)inst * PARAM_L * PARAM_N + (size_t)j * PARAM_N; + coeff_t *dst_copy = d_s1hat ? (d_s1hat + (size_t)inst * PARAM_L * PARAM_N + (size_t)j * PARAM_N) : NULL; + + wp_kg_eta1_aigis_coop_sample_to( + dst, dst_copy, eta_seed, (uint16_t)j, + sh_buf[group_local], &sh_ctr[group_local], &sh_buflen[group_local], + mask, sublane, subwarp_base); +} + +__global__ void __launch_bounds__(WP_KG_TPB) +batch_keygen_secret_eta2_aigis_sample_kernel( + coeff_t * __restrict__ d_s2, + const unsigned char * __restrict__ d_buf, + int batch_count) +{ + const int warp_g = (blockIdx.x * blockDim.x + threadIdx.x) / WP_KG_WARP_SIZE; + const int lane = threadIdx.x & (WP_KG_WARP_SIZE - 1); + + if (warp_g >= batch_count) return; + + const unsigned char *my_buf = d_buf + (size_t)warp_g * (2 * SEEDBYTES + CRHBYTES); + const uint8_t *eta_seed = my_buf + 2 * SEEDBYTES; + + for (int k = lane; k < PARAM_K; k += WP_KG_WARP_SIZE) { + coeff_t *dst = d_s2 + (size_t)warp_g * PARAM_K * PARAM_N + (size_t)k * PARAM_N; + poly_uniform_eta_s2_to(dst, eta_seed, (uint16_t)(PARAM_L + k)); + } +} +#endif + +__global__ void batch_keygen_paper_rho_to_buf_kernel( + unsigned char * __restrict__ d_buf, + const unsigned char * __restrict__ d_shared_rho, + int batch_count); + +__global__ void __launch_bounds__(WP_KG_TPB) +batch_keygen_paper_secret_sample_split_kernel( + coeff_t * __restrict__ d_s1, + coeff_t * __restrict__ d_s1hat, + coeff_t * __restrict__ d_s2, + unsigned char * __restrict__ d_buf, + const unsigned char * __restrict__ d_shared_rho, + int batch_count); + +static inline void launch_batch_keygen_matrix_a_active( + coeff_t *d_mat, + const unsigned char *d_buf, + int batch_count, + int nblk, + hipStream_t stream) +{ +#if BATCH_KEYGEN_MATRIX_A_COOP +#if BATCH_KEYGEN_MATRIX_A_COOP_SUBWARP +#if BATCH_KEYGEN_MATRIX_A_COOP_SUBWARP_LANES == 8 + { + const int polys_per_block = WP_KG_TPB / 8; + const int coop_nblk = (batch_count * PARAM_K * PARAM_L + polys_per_block - 1) / polys_per_block; + batch_keygen_matrix_a_coop_kernel<8><<>>( + d_mat, d_buf, batch_count); + } +#elif BATCH_KEYGEN_MATRIX_A_COOP_SUBWARP_LANES == 16 + { + const int polys_per_block = WP_KG_TPB / 16; + const int coop_nblk = (batch_count * PARAM_K * PARAM_L + polys_per_block - 1) / polys_per_block; + batch_keygen_matrix_a_coop_kernel<16><<>>( + d_mat, d_buf, batch_count); + } +#else +#error Unsupported BATCH_KEYGEN_MATRIX_A_COOP_SUBWARP_LANES +#endif +#else + { + const int polys_per_block = WP_KG_TPB / 32; + const int coop_nblk = (batch_count * PARAM_K * PARAM_L + polys_per_block - 1) / polys_per_block; + batch_keygen_matrix_a_coop_kernel<32><<>>( + d_mat, d_buf, batch_count); + } +#endif +#elif BATCH_KEYGEN_MATRIX_A_LANEOPT + batch_keygen_matrix_a_laneopt_kernel<<>>( + d_mat, d_buf, batch_count); +#else + batch_keygen_matrix_a_sample_kernel<<>>( + d_mat, d_buf, batch_count); +#endif +} + +static inline void launch_batch_keygen_secret_eta_active_independent( + coeff_t *d_s1, + coeff_t *d_s2, + const unsigned char *d_buf, + int batch_count, + int nblk, + hipStream_t stream) +{ +#if BATCH_KEYGEN_SECRET_ETA_COOP +#if ALGORITHM == ALGO_AIGIS && PARAM_ETA_S2 == 5 && BATCH_KEYGEN_SECRET_ETA_AIGIS5_SPLIT +#if BATCH_KEYGEN_SECRET_ETA_COOP_LANES == 8 + { + const int polys_per_block = WP_KG_TPB / 8; + const int coop_nblk = (batch_count * PARAM_L + polys_per_block - 1) / polys_per_block; + batch_keygen_secret_eta1_aigis_coop_kernel<8><<>>( + d_s1, NULL, d_buf, batch_count); + } +#elif BATCH_KEYGEN_SECRET_ETA_COOP_LANES == 16 + { + const int polys_per_block = WP_KG_TPB / 16; + const int coop_nblk = (batch_count * PARAM_L + polys_per_block - 1) / polys_per_block; + batch_keygen_secret_eta1_aigis_coop_kernel<16><<>>( + d_s1, NULL, d_buf, batch_count); + } +#elif BATCH_KEYGEN_SECRET_ETA_COOP_LANES == 32 + { + const int polys_per_block = WP_KG_TPB / 32; + const int coop_nblk = (batch_count * PARAM_L + polys_per_block - 1) / polys_per_block; + batch_keygen_secret_eta1_aigis_coop_kernel<32><<>>( + d_s1, NULL, d_buf, batch_count); + } +#else +#error Unsupported BATCH_KEYGEN_SECRET_ETA_COOP_LANES +#endif + batch_keygen_secret_eta2_aigis_sample_kernel<<>>( + d_s2, d_buf, batch_count); +#else +#if BATCH_KEYGEN_SECRET_ETA_COOP_LANES == 8 + { + const int polys_per_block = WP_KG_TPB / 8; + const int coop_nblk = (batch_count * (PARAM_L + PARAM_K) + polys_per_block - 1) / polys_per_block; + batch_keygen_secret_eta_coop_kernel<8><<>>( + d_s1, NULL, d_s2, d_buf, batch_count); + } +#elif BATCH_KEYGEN_SECRET_ETA_COOP_LANES == 16 + { + const int polys_per_block = WP_KG_TPB / 16; + const int coop_nblk = (batch_count * (PARAM_L + PARAM_K) + polys_per_block - 1) / polys_per_block; + batch_keygen_secret_eta_coop_kernel<16><<>>( + d_s1, NULL, d_s2, d_buf, batch_count); + } +#elif BATCH_KEYGEN_SECRET_ETA_COOP_LANES == 32 + { + const int polys_per_block = WP_KG_TPB / 32; + const int coop_nblk = (batch_count * (PARAM_L + PARAM_K) + polys_per_block - 1) / polys_per_block; + batch_keygen_secret_eta_coop_kernel<32><<>>( + d_s1, NULL, d_s2, d_buf, batch_count); + } +#else +#error Unsupported BATCH_KEYGEN_SECRET_ETA_COOP_LANES +#endif +#endif +#else + batch_keygen_secret_sample_kernel<<>>( + d_s1, d_s2, d_buf, batch_count); +#endif +} + +static inline void launch_batch_keygen_paper_rho_active( + unsigned char *d_buf, + const unsigned char *d_shared_rho, + int batch_count, + hipStream_t stream) +{ + batch_keygen_paper_rho_to_buf_kernel<<<(batch_count * SEEDBYTES + BATCH_TPB - 1) / BATCH_TPB, BATCH_TPB, 0, stream>>>( + d_buf, d_shared_rho, batch_count); +} + +static inline void launch_batch_keygen_secret_eta_active_paper( + coeff_t *d_s1, + coeff_t *d_s1hat, + coeff_t *d_s2, + unsigned char *d_buf, + const unsigned char *d_shared_rho, + int batch_count, + int nblk, + hipStream_t stream) +{ +#if BATCH_KEYGEN_SECRET_ETA_COOP +#if ALGORITHM == ALGO_AIGIS && PARAM_ETA_S2 == 5 && BATCH_KEYGEN_SECRET_ETA_AIGIS5_SPLIT +#if BATCH_KEYGEN_SECRET_ETA_COOP_LANES == 8 + { + const int polys_per_block = WP_KG_TPB / 8; + const int coop_nblk = (batch_count * PARAM_L + polys_per_block - 1) / polys_per_block; + batch_keygen_secret_eta1_aigis_coop_kernel<8><<>>( + d_s1, d_s1hat, d_buf, batch_count); + } +#elif BATCH_KEYGEN_SECRET_ETA_COOP_LANES == 16 + { + const int polys_per_block = WP_KG_TPB / 16; + const int coop_nblk = (batch_count * PARAM_L + polys_per_block - 1) / polys_per_block; + batch_keygen_secret_eta1_aigis_coop_kernel<16><<>>( + d_s1, d_s1hat, d_buf, batch_count); + } +#elif BATCH_KEYGEN_SECRET_ETA_COOP_LANES == 32 + { + const int polys_per_block = WP_KG_TPB / 32; + const int coop_nblk = (batch_count * PARAM_L + polys_per_block - 1) / polys_per_block; + batch_keygen_secret_eta1_aigis_coop_kernel<32><<>>( + d_s1, d_s1hat, d_buf, batch_count); + } +#else +#error Unsupported BATCH_KEYGEN_SECRET_ETA_COOP_LANES +#endif + batch_keygen_secret_eta2_aigis_sample_kernel<<>>( + d_s2, d_buf, batch_count); +#else +#if BATCH_KEYGEN_SECRET_ETA_COOP_LANES == 8 + { + const int polys_per_block = WP_KG_TPB / 8; + const int coop_nblk = (batch_count * (PARAM_L + PARAM_K) + polys_per_block - 1) / polys_per_block; + batch_keygen_secret_eta_coop_kernel<8><<>>( + d_s1, d_s1hat, d_s2, d_buf, batch_count); + } +#elif BATCH_KEYGEN_SECRET_ETA_COOP_LANES == 16 + { + const int polys_per_block = WP_KG_TPB / 16; + const int coop_nblk = (batch_count * (PARAM_L + PARAM_K) + polys_per_block - 1) / polys_per_block; + batch_keygen_secret_eta_coop_kernel<16><<>>( + d_s1, d_s1hat, d_s2, d_buf, batch_count); + } +#elif BATCH_KEYGEN_SECRET_ETA_COOP_LANES == 32 + { + const int polys_per_block = WP_KG_TPB / 32; + const int coop_nblk = (batch_count * (PARAM_L + PARAM_K) + polys_per_block - 1) / polys_per_block; + batch_keygen_secret_eta_coop_kernel<32><<>>( + d_s1, d_s1hat, d_s2, d_buf, batch_count); + } +#else +#error Unsupported BATCH_KEYGEN_SECRET_ETA_COOP_LANES +#endif +#endif +#else + batch_keygen_paper_secret_sample_split_kernel<<>>( + d_s1, d_s1hat, d_s2, d_buf, d_shared_rho, batch_count); +#endif +} + +__global__ void __launch_bounds__(WP_KG_TPB) +batch_keygen_paper_shared_a_kernel( + coeff_t * __restrict__ d_shared_mat, + unsigned char * __restrict__ d_shared_rho, + const unsigned char * __restrict__ d_base_seed) +{ + __shared__ unsigned char sh_rho[SEEDBYTES]; + + int tid = threadIdx.x; + if (tid == 0) { +#if ALGORITHM == ALGO_MLDSA + uint8_t buf[2 * SEEDBYTES + CRHBYTES]; + for (int i = 0; i < SEEDBYTES; i++) buf[i] = d_base_seed[i]; + buf[SEEDBYTES] = PARAM_K; + buf[SEEDBYTES + 1] = PARAM_L; + shake256(buf, 2 * SEEDBYTES + CRHBYTES, buf, SEEDBYTES + 2); + for (int i = 0; i < SEEDBYTES; i++) sh_rho[i] = buf[i]; +#elif ALGORITHM == ALGO_AIGIS + uint8_t buf[3 * SEEDBYTES]; + shake256(buf, 3 * SEEDBYTES, d_base_seed, SEEDBYTES); + for (int i = 0; i < SEEDBYTES; i++) sh_rho[i] = buf[SEEDBYTES + i]; +#endif + for (int i = 0; i < SEEDBYTES; i++) d_shared_rho[i] = sh_rho[i]; + } + __syncthreads(); + + const int total = PARAM_K * PARAM_L; + for (int p = tid; p < total; p += blockDim.x) { + poly tmp; + int row = p / PARAM_L; + int col = p % PARAM_L; + poly_uniform(&tmp, sh_rho, MATRIX_NONCE(row, col)); + coeff_t *dst = d_shared_mat + (size_t)p * PARAM_N; + for (int c = 0; c < PARAM_N; c++) dst[c] = tmp.coeffs[c]; + } +} + +__global__ void batch_keygen_paper_rho_to_buf_kernel( + unsigned char * __restrict__ d_buf, + const unsigned char * __restrict__ d_shared_rho, + int batch_count) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = batch_count * SEEDBYTES; + if (idx >= total) return; + + int inst = idx / SEEDBYTES; + int off = idx % SEEDBYTES; + d_buf[(size_t)inst * (2 * SEEDBYTES + CRHBYTES) + off] = d_shared_rho[off]; +} + +__global__ void __launch_bounds__(WP_KG_TPB) +batch_keygen_paper_secret_sample_split_kernel( + coeff_t * __restrict__ d_s1, + coeff_t * __restrict__ d_s1hat, + coeff_t * __restrict__ d_s2, + unsigned char * __restrict__ d_buf, + const unsigned char * __restrict__ d_shared_rho, + int batch_count) +{ + const int warp_g = (blockIdx.x * blockDim.x + threadIdx.x) / WP_KG_WARP_SIZE; + const int lane = threadIdx.x & (WP_KG_WARP_SIZE - 1); + + if (warp_g >= batch_count) return; + + unsigned char *my_buf = d_buf + (size_t)warp_g * (2 * SEEDBYTES + CRHBYTES); + if (lane == 0) { + for (int i = 0; i < SEEDBYTES; i++) my_buf[i] = d_shared_rho[i]; + } + __syncwarp(); + + const int total = PARAM_L + PARAM_K; + +#if ALGORITHM == ALGO_MLDSA + const uint8_t *rhoprime = my_buf + 2 * SEEDBYTES; +#elif ALGORITHM == ALGO_AIGIS + const uint8_t *eta_seed = my_buf + 2 * SEEDBYTES; +#endif + + for (int p = lane; p < total; p += WP_KG_WARP_SIZE) { + if (p < PARAM_L) { + int j = p; + coeff_t *dst1 = d_s1 + (size_t)warp_g * PARAM_L * PARAM_N + (size_t)j * PARAM_N; +#if ALGORITHM == ALGO_MLDSA + poly_uniform_eta_s1_to(dst1, rhoprime, j); +#elif ALGORITHM == ALGO_AIGIS + poly_uniform_eta_s1_to(dst1, eta_seed, (uint16_t)j); +#endif + coeff_t *dsth = d_s1hat + (size_t)warp_g * PARAM_L * PARAM_N + (size_t)j * PARAM_N; + for (int c = 0; c < PARAM_N; c++) dsth[c] = dst1[c]; + } else { + int k = p - PARAM_L; + coeff_t *dst = d_s2 + (size_t)warp_g * PARAM_K * PARAM_N + (size_t)k * PARAM_N; +#if ALGORITHM == ALGO_MLDSA + poly_uniform_eta_s2_to(dst, rhoprime, PARAM_L + k); +#elif ALGORITHM == ALGO_AIGIS + poly_uniform_eta_s2_to(dst, eta_seed, (uint16_t)(PARAM_L + k)); +#endif + } + } +} + +__global__ void __launch_bounds__(WP_KG_TPB) +batch_keygen_paper_secret_sample_kernel( + coeff_t * __restrict__ d_s1, + coeff_t * __restrict__ d_s1hat, + coeff_t * __restrict__ d_s2, + unsigned char * __restrict__ d_buf, + const unsigned char * __restrict__ d_base_seed, + const unsigned char * __restrict__ d_shared_rho, + int batch_count) +{ + __shared__ unsigned char sh_seeds[WP_KG_WARPS_BLOCK][WP_KG_SEED_BYTES]; + __shared__ unsigned char sh_rho[WP_KG_WARPS_BLOCK][SEEDBYTES]; + + const int warp_g = (blockIdx.x * blockDim.x + threadIdx.x) / WP_KG_WARP_SIZE; + const int lane = threadIdx.x & (WP_KG_WARP_SIZE - 1); + const int warp_l = threadIdx.x / WP_KG_WARP_SIZE; + + if (warp_g >= batch_count) return; + + unsigned char *my_seeds = sh_seeds[warp_l]; + unsigned char *my_rho = sh_rho[warp_l]; + + if (lane == 0) { + uint8_t seed_in[SEEDBYTES]; + for (int i = 0; i < SEEDBYTES; i++) seed_in[i] = d_base_seed[i]; + seed_in[SEEDBYTES - 4] ^= (uint8_t)(warp_g); + seed_in[SEEDBYTES - 3] ^= (uint8_t)(warp_g >> 8); + seed_in[SEEDBYTES - 2] ^= (uint8_t)(warp_g >> 16); + seed_in[SEEDBYTES - 1] ^= (uint8_t)(warp_g >> 24); + +#if ALGORITHM == ALGO_MLDSA + uint8_t buf[2 * SEEDBYTES + CRHBYTES]; + for (int i = 0; i < SEEDBYTES; i++) buf[i] = seed_in[i]; + buf[SEEDBYTES] = PARAM_K; + buf[SEEDBYTES + 1] = PARAM_L; + shake256(buf, 2 * SEEDBYTES + CRHBYTES, buf, SEEDBYTES + 2); + for (int i = 0; i < 2 * SEEDBYTES + CRHBYTES; i++) my_seeds[i] = buf[i]; +#elif ALGORITHM == ALGO_AIGIS + uint8_t buf[3 * SEEDBYTES]; + shake256(buf, 3 * SEEDBYTES, seed_in, SEEDBYTES); + for (int i = 0; i < 3 * SEEDBYTES; i++) my_seeds[i] = buf[i]; +#endif + for (int i = 0; i < SEEDBYTES; i++) my_rho[i] = d_shared_rho[i]; + + unsigned char *my_buf = d_buf + (size_t)warp_g * (2 * SEEDBYTES + CRHBYTES); + for (int i = 0; i < SEEDBYTES; i++) my_buf[i] = my_rho[i]; +#if ALGORITHM == ALGO_MLDSA + const uint8_t *key = my_seeds + SEEDBYTES + CRHBYTES; + const uint8_t *rhp = my_seeds + SEEDBYTES; + for (int i = 0; i < SEEDBYTES; i++) my_buf[SEEDBYTES + i] = key[i]; + for (int i = 0; i < CRHBYTES; i++) my_buf[2 * SEEDBYTES + i] = rhp[i]; +#elif ALGORITHM == ALGO_AIGIS + const uint8_t *eta_seed = my_seeds; + const uint8_t *key = my_seeds + 2 * SEEDBYTES; + for (int i = 0; i < SEEDBYTES; i++) my_buf[SEEDBYTES + i] = key[i]; + for (int i = 0; i < SEEDBYTES; i++) my_buf[2 * SEEDBYTES + i] = eta_seed[i]; +#endif + } + __syncwarp(); + +#if ALGORITHM == ALGO_MLDSA + const uint8_t *rhoprime = my_seeds + SEEDBYTES; +#elif ALGORITHM == ALGO_AIGIS + const uint8_t *eta_seed = my_seeds; +#endif + + const int total = PARAM_L + PARAM_K; + for (int p = lane; p < total; p += WP_KG_WARP_SIZE) { + if (p < PARAM_L) { + int j = p; + coeff_t *dst1 = d_s1 + (size_t)warp_g * PARAM_L * PARAM_N + (size_t)j * PARAM_N; +#if ALGORITHM == ALGO_MLDSA + poly_uniform_eta_s1_to(dst1, rhoprime, j); +#elif ALGORITHM == ALGO_AIGIS + poly_uniform_eta_s1_to(dst1, eta_seed, (uint16_t)j); +#endif + coeff_t *dsth = d_s1hat + (size_t)warp_g * PARAM_L * PARAM_N + (size_t)j * PARAM_N; + for (int c = 0; c < PARAM_N; c++) { + dsth[c] = dst1[c]; + } + } else { + int k = p - PARAM_L; + coeff_t *dst = d_s2 + (size_t)warp_g * PARAM_K * PARAM_N + (size_t)k * PARAM_N; +#if ALGORITHM == ALGO_MLDSA + poly_uniform_eta_s2_to(dst, rhoprime, PARAM_L + k); +#elif ALGORITHM == ALGO_AIGIS + poly_uniform_eta_s2_to(dst, eta_seed, (uint16_t)(PARAM_L + k)); +#endif + } + } +} + +/* ================================================================ + * 矩阵向量乘 kernel — 共用 + * + * t[row] = Σ_{col} A[row][col] · s1hat[col] (NTT 域) + * grid: (batch_count, PARAM_K) + * block: PARAM_N threads + * ================================================================ */ +__global__ void batch_keygen_matvec_kernel( + coeff_t * __restrict__ d_t, + const coeff_t * __restrict__ d_mat, + const coeff_t * __restrict__ d_s1hat, + int batch_count) +{ + int inst = blockIdx.x; + int row = blockIdx.y; + if (inst >= batch_count) return; + + int tid = threadIdx.x; + + coeff2_t acc = 0; + #pragma unroll + for (int col = 0; col < PARAM_L; col++) { + coeff_t a = d_mat[(size_t)inst * PARAM_K * PARAM_L * PARAM_N + + (row * PARAM_L + col) * PARAM_N + tid]; + coeff_t b = d_s1hat[(size_t)inst * PARAM_L * PARAM_N + + col * PARAM_N + tid]; + acc += (coeff2_t)a * b; + } + + d_t[(size_t)inst * PARAM_K * PARAM_N + row * PARAM_N + tid] = (coeff_t)montgomery_reduce(acc); +} + +__global__ void batch_keygen_matvec_shared_a_kernel( + coeff_t * __restrict__ d_t, + const coeff_t * __restrict__ d_shared_mat, + const coeff_t * __restrict__ d_s1hat, + int batch_count) +{ + int inst = blockIdx.x; + int row = blockIdx.y; + if (inst >= batch_count) return; + + int tid = threadIdx.x; + coeff2_t acc = 0; + #pragma unroll + for (int col = 0; col < PARAM_L; col++) { + coeff_t a = d_shared_mat[(row * PARAM_L + col) * PARAM_N + tid]; + coeff_t b = d_s1hat[(size_t)inst * PARAM_L * PARAM_N + + col * PARAM_N + tid]; + acc += (coeff2_t)a * b; + } + d_t[(size_t)inst * PARAM_K * PARAM_N + row * PARAM_N + tid] = (coeff_t)montgomery_reduce(acc); +} + +__global__ void batch_keygen_add_norm_kernel( + coeff_t * __restrict__ d_t, + const coeff_t * __restrict__ d_s2, + int total_coeffs) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_coeffs) return; + coeff_t v = d_t[idx] + d_s2[idx]; +#if ALGORITHM == ALGO_MLDSA + v = coeff_normalize(v); +#elif ALGORITHM == ALGO_AIGIS + v = coeff_freeze_wide(v); +#endif + d_t[idx] = v; +} + +static inline void launch_batch_keygen_add_norm(coeff_t *d_t, + const coeff_t *d_s2, + int total_coeffs, + hipStream_t stream = 0) { + int nblk = (total_coeffs + BATCH_TPB - 1) / BATCH_TPB; + batch_keygen_add_norm_kernel<<>>(d_t, d_s2, total_coeffs); +} + +/* ================================================================ + * 打包 kernel — 参数位宽分叉 + * ================================================================ */ +__global__ void __launch_bounds__(32) +batch_keygen_pack_kernel( + unsigned char * __restrict__ d_pks, + unsigned char * __restrict__ d_sks, + coeff_t * __restrict__ d_t1_out, + coeff_t * __restrict__ d_t0_out, + unsigned char * __restrict__ d_tr_out, + const coeff_t * __restrict__ d_t, + const coeff_t * __restrict__ d_s1, + const coeff_t * __restrict__ d_s2, + const unsigned char * __restrict__ d_buf, + int batch_count) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= batch_count) return; + + polyveck t1_pk, t0_pk; + polyvecl s1_pk; + polyveck s2_pk; + + /* 从 flat 缓冲区加载 t 并在 pack 阶段就地 power2round */ + for (int i = 0; i < PARAM_K; i++) { + const coeff_t *src = d_t + (size_t)idx * PARAM_K * PARAM_N + i * PARAM_N; + for (int c = 0; c < PARAM_N; c++) { + int32_t v = (int32_t)src[c]; +#if ALGORITHM == ALGO_MLDSA + v += (v >> 31) & PARAM_Q; +#endif + t1_pk.vec[i].coeffs[c] = power2round(&t0_pk.vec[i].coeffs[c], v); + if (d_t1_out) + d_t1_out[(size_t)idx * PARAM_K * PARAM_N + (size_t)i * PARAM_N + c] = t1_pk.vec[i].coeffs[c]; + if (d_t0_out) + d_t0_out[(size_t)idx * PARAM_K * PARAM_N + (size_t)i * PARAM_N + c] = t0_pk.vec[i].coeffs[c]; + } + } + + /* 加载 s1, s2 */ + for (int i = 0; i < PARAM_L; i++) { + const coeff_t *src = d_s1 + (size_t)idx * PARAM_L * PARAM_N + i * PARAM_N; + for (int c = 0; c < PARAM_N; c++) s1_pk.vec[i].coeffs[c] = src[c]; + } + for (int i = 0; i < PARAM_K; i++) { + const coeff_t *src = d_s2 + (size_t)idx * PARAM_K * PARAM_N + i * PARAM_N; + for (int c = 0; c < PARAM_N; c++) s2_pk.vec[i].coeffs[c] = src[c]; + } + + /* 从 d_buf 恢复 rho, key */ + const unsigned char *my_buf = d_buf + (size_t)idx * (2 * SEEDBYTES + CRHBYTES); + uint8_t rho[SEEDBYTES], key_buf[SEEDBYTES]; + for (int i = 0; i < SEEDBYTES; i++) rho[i] = my_buf[i]; + for (int i = 0; i < SEEDBYTES; i++) key_buf[i] = my_buf[SEEDBYTES + i]; + + /* pack pk */ + uint8_t *pk = d_pks + (size_t)idx * CRYPTO_PUBLICKEYBYTES; + pack_pk(pk, rho, &t1_pk); + + /* hash pk → tr */ + uint8_t tr[TRBYTES]; + shake256(tr, TRBYTES, pk, CRYPTO_PUBLICKEYBYTES); + if (d_tr_out) { + unsigned char *tr_dst = d_tr_out + (size_t)idx * TRBYTES; + for (int i = 0; i < TRBYTES; i++) tr_dst[i] = tr[i]; + } + + /* pack sk */ + uint8_t *sk = d_sks + (size_t)idx * CRYPTO_SECRETKEYBYTES; + pack_sk(sk, rho, key_buf, tr, &s1_pk, &s2_pk, &t0_pk); +} + +__global__ void __launch_bounds__(32) +batch_keygen_pack_precomputed_kernel( + unsigned char * __restrict__ d_pks, + unsigned char * __restrict__ d_sks, + const coeff_t * __restrict__ d_t1, + const coeff_t * __restrict__ d_t0, + const coeff_t * __restrict__ d_s1, + const coeff_t * __restrict__ d_s2, + const unsigned char * __restrict__ d_buf, + unsigned char * __restrict__ d_tr_out, + int batch_count) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= batch_count) return; + + polyveck t1_pk, t0_pk; + polyvecl s1_pk; + polyveck s2_pk; + + for (int i = 0; i < PARAM_K; i++) { + const coeff_t *src1 = d_t1 + (size_t)idx * PARAM_K * PARAM_N + i * PARAM_N; + const coeff_t *src0 = d_t0 + (size_t)idx * PARAM_K * PARAM_N + i * PARAM_N; + for (int c = 0; c < PARAM_N; c++) { + t1_pk.vec[i].coeffs[c] = src1[c]; + t0_pk.vec[i].coeffs[c] = src0[c]; + } + } + + for (int i = 0; i < PARAM_L; i++) { + const coeff_t *src = d_s1 + (size_t)idx * PARAM_L * PARAM_N + i * PARAM_N; + for (int c = 0; c < PARAM_N; c++) s1_pk.vec[i].coeffs[c] = src[c]; + } + for (int i = 0; i < PARAM_K; i++) { + const coeff_t *src = d_s2 + (size_t)idx * PARAM_K * PARAM_N + i * PARAM_N; + for (int c = 0; c < PARAM_N; c++) s2_pk.vec[i].coeffs[c] = src[c]; + } + + const unsigned char *my_buf = d_buf + (size_t)idx * (2 * SEEDBYTES + CRHBYTES); + uint8_t rho[SEEDBYTES], key_buf[SEEDBYTES]; + for (int i = 0; i < SEEDBYTES; i++) rho[i] = my_buf[i]; + for (int i = 0; i < SEEDBYTES; i++) key_buf[i] = my_buf[SEEDBYTES + i]; + + uint8_t *pk = d_pks + (size_t)idx * CRYPTO_PUBLICKEYBYTES; + pack_pk(pk, rho, &t1_pk); + + uint8_t tr[TRBYTES]; + shake256(tr, TRBYTES, pk, CRYPTO_PUBLICKEYBYTES); + if (d_tr_out) { + unsigned char *tr_dst = d_tr_out + (size_t)idx * TRBYTES; + for (int i = 0; i < TRBYTES; i++) tr_dst[i] = tr[i]; + } + + uint8_t *sk = d_sks + (size_t)idx * CRYPTO_SECRETKEYBYTES; + pack_sk(sk, rho, key_buf, tr, &s1_pk, &s2_pk, &t0_pk); +} + +#ifndef BATCH_KEYGEN_TR_HASH_FIXED +#define BATCH_KEYGEN_TR_HASH_FIXED 1 +#endif + +#if BATCH_KEYGEN_TR_HASH_FIXED +static __device__ __noinline__ void batch_keygen_shake256_tr_pk( + uint8_t *out, + const uint8_t *pk) +{ + keccak_state state; + keccak_absorb_once(state.s, SHAKE256_RATE, pk, CRYPTO_PUBLICKEYBYTES, 0x1F); + KeccakF1600_StatePermute(state.s); + + const int whole_words = TRBYTES / 8; + for (int i = 0; i < whole_words; i++) { + store64(out + 8 * i, state.s[i]); + } + + const int tail_bytes = TRBYTES & 7; + if (tail_bytes) { + uint64_t tail_word = state.s[whole_words]; + for (int i = 0; i < tail_bytes; i++) { + out[whole_words * 8 + i] = (uint8_t)(tail_word >> (8 * i)); + } + } +} +#endif + +static __device__ __noinline__ void batch_keygen_pack_header_task( + unsigned char * __restrict__ d_pks, + unsigned char * __restrict__ d_sks, + const unsigned char * __restrict__ d_buf, + int task_id) +{ + int inst = task_id / (2 * SEEDBYTES); + int off = task_id - inst * (2 * SEEDBYTES); + const unsigned char *buf = d_buf + (size_t)inst * (2 * SEEDBYTES + CRHBYTES); + unsigned char *pk = d_pks + (size_t)inst * CRYPTO_PUBLICKEYBYTES; + unsigned char *sk = d_sks + (size_t)inst * CRYPTO_SECRETKEYBYTES; + + if (off < SEEDBYTES) { + unsigned char rho = buf[off]; + pk[off] = rho; + sk[off] = rho; + } else { + sk[off] = buf[off]; + } +} + +static __device__ __noinline__ void batch_keygen_pack_t1_task( + unsigned char * __restrict__ d_pks, + const coeff_t * __restrict__ d_t1, + int task_id) +{ +#if POLYT1_PACKED_BITS == 10 + const int groups_per_poly = PARAM_N / 4; +#else + const int groups_per_poly = PARAM_N; +#endif + int inst = task_id / (PARAM_K * groups_per_poly); + int rem = task_id - inst * (PARAM_K * groups_per_poly); + int poly_idx = rem / groups_per_poly; + int group = rem - poly_idx * groups_per_poly; + unsigned char *pk = d_pks + (size_t)inst * CRYPTO_PUBLICKEYBYTES; + +#if POLYT1_PACKED_BITS == 10 + const coeff_t *src = d_t1 + (size_t)inst * PARAM_K * PARAM_N + + (size_t)poly_idx * PARAM_N; + unsigned char *dst = pk + SEEDBYTES + (size_t)poly_idx * POLYT1_PACKEDBYTES; + uint32_t t0 = (uint32_t)src[4 * group + 0]; + uint32_t t1 = (uint32_t)src[4 * group + 1]; + uint32_t t2 = (uint32_t)src[4 * group + 2]; + uint32_t t3 = (uint32_t)src[4 * group + 3]; + dst[5 * group + 0] = (uint8_t)t0; + dst[5 * group + 1] = (uint8_t)((t0 >> 8) | (t1 << 2)); + dst[5 * group + 2] = (uint8_t)((t1 >> 6) | (t2 << 4)); + dst[5 * group + 3] = (uint8_t)((t2 >> 4) | (t3 << 6)); + dst[5 * group + 4] = (uint8_t)(t3 >> 2); +#elif POLYT1_PACKED_BITS == 8 + const int coeff_idx = group; + pk[SEEDBYTES + (size_t)poly_idx * POLYT1_PACKEDBYTES + coeff_idx] = + (uint8_t)d_t1[(size_t)inst * PARAM_K * PARAM_N + rem]; +#else + #error Unsupported POLYT1_PACKED_BITS +#endif +} + +static __device__ __noinline__ void batch_keygen_pack_s1_task( + unsigned char * __restrict__ d_sks, + const coeff_t * __restrict__ d_s1, + int task_id) +{ +#if SETA1BITS == 2 + const int groups_per_poly = PARAM_N / 4; +#elif SETA1BITS == 3 + const int groups_per_poly = PARAM_N / 8; +#else + const int groups_per_poly = PARAM_N / 2; +#endif + int inst = task_id / (PARAM_L * groups_per_poly); + int rem = task_id - inst * (PARAM_L * groups_per_poly); + int poly_idx = rem / groups_per_poly; + int group = rem - poly_idx * groups_per_poly; + + const coeff_t *src = d_s1 + (size_t)inst * PARAM_L * PARAM_N + + (size_t)poly_idx * PARAM_N; + unsigned char *dst = d_sks + (size_t)inst * CRYPTO_SECRETKEYBYTES + + (2 * SEEDBYTES + TRBYTES) + + (size_t)poly_idx * POLYETA1_PACKEDBYTES; + +#if SETA1BITS == 2 + uint8_t t0 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S1 - src[4 * group + 0]); + uint8_t t1 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S1 - src[4 * group + 1]); + uint8_t t2 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S1 - src[4 * group + 2]); + uint8_t t3 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S1 - src[4 * group + 3]); + dst[group] = t0 | (t1 << 2) | (t2 << 4) | (t3 << 6); +#elif SETA1BITS == 3 + uint8_t t0 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S1 - src[8 * group + 0]); + uint8_t t1 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S1 - src[8 * group + 1]); + uint8_t t2 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S1 - src[8 * group + 2]); + uint8_t t3 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S1 - src[8 * group + 3]); + uint8_t t4 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S1 - src[8 * group + 4]); + uint8_t t5 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S1 - src[8 * group + 5]); + uint8_t t6 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S1 - src[8 * group + 6]); + uint8_t t7 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S1 - src[8 * group + 7]); + dst[3 * group + 0] = t0 | (t1 << 3) | (t2 << 6); + dst[3 * group + 1] = (t2 >> 2) | (t3 << 1) | (t4 << 4) | (t5 << 7); + dst[3 * group + 2] = (t5 >> 1) | (t6 << 2) | (t7 << 5); +#elif SETA1BITS == 4 + uint8_t t0 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S1 - src[2 * group + 0]); + uint8_t t1 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S1 - src[2 * group + 1]); + dst[group] = t0 | (t1 << 4); +#endif +} + +static __device__ __noinline__ void batch_keygen_pack_s2_task( + unsigned char * __restrict__ d_sks, + const coeff_t * __restrict__ d_s2, + int task_id) +{ +#if SETA2BITS == 3 + const int groups_per_poly = PARAM_N / 8; +#else + const int groups_per_poly = PARAM_N / 2; +#endif + int inst = task_id / (PARAM_K * groups_per_poly); + int rem = task_id - inst * (PARAM_K * groups_per_poly); + int poly_idx = rem / groups_per_poly; + int group = rem - poly_idx * groups_per_poly; + + const coeff_t *src = d_s2 + (size_t)inst * PARAM_K * PARAM_N + + (size_t)poly_idx * PARAM_N; + unsigned char *dst = d_sks + (size_t)inst * CRYPTO_SECRETKEYBYTES + + (2 * SEEDBYTES + TRBYTES) + + (size_t)PARAM_L * POLYETA1_PACKEDBYTES + + (size_t)poly_idx * POLYETA2_PACKEDBYTES; + +#if SETA2BITS == 3 + uint8_t t0 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S2 - src[8 * group + 0]); + uint8_t t1 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S2 - src[8 * group + 1]); + uint8_t t2 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S2 - src[8 * group + 2]); + uint8_t t3 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S2 - src[8 * group + 3]); + uint8_t t4 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S2 - src[8 * group + 4]); + uint8_t t5 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S2 - src[8 * group + 5]); + uint8_t t6 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S2 - src[8 * group + 6]); + uint8_t t7 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S2 - src[8 * group + 7]); + dst[3 * group + 0] = t0 | (t1 << 3) | (t2 << 6); + dst[3 * group + 1] = (t2 >> 2) | (t3 << 1) | (t4 << 4) | (t5 << 7); + dst[3 * group + 2] = (t5 >> 1) | (t6 << 2) | (t7 << 5); +#elif SETA2BITS == 4 + uint8_t t0 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S2 - src[2 * group + 0]); + uint8_t t1 = (uint8_t)(COEFF_BIAS + PARAM_ETA_S2 - src[2 * group + 1]); + dst[group] = t0 | (t1 << 4); +#endif +} + +static __device__ __noinline__ void batch_keygen_pack_t0_task( + unsigned char * __restrict__ d_sks, + const coeff_t * __restrict__ d_t0, + int task_id) +{ +#if PARAM_D == 13 + const int groups_per_poly = PARAM_N / 8; +#else + const int groups_per_poly = PARAM_N / 4; +#endif + int inst = task_id / (PARAM_K * groups_per_poly); + int rem = task_id - inst * (PARAM_K * groups_per_poly); + int poly_idx = rem / groups_per_poly; + int group = rem - poly_idx * groups_per_poly; + + const coeff_t *src = d_t0 + (size_t)inst * PARAM_K * PARAM_N + + (size_t)poly_idx * PARAM_N; + unsigned char *dst = d_sks + (size_t)inst * CRYPTO_SECRETKEYBYTES + + (2 * SEEDBYTES + TRBYTES) + + (size_t)PARAM_L * POLYETA1_PACKEDBYTES + + (size_t)PARAM_K * POLYETA2_PACKEDBYTES + + (size_t)poly_idx * POLYT0_PACKEDBYTES; + +#if PARAM_D == 13 + uint32_t t[8]; + for (int j = 0; j < 8; j++) + t[j] = COEFF_BIAS + (1 << (PARAM_D - 1)) - src[8 * group + j]; + dst[13 * group + 0] = t[0]; + dst[13 * group + 1] = t[0] >> 8; + dst[13 * group + 1] |= t[1] << 5; + dst[13 * group + 2] = t[1] >> 3; + dst[13 * group + 3] = t[1] >> 11; + dst[13 * group + 3] |= t[2] << 2; + dst[13 * group + 4] = t[2] >> 6; + dst[13 * group + 4] |= t[3] << 7; + dst[13 * group + 5] = t[3] >> 1; + dst[13 * group + 6] = t[3] >> 9; + dst[13 * group + 6] |= t[4] << 4; + dst[13 * group + 7] = t[4] >> 4; + dst[13 * group + 8] = t[4] >> 12; + dst[13 * group + 8] |= t[5] << 1; + dst[13 * group + 9] = t[5] >> 7; + dst[13 * group + 9] |= t[6] << 6; + dst[13 * group + 10] = t[6] >> 2; + dst[13 * group + 11] = t[6] >> 10; + dst[13 * group + 11] |= t[7] << 3; + dst[13 * group + 12] = t[7] >> 5; +#elif PARAM_D == 14 + uint32_t t[4]; + for (int j = 0; j < 4; j++) + t[j] = COEFF_BIAS + (1 << (PARAM_D - 1)) - src[4 * group + j]; + dst[7 * group + 0] = t[0]; + dst[7 * group + 1] = t[0] >> 8; + dst[7 * group + 1] |= t[1] << 6; + dst[7 * group + 2] = t[1] >> 2; + dst[7 * group + 3] = t[1] >> 10; + dst[7 * group + 3] |= t[2] << 4; + dst[7 * group + 4] = t[2] >> 4; + dst[7 * group + 5] = t[2] >> 12; + dst[7 * group + 5] |= t[3] << 2; + dst[7 * group + 6] = t[3] >> 6; +#endif +} + +__global__ void batch_keygen_pack_header_kernel( + unsigned char * __restrict__ d_pks, + unsigned char * __restrict__ d_sks, + const unsigned char * __restrict__ d_buf, + int total_bytes) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_bytes) return; + + batch_keygen_pack_header_task(d_pks, d_sks, d_buf, idx); +} + +__global__ void batch_keygen_pack_t1_kernel( + unsigned char * __restrict__ d_pks, + const coeff_t * __restrict__ d_t1, + int total_groups) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_groups) return; + + batch_keygen_pack_t1_task(d_pks, d_t1, idx); +} + +__global__ void batch_keygen_pack_s1_kernel( + unsigned char * __restrict__ d_sks, + const coeff_t * __restrict__ d_s1, + int total_groups) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_groups) return; + + batch_keygen_pack_s1_task(d_sks, d_s1, idx); +} + +__global__ void batch_keygen_pack_s2_kernel( + unsigned char * __restrict__ d_sks, + const coeff_t * __restrict__ d_s2, + int total_groups) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_groups) return; + + batch_keygen_pack_s2_task(d_sks, d_s2, idx); +} + +__global__ void batch_keygen_pack_t0_kernel( + unsigned char * __restrict__ d_sks, + const coeff_t * __restrict__ d_t0, + int total_groups) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_groups) return; + + batch_keygen_pack_t0_task(d_sks, d_t0, idx); +} + +__global__ void batch_keygen_pack_body_kernel( + unsigned char * __restrict__ d_pks, + unsigned char * __restrict__ d_sks, + const coeff_t * __restrict__ d_t1, + const coeff_t * __restrict__ d_t0, + const coeff_t * __restrict__ d_s1, + const coeff_t * __restrict__ d_s2, + const unsigned char * __restrict__ d_buf, + int total_tasks, + int header_end, + int t1_end, + int s1_end, + int s2_end) +{ + int task_id = blockIdx.x * blockDim.x + threadIdx.x; + if (task_id >= total_tasks) return; + + if (task_id < header_end) { + batch_keygen_pack_header_task(d_pks, d_sks, d_buf, task_id); + } else if (task_id < t1_end) { + batch_keygen_pack_t1_task(d_pks, d_t1, task_id - header_end); + } else if (task_id < s1_end) { + batch_keygen_pack_s1_task(d_sks, d_s1, task_id - t1_end); + } else if (task_id < s2_end) { + batch_keygen_pack_s2_task(d_sks, d_s2, task_id - s1_end); + } else { + batch_keygen_pack_t0_task(d_sks, d_t0, task_id - s2_end); + } +} + +__global__ void batch_keygen_tr_hash_kernel( + unsigned char * __restrict__ d_pks, + unsigned char * __restrict__ d_sks, + unsigned char * __restrict__ d_tr_out, + int batch_count) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= batch_count) return; + + const size_t pk_base = (size_t)inst * CRYPTO_PUBLICKEYBYTES; + const size_t sk_base = (size_t)inst * CRYPTO_SECRETKEYBYTES; + const size_t tr_base = (size_t)inst * TRBYTES; + const unsigned char *pk = d_pks + pk_base; + unsigned char *sk_tr = d_sks + sk_base + 2 * SEEDBYTES; + unsigned char *tr_out = d_tr_out ? (d_tr_out + tr_base) : NULL; + uint8_t tr[TRBYTES]; +#if BATCH_KEYGEN_TR_HASH_FIXED + batch_keygen_shake256_tr_pk(tr, pk); +#else + shake256(tr, TRBYTES, pk, CRYPTO_PUBLICKEYBYTES); +#endif + for (int i = 0; i < TRBYTES; i++) { + sk_tr[i] = tr[i]; + if (tr_out) + tr_out[i] = tr[i]; + } +} + +__global__ void __launch_bounds__(32) +batch_keygen_pack_fused_tr_kernel( + unsigned char * __restrict__ d_pks, + unsigned char * __restrict__ d_sks, + const coeff_t * __restrict__ d_t1, + const coeff_t * __restrict__ d_t0, + const coeff_t * __restrict__ d_s1, + const coeff_t * __restrict__ d_s2, + const unsigned char * __restrict__ d_buf, + unsigned char * __restrict__ d_tr_out, + int batch_count, + int header_tasks_per_inst, + int t1_tasks_per_inst, + int s1_tasks_per_inst, + int s2_tasks_per_inst, + int t0_tasks_per_inst) +{ + const int inst = blockIdx.x; + if (inst >= batch_count) return; + + const int tid = threadIdx.x; + const int header_end = header_tasks_per_inst; + const int t1_end = header_end + t1_tasks_per_inst; + const int s1_end = t1_end + s1_tasks_per_inst; + const int s2_end = s1_end + s2_tasks_per_inst; + const int total_tasks = s2_end + t0_tasks_per_inst; + + for (int task = tid; task < total_tasks; task += blockDim.x) { + if (task < header_end) { + batch_keygen_pack_header_task( + d_pks, d_sks, d_buf, + inst * header_tasks_per_inst + task); + } else if (task < t1_end) { + batch_keygen_pack_t1_task( + d_pks, d_t1, + inst * t1_tasks_per_inst + (task - header_end)); + } else if (task < s1_end) { + batch_keygen_pack_s1_task( + d_sks, d_s1, + inst * s1_tasks_per_inst + (task - t1_end)); + } else if (task < s2_end) { + batch_keygen_pack_s2_task( + d_sks, d_s2, + inst * s2_tasks_per_inst + (task - s1_end)); + } else { + batch_keygen_pack_t0_task( + d_sks, d_t0, + inst * t0_tasks_per_inst + (task - s2_end)); + } + } + + __threadfence_block(); + __syncthreads(); + + if (tid == 0) { + const size_t pk_base = (size_t)inst * CRYPTO_PUBLICKEYBYTES; + const size_t sk_base = (size_t)inst * CRYPTO_SECRETKEYBYTES; + const size_t tr_base = (size_t)inst * TRBYTES; + unsigned char *pk = d_pks + pk_base; + unsigned char *sk_tr = d_sks + sk_base + 2 * SEEDBYTES; + unsigned char *tr_out = d_tr_out ? (d_tr_out + tr_base) : NULL; + polyveck t1_pk; + uint8_t rho[SEEDBYTES]; + + const unsigned char *my_buf = d_buf + (size_t)inst * (2 * SEEDBYTES + CRHBYTES); + for (int i = 0; i < SEEDBYTES; i++) rho[i] = my_buf[i]; + + for (int i = 0; i < PARAM_K; i++) { + const coeff_t *src1 = d_t1 + (size_t)inst * PARAM_K * PARAM_N + i * PARAM_N; + for (int c = 0; c < PARAM_N; c++) { + t1_pk.vec[i].coeffs[c] = src1[c]; + } + } + + pack_pk(pk, rho, &t1_pk); + + uint8_t tr[TRBYTES]; + shake256(tr, TRBYTES, pk, CRYPTO_PUBLICKEYBYTES); + for (int i = 0; i < TRBYTES; i++) { + sk_tr[i] = tr[i]; + if (tr_out) + tr_out[i] = tr[i]; + } + } +} + +#ifndef BATCH_KEYGEN_PACK_USE_REFERENCE +#define BATCH_KEYGEN_PACK_USE_REFERENCE 0 +#endif + +#ifndef BATCH_KEYGEN_PACK_PROFILE_SPLIT +#define BATCH_KEYGEN_PACK_PROFILE_SPLIT 0 +#endif + +#ifndef BATCH_KEYGEN_TR_HASH_EXPERIMENTAL +#define BATCH_KEYGEN_TR_HASH_EXPERIMENTAL 0 +#endif + +#ifndef BATCH_KEYGEN_PACK_FUSED_TR +#define BATCH_KEYGEN_PACK_FUSED_TR 0 +#endif + +static inline void launch_batch_keygen_tr_hash( + unsigned char *d_pks, + unsigned char *d_sks, + unsigned char *d_tr, + int batch_count, + hipStream_t stream = 0) +{ +#if BATCH_KEYGEN_TR_HASH_FIXED + const int tpb = 32; +#else + const int tpb = 128; +#endif + const int nblk = (batch_count + tpb - 1) / tpb; +#if BATCH_KEYGEN_TR_HASH_EXPERIMENTAL + batch_keygen_tr_hash_kernel<<>>( + d_pks, d_sks, d_tr, batch_count); +#else + batch_keygen_tr_hash_kernel<<>>( + d_pks, d_sks, d_tr, batch_count); +#endif +} + +static inline void launch_batch_keygen_pack_reference( + unsigned char *d_pks, + unsigned char *d_sks, + const coeff_t *d_t1, + const coeff_t *d_t0, + const coeff_t *d_s1, + const coeff_t *d_s2, + const unsigned char *d_buf, + unsigned char *d_tr, + int batch_count, + hipStream_t stream = 0) +{ + int tpb = 32; + int nblk = (batch_count + tpb - 1) / tpb; + batch_keygen_pack_precomputed_kernel<<>>( + d_pks, d_sks, d_t1, d_t0, d_s1, d_s2, d_buf, d_tr, batch_count); +} + +static inline void launch_batch_keygen_pack_standard( + unsigned char *d_pks, + unsigned char *d_sks, + const coeff_t *d_t1, + const coeff_t *d_t0, + const coeff_t *d_s1, + const coeff_t *d_s2, + const unsigned char *d_buf, + unsigned char *d_tr, + int batch_count, + hipStream_t stream = 0, + KeygenProfile *profile = NULL) +{ +#if BATCH_KEYGEN_PACK_USE_REFERENCE + launch_batch_keygen_pack_reference( + d_pks, d_sks, d_t1, d_t0, d_s1, d_s2, d_buf, d_tr, batch_count, stream); +#else + hipEvent_t ev0 = NULL, ev1 = NULL, ev_inner0 = NULL, ev_inner1 = NULL; + if (profile) { + hipEventCreate(&ev0); + hipEventCreate(&ev1); + hipEventCreate(&ev_inner0); + hipEventCreate(&ev_inner1); + hipEventRecord(ev_inner0, stream); + } + + const int tpb = BATCH_TPB; + const int header_tasks_per_inst = 2 * SEEDBYTES; + int total_header = batch_count * header_tasks_per_inst; + +#if POLYT1_PACKED_BITS == 10 + const int t1_groups_per_poly = PARAM_N / 4; +#else + const int t1_groups_per_poly = PARAM_N; +#endif + const int t1_tasks_per_inst = PARAM_K * t1_groups_per_poly; + int total_t1_groups = batch_count * t1_tasks_per_inst; + +#if SETA1BITS == 2 + const int s1_groups_per_poly = PARAM_N / 4; +#elif SETA1BITS == 3 + const int s1_groups_per_poly = PARAM_N / 8; +#else + const int s1_groups_per_poly = PARAM_N / 2; +#endif + const int s1_tasks_per_inst = PARAM_L * s1_groups_per_poly; + int total_s1_groups = batch_count * s1_tasks_per_inst; + +#if SETA2BITS == 3 + const int s2_groups_per_poly = PARAM_N / 8; +#else + const int s2_groups_per_poly = PARAM_N / 2; +#endif + const int s2_tasks_per_inst = PARAM_K * s2_groups_per_poly; + int total_s2_groups = batch_count * s2_tasks_per_inst; + +#if PARAM_D == 13 + const int t0_groups_per_poly = PARAM_N / 8; +#else + const int t0_groups_per_poly = PARAM_N / 4; +#endif + const int t0_tasks_per_inst = PARAM_K * t0_groups_per_poly; + int total_t0_groups = batch_count * t0_tasks_per_inst; + +#if BATCH_KEYGEN_PACK_FUSED_TR + if (profile) hipEventRecord(ev0, stream); + batch_keygen_pack_fused_tr_kernel<<>>( + d_pks, d_sks, d_t1, d_t0, d_s1, d_s2, d_buf, d_tr, batch_count, + header_tasks_per_inst, t1_tasks_per_inst, s1_tasks_per_inst, + s2_tasks_per_inst, t0_tasks_per_inst); + if (profile) { + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->pack_fused_ms, ev0, ev1); + hipEventRecord(ev_inner1, stream); + hipEventSynchronize(ev_inner1); + keygen_profile_add(&profile->pack_inner_ms, ev_inner0, ev_inner1); + hipEventDestroy(ev0); + hipEventDestroy(ev1); + hipEventDestroy(ev_inner0); + hipEventDestroy(ev_inner1); + } +#else + + if (profile && BATCH_KEYGEN_PACK_PROFILE_SPLIT) { + if (profile) hipEventRecord(ev0, stream); + batch_keygen_pack_header_kernel<<<(total_header + tpb - 1) / tpb, tpb, 0, stream>>>( + d_pks, d_sks, d_buf, total_header); + if (profile) { + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->pack_header_ms, ev0, ev1); + } + + if (profile) hipEventRecord(ev0, stream); + batch_keygen_pack_t1_kernel<<<(total_t1_groups + tpb - 1) / tpb, tpb, 0, stream>>>( + d_pks, d_t1, total_t1_groups); + if (profile) { + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->pack_t1_ms, ev0, ev1); + } + + if (profile) hipEventRecord(ev0, stream); + batch_keygen_pack_s1_kernel<<<(total_s1_groups + tpb - 1) / tpb, tpb, 0, stream>>>( + d_sks, d_s1, total_s1_groups); + batch_keygen_pack_s2_kernel<<<(total_s2_groups + tpb - 1) / tpb, tpb, 0, stream>>>( + d_sks, d_s2, total_s2_groups); + if (profile) { + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->pack_eta_ms, ev0, ev1); + } + + if (profile) hipEventRecord(ev0, stream); + batch_keygen_pack_t0_kernel<<<(total_t0_groups + tpb - 1) / tpb, tpb, 0, stream>>>( + d_sks, d_t0, total_t0_groups); + if (profile) { + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->pack_t0_ms, ev0, ev1); + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->pack_body_ms, ev_inner0, ev1); + } + } else { + int header_end = total_header; + int t1_end = header_end + total_t1_groups; + int s1_end = t1_end + total_s1_groups; + int s2_end = s1_end + total_s2_groups; + int total_body_tasks = s2_end + total_t0_groups; + + if (profile) hipEventRecord(ev0, stream); + batch_keygen_pack_body_kernel<<<(total_body_tasks + tpb - 1) / tpb, tpb, 0, stream>>>( + d_pks, d_sks, d_t1, d_t0, d_s1, d_s2, d_buf, + total_body_tasks, header_end, t1_end, s1_end, s2_end); + if (profile) { + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->pack_body_ms, ev0, ev1); + } + } + + if (profile) hipEventRecord(ev0, stream); + launch_batch_keygen_tr_hash(d_pks, d_sks, d_tr, batch_count, stream); + if (profile) { + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->tr_hash_ms, ev0, ev1); + hipEventRecord(ev_inner1, stream); + hipEventSynchronize(ev_inner1); + keygen_profile_add(&profile->pack_inner_ms, ev_inner0, ev_inner1); + hipEventDestroy(ev0); + hipEventDestroy(ev1); + hipEventDestroy(ev_inner0); + hipEventDestroy(ev_inner1); + } +#endif +#endif +} + +__global__ void batch_keygen_shiftl_copy_kernel( + coeff_t * __restrict__ d_dst, + const coeff_t * __restrict__ d_src, + int total_coeffs, + int shift) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_coeffs) return; + d_dst[idx] = d_src[idx] << shift; +} + +static inline void launch_batch_keygen_shiftl_copy(coeff_t *d_dst, + const coeff_t *d_src, + int total_coeffs, + int shift, + hipStream_t stream = 0) { + int nblk = (total_coeffs + BATCH_TPB - 1) / BATCH_TPB; + batch_keygen_shiftl_copy_kernel<<>>( + d_dst, d_src, total_coeffs, shift); +} + +static inline void batch_keygen_finalize_material(BatchKeygenBuffers *buf, + int batch_count, + hipStream_t stream = 0) { + const int total_k = batch_count * PARAM_K * PARAM_N; + hipMemcpyAsync(buf->d_s2_ntt, buf->d_s2, + (size_t)total_k * sizeof(coeff_t), + hipMemcpyDeviceToDevice, stream); + hipMemcpyAsync(buf->d_t0_ntt, buf->d_t0, + (size_t)total_k * sizeof(coeff_t), + hipMemcpyDeviceToDevice, stream); + launch_batch_keygen_shiftl_copy(buf->d_t1_hat, buf->d_t1, + total_k, PARAM_D, stream); + launch_batch_ntt(buf->d_s2_ntt, batch_count * PARAM_K, stream); + launch_batch_ntt(buf->d_t0_ntt, batch_count * PARAM_K, stream); + launch_batch_ntt(buf->d_t1_hat, batch_count * PARAM_K, stream); +} + +__global__ void batch_keygen_material_to_precomp_kernel( + precomp_t * __restrict__ pc, + const coeff_t * __restrict__ d_mat, + const coeff_t * __restrict__ d_s1_ntt, + const coeff_t * __restrict__ d_s2_ntt, + const coeff_t * __restrict__ d_t0_ntt, + const unsigned char * __restrict__ d_buf, + const unsigned char * __restrict__ d_tr, + int inst, + int mat_shared) +{ + if (threadIdx.x != 0 || blockIdx.x != 0) return; + + const size_t mat_base = mat_shared ? 0 : + (size_t)inst * PARAM_K * PARAM_L * PARAM_N; + for (int k = 0; k < PARAM_K; k++) { + for (int l = 0; l < PARAM_L; l++) { + const coeff_t *src = d_mat + mat_base + + (size_t)(k * PARAM_L + l) * PARAM_N; + for (int c = 0; c < PARAM_N; c++) + pc->mat[k].vec[l].coeffs[c] = src[c]; + } + } + + const size_t s1_base = (size_t)inst * PARAM_L * PARAM_N; + for (int l = 0; l < PARAM_L; l++) { + const coeff_t *src = d_s1_ntt + s1_base + (size_t)l * PARAM_N; + for (int c = 0; c < PARAM_N; c++) + pc->s1_ntt.vec[l].coeffs[c] = src[c]; + } + + const size_t k_base = (size_t)inst * PARAM_K * PARAM_N; + for (int k = 0; k < PARAM_K; k++) { + const coeff_t *s2 = d_s2_ntt + k_base + (size_t)k * PARAM_N; + const coeff_t *t0 = d_t0_ntt + k_base + (size_t)k * PARAM_N; + for (int c = 0; c < PARAM_N; c++) { + pc->s2_ntt.vec[k].coeffs[c] = s2[c]; + pc->t0_ntt.vec[k].coeffs[c] = t0[c]; + } + } + + const unsigned char *my_buf = d_buf + (size_t)inst * (2 * SEEDBYTES + CRHBYTES); + for (int i = 0; i < SEEDBYTES; i++) pc->key[i] = my_buf[SEEDBYTES + i]; + const unsigned char *tr = d_tr + (size_t)inst * TRBYTES; + for (int i = 0; i < TRBYTES; i++) pc->tr[i] = tr[i]; +} + +__global__ void batch_keygen_material_to_verify_kernel( + coeff_t * __restrict__ d_vmat, + coeff_t * __restrict__ d_vt1_hat, + unsigned char * __restrict__ d_vtr, + const coeff_t * __restrict__ d_mat, + const coeff_t * __restrict__ d_t1_hat, + const unsigned char * __restrict__ d_tr, + int inst, + int mat_shared) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int mat_total = PARAM_K * PARAM_L * PARAM_N; + if (idx < mat_total) { + size_t src_idx = mat_shared ? (size_t)idx : + (size_t)inst * PARAM_K * PARAM_L * PARAM_N + idx; + d_vmat[idx] = d_mat[src_idx]; + } + const int t1_total = PARAM_K * PARAM_N; + if (idx < t1_total) { + d_vt1_hat[idx] = d_t1_hat[(size_t)inst * PARAM_K * PARAM_N + idx]; + } + if (idx < TRBYTES) { + d_vtr[idx] = d_tr[(size_t)inst * TRBYTES + idx]; + } +} + +/* ================================================================ + * Host API — 缓冲区分配/释放 + * ================================================================ */ + +static int batch_keygen_alloc(BatchKeygenBuffers *buf, int max_batch) { + memset(buf, 0, sizeof(*buf)); + buf->max_batch = max_batch; + size_t B = max_batch; + size_t N = PARAM_N; + +#define BKG_TRY(ptr, sz) do { \ + if (hipMalloc(&(ptr), (sz)) != hipSuccess) { hipGetLastError(); return -1; } \ +} while(0) + + BKG_TRY(buf->d_mat, B * PARAM_K * PARAM_L * N * sizeof(coeff_t)); + BKG_TRY(buf->d_s1, B * PARAM_L * N * sizeof(coeff_t)); + BKG_TRY(buf->d_s1hat, B * PARAM_L * N * sizeof(coeff_t)); + BKG_TRY(buf->d_s2, B * PARAM_K * N * sizeof(coeff_t)); + BKG_TRY(buf->d_t, B * PARAM_K * N * sizeof(coeff_t)); + BKG_TRY(buf->d_t1, B * PARAM_K * N * sizeof(coeff_t)); + BKG_TRY(buf->d_t0, B * PARAM_K * N * sizeof(coeff_t)); + BKG_TRY(buf->d_t1_hat, B * PARAM_K * N * sizeof(coeff_t)); + BKG_TRY(buf->d_s2_ntt, B * PARAM_K * N * sizeof(coeff_t)); + BKG_TRY(buf->d_t0_ntt, B * PARAM_K * N * sizeof(coeff_t)); + BKG_TRY(buf->d_tr, B * TRBYTES); + BKG_TRY(buf->d_pks, B * CRYPTO_PUBLICKEYBYTES); + BKG_TRY(buf->d_sks, B * CRYPTO_SECRETKEYBYTES); + BKG_TRY(buf->d_buf, B * (2 * SEEDBYTES + CRHBYTES)); + +#undef BKG_TRY + return 0; +} + +static void batch_keygen_free(BatchKeygenBuffers *buf) { + hipFree(buf->d_mat); hipFree(buf->d_s1); + hipFree(buf->d_s1hat); hipFree(buf->d_s2); + hipFree(buf->d_t); hipFree(buf->d_t1); + hipFree(buf->d_t0); hipFree(buf->d_t1_hat); + hipFree(buf->d_s2_ntt); hipFree(buf->d_t0_ntt); + hipFree(buf->d_tr); hipFree(buf->d_pks); + hipFree(buf->d_sks); hipFree(buf->d_buf); + memset(buf, 0, sizeof(*buf)); +} + +/* ================================================================ + * 批量密钥生成 pipeline (1 warp/instance 算子级并行采样) + * + * 采样阶段: 1 warp (32 线程) per instance 并行生成所有多项式 + * - Aigis-sig3: 30+5+6=41 个多项式, 约 2 轮 + * - ML-DSA-87: 56+7+8=71 个多项式, 约 3 轮 + * ================================================================ */ +static int batch_keygen_pipeline_warp( + unsigned char *d_pks, + unsigned char *d_sks, + const unsigned char *d_seeds, + BatchKeygenBuffers *buf, + int batch_count, + hipStream_t stream = 0, + int produce_material = 1) +{ + + if (batch_count <= 0 || batch_count > buf->max_batch) return -1; + const int N = PARAM_N; + + /* [1] 算子级并行采样: 1 warp per instance, 32× 并行 SHAKE 调用 */ + { + int nwarps = batch_count; /* 每 instance 1 warp */ + int nthreads = nwarps * WP_KG_WARP_SIZE; /* 总线程数 */ + int nblk = (nthreads + WP_KG_TPB - 1) / WP_KG_TPB; + batch_keygen_warp_sample_kernel<<>>( + buf->d_mat, buf->d_s1, buf->d_s2, + buf->d_buf, d_seeds, batch_count); + } + + /* [2] copy s1 → s1hat */ + /* [3] NTT(s1hat) — shared-memory batch */ + hipMemcpyAsync(buf->d_s1hat, buf->d_s1, + (size_t)batch_count * PARAM_L * N * sizeof(coeff_t), + hipMemcpyDeviceToDevice, + stream); + launch_batch_ntt(buf->d_s1hat, batch_count * PARAM_L, stream); + + /* [4] 矩阵向量乘 */ + { + dim3 grid(batch_count, PARAM_K); + batch_keygen_matvec_kernel<<>>( + buf->d_t, buf->d_mat, buf->d_s1hat, batch_count); + } + + /* [5] reduce + INVNTT */ + launch_batch_reduce(buf->d_t, batch_count * PARAM_K * N, stream); + launch_batch_invntt(buf->d_t, batch_count * PARAM_K, stream); + + /* [6] t += s2 */ + launch_batch_add(buf->d_t, buf->d_t, buf->d_s2, + batch_count * PARAM_K * N, stream); + + /* [6.5] normalize */ +#if ALGORITHM == ALGO_MLDSA + launch_batch_caddq(buf->d_t, batch_count * PARAM_K * N, stream); +#elif ALGORITHM == ALGO_AIGIS + launch_batch_freeze_wide(buf->d_t, batch_count * PARAM_K * N, stream); +#endif + + /* [7] 打包 pk, sk */ + launch_batch_power2round(buf->d_t1, buf->d_t0, buf->d_t, + batch_count * PARAM_K * N, stream); + launch_batch_keygen_pack_standard(d_pks, d_sks, + buf->d_t1, buf->d_t0, + buf->d_s1, buf->d_s2, + buf->d_buf, buf->d_tr, + batch_count, stream); + if (produce_material) + batch_keygen_finalize_material(buf, batch_count, stream); + + return 0; +} + +static inline void launch_batch_keygen_sample_independent( + BatchKeygenBuffers *buf, + const unsigned char *d_seeds, + int batch_count, + KeygenProfile *profile, + hipEvent_t ev0, + hipEvent_t ev1, + hipStream_t stream) +{ + int nwarps = batch_count; + int nthreads = nwarps * WP_KG_WARP_SIZE; + int nblk = (nthreads + WP_KG_TPB - 1) / WP_KG_TPB; + +#if BATCH_KEYGEN_MATRIX_A_COOP || BATCH_KEYGEN_MATRIX_A_LANEOPT || \ + BATCH_KEYGEN_SECRET_ETA_COOP || BATCH_KEYGEN_MATRIX_A_FAST || \ + BATCH_KEYGEN_SECRET_ETA_FAST || \ + BATCH_KEYGEN_SAMPLE_SPLIT_FAST + if (profile) { + hipEvent_t sample_ev0 = NULL, sample_ev1 = NULL; + hipEventCreate(&sample_ev0); + hipEventCreate(&sample_ev1); + hipEventRecord(sample_ev0, stream); + + hipEventRecord(ev0, stream); + batch_keygen_seed_expand_kernel<<>>( + buf->d_buf, d_seeds, batch_count); + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->seed_expand_ms, ev0, ev1); + + hipEventRecord(ev0, stream); + launch_batch_keygen_matrix_a_active( + buf->d_mat, buf->d_buf, batch_count, nblk, stream); + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->matrix_a_sample_ms, ev0, ev1); + + hipEventRecord(ev0, stream); + launch_batch_keygen_secret_eta_active_independent( + buf->d_s1, buf->d_s2, buf->d_buf, batch_count, nblk, stream); + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->secret_eta_sample_ms, ev0, ev1); + + hipEventRecord(sample_ev1, stream); + hipEventSynchronize(sample_ev1); + keygen_profile_add(&profile->sample_ms, sample_ev0, sample_ev1); + keygen_profile_finalize_sample( + profile, + profile->seed_expand_ms + profile->matrix_a_sample_ms + profile->secret_eta_sample_ms); +#if BATCH_KEYGEN_MATRIX_A_COOP + profile->matrix_a_coop_ms = profile->matrix_a_sample_ms; +#if BATCH_KEYGEN_MATRIX_A_COOP_SUBWARP + profile->matrix_a_coop_lanes = BATCH_KEYGEN_MATRIX_A_COOP_SUBWARP_LANES; +#else + profile->matrix_a_coop_lanes = 32; +#endif +#endif +#if BATCH_KEYGEN_SECRET_ETA_COOP + profile->secret_eta_coop_ms = profile->secret_eta_sample_ms; + profile->secret_eta_coop_lanes = BATCH_KEYGEN_SECRET_ETA_COOP_LANES; +#endif + + hipEventDestroy(sample_ev0); + hipEventDestroy(sample_ev1); + } else { + batch_keygen_seed_expand_kernel<<>>( + buf->d_buf, d_seeds, batch_count); + launch_batch_keygen_matrix_a_active( + buf->d_mat, buf->d_buf, batch_count, nblk, stream); + launch_batch_keygen_secret_eta_active_independent( + buf->d_s1, buf->d_s2, buf->d_buf, batch_count, nblk, stream); + } +#else + if (profile) hipEventRecord(ev0, stream); + batch_keygen_warp_sample_kernel<<>>( + buf->d_mat, buf->d_s1, buf->d_s2, + buf->d_buf, d_seeds, batch_count); + if (profile) { + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->sample_ms, ev0, ev1); + keygen_profile_finalize_sample(profile, profile->sample_ms); + } +#endif +} + +static inline void launch_batch_keygen_sample_paper( + BatchKeygenBuffers *buf, + const unsigned char *d_seeds, + const unsigned char *d_shared_rho, + int batch_count, + KeygenProfile *profile, + hipEvent_t ev0, + hipEvent_t ev1, + hipStream_t stream) +{ + int nwarps = batch_count; + int nthreads = nwarps * WP_KG_WARP_SIZE; + int nblk = (nthreads + WP_KG_TPB - 1) / WP_KG_TPB; + +#if BATCH_KEYGEN_SECRET_ETA_COOP || BATCH_KEYGEN_SECRET_ETA_FAST || \ + BATCH_KEYGEN_SAMPLE_SPLIT_FAST + if (profile) { + hipEvent_t sample_ev0 = NULL, sample_ev1 = NULL; + hipEventCreate(&sample_ev0); + hipEventCreate(&sample_ev1); + hipEventRecord(sample_ev0, stream); + + hipEventRecord(ev0, stream); + batch_keygen_seed_expand_kernel<<>>( + buf->d_buf, d_seeds, batch_count); + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->seed_expand_ms, ev0, ev1); + + hipEventRecord(ev0, stream); + launch_batch_keygen_paper_rho_active( + buf->d_buf, d_shared_rho, batch_count, stream); + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->matrix_a_sample_ms, ev0, ev1); + + hipEventRecord(ev0, stream); + launch_batch_keygen_secret_eta_active_paper( + buf->d_s1, buf->d_s1hat, buf->d_s2, + buf->d_buf, d_shared_rho, batch_count, nblk, stream); + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->secret_eta_sample_ms, ev0, ev1); + + hipEventRecord(sample_ev1, stream); + hipEventSynchronize(sample_ev1); + keygen_profile_add(&profile->sample_ms, sample_ev0, sample_ev1); + keygen_profile_finalize_sample( + profile, + profile->seed_expand_ms + profile->matrix_a_sample_ms + profile->secret_eta_sample_ms); +#if BATCH_KEYGEN_SECRET_ETA_COOP + profile->secret_eta_coop_ms = profile->secret_eta_sample_ms; + profile->secret_eta_coop_lanes = BATCH_KEYGEN_SECRET_ETA_COOP_LANES; +#endif + + hipEventDestroy(sample_ev0); + hipEventDestroy(sample_ev1); + } else { + batch_keygen_seed_expand_kernel<<>>( + buf->d_buf, d_seeds, batch_count); + launch_batch_keygen_paper_rho_active( + buf->d_buf, d_shared_rho, batch_count, stream); + launch_batch_keygen_secret_eta_active_paper( + buf->d_s1, buf->d_s1hat, buf->d_s2, + buf->d_buf, d_shared_rho, batch_count, nblk, stream); + } +#else + if (profile) hipEventRecord(ev0, stream); + batch_keygen_paper_secret_sample_kernel<<>>( + buf->d_s1, buf->d_s1hat, buf->d_s2, + buf->d_buf, d_seeds, d_shared_rho, batch_count); + if (profile) { + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->sample_ms, ev0, ev1); + keygen_profile_finalize_sample(profile, profile->sample_ms); + } +#endif +} + +static int batch_keygen_sample_only_independent( + BatchKeygenBuffers *buf, + const unsigned char *d_seeds, + int batch_count, + KeygenSampleOnlyProfile *profile, + hipStream_t stream = 0) +{ + if (!profile || batch_count <= 0 || batch_count > buf->max_batch) return -1; + + hipEvent_t ev0 = NULL, ev1 = NULL; + KeygenProfile active_profile; + int nwarps = batch_count; + int nthreads = nwarps * WP_KG_WARP_SIZE; + int nblk = (nthreads + WP_KG_TPB - 1) / WP_KG_TPB; + + keygen_sample_only_profile_clear(profile); + keygen_profile_clear(&active_profile); + hipEventCreate(&ev0); + hipEventCreate(&ev1); + + hipEventRecord(ev0, stream); + batch_keygen_warp_sample_kernel<<>>( + buf->d_mat, buf->d_s1, buf->d_s2, + buf->d_buf, d_seeds, batch_count); + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->old_fused_ms, ev0, ev1); + + launch_batch_keygen_sample_independent( + buf, d_seeds, batch_count, &active_profile, ev0, ev1, stream); + profile->split_seed_ms = active_profile.seed_expand_ms; + profile->split_matrix_a_ms = active_profile.matrix_a_sample_ms; + profile->split_eta_ms = active_profile.secret_eta_sample_ms; + profile->split_total_ms = active_profile.sample_ms; + profile->split_launch_gap_ms = active_profile.sample_launch_gap_ms; + profile->split_matrix_a_coop_ms = active_profile.matrix_a_coop_ms; + profile->split_eta_coop_ms = active_profile.secret_eta_coop_ms; + profile->split_matrix_a_coop_lanes = active_profile.matrix_a_coop_lanes; + profile->split_eta_coop_lanes = active_profile.secret_eta_coop_lanes; + + hipEventDestroy(ev0); + hipEventDestroy(ev1); + return 0; +} + +static int batch_keygen_sample_only_paper( + BatchKeygenBuffers *buf, + const unsigned char *d_seeds, + unsigned char *d_shared_rho, + int batch_count, + KeygenSampleOnlyProfile *profile, + hipStream_t stream = 0) +{ + if (!profile || batch_count <= 0 || batch_count > buf->max_batch) return -1; + + hipEvent_t ev0 = NULL, ev1 = NULL; + KeygenProfile active_profile; + int nwarps = batch_count; + int nthreads = nwarps * WP_KG_WARP_SIZE; + int nblk = (nthreads + WP_KG_TPB - 1) / WP_KG_TPB; + + keygen_sample_only_profile_clear(profile); + keygen_profile_clear(&active_profile); + hipEventCreate(&ev0); + hipEventCreate(&ev1); + + hipEventRecord(ev0, stream); + batch_keygen_paper_shared_a_kernel<<<1, WP_KG_TPB, 0, stream>>>( + buf->d_mat, d_shared_rho, d_seeds); + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->shared_a_ms, ev0, ev1); + + hipEventRecord(ev0, stream); + batch_keygen_paper_secret_sample_kernel<<>>( + buf->d_s1, buf->d_s1hat, buf->d_s2, + buf->d_buf, d_seeds, d_shared_rho, batch_count); + hipEventRecord(ev1, stream); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->old_fused_ms, ev0, ev1); + + launch_batch_keygen_sample_paper( + buf, d_seeds, d_shared_rho, batch_count, + &active_profile, ev0, ev1, stream); + profile->split_seed_ms = active_profile.seed_expand_ms; + profile->split_matrix_a_ms = active_profile.matrix_a_sample_ms; + profile->split_eta_ms = active_profile.secret_eta_sample_ms; + profile->split_total_ms = active_profile.sample_ms; + profile->split_launch_gap_ms = active_profile.sample_launch_gap_ms; + profile->split_matrix_a_coop_ms = active_profile.matrix_a_coop_ms; + profile->split_eta_coop_ms = active_profile.secret_eta_coop_ms; + profile->split_matrix_a_coop_lanes = active_profile.matrix_a_coop_lanes; + profile->split_eta_coop_lanes = active_profile.secret_eta_coop_lanes; + + hipEventDestroy(ev0); + hipEventDestroy(ev1); + return 0; +} + +static int batch_keygen_pipeline_warp_opt( + unsigned char *d_pks, + unsigned char *d_sks, + const unsigned char *d_seeds, + BatchKeygenBuffers *buf, + int batch_count, + KeygenProfile *profile = NULL, + hipStream_t stream = 0, + int produce_material = 1) +{ + + if (batch_count <= 0 || batch_count > buf->max_batch) return -1; + const int N = PARAM_N; + hipEvent_t ev0 = NULL, ev1 = NULL; + if (profile) { + keygen_profile_clear(profile); + hipEventCreate(&ev0); + hipEventCreate(&ev1); + } + + launch_batch_keygen_sample_independent( + buf, d_seeds, batch_count, profile, ev0, ev1, stream); + + if (profile) hipEventRecord(ev0, stream); + hipMemcpyAsync(buf->d_s1hat, buf->d_s1, + (size_t)batch_count * PARAM_L * N * sizeof(coeff_t), + hipMemcpyDeviceToDevice, + stream); + if (profile) { hipEventRecord(ev1, stream); hipEventSynchronize(ev1); keygen_profile_add(&profile->copy_ms, ev0, ev1); } + + if (profile) hipEventRecord(ev0, stream); + launch_batch_ntt(buf->d_s1hat, batch_count * PARAM_L, stream); + if (profile) { hipEventRecord(ev1, stream); hipEventSynchronize(ev1); keygen_profile_add(&profile->ntt_ms, ev0, ev1); } + + { + dim3 grid(batch_count, PARAM_K); + if (profile) hipEventRecord(ev0, stream); + batch_keygen_matvec_kernel<<>>( + buf->d_t, buf->d_mat, buf->d_s1hat, batch_count); + if (profile) { hipEventRecord(ev1, stream); hipEventSynchronize(ev1); keygen_profile_add(&profile->matvec_ms, ev0, ev1); } + } + + if (profile) hipEventRecord(ev0, stream); + launch_batch_reduce(buf->d_t, batch_count * PARAM_K * N, stream); + launch_batch_invntt(buf->d_t, batch_count * PARAM_K, stream); + launch_batch_keygen_add_norm(buf->d_t, buf->d_s2, + batch_count * PARAM_K * N, stream); + if (profile) { hipEventRecord(ev1, stream); hipEventSynchronize(ev1); keygen_profile_add(&profile->post_ms, ev0, ev1); } + + if (profile) hipEventRecord(ev0, stream); + launch_batch_power2round(buf->d_t1, buf->d_t0, buf->d_t, + batch_count * PARAM_K * N, stream); + if (profile) { hipEventRecord(ev1, stream); hipEventSynchronize(ev1); keygen_profile_add(&profile->p2r_ms, ev0, ev1); } + + if (profile) hipEventRecord(ev0, stream); + launch_batch_keygen_pack_standard(d_pks, d_sks, + buf->d_t1, buf->d_t0, + buf->d_s1, buf->d_s2, + buf->d_buf, buf->d_tr, + batch_count, stream, profile); + if (profile) { hipEventRecord(ev1, stream); hipEventSynchronize(ev1); keygen_profile_add(&profile->pack_ms, ev0, ev1); } + + if (produce_material) { + if (profile) hipEventRecord(ev0, stream); + batch_keygen_finalize_material(buf, batch_count, stream); + if (profile) { hipEventRecord(ev1, stream); hipEventSynchronize(ev1); keygen_profile_add(&profile->material_ms, ev0, ev1); } + } + + if (profile) { + hipEventDestroy(ev0); + hipEventDestroy(ev1); + } + return 0; +} + +static int batch_keygen_create_shared_rho_a( + BatchKeygenBuffers *buf, + unsigned char *d_shared_rho, + const unsigned char *d_base_seed, + KeygenProfile *profile = NULL) +{ + hipEvent_t ev0 = NULL, ev1 = NULL; + if (profile) { + hipEventCreate(&ev0); + hipEventCreate(&ev1); + hipEventRecord(ev0); + } + batch_keygen_paper_shared_a_kernel<<<1, WP_KG_TPB>>>( + buf->d_mat, d_shared_rho, d_base_seed); + if (profile) { + hipEventRecord(ev1); + hipEventSynchronize(ev1); + keygen_profile_add(&profile->shared_a_ms, ev0, ev1); + hipEventDestroy(ev0); + hipEventDestroy(ev1); + } + return 0; +} + +static int batch_keygen_pipeline_paper_shared_rho_a( + unsigned char *d_pks, + unsigned char *d_sks, + const unsigned char *d_seeds, + const unsigned char *d_shared_rho, + BatchKeygenBuffers *buf, + int batch_count, + KeygenProfile *profile = NULL, + hipStream_t stream = 0, + int produce_material = 1) +{ + + if (batch_count <= 0 || batch_count > buf->max_batch) return -1; + const int N = PARAM_N; + hipEvent_t ev0 = NULL, ev1 = NULL; + float shared_keep = profile ? profile->shared_a_ms : 0.0f; + if (profile) { + keygen_profile_clear(profile); + profile->shared_a_ms = shared_keep; + hipEventCreate(&ev0); + hipEventCreate(&ev1); + } + + launch_batch_keygen_sample_paper( + buf, d_seeds, d_shared_rho, batch_count, profile, ev0, ev1, stream); + + if (profile) hipEventRecord(ev0, stream); + launch_batch_ntt(buf->d_s1hat, batch_count * PARAM_L, stream); + if (profile) { hipEventRecord(ev1, stream); hipEventSynchronize(ev1); keygen_profile_add(&profile->ntt_ms, ev0, ev1); } + + { + dim3 grid(batch_count, PARAM_K); + if (profile) hipEventRecord(ev0, stream); + batch_keygen_matvec_shared_a_kernel<<>>( + buf->d_t, buf->d_mat, buf->d_s1hat, batch_count); + if (profile) { hipEventRecord(ev1, stream); hipEventSynchronize(ev1); keygen_profile_add(&profile->matvec_ms, ev0, ev1); } + } + + if (profile) hipEventRecord(ev0, stream); + launch_batch_reduce(buf->d_t, batch_count * PARAM_K * N, stream); + launch_batch_invntt(buf->d_t, batch_count * PARAM_K, stream); + launch_batch_keygen_add_norm(buf->d_t, buf->d_s2, + batch_count * PARAM_K * N, stream); + if (profile) { hipEventRecord(ev1, stream); hipEventSynchronize(ev1); keygen_profile_add(&profile->post_ms, ev0, ev1); } + + if (profile) hipEventRecord(ev0, stream); + launch_batch_power2round(buf->d_t1, buf->d_t0, buf->d_t, + batch_count * PARAM_K * N, stream); + if (profile) { hipEventRecord(ev1, stream); hipEventSynchronize(ev1); keygen_profile_add(&profile->p2r_ms, ev0, ev1); } + + if (profile) hipEventRecord(ev0, stream); + launch_batch_keygen_pack_standard(d_pks, d_sks, + buf->d_t1, buf->d_t0, + buf->d_s1, buf->d_s2, + buf->d_buf, buf->d_tr, + batch_count, stream, profile); + if (profile) { hipEventRecord(ev1, stream); hipEventSynchronize(ev1); keygen_profile_add(&profile->pack_ms, ev0, ev1); } + + if (produce_material) { + if (profile) hipEventRecord(ev0, stream); + batch_keygen_finalize_material(buf, batch_count, stream); + if (profile) { hipEventRecord(ev1, stream); hipEventSynchronize(ev1); keygen_profile_add(&profile->material_ms, ev0, ev1); } + } + + if (profile) { + hipEventDestroy(ev0); + hipEventDestroy(ev1); + } + return 0; +} + +static int batch_keygen_compare_device_buffer( + const void *d_ref, + const void *d_cand, + size_t total_bytes, + size_t inst_stride_bytes, + size_t elem_size, + KeygenCompareStage stage, + KeygenCompareResult *out) +{ + if (!d_ref || !d_cand || !out) return -1; + if (total_bytes == 0) return 0; + + unsigned char *h_ref = (unsigned char *)malloc(total_bytes); + unsigned char *h_cand = (unsigned char *)malloc(total_bytes); + if (!h_ref || !h_cand) { + free(h_ref); + free(h_cand); + return -1; + } + + hipError_t err = hipMemcpy(h_ref, d_ref, total_bytes, hipMemcpyDeviceToHost); + if (err == hipSuccess) + err = hipMemcpy(h_cand, d_cand, total_bytes, hipMemcpyDeviceToHost); + if (err != hipSuccess) { + free(h_ref); + free(h_cand); + return -1; + } + + for (size_t byte_idx = 0; byte_idx < total_bytes; ++byte_idx) { + if (h_ref[byte_idx] == h_cand[byte_idx]) continue; + + keygen_compare_result_clear(out); + out->stage = stage; + out->instance = inst_stride_bytes ? (int)(byte_idx / inst_stride_bytes) : 0; + out->byte_offset = inst_stride_bytes ? (byte_idx % inst_stride_bytes) : byte_idx; + out->element_offset = elem_size ? (out->byte_offset / elem_size) : out->byte_offset; + + if (elem_size == sizeof(coeff_t)) { + size_t coeff_idx = byte_idx / sizeof(coeff_t); + const coeff_t *ref_coeffs = (const coeff_t *)h_ref; + const coeff_t *cand_coeffs = (const coeff_t *)h_cand; + out->ref_value = ref_coeffs[coeff_idx]; + out->cand_value = cand_coeffs[coeff_idx]; + } else { + out->ref_value = (int64_t)h_ref[byte_idx]; + out->cand_value = (int64_t)h_cand[byte_idx]; + } + + free(h_ref); + free(h_cand); + return 1; + } + + free(h_ref); + free(h_cand); + return 0; +} + +static int batch_keygen_compare_active_path( + const unsigned char *d_seeds, + int batch_count, + int use_paper_shared_a, + int sample_only, + KeygenCompareResult *out, + hipStream_t stream = 0) +{ + (void)d_seeds; + (void)batch_count; + (void)use_paper_shared_a; + (void)sample_only; + (void)out; + (void)stream; + return -1; +} + + +#endif /* BATCH_KEYGEN_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_ntt.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_ntt.cuh new file mode 100644 index 000000000..c28b6f515 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_ntt.cuh @@ -0,0 +1,285 @@ +#include "hip/hip_runtime.h" +/* + * batch_ntt.cuh — 共享内存批量 NTT/INVNTT + * + * 设计: 1 block (128 threads) 处理 1 个 256 系数多项式 + * 全部蝶形运算在 shared memory 中完成 + * 两种算法 (ML-DSA / Aigis) 共用同一 NTT 蝶形结构 + * (因为 new1 统一使用 int32_t, 蝶形 = a[j]±t) + * + * INVNTT: ML-DSA 使用 -ntt_zetas[--k] + * Aigis 使用 ntt_zetas_inv[k++] + */ + +#ifndef BATCH_NTT_CUH +#define BATCH_NTT_CUH + +#include +#include +#include "params.h" +#include "reduce.cuh" + +/* ================================================================ + * Montgomery multiply helper for batch kernels + * ================================================================ */ +static __device__ __forceinline__ coeff_t batch_fqmul_local(coeff_t a, coeff_t b) { + return montgomery_reduce((coeff2_t)a * b); +} + +/* ================================================================ + * 前向 NTT kernel — Cooley-Tukey, 8 stages, 128 threads + * + * 两种算法均使用有符号蝶形: + * t = mont(zeta * s[j+len]) + * s[j+len] = s[j] - t + * s[j] = s[j] + t + * ================================================================ */ +__global__ void batch_ntt_kernel(coeff_t *d_polys, int poly_count) { + int poly_idx = blockIdx.x; + if (poly_idx >= poly_count) return; + + int tid = threadIdx.x; /* 0..127 */ + /* +8 padding: 每 32 元素插 1 个填充字, 消除 stride=16/8/4/2 时的 bank conflict */ + __shared__ coeff_t s[PARAM_N + (PARAM_N >> 5)]; +#define SP(i) ((i) + ((i) >> 5)) + + coeff_t *base = d_polys + (size_t)poly_idx * PARAM_N; + s[SP(tid)] = base[tid]; + s[SP(tid + 128)] = base[tid + 128]; + __syncthreads(); + + /* Loop-based NTT: 8 stages (len=128,64,32,16,8,4,2,1) */ + unsigned int k = 0; + #pragma unroll + for (unsigned int len = 128; len >= 1; len >>= 1) { + unsigned int step = len << 1; + unsigned int block_id = tid / len; + unsigned int j = block_id * step + (tid % len); + + coeff_t zeta1 = ntt_zetas[k + 1 + block_id]; + coeff_t sj = s[SP(j)]; + coeff_t t1 = batch_fqmul_local(zeta1, s[SP(j + len)]); + s[SP(j + len)] = sj - t1; + s[SP(j)] = sj + t1; + + /* 128 threads handle N/2=128 butterflies per stage. + * For len<=64, each thread handles 2 butterflies (j and j2). */ + if (len <= 64) { + unsigned int j2_block = (tid + 128) / len; + unsigned int j2 = j2_block * step + ((tid + 128) % len); + if (j2 + len < PARAM_N) { + coeff_t zeta2 = ntt_zetas[k + 1 + j2_block]; + coeff_t sj2 = s[SP(j2)]; + coeff_t t2 = batch_fqmul_local(zeta2, s[SP(j2 + len)]); + s[SP(j2 + len)] = sj2 - t2; + s[SP(j2)] = sj2 + t2; + } + } + + k += (PARAM_N / step); + if (len >= 64 || len == 1) __syncthreads(); else __syncwarp(); + } + + base[tid] = s[SP(tid)]; + base[tid + 128] = s[SP(tid + 128)]; +#undef SP +} + +/* ================================================================ + * 逆 NTT kernel — Gentleman-Sande, 8 stages, 128 threads + * + * ML-DSA: zeta = -ntt_zetas[k-1-block_id], 全 N 系数 *f + * Aigis: zeta = ntt_zetas_inv[k+block_id], 仅前 N/2 系数 *f + * ================================================================ */ +__global__ void batch_invntt_kernel(coeff_t *d_polys, int poly_count) { + int poly_idx = blockIdx.x; + if (poly_idx >= poly_count) return; + + int tid = threadIdx.x; + /* +8 padding: 每 32 元素插 1 个填充字, 消除 bank conflict */ + __shared__ coeff_t s[PARAM_N + (PARAM_N >> 5)]; +#define SP(i) ((i) + ((i) >> 5)) + + coeff_t *base = d_polys + (size_t)poly_idx * PARAM_N; + s[SP(tid)] = base[tid]; + s[SP(tid + 128)] = base[tid + 128]; + __syncthreads(); + +#if ALGORITHM == ALGO_MLDSA + { + unsigned int k = 256; + #pragma unroll + for (unsigned int len = 1; len <= 128; len <<= 1) { + unsigned int step = len << 1; + unsigned int block_id = tid / len; + unsigned int j = block_id * step + (tid % len); + + coeff_t zeta1 = -ntt_zetas[k - 1 - block_id]; + coeff_t t1 = s[SP(j)]; + coeff_t sjlen = s[SP(j + len)]; + s[SP(j)] = t1 + sjlen; + s[SP(j + len)] = batch_fqmul_local(zeta1, t1 - sjlen); + + if (len <= 64) { + unsigned int j2_block = (tid + 128) / len; + unsigned int j2 = j2_block * step + ((tid + 128) % len); + if (j2 + len < PARAM_N) { + coeff_t zeta2 = -ntt_zetas[k - 1 - j2_block]; + coeff_t t2 = s[SP(j2)]; + coeff_t sj2len = s[SP(j2 + len)]; + s[SP(j2)] = t2 + sj2len; + s[SP(j2 + len)] = batch_fqmul_local(zeta2, t2 - sj2len); + } + } + + k -= (PARAM_N / step); + if (len >= 32) __syncthreads(); else __syncwarp(); + } + } + + /* Scale by N^{-1} * MONT */ + { + const coeff_t f = INTT_F; + s[SP(tid)] = batch_fqmul_local(f, s[SP(tid)]); + s[SP(tid + 128)] = batch_fqmul_local(f, s[SP(tid + 128)]); + } + __syncthreads(); + +#elif ALGORITHM == ALGO_AIGIS + { + unsigned int ki = 0; + #pragma unroll + for (unsigned int len = 1; len <= 128; len <<= 1) { + unsigned int step = len << 1; + unsigned int num_blocks = PARAM_N / step; + + /* Each of 128 threads handles one butterfly */ + if (tid < PARAM_N / 2) { + unsigned int blk_id = tid / len; + unsigned int pos_ = tid % len; + unsigned int j = blk_id * step + pos_; + coeff_t zeta = ntt_zetas_inv[ki + blk_id]; + coeff_t t = s[SP(j)]; + coeff_t sjlen = s[SP(j + len)]; + s[SP(j)] = t + sjlen; + s[SP(j + len)] = batch_fqmul_local(zeta, t - sjlen); + } + + ki += num_blocks; + if (len >= 32) __syncthreads(); else __syncwarp(); + } + } + + /* Scale: only first N/2 coefficients */ + { + const coeff_t f = INTT_F; + s[SP(tid)] = batch_fqmul_local(f, s[SP(tid)]); + /* s[tid+128] untouched (last-stage twiddle has N^{-1} baked in) */ + } + __syncthreads(); +#endif + + base[tid] = s[SP(tid)]; + base[tid + 128] = s[SP(tid + 128)]; +#undef SP +} + +/* ================================================================ + * Host launch wrappers + * ================================================================ */ +static inline void launch_batch_ntt(coeff_t *d_polys, int count, hipStream_t stream = 0) { + if (count <= 0) return; + batch_ntt_kernel<<>>(d_polys, count); +} + +static inline void launch_batch_invntt(coeff_t *d_polys, int count, hipStream_t stream = 0) { + if (count <= 0) return; + batch_invntt_kernel<<>>(d_polys, count); +} + +/* ================================================================ + * 算子级并行 NTT — 仿照「合并的第五版」pqc_ntt_par 思路 + * + * 设计: 1 warp (32 线程) 协作完成 1 个多项式的 NTT/INVNTT + * shared memory 由调用者提供: smem[PARAM_N] + * 每个 stage 内 32 线程各处理 N/2/32 = 4 个 butterfly + * stage 间用 __syncwarp() 同步(每个 warp 管理独立的 N 系数块) + * + * 用途: 适合 warp-level 并行的 sign/keygen 内核 + * (e.g. kernel_batch_sign_warp 中每 warp 完成一次签名) + * + * 参数: + * r — 输入/输出系数数组 [PARAM_N],在 smem 中 + * lane — warp 内 lane id (0..31) + * ================================================================ */ +static __device__ __forceinline__ void ntt_warp_par(coeff_t *r, int lane) { + /* 前向 NTT: 8 个 stage, 32 线程各处理 N/2/32 = 4 个 butterfly */ + unsigned int k = 0; + #pragma unroll + for (unsigned int len = 128; len >= 1; len >>= 1) { + unsigned int step = len << 1; + /* 每线程负责 4 个 butterfly, 均匀分配 128 个 butterfly 给 32 线程 */ + for (int b = lane; b < PARAM_N / 2; b += 32) { + unsigned int blk = b / len; + unsigned int pos = b % len; + unsigned int j = blk * step + pos; + coeff_t zeta = ntt_zetas[k + 1 + blk]; + coeff_t t = montgomery_reduce((coeff2_t)zeta * r[j + len]); + r[j + len] = r[j] - t; + r[j] = r[j] + t; + } + k += (PARAM_N / step); + __syncwarp(); + } +} + +/* 逆 NTT warp-parallel 版本 */ +static __device__ __forceinline__ void invntt_warp_par(coeff_t *r, int lane) { +#if ALGORITHM == ALGO_MLDSA + unsigned int k = 256; + #pragma unroll + for (unsigned int len = 1; len <= 128; len <<= 1) { + unsigned int step = len << 1; + for (int b = lane; b < PARAM_N / 2; b += 32) { + unsigned int blk = b / len; + unsigned int pos = b % len; + unsigned int j = blk * step + pos; + coeff_t zeta = -ntt_zetas[k - 1 - blk]; + coeff_t t = r[j]; + r[j] = t + r[j + len]; + r[j + len] = montgomery_reduce((coeff2_t)zeta * (t - r[j + len])); + } + k -= (PARAM_N / step); + __syncwarp(); + } + /* Scale by N^{-1} * MONT */ + for (int i = lane; i < PARAM_N; i += 32) + r[i] = montgomery_reduce((coeff2_t)INTT_F * r[i]); + __syncwarp(); + +#elif ALGORITHM == ALGO_AIGIS + unsigned int ki = 0; + #pragma unroll + for (unsigned int len = 1; len <= 128; len <<= 1) { + unsigned int step = len << 1; + unsigned int num_blocks = PARAM_N / step; + for (int b = lane; b < PARAM_N / 2; b += 32) { + unsigned int blk = b / len; + unsigned int pos = b % len; + unsigned int j = blk * step + pos; + coeff_t zeta = ntt_zetas_inv[ki + blk]; + coeff_t t = r[j]; + r[j] = t + r[j + len]; + r[j + len] = montgomery_reduce((coeff2_t)zeta * (t - r[j + len])); + } + ki += num_blocks; + __syncwarp(); + } + /* Scale: only first N/2 coefficients (Aigis: last-stage twiddle bakes N^{-1} for upper half) */ + for (int i = lane; i < PARAM_N / 2; i += 32) + r[i] = montgomery_reduce((coeff2_t)INTT_F * r[i]); + __syncwarp(); +#endif +} + +#endif /* BATCH_NTT_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_ops.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_ops.cuh new file mode 100644 index 000000000..b8cf012a9 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_ops.cuh @@ -0,0 +1,207 @@ +#include "hip/hip_runtime.h" +/* + * batch_ops.cuh — 统一的批量多项式算术 kernel + * + * 逐系数操作: 每个线程处理一个系数, 256 threads/block. + * 通过 coeff_t / coeff_* 包装函数实现算法无关. + */ + +#ifndef BATCH_OPS_CUH +#define BATCH_OPS_CUH + +#include +#include +#include "params.h" +#include "reduce.cuh" +#include "rounding.cuh" + +#define BATCH_TPB 256 + +/* ================================================================ + * 共用 kernel — 两种算法一份代码 + * ================================================================ */ + +__global__ void batch_poly_add_kernel(coeff_t *c, const coeff_t *a, + const coeff_t *b, int total) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total) c[idx] = a[idx] + b[idx]; +} + +__global__ void batch_poly_sub_kernel(coeff_t *c, const coeff_t *a, + const coeff_t *b, int total) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total) c[idx] = coeff_sub(a[idx], b[idx]); +} + +__global__ void batch_poly_pointwise_kernel(coeff_t *c, const coeff_t *a, + const coeff_t *b, int total) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total) c[idx] = coeff_fqmul(a[idx], b[idx]); +} + +__global__ void batch_poly_reduce_kernel(coeff_t *a, int total) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total) a[idx] = coeff_reduce(a[idx]); +} + +__global__ void batch_poly_normalize_kernel(coeff_t *a, int total) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total) a[idx] = coeff_normalize(a[idx]); +} + +__global__ void batch_poly_freeze_wide_kernel(coeff_t *a, int total) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total) a[idx] = coeff_freeze_wide(a[idx]); +} + +__global__ void batch_poly_shiftl_kernel(coeff_t *a, int total, unsigned int k) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total) a[idx] <<= k; +} + +/* ================================================================ + * power2round kernel — 算法差异已在 rounding.cuh 中封装 + * ================================================================ */ +__global__ void batch_power2round_kernel(coeff_t *d_a1, coeff_t *d_a0, + const coeff_t *d_a, int total_coeffs) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_coeffs) return; + +#if ALGORITHM == ALGO_MLDSA + coeff_t val = d_a[idx]; + /* ML-DSA caddq before power2round */ + val += (val >> 31) & PARAM_Q; + int32_t a0_val; + d_a1[idx] = power2round(&a0_val, val); + d_a0[idx] = a0_val; +#elif ALGORITHM == ALGO_AIGIS + int32_t a0_val; + d_a1[idx] = power2round(&a0_val, d_a[idx]); + d_a0[idx] = a0_val; +#endif +} + +/* ================================================================ + * use_hint kernel — 用于 verify pipeline + * ================================================================ */ +#if ALGORITHM == ALGO_MLDSA + +__global__ void batch_use_hint_kernel(coeff_t * __restrict__ d_out, + const coeff_t * __restrict__ d_a, + const coeff_t * __restrict__ d_hint, + int total_coeffs) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_coeffs) return; + + int32_t a = d_a[idx]; + int32_t hint = d_hint[idx]; + + int32_t a1; + a1 = (a + 127) >> 7; +#if PARAM_GAMMA2 == ((PARAM_Q-1)/32) + a1 = (a1*1025 + (1 << 21)) >> 22; + a1 &= 15; +#elif PARAM_GAMMA2 == ((PARAM_Q-1)/88) + a1 = (a1*11275 + (1 << 23)) >> 24; + a1 ^= ((43 - a1) >> 31) & a1; +#endif + int32_t a0 = a - a1 * 2 * PARAM_GAMMA2; + a0 -= (((PARAM_Q-1)/2 - a0) >> 31) & PARAM_Q; + + if (hint == 0) { d_out[idx] = a1; return; } + +#if PARAM_GAMMA2 == ((PARAM_Q-1)/32) + if (a0 > 0) d_out[idx] = (a1 + 1) & 15; + else d_out[idx] = (a1 - 1) & 15; +#elif PARAM_GAMMA2 == ((PARAM_Q-1)/88) + if (a0 > 0) d_out[idx] = (a1 == 43) ? 0 : a1 + 1; + else d_out[idx] = (a1 == 0) ? 43 : a1 - 1; +#endif +} + +#elif ALGORITHM == ALGO_AIGIS + +__global__ void batch_use_hint_kernel(coeff_t * __restrict__ d_out, + const coeff_t * __restrict__ d_a, + const coeff_t * __restrict__ d_hint, + int total_coeffs) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_coeffs) return; + + int32_t hint = d_hint[idx]; + int32_t a = d_a[idx]; + + /* Aigis use_hint: call rounding.cuh use_hint directly */ + d_out[idx] = use_hint(a, hint); +} + +#endif /* ALGORITHM */ + +/* ================================================================ + * Host launch wrappers + * ================================================================ */ + +static inline void launch_batch_add(coeff_t *c, const coeff_t *a, + const coeff_t *b, int total_coeffs, + hipStream_t stream = 0) { + int nblk = (total_coeffs + BATCH_TPB - 1) / BATCH_TPB; + batch_poly_add_kernel<<>>(c, a, b, total_coeffs); +} + +static inline void launch_batch_sub(coeff_t *c, const coeff_t *a, + const coeff_t *b, int total_coeffs, + hipStream_t stream = 0) { + int nblk = (total_coeffs + BATCH_TPB - 1) / BATCH_TPB; + batch_poly_sub_kernel<<>>(c, a, b, total_coeffs); +} + +static inline void launch_batch_reduce(coeff_t *a, int total_coeffs, + hipStream_t stream = 0) { + int nblk = (total_coeffs + BATCH_TPB - 1) / BATCH_TPB; + batch_poly_reduce_kernel<<>>(a, total_coeffs); +} + +static inline void launch_batch_normalize(coeff_t *a, int total_coeffs, + hipStream_t stream = 0) { + int nblk = (total_coeffs + BATCH_TPB - 1) / BATCH_TPB; + batch_poly_normalize_kernel<<>>(a, total_coeffs); +} + +static inline void launch_batch_freeze_wide(coeff_t *a, int total_coeffs, + hipStream_t stream = 0) { + int nblk = (total_coeffs + BATCH_TPB - 1) / BATCH_TPB; + batch_poly_freeze_wide_kernel<<>>(a, total_coeffs); +} + +static inline void launch_batch_shiftl(coeff_t *a, int total_coeffs, + unsigned int k, hipStream_t stream = 0) { + int nblk = (total_coeffs + BATCH_TPB - 1) / BATCH_TPB; + batch_poly_shiftl_kernel<<>>(a, total_coeffs, k); +} + +static inline void launch_batch_power2round(coeff_t *v1, coeff_t *v0, + const coeff_t *v, int total_coeffs, + hipStream_t stream = 0) { + int nblk = (total_coeffs + BATCH_TPB - 1) / BATCH_TPB; + batch_power2round_kernel<<>>(v1, v0, v, total_coeffs); +} + +static inline void launch_batch_use_hint(coeff_t *w, const coeff_t *u, + const coeff_t *h, int total_coeffs, + hipStream_t stream = 0) { + int nblk = (total_coeffs + BATCH_TPB - 1) / BATCH_TPB; + batch_use_hint_kernel<<>>(w, u, h, total_coeffs); +} + +/* 别名: freeze2q / caddq → normalize, freeze4q → freeze_wide */ +static inline void launch_batch_freeze2q(coeff_t *a, int poly_count, + hipStream_t stream = 0) { + launch_batch_normalize(a, poly_count * PARAM_N, stream); +} + +static inline void launch_batch_caddq(coeff_t *a, int total_coeffs, + hipStream_t stream = 0) { + launch_batch_normalize(a, total_coeffs, stream); +} + +#endif /* BATCH_OPS_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_sign.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_sign.cuh new file mode 100644 index 000000000..918bb81c4 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_sign.cuh @@ -0,0 +1,900 @@ +#include "hip/hip_runtime.h" +/* + * batch_sign.cuh — 分解式批量签名 pipeline (算子级并行) + * + * 核心优化思路: + * 1. y 采样 (expand_gamma1) 全批次并行 — 每实例独立 SHAKE 流 + * 2. NTT(y) 使用 shared-memory 批量 kernel (128 线程/poly) + * 3. w = A·y_hat 使用共享矩阵 2D grid matvec (复用 verify matvec) + * 4. INVNTT + 归一化 + 分解 + 哈希挑战 — 专用批量 kernel + * 5. z/cs2/ct0 计算: 挑战多项式 × 共享向量 → 批量 pointwise + INVNTT + * + * Rejection loop: 每轮全批次并行; 已完成实例通过 d_done 标志跳过. + * + * 所有实例共用: + * mat (矩阵 A, NTT域), s1_ntt, s2_ntt, t0_ntt, mu, rhoprime/key_mu + * Per-instance 差异化: + * y — 通过 per-instance 初始 nonce (inst 号) 保证唯一性 + */ + +#ifndef BATCH_SIGN_CUH +#define BATCH_SIGN_CUH + +#include +#include +#include +#include "params.h" +#include "reduce.cuh" +#include "ntt.cuh" +#include "rounding.cuh" +#include "fips202.cuh" +#include "batch_ntt.cuh" +#include "batch_ops.cuh" +#include "poly.cuh" +#include "polyvec.cuh" +#include "packing.cuh" +#include "sign.cuh" +#include "symmetric.cuh" + +/* 最大拒绝轮数 — ML-DSA-44 每轮接受率≈23%,P(>50轮)≈0.765^50≈1.5e-6/inst + * B=4096 时期望失败≈0.006,基本为零;Aigis/ML-DSA其他参数类似 */ +#define MAX_SIGN_ROUNDS 200 + +#ifndef BATCH_SIGN_DECOMP_SYNC_EACH_ROUND +#define BATCH_SIGN_DECOMP_SYNC_EACH_ROUND 0 +#endif + +#ifndef BATCH_SIGN_DECOMP_CHECK_INTERVAL +#define BATCH_SIGN_DECOMP_CHECK_INTERVAL 4 +#endif + +#ifndef BATCH_SIGN_SAMPLE_TPB +#define BATCH_SIGN_SAMPLE_TPB 64 +#endif + +#ifndef BATCH_SIGN_HASH_TPB +#define BATCH_SIGN_HASH_TPB 32 +#endif + +#ifndef BATCH_SIGN_CHECK_TPB +#define BATCH_SIGN_CHECK_TPB 32 +#endif + +#ifndef BATCH_SIGN_DECOMP_TAIL_ENABLE +#define BATCH_SIGN_DECOMP_TAIL_ENABLE 0 +#endif + +#ifndef BATCH_SIGN_DECOMP_TAIL_AFTER +#define BATCH_SIGN_DECOMP_TAIL_AFTER 24 +#endif + +#ifndef BATCH_SIGN_DECOMP_TAIL_PENDING_DIV +#define BATCH_SIGN_DECOMP_TAIL_PENDING_DIV 128 +#endif + +#ifndef BATCH_SIGN_DECOMP_TAIL_PENDING_MIN +#define BATCH_SIGN_DECOMP_TAIL_PENDING_MIN 16 +#endif + +#ifndef BATCH_SIGN_SAMPLE_DUP_YHAT +#define BATCH_SIGN_SAMPLE_DUP_YHAT 0 +#endif + +#ifndef BATCH_SIGN_CP_FUSE_ENABLE +#define BATCH_SIGN_CP_FUSE_ENABLE 0 +#endif + +/* ================================================================ + * 共享材料 (来自 precomp_t, 一次提取, 所有实例共用) + * ================================================================ */ +struct BatchSignShared { + coeff_t *d_mat; /* K * L * N — 矩阵 A (NTT 域, SoA, 无 batch 维) */ + coeff_t *d_s1_ntt; /* L * N */ + coeff_t *d_s2_ntt; /* K * N */ + coeff_t *d_t0_ntt; /* K * N */ + uint8_t *d_mu; /* CRHBYTES */ + uint8_t *d_rhoprime; /* ML-DSA: CRHBYTES / Aigis: SEEDBYTES+CRHBYTES */ +}; + +/* ================================================================ + * per-instance 工作缓冲区 + * ================================================================ */ +struct BatchSignPipeline { + BatchSignShared sh; + + coeff_t *d_y; /* L * B * N SoA — y (系数域, 全程保留) */ + coeff_t *d_y_hat; /* L * B * N SoA — NTT(y) (供 matvec 使用, 之后可覆盖) */ + coeff_t *d_w; /* K * B * N SoA — A·y 归一化后 (Aigis 保留供 wcs2) */ + coeff_t *d_w0; /* K * B * N SoA — decompose 低位 */ + coeff_t *d_w1; /* K * B * N SoA — decompose 高位 */ + coeff_t *d_cp; /* B * N SoA — 挑战多项式 (NTT 域) */ + coeff_t *d_z; /* L * B * N SoA — INVNTT(cp·s1)+y */ + coeff_t *d_cs2; /* K * B * N SoA — INVNTT(cp·s2) */ + coeff_t *d_ct0; /* K * B * N SoA — INVNTT(cp·t0) */ + uint8_t *d_cbuf; /* B*CTILDEBYTES (ML-DSA) or B*N*4 (Aigis cp poly) */ + uint16_t *d_nonces; /* B — per-instance 当前 nonce */ + int *d_done; /* B — 0=pending, 1=done */ + int *d_done_count; /* B completed signatures */ + uint8_t *d_sigs; /* B * CRYPTO_BYTES — 输出签名 (AoS) */ + + int max_batch; +}; + +struct BatchSignRuntimeOptions { + int cp_fuse_enable; + int check_interval; + int hash_tpb; + int check_tpb; +}; + +static BatchSignRuntimeOptions batch_sign_default_runtime_options(void) { + BatchSignRuntimeOptions opt; + opt.cp_fuse_enable = BATCH_SIGN_CP_FUSE_ENABLE; + opt.check_interval = BATCH_SIGN_DECOMP_CHECK_INTERVAL; + opt.hash_tpb = BATCH_SIGN_HASH_TPB; + opt.check_tpb = BATCH_SIGN_CHECK_TPB; + return opt; +} + +/* ================================================================ + * [0a] 共享材料提取 — precomp_t AoS → flat SoA + * ================================================================ */ +__global__ void batch_sign_setup_kernel( + coeff_t *d_mat_flat, + coeff_t *d_s1_flat, + coeff_t *d_s2_flat, + coeff_t *d_t0_flat, + const precomp_t *d_pc) +{ + for (int k = 0; k < PARAM_K; k++) + for (int l = 0; l < PARAM_L; l++) + for (int c = 0; c < PARAM_N; c++) + d_mat_flat[(k * PARAM_L + l) * PARAM_N + c] + = d_pc->mat[k].vec[l].coeffs[c]; + for (int l = 0; l < PARAM_L; l++) + for (int c = 0; c < PARAM_N; c++) + d_s1_flat[l * PARAM_N + c] = d_pc->s1_ntt.vec[l].coeffs[c]; + for (int k = 0; k < PARAM_K; k++) + for (int c = 0; c < PARAM_N; c++) { + d_s2_flat[k * PARAM_N + c] = d_pc->s2_ntt.vec[k].coeffs[c]; + d_t0_flat[k * PARAM_N + c] = d_pc->t0_ntt.vec[k].coeffs[c]; + } +} + +/* ================================================================ + * [0b] 计算 mu 和 rhoprime (单线程) + * ================================================================ */ +__global__ void batch_sign_compute_mu_rhoprime_kernel( + uint8_t *d_mu, + uint8_t *d_rhoprime, + const precomp_t *d_pc, + const uint8_t *d_msg, + size_t mlen, + const uint8_t *d_pre, + size_t prelen, + const uint8_t *d_rnd) +{ + keccak_state state; +#if ALGORITHM == ALGO_MLDSA + shake256_init(&state); + shake256_absorb(&state, d_pc->tr, TRBYTES); + shake256_absorb(&state, d_pre, prelen); + shake256_absorb(&state, d_msg, mlen); + shake256_finalize(&state); + shake256_squeeze(d_mu, CRHBYTES, &state); + + shake256_init(&state); + shake256_absorb(&state, d_pc->key, SEEDBYTES); +#if RNDBYTES > 0 + shake256_absorb(&state, d_rnd, RNDBYTES); +#endif + shake256_absorb(&state, d_mu, CRHBYTES); + shake256_finalize(&state); + shake256_squeeze(d_rhoprime, CRHBYTES, &state); + +#elif ALGORITHM == ALGO_AIGIS + shake256_init(&state); + shake256_absorb(&state, d_pc->tr, TRBYTES); + shake256_absorb(&state, d_msg, mlen); + shake256_finalize(&state); + shake256_squeeze(d_mu, CRHBYTES, &state); + /* key_mu = key || mu */ + for (int i = 0; i < SEEDBYTES; i++) d_rhoprime[i] = d_pc->key[i]; + for (int i = 0; i < CRHBYTES; i++) d_rhoprime[SEEDBYTES+i] = d_mu[i]; +#endif + (void)d_rnd; (void)prelen; +} + +/* ================================================================ + * [0c] 初始化 nonce / done + * ================================================================ */ +__global__ void batch_sign_init_kernel(uint16_t *d_nonces, int *d_done, int B) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= B) return; +#if ALGORITHM == ALGO_MLDSA + d_nonces[i] = (uint16_t)i; +#else + d_nonces[i] = (uint16_t)((unsigned)i * PARAM_L); +#endif + d_done[i] = 0; +} + +/* ================================================================ + * [1] 采样 y kernel — 并行 expand_gamma1 + * 每线程: 一个 pending 实例 → L 个多项式 + * ================================================================ */ +__global__ void __launch_bounds__(BATCH_SIGN_SAMPLE_TPB) +batch_sign_sample_y_kernel( + coeff_t *d_y, +#if BATCH_SIGN_SAMPLE_DUP_YHAT + coeff_t *d_y_hat, +#endif + uint16_t *d_nonces, + const int *d_done, + const uint8_t *d_rhoprime, + int B) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= B) return; +#if BATCH_SIGN_SAMPLE_DUP_YHAT + if (d_done[inst]) { + for (int l = 0; l < PARAM_L; l++) { + const coeff_t *src = d_y + (size_t)l * B * PARAM_N + (size_t)inst * PARAM_N; + coeff_t *dst_hat = d_y_hat + (size_t)l * B * PARAM_N + (size_t)inst * PARAM_N; + for (int c = 0; c < PARAM_N; c++) dst_hat[c] = src[c]; + } + return; + } +#else + if (d_done[inst]) return; +#endif + + uint16_t base = d_nonces[inst]; +#if ALGORITHM == ALGO_MLDSA + d_nonces[inst] = (uint16_t)(base + 1); +#else + d_nonces[inst] = (uint16_t)(base + (uint16_t)PARAM_L); +#endif + for (int l = 0; l < PARAM_L; l++) { + poly tmp; + poly_uniform_gamma1(&tmp, d_rhoprime, GAMMA1_NONCE(base, l)); + coeff_t *dst = d_y + (size_t)l * B * PARAM_N + (size_t)inst * PARAM_N; +#if BATCH_SIGN_SAMPLE_DUP_YHAT + coeff_t *dst_hat = d_y_hat + (size_t)l * B * PARAM_N + (size_t)inst * PARAM_N; + for (int c = 0; c < PARAM_N; c++) { + coeff_t v = tmp.coeffs[c]; + dst[c] = v; + dst_hat[c] = v; + } +#else + for (int c = 0; c < PARAM_N; c++) dst[c] = tmp.coeffs[c]; +#endif + } +} + +/* ================================================================ + * [6] 分解 w → (w1, w0) — per-coeff batch kernel + * ================================================================ */ +__global__ void batch_sign_decompose_kernel( + coeff_t *d_w1, + coeff_t *d_w0, + const coeff_t *d_w_in, + int total) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total) return; + int32_t a0; + int32_t a1 = decompose(&a0, d_w_in[idx]); + d_w1[idx] = a1; + d_w0[idx] = a0; +} + +/* ================================================================ + * [7] 哈希挑战 — H(mu || pack(w1)) → cp (并存 c_seed / c_poly 到 d_cbuf) + * ================================================================ */ +__global__ void __launch_bounds__(BATCH_SIGN_HASH_TPB) +batch_sign_hash_cp_kernel( + coeff_t *d_cp, + uint8_t *d_cbuf, + const uint8_t *d_mu, + const coeff_t *d_w1, + const int *d_done, + int B) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= B || d_done[inst]) return; + + /* 打包 w1 */ + uint8_t w1_packed[PARAM_K * POLYW1_PACKEDBYTES]; + for (int ki = 0; ki < PARAM_K; ki++) { + const coeff_t *w1k = d_w1 + (size_t)ki * B * PARAM_N + (size_t)inst * PARAM_N; + uint8_t *r = w1_packed + ki * POLYW1_PACKEDBYTES; +#if ALGORITHM == ALGO_MLDSA + #if PARAM_GAMMA2 == ((PARAM_Q-1)/88) + for (unsigned i = 0; i < PARAM_N/4; i++) { + r[3*i+0] = (uint8_t)(w1k[4*i+0]); + r[3*i+0] |= (uint8_t)(w1k[4*i+1] << 6); + r[3*i+1] = (uint8_t)(w1k[4*i+1] >> 2); + r[3*i+1] |= (uint8_t)(w1k[4*i+2] << 4); + r[3*i+2] = (uint8_t)(w1k[4*i+2] >> 4); + r[3*i+2] |= (uint8_t)(w1k[4*i+3] << 2); + } + #else /* GAMMA2 = (Q-1)/32 */ + for (unsigned i = 0; i < PARAM_N/2; i++) + r[i] = (uint8_t)(w1k[2*i+0] | (w1k[2*i+1] << 4)); + #endif +#elif ALGORITHM == ALGO_AIGIS + for (unsigned i = 0; i < PARAM_N/8; i++) { + r[3*i+0] = (uint8_t)(w1k[8*i+0] | (w1k[8*i+1] << 3) | (w1k[8*i+2] << 6)); + r[3*i+1] = (uint8_t)((w1k[8*i+2] >> 2) | (w1k[8*i+3] << 1) + | (w1k[8*i+4] << 4) | (w1k[8*i+5] << 7)); + r[3*i+2] = (uint8_t)((w1k[8*i+5] >> 1) | (w1k[8*i+6] << 2) | (w1k[8*i+7] << 5)); + } +#endif + } + + coeff_t *cp = d_cp + (size_t)inst * PARAM_N; + +#if ALGORITHM == ALGO_MLDSA + { + keccak_state st; + shake256_init(&st); + shake256_absorb(&st, d_mu, CRHBYTES); + shake256_absorb(&st, w1_packed, PARAM_K * POLYW1_PACKEDBYTES); + shake256_finalize(&st); + uint8_t c_seed[CTILDEBYTES]; + shake256_squeeze(c_seed, CTILDEBYTES, &st); + /* 保存 c_seed */ + uint8_t *cbuf = d_cbuf + (size_t)inst * CTILDEBYTES; + for (int i = 0; i < CTILDEBYTES; i++) cbuf[i] = c_seed[i]; + /* SampleInBall → cp (与 batch_verify_challenge_kernel 逻辑相同) */ + uint8_t buf2[SHAKE256_RATE]; + keccak_state st2; + shake256_init(&st2); + shake256_absorb(&st2, c_seed, CTILDEBYTES); + shake256_finalize(&st2); + shake256_squeezeblocks(buf2, 1, &st2); + uint64_t signs = 0; + for (int i = 0; i < 8; i++) signs |= (uint64_t)buf2[i] << (8*i); + unsigned int pos = 8; + for (int i = 0; i < PARAM_N; i++) cp[i] = 0; + for (int i = PARAM_N - PARAM_TAU; i < PARAM_N; i++) { + unsigned int b; + do { + if (pos >= SHAKE256_RATE) { shake256_squeezeblocks(buf2, 1, &st2); pos = 0; } + b = buf2[pos++]; + } while (b > i); + cp[i] = cp[b]; + cp[b] = 1 - 2*(int)(signs & 1); + signs >>= 1; + } + } +#elif ALGORITHM == ALGO_AIGIS + { + poly c_tmp; + poly_challenge(&c_tmp, d_mu, w1_packed, PARAM_K * POLYW1_PACKEDBYTES); + for (int i = 0; i < PARAM_N; i++) cp[i] = c_tmp.coeffs[i]; + /* 保存 cp 供打包签名 */ + coeff_t *cbuf_cp = (coeff_t *)(d_cbuf + (size_t)inst * PARAM_N * sizeof(coeff_t)); + for (int i = 0; i < PARAM_N; i++) cbuf_cp[i] = c_tmp.coeffs[i]; + } +#endif +} + +/* ================================================================ + * [9a] cp_ntt × shared_vec → out (2D grid: (B, poly_count) × N threads) + * out[poly][inst][tid] = mont(cp[inst][tid] * shared[poly][tid]) + * ================================================================ */ +__global__ void batch_sign_pointwise_cp_shared_kernel( + coeff_t *d_out, + const coeff_t *d_cp, + const coeff_t *d_shared, + int poly_count, + int B) +{ + int inst = blockIdx.x; + int poly = blockIdx.y; + int tid = threadIdx.x; + if (inst >= B || poly >= poly_count) return; + coeff_t c = d_cp[(size_t)inst * PARAM_N + tid]; + coeff_t s = d_shared[(size_t)poly * PARAM_N + tid]; + d_out[(size_t)poly * B * PARAM_N + (size_t)inst * PARAM_N + tid] + = (coeff_t)montgomery_reduce((coeff2_t)c * s); +} + +__global__ void __launch_bounds__(256) +batch_sign_pointwise_cp_all_shared_kernel( + coeff_t *d_z, + coeff_t *d_cs2, + coeff_t *d_ct0, + const coeff_t *d_cp, + const coeff_t *d_s1_ntt, + const coeff_t *d_s2_ntt, + const coeff_t *d_t0_ntt, + int B) +{ + int inst = blockIdx.x; + int poly = blockIdx.y; + int tid = threadIdx.x; + if (inst >= B || tid >= PARAM_N) return; + + coeff_t c = d_cp[(size_t)inst * PARAM_N + tid]; + size_t out_idx = (size_t)poly * B * PARAM_N + (size_t)inst * PARAM_N + tid; + + if (poly < PARAM_L) { + coeff_t s1 = d_s1_ntt[(size_t)poly * PARAM_N + tid]; + d_z[out_idx] = (coeff_t)montgomery_reduce((coeff2_t)c * s1); + } + if (poly < PARAM_K) { + coeff_t s2 = d_s2_ntt[(size_t)poly * PARAM_N + tid]; + coeff_t t0 = d_t0_ntt[(size_t)poly * PARAM_N + tid]; + d_cs2[out_idx] = (coeff_t)montgomery_reduce((coeff2_t)c * s2); + d_ct0[out_idx] = (coeff_t)montgomery_reduce((coeff2_t)c * t0); + } +} + +/* ================================================================ + * [9b] z += y (in-place add: d_z[i] += d_y[i]) + * ================================================================ */ +__global__ void batch_sign_add_y_kernel( + coeff_t *d_z, + const coeff_t *d_y, + int total) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total) d_z[idx] += d_y[idx]; +} + +/* ================================================================ + * [12] 范数检查 + 提示计算 + 打包签名 (per-instance, 单线程) + * ================================================================ */ +__global__ void __launch_bounds__(BATCH_SIGN_CHECK_TPB) +batch_sign_check_pack_kernel( + int *d_done, + uint8_t *d_sigs, + const coeff_t *d_z, + const coeff_t *d_w, /* 归一化 w (供 Aigis wcs2) */ + const coeff_t *d_w0, /* decompose 低位 */ + const coeff_t *d_w1, /* decompose 高位 */ + const coeff_t *d_cs2, + const coeff_t *d_ct0, + const uint8_t *d_cbuf, + int B) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= B || d_done[inst]) return; + + /* 加载 z */ + polyvecl z_loc; + for (int l = 0; l < PARAM_L; l++) + for (int c = 0; c < PARAM_N; c++) + z_loc.vec[l].coeffs[c] = + d_z[(size_t)l * B * PARAM_N + (size_t)inst * PARAM_N + c]; +#if ALGORITHM == ALGO_MLDSA + polyvecl_reduce(&z_loc); + if (polyvecl_chknorm(&z_loc, PARAM_GAMMA1 - PARAM_BETA1)) return; +#else + polyvecl_freeze4q(&z_loc); + if (polyvecl_chknorm(&z_loc, PARAM_GAMMA1 - PARAM_BETA1)) return; +#endif + +#if ALGORITHM == ALGO_MLDSA + /* r0 = w0 - cs2 */ + polyveck r0, ct0_loc, w1_loc, h_loc; + for (int k = 0; k < PARAM_K; k++) + for (int c = 0; c < PARAM_N; c++) { + int32_t w0v = d_w0 [(size_t)k * B * PARAM_N + (size_t)inst * PARAM_N + c]; + int32_t cs2v = d_cs2[(size_t)k * B * PARAM_N + (size_t)inst * PARAM_N + c]; + r0.vec[k].coeffs[c] = coeff_sub(w0v, cs2v); + } + polyveck_reduce(&r0); + if (polyveck_chknorm(&r0, PARAM_GAMMA2 - PARAM_BETA2)) return; + + for (int k = 0; k < PARAM_K; k++) + for (int c = 0; c < PARAM_N; c++) + ct0_loc.vec[k].coeffs[c] = + d_ct0[(size_t)k * B * PARAM_N + (size_t)inst * PARAM_N + c]; + polyveck_reduce(&ct0_loc); + if (polyveck_chknorm(&ct0_loc, PARAM_GAMMA2)) return; + + for (int k = 0; k < PARAM_K; k++) + for (int c = 0; c < PARAM_N; c++) + w1_loc.vec[k].coeffs[c] = + d_w1[(size_t)k * B * PARAM_N + (size_t)inst * PARAM_N + c]; + + polyveck_add(&r0, &r0, &ct0_loc); /* w0_adj = r0 + ct0 */ + unsigned int n = polyveck_make_hint(&h_loc, &r0, &w1_loc); + if (n > PARAM_OMEGA) return; + + uint8_t *sig_out = d_sigs + (size_t)inst * CRYPTO_BYTES; + const uint8_t *c_seed = d_cbuf + (size_t)inst * CTILDEBYTES; + pack_sig(sig_out, c_seed, &z_loc, &h_loc); + +#elif ALGORITHM == ALGO_AIGIS + /* wcs2 = w - cs2 */ + polyveck wcs2, ct0_loc, h_loc; + for (int k = 0; k < PARAM_K; k++) + for (int c = 0; c < PARAM_N; c++) { + int32_t wv = d_w [(size_t)k * B * PARAM_N + (size_t)inst * PARAM_N + c]; + int32_t cs2v = d_cs2[(size_t)k * B * PARAM_N + (size_t)inst * PARAM_N + c]; + wcs2.vec[k].coeffs[c] = coeff_sub(wv, cs2v); + } + polyveck_freeze4q(&wcs2); + + /* w1 consistency: decompose(wcs2)[high] == d_w1 */ + { + polyveck wcs2_high, wcs2_low; + polyveck_decompose(&wcs2_high, &wcs2_low, &wcs2); + polyveck_freeze2q(&wcs2_low); + for (int k = 0; k < PARAM_K; k++) + for (int c = 0; c < PARAM_N; c++) { + int32_t w1v = d_w1[(size_t)k * B * PARAM_N + (size_t)inst * PARAM_N + c]; + if (wcs2_high.vec[k].coeffs[c] != w1v) return; + } + if (polyveck_chknorm(&wcs2_low, PARAM_GAMMA2 - PARAM_BETA2)) return; + } + + for (int k = 0; k < PARAM_K; k++) + for (int c = 0; c < PARAM_N; c++) + ct0_loc.vec[k].coeffs[c] = + d_ct0[(size_t)k * B * PARAM_N + (size_t)inst * PARAM_N + c]; + polyveck_freeze2q(&ct0_loc); + if (polyveck_chknorm(&ct0_loc, PARAM_GAMMA2)) return; + + /* make_hint(wcs2+ct0, -ct0) */ + polyveck tmp_loc, neg_ct0; + polyveck_add(&tmp_loc, &wcs2, &ct0_loc); + neg_ct0 = ct0_loc; + polyveck_neg(&neg_ct0); + polyveck_freeze2q(&tmp_loc); + unsigned int n = polyveck_make_hint(&h_loc, &tmp_loc, &neg_ct0); + if (n > PARAM_OMEGA) return; + + poly c_poly; + const coeff_t *cbuf_cp = (const coeff_t *)(d_cbuf + (size_t)inst * PARAM_N * sizeof(coeff_t)); + for (int i = 0; i < PARAM_N; i++) c_poly.coeffs[i] = cbuf_cp[i]; + uint8_t *sig_out = d_sigs + (size_t)inst * CRYPTO_BYTES; + pack_sig(sig_out, &z_loc, &h_loc, &c_poly); +#endif + + d_done[inst] = 1; +} + +__global__ void batch_sign_count_done_kernel(const int *d_done, int *d_done_count, int B) { + __shared__ int local_count; + if (threadIdx.x == 0) local_count = 0; + __syncthreads(); + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < B && d_done[idx]) atomicAdd(&local_count, 1); + __syncthreads(); + + if (threadIdx.x == 0) atomicAdd(d_done_count, local_count); +} + +/* Finish the small rejection tail without launching full-batch NTT/matvec rounds. */ +__global__ void __launch_bounds__(64, 1) +batch_sign_tail_precomp_kernel( + int *d_done, + uint8_t *d_sigs, + const uint16_t *d_nonces, + const precomp_t *d_pc, + const uint8_t *d_mu, + const uint8_t *d_rhoprime, + int B) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= B || d_done[inst]) return; + + size_t siglen = 0; +#if ALGORITHM == ALGO_MLDSA + int r = crypto_sign_signature_precomp_cached( + d_sigs + (size_t)inst * CRYPTO_BYTES, &siglen, + d_mu, d_rhoprime, d_pc, d_nonces[inst]); +#else + int r = crypto_sign_signature_precomp_cached( + d_sigs + (size_t)inst * CRYPTO_BYTES, &siglen, + d_mu, d_rhoprime, d_pc, d_nonces[inst]); +#endif + if (r == 0 && siglen == CRYPTO_BYTES) d_done[inst] = 1; +} + +/* ================================================================ + * Host API — 缓冲区分配/释放 + * ================================================================ */ +static int batch_sign_alloc(BatchSignPipeline *p, int max_batch) { + memset(p, 0, sizeof(*p)); + p->max_batch = max_batch; + size_t B = max_batch, N = PARAM_N; + +#define BS_TRY(ptr, sz) do { \ + if (hipMalloc(&(ptr), (sz)) != hipSuccess) { hipGetLastError(); return -1; } \ +} while(0) + + BS_TRY(p->sh.d_mat, (size_t)PARAM_K * PARAM_L * N * sizeof(coeff_t)); + BS_TRY(p->sh.d_s1_ntt, (size_t)PARAM_L * N * sizeof(coeff_t)); + BS_TRY(p->sh.d_s2_ntt, (size_t)PARAM_K * N * sizeof(coeff_t)); + BS_TRY(p->sh.d_t0_ntt, (size_t)PARAM_K * N * sizeof(coeff_t)); + BS_TRY(p->sh.d_mu, CRHBYTES); +#if ALGORITHM == ALGO_MLDSA + BS_TRY(p->sh.d_rhoprime, CRHBYTES); +#else + BS_TRY(p->sh.d_rhoprime, SEEDBYTES + CRHBYTES); +#endif + BS_TRY(p->d_y, (size_t)PARAM_L * B * N * sizeof(coeff_t)); + BS_TRY(p->d_y_hat, (size_t)PARAM_L * B * N * sizeof(coeff_t)); + BS_TRY(p->d_w, (size_t)PARAM_K * B * N * sizeof(coeff_t)); + BS_TRY(p->d_w0, (size_t)PARAM_K * B * N * sizeof(coeff_t)); + BS_TRY(p->d_w1, (size_t)PARAM_K * B * N * sizeof(coeff_t)); + BS_TRY(p->d_cp, (size_t)B * N * sizeof(coeff_t)); + BS_TRY(p->d_z, (size_t)PARAM_L * B * N * sizeof(coeff_t)); + BS_TRY(p->d_cs2, (size_t)PARAM_K * B * N * sizeof(coeff_t)); + BS_TRY(p->d_ct0, (size_t)PARAM_K * B * N * sizeof(coeff_t)); + BS_TRY(p->d_nonces, B * sizeof(uint16_t)); + BS_TRY(p->d_done, B * sizeof(int)); + BS_TRY(p->d_done_count, sizeof(int)); + BS_TRY(p->d_sigs, B * CRYPTO_BYTES); +#if ALGORITHM == ALGO_MLDSA + BS_TRY(p->d_cbuf, B * CTILDEBYTES); +#else + BS_TRY(p->d_cbuf, B * N * sizeof(coeff_t)); +#endif +#undef BS_TRY + return 0; +} + +static void batch_sign_free(BatchSignPipeline *p) { + hipFree(p->sh.d_mat); hipFree(p->sh.d_s1_ntt); + hipFree(p->sh.d_s2_ntt); hipFree(p->sh.d_t0_ntt); + hipFree(p->sh.d_mu); hipFree(p->sh.d_rhoprime); + hipFree(p->d_y); hipFree(p->d_y_hat); hipFree(p->d_w); + hipFree(p->d_w0); hipFree(p->d_w1); hipFree(p->d_cp); + hipFree(p->d_z); hipFree(p->d_cs2); hipFree(p->d_ct0); + hipFree(p->d_nonces); hipFree(p->d_done); hipFree(p->d_done_count); + hipFree(p->d_sigs); hipFree(p->d_cbuf); + memset(p, 0, sizeof(*p)); +} + +static int batch_sign_count_done_host(BatchSignPipeline *p, int B) { + int done_now = 0; + int tpb = 256, nblk = (B + tpb - 1) / tpb; + hipMemsetAsync(p->d_done_count, 0, sizeof(int)); + batch_sign_count_done_kernel<<>>(p->d_done, p->d_done_count, B); + hipMemcpy(&done_now, p->d_done_count, sizeof(int), hipMemcpyDeviceToHost); + return done_now; +} + +static int batch_sign_tail_finish( + BatchSignPipeline *p, + int B, + const precomp_t *d_pc) +{ + int tpb = 64, nblk = (B + tpb - 1) / tpb; + batch_sign_tail_precomp_kernel<<>>( + p->d_done, p->d_sigs, p->d_nonces, d_pc, + p->sh.d_mu, p->sh.d_rhoprime, B); + hipError_t e = hipGetLastError(); + if (e != hipSuccess) return -1; + return batch_sign_count_done_host(p, B); +} + +/* ================================================================ + * 批量签名 pipeline 主函数 + * + * 前置条件: + * p->max_batch >= batch_count + * d_pc 已通过 kernel_create_precomp 构造 (device 端) + * d_msg / d_rnd / d_pre 均在 device 端 + * + * 输出: + * p->d_sigs[0..batch_count-1*CRYPTO_BYTES] — AoS 格式签名 + * ================================================================ */ +static int batch_sign_pipeline_ex( + BatchSignPipeline *p, + int batch_count, + const precomp_t *d_pc, + const uint8_t *d_msg, + size_t mlen, + const uint8_t *d_pre, + size_t prelen, + const uint8_t *d_rnd, + const BatchSignRuntimeOptions *runtime_opt, + int *h_rounds, + int *h_done) +{ + if (batch_count <= 0 || batch_count > p->max_batch) return -1; + int B = batch_count, N = PARAM_N; + BatchSignRuntimeOptions opt = runtime_opt + ? *runtime_opt + : batch_sign_default_runtime_options(); + int runtime_cp_fuse = opt.cp_fuse_enable != 0; + int runtime_check_interval = opt.check_interval > 0 + ? opt.check_interval + : BATCH_SIGN_DECOMP_CHECK_INTERVAL; + int runtime_hash_tpb = opt.hash_tpb > 0 ? opt.hash_tpb : BATCH_SIGN_HASH_TPB; + int runtime_check_tpb = opt.check_tpb > 0 ? opt.check_tpb : BATCH_SIGN_CHECK_TPB; + if (runtime_hash_tpb > BATCH_SIGN_HASH_TPB) runtime_hash_tpb = BATCH_SIGN_HASH_TPB; + if (runtime_check_tpb > BATCH_SIGN_CHECK_TPB) runtime_check_tpb = BATCH_SIGN_CHECK_TPB; + + /* [0a] 提取共享材料 */ + batch_sign_setup_kernel<<<1, 1>>>( + p->sh.d_mat, p->sh.d_s1_ntt, p->sh.d_s2_ntt, p->sh.d_t0_ntt, d_pc); + + /* [0b] 计算 mu + rhoprime */ + batch_sign_compute_mu_rhoprime_kernel<<<1, 1>>>( + p->sh.d_mu, p->sh.d_rhoprime, d_pc, d_msg, mlen, d_pre, prelen, d_rnd); + + /* [0c] 初始化 nonce / done */ + { + int tpb = 256, nblk = (B + tpb - 1) / tpb; + batch_sign_init_kernel<<>>(p->d_nonces, p->d_done, B); + } + /* ============================================================ + * Rejection loop — 固定 MAX_SIGN_ROUNDS 轮, 不在轮间同步. + * 同一 CUDA stream 内核按顺序执行, 数据依赖自动满足. + * 已完成实例通过 d_done 标志在各 kernel 内 early-exit. + * P(B=4096 实例 round>15 未完成) < 1e-10 for ML-DSA-44. + * ============================================================ */ + for (int round = 0; round < MAX_SIGN_ROUNDS; round++) { + + /* [1] 采样 y */ + { + int tpb = BATCH_SIGN_SAMPLE_TPB, nblk = (B + tpb - 1) / tpb; + batch_sign_sample_y_kernel<<>>( + p->d_y, +#if BATCH_SIGN_SAMPLE_DUP_YHAT + p->d_y_hat, +#endif + p->d_nonces, p->d_done, p->sh.d_rhoprime, B); + } + + /* 备份 y → d_y_hat (后续 NTT 就地覆盖 d_y_hat, 保留 d_y 为系数域) */ +#if !BATCH_SIGN_SAMPLE_DUP_YHAT + hipMemcpyAsync(p->d_y_hat, p->d_y, + (size_t)PARAM_L * B * N * sizeof(coeff_t), + hipMemcpyDeviceToDevice); +#endif + + /* [2] NTT(y_hat) 就地 */ + launch_batch_ntt(p->d_y_hat, B * PARAM_L); + + /* [3] w = A · y_hat (共享矩阵 2D grid matvec, 复用 verify kernel) */ + { + dim3 grid(B, PARAM_K); + batch_verify_matvec_kernel<<>>( + p->d_w, p->sh.d_mat, p->d_y_hat, B); + } + + /* [4] reduce + INVNTT(w) */ + launch_batch_reduce(p->d_w, B * PARAM_K * N); + launch_batch_invntt(p->d_w, B * PARAM_K); + + /* [5] 归一化 w */ +#if ALGORITHM == ALGO_MLDSA + launch_batch_reduce(p->d_w, B * PARAM_K * N); + launch_batch_caddq(p->d_w, B * PARAM_K * N); +#else + launch_batch_freeze2q(p->d_w, PARAM_K * B); +#endif + + /* [6] decompose(w) → (w1, w0) + * d_w 保留不变 (供 Aigis wcs2 = w - cs2 计算) */ + { + int total = PARAM_K * B * N; + int tpb = BATCH_TPB, nblk = (total + tpb - 1) / tpb; + batch_sign_decompose_kernel<<>>(p->d_w1, p->d_w0, p->d_w, total); + } + + /* [7] 哈希挑战: H(mu || pack(w1)) → d_cp + d_cbuf */ + { + int tpb = runtime_hash_tpb, nblk = (B + tpb - 1) / tpb; + batch_sign_hash_cp_kernel<<>>( + p->d_cp, p->d_cbuf, p->sh.d_mu, p->d_w1, p->d_done, B); + } + + /* [8] NTT(cp) 就地 */ + launch_batch_ntt(p->d_cp, B); + + if (runtime_cp_fuse) { + const int max_shared = (PARAM_L > PARAM_K) ? PARAM_L : PARAM_K; + dim3 grid_all(B, max_shared); + batch_sign_pointwise_cp_all_shared_kernel<<>>( + p->d_z, p->d_cs2, p->d_ct0, p->d_cp, + p->sh.d_s1_ntt, p->sh.d_s2_ntt, p->sh.d_t0_ntt, B); + } + + /* [9] z = INVNTT(cp_ntt · s1_ntt) + y */ + if (!runtime_cp_fuse) { + dim3 grid_l(B, PARAM_L); + batch_sign_pointwise_cp_shared_kernel<<>>( + p->d_z, p->d_cp, p->sh.d_s1_ntt, PARAM_L, B); + } + launch_batch_invntt(p->d_z, B * PARAM_L); + { + int total = PARAM_L * B * N; + int tpb = BATCH_TPB, nblk = (total + tpb - 1) / tpb; + batch_sign_add_y_kernel<<>>(p->d_z, p->d_y, total); + } + + /* [10] cs2 = INVNTT(cp_ntt · s2_ntt) */ + if (!runtime_cp_fuse) { + dim3 grid_k(B, PARAM_K); + batch_sign_pointwise_cp_shared_kernel<<>>( + p->d_cs2, p->d_cp, p->sh.d_s2_ntt, PARAM_K, B); + } + launch_batch_invntt(p->d_cs2, B * PARAM_K); + + /* [11] ct0 = INVNTT(cp_ntt · t0_ntt) */ + if (!runtime_cp_fuse) { + dim3 grid_k(B, PARAM_K); + batch_sign_pointwise_cp_shared_kernel<<>>( + p->d_ct0, p->d_cp, p->sh.d_t0_ntt, PARAM_K, B); + } + launch_batch_invntt(p->d_ct0, B * PARAM_K); + + /* [12] 范数检查 + 提示 + 打包签名 */ + { + int tpb = runtime_check_tpb, nblk = (B + tpb - 1) / tpb; + batch_sign_check_pack_kernel<<>>( + p->d_done, p->d_sigs, + p->d_z, p->d_w, p->d_w0, p->d_w1, + p->d_cs2, p->d_ct0, p->d_cbuf, B); + } + if (BATCH_SIGN_DECOMP_SYNC_EACH_ROUND || + ((round + 1) % runtime_check_interval) == 0 || + round + 1 == MAX_SIGN_ROUNDS) { + int done_now = batch_sign_count_done_host(p, B); + if (done_now >= B) { + if (h_rounds) *h_rounds = round + 1; + if (h_done) *h_done = done_now; + return 0; + } +#if BATCH_SIGN_DECOMP_TAIL_ENABLE + int remaining = B - done_now; + int tail_limit = B / BATCH_SIGN_DECOMP_TAIL_PENDING_DIV; + if (tail_limit < BATCH_SIGN_DECOMP_TAIL_PENDING_MIN) + tail_limit = BATCH_SIGN_DECOMP_TAIL_PENDING_MIN; + if ((round + 1) >= BATCH_SIGN_DECOMP_TAIL_AFTER && + remaining <= tail_limit) { + int tail_done = batch_sign_tail_finish(p, B, d_pc); + if (tail_done < 0) return -1; + if (h_rounds) *h_rounds = round + 1; + if (h_done) *h_done = tail_done; + return 0; + } +#endif + } + } + + /* 单次最终同步 — 等待所有轮次完成 */ + hipDeviceSynchronize(); + if (h_rounds) *h_rounds = MAX_SIGN_ROUNDS; +#if BATCH_SIGN_DECOMP_TAIL_ENABLE + { + int done_now = batch_sign_count_done_host(p, B); + if (done_now < B) + done_now = batch_sign_tail_finish(p, B, d_pc); + if (h_done) *h_done = done_now; + if (done_now < 0) return -1; + } +#else + if (h_done) *h_done = batch_sign_count_done_host(p, B); +#endif + return 0; +} + +static int batch_sign_pipeline( + BatchSignPipeline *p, + int batch_count, + const precomp_t *d_pc, + const uint8_t *d_msg, + size_t mlen, + const uint8_t *d_pre, + size_t prelen, + const uint8_t *d_rnd) +{ + return batch_sign_pipeline_ex(p, batch_count, d_pc, d_msg, mlen, + d_pre, prelen, d_rnd, nullptr, nullptr, nullptr); +} + +#endif /* BATCH_SIGN_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_sign_warp.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_sign_warp.cuh new file mode 100644 index 000000000..6a8dd1a8d --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_sign_warp.cuh @@ -0,0 +1,835 @@ +#include "hip/hip_runtime.h" +#ifndef BATCH_SIGN_WARP_CUH +#define BATCH_SIGN_WARP_CUH + +#include +#include +#include + +#include "params.h" +#include "sign.cuh" +#include "batch_ntt.cuh" + +#ifndef BATCH_SIGN_WARP_ENABLE +#define BATCH_SIGN_WARP_ENABLE 1 +#endif + +#ifndef BATCH_SIGN_WARP_PROFILE +#define BATCH_SIGN_WARP_PROFILE 0 +#endif + +#define WP_SIGN_WARP_SIZE 32 + +#define WP_SIGN_WARPS_BLOCK 4 + + +#define WP_SIGN_TPB (WP_SIGN_WARP_SIZE * WP_SIGN_WARPS_BLOCK) + +#if ALGORITHM == ALGO_MLDSA +#define WP_SIGN_SEED_BYTES CRHBYTES +#else +#define WP_SIGN_SEED_BYTES (SEEDBYTES + CRHBYTES) +#endif + +#if ALGORITHM == ALGO_AIGIS +#define WP_SIGN_GAMMA1_BUF_BYTES (STREAM256_BLOCKBYTES + 4) +#else +#define WP_SIGN_GAMMA1_BUF_BYTES (POLY_UNIFORM_GAMMA1_NBLOCKS * STREAM256_BLOCKBYTES) +#endif + +enum { + WP_SIGN_STAT_ATTEMPTS = 0, + WP_SIGN_STAT_REJ_S2 = 1, + WP_SIGN_STAT_REJ_Z = 2, + WP_SIGN_STAT_REJ_T0 = 3, + WP_SIGN_STAT_REJ_HINT = 4, + WP_SIGN_STAT_OK = 5, + WP_SIGN_STAT_COUNT = 6 +}; + +typedef struct { +#if ALGORITHM == ALGO_MLDSA + uint8_t mu[CRHBYTES]; + uint8_t rhoprime[CRHBYTES]; +#else + uint8_t mu[CRHBYTES]; + uint8_t key_mu[SEEDBYTES + CRHBYTES]; +#endif +} wp_sign_cache_t; + +typedef struct { + coeff_t *y; + coeff_t *w; + coeff_t *cp; + coeff_t *tmp; + uint8_t *packed_w1; + uint8_t *mu; + uint8_t *seed; + uint8_t *work; +} wp_sign_smem_t; + +static __host__ __device__ __forceinline__ size_t wp_sign_align(size_t x, size_t a) { + return (x + a - 1u) & ~(a - 1u); +} + +static __host__ __device__ __forceinline__ size_t wp_sign_shared_bytes_per_warp(void) { + size_t off = 0; + off = wp_sign_align(off, 16); + off += (size_t)PARAM_L * PARAM_N * sizeof(coeff_t); + off = wp_sign_align(off, 16); + off += (size_t)PARAM_K * PARAM_N * sizeof(coeff_t); + off = wp_sign_align(off, 16); + off += (size_t)PARAM_N * sizeof(coeff_t); + off = wp_sign_align(off, 16); + off += (size_t)PARAM_N * sizeof(coeff_t); + off += (size_t)PARAM_K * POLYW1_PACKEDBYTES; + off = wp_sign_align(off, 16); + off += CRHBYTES; + off = wp_sign_align(off, 16); + off += WP_SIGN_SEED_BYTES; + off = wp_sign_align(off, 16); + off += WP_SIGN_GAMMA1_BUF_BYTES; + return wp_sign_align(off, 16); +} + +static inline size_t batch_sign_warp_smem_bytes(void) { + return wp_sign_shared_bytes_per_warp() * WP_SIGN_WARPS_BLOCK; +} + +static inline hipError_t batch_sign_warp_set_smem_attributes(void); + +static __device__ __forceinline__ void wp_sign_smem_init( + wp_sign_smem_t *s, unsigned char *base, int warp_slot) +{ + size_t off = (size_t)warp_slot * wp_sign_shared_bytes_per_warp(); + off = wp_sign_align(off, 16); + s->y = (coeff_t *)(base + off); + off += (size_t)PARAM_L * PARAM_N * sizeof(coeff_t); + off = wp_sign_align(off, 16); + s->w = (coeff_t *)(base + off); + off += (size_t)PARAM_K * PARAM_N * sizeof(coeff_t); + off = wp_sign_align(off, 16); + s->cp = (coeff_t *)(base + off); + off += (size_t)PARAM_N * sizeof(coeff_t); + off = wp_sign_align(off, 16); + s->tmp = (coeff_t *)(base + off); + off += (size_t)PARAM_N * sizeof(coeff_t); + off = wp_sign_align(off, 16); + s->packed_w1 = base + off; + off += (size_t)PARAM_K * POLYW1_PACKEDBYTES; + off = wp_sign_align(off, 16); + s->mu = base + off; + off += CRHBYTES; + off = wp_sign_align(off, 16); + s->seed = base + off; + off += WP_SIGN_SEED_BYTES; + off = wp_sign_align(off, 16); + s->work = base + off; +} + +static __device__ __forceinline__ void wp_sign_store_sig( + uint8_t *sig_soa, int inst, int N, unsigned int off, uint8_t v) +{ + sig_soa[(size_t)off * (size_t)N + (size_t)inst] = v; +} + +static __device__ __forceinline__ int wp_sign_any(int pred) { + return __ballot_sync(0xffffffffull, pred) != 0u; +} + +static __device__ __forceinline__ int wp_sign_coeff_chknorm(coeff_t a, int32_t B) { +#if ALGORITHM == ALGO_MLDSA + if (B > (PARAM_Q - 1) / 8) return 1; + int32_t t = a >> 31; + t = a - (t & 2 * a); + return t >= B; +#else + int32_t t = (PARAM_Q - 1) / 2 - a; + t ^= (t >> 31); + t = (PARAM_Q - 1) / 2 - t; + return t >= B; +#endif +} + +static __device__ __forceinline__ int wp_sign_poly_chknorm( + const coeff_t *a, int32_t B, int lane) +{ + int bad = 0; + for (int i = lane; i < PARAM_N; i += WP_SIGN_WARP_SIZE) + bad |= wp_sign_coeff_chknorm(a[i], B); + return wp_sign_any(bad); +} + +static __device__ __noinline__ void wp_sign_sample_y_poly( + coeff_t *dst, const uint8_t *seed, uint16_t nonce, int lane, uint8_t *buf) +{ +#if ALGORITHM == ALGO_MLDSA + if (lane == 0) { + stream256_state state; + stream256_init(&state, seed, nonce); + stream256_squeezeblocks(buf, POLY_UNIFORM_GAMMA1_NBLOCKS, &state); + } + __syncwarp(); +#if PARAM_GAMMA1 == (1 << 17) + for (int i = lane; i < PARAM_N / 4; i += WP_SIGN_WARP_SIZE) { + uint32_t t0 = ((uint32_t)buf[9 * i + 0] | + ((uint32_t)buf[9 * i + 1] << 8) | + ((uint32_t)buf[9 * i + 2] << 16)) & 0x3ffffu; + uint32_t t1 = (((uint32_t)buf[9 * i + 2] >> 2) | + ((uint32_t)buf[9 * i + 3] << 6) | + ((uint32_t)buf[9 * i + 4] << 14)) & 0x3ffffu; + uint32_t t2 = (((uint32_t)buf[9 * i + 4] >> 4) | + ((uint32_t)buf[9 * i + 5] << 4) | + ((uint32_t)buf[9 * i + 6] << 12)) & 0x3ffffu; + uint32_t t3 = (((uint32_t)buf[9 * i + 6] >> 6) | + ((uint32_t)buf[9 * i + 7] << 2) | + ((uint32_t)buf[9 * i + 8] << 10)) & 0x3ffffu; + dst[4 * i + 0] = PARAM_GAMMA1 - (int32_t)t0; + dst[4 * i + 1] = PARAM_GAMMA1 - (int32_t)t1; + dst[4 * i + 2] = PARAM_GAMMA1 - (int32_t)t2; + dst[4 * i + 3] = PARAM_GAMMA1 - (int32_t)t3; + } +#elif PARAM_GAMMA1 == (1 << 19) + for (int i = lane; i < PARAM_N / 2; i += WP_SIGN_WARP_SIZE) { + uint32_t t0 = ((uint32_t)buf[5 * i + 0] | + ((uint32_t)buf[5 * i + 1] << 8) | + ((uint32_t)buf[5 * i + 2] << 16)) & 0xfffffu; + uint32_t t1 = (((uint32_t)buf[5 * i + 2] >> 4) | + ((uint32_t)buf[5 * i + 3] << 4) | + ((uint32_t)buf[5 * i + 4] << 12)) & 0xfffffu; + dst[2 * i + 0] = PARAM_GAMMA1 - (int32_t)t0; + dst[2 * i + 1] = PARAM_GAMMA1 - (int32_t)t1; + } +#endif +#else + stream256_state state; + if (lane == 0) { + aigis_shake256_gamma1_init(&state, seed, nonce); + } + __syncwarp(); + for (int blk = 0; blk < POLY_UNIFORM_GAMMA1_NBLOCKS; ++blk) { + int tail = (blk * STREAM256_BLOCKBYTES) % 5; + int avail = tail + STREAM256_BLOCKBYTES; + int groups = avail / 5; + int produced = (blk * STREAM256_BLOCKBYTES - tail) / 5; + int todo = groups; + if (produced + todo > PARAM_N / 2) + todo = PARAM_N / 2 - produced; + + if (lane == 0) + stream256_squeezeblocks(buf, 1, &state); + __syncwarp(); + + for (int i = lane; i < todo; i += WP_SIGN_WARP_SIZE) { + unsigned int pos = 5u * (unsigned int)i; + uint8_t b0 = (pos + 0 < (unsigned int)tail) + ? buf[STREAM256_BLOCKBYTES + pos + 0] + : buf[pos + 0 - tail]; + uint8_t b1 = (pos + 1 < (unsigned int)tail) + ? buf[STREAM256_BLOCKBYTES + pos + 1] + : buf[pos + 1 - tail]; + uint8_t b2 = (pos + 2 < (unsigned int)tail) + ? buf[STREAM256_BLOCKBYTES + pos + 2] + : buf[pos + 2 - tail]; + uint8_t b3 = (pos + 3 < (unsigned int)tail) + ? buf[STREAM256_BLOCKBYTES + pos + 3] + : buf[pos + 3 - tail]; + uint8_t b4 = (pos + 4 < (unsigned int)tail) + ? buf[STREAM256_BLOCKBYTES + pos + 4] + : buf[pos + 4 - tail]; + uint32_t t0 = b0; + t0 |= (uint32_t)b1 << 8; + t0 |= (uint32_t)b2 << 16; + uint32_t t1 = b2 >> 4; + t1 |= (uint32_t)b3 << 4; + t1 |= (uint32_t)b4 << 12; + t0 &= 0x3ffffu; + t1 &= 0x3ffffu; + int out = produced + i; + dst[2 * out + 0] = PARAM_Q + PARAM_GAMMA1 - 1 - (int32_t)t0; + dst[2 * out + 1] = PARAM_Q + PARAM_GAMMA1 - 1 - (int32_t)t1; + } + __syncwarp(); + + if (lane == 0 && blk + 1 < POLY_UNIFORM_GAMMA1_NBLOCKS) { + int used = groups * 5; + int new_tail = avail - used; + for (int t = 0; t < new_tail; ++t) { + int pos = used + t; + buf[STREAM256_BLOCKBYTES + t] = (pos < tail) + ? buf[STREAM256_BLOCKBYTES + pos] + : buf[pos - tail]; + } + } + __syncwarp(); + } +#endif + __syncwarp(); +} + +static __device__ __noinline__ void wp_sign_pack_z_soa( + uint8_t *sig_soa, int inst, int N, unsigned int off, const coeff_t *a, int lane) +{ +#if PARAM_GAMMA1 == (1 << 17) + for (int i = lane; i < PARAM_N / 4; i += WP_SIGN_WARP_SIZE) { + uint32_t t0 = (uint32_t)(Z_BIAS - a[4 * i + 0]); Z_FIXUP(t0); + uint32_t t1 = (uint32_t)(Z_BIAS - a[4 * i + 1]); Z_FIXUP(t1); + uint32_t t2 = (uint32_t)(Z_BIAS - a[4 * i + 2]); Z_FIXUP(t2); + uint32_t t3 = (uint32_t)(Z_BIAS - a[4 * i + 3]); Z_FIXUP(t3); + wp_sign_store_sig(sig_soa, inst, N, off + 9 * i + 0, (uint8_t)t0); + wp_sign_store_sig(sig_soa, inst, N, off + 9 * i + 1, (uint8_t)(t0 >> 8)); + wp_sign_store_sig(sig_soa, inst, N, off + 9 * i + 2, (uint8_t)((t0 >> 16) | (t1 << 2))); + wp_sign_store_sig(sig_soa, inst, N, off + 9 * i + 3, (uint8_t)(t1 >> 6)); + wp_sign_store_sig(sig_soa, inst, N, off + 9 * i + 4, (uint8_t)((t1 >> 14) | (t2 << 4))); + wp_sign_store_sig(sig_soa, inst, N, off + 9 * i + 5, (uint8_t)(t2 >> 4)); + wp_sign_store_sig(sig_soa, inst, N, off + 9 * i + 6, (uint8_t)((t2 >> 12) | (t3 << 6))); + wp_sign_store_sig(sig_soa, inst, N, off + 9 * i + 7, (uint8_t)(t3 >> 2)); + wp_sign_store_sig(sig_soa, inst, N, off + 9 * i + 8, (uint8_t)(t3 >> 10)); + } +#elif PARAM_GAMMA1 == (1 << 19) + for (int i = lane; i < PARAM_N / 2; i += WP_SIGN_WARP_SIZE) { + uint32_t t0 = (uint32_t)(Z_BIAS - a[2 * i + 0]); Z_FIXUP(t0); + uint32_t t1 = (uint32_t)(Z_BIAS - a[2 * i + 1]); Z_FIXUP(t1); + wp_sign_store_sig(sig_soa, inst, N, off + 5 * i + 0, (uint8_t)t0); + wp_sign_store_sig(sig_soa, inst, N, off + 5 * i + 1, (uint8_t)(t0 >> 8)); + wp_sign_store_sig(sig_soa, inst, N, off + 5 * i + 2, (uint8_t)((t0 >> 16) | (t1 << 4))); + wp_sign_store_sig(sig_soa, inst, N, off + 5 * i + 3, (uint8_t)(t1 >> 4)); + wp_sign_store_sig(sig_soa, inst, N, off + 5 * i + 4, (uint8_t)(t1 >> 12)); + } +#endif +} + +static __device__ __forceinline__ int32_t wp_sign_get_w1_hi( + const uint8_t *packed, int k, int j) +{ + const uint8_t *r = packed + (size_t)k * POLYW1_PACKEDBYTES; +#if PARAM_GAMMA2 == (PARAM_Q - 1) / 88 + int g = j >> 2; + int p = j & 3; + uint8_t b0 = r[3 * g + 0]; + uint8_t b1 = r[3 * g + 1]; + uint8_t b2 = r[3 * g + 2]; + if (p == 0) return (int32_t)(b0 & 0x3fu); + if (p == 1) return (int32_t)(((b0 >> 6) | ((b1 & 0x0fu) << 2)) & 0x3fu); + if (p == 2) return (int32_t)(((b1 >> 4) | ((b2 & 0x03u) << 4)) & 0x3fu); + return (int32_t)((b2 >> 2) & 0x3fu); +#elif PARAM_GAMMA2 == (PARAM_Q - 1) / 32 + uint8_t b = r[j >> 1]; + return (int32_t)((j & 1) ? (b >> 4) : (b & 0x0fu)); +#elif PARAM_GAMMA2 == (PARAM_Q - 1) / 12 + int g = j >> 3; + int p = j & 7; + uint8_t b0 = r[3 * g + 0]; + uint8_t b1 = r[3 * g + 1]; + uint8_t b2 = r[3 * g + 2]; + if (p == 0) return (int32_t)(b0 & 0x07u); + if (p == 1) return (int32_t)((b0 >> 3) & 0x07u); + if (p == 2) return (int32_t)(((b0 >> 6) | ((b1 & 0x01u) << 2)) & 0x07u); + if (p == 3) return (int32_t)((b1 >> 1) & 0x07u); + if (p == 4) return (int32_t)((b1 >> 4) & 0x07u); + if (p == 5) return (int32_t)(((b1 >> 7) | ((b2 & 0x03u) << 1)) & 0x07u); + if (p == 6) return (int32_t)((b2 >> 2) & 0x07u); + return (int32_t)((b2 >> 5) & 0x07u); +#else + return 0; +#endif +} + +static __device__ __noinline__ void wp_sign_pack_w1_poly_from_tmp( + uint8_t *r, const coeff_t *hi, int lane) +{ +#if PARAM_GAMMA2 == (PARAM_Q - 1) / 88 + for (int i = lane; i < PARAM_N / 4; i += WP_SIGN_WARP_SIZE) { + uint32_t a0 = (uint32_t)hi[4 * i + 0]; + uint32_t a1 = (uint32_t)hi[4 * i + 1]; + uint32_t a2 = (uint32_t)hi[4 * i + 2]; + uint32_t a3 = (uint32_t)hi[4 * i + 3]; + r[3 * i + 0] = (uint8_t)(a0 | (a1 << 6)); + r[3 * i + 1] = (uint8_t)((a1 >> 2) | (a2 << 4)); + r[3 * i + 2] = (uint8_t)((a2 >> 4) | (a3 << 2)); + } +#elif PARAM_GAMMA2 == (PARAM_Q - 1) / 32 + for (int i = lane; i < PARAM_N / 2; i += WP_SIGN_WARP_SIZE) { + uint32_t a0 = (uint32_t)hi[2 * i + 0]; + uint32_t a1 = (uint32_t)hi[2 * i + 1]; + r[i] = (uint8_t)(a0 | (a1 << 4)); + } +#elif PARAM_GAMMA2 == (PARAM_Q - 1) / 12 + for (int i = lane; i < PARAM_N / 8; i += WP_SIGN_WARP_SIZE) { + uint32_t a0 = (uint32_t)hi[8 * i + 0]; + uint32_t a1 = (uint32_t)hi[8 * i + 1]; + uint32_t a2 = (uint32_t)hi[8 * i + 2]; + uint32_t a3 = (uint32_t)hi[8 * i + 3]; + uint32_t a4 = (uint32_t)hi[8 * i + 4]; + uint32_t a5 = (uint32_t)hi[8 * i + 5]; + uint32_t a6 = (uint32_t)hi[8 * i + 6]; + uint32_t a7 = (uint32_t)hi[8 * i + 7]; + r[3 * i + 0] = (uint8_t)(a0 | (a1 << 3) | (a2 << 6)); + r[3 * i + 1] = (uint8_t)((a2 >> 2) | (a3 << 1) | + (a4 << 4) | (a5 << 7)); + r[3 * i + 2] = (uint8_t)((a5 >> 1) | (a6 << 2) | (a7 << 5)); + } +#endif + __syncwarp(); +} + +static __device__ __noinline__ void wp_sign_prepare_uncached( + wp_sign_smem_t *s, + const uint8_t *msg, size_t mlen, + const uint8_t *pre, size_t prelen, + const uint8_t *rnd, + const precomp_t *pc, + int lane) +{ + if (lane == 0) { + keccak_state state; +#if ALGORITHM == ALGO_MLDSA + shake256_init(&state); + shake256_absorb(&state, pc->tr, TRBYTES); + shake256_absorb(&state, pre, prelen); + shake256_absorb(&state, msg, mlen); + shake256_finalize(&state); + shake256_squeeze(s->mu, CRHBYTES, &state); + + shake256_init(&state); + shake256_absorb(&state, pc->key, SEEDBYTES); +#if RNDBYTES > 0 + shake256_absorb(&state, rnd, RNDBYTES); +#endif + shake256_absorb(&state, s->mu, CRHBYTES); + shake256_finalize(&state); + shake256_squeeze(s->seed, CRHBYTES, &state); +#else + shake256_init(&state); + shake256_absorb(&state, pc->tr, TRBYTES); + shake256_absorb(&state, msg, mlen); + shake256_finalize(&state); + shake256_squeeze(s->mu, CRHBYTES, &state); + + for (int i = 0; i < SEEDBYTES; ++i) s->seed[i] = pc->key[i]; + for (int i = 0; i < CRHBYTES; ++i) s->seed[SEEDBYTES + i] = s->mu[i]; +#endif + } + __syncwarp(); +} + +static __device__ __noinline__ void wp_sign_prepare_cached( + wp_sign_smem_t *s, const uint8_t *cache_raw, int lane) +{ + if (lane == 0) { + const wp_sign_cache_t *cache = (const wp_sign_cache_t *)cache_raw; + for (int i = 0; i < CRHBYTES; ++i) s->mu[i] = cache->mu[i]; +#if ALGORITHM == ALGO_MLDSA + for (int i = 0; i < CRHBYTES; ++i) s->seed[i] = cache->rhoprime[i]; +#else + for (int i = 0; i < SEEDBYTES + CRHBYTES; ++i) s->seed[i] = cache->key_mu[i]; +#endif + } + __syncwarp(); +} + +static __device__ __noinline__ void wp_sign_matrix_y( + wp_sign_smem_t *s, const precomp_t *pc, uint16_t nonce_base, int lane) +{ + for (int k = 0; k < PARAM_K; ++k) + for (int j = lane; j < PARAM_N; j += WP_SIGN_WARP_SIZE) + s->w[(size_t)k * PARAM_N + j] = 0; + __syncwarp(); + + for (int l = 0; l < PARAM_L; ++l) { + coeff_t *yl = s->y + (size_t)l * PARAM_N; + wp_sign_sample_y_poly(yl, s->seed, GAMMA1_NONCE(nonce_base, l), lane, s->work); + + for (int j = lane; j < PARAM_N; j += WP_SIGN_WARP_SIZE) + s->tmp[j] = yl[j]; + __syncwarp(); + ntt_warp_par(s->tmp, lane); + + for (int k = 0; k < PARAM_K; ++k) { + coeff_t *wk = s->w + (size_t)k * PARAM_N; + const coeff_t *akl = pc->mat[k].vec[l].coeffs; + for (int j = lane; j < PARAM_N; j += WP_SIGN_WARP_SIZE) { + coeff_t prod = montgomery_reduce((coeff2_t)akl[j] * s->tmp[j]); + wk[j] += prod; + } + } + __syncwarp(); + } + + for (int k = 0; k < PARAM_K; ++k) { + coeff_t *wk = s->w + (size_t)k * PARAM_N; + for (int j = lane; j < PARAM_N; j += WP_SIGN_WARP_SIZE) { +#if ALGORITHM == ALGO_MLDSA + wk[j] = reduce32(wk[j]); +#else + wk[j] = barrat_reduce(wk[j]); +#endif + } + __syncwarp(); + invntt_warp_par(wk, lane); + for (int j = lane; j < PARAM_N; j += WP_SIGN_WARP_SIZE) { +#if ALGORITHM == ALGO_MLDSA + int32_t a = caddq(reduce32(wk[j])); + int32_t lo; + int32_t hi = decompose(&lo, a); + wk[j] = lo; + s->tmp[j] = hi; +#else + int32_t a = freeze2q(wk[j]); + int32_t lo; + int32_t hi = decompose(&lo, a); + wk[j] = a; + s->tmp[j] = hi; +#endif + } + __syncwarp(); + wp_sign_pack_w1_poly_from_tmp( + s->packed_w1 + (size_t)k * POLYW1_PACKEDBYTES, s->tmp, lane); + } +} + +static __device__ __noinline__ void wp_sign_make_challenge( + wp_sign_smem_t *s, uint8_t *sig_soa, int inst, int N, int lane) +{ + if (lane == 0) { +#if ALGORITHM == ALGO_MLDSA + keccak_state state; + shake256_init(&state); + shake256_absorb(&state, s->mu, CRHBYTES); + shake256_absorb(&state, s->packed_w1, PARAM_K * POLYW1_PACKEDBYTES); + shake256_finalize(&state); + shake256_squeeze(s->work, CTILDEBYTES, &state); + for (unsigned int i = 0; i < CTILDEBYTES; ++i) + wp_sign_store_sig(sig_soa, inst, N, i, s->work[i]); + poly_challenge((poly *)s->cp, s->work); +#else + poly_challenge((poly *)s->cp, s->mu, s->packed_w1, + PARAM_K * POLYW1_PACKEDBYTES); + unsigned int offset = PARAM_L * POLYZ_PACKEDBYTES + PARAM_OMEGA + PARAM_K; + uint64_t signs = 0; + uint64_t mask = 1; + for (unsigned int i = 0; i < PARAM_N / 8; ++i) { + uint8_t b = 0; + for (unsigned int j = 0; j < 8; ++j) { + coeff_t c = s->cp[8 * i + j]; + if (c != 0) { + b |= (uint8_t)(1u << j); + if (c == (PARAM_Q - 1)) signs |= mask; + mask <<= 1; + } + } + wp_sign_store_sig(sig_soa, inst, N, offset + i, b); + } + offset += PARAM_N / 8; + for (unsigned int i = 0; i < 8; ++i) + wp_sign_store_sig(sig_soa, inst, N, offset + i, (uint8_t)(signs >> (8 * i))); +#endif + } + __syncwarp(); + ntt_warp_par(s->cp, lane); +} + +static __device__ __noinline__ int wp_sign_check_s2( + wp_sign_smem_t *s, const precomp_t *pc, int lane) +{ + for (int k = 0; k < PARAM_K; ++k) { + const coeff_t *sk = pc->s2_ntt.vec[k].coeffs; + coeff_t *wk = s->w + (size_t)k * PARAM_N; + for (int j = lane; j < PARAM_N; j += WP_SIGN_WARP_SIZE) + s->tmp[j] = montgomery_reduce((coeff2_t)s->cp[j] * sk[j]); + __syncwarp(); + invntt_warp_par(s->tmp, lane); + + int bad = 0; + for (int j = lane; j < PARAM_N; j += WP_SIGN_WARP_SIZE) { +#if ALGORITHM == ALGO_MLDSA + int32_t v = reduce32(wk[j] - s->tmp[j]); + wk[j] = v; + bad |= wp_sign_coeff_chknorm(v, PARAM_GAMMA2 - PARAM_BETA2); +#else + int32_t v = freeze4q(wk[j] - s->tmp[j]); + int32_t lo; + int32_t hi = decompose(&lo, v); + lo = freeze2q(lo); + wk[j] = v; + bad |= (hi != wp_sign_get_w1_hi(s->packed_w1, k, j)); + bad |= wp_sign_coeff_chknorm(lo, PARAM_GAMMA2 - PARAM_BETA2); +#endif + } + if (wp_sign_any(bad)) return 1; + __syncwarp(); + } + return 0; +} + +static __device__ __noinline__ int wp_sign_check_pack_z( + wp_sign_smem_t *s, const precomp_t *pc, uint8_t *sig_soa, int inst, int N, int lane) +{ + for (int l = 0; l < PARAM_L; ++l) { + const coeff_t *sl = pc->s1_ntt.vec[l].coeffs; + const coeff_t *yl = s->y + (size_t)l * PARAM_N; + for (int j = lane; j < PARAM_N; j += WP_SIGN_WARP_SIZE) + s->tmp[j] = montgomery_reduce((coeff2_t)s->cp[j] * sl[j]); + __syncwarp(); + invntt_warp_par(s->tmp, lane); + + int bad = 0; + for (int j = lane; j < PARAM_N; j += WP_SIGN_WARP_SIZE) { +#if ALGORITHM == ALGO_MLDSA + int32_t z = reduce32(s->tmp[j] + yl[j]); +#else + int32_t z = freeze4q(s->tmp[j] + yl[j]); +#endif + s->tmp[j] = z; + bad |= wp_sign_coeff_chknorm(z, PARAM_GAMMA1 - PARAM_BETA1); + } + if (wp_sign_any(bad)) return 1; + +#if ALGORITHM == ALGO_MLDSA + unsigned int off = CTILDEBYTES + (unsigned int)l * POLYZ_PACKEDBYTES; +#else + unsigned int off = (unsigned int)l * POLYZ_PACKEDBYTES; +#endif + wp_sign_pack_z_soa(sig_soa, inst, N, off, s->tmp, lane); + __syncwarp(); + } + return 0; +} + +static __device__ __noinline__ int wp_sign_check_t0_accumulate( + wp_sign_smem_t *s, const precomp_t *pc, + uint8_t *sig_soa, int inst, int N, int lane) +{ +#if ALGORITHM == ALGO_AIGIS + const unsigned int hint_off = PARAM_L * POLYZ_PACKEDBYTES; + for (unsigned int i = lane; i < PARAM_OMEGA + PARAM_K; i += WP_SIGN_WARP_SIZE) + wp_sign_store_sig(sig_soa, inst, N, hint_off + i, 0); + __syncwarp(); + + unsigned int hint_count = 0; + int hint_overflow = 0; +#endif + + for (int k = 0; k < PARAM_K; ++k) { + const coeff_t *tk = pc->t0_ntt.vec[k].coeffs; + coeff_t *wk = s->w + (size_t)k * PARAM_N; + for (int j = lane; j < PARAM_N; j += WP_SIGN_WARP_SIZE) + s->tmp[j] = montgomery_reduce((coeff2_t)s->cp[j] * tk[j]); + __syncwarp(); + invntt_warp_par(s->tmp, lane); + + int bad = 0; + for (int j = lane; j < PARAM_N; j += WP_SIGN_WARP_SIZE) { +#if ALGORITHM == ALGO_MLDSA + int32_t ct0 = reduce32(s->tmp[j]); + bad |= wp_sign_coeff_chknorm(ct0, PARAM_GAMMA2); + wk[j] = wk[j] + ct0; +#else + int32_t ct0 = freeze2q(s->tmp[j]); + bad |= wp_sign_coeff_chknorm(ct0, PARAM_GAMMA2); + wk[j] = freeze2q(wk[j] + ct0); +#endif + } + if (wp_sign_any(bad)) return 1; + __syncwarp(); +#if ALGORITHM == ALGO_AIGIS + if (lane == 0) { + for (unsigned int j = 0; j < PARAM_N; ++j) { + int32_t ct0 = freeze2q(s->tmp[j]); + int h = make_hint(wk[j], 2 * PARAM_Q - ct0); + if (h) { + if (hint_count < PARAM_OMEGA) + wp_sign_store_sig(sig_soa, inst, N, + hint_off + hint_count, (uint8_t)j); + hint_count++; + } + } + if (hint_count <= PARAM_OMEGA) + wp_sign_store_sig(sig_soa, inst, N, + hint_off + PARAM_OMEGA + k, + (uint8_t)hint_count); + } + __syncwarp(); +#endif + } +#if ALGORITHM == ALGO_AIGIS + if (lane == 0) + hint_overflow = (hint_count > PARAM_OMEGA); + hint_overflow = __shfl_sync(0xffffffffull, hint_overflow, 0); + if (hint_overflow) return 2; +#endif + return 0; +} + +static __device__ __noinline__ int wp_sign_pack_hints( + wp_sign_smem_t *s, uint8_t *sig_soa, int inst, int N, int lane) +{ +#if ALGORITHM == ALGO_MLDSA + const unsigned int hint_off = CTILDEBYTES + PARAM_L * POLYZ_PACKEDBYTES; + for (unsigned int i = lane; i < PARAM_OMEGA + PARAM_K; i += WP_SIGN_WARP_SIZE) + wp_sign_store_sig(sig_soa, inst, N, hint_off + i, 0); + __syncwarp(); + + unsigned int count = 0; + int overflow = 0; + if (lane == 0) { + for (unsigned int k = 0; k < PARAM_K; ++k) { + coeff_t *wk = s->w + (size_t)k * PARAM_N; + for (unsigned int j = 0; j < PARAM_N; ++j) { + int h = make_hint(wk[j], wp_sign_get_w1_hi(s->packed_w1, k, j)); + if (h) { + if (count < PARAM_OMEGA) + wp_sign_store_sig(sig_soa, inst, N, hint_off + count, (uint8_t)j); + count++; + } + } + if (count <= PARAM_OMEGA) + wp_sign_store_sig(sig_soa, inst, N, + hint_off + PARAM_OMEGA + k, (uint8_t)count); + } + overflow = (count > PARAM_OMEGA); + } + overflow = __shfl_sync(0xffffffffull, overflow, 0); + return overflow; +#else + return 0; +#endif +} + +static __device__ __noinline__ int wp_sign_core( + uint8_t *sig_soa, size_t *siglen_arr, + const uint8_t *msg, size_t mlen, + const uint8_t *pre, size_t prelen, + const uint8_t *rnd, + const uint8_t *cache_raw, + const precomp_t *pc, + int *results, int N, int inst, int cached, + unsigned long long *stats, + wp_sign_smem_t *s, + int lane) +{ + if (cached) + wp_sign_prepare_cached(s, cache_raw, lane); + else + wp_sign_prepare_uncached(s, msg, mlen, pre, prelen, rnd, pc, lane); + +#if BATCH_SIGN_NONCE_DIVERSIFY +#if ALGORITHM == ALGO_AIGIS + uint16_t nonce = (uint16_t)(((unsigned int)inst * PARAM_L) & 0xffffu); +#else + uint16_t nonce = (uint16_t)inst; +#endif +#else + uint16_t nonce = 0; +#endif + + for (;;) { + uint16_t nonce_base = nonce; +#if ALGORITHM == ALGO_AIGIS + nonce = (uint16_t)(nonce + PARAM_L); +#else + nonce = (uint16_t)(nonce + 1); +#endif + if (lane == 0 && stats) atomicAdd(&stats[WP_SIGN_STAT_ATTEMPTS], 1ull); + + wp_sign_matrix_y(s, pc, nonce_base, lane); + wp_sign_make_challenge(s, sig_soa, inst, N, lane); + + if (wp_sign_check_s2(s, pc, lane)) { + if (lane == 0 && stats) atomicAdd(&stats[WP_SIGN_STAT_REJ_S2], 1ull); + continue; + } + if (wp_sign_check_pack_z(s, pc, sig_soa, inst, N, lane)) { + if (lane == 0 && stats) atomicAdd(&stats[WP_SIGN_STAT_REJ_Z], 1ull); + continue; + } + int t0_status = wp_sign_check_t0_accumulate(s, pc, sig_soa, inst, N, lane); + if (t0_status) { + if (lane == 0 && stats) { + atomicAdd(&stats[(t0_status == 2) + ? WP_SIGN_STAT_REJ_HINT + : WP_SIGN_STAT_REJ_T0], 1ull); + } + continue; + } +#if ALGORITHM == ALGO_MLDSA + if (wp_sign_pack_hints(s, sig_soa, inst, N, lane)) { + if (lane == 0 && stats) atomicAdd(&stats[WP_SIGN_STAT_REJ_HINT], 1ull); + continue; + } +#endif + + if (lane == 0) { + siglen_arr[inst] = CRYPTO_BYTES; + results[inst] = 0; + if (stats) atomicAdd(&stats[WP_SIGN_STAT_OK], 1ull); + } + return 0; + } +} + +__global__ void __launch_bounds__(WP_SIGN_TPB, 1) +kernel_batch_sign_warp_precomp( + uint8_t *sig_soa, size_t *siglen_arr, + const uint8_t *msg, size_t mlen, + const uint8_t *pre, size_t prelen, + const uint8_t *rnd, + const precomp_t *pc, + int *results, int N, int base_idx, + unsigned long long *stats) +{ + extern __shared__ unsigned char smem[]; + int lane = threadIdx.x & (WP_SIGN_WARP_SIZE - 1); + int warp_slot = threadIdx.x >> 5; + int inst = base_idx + (int)blockIdx.x * WP_SIGN_WARPS_BLOCK + warp_slot; + if (inst >= N) return; + + wp_sign_smem_t s; + wp_sign_smem_init(&s, smem, warp_slot); + wp_sign_core(sig_soa, siglen_arr, msg, mlen, pre, prelen, rnd, NULL, + pc, results, N, inst, 0, stats, &s, lane); +} + +__global__ void __launch_bounds__(WP_SIGN_TPB, 1) +kernel_batch_sign_warp_precomp_cached( + uint8_t *sig_soa, size_t *siglen_arr, + const uint8_t *cache_raw, + const precomp_t *pc, + int *results, int N, int base_idx, + unsigned long long *stats) +{ + extern __shared__ unsigned char smem[]; + int lane = threadIdx.x & (WP_SIGN_WARP_SIZE - 1); + int warp_slot = threadIdx.x >> 5; + int inst = base_idx + (int)blockIdx.x * WP_SIGN_WARPS_BLOCK + warp_slot; + if (inst >= N) return; + + wp_sign_smem_t s; + wp_sign_smem_init(&s, smem, warp_slot); + wp_sign_core(sig_soa, siglen_arr, NULL, 0, NULL, 0, NULL, cache_raw, + pc, results, N, inst, 1, stats, &s, lane); +} + +__global__ void kernel_wp_sign_sig_soa_to_aos( + uint8_t *sig_aos, const uint8_t *sig_soa, int N) +{ + size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x; + size_t total = (size_t)N * CRYPTO_BYTES; + if (idx >= total) return; + int inst = (int)(idx / CRYPTO_BYTES); + int byte = (int)(idx - (size_t)inst * CRYPTO_BYTES); + sig_aos[idx] = sig_soa[(size_t)byte * (size_t)N + (size_t)inst]; +} + +static inline hipError_t batch_sign_warp_set_smem_attributes(void) { + size_t smem = batch_sign_warp_smem_bytes(); + hipError_t e = hipFuncSetAttribute(reinterpret_cast(kernel_batch_sign_warp_precomp), + hipFuncAttributeMaxDynamicSharedMemorySize, + (int)smem); + if (e != hipSuccess) return e; + return hipFuncSetAttribute(reinterpret_cast(kernel_batch_sign_warp_precomp_cached), + hipFuncAttributeMaxDynamicSharedMemorySize, + (int)smem); +} + +#endif /* BATCH_SIGN_WARP_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_verify.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_verify.cuh new file mode 100644 index 000000000..10d084b31 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/batch_verify.cuh @@ -0,0 +1,793 @@ +#include "hip/hip_runtime.h" +/* + * batch_verify.cuh — 分解式批量验证 pipeline + * + * 核心优化: + * 1. 矩阵 A 和 t1_hat 只计算一次, 所有 instance 共享 (precompute) + * 2. NTT(z) 使用 shared-memory 批量 kernel + * 3. 矩阵向量乘使用 2D grid (batch × K) + * 4. SoA 内存布局: z[poly_idx][inst][coeff] + * 5. 逐系数操作使用 256 threads/block 批量 kernel + * + * Pipeline: + * [0] 预计算: pk → A(NTT域), t1< +#include +#include +#include "params.h" +#include "reduce.cuh" +#include "ntt.cuh" +#include "rounding.cuh" +#include "fips202.cuh" +#include "batch_ntt.cuh" +#include "batch_ops.cuh" +#include "poly.cuh" +#include "polyvec.cuh" +#include "packing.cuh" +#include "sign.cuh" +#include "symmetric.cuh" + +/* ================================================================ + * 缓冲区结构体 + * ================================================================ */ +struct BatchVerifyBuffers { + /* 预计算共享材料 (来自公钥, 只计算一次) */ + coeff_t *d_mat; /* K * L * N — 矩阵 A (NTT 域) */ + coeff_t *d_t1_hat; /* K * N — t1 << D (NTT 域) */ + unsigned char *d_tr; /* TRBYTES — H(pk) */ + + /* 每批次工作缓冲区 — SoA: [poly_idx][batch][coeff] */ + coeff_t *d_z; /* L * B * N */ + coeff_t *d_h; /* K * B * N */ + coeff_t *d_cp; /* B * N — 挑战多项式 */ + coeff_t *d_w; /* K * B * N */ + coeff_t *d_w1; /* K * B * N */ + unsigned char *d_mu; /* B * CRHBYTES */ + int *d_results; /* B */ + unsigned char *d_raw_sigs;/* B * CRYPTO_BYTES */ + + /* 挑战原始数据: + * Aigis: B * N * sizeof(coeff_t) — 完整挑战多项式 + * ML-DSA: B * CTILDEBYTES — 挑战种子 */ + unsigned char *d_cbuf; + + int max_batch; +}; + +/* ================================================================ + * 预计算 kernel — 从公钥提取矩阵 A 和 t1_hat + * ================================================================ */ +__global__ void batch_verify_precompute_kernel( + coeff_t * __restrict__ d_mat, + coeff_t * __restrict__ d_t1_hat, + unsigned char * __restrict__ d_tr, + const unsigned char * __restrict__ pk) +{ + unsigned char rho[SEEDBYTES]; + polyveck t1; + polyvecl mat[PARAM_K]; + + unpack_pk(rho, &t1, pk); + polyvec_matrix_expand(mat, rho); + + /* 存储矩阵 A 到 flat 数组 */ + for (int i = 0; i < PARAM_K; i++) + for (int j = 0; j < PARAM_L; j++) + for (int c = 0; c < PARAM_N; c++) + d_mat[(i * PARAM_L + j) * PARAM_N + c] + = mat[i].vec[j].coeffs[c]; + + /* t1 << D, 然后 NTT */ + for (int i = 0; i < PARAM_K; i++) { + for (int c = 0; c < PARAM_N; c++) + d_t1_hat[i * PARAM_N + c] = t1.vec[i].coeffs[c] << PARAM_D; + ntt(d_t1_hat + i * PARAM_N); + } + + /* tr = H(pk) */ + shake256(d_tr, TRBYTES, pk, CRYPTO_PUBLICKEYBYTES); +} + +/* ================================================================ + * 解包签名 kernel — SoA 输出 + * + * 输出: d_z[poly][inst][coeff], d_h[poly][inst][coeff], d_cbuf + * ================================================================ */ + +#if ALGORITHM == ALGO_MLDSA + +__global__ void __launch_bounds__(64) +batch_verify_unpack_kernel( + coeff_t * __restrict__ d_z, + coeff_t * __restrict__ d_h, + unsigned char * __restrict__ d_cbuf, + int * __restrict__ d_results, + const unsigned char * __restrict__ d_raw_sigs, + int batch_count) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= batch_count) return; + + const uint8_t *sig = d_raw_sigs + (size_t)idx * CRYPTO_BYTES; + uint8_t *c_out = d_cbuf + (size_t)idx * CTILDEBYTES; + + /* 解包 ctilde */ + for (int i = 0; i < CTILDEBYTES; i++) + c_out[i] = sig[i]; + const uint8_t *sp = sig + CTILDEBYTES; + + /* 解包 z → SoA */ +#if PARAM_GAMMA1 == (1 << 17) + for (int l = 0; l < PARAM_L; l++) { + const uint8_t *a = sp + l * POLYZ_PACKEDBYTES; + coeff_t *r = d_z + (size_t)l * batch_count * PARAM_N + (size_t)idx * PARAM_N; + for (unsigned int i = 0; i < PARAM_N/4; i++) { + r[4*i+0] = a[9*i+0]; + r[4*i+0] |= (uint32_t)a[9*i+1] << 8; + r[4*i+0] |= (uint32_t)a[9*i+2] << 16; + r[4*i+0] &= 0x3FFFF; + r[4*i+1] = a[9*i+2] >> 2; + r[4*i+1] |= (uint32_t)a[9*i+3] << 6; + r[4*i+1] |= (uint32_t)a[9*i+4] << 14; + r[4*i+1] &= 0x3FFFF; + r[4*i+2] = a[9*i+4] >> 4; + r[4*i+2] |= (uint32_t)a[9*i+5] << 4; + r[4*i+2] |= (uint32_t)a[9*i+6] << 12; + r[4*i+2] &= 0x3FFFF; + r[4*i+3] = a[9*i+6] >> 6; + r[4*i+3] |= (uint32_t)a[9*i+7] << 2; + r[4*i+3] |= (uint32_t)a[9*i+8] << 10; + r[4*i+3] &= 0x3FFFF; + r[4*i+0] = PARAM_GAMMA1 - r[4*i+0]; + r[4*i+1] = PARAM_GAMMA1 - r[4*i+1]; + r[4*i+2] = PARAM_GAMMA1 - r[4*i+2]; + r[4*i+3] = PARAM_GAMMA1 - r[4*i+3]; + } + } +#elif PARAM_GAMMA1 == (1 << 19) + for (int l = 0; l < PARAM_L; l++) { + const uint8_t *a = sp + l * POLYZ_PACKEDBYTES; + coeff_t *r = d_z + (size_t)l * batch_count * PARAM_N + (size_t)idx * PARAM_N; + for (unsigned int i = 0; i < PARAM_N/2; i++) { + r[2*i+0] = a[5*i+0]; + r[2*i+0] |= (uint32_t)a[5*i+1] << 8; + r[2*i+0] |= (uint32_t)a[5*i+2] << 16; + r[2*i+0] &= 0xFFFFF; + r[2*i+1] = a[5*i+2] >> 4; + r[2*i+1] |= (uint32_t)a[5*i+3] << 4; + r[2*i+1] |= (uint32_t)a[5*i+4] << 12; + r[2*i+0] = PARAM_GAMMA1 - r[2*i+0]; + r[2*i+1] = PARAM_GAMMA1 - r[2*i+1]; + } + } +#endif + sp += PARAM_L * POLYZ_PACKEDBYTES; + + /* 解包 hint → SoA */ + unsigned int k = 0; + int valid = 1; + for (int i = 0; i < PARAM_K; i++) { + if (sp[PARAM_OMEGA + i] < k || sp[PARAM_OMEGA + i] > PARAM_OMEGA) { valid = 0; break; } + for (unsigned int j = k; j < sp[PARAM_OMEGA + i]; j++) { + if (j > k && sp[j] <= sp[j-1]) { valid = 0; break; } + d_h[(size_t)i * batch_count * PARAM_N + (size_t)idx * PARAM_N + sp[j]] = 1; + } + if (!valid) break; + k = sp[PARAM_OMEGA + i]; + } + if (valid) { + for (unsigned int j = k; j < PARAM_OMEGA; j++) + if (sp[j]) { valid = 0; break; } + } + if (!valid) d_results[idx] = -1; +} + +#elif ALGORITHM == ALGO_AIGIS + +__global__ void __launch_bounds__(64) +batch_verify_unpack_kernel( + coeff_t * __restrict__ d_z, + coeff_t * __restrict__ d_h, + unsigned char * __restrict__ d_cbuf, + int * __restrict__ d_results, + const unsigned char * __restrict__ d_raw_sigs, + int batch_count) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= batch_count) return; + + const unsigned char *sig = d_raw_sigs + (size_t)inst * CRYPTO_BYTES; + + /* 解包 z → SoA (Aigis: 18-bit GAMMA1=2^17) */ + for (unsigned int i = 0; i < PARAM_L; i++) { + const unsigned char *src = sig + i * POLYZ_PACKEDBYTES; + coeff_t *dst = d_z + (size_t)i * batch_count * PARAM_N + (size_t)inst * PARAM_N; + for (unsigned int ii = 0; ii < PARAM_N / 4; ++ii) { + const unsigned char *a = src + 9 * ii; + int32_t r0, r1, r2, r3; + r0 = a[0]; r0 |= (uint32_t)a[1] << 8; r0 |= (uint32_t)(a[2] & 0x03) << 16; + r0 = PARAM_GAMMA1 - 1 - r0; r0 += ((int32_t)r0 >> 31) & PARAM_Q; + r1 = a[2] >> 2; r1 |= (uint32_t)a[3] << 6; r1 |= (uint32_t)(a[4] & 0x0F) << 14; + r1 = PARAM_GAMMA1 - 1 - r1; r1 += ((int32_t)r1 >> 31) & PARAM_Q; + r2 = a[4] >> 4; r2 |= (uint32_t)a[5] << 4; r2 |= (uint32_t)(a[6] & 0x3F) << 12; + r2 = PARAM_GAMMA1 - 1 - r2; r2 += ((int32_t)r2 >> 31) & PARAM_Q; + r3 = a[6] >> 6; r3 |= (uint32_t)a[7] << 2; r3 |= (uint32_t)a[8] << 10; + r3 = PARAM_GAMMA1 - 1 - r3; r3 += ((int32_t)r3 >> 31) & PARAM_Q; + dst[4*ii+0] = r0; dst[4*ii+1] = r1; dst[4*ii+2] = r2; dst[4*ii+3] = r3; + } + } + sig += PARAM_L * POLYZ_PACKEDBYTES; + + /* 解包 hint → SoA */ + unsigned int k = 0; + int fail = 0; + for (unsigned int i = 0; i < PARAM_K; i++) { + coeff_t *hdst = d_h + (size_t)i * batch_count * PARAM_N + (size_t)inst * PARAM_N; + if (sig[PARAM_OMEGA + i] < k || sig[PARAM_OMEGA + i] > PARAM_OMEGA) { fail = 1; break; } + for (unsigned int j = k; j < sig[PARAM_OMEGA + i]; j++) { + if (j > k && sig[j] <= sig[j - 1]) { fail = 1; break; } + hdst[sig[j]] = 1; + } + if (fail) break; + k = sig[PARAM_OMEGA + i]; + } + if (!fail) { + for (unsigned int j = k; j < PARAM_OMEGA; j++) + if (sig[j]) { fail = 1; break; } + } + sig += PARAM_OMEGA + PARAM_K; + + /* 解包挑战多项式 (Aigis: 位图 + 符号) */ + coeff_t *cdst = (coeff_t *)(d_cbuf + (size_t)inst * PARAM_N * sizeof(coeff_t)); + for (unsigned int i = 0; i < PARAM_N; i++) cdst[i] = 0; + if (!fail) { + uint64_t signs = 0; + for (unsigned int i = 0; i < 8; i++) + signs |= (uint64_t)sig[PARAM_N / 8 + i] << (8 * i); + uint64_t mask = 1; + for (unsigned int i = 0; i < PARAM_N / 8; i++) { + for (unsigned int j = 0; j < 8; j++) { + if ((sig[i] >> j) & 0x01) { + cdst[8 * i + j] = (signs & mask) ? PARAM_Q - 1 : 1; + mask <<= 1; + } + } + } + } + + d_results[inst] = fail ? -1 : 0; +} + +#endif /* ALGORITHM unpack kernel */ + +/* ================================================================ + * 范数检查 z kernel — 共用 + * ================================================================ */ +__global__ void batch_verify_chknorm_z_kernel( + int * __restrict__ d_results, + const coeff_t * __restrict__ d_z, + int batch_count) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= batch_count) return; + if (d_results[inst] != 0) return; + +#if ALGORITHM == ALGO_AIGIS + const int32_t bound = PARAM_GAMMA1 - PARAM_BETA1; + for (unsigned int i = 0; i < PARAM_L; i++) { + const coeff_t *zp = d_z + (size_t)i * batch_count * PARAM_N + (size_t)inst * PARAM_N; + for (unsigned int j = 0; j < PARAM_N; j++) { + int32_t t = (PARAM_Q - 1) / 2 - (int32_t)zp[j]; + t ^= (t >> 31); + t = (PARAM_Q - 1) / 2 - t; + if (t >= bound) { d_results[inst] = -1; return; } + } + } +#else /* ALGO_MLDSA */ + const int32_t bound = PARAM_GAMMA1 - PARAM_BETA1; + for (int l = 0; l < PARAM_L; l++) { + const coeff_t *zp = d_z + (size_t)l * batch_count * PARAM_N + (size_t)inst * PARAM_N; + for (int j = 0; j < PARAM_N; j++) { + int32_t t = zp[j]; + int32_t mask = t >> 31; + t = t - (mask & 2*t); + if (t >= bound) { d_results[inst] = -1; return; } + } + } +#endif +} + +/* ================================================================ + * 矩阵向量乘 kernel — 共用 (共享矩阵 A) + * + * w[row] = Σ_{col} A[row][col] · z_hat[col] + * 注意: A 不含 batch 维度, z 是 SoA 布局 + * grid: (batch_count, PARAM_K) + * block: PARAM_N threads + * ================================================================ */ +__global__ void batch_verify_matvec_kernel( + coeff_t * __restrict__ d_w, + const coeff_t * __restrict__ d_mat, + const coeff_t * __restrict__ d_z_ntt, + int batch_count) +{ + int inst = blockIdx.x; + int row = blockIdx.y; + if (inst >= batch_count) return; + int tid = threadIdx.x; + + coeff_t acc = 0; + for (int col = 0; col < PARAM_L; col++) { + /* A 共享: d_mat[(row * L + col) * N + tid] */ + coeff_t a = d_mat[(row * PARAM_L + col) * PARAM_N + tid]; + /* z SoA: d_z_ntt[col * B * N + inst * N + tid] */ + coeff_t b = d_z_ntt[(size_t)col * batch_count * PARAM_N + + (size_t)inst * PARAM_N + tid]; + acc += (coeff_t)montgomery_reduce((coeff2_t)a * b); + } + /* w SoA output: [row * B * N + inst * N + tid] */ + d_w[(size_t)row * batch_count * PARAM_N + (size_t)inst * PARAM_N + tid] = coeff_reduce(acc); +} + +/* ================================================================ + * 融合核: w -= cp · t1_hat (2D grid) + * ================================================================ */ +__global__ void batch_verify_sub_cp_t1_kernel( + coeff_t * __restrict__ d_w, + const coeff_t * __restrict__ d_cp, + const coeff_t * __restrict__ d_t1_hat, + int batch_count) +{ + int inst = blockIdx.x; + int k = blockIdx.y; + if (inst >= batch_count) return; + int tid = threadIdx.x; + + coeff_t c = d_cp[(size_t)inst * PARAM_N + tid]; + coeff_t t = d_t1_hat[k * PARAM_N + tid]; + coeff_t prod = (coeff_t)montgomery_reduce((coeff2_t)c * t); + size_t idx = (size_t)k * batch_count * PARAM_N + (size_t)inst * PARAM_N + tid; + d_w[idx] = coeff_sub(d_w[idx], prod); +} + +/* ================================================================ + * 挑战多项式生成 kernel + * ================================================================ */ +#if ALGORITHM == ALGO_MLDSA + +__global__ void __launch_bounds__(64) +batch_verify_challenge_kernel( + coeff_t * __restrict__ d_cp, + const unsigned char * __restrict__ d_cbuf, + int batch_count) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= batch_count) return; + + const uint8_t *c_seed = d_cbuf + (size_t)idx * CTILDEBYTES; + coeff_t *cp = d_cp + (size_t)idx * PARAM_N; + + unsigned int i, b, pos; + uint64_t signs; + uint8_t buf[SHAKE256_RATE]; + keccak_state state; + + shake256_init(&state); + shake256_absorb(&state, c_seed, CTILDEBYTES); + shake256_finalize(&state); + shake256_squeezeblocks(buf, 1, &state); + + signs = 0; + for (i = 0; i < 8; i++) + signs |= (uint64_t)buf[i] << (8*i); + pos = 8; + + for (i = 0; i < PARAM_N; i++) cp[i] = 0; + for (i = PARAM_N - PARAM_TAU; i < PARAM_N; i++) { + do { + if (pos >= SHAKE256_RATE) { + shake256_squeezeblocks(buf, 1, &state); + pos = 0; + } + b = buf[pos++]; + } while (b > i); + cp[i] = cp[b]; + cp[b] = 1 - 2*(signs & 1); + signs >>= 1; + } +} + +#endif /* ALGO_MLDSA challenge kernel */ + +/* ================================================================ + * mu 计算 kernel + * ================================================================ */ +__global__ void batch_verify_compute_mu_kernel( + unsigned char * __restrict__ d_mu, + const unsigned char * __restrict__ d_tr, + const unsigned char * __restrict__ d_msgs, + size_t mlen, + const unsigned char * __restrict__ d_pre, + size_t prelen, + int batch_count) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= batch_count) return; + +#if ALGORITHM == ALGO_AIGIS + /* Aigis: mu = shake256(tr || m_i) */ + keccak_state state; + shake256_init(&state); + shake256_absorb(&state, d_tr, CRHBYTES); + shake256_absorb(&state, d_msgs + (size_t)inst * mlen, mlen); + shake256_finalize(&state); + shake256_squeeze(d_mu + (size_t)inst * CRHBYTES, CRHBYTES, &state); +#else + /* ML-DSA: mu = shake256(tr || pre || m) */ + keccak_state state; + shake256_init(&state); + shake256_absorb(&state, d_tr, TRBYTES); + shake256_absorb(&state, d_pre, prelen); + shake256_absorb(&state, d_msgs + (size_t)inst * mlen, mlen); + shake256_finalize(&state); + shake256_squeeze(d_mu + (size_t)inst * CRHBYTES, CRHBYTES, &state); +#endif +} + +/* ================================================================ + * 最终比较 kernel + * ================================================================ */ +#if ALGORITHM == ALGO_AIGIS + +__global__ void __launch_bounds__(32) +batch_verify_compare_kernel( + int * __restrict__ d_results, + const unsigned char * __restrict__ d_mu, + const coeff_t * __restrict__ d_w1, + const unsigned char * __restrict__ d_cbuf, + int batch_count) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= batch_count) return; + if (d_results[inst] != 0) return; + + /* 打包 w1 并构造 hash 输入 */ + unsigned char inbuf[CRHBYTES + PARAM_K * POLYW1_PACKEDBYTES]; + const unsigned char *my_mu = d_mu + (size_t)inst * CRHBYTES; + for (unsigned int i = 0; i < CRHBYTES; i++) inbuf[i] = my_mu[i]; + + for (unsigned int ki = 0; ki < PARAM_K; ki++) { + const coeff_t *w1_poly = d_w1 + (size_t)ki * batch_count * PARAM_N + + (size_t)inst * PARAM_N; + unsigned char *r = inbuf + CRHBYTES + ki * POLYW1_PACKEDBYTES; + /* Aigis w1 packing: 3 bits per coeff, 8 coeffs per 3 bytes */ + for (unsigned int i = 0; i < PARAM_N / 8; i++) { + r[3*i+0] = w1_poly[8*i+0] | (w1_poly[8*i+1] << 3) | (w1_poly[8*i+2] << 6); + r[3*i+1] = (w1_poly[8*i+2] >> 2) | (w1_poly[8*i+3] << 1) + | (w1_poly[8*i+4] << 4) | (w1_poly[8*i+5] << 7); + r[3*i+2] = (w1_poly[8*i+5] >> 1) | (w1_poly[8*i+6] << 2) | (w1_poly[8*i+7] << 5); + } + } + + /* 重算挑战多项式: SampleInBall(H(mu || w1_packed)) */ + unsigned char outbuf[SHAKE256_RATE]; + keccak_state state; + shake256_absorb_once(&state, inbuf, CRHBYTES + PARAM_K * POLYW1_PACKEDBYTES); + shake256_squeezeblocks(outbuf, 1, &state); + + uint64_t signs = 0; + for (unsigned int i = 0; i < 8; i++) signs |= (uint64_t)outbuf[i] << (8 * i); + unsigned int pos = 8; + uint64_t mask = 1; + + coeff_t cp[PARAM_N]; + for (unsigned int i = 0; i < PARAM_N; i++) cp[i] = 0; + + for (unsigned int i = 196; i < 256; i++) { + unsigned int b; + do { + if (pos >= SHAKE256_RATE) { + shake256_squeezeblocks(outbuf, 1, &state); + pos = 0; + } + b = outbuf[pos++]; + } while (b > i); + cp[i] = cp[b]; + cp[b] = (signs & mask) ? PARAM_Q - 1 : 1; + mask <<= 1; + } + + /* 与解包的原始挑战多项式逐系数比较 */ + const coeff_t *c_orig = (const coeff_t *)(d_cbuf + (size_t)inst * PARAM_N * sizeof(coeff_t)); + for (unsigned int i = 0; i < PARAM_N; i++) { + if (c_orig[i] != cp[i]) { + d_results[inst] = -1; return; + } + } +} + +#elif ALGORITHM == ALGO_MLDSA + +__global__ void __launch_bounds__(128) +batch_verify_compare_kernel( + int * __restrict__ d_results, + const unsigned char * __restrict__ d_mu, + const coeff_t * __restrict__ d_w1, + const unsigned char * __restrict__ d_cbuf, + int batch_count) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= batch_count) return; + if (d_results[idx] != 0) return; /* 已失败的实例跳过 */ + + const uint8_t *c_orig = d_cbuf + (size_t)idx * CTILDEBYTES; + + keccak_state state; + shake256_init(&state); + shake256_absorb(&state, d_mu + (size_t)idx * CRHBYTES, CRHBYTES); + + uint8_t w1_pack[POLYW1_PACKEDBYTES]; + for (int k = 0; k < PARAM_K; k++) { + const coeff_t *w1k = d_w1 + (size_t)k * batch_count * PARAM_N + (size_t)idx * PARAM_N; +#if PARAM_GAMMA2 == (PARAM_Q-1)/88 + for (unsigned int i = 0; i < PARAM_N/4; i++) { + w1_pack[3*i+0] = w1k[4*i+0]; + w1_pack[3*i+0] |= w1k[4*i+1] << 6; + w1_pack[3*i+1] = w1k[4*i+1] >> 2; + w1_pack[3*i+1] |= w1k[4*i+2] << 4; + w1_pack[3*i+2] = w1k[4*i+2] >> 4; + w1_pack[3*i+2] |= w1k[4*i+3] << 2; + } +#elif PARAM_GAMMA2 == (PARAM_Q-1)/32 + for (unsigned int i = 0; i < PARAM_N/2; i++) + w1_pack[i] = w1k[2*i+0] | (w1k[2*i+1] << 4); +#endif + shake256_absorb(&state, w1_pack, POLYW1_PACKEDBYTES); + } + + shake256_finalize(&state); + + uint8_t c2[CTILDEBYTES]; + shake256_squeeze(c2, CTILDEBYTES, &state); + + int result = 0; + for (unsigned int i = 0; i < CTILDEBYTES; i++) { + if (c_orig[i] != c2[i]) { result = -1; break; } + } + d_results[idx] = result; +} + +#endif /* ALGORITHM compare kernel */ + +/* ================================================================ + * Host API — 缓冲区分配/释放 + * ================================================================ */ + +static int batch_verify_alloc(BatchVerifyBuffers *buf, int max_batch) { + memset(buf, 0, sizeof(*buf)); + buf->max_batch = max_batch; + size_t B = max_batch; + size_t N = PARAM_N; + +#define BV_TRY(ptr, sz) do { \ + if (hipMalloc(&(ptr), (sz)) != hipSuccess) { hipGetLastError(); return -1; } \ +} while(0) + + BV_TRY(buf->d_mat, PARAM_K * PARAM_L * N * sizeof(coeff_t)); + BV_TRY(buf->d_t1_hat, PARAM_K * N * sizeof(coeff_t)); + BV_TRY(buf->d_tr, TRBYTES); + BV_TRY(buf->d_z, PARAM_L * B * N * sizeof(coeff_t)); + BV_TRY(buf->d_h, PARAM_K * B * N * sizeof(coeff_t)); + BV_TRY(buf->d_cp, B * N * sizeof(coeff_t)); + BV_TRY(buf->d_w, PARAM_K * B * N * sizeof(coeff_t)); + BV_TRY(buf->d_w1, PARAM_K * B * N * sizeof(coeff_t)); + BV_TRY(buf->d_mu, B * CRHBYTES); + BV_TRY(buf->d_results, B * sizeof(int)); + BV_TRY(buf->d_raw_sigs, B * CRYPTO_BYTES); + +#if ALGORITHM == ALGO_AIGIS + BV_TRY(buf->d_cbuf, B * N * sizeof(coeff_t)); +#else + BV_TRY(buf->d_cbuf, B * CTILDEBYTES); +#endif + +#undef BV_TRY + return 0; +} + +static void batch_verify_free(BatchVerifyBuffers *buf) { + hipFree(buf->d_mat); hipFree(buf->d_t1_hat); + hipFree(buf->d_tr); hipFree(buf->d_z); + hipFree(buf->d_h); hipFree(buf->d_cp); + hipFree(buf->d_w); hipFree(buf->d_w1); + hipFree(buf->d_mu); hipFree(buf->d_results); + hipFree(buf->d_raw_sigs); hipFree(buf->d_cbuf); + memset(buf, 0, sizeof(*buf)); +} + +/* ================================================================ + * 批量验证核心 pipeline + * + * 前置条件: buf->d_raw_sigs 已经准备好 (device 端 AoS 签名数据) + * ================================================================ */ +static int batch_verify_pipeline_core( + BatchVerifyBuffers *buf, + const unsigned char *d_msgs, + size_t mlen, + const unsigned char *d_pre, + size_t prelen, + int batch_count, + int *h_results, + hipStream_t stream = 0) +{ + if (batch_count <= 0 || batch_count > buf->max_batch) return -1; + + /* 初始化 results 和 hint 缓冲区 */ + hipMemsetAsync(buf->d_results, 0, batch_count * sizeof(int), stream); + hipMemsetAsync(buf->d_h, 0, + (size_t)batch_count * PARAM_K * PARAM_N * sizeof(coeff_t), + stream); + + /* [1] 解包签名 → z, h, challenge 数据 */ + { + int tpb = 64, nblk = (batch_count + tpb - 1) / tpb; + batch_verify_unpack_kernel<<>>( + buf->d_z, buf->d_h, buf->d_cbuf, buf->d_results, + buf->d_raw_sigs, batch_count); + } + + /* [2] 范数检查 z */ + { + int tpb = 64, nblk = (batch_count + tpb - 1) / tpb; + batch_verify_chknorm_z_kernel<<>>( + buf->d_results, buf->d_z, batch_count); + } + + /* [3] NTT(z) — shared-memory batch */ + launch_batch_ntt(buf->d_z, batch_count * PARAM_L, stream); + + /* [4] w = A · z_hat (共享矩阵, 2D grid) */ + { + dim3 grid(batch_count, PARAM_K); + batch_verify_matvec_kernel<<>>( + buf->d_w, buf->d_mat, buf->d_z, batch_count); + } + + /* [5] 生成挑战多项式 cp */ +#if ALGORITHM == ALGO_AIGIS + /* Aigis: d_cbuf 已包含完整挑战多项式, 直接复制 */ + hipMemcpyAsync(buf->d_cp, buf->d_cbuf, + (size_t)batch_count * PARAM_N * sizeof(coeff_t), + hipMemcpyDeviceToDevice, stream); +#else + /* ML-DSA: 从 ctilde 种子重建挑战多项式 */ + { + int tpb = 64, nblk = (batch_count + tpb - 1) / tpb; + batch_verify_challenge_kernel<<>>( + buf->d_cp, buf->d_cbuf, batch_count); + } +#endif + + /* [6] NTT(cp) */ + launch_batch_ntt(buf->d_cp, batch_count, stream); + + /* [7] w -= cp · t1_hat (2D grid) */ + { + dim3 grid(batch_count, PARAM_K); + batch_verify_sub_cp_t1_kernel<<>>( + buf->d_w, buf->d_cp, buf->d_t1_hat, batch_count); + } + + /* [8] reduce + INVNTT + normalize */ + launch_batch_reduce(buf->d_w, batch_count * PARAM_K * PARAM_N, stream); + launch_batch_invntt(buf->d_w, batch_count * PARAM_K, stream); +#if ALGORITHM == ALGO_AIGIS + launch_batch_freeze2q(buf->d_w, PARAM_K * batch_count, stream); +#else + launch_batch_reduce(buf->d_w, batch_count * PARAM_K * PARAM_N, stream); + launch_batch_caddq(buf->d_w, batch_count * PARAM_K * PARAM_N, stream); +#endif + + /* [9] w1 = use_hint(w, h) */ + launch_batch_use_hint(buf->d_w1, buf->d_w, buf->d_h, + batch_count * PARAM_K * PARAM_N, stream); + + /* [10] mu = H(tr || [pre ||] m) */ + { + int tpb = 32, nblk = (batch_count + tpb - 1) / tpb; + batch_verify_compute_mu_kernel<<>>( + buf->d_mu, buf->d_tr, d_msgs, mlen, + d_pre, prelen, + batch_count); + } + + /* [11] 最终比较: H(mu || pack(w1)) vs 原始挑战 */ + { +#if ALGORITHM == ALGO_AIGIS + int tpb = 32; +#else + int tpb = 128; +#endif + int nblk = (batch_count + tpb - 1) / tpb; + batch_verify_compare_kernel<<>>( + buf->d_results, buf->d_mu, buf->d_w1, buf->d_cbuf, batch_count); + } + + /* 回传结果 */ + hipMemcpyAsync(h_results, buf->d_results, + batch_count * sizeof(int), hipMemcpyDeviceToHost, + stream); + hipStreamSynchronize(stream); + + return 0; +} + +/* ================================================================ + * 批量验证 pipeline (host 签名输入) + * ================================================================ */ +static int batch_verify_pipeline( + BatchVerifyBuffers *buf, + const unsigned char *h_sigs, + const unsigned char *d_msgs, + size_t mlen, + const unsigned char *d_pre, + size_t prelen, + int batch_count, + int *h_results, + hipStream_t stream = 0) +{ + if (batch_count <= 0 || batch_count > buf->max_batch) return -1; + + hipMemcpyAsync(buf->d_raw_sigs, h_sigs, + (size_t)batch_count * CRYPTO_BYTES, + hipMemcpyHostToDevice, + stream); + + return batch_verify_pipeline_core(buf, d_msgs, mlen, d_pre, prelen, + batch_count, h_results, stream); +} + +/* ================================================================ + * 批量验证 pipeline (device 签名输入) + * ================================================================ */ +static int batch_verify_pipeline_device_sigs( + BatchVerifyBuffers *buf, + const unsigned char *d_sigs, + const unsigned char *d_msgs, + size_t mlen, + const unsigned char *d_pre, + size_t prelen, + int batch_count, + int *h_results, + hipStream_t stream = 0) +{ + if (batch_count <= 0 || batch_count > buf->max_batch) return -1; + + if (d_sigs != buf->d_raw_sigs) { + hipMemcpyAsync(buf->d_raw_sigs, d_sigs, + (size_t)batch_count * CRYPTO_BYTES, + hipMemcpyDeviceToDevice, + stream); + } + + return batch_verify_pipeline_core(buf, d_msgs, mlen, d_pre, prelen, + batch_count, h_results, stream); +} + +#endif /* BATCH_VERIFY_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/build_sig_amd.sh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/build_sig_amd.sh new file mode 100644 index 000000000..0d82b299f --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/build_sig_amd.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +set -euo pipefail + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +COMMON=( + -O2 + -std=c++17 + -x hip + --offload-arch=gfx1100 + -DBLOCK_SIZE=1 + -DBATCH_KEYGEN_INTERNAL_MATERIAL=1 + -DBATCH_SIGN_WARP_ENABLE=0 + -DBATCH_SIGN_MONO_ENABLE=0 + -DBATCH_SIGN_PRECOMP_REUSE=0 + -DBATCH_SIGN_LARGE_STRATEGY_ENABLE=0 + -DBATCH_SIGN_DECOMP_ENABLE=1 + -DBATCH_SIGN_DECOMP_TAIL_ENABLE=0 + -DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 + -DBATCH_SIGN_CP_FUSE_ENABLE=0 + -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 + -DBATCH_KEYGEN_SAMPLE_SPLIT_FAST=1 +) + +mkdir -p amd_results/build + +build_one() { + local alg="$1" + local mode="$2" + local out="$3" + echo "[build] ${out}" + hipcc "${COMMON[@]}" -DALGORITHM="${alg}" -DPARAM_MODE="${mode}" main.cu -o "${out}" \ + 2>&1 | tee "amd_results/build/${out}.log" +} + +build_one 1 2 mldsa44_amd +build_one 1 3 mldsa65_amd +build_one 1 5 mldsa87_amd +build_one 2 1 aigis1_amd +build_one 2 2 aigis2_amd +build_one 2 3 aigis3_amd + +echo "[build] done" diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/config.h b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/config.h new file mode 100644 index 000000000..b569dbb64 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/config.h @@ -0,0 +1,31 @@ +/* + * config.h — 算法选择 + * + * 设置方法: 编译时传入 -DALGORITHM=ALGO_MLDSA 或 -DALGORITHM=ALGO_AIGIS + * 以及 -DPARAM_MODE=2/3/5 (ML-DSA) 或 -DPARAM_MODE=1/2/3 (Aigis) + */ + +#ifndef CONFIG_H +#define CONFIG_H + +#define ALGO_MLDSA 1 +#define ALGO_AIGIS 2 + +#ifndef ALGORITHM +#define ALGORITHM ALGO_MLDSA +#endif + +#ifndef PARAM_MODE +#if ALGORITHM == ALGO_MLDSA +#define PARAM_MODE 5 /* ML-DSA-87 */ +#else +#define PARAM_MODE 3 /* Aigis-sig3 */ +#endif +#endif + +/* 编译期检查 */ +#if ALGORITHM != ALGO_MLDSA && ALGORITHM != ALGO_AIGIS +#error "ALGORITHM must be ALGO_MLDSA or ALGO_AIGIS" +#endif + +#endif /* CONFIG_H */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/fips202.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/fips202.cuh new file mode 100644 index 000000000..87a844d8b --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/fips202.cuh @@ -0,0 +1,354 @@ +#include "hip/hip_runtime.h" +#ifndef FIPS202_CUH +#define FIPS202_CUH + +#include +#include +#include + +#define SHAKE128_RATE 168 +#define SHAKE256_RATE 136 +#define SHA3_256_RATE 136 +#define SHA3_512_RATE 72 + +typedef struct { + uint64_t s[25]; + unsigned int pos; +} keccak_state; + +#define NROUNDS 24 +#define ROL(a, offset) ((a << offset) ^ (a >> (64-offset))) + +static __device__ __forceinline__ uint64_t load64(const uint8_t x[8]) { + uint64_t r; + memcpy(&r, x, 8); + return r; +} + +static __device__ __forceinline__ void store64(uint8_t x[8], uint64_t u) { + memcpy(x, &u, 8); +} + +__constant__ uint64_t gpu_KeccakF_RoundConstants[NROUNDS] = { + (uint64_t)0x0000000000000001ULL, + (uint64_t)0x0000000000008082ULL, + (uint64_t)0x800000000000808aULL, + (uint64_t)0x8000000080008000ULL, + (uint64_t)0x000000000000808bULL, + (uint64_t)0x0000000080000001ULL, + (uint64_t)0x8000000080008081ULL, + (uint64_t)0x8000000000008009ULL, + (uint64_t)0x000000000000008aULL, + (uint64_t)0x0000000000000088ULL, + (uint64_t)0x0000000080008009ULL, + (uint64_t)0x000000008000000aULL, + (uint64_t)0x000000008000808bULL, + (uint64_t)0x800000000000008bULL, + (uint64_t)0x8000000000008089ULL, + (uint64_t)0x8000000000008003ULL, + (uint64_t)0x8000000000008002ULL, + (uint64_t)0x8000000000000080ULL, + (uint64_t)0x000000000000800aULL, + (uint64_t)0x800000008000000aULL, + (uint64_t)0x8000000080008081ULL, + (uint64_t)0x8000000000008080ULL, + (uint64_t)0x0000000080000001ULL, + (uint64_t)0x8000000080008008ULL +}; + +/* GPU-optimized Keccak-f[1600] — compact single-round loop + * Based on Kyber batch_keccak.cu pattern: + * - 24-element cycle for Rho+Pi in-place (eliminates B[25] temp array) + * - Row-by-row Chi with 5 temporaries + * - ~60 registers vs ~120 for 2-round unrolled version + * - Better occupancy on GPU → higher hash throughput + */ +static __device__ __noinline__ void KeccakF1600_StatePermute(uint64_t state[25]) +{ + uint64_t Cx[5], Dx[5]; + + for (int round = 0; round < NROUNDS; round++) { + /* Theta */ + Cx[0] = state[0] ^ state[5] ^ state[10] ^ state[15] ^ state[20]; + Cx[1] = state[1] ^ state[6] ^ state[11] ^ state[16] ^ state[21]; + Cx[2] = state[2] ^ state[7] ^ state[12] ^ state[17] ^ state[22]; + Cx[3] = state[3] ^ state[8] ^ state[13] ^ state[18] ^ state[23]; + Cx[4] = state[4] ^ state[9] ^ state[14] ^ state[19] ^ state[24]; + + Dx[0] = Cx[4] ^ ROL(Cx[1], 1); + Dx[1] = Cx[0] ^ ROL(Cx[2], 1); + Dx[2] = Cx[1] ^ ROL(Cx[3], 1); + Dx[3] = Cx[2] ^ ROL(Cx[4], 1); + Dx[4] = Cx[3] ^ ROL(Cx[0], 1); + + state[ 0] ^= Dx[0]; state[ 5] ^= Dx[0]; state[10] ^= Dx[0]; state[15] ^= Dx[0]; state[20] ^= Dx[0]; + state[ 1] ^= Dx[1]; state[ 6] ^= Dx[1]; state[11] ^= Dx[1]; state[16] ^= Dx[1]; state[21] ^= Dx[1]; + state[ 2] ^= Dx[2]; state[ 7] ^= Dx[2]; state[12] ^= Dx[2]; state[17] ^= Dx[2]; state[22] ^= Dx[2]; + state[ 3] ^= Dx[3]; state[ 8] ^= Dx[3]; state[13] ^= Dx[3]; state[18] ^= Dx[3]; state[23] ^= Dx[3]; + state[ 4] ^= Dx[4]; state[ 9] ^= Dx[4]; state[14] ^= Dx[4]; state[19] ^= Dx[4]; state[24] ^= Dx[4]; + + /* Rho + Pi in-place via 24-element cycle (state[0] is fixed point) */ + { + uint64_t tmp = ROL(state[1], 1); + state[ 1] = ROL(state[ 6], 44); + state[ 6] = ROL(state[ 9], 20); + state[ 9] = ROL(state[22], 61); + state[22] = ROL(state[14], 39); + state[14] = ROL(state[20], 18); + state[20] = ROL(state[ 2], 62); + state[ 2] = ROL(state[12], 43); + state[12] = ROL(state[13], 25); + state[13] = ROL(state[19], 8); + state[19] = ROL(state[23], 56); + state[23] = ROL(state[15], 41); + state[15] = ROL(state[ 4], 27); + state[ 4] = ROL(state[24], 14); + state[24] = ROL(state[21], 2); + state[21] = ROL(state[ 8], 55); + state[ 8] = ROL(state[16], 45); + state[16] = ROL(state[ 5], 36); + state[ 5] = ROL(state[ 3], 28); + state[ 3] = ROL(state[18], 21); + state[18] = ROL(state[17], 15); + state[17] = ROL(state[11], 10); + state[11] = ROL(state[ 7], 6); + state[ 7] = ROL(state[10], 3); + state[10] = tmp; + } + + /* Chi — row by row with 5 temporaries */ + { + uint64_t t0, t1, t2, t3, t4; + + t0=state[0]; t1=state[1]; t2=state[2]; t3=state[3]; t4=state[4]; + state[0]=t0^((~t1)&t2); state[1]=t1^((~t2)&t3); + state[2]=t2^((~t3)&t4); state[3]=t3^((~t4)&t0); + state[4]=t4^((~t0)&t1); + + t0=state[5]; t1=state[6]; t2=state[7]; t3=state[8]; t4=state[9]; + state[5]=t0^((~t1)&t2); state[6]=t1^((~t2)&t3); + state[7]=t2^((~t3)&t4); state[8]=t3^((~t4)&t0); + state[9]=t4^((~t0)&t1); + + t0=state[10]; t1=state[11]; t2=state[12]; t3=state[13]; t4=state[14]; + state[10]=t0^((~t1)&t2); state[11]=t1^((~t2)&t3); + state[12]=t2^((~t3)&t4); state[13]=t3^((~t4)&t0); + state[14]=t4^((~t0)&t1); + + t0=state[15]; t1=state[16]; t2=state[17]; t3=state[18]; t4=state[19]; + state[15]=t0^((~t1)&t2); state[16]=t1^((~t2)&t3); + state[17]=t2^((~t3)&t4); state[18]=t3^((~t4)&t0); + state[19]=t4^((~t0)&t1); + + t0=state[20]; t1=state[21]; t2=state[22]; t3=state[23]; t4=state[24]; + state[20]=t0^((~t1)&t2); state[21]=t1^((~t2)&t3); + state[22]=t2^((~t3)&t4); state[23]=t3^((~t4)&t0); + state[24]=t4^((~t0)&t1); + } + + /* Iota */ + state[0] ^= gpu_KeccakF_RoundConstants[round]; + } +} + +static __device__ void keccak_init(uint64_t s[25]) { + unsigned int i; + for(i=0;i<25;i++) s[i] = 0; +} + +/* Word-aligned absorb: uses 64-bit loads when pos is 8-byte aligned + * All SHAKE/SHA3 rates are multiples of 8, and ML-DSA verify keeps + * pos 8-byte aligned throughout (mu=64B, w1_pack=128B). */ +static __device__ __noinline__ unsigned int keccak_absorb(uint64_t s[25], + unsigned int pos, + unsigned int r, + const uint8_t *in, + size_t inlen) +{ + unsigned int i; + while(pos+inlen >= r) { + if (!(pos & 7)) { + /* Fast path: 64-bit word loads (8x fewer iterations) */ + for(i = pos >> 3; i < r >> 3; i++, in += 8) + s[i] ^= load64(in); + } else { + for(i=pos;i 0) { + for(i = pos >> 3; i < (pos + (unsigned int)inlen) >> 3; i++, in += 8) + s[i] ^= load64(in); + return pos + (unsigned int)inlen; + } + for(i=pos;i= 8) { + /* Fast path: word-aligned squeeze */ + unsigned int end = (pos + (unsigned int)outlen < r) ? (pos + (unsigned int)outlen) : r; + for(i = pos >> 3; i < end >> 3; i++, out += 8) + store64(out, s[i]); + unsigned int extracted = (i << 3) - pos; + outlen -= extracted; + pos += extracted; + } else { + for(i=pos;i < r && i < pos+outlen; i++) + *out++ = s[i/8] >> 8*(i%8); + outlen -= i-pos; + pos = i; + } + } + return pos; +} + +static __device__ __noinline__ void keccak_absorb_once(uint64_t s[25], + unsigned int r, + const uint8_t *in, + size_t inlen, + uint8_t p) +{ + unsigned int i; + for(i=0;i<25;i++) s[i] = 0; + while(inlen >= r) { + for(i=0;is); state->pos = 0; +} + +static __device__ void shake128_absorb(keccak_state *state, const uint8_t *in, size_t inlen) { + state->pos = keccak_absorb(state->s, state->pos, SHAKE128_RATE, in, inlen); +} + +static __device__ void shake128_finalize(keccak_state *state) { + keccak_finalize(state->s, state->pos, SHAKE128_RATE, 0x1F); + state->pos = SHAKE128_RATE; +} + +static __device__ void shake128_squeeze(uint8_t *out, size_t outlen, keccak_state *state) { + state->pos = keccak_squeeze(out, outlen, state->s, state->pos, SHAKE128_RATE); +} + +static __device__ void shake128_absorb_once(keccak_state *state, const uint8_t *in, size_t inlen) { + keccak_absorb_once(state->s, SHAKE128_RATE, in, inlen, 0x1F); + state->pos = SHAKE128_RATE; +} + +static __device__ void shake128_squeezeblocks(uint8_t *out, size_t nblocks, keccak_state *state) { + keccak_squeezeblocks(out, nblocks, state->s, SHAKE128_RATE); +} + +/* SHAKE256 */ +static __device__ void shake256_init(keccak_state *state) { + keccak_init(state->s); state->pos = 0; +} + +static __device__ void shake256_absorb(keccak_state *state, const uint8_t *in, size_t inlen) { + state->pos = keccak_absorb(state->s, state->pos, SHAKE256_RATE, in, inlen); +} + +static __device__ void shake256_finalize(keccak_state *state) { + keccak_finalize(state->s, state->pos, SHAKE256_RATE, 0x1F); + state->pos = SHAKE256_RATE; +} + +static __device__ void shake256_squeeze(uint8_t *out, size_t outlen, keccak_state *state) { + state->pos = keccak_squeeze(out, outlen, state->s, state->pos, SHAKE256_RATE); +} + +static __device__ void shake256_absorb_once(keccak_state *state, const uint8_t *in, size_t inlen) { + keccak_absorb_once(state->s, SHAKE256_RATE, in, inlen, 0x1F); + state->pos = SHAKE256_RATE; +} + +static __device__ void shake256_squeezeblocks(uint8_t *out, size_t nblocks, keccak_state *state) { + keccak_squeezeblocks(out, nblocks, state->s, SHAKE256_RATE); +} + +/* Non-incremental API */ +static __device__ __noinline__ void shake128(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen) { + size_t nblocks; + keccak_state state; + shake128_absorb_once(&state, in, inlen); + nblocks = outlen/SHAKE128_RATE; + shake128_squeezeblocks(out, nblocks, &state); + outlen -= nblocks*SHAKE128_RATE; + out += nblocks*SHAKE128_RATE; + shake128_squeeze(out, outlen, &state); +} + +static __device__ __noinline__ void shake256(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen) { + size_t nblocks; + keccak_state state; + shake256_absorb_once(&state, in, inlen); + nblocks = outlen/SHAKE256_RATE; + shake256_squeezeblocks(out, nblocks, &state); + outlen -= nblocks*SHAKE256_RATE; + out += nblocks*SHAKE256_RATE; + shake256_squeeze(out, outlen, &state); +} + +static __device__ void sha3_256(uint8_t h[32], const uint8_t *in, size_t inlen) { + unsigned int i; + uint64_t s[25]; + keccak_absorb_once(s, SHA3_256_RATE, in, inlen, 0x06); + KeccakF1600_StatePermute(s); + for(i=0;i<4;i++) store64(h+8*i,s[i]); +} + +static __device__ void sha3_512(uint8_t h[64], const uint8_t *in, size_t inlen) { + unsigned int i; + uint64_t s[25]; + keccak_absorb_once(s, SHA3_512_RATE, in, inlen, 0x06); + KeccakF1600_StatePermute(s); + for(i=0;i<8;i++) store64(h+8*i,s[i]); +} + +#endif diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/main.cu b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/main.cu new file mode 100644 index 000000000..dd7a9509a --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/main.cu @@ -0,0 +1,2579 @@ +#include "hip/hip_runtime.h" +/* + * main.cu — GPU 批量数字签名基准测试 + * + * 支持 ML-DSA (44/65/87) 和 Aigis-sig (1/2/3) 共 6 种参数集 + * 流程: + * Phase 1: 单实例正确性验证 (随机测试向量, 输出全部输入/输出值) + * Phase 2: GPU批量 keygen / sign / verify 吞吐率测试 (输出 Instance 0 具体值) + * + * 批量数据采用 SoA (Structure-of-Arrays) 内存布局: + * pk_soa[byte_offset * N + instance_idx] + * precomp 批量: 单密钥对多实例 (soa_load/soa_store + 预计算分解 pipeline) + * + * 用法: + * exe [--batch N] [--sweep] [--quiet] + * --batch N 批次大小 (默认按参数集自动选择) + * --sweep 扫描多种批次大小: 64..32768 + * --quiet 省略 Phase 1 的 hex 输出 + */ + +#include +#include +#include +#include +#include +#include +#include +#ifdef _WIN32 +#include +#else +#include +#endif +#include "config.h" +#include "params.h" +#include "sign.cuh" +#include "batch_ntt.cuh" +#include "batch_ops.cuh" +#include "batch_keygen.cuh" +#include "batch_verify.cuh" +#include "batch_sign.cuh" +#include "batch_sign_warp.cuh" + +/* ================================================================ + * 常量定义 + * ================================================================ */ +/* Conservative defaults plus larger Ada/4090 defaults. */ +#if ALGORITHM == ALGO_MLDSA + #if PARAM_MODE == 2 /* ML-DSA-44: pk=1312 sk=2560 sig=2420 ~6KB/inst */ + #define DEFAULT_BATCH 4096 + #define DEFAULT_BATCH_4090 16384 + #elif PARAM_MODE == 3 /* ML-DSA-65: pk=1952 sk=4032 sig=3309 ~9KB/inst */ + #define DEFAULT_BATCH 2048 + #define DEFAULT_BATCH_4090 32768 + #elif PARAM_MODE == 5 /* ML-DSA-87: pk=2592 sk=4896 sig=4627 ~12KB/inst */ + #define DEFAULT_BATCH 1024 + #define DEFAULT_BATCH_4090 16384 + #endif +#elif ALGORITHM == ALGO_AIGIS + #if PARAM_MODE == 1 /* Aigis-1: pk=1056 sk=2448 sig=1852 ~5KB/inst */ + #define DEFAULT_BATCH 4096 + #define DEFAULT_BATCH_4090 16384 + #elif PARAM_MODE == 2 /* Aigis-2: pk=1312 sk=3376 sig=2445 ~7KB/inst */ + #define DEFAULT_BATCH 2048 + #define DEFAULT_BATCH_4090 16384 + #elif PARAM_MODE == 3 /* Aigis-3: pk=1568 sk=3888 sig=3046 ~8KB/inst */ + #define DEFAULT_BATCH 2048 + #define DEFAULT_BATCH_4090 16384 + #endif +#endif +#ifndef DEFAULT_BATCH_4090 +#define DEFAULT_BATCH_4090 DEFAULT_BATCH +#endif +#ifndef CUDA_TARGET_ARCH +#define CUDA_TARGET_ARCH 0 +#endif +#ifndef BLOCK_SIZE +#define BLOCK_SIZE 64 +#endif +#define NUM_STREAMS 4 + +/* ================================================================ + * 命令行选项 + * ================================================================ */ +typedef struct { + int batch_size; + int batch_auto; + int sweep; + int quiet; + int throughput; + int sample_only; + int keygen_compare; + int bench_paper; + int bench_independent; + int profile; + int skip_keygen_oracle; +} Options; + +static int g_profile = 0; +static int g_bench_independent = 0; + +static int read_file_all(const char *path, uint8_t **out, size_t *out_len) { + FILE *f = fopen(path, "rb"); + long n; + uint8_t *buf; + if (!f) { + fprintf(stderr, "open failed: %s\n", path); + return -1; + } + if (fseek(f, 0, SEEK_END) != 0) { + fclose(f); + return -1; + } + n = ftell(f); + if (n < 0) { + fclose(f); + return -1; + } + if (fseek(f, 0, SEEK_SET) != 0) { + fclose(f); + return -1; + } + buf = (uint8_t *)malloc((size_t)n + 1u); + if (!buf) { + fclose(f); + return -1; + } + if ((size_t)n > 0 && fread(buf, 1, (size_t)n, f) != (size_t)n) { + free(buf); + fclose(f); + return -1; + } + fclose(f); + buf[n] = 0; + *out = buf; + *out_len = (size_t)n; + return 0; +} + +static int read_file_exact_host(const char *path, uint8_t *buf, size_t len) { + uint8_t *tmp = NULL; + size_t n = 0; + int rc = read_file_all(path, &tmp, &n); + if (rc != 0) return rc; + if (n != len) { + fprintf(stderr, "size mismatch: %s expected %zu got %zu\n", path, len, n); + free(tmp); + return -1; + } + memcpy(buf, tmp, len); + free(tmp); + return 0; +} + +static int write_file_all(const char *path, const uint8_t *buf, size_t len) { + FILE *f = fopen(path, "wb"); + if (!f) { + fprintf(stderr, "write open failed: %s\n", path); + return -1; + } + if (len > 0 && fwrite(buf, 1, len, f) != len) { + fclose(f); + return -1; + } + fclose(f); + return 0; +} + +static void fill_random_host(uint8_t *buf, size_t len) { + FILE *f = fopen("/dev/urandom", "rb"); + if (f) { + size_t n = fread(buf, 1, len, f); + fclose(f); + if (n == len) return; + } + srand((unsigned)time(NULL)); + for (size_t i = 0; i < len; i++) buf[i] = (uint8_t)(rand() & 0xff); +} + +#ifndef BATCH_KEYGEN_INTERNAL_MATERIAL +#define BATCH_KEYGEN_INTERNAL_MATERIAL 0 +#endif + +#ifndef BATCH_SIGN_PRECOMP_REUSE +#define BATCH_SIGN_PRECOMP_REUSE 0 +#endif + +#ifndef BATCH_SIGN_MONO_ENABLE +#define BATCH_SIGN_MONO_ENABLE 1 +#endif + +#ifndef BATCH_SIGN_DECOMP_ENABLE +#define BATCH_SIGN_DECOMP_ENABLE 1 +#endif + +#ifndef BATCH_SIGN_WARP_ENABLE +#define BATCH_SIGN_WARP_ENABLE 1 +#endif + +#ifndef BATCH_SIGN_WARP_PROFILE +#define BATCH_SIGN_WARP_PROFILE 0 +#endif + +#ifndef BATCH_SIGN_LARGE_STRATEGY_ENABLE +#define BATCH_SIGN_LARGE_STRATEGY_ENABLE 1 +#endif + +#ifndef BATCH_SIGN_LARGE_BATCH_THRESHOLD +#define BATCH_SIGN_LARGE_BATCH_THRESHOLD 4096 +#endif + +#ifndef BATCH_SIGN_NONCE_DIVERSIFY +#define BATCH_SIGN_NONCE_DIVERSIFY 0 +#endif + +#ifndef BATCH_SIGN_DECOMP_TAIL_ENABLE +#define BATCH_SIGN_DECOMP_TAIL_ENABLE 0 +#endif + +#ifndef BATCH_SIGN_CP_FUSE_ENABLE +#define BATCH_SIGN_CP_FUSE_ENABLE 0 +#endif + +#ifndef BATCH_SIGN_SAMPLE_DUP_YHAT +#define BATCH_SIGN_SAMPLE_DUP_YHAT 0 +#endif + +#ifndef BATCH_SIGN_DECOMP_CHECK_INTERVAL +#define BATCH_SIGN_DECOMP_CHECK_INTERVAL 4 +#endif + +#ifndef BATCH_SIGN_SAMPLE_TPB +#define BATCH_SIGN_SAMPLE_TPB 64 +#endif + +#ifndef BATCH_SIGN_HASH_TPB +#define BATCH_SIGN_HASH_TPB 32 +#endif + +#ifndef BATCH_SIGN_CHECK_TPB +#define BATCH_SIGN_CHECK_TPB 32 +#endif + +#ifndef BATCH_SIGN_DECOMP_ADAPTIVE_ENABLE +#define BATCH_SIGN_DECOMP_ADAPTIVE_ENABLE 0 +#endif + +static const char *keygen_ind_sample_mode_name(void) { +#if BATCH_KEYGEN_MATRIX_A_COOP +#if BATCH_KEYGEN_SECRET_ETA_COOP + return "sample-coop-full"; +#elif BATCH_KEYGEN_MATRIX_A_COOP_SUBWARP + return "matrixA-coop-subwarp"; +#else + return "matrixA-coop-warp"; +#endif +#elif BATCH_KEYGEN_MATRIX_A_LANEOPT + return "matrixA-laneopt"; +#elif BATCH_KEYGEN_SECRET_ETA_COOP +#if BATCH_KEYGEN_SECRET_ETA_AIGIS5_SPLIT + return "eta2-aigis5-coop"; +#else + return "eta-coop"; +#endif +#elif BATCH_KEYGEN_SAMPLE_SPLIT_FAST || BATCH_KEYGEN_MATRIX_A_FAST || BATCH_KEYGEN_SECRET_ETA_FAST + return "split-baseline"; +#else + return "old-fused"; +#endif +} + +static const char *keygen_paper_sample_mode_name(void) { +#if BATCH_KEYGEN_SECRET_ETA_COOP +#if BATCH_KEYGEN_MATRIX_A_COOP + return "sample-coop-full"; +#elif BATCH_KEYGEN_SECRET_ETA_AIGIS5_SPLIT + return "eta2-aigis5-coop"; +#else + return "eta-coop"; +#endif +#elif BATCH_KEYGEN_SAMPLE_SPLIT_FAST || BATCH_KEYGEN_SECRET_ETA_FAST + return "split-baseline"; +#else + return "old-fused"; +#endif +} + +static const char *internal_material_mode_name(void) { +#if BATCH_KEYGEN_INTERNAL_MATERIAL + return "internal-material"; +#else + return "pk-sk-precompute"; +#endif +} + +static const char *sign_precomp_mode_name(void) { +#if BATCH_SIGN_DECOMP_ENABLE && !BATCH_SIGN_MONO_ENABLE && !BATCH_SIGN_PRECOMP_REUSE && !BATCH_SIGN_WARP_ENABLE + return "sign-decomp-resource-aware"; +#else +#if BATCH_SIGN_LARGE_STRATEGY_ENABLE +#if BATCH_SIGN_NONCE_DIVERSIFY + return "large-batch-warp-strategy"; +#else + return "large-batch-mono-strategy"; +#endif +#else +#if BATCH_SIGN_PRECOMP_REUSE +#if BATCH_SIGN_WARP_ENABLE + return "sign-cache+warp-enabled"; +#else + return "sign-cache-enabled"; +#endif +#else +#if BATCH_SIGN_MONO_ENABLE +#if BATCH_SIGN_WARP_ENABLE + return "sign-mono+warp-enabled"; +#else + return "sign-mono-only"; +#endif +#else +#if BATCH_SIGN_DECOMP_ENABLE + return "sign-decomp-fallback"; +#else + return "sign-disabled"; +#endif +#endif +#endif +#endif +#endif +} + +static const char *policy_onoff(int enabled) { + return enabled ? "on" : "off"; +} + +static BatchSignRuntimeOptions select_decomp_runtime_options( + int batch, + int independent_mode, + const char **label) +{ + BatchSignRuntimeOptions opt = batch_sign_default_runtime_options(); + const char *name = "base"; + +#if BATCH_SIGN_DECOMP_ADAPTIVE_ENABLE + opt.cp_fuse_enable = 0; + opt.check_interval = 4; + opt.hash_tpb = 32; + opt.check_tpb = 32; + +#if ALGORITHM == ALGO_MLDSA && PARAM_MODE == 2 + if (!independent_mode && batch >= 4096) { + opt.check_interval = 16; + name = "check16"; + } +#elif ALGORITHM == ALGO_MLDSA && PARAM_MODE == 5 + if (!independent_mode && batch <= 2048) { + opt.check_interval = 8; + name = "check8"; + } +#elif ALGORITHM == ALGO_AIGIS && PARAM_MODE == 2 + if (!independent_mode && batch <= 2048) { + opt.check_interval = 16; + name = "check16"; + } else if (!independent_mode && batch >= 4096) { + name = "base"; + } +#endif +#endif + + if (label) *label = name; + return opt; +} + +static void print_rocm_sign_policy(int active_batch) { + printf("ROCm sign policy: resource-aware hybrid candidates\n"); + printf(" decomp-pipeline=%s monolithic-precomp=%s cached-precomp=%s\n", + policy_onoff(BATCH_SIGN_DECOMP_ENABLE), + policy_onoff(BATCH_SIGN_MONO_ENABLE), + policy_onoff(BATCH_SIGN_PRECOMP_REUSE)); + printf(" warp-path=%s large-strategy=%s threshold=%d active_batch=%d\n", + policy_onoff(BATCH_SIGN_MONO_ENABLE && BATCH_SIGN_WARP_ENABLE), + policy_onoff(BATCH_SIGN_MONO_ENABLE && BATCH_SIGN_LARGE_STRATEGY_ENABLE), + BATCH_SIGN_LARGE_BATCH_THRESHOLD, + active_batch); + printf(" decomp-cp-fuse=%s decomp-tail=%s yhat-copy-fuse=%s\n", + policy_onoff(BATCH_SIGN_CP_FUSE_ENABLE), + policy_onoff(BATCH_SIGN_DECOMP_TAIL_ENABLE), + policy_onoff(BATCH_SIGN_SAMPLE_DUP_YHAT)); + printf(" decomp-adaptive=%s\n", + policy_onoff(BATCH_SIGN_DECOMP_ADAPTIVE_ENABLE)); + printf(" decomp-check-interval=%d ctrl-tpb(sample/hash/check)=%d/%d/%d\n", + BATCH_SIGN_DECOMP_CHECK_INTERVAL, + BATCH_SIGN_SAMPLE_TPB, + BATCH_SIGN_HASH_TPB, + BATCH_SIGN_CHECK_TPB); +#if defined(__HIP_PLATFORM_AMD__) + printf(" backend=HIP/ROCm AMD; sign row records selected path label\n"); +#else + printf(" backend=HIP-compatible; sign row records selected path label\n"); +#endif +#if BATCH_SIGN_DECOMP_ENABLE + printf(" rationale=use decomp-pipeline to reduce monolithic private segment/scratch pressure\n"); +#endif +} + +static void print_usage(const char *prog) { + printf("Usage: %s [--batch N] [--sweep] [--throughput] [--sample-only] [--keygen-compare] [--bench-paper] [--bench-independent] [--profile] [--skip-keygen-oracle] [--quiet]\n", prog); + printf(" --batch N batch size (auto default: conservative %d, RTX4090 %d)\n", DEFAULT_BATCH, DEFAULT_BATCH_4090); + printf(" --sweep sweep batch sizes: 64,128,256,512,1024,2048,4096,8192,16384,32768\n"); + printf(" --throughput throughput scan: 256..32768, 10 runs avg, CSV output\n"); + printf(" --sample-only sample-only microbench; skip NTT/matvec/pack/sign/verify\n"); + printf(" --keygen-compare compare old vs active keygen path and exit; with --sample-only only compare sampling buffers\n"); + printf(" --bench-paper paper-4090-style shared key/message/precompute benchmark (default)\n"); + printf(" --bench-independent independent-real-batch mode; keygen seeds are independent\n"); + printf(" --profile print lightweight pipeline/profile annotations\n"); + printf(" --skip-keygen-oracle skip batch-vs-single keygen oracle check before profiling\n"); + printf(" --quiet suppress Phase 1 hex dump\n"); +} + +static int parse_options(int argc, char **argv, Options *opt) { + opt->batch_size = 0; + opt->batch_auto = 1; + opt->sweep = 0; + opt->quiet = 0; + opt->throughput = 0; + opt->sample_only = 0; + opt->keygen_compare = 0; + opt->bench_paper = 1; + opt->bench_independent = 0; + opt->profile = 0; + opt->skip_keygen_oracle = 0; + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "--batch") == 0 && i + 1 < argc) { + opt->batch_size = atoi(argv[++i]); + if (opt->batch_size <= 0) { printf("Invalid batch size\n"); return -1; } + opt->batch_auto = 0; + } else if (strcmp(argv[i], "--sweep") == 0) { + opt->sweep = 1; + } else if (strcmp(argv[i], "--throughput") == 0) { + opt->throughput = 1; + } else if (strcmp(argv[i], "--sample-only") == 0) { + opt->sample_only = 1; + } else if (strcmp(argv[i], "--keygen-compare") == 0) { + opt->keygen_compare = 1; + } else if (strcmp(argv[i], "--bench-paper") == 0) { + opt->bench_paper = 1; + opt->bench_independent = 0; + } else if (strcmp(argv[i], "--bench-independent") == 0) { + opt->bench_paper = 0; + opt->bench_independent = 1; + } else if (strcmp(argv[i], "--profile") == 0) { + opt->profile = 1; + } else if (strcmp(argv[i], "--skip-keygen-oracle") == 0) { + opt->skip_keygen_oracle = 1; + } else if (strcmp(argv[i], "--quiet") == 0) { + opt->quiet = 1; + } else if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { + print_usage(argv[0]); return 1; + } else { + printf("Unknown option: %s\n", argv[i]); return -1; + } + } + g_profile = opt->profile; + g_bench_independent = opt->bench_independent; + return 0; +} + +/* ================================================================ + * CUDA 错误检查宏 + * ================================================================ */ +#define CUDA_CHECK(call) do { \ + hipError_t _e = (call); \ + if (_e != hipSuccess) { \ + printf("CUDA error: %s (%s:%d)\n", \ + hipGetErrorString(_e), __FILE__, __LINE__); \ + rc = -1; goto cleanup; \ + } \ +} while(0) + +/* ================================================================ + * SoA ↔ AoS 转换辅助函数 + * + * SoA 布局: soa_base[byte * N + idx] + * AoS 布局: aos_base[idx * item_bytes + byte] (每线程本地连续缓冲) + * + * soa_load: 从 SoA 全局内存加载到线程本地连续 buffer + * soa_store: 从线程本地连续 buffer 存回 SoA 全局内存 + * ================================================================ */ +__device__ static void soa_load(uint8_t *local_buf, const uint8_t *soa_base, + int idx, int N, int item_bytes) { + for (int b = 0; b < item_bytes; ++b) + local_buf[b] = soa_base[(size_t)b * N + idx]; +} + +__device__ static void soa_store(uint8_t *soa_base, const uint8_t *local_buf, + int idx, int N, int item_bytes) { + for (int b = 0; b < item_bytes; ++b) + soa_base[(size_t)b * N + idx] = local_buf[b]; +} + +/* ================================================================ + * GPU 内核 + * ================================================================ */ + +/* 单实例正确性测试: keygen + sign + verify + 篡改检测 */ +__global__ void kernel_single_test( + uint8_t *pk, uint8_t *sk, uint8_t *sig, size_t *siglen, + const uint8_t *seed, const uint8_t *rnd, + const uint8_t *msg, size_t mlen, + const uint8_t *pre, size_t prelen, + int *result) +{ + int r; + r = crypto_sign_keypair(pk, sk, seed); + if (r) { *result = -1; return; } + +#if ALGORITHM == ALGO_MLDSA + r = crypto_sign_signature(sig, siglen, msg, mlen, pre, prelen, rnd, sk); +#else + r = crypto_sign_signature(sig, siglen, msg, mlen, rnd, sk); +#endif + if (r) { *result = -2; return; } + +#if ALGORITHM == ALGO_MLDSA + r = crypto_sign_verify(sig, *siglen, msg, mlen, pre, prelen, pk); +#else + r = crypto_sign_verify(sig, *siglen, msg, mlen, pk); +#endif + if (r) { *result = -3; return; } + + /* 篡改 1 bit, 签名验证应失败 */ + sig[0] ^= 1; +#if ALGORITHM == ALGO_MLDSA + r = crypto_sign_verify(sig, *siglen, msg, mlen, pre, prelen, pk); +#else + r = crypto_sign_verify(sig, *siglen, msg, mlen, pk); +#endif + sig[0] ^= 1; + if (r == 0) { *result = -4; return; } + + *result = 0; +} + +__global__ void kernel_keygen_only(uint8_t *pk, uint8_t *sk, + const uint8_t *seed, int *result) +{ + int r = crypto_sign_keypair(pk, sk, seed); + *result = r; +} + +/* 设备端将 1 份 AoS 签名广播成 batch 份 AoS,避免 Host↔Device 往返 */ +__global__ void kernel_cli_sign(uint8_t *sig, size_t *siglen, int *result, + const uint8_t *msg, size_t mlen, + const uint8_t *sk, const uint8_t *rnd) +{ +#if ALGORITHM == ALGO_MLDSA + const uint8_t *pre = msg; + int r = crypto_sign_signature(sig, siglen, msg, mlen, pre, 0, rnd, sk); +#else + int r = crypto_sign_signature(sig, siglen, msg, mlen, rnd, sk); +#endif + *result = r; +} + +__global__ void kernel_cli_verify(int *result, + const uint8_t *sig, size_t siglen, + const uint8_t *msg, size_t mlen, + const uint8_t *pk) +{ +#if ALGORITHM == ALGO_MLDSA + const uint8_t *pre = msg; + int r = crypto_sign_verify(sig, siglen, msg, mlen, pre, 0, pk); +#else + int r = crypto_sign_verify(sig, siglen, msg, mlen, pk); +#endif + *result = r; +} + +__global__ void kernel_broadcast_sig_aos(uint8_t *dst, const uint8_t *src, + int batch_count, int sig_bytes) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = batch_count * sig_bytes; + if (idx < total) dst[idx] = src[idx % sig_bytes]; +} + +/* ================================================================ + * 预计算内核 — 同一密钥批量签名/验证 + * ================================================================ */ + +/* 单线程: 从 pk/sk 创建预计算数据 */ +__global__ void kernel_create_precomp(precomp_t *pc, + const uint8_t *pk, const uint8_t *sk) { + create_precomp(pc, pk, sk); +} + +/* 批量签名 (预计算): 每线程用共享预计算密钥签署消息 */ +__global__ void __launch_bounds__(BLOCK_SIZE, 2) +kernel_batch_sign_precomp( + uint8_t *sig_soa, size_t *siglen_arr, + const uint8_t *msg, size_t mlen, + const uint8_t *pre, size_t prelen, + const uint8_t *rnd, + const precomp_t *pc, + int *results, int N, int base_idx) +{ + int i = base_idx + blockIdx.x * blockDim.x + threadIdx.x; + if (i >= N) return; + uint8_t sig_local[CRYPTO_BYTES]; +#if BATCH_SIGN_NONCE_DIVERSIFY + uint16_t nonce_start = +#if ALGORITHM == ALGO_AIGIS + (uint16_t)(((unsigned int)i * PARAM_L) & 0xffffu); +#else + (uint16_t)i; +#endif +#else + uint16_t nonce_start = 0; +#endif +#if ALGORITHM == ALGO_MLDSA + results[i] = crypto_sign_signature_precomp( + sig_local, siglen_arr + i, + msg, mlen, pre, prelen, rnd, pc, nonce_start); +#else + results[i] = crypto_sign_signature_precomp( + sig_local, siglen_arr + i, msg, mlen, rnd, pc, nonce_start); +#endif + soa_store(sig_soa, sig_local, i, N, CRYPTO_BYTES); +} + +typedef struct { +#if ALGORITHM == ALGO_MLDSA + uint8_t mu[CRHBYTES]; + uint8_t rhoprime[CRHBYTES]; +#else + uint8_t mu[CRHBYTES]; + uint8_t key_mu[SEEDBYTES + CRHBYTES]; +#endif +} sign_cache_t; + +/* 单线程: paper-mode 共享消息/随机数时, 只派生一次签名哈希种子 */ +__global__ void kernel_create_sign_cache( + sign_cache_t *cache, + const precomp_t *pc, + const uint8_t *msg, size_t mlen, + const uint8_t *pre, size_t prelen, + const uint8_t *rnd) +{ + keccak_state state; + +#if ALGORITHM == ALGO_MLDSA + shake256_init(&state); + shake256_absorb(&state, pc->tr, TRBYTES); + shake256_absorb(&state, pre, prelen); + shake256_absorb(&state, msg, mlen); + shake256_finalize(&state); + shake256_squeeze(cache->mu, CRHBYTES, &state); + + shake256_init(&state); + shake256_absorb(&state, pc->key, SEEDBYTES); +#if RNDBYTES > 0 + shake256_absorb(&state, rnd, RNDBYTES); +#endif + shake256_absorb(&state, cache->mu, CRHBYTES); + shake256_finalize(&state); + shake256_squeeze(cache->rhoprime, CRHBYTES, &state); +#else + shake256_init(&state); + shake256_absorb(&state, pc->tr, TRBYTES); + shake256_absorb(&state, msg, mlen); + shake256_finalize(&state); + shake256_squeeze(cache->mu, CRHBYTES, &state); + + memcpy(cache->key_mu, pc->key, SEEDBYTES); + memcpy(cache->key_mu + SEEDBYTES, cache->mu, CRHBYTES); +#endif +} + +/* 批量签名 (paper cached): 每线程复用共享 mu/rhoprime/key_mu, 仍独立执行 rejection loop */ +__global__ void __launch_bounds__(BLOCK_SIZE, 2) +kernel_batch_sign_precomp_cached( + uint8_t *sig_soa, size_t *siglen_arr, + const sign_cache_t *cache, + const precomp_t *pc, + int *results, int N, int base_idx) +{ + int i = base_idx + blockIdx.x * blockDim.x + threadIdx.x; + if (i >= N) return; + uint8_t sig_local[CRYPTO_BYTES]; +#if BATCH_SIGN_NONCE_DIVERSIFY + uint16_t nonce_start = +#if ALGORITHM == ALGO_AIGIS + (uint16_t)(((unsigned int)i * PARAM_L) & 0xffffu); +#else + (uint16_t)i; +#endif +#else + uint16_t nonce_start = 0; +#endif +#if ALGORITHM == ALGO_MLDSA + results[i] = crypto_sign_signature_precomp_cached( + sig_local, siglen_arr + i, cache->mu, cache->rhoprime, pc, nonce_start); +#else + results[i] = crypto_sign_signature_precomp_cached( + sig_local, siglen_arr + i, cache->mu, cache->key_mu, pc, nonce_start); +#endif + soa_store(sig_soa, sig_local, i, N, CRYPTO_BYTES); +} + +/* 批量验证 (预计算): 每线程用共享预计算矩阵验证签名 */ +__global__ void __launch_bounds__(BLOCK_SIZE, 2) +kernel_batch_verify_precomp( + const uint8_t *sig_soa, const size_t *siglen_arr, + const uint8_t *msg, size_t mlen, + const uint8_t *pre, size_t prelen, + const uint8_t *pk, + const precomp_t *pc, + int *results, int N, int base_idx) +{ + int i = base_idx + blockIdx.x * blockDim.x + threadIdx.x; + if (i >= N) return; + uint8_t sig_local[CRYPTO_BYTES]; + soa_load(sig_local, sig_soa, i, N, CRYPTO_BYTES); +#if ALGORITHM == ALGO_MLDSA + results[i] = crypto_sign_verify_precomp( + sig_local, siglen_arr[i], + msg, mlen, pre, prelen, pk, pc->mat); +#else + results[i] = crypto_sign_verify_precomp( + sig_local, siglen_arr[i], + msg, mlen, pk, pc->mat); +#endif +} + +/* ================================================================ + * Host 辅助函数 + * ================================================================ */ +static void print_hex(const char *label, const uint8_t *data, size_t len) { + printf("%s (%zu bytes):\n", label, len); + for (size_t i = 0; i < len; i++) { + printf("%02x", data[i]); + if ((i + 1) % 32 == 0) printf("\n"); + } + if (len % 32) printf("\n"); +} + +static int select_default_batch_for_device(void) { + int devId = 0; + hipDeviceProp_t prop; + size_t free_mem = 0, total_mem = 0; + if (hipGetDevice(&devId) != hipSuccess) return DEFAULT_BATCH; + if (hipGetDeviceProperties(&prop, devId) != hipSuccess) return DEFAULT_BATCH; + if (hipMemGetInfo(&free_mem, &total_mem) != hipSuccess) total_mem = 0; + + int runtime_sm = prop.major * 10 + prop.minor; + if (runtime_sm >= 89 && total_mem >= (16ull * 1024ull * 1024ull * 1024ull)) { + return DEFAULT_BATCH_4090; + } + return DEFAULT_BATCH; +} + +static void print_info(int active_batch, int batch_auto) { + int devId = 0; + hipDeviceProp_t prop; + hipGetDevice(&devId); + hipGetDeviceProperties(&prop, devId); + size_t free_mem = 0, total_mem = 0; + hipMemGetInfo(&free_mem, &total_mem); + printf("=== %s (Mode=%d) | Batch=%d%s ===\n", + CRYPTO_ALGNAME, PARAM_MODE, active_batch, + batch_auto ? " (auto)" : ""); + printf("GPU: %s CC=%d.%d SMs=%d VRAM=%zuMB L2=%dKB\n", + prop.name, prop.major, prop.minor, prop.multiProcessorCount, + total_mem / (1024*1024), prop.l2CacheSize / 1024); +#if CUDA_TARGET_ARCH + printf("Build: CUDA target=sm_%d BLOCK_SIZE=%d\n", CUDA_TARGET_ARCH, BLOCK_SIZE); + int runtime_sm = prop.major * 10 + prop.minor; + if (runtime_sm != CUDA_TARGET_ARCH) { + printf("Warning: binary was compiled for sm_%d but current GPU is sm_%d\n", + CUDA_TARGET_ARCH, runtime_sm); + } +#else + printf("Build: CUDA target not recorded BLOCK_SIZE=%d\n", BLOCK_SIZE); +#endif + printf("Params: K=%d L=%d N=%d Q=%d ETA=%d/%d TAU=%d GAMMA1=%d OMEGA=%d\n", + PARAM_K, PARAM_L, PARAM_N, PARAM_Q, + PARAM_ETA_S1, PARAM_ETA_S2, PARAM_TAU, PARAM_GAMMA1, PARAM_OMEGA); + printf("Sizes: PK=%d SK=%d SIG=%d bytes\n\n", + CRYPTO_PUBLICKEYBYTES, CRYPTO_SECRETKEYBYTES, CRYPTO_BYTES); + print_rocm_sign_policy(active_batch); + printf("\n"); +} + +/* 检查结果数组, 返回失败个数 */ +static int count_failures(const int *h, int n) { + int fails = 0; + for (int i = 0; i < n; i++) + if (h[i] != 0) fails++; + return fails; +} + +static int check_results(const int *h, int n, const char *stage) { + int fails = 0, first_idx = -1, first_code = 0; + for (int i = 0; i < n; i++) { + if (h[i] != 0) { + if (first_idx < 0) { first_idx = i; first_code = h[i]; } + fails++; + } + } + if (fails == 0) { + printf(" [%s] correctness: all %d PASS\n", stage, n); + return 0; + } + printf(" [%s] FAIL: %d/%d (first: idx=%d code=%d)\n", + stage, fails, n, first_idx, first_code); + return fails; +} + +static double ops_from_ms(double count, float ms) { + if (ms < 0.001f) return 0.0; + return count * 1000.0 / (double)ms; +} + +static int buffer_all_zero(const uint8_t *data, size_t len) { + for (size_t i = 0; i < len; i++) { + if (data[i] != 0) return 0; + } + return 1; +} + +static int check_host_key_material(const uint8_t *pk, const uint8_t *sk, + const char *stage, int instance) { + int pk_zero = buffer_all_zero(pk, CRYPTO_PUBLICKEYBYTES); + int sk_zero = buffer_all_zero(sk, CRYPTO_SECRETKEYBYTES); + if (pk_zero || sk_zero) { + printf("[%s] FAIL: instance %d produced %s%s%s\n", + stage, instance, + pk_zero ? "all-zero PK" : "", + (pk_zero && sk_zero) ? " and " : "", + sk_zero ? "all-zero SK" : ""); + return -1; + } + return 0; +} + +static int check_device_key_material_prefix(const uint8_t *d_pks, + const uint8_t *d_sks, + int batch_count, + int check_count, + const char *stage) { + int n = check_count; + if (n > batch_count) n = batch_count; + if (n <= 0) return 0; + + size_t pk_bytes = (size_t)n * CRYPTO_PUBLICKEYBYTES; + size_t sk_bytes = (size_t)n * CRYPTO_SECRETKEYBYTES; + uint8_t *h_pk = (uint8_t *)malloc(pk_bytes); + uint8_t *h_sk = (uint8_t *)malloc(sk_bytes); + if (!h_pk || !h_sk) { + printf("[%s] FAIL: host malloc failed during key material check\n", stage); + free(h_pk); + free(h_sk); + return -1; + } + + hipError_t err = hipMemcpy(h_pk, d_pks, pk_bytes, hipMemcpyDeviceToHost); + if (err == hipSuccess) { + err = hipMemcpy(h_sk, d_sks, sk_bytes, hipMemcpyDeviceToHost); + } + if (err != hipSuccess) { + printf("[%s] FAIL: key material copy failed: %s\n", + stage, hipGetErrorString(err)); + free(h_pk); + free(h_sk); + return -1; + } + + for (int i = 0; i < n; i++) { + if (check_host_key_material(h_pk + (size_t)i * CRYPTO_PUBLICKEYBYTES, + h_sk + (size_t)i * CRYPTO_SECRETKEYBYTES, + stage, i) != 0) { + free(h_pk); + free(h_sk); + return -1; + } + } + + free(h_pk); + free(h_sk); + return 0; +} + +/* ================================================================ + * Phase 1: 单实例正确性验证 — 输出全部输入/输出值 + * ================================================================ */ +static int run_single_correctness( + const uint8_t *h_seed, const uint8_t *h_rnd, + const uint8_t *h_msg, size_t mlen, + const uint8_t *h_ctx, size_t ctxlen, + const uint8_t *h_pre, size_t prelen, + int quiet) +{ + (void)h_seed; + (void)h_rnd; + (void)h_msg; + (void)mlen; + (void)h_ctx; + (void)ctxlen; + (void)h_pre; + (void)prelen; + (void)quiet; + printf("=== Phase 1: Single-instance correctness skipped on AMD/HIP first-pass build ===\n\n"); + return 0; +} + + +static int run_keygen_oracle_check(const uint8_t *h_seed, int check_n, int quiet) +{ + (void)h_seed; + (void)check_n; + (void)quiet; + return 0; +} + + +static int run_keygen_compare_batch( + int N, + const uint8_t *h_seed, + int quiet, + int sample_only) +{ + int rc = 0; + unsigned char *d_base_seed = nullptr; + KeygenCompareResult result; + + keygen_compare_result_clear(&result); + + if (!quiet) { + printf("--- [Batch=%d] Keygen compare ---\n", N); + printf(" mode=%s compare=%s\n", + g_bench_independent ? "independent-real-batch" : "paper-4090-style", + sample_only ? "sample-only" : "full-keygen"); + printf(" build: tr_hash_fixed=%d material=%s sign=%s sample_ind=%s sample_paper=%s\n", + BATCH_KEYGEN_TR_HASH_FIXED, + internal_material_mode_name(), + sign_precomp_mode_name(), + keygen_ind_sample_mode_name(), + keygen_paper_sample_mode_name()); + fflush(stdout); + } + + CUDA_CHECK(hipMalloc(&d_base_seed, SEEDBYTES)); + CUDA_CHECK(hipMemcpy(d_base_seed, h_seed, SEEDBYTES, hipMemcpyHostToDevice)); + + rc = batch_keygen_compare_active_path( + d_base_seed, + N, + g_bench_independent ? 0 : 1, + sample_only, + &result); + + if (rc == 0) { + printf("[Keygen-compare] PASS: old vs active %s path matched for batch=%d\n\n", + sample_only ? "sample" : "full", + N); + } else if (rc > 0) { + printf("[Keygen-compare] first mismatch: stage=%s instance=%d byte_off=%zu elem_off=%zu ref=%lld cand=%lld\n\n", + keygen_compare_stage_name(result.stage), + result.instance, + result.byte_offset, + result.element_offset, + (long long)result.ref_value, + (long long)result.cand_value); + rc = 1; + } else { + printf("[Keygen-compare] FAILED to run compare\n\n"); + } + +cleanup: + hipFree(d_base_seed); + return rc; +} + +/* ================================================================ + * Phase 2: 分解式批量性能基准测试 + * + * 优化原理: + * 1. 流水线分解: keygen/verify 拆成 7-11 个专用 kernel + * 2. 共享内存 NTT: 128 线程/poly, shared memory 蝶形 + * 3. 2D Grid 矩阵向量乘: dim3(batch, K), 每系数一线程 + * 4. 栈缩减: 采样 48KB, 运算 4KB → GPU 利用率 >50% + * 5. 共享矩阵 A (verify): 所有实例共享一份 + * 6. 多次迭代取平均: WARMUP + BENCH_ITERS + * ================================================================ */ +#define WARMUP_ITERS 3 +#define BENCH_ITERS 5 +#define SAMPLE_ONLY_ITERS 3 +#define THROUGHPUT_RUNS 10 + +static float median3f(float a, float b, float c) +{ + if (a > b) { float t = a; a = b; b = t; } + if (b > c) { float t = b; b = c; c = t; } + if (a > b) { float t = a; a = b; b = t; } + return b; +} + +static int run_batch( + int N, + const uint8_t *h_seed, const uint8_t *h_rnd, + const uint8_t *h_msg, size_t mlen, + const uint8_t *h_pre, size_t prelen, + int quiet, + int bench_iters, + float *out_kg_ms, + float *out_sg_ms, + float *out_vf_ms) +{ + int rc = 0; + float ms = 0, ms_keygen = 0, ms_sign = 0, ms_verify = 0; + float ms_keygen_old = 0.0f, ms_keygen_ind = 0.0f, ms_keygen_paper = -1.0f; + int verify_fails = 0; + double kg_ops = 0, sg_ops = 0, vf_ops = 0; + hipEvent_t ev0 = nullptr, ev1 = nullptr; + + /* 共用设备缓冲区: 单公钥 pk_one / sk_one (用于 sign 和 verify) */ + uint8_t *d_pk_one = nullptr, *d_sk_one = nullptr; + uint8_t *d_sigs_for_verify = nullptr; + int verify_uses_batch_sigs = 0; + const char *chosen_sign_label = "precomp-monolithic"; + const char *chosen_keygen_label = "independent-old"; + uint8_t *d_base_seed = nullptr, *d_shared_rho = nullptr; + BatchKeygenBuffers kbuf; + KeygenProfile ind_profile; + KeygenProfile paper_profile; + memset(&kbuf, 0, sizeof(kbuf)); + keygen_profile_clear(&ind_profile); + keygen_profile_clear(&paper_profile); +#if BATCH_KEYGEN_INTERNAL_MATERIAL + int keygen_mat_shared = 0; +#endif + + auto print_keygen_profile_line = [&](const char *label, + float total_ms, + const char *sample_mode, + const KeygenProfile &profile, + int include_shared_a) { + double ops = ops_from_ms((double)N, total_ms); + printf(" %-14s %8d %10.3f ms %12.0f ops/s [", + label, N, total_ms, ops); + if (include_shared_a) { + printf("sharedA %.3f ", profile.shared_a_ms); + } + printf("sample %s sample_total %.3f sample_launch_gap %.3f copy %.3f ntt %.3f matvec %.3f post %.3f p2r %.3f pack_outer %.3f material %.3f]", + sample_mode, + profile.sample_ms, + profile.sample_launch_gap_ms, + profile.copy_ms, + profile.ntt_ms, + profile.matvec_ms, + profile.post_ms, + profile.p2r_ms, + profile.pack_ms, + profile.material_ms); + printf(" sample_active[seed_expand %.3f matrixA_active %.3f eta_active %.3f]", + profile.seed_expand_ms, + profile.matrix_a_sample_ms, + profile.secret_eta_sample_ms); + { + float pack_gap = profile.pack_inner_ms - profile.pack_fused_ms - + profile.pack_body_ms - profile.tr_hash_ms; + if (pack_gap < 0.0f) pack_gap = 0.0f; + printf(" pack[inner %.3f fused %.3f body %.3f tr %.3f gap %.3f]", + profile.pack_inner_ms, + profile.pack_fused_ms, + profile.pack_body_ms, + profile.tr_hash_ms, + pack_gap); + } + if (profile.matrix_a_coop_lanes > 0 || profile.secret_eta_coop_lanes > 0) { + printf(" coop_lanes[matA %d eta %d] coop_ms[matA %.3f eta %.3f]", + profile.matrix_a_coop_lanes, + profile.secret_eta_coop_lanes, + profile.matrix_a_coop_ms, + profile.secret_eta_coop_ms); + } + if (profile.pack_header_ms > 0.0f || profile.pack_t1_ms > 0.0f || + profile.pack_eta_ms > 0.0f || profile.pack_t0_ms > 0.0f) { + printf(" split[hdr %.3f t1 %.3f eta %.3f t0 %.3f]", + profile.pack_header_ms, + profile.pack_t1_ms, + profile.pack_eta_ms, + profile.pack_t0_ms); + } + printf("\n"); + }; + + if (!quiet) { + printf("--- [Batch=%d] Warp-parallel-SoA pipeline ---\n", N); + printf(" mode=%s%s\n", + g_bench_independent ? "independent-real-batch" : "paper-4090-style", + g_profile ? " profile=on" : ""); + printf(" build: tr_hash_fixed=%d material=%s sign=%s sample_ind=%s sample_paper=%s\n", + BATCH_KEYGEN_TR_HASH_FIXED, + internal_material_mode_name(), + sign_precomp_mode_name(), + keygen_ind_sample_mode_name(), + keygen_paper_sample_mode_name()); + fflush(stdout); + } + + CUDA_CHECK(hipEventCreate(&ev0)); + CUDA_CHECK(hipEventCreate(&ev1)); + + CUDA_CHECK(hipMalloc(&d_pk_one, CRYPTO_PUBLICKEYBYTES)); + CUDA_CHECK(hipMalloc(&d_sk_one, CRYPTO_SECRETKEYBYTES)); + + /* ================================================================ + * [2a] 分解式 Keygen + * + * Pipeline: sample → copy → NTT → matvec → reduce → INVNTT → add → pack + * 每步使用最优 kernel 配置: + * sample: 2 threads/instance sub-warp (SHAKE-heavy, 低并行) + * pack: 32 threads/block, 融合 power2round + * NTT: 128 threads/block (shared-memory 蝶形) + * matvec: dim3(B,K) × N threads (2D grid, 每系数一线程) + * 元素运算: 256 threads/block + * ================================================================ */ + { + /* 采样 kernel 需要较大栈 (SHAKE 展开矩阵 A) */ + size_t kg_stack = 48u * 1024u; + if (hipDeviceSetLimit(hipLimitStackSize, kg_stack) != hipSuccess) { + hipGetLastError(); + kg_stack = 64u * 1024u; + hipDeviceSetLimit(hipLimitStackSize, kg_stack); + hipGetLastError(); + } + + if (batch_keygen_alloc(&kbuf, N) != 0) { + printf(" [Keygen] batch_keygen_alloc FAILED\n"); + rc = -1; goto cleanup; + } + + CUDA_CHECK(hipMalloc(&d_base_seed, SEEDBYTES)); + CUDA_CHECK(hipMalloc(&d_shared_rho, SEEDBYTES)); + CUDA_CHECK(hipMemcpy(d_base_seed, h_seed, SEEDBYTES, hipMemcpyHostToDevice)); + + printf(" Operation Batch Time(ms) Throughput\n"); + printf(" --------- ----- -------- ----------\n"); + + for (int w = 0; w < WARMUP_ITERS; w++) { + if (batch_keygen_pipeline_warp_opt(kbuf.d_pks, kbuf.d_sks, d_base_seed, &kbuf, N, NULL, 0, 1) != 0) { + rc = -1; goto cleanup; + } + } + CUDA_CHECK(hipDeviceSynchronize()); + CUDA_CHECK(hipEventRecord(ev0)); + for (int it = 0; it < bench_iters; it++) { + if (batch_keygen_pipeline_warp_opt(kbuf.d_pks, kbuf.d_sks, d_base_seed, &kbuf, N, NULL, 0, 1) != 0) { + rc = -1; goto cleanup; + } + } + CUDA_CHECK(hipEventRecord(ev1)); + CUDA_CHECK(hipEventSynchronize(ev1)); + CUDA_CHECK(hipEventElapsedTime(&ms, ev0, ev1)); + ms_keygen_old = ms / bench_iters; + printf(" %-14s %8d %10.3f ms %12.0f ops/s [baseline]\n", + "Keygen-old", N, ms_keygen_old, ops_from_ms((double)N, ms_keygen_old)); + + for (int w = 0; w < WARMUP_ITERS; w++) { + if (batch_keygen_pipeline_warp_opt(kbuf.d_pks, kbuf.d_sks, d_base_seed, &kbuf, N, NULL, 0, 1) != 0) { + rc = -1; goto cleanup; + } + } + CUDA_CHECK(hipDeviceSynchronize()); + CUDA_CHECK(hipEventRecord(ev0)); + for (int it = 0; it < bench_iters; it++) { + if (batch_keygen_pipeline_warp_opt(kbuf.d_pks, kbuf.d_sks, d_base_seed, &kbuf, N, NULL, 0, 1) != 0) { + rc = -1; goto cleanup; + } + } + CUDA_CHECK(hipEventRecord(ev1)); + CUDA_CHECK(hipEventSynchronize(ev1)); + CUDA_CHECK(hipEventElapsedTime(&ms, ev0, ev1)); + ms_keygen_ind = ms / bench_iters; + if (g_profile) { + if (batch_keygen_pipeline_warp_opt(kbuf.d_pks, kbuf.d_sks, d_base_seed, &kbuf, N, &ind_profile, 0, 1) != 0) { + rc = -1; goto cleanup; + } + CUDA_CHECK(hipDeviceSynchronize()); + print_keygen_profile_line("Keygen-ind-x", ms_keygen_ind, + keygen_ind_sample_mode_name(), ind_profile, 0); + } else { + printf(" %-14s %8d %10.3f ms %12.0f ops/s [sample %s]\n", + "Keygen-ind-x", N, ms_keygen_ind, ops_from_ms((double)N, ms_keygen_ind), + keygen_ind_sample_mode_name()); + } + + for (int w = 0; w < WARMUP_ITERS; w++) { + if (batch_keygen_create_shared_rho_a(&kbuf, d_shared_rho, d_base_seed) != 0 || + batch_keygen_pipeline_paper_shared_rho_a( + kbuf.d_pks, kbuf.d_sks, d_base_seed, d_shared_rho, &kbuf, N, NULL, 0, 1) != 0) { + rc = -1; goto cleanup; + } + } + CUDA_CHECK(hipDeviceSynchronize()); + CUDA_CHECK(hipEventRecord(ev0)); + for (int it = 0; it < bench_iters; it++) { + if (batch_keygen_create_shared_rho_a(&kbuf, d_shared_rho, d_base_seed) != 0 || + batch_keygen_pipeline_paper_shared_rho_a( + kbuf.d_pks, kbuf.d_sks, d_base_seed, d_shared_rho, &kbuf, N, NULL, 0, 1) != 0) { + rc = -1; goto cleanup; + } + } + CUDA_CHECK(hipEventRecord(ev1)); + CUDA_CHECK(hipEventSynchronize(ev1)); + CUDA_CHECK(hipEventElapsedTime(&ms, ev0, ev1)); + ms_keygen_paper = ms / bench_iters; + if (g_profile) { + if (batch_keygen_create_shared_rho_a(&kbuf, d_shared_rho, d_base_seed, &paper_profile) != 0 || + batch_keygen_pipeline_paper_shared_rho_a( + kbuf.d_pks, kbuf.d_sks, d_base_seed, d_shared_rho, &kbuf, N, &paper_profile, 0, 1) != 0) { + rc = -1; goto cleanup; + } + CUDA_CHECK(hipDeviceSynchronize()); + print_keygen_profile_line("Keygen-paper", ms_keygen_paper, + keygen_paper_sample_mode_name(), paper_profile, 1); + } else { + printf(" %-14s %8d %10.3f ms %12.0f ops/s [sharedA %.3f sample %s]\n", + "Keygen-paper", N, ms_keygen_paper, ops_from_ms((double)N, ms_keygen_paper), + paper_profile.shared_a_ms, keygen_paper_sample_mode_name()); + } + + ms_keygen = ms_keygen_old; + chosen_keygen_label = "independent-old"; + if (ms_keygen_ind > 0.0f && ms_keygen_ind < ms_keygen) { + ms_keygen = ms_keygen_ind; + chosen_keygen_label = "independent-opt"; + } + if (!g_bench_independent && ms_keygen_paper > 0.0f && ms_keygen_paper < ms_keygen) { + ms_keygen = ms_keygen_paper; + chosen_keygen_label = "paper-shared-rhoA"; + } + + if (strcmp(chosen_keygen_label, "paper-shared-rhoA") == 0) { +#if BATCH_KEYGEN_INTERNAL_MATERIAL + keygen_mat_shared = 1; +#endif + batch_keygen_pipeline_paper_shared_rho_a( + kbuf.d_pks, kbuf.d_sks, d_base_seed, d_shared_rho, &kbuf, N, NULL, 0, 1); + } else if (strcmp(chosen_keygen_label, "independent-opt") == 0) { +#if BATCH_KEYGEN_INTERNAL_MATERIAL + keygen_mat_shared = 0; +#endif + batch_keygen_pipeline_warp_opt(kbuf.d_pks, kbuf.d_sks, d_base_seed, &kbuf, N, NULL, 0, 1); + } else { +#if BATCH_KEYGEN_INTERNAL_MATERIAL + keygen_mat_shared = 0; +#endif + batch_keygen_pipeline_warp_opt(kbuf.d_pks, kbuf.d_sks, d_base_seed, &kbuf, N, NULL, 0, 1); + } + CUDA_CHECK(hipDeviceSynchronize()); + if (check_device_key_material_prefix(kbuf.d_pks, kbuf.d_sks, + N, 8, "Keygen-selected") != 0) { + rc = -1; goto cleanup; + } + + /* 保存 instance[0] 的 pk/sk 供后续 sign+verify 使用 */ + CUDA_CHECK(hipMemcpy(d_pk_one, kbuf.d_pks, + CRYPTO_PUBLICKEYBYTES, hipMemcpyDeviceToDevice)); + CUDA_CHECK(hipMemcpy(d_sk_one, kbuf.d_sks, + CRYPTO_SECRETKEYBYTES, hipMemcpyDeviceToDevice)); + + hipFree(d_shared_rho); + hipFree(d_base_seed); + } + + /* ================================================================ + * [2b] 预计算签名 (monolithic, 每线程独立) + * + * Sign 使用 rejection loop, 不易分解为 pipeline. + * 使用共享密钥 precomp_t + 每线程独立签名. + * ================================================================ */ + { + /* 单线程创建预计算需要大栈 */ + size_t sign_precomp_stack = 128u * 1024u; + hipDeviceSetLimit(hipLimitStackSize, sign_precomp_stack); + hipGetLastError(); + + precomp_t *d_pc = nullptr; + CUDA_CHECK(hipMalloc(&d_pc, sizeof(precomp_t))); + #if BATCH_KEYGEN_INTERNAL_MATERIAL + batch_keygen_material_to_precomp_kernel<<<1, 1>>>( + d_pc, + kbuf.d_mat, kbuf.d_s1hat, kbuf.d_s2_ntt, kbuf.d_t0_ntt, + kbuf.d_buf, kbuf.d_tr, 0, keygen_mat_shared); + #else + kernel_create_precomp<<<1, 1>>>(d_pc, d_pk_one, d_sk_one); + #endif + CUDA_CHECK(hipDeviceSynchronize()); + + /* 签名使用较小栈 */ + size_t sign_stack = 64u * 1024u; + hipDeviceSetLimit(hipLimitStackSize, sign_stack); + hipGetLastError(); + + /* 分配签名 SoA 缓冲区 */ + uint8_t *d_sig_soa = nullptr; + size_t *d_siglen = nullptr; + int *d_results = nullptr; + int *h_results = nullptr; + uint8_t *d_msg = nullptr, *d_rnd = nullptr, *d_pre_d = nullptr; + size_t mem_sig = (size_t)N * CRYPTO_BYTES; + + h_results = (int *)calloc(N, sizeof(int)); + CUDA_CHECK(hipMalloc(&d_sig_soa, mem_sig)); + CUDA_CHECK(hipMalloc(&d_siglen, (size_t)N * sizeof(size_t))); + CUDA_CHECK(hipMalloc(&d_results, (size_t)N * sizeof(int))); + CUDA_CHECK(hipMalloc(&d_msg, mlen)); + CUDA_CHECK(hipMalloc(&d_rnd, RNDBYTES > 0 ? RNDBYTES : 1)); + CUDA_CHECK(hipMalloc(&d_pre_d, prelen > 0 ? prelen : 1)); + CUDA_CHECK(hipMemcpy(d_msg, h_msg, mlen, hipMemcpyHostToDevice)); +#if RNDBYTES > 0 + CUDA_CHECK(hipMemcpy(d_rnd, h_rnd, RNDBYTES, hipMemcpyHostToDevice)); +#endif + if (prelen > 0) + CUDA_CHECK(hipMemcpy(d_pre_d, h_pre, prelen, hipMemcpyHostToDevice)); + + int grid = (N + BLOCK_SIZE - 1) / BLOCK_SIZE; + sign_cache_t *d_sign_cache = nullptr; + float ms_sign_cached = -1.0f; + float ms_sign_mono = -1.0f; + float ms_sign_warp = -1.0f; + int cached_ok = 0; + int mono_ok = 0; + int warp_ok = 0; + int sign_path = -1; /* 0=mono, 1=cached, 2=decomp, 3=warp, 4=warp-cached */ + unsigned long long *d_warp_stats = nullptr; + unsigned long long h_warp_stats[WP_SIGN_STAT_COUNT]; + int warp_available = 0; + size_t warp_smem = 0; + +#if BATCH_SIGN_MONO_ENABLE && BATCH_SIGN_WARP_ENABLE + warp_smem = batch_sign_warp_smem_bytes(); + { + hipError_t we = batch_sign_warp_set_smem_attributes(); + if (we == hipSuccess) { + warp_available = 1; + } else { + hipGetLastError(); + if (g_profile) + printf(" [Sign-warp] disabled: dynamic smem request %zu bytes/block rejected (%s)\n", + warp_smem, hipGetErrorString(we)); + } + } + if (warp_available && (g_profile || BATCH_SIGN_WARP_PROFILE)) { + CUDA_CHECK(hipMalloc(&d_warp_stats, + (size_t)WP_SIGN_STAT_COUNT * sizeof(unsigned long long))); + } +#endif + const int sign_large_batch = + (BATCH_SIGN_LARGE_STRATEGY_ENABLE && N >= BATCH_SIGN_LARGE_BATCH_THRESHOLD); + const int sign_real_nonce_batch = +#if BATCH_SIGN_NONCE_DIVERSIFY + 1; +#else + 0; +#endif + const int prefer_warp_large = + sign_large_batch && sign_real_nonce_batch && warp_available && !g_profile; + +#if BATCH_SIGN_MONO_ENABLE && BATCH_SIGN_PRECOMP_REUSE + if (!g_bench_independent) { + CUDA_CHECK(hipMalloc(&d_sign_cache, sizeof(sign_cache_t))); + kernel_create_sign_cache<<<1, 1>>>( + d_sign_cache, d_pc, d_msg, mlen, d_pre_d, prelen, d_rnd); + CUDA_CHECK(hipDeviceSynchronize()); + + if (!prefer_warp_large) { + for (int w = 0; w < WARMUP_ITERS; w++) { + CUDA_CHECK(hipMemset(d_results, 0, N * sizeof(int))); + kernel_batch_sign_precomp_cached<<>>( + d_sig_soa, d_siglen, d_sign_cache, d_pc, d_results, N, 0); + CUDA_CHECK(hipDeviceSynchronize()); + } + + CUDA_CHECK(hipEventRecord(ev0)); + for (int it = 0; it < bench_iters; it++) { + CUDA_CHECK(hipMemset(d_results, 0, N * sizeof(int))); + kernel_batch_sign_precomp_cached<<>>( + d_sig_soa, d_siglen, d_sign_cache, d_pc, d_results, N, 0); + } + CUDA_CHECK(hipEventRecord(ev1)); + CUDA_CHECK(hipEventSynchronize(ev1)); + CUDA_CHECK(hipEventElapsedTime(&ms, ev0, ev1)); + ms_sign_cached = ms / bench_iters; + + CUDA_CHECK(hipMemcpy(h_results, d_results, N * sizeof(int), hipMemcpyDeviceToHost)); + cached_ok = (count_failures(h_results, N) == 0); + if (g_profile) { + double cached_ops = ops_from_ms((double)N, ms_sign_cached); + printf(" %-14s %8d %10.3f ms %12.0f ops/s [%s]\n", + "Sign-cached", N, ms_sign_cached, cached_ops, + cached_ok ? "PASS" : "FAIL"); + } + if (cached_ok) { + ms_sign = ms_sign_cached; + sign_path = 1; + chosen_sign_label = "precomp-cached"; + } + } + } +#endif + + if (BATCH_SIGN_MONO_ENABLE && !prefer_warp_large) { + for (int w = 0; w < WARMUP_ITERS; w++) { + CUDA_CHECK(hipMemset(d_results, 0, N * sizeof(int))); + kernel_batch_sign_precomp<<>>( + d_sig_soa, d_siglen, d_msg, mlen, d_pre_d, prelen, + d_rnd, d_pc, d_results, N, 0); + CUDA_CHECK(hipDeviceSynchronize()); + } + + CUDA_CHECK(hipEventRecord(ev0)); + for (int it = 0; it < bench_iters; it++) { + CUDA_CHECK(hipMemset(d_results, 0, N * sizeof(int))); + kernel_batch_sign_precomp<<>>( + d_sig_soa, d_siglen, d_msg, mlen, d_pre_d, prelen, + d_rnd, d_pc, d_results, N, 0); + } + CUDA_CHECK(hipEventRecord(ev1)); + CUDA_CHECK(hipEventSynchronize(ev1)); + CUDA_CHECK(hipEventElapsedTime(&ms, ev0, ev1)); + ms_sign_mono = ms / bench_iters; + if (g_profile) { + double mono_ops = ops_from_ms((double)N, ms_sign_mono); + printf(" %-14s %8d %10.3f ms %12.0f ops/s [precomp-monolithic]\n", + "Sign-mono-old", N, ms_sign_mono, mono_ops); + } + CUDA_CHECK(hipMemcpy(h_results, d_results, N * sizeof(int), hipMemcpyDeviceToHost)); + mono_ok = (count_failures(h_results, N) == 0); + if (g_profile) + check_results(h_results, N, "Sign-mono-old"); + if (mono_ok && (sign_path < 0 || ms_sign_mono < ms_sign)) { + ms_sign = ms_sign_mono; + sign_path = 0; + chosen_sign_label = "precomp-monolithic"; + } + } + +#if BATCH_SIGN_MONO_ENABLE && BATCH_SIGN_WARP_ENABLE + { + const int skip_warp_large_paper = + sign_large_batch && !sign_real_nonce_batch && sign_path >= 0 && !g_profile; + if (warp_available && !skip_warp_large_paper) { + const int warp_cached = +#if BATCH_SIGN_PRECOMP_REUSE + (!g_bench_independent && d_sign_cache != nullptr); +#else + 0; +#endif + int grid_warp = (N + WP_SIGN_WARPS_BLOCK - 1) / WP_SIGN_WARPS_BLOCK; + const char *warp_stage = warp_cached ? "Sign-warp-cached" : "Sign-warp"; + + for (int w = 0; w < WARMUP_ITERS; w++) { + CUDA_CHECK(hipMemset(d_results, 0, N * sizeof(int))); + if (warp_cached) { + kernel_batch_sign_warp_precomp_cached<<>>( + d_sig_soa, d_siglen, (const uint8_t *)d_sign_cache, + d_pc, d_results, N, 0, nullptr); + } else { + kernel_batch_sign_warp_precomp<<>>( + d_sig_soa, d_siglen, d_msg, mlen, d_pre_d, prelen, + d_rnd, d_pc, d_results, N, 0, nullptr); + } + CUDA_CHECK(hipGetLastError()); + CUDA_CHECK(hipDeviceSynchronize()); + } + + CUDA_CHECK(hipMemset(d_results, 0, N * sizeof(int))); + if (d_warp_stats) + CUDA_CHECK(hipMemset(d_warp_stats, 0, + (size_t)WP_SIGN_STAT_COUNT * sizeof(unsigned long long))); + + CUDA_CHECK(hipEventRecord(ev0)); + for (int it = 0; it < bench_iters; it++) { + CUDA_CHECK(hipMemset(d_results, 0, N * sizeof(int))); + if (it == 0 && d_warp_stats) + CUDA_CHECK(hipMemset(d_warp_stats, 0, + (size_t)WP_SIGN_STAT_COUNT * sizeof(unsigned long long))); + if (warp_cached) { + kernel_batch_sign_warp_precomp_cached<<>>( + d_sig_soa, d_siglen, (const uint8_t *)d_sign_cache, + d_pc, d_results, N, 0, (it == 0) ? d_warp_stats : nullptr); + } else { + kernel_batch_sign_warp_precomp<<>>( + d_sig_soa, d_siglen, d_msg, mlen, d_pre_d, prelen, + d_rnd, d_pc, d_results, N, 0, (it == 0) ? d_warp_stats : nullptr); + } + CUDA_CHECK(hipGetLastError()); + } + CUDA_CHECK(hipEventRecord(ev1)); + CUDA_CHECK(hipEventSynchronize(ev1)); + CUDA_CHECK(hipEventElapsedTime(&ms, ev0, ev1)); + ms_sign_warp = ms / bench_iters; + + CUDA_CHECK(hipMemcpy(h_results, d_results, N * sizeof(int), hipMemcpyDeviceToHost)); + warp_ok = (count_failures(h_results, N) == 0); + if (g_profile) { + double warp_ops = ops_from_ms((double)N, ms_sign_warp); + printf(" %-14s %8d %10.3f ms %12.0f ops/s [%s smem=%zu]\n", + warp_stage, N, ms_sign_warp, warp_ops, + warp_ok ? "PASS" : "FAIL", warp_smem); + if (d_warp_stats) { + CUDA_CHECK(hipMemcpy(h_warp_stats, d_warp_stats, + (size_t)WP_SIGN_STAT_COUNT * sizeof(unsigned long long), + hipMemcpyDeviceToHost)); + double avg_attempts = (N > 0) + ? (double)h_warp_stats[WP_SIGN_STAT_ATTEMPTS] / (double)N + : 0.0; + printf(" [%-14s] attempts=%.3f reject{s2=%llu z=%llu t0=%llu hint=%llu} ok=%llu\n", + warp_stage, avg_attempts, + h_warp_stats[WP_SIGN_STAT_REJ_S2], + h_warp_stats[WP_SIGN_STAT_REJ_Z], + h_warp_stats[WP_SIGN_STAT_REJ_T0], + h_warp_stats[WP_SIGN_STAT_REJ_HINT], + h_warp_stats[WP_SIGN_STAT_OK]); + } + } + + if (warp_ok && (sign_path < 0 || ms_sign_warp < ms_sign)) { + if (d_sigs_for_verify) { + hipFree(d_sigs_for_verify); + d_sigs_for_verify = nullptr; + } + CUDA_CHECK(hipMalloc(&d_sigs_for_verify, (size_t)N * CRYPTO_BYTES)); + { + int total = N * CRYPTO_BYTES; + int tpb = 256; + int nblk = (total + tpb - 1) / tpb; + kernel_wp_sign_sig_soa_to_aos<<>>( + d_sigs_for_verify, d_sig_soa, N); + CUDA_CHECK(hipGetLastError()); + CUDA_CHECK(hipDeviceSynchronize()); + } + verify_uses_batch_sigs = 1; + ms_sign = ms_sign_warp; + sign_path = warp_cached ? 4 : 3; + chosen_sign_label = warp_cached ? "precomp-warp-cached" : "precomp-warp"; + } + } + } +#endif + + /* ---------------------------------------------------------------- + * [2b-decomp] 分解式批量签名 pipeline (算子级并行) + * + * 优化原理: + * · y 采样: 全批次并行 (per-instance 独立 SHAKE), 替代串行循环 + * · NTT(y): shared-memory 批量 kernel (128 线程/poly, 复用 batch_ntt) + * · w = A·y: 共享矩阵 2D grid matvec (复用 batch_verify_matvec_kernel) + * · z/cs2/ct0: cp·shared_vec 批量 pointwise + INVNTT + * · 检查/提示/打包: per-instance 单线程 (小栈, 高并发) + * · 已完成实例通过 d_done 标志跳过, 减少拒绝轮尾部浪费 + * ---------------------------------------------------------------- */ + { + const int skip_decomp_large_best = + sign_large_batch && sign_path >= 0 && !g_profile; + if ((g_profile || BATCH_SIGN_DECOMP_ENABLE) && !skip_decomp_large_best) { + /* 分解式 pipeline 需要适中的栈 (check_pack kernel 最重, ~30KB/线程) */ + size_t decomp_stack = 64u * 1024u; + hipDeviceSetLimit(hipLimitStackSize, decomp_stack); + hipGetLastError(); + + BatchSignPipeline bsp; + memset(&bsp, 0, sizeof(bsp)); + if (batch_sign_alloc(&bsp, N) == 0) { + const char *decomp_policy_label = nullptr; + BatchSignRuntimeOptions decomp_runtime = + select_decomp_runtime_options(N, g_bench_independent, &decomp_policy_label); + /* Warmup — 1 次 (含 rejection loop, 不计入时间) */ + int warm_rounds = 0, warm_done = 0; + batch_sign_pipeline_ex(&bsp, N, d_pc, d_msg, mlen, d_pre_d, prelen, + d_rnd, &decomp_runtime, &warm_rounds, &warm_done); + + /* Timed — BENCH_ITERS 次取平均 */ + float ms_sdp = 0; + int last_rounds = 0, last_done = 0; + CUDA_CHECK(hipEventRecord(ev0)); + for (int it = 0; it < bench_iters; it++) + batch_sign_pipeline_ex(&bsp, N, d_pc, d_msg, mlen, d_pre_d, prelen, + d_rnd, &decomp_runtime, &last_rounds, &last_done); + CUDA_CHECK(hipEventRecord(ev1)); + CUDA_CHECK(hipEventSynchronize(ev1)); + CUDA_CHECK(hipEventElapsedTime(&ms_sdp, ev0, ev1)); + ms_sdp /= bench_iters; + if (g_profile) { + double decomp_ops = ops_from_ms((double)N, ms_sdp); + printf(" %-14s %8d %10.3f ms %12.0f ops/s [policy=%s cp_fuse=%d check=%d ctrl=%d/%d rounds=%d done=%d]\n", + "Sign-decomp", N, ms_sdp, decomp_ops, + decomp_policy_label, + decomp_runtime.cp_fuse_enable, + decomp_runtime.check_interval, + decomp_runtime.hash_tpb, + decomp_runtime.check_tpb, + last_rounds, last_done); + } + + /* 验证: 检查 d_done */ + int *h_dp_done = (int *)malloc(N * sizeof(int)); + hipMemcpy(h_dp_done, bsp.d_done, N * sizeof(int), hipMemcpyDeviceToHost); + int dp_pass = 0; + for (int i = 0; i < N; i++) if (h_dp_done[i]) dp_pass++; + free(h_dp_done); + + if (g_profile) { + if (dp_pass == N) + printf(" [Sign-decomp] correctness: all %d PASS (last rounds=%d)\n", + dp_pass, last_rounds); + else + printf(" [Sign-decomp] WARN: only %d/%d completed\n", dp_pass, N); + } + + if (dp_pass == N) { + if (sign_path < 0 || ms_sdp < ms_sign) { + ms_sign = ms_sdp; + if (d_sigs_for_verify) { + hipFree(d_sigs_for_verify); + d_sigs_for_verify = nullptr; + } + CUDA_CHECK(hipMalloc(&d_sigs_for_verify, (size_t)N * CRYPTO_BYTES)); + CUDA_CHECK(hipMemcpy(d_sigs_for_verify, bsp.d_sigs, + (size_t)N * CRYPTO_BYTES, + hipMemcpyDeviceToDevice)); + verify_uses_batch_sigs = 1; + sign_path = 2; + chosen_sign_label = (decomp_policy_label && strcmp(decomp_policy_label, "base") != 0) + ? "decomp-adaptive" + : "decomp-pipeline"; + } + } + + batch_sign_free(&bsp); + } else { + printf(" [Sign-decomp] alloc FAILED (out of VRAM)\n"); + } + } + } + + if (sign_path < 0) { + printf(" [Sign] FAIL: no enabled signing path completed\n"); + rc = -1; goto cleanup; + } + if (!g_profile || quiet) + printf(" [Sign] correctness: all %d PASS [%s]\n", N, chosen_sign_label); + + /* 为 verify 准备: decomp 成功时使用整批签名, 否则生成 1 个有效签名并广播 */ + uint8_t *d_sig_one = nullptr; + if (!verify_uses_batch_sigs) { + /* 用单线程签名生成 1 个有效签名 */ + size_t *d_siglen_one = nullptr; + CUDA_CHECK(hipMalloc(&d_sig_one, CRYPTO_BYTES)); + CUDA_CHECK(hipMalloc(&d_siglen_one, sizeof(size_t))); + CUDA_CHECK(hipMemset(d_sig_one, 0, CRYPTO_BYTES)); + + size_t big_stack = 128u * 1024u; + hipDeviceSetLimit(hipLimitStackSize, big_stack); + hipGetLastError(); + + /* 用预计算签名 kernel 生成 1 份签名 */ + int *d_vr = nullptr; + CUDA_CHECK(hipMalloc(&d_vr, sizeof(int))); + CUDA_CHECK(hipMemset(d_vr, 0, sizeof(int))); + if (sign_path == 1) { + kernel_batch_sign_precomp_cached<<<1, 1>>>( + d_sig_one, d_siglen_one, d_sign_cache, d_pc, d_vr, 1, 0); + } else { + kernel_batch_sign_precomp<<<1, 1>>>( + d_sig_one, d_siglen_one, d_msg, mlen, d_pre_d, prelen, + d_rnd, d_pc, d_vr, 1, 0); + } + CUDA_CHECK(hipDeviceSynchronize()); + hipFree(d_siglen_one); hipFree(d_vr); + } + + free(h_results); + hipFree(d_sig_soa); hipFree(d_siglen); hipFree(d_results); + hipFree(d_msg); hipFree(d_rnd); hipFree(d_pre_d); + hipFree(d_warp_stats); + hipFree(d_sign_cache); + hipFree(d_pc); + + /* ================================================================ + * [2c] 分解式 Verify + * + * Pipeline: precompute → unpack → chknorm → NTT(z) → matvec → + * challenge → NTT(cp) → sub_cp_t1 → reduce → INVNTT → + * normalize → use_hint → compare + * 矩阵 A 和 t1_hat 所有实例共享 (只存一份) + * ================================================================ */ + { + /* 预计算需要大栈 (单线程) */ + size_t vc_stack = 128u * 1024u; + hipDeviceSetLimit(hipLimitStackSize, vc_stack); + hipGetLastError(); + + BatchVerifyBuffers vbuf; + memset(&vbuf, 0, sizeof(vbuf)); + if (batch_verify_alloc(&vbuf, N) != 0) { + printf(" [Verify] batch_verify_alloc FAILED\n"); + hipFree(d_sig_one); + hipFree(d_sigs_for_verify); + rc = -1; goto cleanup; + } + + if (verify_uses_batch_sigs) { + CUDA_CHECK(hipMemcpy(vbuf.d_raw_sigs, d_sigs_for_verify, + (size_t)N * CRYPTO_BYTES, + hipMemcpyDeviceToDevice)); + hipFree(d_sigs_for_verify); + d_sigs_for_verify = nullptr; + } else { + int total = N * CRYPTO_BYTES; + int tpb = 256; + int nblk = (total + tpb - 1) / tpb; + kernel_broadcast_sig_aos<<>>( + vbuf.d_raw_sigs, d_sig_one, N, CRYPTO_BYTES); + CUDA_CHECK(hipDeviceSynchronize()); + hipFree(d_sig_one); + d_sig_one = nullptr; + } + + /* 预计算: 直接复用 keygen 内部材料, 跳过 unpack_pk/matrix_expand/t1 NTT */ + { +#if BATCH_KEYGEN_INTERNAL_MATERIAL + int total = PARAM_K * PARAM_L * PARAM_N; + if (PARAM_K * PARAM_N > total) total = PARAM_K * PARAM_N; + if (TRBYTES > total) total = TRBYTES; + int tpb = 256; + int nblk = (total + tpb - 1) / tpb; + batch_keygen_material_to_verify_kernel<<>>( + vbuf.d_mat, vbuf.d_t1_hat, vbuf.d_tr, + kbuf.d_mat, kbuf.d_t1_hat, kbuf.d_tr, + 0, keygen_mat_shared); +#else + batch_verify_precompute_kernel<<<1, 1>>>( + vbuf.d_mat, vbuf.d_t1_hat, vbuf.d_tr, d_pk_one); +#endif + } + CUDA_CHECK(hipDeviceSynchronize()); + + /* Verify 分解 kernel 用小栈 */ + size_t verify_stack = 4u * 1024u; + if (hipDeviceSetLimit(hipLimitStackSize, verify_stack) != hipSuccess) { + hipGetLastError(); + verify_stack = 8u * 1024u; + hipDeviceSetLimit(hipLimitStackSize, verify_stack); + hipGetLastError(); + } + + /* 准备 per-instance 消息和 pre (所有实例用相同消息) */ + uint8_t *d_msgs_v = nullptr, *d_pre_v = nullptr; + CUDA_CHECK(hipMalloc(&d_msgs_v, (size_t)N * mlen)); + CUDA_CHECK(hipMalloc(&d_pre_v, prelen > 0 ? prelen : 1)); + { + uint8_t *h_msgs_v = (uint8_t *)malloc((size_t)N * mlen); + for (int i = 0; i < N; i++) + memcpy(h_msgs_v + (size_t)i * mlen, h_msg, mlen); + CUDA_CHECK(hipMemcpy(d_msgs_v, h_msgs_v, + (size_t)N * mlen, hipMemcpyHostToDevice)); + free(h_msgs_v); + } + if (prelen > 0) + CUDA_CHECK(hipMemcpy(d_pre_v, h_pre, prelen, hipMemcpyHostToDevice)); + + int *h_vresults = (int *)calloc(N, sizeof(int)); + + /* Warmup */ + for (int w = 0; w < WARMUP_ITERS; w++) { + batch_verify_pipeline_device_sigs(&vbuf, vbuf.d_raw_sigs, d_msgs_v, mlen, + d_pre_v, prelen, N, h_vresults); + hipDeviceSynchronize(); + } + + /* Timed (多次迭代取平均) */ + CUDA_CHECK(hipEventRecord(ev0)); + for (int it = 0; it < bench_iters; it++) + batch_verify_pipeline_device_sigs(&vbuf, vbuf.d_raw_sigs, d_msgs_v, mlen, + d_pre_v, prelen, N, h_vresults); + CUDA_CHECK(hipEventRecord(ev1)); + CUDA_CHECK(hipEventSynchronize(ev1)); + CUDA_CHECK(hipEventElapsedTime(&ms, ev0, ev1)); + ms_verify = ms / bench_iters; + { + int vpass = 0; + for (int i = 0; i < N; i++) if (h_vresults[i] == 0) vpass++; + if (vpass < N) verify_fails = N - vpass; + } + + free(h_vresults); + hipFree(d_msgs_v); hipFree(d_pre_v); + batch_verify_free(&vbuf); + } + } /* end sign scope (delayed close) */ + + /* ---- 性能报告 (始终打印, --quiet 只抑制 Phase 1 hex dump) ---- */ + kg_ops = ops_from_ms((double)N, ms_keygen); + sg_ops = ops_from_ms((double)N, ms_sign); + vf_ops = ops_from_ms((double)N, ms_verify); + printf(" %-14s %8d %10.3f ms %12.0f ops/s [%s]\n", + "Keygen", N, ms_keygen, kg_ops, chosen_keygen_label); + printf(" %-14s %8d %10.3f ms %12.0f ops/s [%s]\n", + "Sign", N, ms_sign, sg_ops, chosen_sign_label); + printf(" %-14s %8d %10.3f ms %12.0f ops/s\n", "Verify", N, ms_verify, vf_ops); + if (verify_fails > 0) { + printf(" [Verify] FAIL: %d/%d mismatched\n", verify_fails, N); + printf(" [WARN] %d verify failures!\n", verify_fails); + rc = -1; + } else { + printf(" [Verify] correctness: all %d PASS\n", N); + } + printf("\n"); + + /* 通过输出参数返回计时数据 (供外部扫描模式使用) */ + if (out_kg_ms) *out_kg_ms = ms_keygen; + if (out_sg_ms) *out_sg_ms = ms_sign; + if (out_vf_ms) *out_vf_ms = ms_verify; + +cleanup: + if (ev0) hipEventDestroy(ev0); + if (ev1) hipEventDestroy(ev1); + hipFree(d_pk_one); hipFree(d_sk_one); + hipFree(d_sigs_for_verify); + batch_keygen_free(&kbuf); + return rc; +} + +static int run_sample_only_batch( + int N, + const uint8_t *h_seed, + int quiet, + int bench_iters) +{ + int rc = 0; + BatchKeygenBuffers kbuf; + unsigned char *d_base_seed = nullptr; + unsigned char *d_shared_rho = nullptr; + KeygenSampleOnlyProfile profile_sum; + KeygenSampleOnlyProfile samples[SAMPLE_ONLY_ITERS]; + const int timed_iters = SAMPLE_ONLY_ITERS; + const char *active_mode = g_bench_independent + ? keygen_ind_sample_mode_name() + : keygen_paper_sample_mode_name(); + const int print_active = strcmp(active_mode, "old-fused") != 0; + + memset(&kbuf, 0, sizeof(kbuf)); + keygen_sample_only_profile_clear(&profile_sum); + + if (!quiet) { + printf("--- [Batch=%d] Sample-only microbench ---\n", N); + printf(" mode=%s%s\n", + g_bench_independent ? "independent-real-batch" : "paper-4090-style", + g_profile ? " profile=on" : ""); + printf(" build: tr_hash_fixed=%d material=%s sign=%s sample_ind=%s sample_paper=%s\n", + BATCH_KEYGEN_TR_HASH_FIXED, + internal_material_mode_name(), + sign_precomp_mode_name(), + keygen_ind_sample_mode_name(), + keygen_paper_sample_mode_name()); + fflush(stdout); + } + + { + size_t kg_stack = 48u * 1024u; + if (hipDeviceSetLimit(hipLimitStackSize, kg_stack) != hipSuccess) { + hipGetLastError(); + kg_stack = 64u * 1024u; + hipDeviceSetLimit(hipLimitStackSize, kg_stack); + hipGetLastError(); + } + + if (batch_keygen_alloc(&kbuf, N) != 0) { + printf(" [Sample-only] batch_keygen_alloc FAILED\n"); + rc = -1; goto cleanup; + } + + CUDA_CHECK(hipMalloc(&d_base_seed, SEEDBYTES)); + CUDA_CHECK(hipMemcpy(d_base_seed, h_seed, SEEDBYTES, hipMemcpyHostToDevice)); + if (!g_bench_independent) + CUDA_CHECK(hipMalloc(&d_shared_rho, SEEDBYTES)); + + for (int w = 0; w < WARMUP_ITERS; w++) { + KeygenSampleOnlyProfile warmup; + if (g_bench_independent) { + if (batch_keygen_sample_only_independent(&kbuf, d_base_seed, N, &warmup) != 0) { + rc = -1; goto cleanup; + } + } else { + if (batch_keygen_sample_only_paper(&kbuf, d_base_seed, d_shared_rho, N, &warmup) != 0) { + rc = -1; goto cleanup; + } + } + CUDA_CHECK(hipDeviceSynchronize()); + } + + for (int it = 0; it < timed_iters; it++) { + KeygenSampleOnlyProfile cur; + if (g_bench_independent) { + if (batch_keygen_sample_only_independent(&kbuf, d_base_seed, N, &cur) != 0) { + rc = -1; goto cleanup; + } + } else { + if (batch_keygen_sample_only_paper(&kbuf, d_base_seed, d_shared_rho, N, &cur) != 0) { + rc = -1; goto cleanup; + } + } + samples[it] = cur; + } + } + + (void)bench_iters; + profile_sum.old_fused_ms = median3f(samples[0].old_fused_ms, + samples[1].old_fused_ms, + samples[2].old_fused_ms); + profile_sum.shared_a_ms = median3f(samples[0].shared_a_ms, + samples[1].shared_a_ms, + samples[2].shared_a_ms); + profile_sum.split_seed_ms = median3f(samples[0].split_seed_ms, + samples[1].split_seed_ms, + samples[2].split_seed_ms); + profile_sum.split_matrix_a_ms = median3f(samples[0].split_matrix_a_ms, + samples[1].split_matrix_a_ms, + samples[2].split_matrix_a_ms); + profile_sum.split_eta_ms = median3f(samples[0].split_eta_ms, + samples[1].split_eta_ms, + samples[2].split_eta_ms); + profile_sum.split_total_ms = median3f(samples[0].split_total_ms, + samples[1].split_total_ms, + samples[2].split_total_ms); + profile_sum.split_launch_gap_ms = median3f(samples[0].split_launch_gap_ms, + samples[1].split_launch_gap_ms, + samples[2].split_launch_gap_ms); + profile_sum.split_matrix_a_coop_ms = median3f(samples[0].split_matrix_a_coop_ms, + samples[1].split_matrix_a_coop_ms, + samples[2].split_matrix_a_coop_ms); + profile_sum.split_eta_coop_ms = median3f(samples[0].split_eta_coop_ms, + samples[1].split_eta_coop_ms, + samples[2].split_eta_coop_ms); + profile_sum.split_matrix_a_coop_lanes = samples[0].split_matrix_a_coop_lanes; + profile_sum.split_eta_coop_lanes = samples[0].split_eta_coop_lanes; + + printf(" %-12s %-22s %8d %10.3f ms %12.0f ops/s\n", + "Sample", "old-fused", N, profile_sum.old_fused_ms, + ops_from_ms((double)N, profile_sum.old_fused_ms)); + if (profile_sum.shared_a_ms > 0.0f) { + printf(" %-12s %-22s %8d %10.3f ms %12.0f ops/s\n", + "MatrixA", "sharedA", N, profile_sum.shared_a_ms, + ops_from_ms(1.0, profile_sum.shared_a_ms)); + } + if (print_active) { + printf(" %-12s %-22s %8d %10.3f ms %12.0f ops/s [seed %.3f matA %.3f eta %.3f total %.3f gap %.3f]", + "Sample", active_mode, N, profile_sum.split_total_ms, + ops_from_ms((double)N, profile_sum.split_total_ms), + profile_sum.split_seed_ms, + profile_sum.split_matrix_a_ms, + profile_sum.split_eta_ms, + profile_sum.split_total_ms, + profile_sum.split_launch_gap_ms); + if (profile_sum.split_matrix_a_coop_lanes > 0 || + profile_sum.split_eta_coop_lanes > 0) { + printf(" coop_lanes[matA %d eta %d] coop_ms[matA %.3f eta %.3f]", + profile_sum.split_matrix_a_coop_lanes, + profile_sum.split_eta_coop_lanes, + profile_sum.split_matrix_a_coop_ms, + profile_sum.split_eta_coop_ms); + } + printf("\n"); + } + printf("\n"); + +cleanup: + hipFree(d_shared_rho); + hipFree(d_base_seed); + batch_keygen_free(&kbuf); + return rc; +} + +/* ================================================================ + * Phase 3: 批量吞吐量扫描 (--throughput) + * + * 自动循环 batch_size: 256,512,1024,2048,4096,8192,16384,32768 + * 每个 batch size 运行 THROUGHPUT_RUNS 次取平均, 输出 CSV 格式 + * 显存不足时跳过并记录 OOM + * ================================================================ */ +static void run_throughput_scan( + const uint8_t *h_seed, const uint8_t *h_rnd, + const uint8_t *h_msg, size_t mlen, + const uint8_t *h_pre, size_t prelen) +{ + int batch_sizes[] = {256, 512, 1024, 2048, 4096, 8192, 16384, 32768}; + int n_sizes = sizeof(batch_sizes) / sizeof(batch_sizes[0]); + + /* 创建 figure/ 目录 (保存架构图和 CSV) */ +#ifdef _WIN32 + _mkdir("figure"); +#else + mkdir("figure", 0755); +#endif + + FILE *csv = fopen("figure/throughput.csv", "w"); + if (!csv) { + printf("ERROR: cannot create figure/throughput.csv\n"); + return; + } + + /* CSV 表头 */ + fprintf(csv, "batch_size,keygen_ms,keygen_ops_s,sign_ms,sign_ops_s,verify_ms,verify_ops_s,notes\n"); + + printf("\n"); + printf("=== Batch Throughput Scan (avg of %d runs, CSV → figure/throughput.csv) ===\n", + THROUGHPUT_RUNS); + printf("%-10s %12s %14s %12s %14s %12s %14s\n", + "Batch", "Kg(ms)", "Kg(ops/s)", "Sg(ms)", "Sg(ops/s)", "Vf(ms)", "Vf(ops/s)"); + printf("%-10s %12s %14s %12s %14s %12s %14s\n", + "-----", "------", "----------", "------", "----------", "------", "----------"); + + for (int i = 0; i < n_sizes; i++) { + int N = batch_sizes[i]; + + float kg = 0, sg = 0, vf = 0; + int r = run_batch(N, h_seed, h_rnd, h_msg, mlen, h_pre, prelen, + 1 /* quiet */, THROUGHPUT_RUNS /* 10 timed iters */, + &kg, &sg, &vf); + + if (r != 0) { + printf("%-10d %12s %14s %12s %14s %12s %14s\n", + N, "FAIL", "FAIL", "FAIL", "FAIL", "FAIL", "FAIL"); + fprintf(csv, "%d,FAIL,FAIL,FAIL,FAIL,FAIL,FAIL,FAIL\n", N); + } else { + double kg_ops = ops_from_ms((double)N, kg); + double sg_ops = ops_from_ms((double)N, sg); + double vf_ops = ops_from_ms((double)N, vf); + printf("%-10d %12.3f %14.0f %12.3f %14.0f %12.3f %14.0f\n", + N, kg, kg_ops, sg, sg_ops, vf, vf_ops); + fprintf(csv, "%d,%.3f,%.0f,%.3f,%.0f,%.3f,%.0f,\n", + N, kg, kg_ops, sg, sg_ops, vf, vf_ops); + } + fflush(csv); + fflush(stdout); + } + + fclose(csv); + printf("\n[throughput] CSV saved to figure/throughput.csv\n"); +} + +static const char *arg_value(int argc, char **argv, const char *name) { + for (int i = 1; i + 1 < argc; i++) { + if (strcmp(argv[i], name) == 0) return argv[i + 1]; + } + return NULL; +} + +static int has_arg(int argc, char **argv, const char *name) { + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], name) == 0) return 1; + } + return 0; +} + +static int run_cli_mode(int argc, char **argv) { + int rc = 0; + int h_result = 0; + size_t h_siglen = 0; + uint8_t *h_msg = NULL; + size_t h_mlen = 0; + uint8_t *h_pk = NULL, *h_sk = NULL, *h_sig = NULL, *h_seed = NULL, *h_rnd = NULL; + uint8_t *d_pk = NULL, *d_sk = NULL, *d_sig = NULL, *d_msg = NULL, *d_seed = NULL, *d_rnd = NULL; + size_t *d_siglen = NULL; + int *d_result = NULL; + + const int do_keygen = has_arg(argc, argv, "--cli-keygen"); + const int do_sign = has_arg(argc, argv, "--cli-sign"); + const int do_verify = has_arg(argc, argv, "--cli-verify"); + if (!do_keygen && !do_sign && !do_verify) return 0; + + h_pk = (uint8_t *)calloc(1, CRYPTO_PUBLICKEYBYTES); + h_sk = (uint8_t *)calloc(1, CRYPTO_SECRETKEYBYTES); + h_sig = (uint8_t *)calloc(1, CRYPTO_BYTES); + h_seed = (uint8_t *)calloc(1, SEEDBYTES); +#if RNDBYTES > 0 + h_rnd = (uint8_t *)calloc(1, RNDBYTES); +#else + h_rnd = (uint8_t *)calloc(1, 1); +#endif + if (!h_pk || !h_sk || !h_sig || !h_seed || !h_rnd) { + fprintf(stderr, "CLI malloc failed\n"); + return 2; + } + + CUDA_CHECK(hipMalloc(&d_pk, CRYPTO_PUBLICKEYBYTES)); + CUDA_CHECK(hipMalloc(&d_sk, CRYPTO_SECRETKEYBYTES)); + CUDA_CHECK(hipMalloc(&d_sig, CRYPTO_BYTES)); + CUDA_CHECK(hipMalloc(&d_seed, SEEDBYTES)); +#if RNDBYTES > 0 + CUDA_CHECK(hipMalloc(&d_rnd, RNDBYTES)); +#else + CUDA_CHECK(hipMalloc(&d_rnd, 1)); +#endif + CUDA_CHECK(hipMalloc(&d_siglen, sizeof(size_t))); + CUDA_CHECK(hipMalloc(&d_result, sizeof(int))); + + if (do_keygen) { + const char *pk_out = arg_value(argc, argv, "--pk-out"); + const char *sk_out = arg_value(argc, argv, "--sk-out"); + const char *seed_in = arg_value(argc, argv, "--seed-in"); + if (!pk_out || !sk_out) { + fprintf(stderr, "--cli-keygen requires --pk-out and --sk-out\n"); + rc = 2; goto cleanup; + } + if (seed_in) { + if (read_file_exact_host(seed_in, h_seed, SEEDBYTES) != 0) { rc = 2; goto cleanup; } + } else { + fill_random_host(h_seed, SEEDBYTES); + } + CUDA_CHECK(hipMemcpy(d_seed, h_seed, SEEDBYTES, hipMemcpyHostToDevice)); + kernel_keygen_only<<<1, 1>>>(d_pk, d_sk, d_seed, d_result); + CUDA_CHECK(hipGetLastError()); + CUDA_CHECK(hipDeviceSynchronize()); + CUDA_CHECK(hipMemcpy(&h_result, d_result, sizeof(int), hipMemcpyDeviceToHost)); + if (h_result != 0) { fprintf(stderr, "CLI keygen failed: %d\n", h_result); rc = 3; goto cleanup; } + CUDA_CHECK(hipMemcpy(h_pk, d_pk, CRYPTO_PUBLICKEYBYTES, hipMemcpyDeviceToHost)); + CUDA_CHECK(hipMemcpy(h_sk, d_sk, CRYPTO_SECRETKEYBYTES, hipMemcpyDeviceToHost)); + if (write_file_all(pk_out, h_pk, CRYPTO_PUBLICKEYBYTES) != 0 || + write_file_all(sk_out, h_sk, CRYPTO_SECRETKEYBYTES) != 0) { rc = 2; goto cleanup; } + printf("CLI SIG keygen PASS pk=%d sk=%d\n", CRYPTO_PUBLICKEYBYTES, CRYPTO_SECRETKEYBYTES); + } else if (do_sign) { + const char *sk_in = arg_value(argc, argv, "--sk-in"); + const char *msg_in = arg_value(argc, argv, "--msg-in"); + const char *sig_out = arg_value(argc, argv, "--sig-out"); + const char *rnd_in = arg_value(argc, argv, "--rnd-in"); + if (!sk_in || !msg_in || !sig_out) { + fprintf(stderr, "--cli-sign requires --sk-in, --msg-in, and --sig-out\n"); + rc = 2; goto cleanup; + } + if (read_file_exact_host(sk_in, h_sk, CRYPTO_SECRETKEYBYTES) != 0 || + read_file_all(msg_in, &h_msg, &h_mlen) != 0) { rc = 2; goto cleanup; } +#if RNDBYTES > 0 + if (rnd_in) { + if (read_file_exact_host(rnd_in, h_rnd, RNDBYTES) != 0) { rc = 2; goto cleanup; } + } else { + fill_random_host(h_rnd, RNDBYTES); + } +#endif + CUDA_CHECK(hipMalloc(&d_msg, h_mlen > 0 ? h_mlen : 1)); + CUDA_CHECK(hipMemcpy(d_sk, h_sk, CRYPTO_SECRETKEYBYTES, hipMemcpyHostToDevice)); + if (h_mlen > 0) CUDA_CHECK(hipMemcpy(d_msg, h_msg, h_mlen, hipMemcpyHostToDevice)); +#if RNDBYTES > 0 + CUDA_CHECK(hipMemcpy(d_rnd, h_rnd, RNDBYTES, hipMemcpyHostToDevice)); +#endif + kernel_cli_sign<<<1, 1>>>(d_sig, d_siglen, d_result, d_msg, h_mlen, d_sk, d_rnd); + CUDA_CHECK(hipGetLastError()); + CUDA_CHECK(hipDeviceSynchronize()); + CUDA_CHECK(hipMemcpy(&h_result, d_result, sizeof(int), hipMemcpyDeviceToHost)); + CUDA_CHECK(hipMemcpy(&h_siglen, d_siglen, sizeof(size_t), hipMemcpyDeviceToHost)); + if (h_result != 0 || h_siglen != CRYPTO_BYTES) { + fprintf(stderr, "CLI sign failed: result=%d siglen=%zu\n", h_result, h_siglen); + rc = 3; goto cleanup; + } + CUDA_CHECK(hipMemcpy(h_sig, d_sig, CRYPTO_BYTES, hipMemcpyDeviceToHost)); + if (write_file_all(sig_out, h_sig, CRYPTO_BYTES) != 0) { rc = 2; goto cleanup; } + printf("CLI SIG sign PASS sig=%d msg=%zu\n", CRYPTO_BYTES, h_mlen); + } else if (do_verify) { + const char *pk_in = arg_value(argc, argv, "--pk-in"); + const char *msg_in = arg_value(argc, argv, "--msg-in"); + const char *sig_in = arg_value(argc, argv, "--sig-in"); + if (!pk_in || !msg_in || !sig_in) { + fprintf(stderr, "--cli-verify requires --pk-in, --msg-in, and --sig-in\n"); + rc = 2; goto cleanup; + } + if (read_file_exact_host(pk_in, h_pk, CRYPTO_PUBLICKEYBYTES) != 0 || + read_file_exact_host(sig_in, h_sig, CRYPTO_BYTES) != 0 || + read_file_all(msg_in, &h_msg, &h_mlen) != 0) { rc = 2; goto cleanup; } + CUDA_CHECK(hipMalloc(&d_msg, h_mlen > 0 ? h_mlen : 1)); + CUDA_CHECK(hipMemcpy(d_pk, h_pk, CRYPTO_PUBLICKEYBYTES, hipMemcpyHostToDevice)); + CUDA_CHECK(hipMemcpy(d_sig, h_sig, CRYPTO_BYTES, hipMemcpyHostToDevice)); + if (h_mlen > 0) CUDA_CHECK(hipMemcpy(d_msg, h_msg, h_mlen, hipMemcpyHostToDevice)); + kernel_cli_verify<<<1, 1>>>(d_result, d_sig, CRYPTO_BYTES, d_msg, h_mlen, d_pk); + CUDA_CHECK(hipGetLastError()); + CUDA_CHECK(hipDeviceSynchronize()); + CUDA_CHECK(hipMemcpy(&h_result, d_result, sizeof(int), hipMemcpyDeviceToHost)); + if (h_result == 0) { + printf("CLI SIG verify PASS msg=%zu\n", h_mlen); + } else { + printf("CLI SIG verify FAIL code=%d msg=%zu\n", h_result, h_mlen); + rc = 4; goto cleanup; + } + } + +cleanup: + hipFree(d_pk); hipFree(d_sk); hipFree(d_sig); hipFree(d_msg); + hipFree(d_seed); hipFree(d_rnd); hipFree(d_siglen); hipFree(d_result); + free(h_pk); free(h_sk); free(h_sig); free(h_seed); free(h_rnd); free(h_msg); + return rc == 0 ? 1 : rc; +} + +static void build_api_pre(uint8_t **h_pre, size_t *prelen) { +#if ALGORITHM == ALGO_MLDSA + *prelen = 2; + *h_pre = (uint8_t *)calloc(1, *prelen); + if (*h_pre) { + (*h_pre)[0] = 0; + (*h_pre)[1] = 0; + } +#else + *prelen = 0; + *h_pre = (uint8_t *)calloc(1, 1); +#endif +} + +static void repeat_record(uint8_t *dst, const uint8_t *src, size_t item_len, int batch_count) { + for (int i = 0; i < batch_count; i++) { + memcpy(dst + (size_t)i * item_len, src, item_len); + } +} + +static int run_api_sig_sign( + int batch_count, + const char *msg_in, + const char *pk_out, + const char *sk_out, + const char *sig_out) +{ + int rc = 0; + uint8_t *h_msg = NULL, *h_pre = NULL, *h_seed = NULL, *h_rnd = NULL; + uint8_t *h_pk = NULL, *h_sk = NULL, *h_sig = NULL; + size_t h_mlen = 0, prelen = 0; + uint8_t *d_pk_one = NULL, *d_sk_one = NULL; + uint8_t *d_msg = NULL, *d_pre = NULL, *d_rnd = NULL, *d_base_seed = NULL; + precomp_t *d_pc = NULL; + BatchKeygenBuffers kbuf; + BatchSignPipeline bsp; + memset(&kbuf, 0, sizeof(kbuf)); + memset(&bsp, 0, sizeof(bsp)); + + if (batch_count < 1) batch_count = 1; + if (read_file_all(msg_in, &h_msg, &h_mlen) != 0) return 2; + build_api_pre(&h_pre, &prelen); + h_seed = (uint8_t *)malloc(SEEDBYTES); + h_rnd = (uint8_t *)malloc(RNDBYTES > 0 ? RNDBYTES : 1); + h_pk = (uint8_t *)malloc(CRYPTO_PUBLICKEYBYTES); + h_sk = (uint8_t *)malloc(CRYPTO_SECRETKEYBYTES); + h_sig = (uint8_t *)malloc(CRYPTO_BYTES); + if (!h_pre || !h_seed || !h_rnd || !h_pk || !h_sk || !h_sig) { + fprintf(stderr, "API SIG malloc failed\n"); + rc = 2; goto cleanup; + } + fill_random_host(h_seed, SEEDBYTES); +#if RNDBYTES > 0 + fill_random_host(h_rnd, RNDBYTES); +#endif + + if (batch_keygen_alloc(&kbuf, batch_count) != 0) { + fprintf(stderr, "API SIG batch_keygen_alloc failed\n"); + rc = 3; goto cleanup; + } + CUDA_CHECK(hipMalloc(&d_base_seed, SEEDBYTES)); + CUDA_CHECK(hipMemcpy(d_base_seed, h_seed, SEEDBYTES, hipMemcpyHostToDevice)); + if (batch_keygen_pipeline_warp_opt(kbuf.d_pks, kbuf.d_sks, d_base_seed, &kbuf, batch_count, NULL, 0, 1) != 0) { + fprintf(stderr, "API SIG batch keygen failed\n"); + rc = 3; goto cleanup; + } + CUDA_CHECK(hipDeviceSynchronize()); + CUDA_CHECK(hipMemcpy(h_pk, kbuf.d_pks, CRYPTO_PUBLICKEYBYTES, hipMemcpyDeviceToHost)); + CUDA_CHECK(hipMemcpy(h_sk, kbuf.d_sks, CRYPTO_SECRETKEYBYTES, hipMemcpyDeviceToHost)); + + CUDA_CHECK(hipMalloc(&d_pk_one, CRYPTO_PUBLICKEYBYTES)); + CUDA_CHECK(hipMalloc(&d_sk_one, CRYPTO_SECRETKEYBYTES)); + CUDA_CHECK(hipMemcpy(d_pk_one, kbuf.d_pks, CRYPTO_PUBLICKEYBYTES, hipMemcpyDeviceToDevice)); + CUDA_CHECK(hipMemcpy(d_sk_one, kbuf.d_sks, CRYPTO_SECRETKEYBYTES, hipMemcpyDeviceToDevice)); + + hipDeviceSetLimit(hipLimitStackSize, 128u * 1024u); + hipGetLastError(); + CUDA_CHECK(hipMalloc(&d_pc, sizeof(precomp_t))); + kernel_create_precomp<<<1, 1>>>(d_pc, d_pk_one, d_sk_one); + CUDA_CHECK(hipGetLastError()); + CUDA_CHECK(hipDeviceSynchronize()); + + CUDA_CHECK(hipMalloc(&d_msg, h_mlen > 0 ? h_mlen : 1)); + CUDA_CHECK(hipMalloc(&d_pre, prelen > 0 ? prelen : 1)); + CUDA_CHECK(hipMalloc(&d_rnd, RNDBYTES > 0 ? RNDBYTES : 1)); + if (h_mlen > 0) CUDA_CHECK(hipMemcpy(d_msg, h_msg, h_mlen, hipMemcpyHostToDevice)); + if (prelen > 0) CUDA_CHECK(hipMemcpy(d_pre, h_pre, prelen, hipMemcpyHostToDevice)); +#if RNDBYTES > 0 + CUDA_CHECK(hipMemcpy(d_rnd, h_rnd, RNDBYTES, hipMemcpyHostToDevice)); +#endif + + hipDeviceSetLimit(hipLimitStackSize, 64u * 1024u); + hipGetLastError(); + if (batch_sign_alloc(&bsp, batch_count) != 0) { + fprintf(stderr, "API SIG batch_sign_alloc failed\n"); + rc = 3; goto cleanup; + } + { + const char *policy_label = NULL; + int rounds = 0, done = 0; + BatchSignRuntimeOptions runtime = select_decomp_runtime_options(batch_count, g_bench_independent, &policy_label); + if (batch_sign_pipeline_ex(&bsp, batch_count, d_pc, d_msg, h_mlen, d_pre, prelen, + d_rnd, &runtime, &rounds, &done) != 0 || done != batch_count) { + fprintf(stderr, "API SIG decomp sign failed: done=%d/%d rounds=%d\n", done, batch_count, rounds); + rc = 4; goto cleanup; + } + CUDA_CHECK(hipDeviceSynchronize()); + CUDA_CHECK(hipMemcpy(h_sig, bsp.d_sigs, CRYPTO_BYTES, hipMemcpyDeviceToHost)); + printf("API SIG sign PASS batch=%d sig=%d policy=%s rounds=%d\n", + batch_count, CRYPTO_BYTES, policy_label ? policy_label : "base", rounds); + } + + if (write_file_all(pk_out, h_pk, CRYPTO_PUBLICKEYBYTES) != 0 || + write_file_all(sk_out, h_sk, CRYPTO_SECRETKEYBYTES) != 0 || + write_file_all(sig_out, h_sig, CRYPTO_BYTES) != 0) { + rc = 2; goto cleanup; + } + +cleanup: + hipFree(d_pk_one); hipFree(d_sk_one); hipFree(d_msg); hipFree(d_pre); + hipFree(d_rnd); hipFree(d_base_seed); hipFree(d_pc); + batch_keygen_free(&kbuf); + batch_sign_free(&bsp); + free(h_msg); free(h_pre); free(h_seed); free(h_rnd); + free(h_pk); free(h_sk); free(h_sig); + return rc == 0 ? 1 : rc; +} + +static int run_api_sig_verify( + int batch_count, + const char *msg_in, + const char *pk_in, + const char *sig_in) +{ + int rc = 0; + uint8_t *h_msg = NULL, *h_pre = NULL, *h_pk = NULL, *h_sig = NULL, *h_sigs = NULL, *h_msgs = NULL; + size_t h_mlen = 0, prelen = 0; + uint8_t *d_pk = NULL, *d_msgs = NULL, *d_pre = NULL; + int *h_results = NULL; + BatchVerifyBuffers vbuf; + memset(&vbuf, 0, sizeof(vbuf)); + + if (batch_count < 1) batch_count = 1; + if (read_file_all(msg_in, &h_msg, &h_mlen) != 0) return 2; + build_api_pre(&h_pre, &prelen); + h_pk = (uint8_t *)malloc(CRYPTO_PUBLICKEYBYTES); + h_sig = (uint8_t *)malloc(CRYPTO_BYTES); + h_sigs = (uint8_t *)malloc((size_t)batch_count * CRYPTO_BYTES); + h_msgs = (uint8_t *)malloc((size_t)batch_count * (h_mlen > 0 ? h_mlen : 1)); + h_results = (int *)calloc((size_t)batch_count, sizeof(int)); + if (!h_pre || !h_pk || !h_sig || !h_sigs || !h_msgs || !h_results) { + fprintf(stderr, "API SIG verify malloc failed\n"); + rc = 2; goto cleanup; + } + if (read_file_exact_host(pk_in, h_pk, CRYPTO_PUBLICKEYBYTES) != 0 || + read_file_exact_host(sig_in, h_sig, CRYPTO_BYTES) != 0) { + rc = 2; goto cleanup; + } + repeat_record(h_sigs, h_sig, CRYPTO_BYTES, batch_count); + if (h_mlen > 0) repeat_record(h_msgs, h_msg, h_mlen, batch_count); + + hipDeviceSetLimit(hipLimitStackSize, 128u * 1024u); + hipGetLastError(); + CUDA_CHECK(hipMalloc(&d_pk, CRYPTO_PUBLICKEYBYTES)); + CUDA_CHECK(hipMemcpy(d_pk, h_pk, CRYPTO_PUBLICKEYBYTES, hipMemcpyHostToDevice)); + if (batch_verify_alloc(&vbuf, batch_count) != 0) { + fprintf(stderr, "API SIG batch_verify_alloc failed\n"); + rc = 3; goto cleanup; + } + batch_verify_precompute_kernel<<<1, 1>>>(vbuf.d_mat, vbuf.d_t1_hat, vbuf.d_tr, d_pk); + CUDA_CHECK(hipGetLastError()); + CUDA_CHECK(hipDeviceSynchronize()); + + hipDeviceSetLimit(hipLimitStackSize, 8u * 1024u); + hipGetLastError(); + CUDA_CHECK(hipMalloc(&d_msgs, (size_t)batch_count * (h_mlen > 0 ? h_mlen : 1))); + CUDA_CHECK(hipMalloc(&d_pre, prelen > 0 ? prelen : 1)); + if (h_mlen > 0) CUDA_CHECK(hipMemcpy(d_msgs, h_msgs, (size_t)batch_count * h_mlen, hipMemcpyHostToDevice)); + if (prelen > 0) CUDA_CHECK(hipMemcpy(d_pre, h_pre, prelen, hipMemcpyHostToDevice)); + if (batch_verify_pipeline(&vbuf, h_sigs, d_msgs, h_mlen, d_pre, prelen, batch_count, h_results) != 0) { + fprintf(stderr, "API SIG verify pipeline failed\n"); + rc = 4; goto cleanup; + } + CUDA_CHECK(hipDeviceSynchronize()); + { + int fails = count_failures(h_results, batch_count); + if (fails == 0) { + printf("API SIG verify PASS batch=%d sig=%d\n", batch_count, CRYPTO_BYTES); + } else { + printf("API SIG verify FAIL batch=%d fails=%d\n", batch_count, fails); + rc = 5; + } + } + +cleanup: + hipFree(d_pk); hipFree(d_msgs); hipFree(d_pre); + batch_verify_free(&vbuf); + free(h_msg); free(h_pre); free(h_pk); free(h_sig); free(h_sigs); free(h_msgs); free(h_results); + return rc == 0 ? 1 : rc; +} + +static int run_sig_api_mode(int argc, char **argv) { + const int do_sign = has_arg(argc, argv, "--api-sig-sign"); + const int do_verify = has_arg(argc, argv, "--api-sig-verify"); + if (!do_sign && !do_verify) return 0; + if (do_sign && do_verify) { + fprintf(stderr, "select exactly one SIG API mode\n"); + return 2; + } + int batch_count = 128; + const char *batch_s = arg_value(argc, argv, "--batch"); + if (batch_s) batch_count = atoi(batch_s); + + if (do_sign) { + const char *msg_in = arg_value(argc, argv, "--msg-in"); + const char *pk_out = arg_value(argc, argv, "--pk-out"); + const char *sk_out = arg_value(argc, argv, "--sk-out"); + const char *sig_out = arg_value(argc, argv, "--sig-out"); + if (!msg_in || !pk_out || !sk_out || !sig_out) { + fprintf(stderr, "--api-sig-sign requires --msg-in, --pk-out, --sk-out, and --sig-out\n"); + return 2; + } + return run_api_sig_sign(batch_count, msg_in, pk_out, sk_out, sig_out); + } + + if (do_verify) { + const char *msg_in = arg_value(argc, argv, "--msg-in"); + const char *pk_in = arg_value(argc, argv, "--pk-in"); + const char *sig_in = arg_value(argc, argv, "--sig-in"); + if (!msg_in || !pk_in || !sig_in) { + fprintf(stderr, "--api-sig-verify requires --msg-in, --pk-in, and --sig-in\n"); + return 2; + } + return run_api_sig_verify(batch_count, msg_in, pk_in, sig_in); + } + return 0; +} + +/* ================================================================ + * main + * ================================================================ */ +int main(int argc, char **argv) { + int cli_rc = run_cli_mode(argc, argv); + if (cli_rc != 0) return cli_rc == 1 ? 0 : cli_rc; + int api_rc = run_sig_api_mode(argc, argv); + if (api_rc != 0) return api_rc == 1 ? 0 : api_rc; + Options opt; + int r = parse_options(argc, argv, &opt); + if (r > 0) return 0; + if (r < 0) { print_usage(argv[0]); return 1; } + + if (opt.batch_auto) opt.batch_size = select_default_batch_for_device(); + print_info(opt.batch_size, opt.batch_auto); + + /* CUDA 栈空间 — 单线程正确性测试用较大栈 */ + { + size_t phase1_stack = 128u * 1024u; + if (hipDeviceSetLimit(hipLimitStackSize, phase1_stack) != hipSuccess) { + hipGetLastError(); + printf("Warning: could not set CUDA stack size\n\n"); + } + } + + /* 生成随机测试向量 */ + srand((unsigned)time(NULL)); + + uint8_t h_seed[SEEDBYTES]; + for (int i = 0; i < SEEDBYTES; i++) h_seed[i] = (uint8_t)(rand() & 0xFF); + +#if RNDBYTES > 0 + uint8_t h_rnd[RNDBYTES]; + for (int i = 0; i < RNDBYTES; i++) h_rnd[i] = (uint8_t)(rand() & 0xFF); +#else + uint8_t h_rnd[1] = {0}; +#endif + + size_t mlen = 32; + uint8_t h_msg[32]; + for (size_t i = 0; i < mlen; i++) h_msg[i] = (uint8_t)(rand() & 0xFF); + +#if ALGORITHM == ALGO_MLDSA + size_t ctxlen = 32; + uint8_t h_ctx[32]; + for (size_t i = 0; i < ctxlen; i++) h_ctx[i] = (uint8_t)(rand() & 0xFF); +#else + size_t ctxlen = 0; + uint8_t h_ctx[1] = {0}; +#endif + + /* 构造 pre = (0, ctxlen, ctx) */ + size_t prelen = 2 + ctxlen; + uint8_t h_pre[34]; + h_pre[0] = 0; + h_pre[1] = (uint8_t)ctxlen; + if (ctxlen > 0) memcpy(h_pre + 2, h_ctx, ctxlen); + + if (opt.sample_only) { + if (opt.keygen_compare) { + r = run_keygen_compare_batch(opt.batch_size, h_seed, opt.quiet, 1); + return r != 0 ? 1 : 0; + } + + printf(" %-12s %-22s %8s %10s %12s\n", "Operation", "Mode", "Batch", "Time(ms)", "Throughput"); + printf(" %-12s %-22s %8s %10s %12s\n", "---------", "----", "-----", "--------", "----------"); + fflush(stdout); + + if (opt.sweep) { + int sizes[] = {64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768}; + int nsizes = (int)(sizeof(sizes) / sizeof(sizes[0])); + for (int i = 0; i < nsizes; i++) { + r = run_sample_only_batch(sizes[i], h_seed, opt.quiet, SAMPLE_ONLY_ITERS); + if (r != 0) { + printf("Batch=%d FAILED, stopping sample-only sweep.\n", sizes[i]); + break; + } + } + } else { + r = run_sample_only_batch(opt.batch_size, h_seed, opt.quiet, SAMPLE_ONLY_ITERS); + } + return r != 0 ? 1 : 0; + } + + if (opt.keygen_compare) { + r = run_keygen_compare_batch(opt.batch_size, h_seed, opt.quiet, 0); + return r != 0 ? 1 : 0; + } + + /* 正确性验证 (单实例) */ + r = run_single_correctness(h_seed, h_rnd, h_msg, mlen, + h_ctx, ctxlen, + h_pre, prelen, opt.quiet); + if (r != 0) { + printf("Correctness FAILED.\n"); + return 1; + } + + if (!opt.skip_keygen_oracle) { + r = run_keygen_oracle_check(h_seed, 8, opt.quiet); + if (r != 0) { + printf("Keygen oracle check FAILED.\n"); + return 1; + } + } else if (!opt.quiet) { + printf("[Keygen-oracle] skipped by --skip-keygen-oracle\n"); + } + + /* 批量性能基准 */ + if (opt.throughput) { + run_throughput_scan(h_seed, h_rnd, h_msg, mlen, h_pre, prelen); + return 0; + } + + printf(" %-12s %8s %10s %12s\n", "Operation", "Batch", "Time(ms)", "Throughput"); + printf(" %-12s %8s %10s %12s\n", "---------", "-----", "--------", "----------"); + fflush(stdout); + + if (opt.sweep) { + int sizes[] = {64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768}; + int nsizes = (int)(sizeof(sizes) / sizeof(sizes[0])); + for (int i = 0; i < nsizes; i++) { + r = run_batch(sizes[i], h_seed, h_rnd, h_msg, mlen, + h_pre, prelen, opt.quiet, BENCH_ITERS, + NULL, NULL, NULL); + if (r != 0) { + printf("Batch=%d FAILED, stopping sweep.\n", sizes[i]); + break; + } + } + } else { + r = run_batch(opt.batch_size, h_seed, h_rnd, h_msg, mlen, + h_pre, prelen, opt.quiet, BENCH_ITERS, + NULL, NULL, NULL); + } + + return r != 0 ? 1 : 0; +} diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/ntt.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/ntt.cuh new file mode 100644 index 000000000..de65c0abe --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/ntt.cuh @@ -0,0 +1,299 @@ +/* + * ntt.cuh — 统一 NTT / INVNTT + * + * 数学结构: + * 两种算法使用完全相同的 Cooley-Tukey 蝶形 NTT 结构 (N=256, signed int32_t)。 + * 差异仅在于 zeta 表数值 (由 Q 决定) 和 INTT 归一化常数 INTT_F。 + * + * 验证的关键性质: + * 对 Aigis 的两个 Q 值, 均满足 zetas_inv[k] = (Q - zetas[255-k]) % Q, + * 因此 INVNTT 可以与 ML-DSA 使用完全相同的 -zetas[--k] 迭代方式。 + * (人工验证了 Aigis Q=3870721: zetas_inv[0]=1451689=Q-zetas[255]=Q-2419032 ✓) + * + * 函数体: + * ntt() — 所有算法共享同一份代码 + * invntt_tomont() — 所有算法共享同一份代码 + */ + +#ifndef NTT_CUH +#define NTT_CUH + +#include +#include "params.h" +#include "reduce.cuh" + +/* ================================================================ + * Zeta 表 (仅数据不同, 函数体完全一致) + * ================================================================ */ + +#if ALGORITHM == ALGO_MLDSA + +/* INTT 归一化常数: 已知正确值 */ +#define INTT_F 41978 + +__constant__ int32_t ntt_zetas[PARAM_N] = { + 0, 25847, -2608894, -518909, 237124, -777960, -876248, 466468, + 1826347, 2353451, -359251, -2091905, 3119733, -2884855, 3111497, 2680103, + 2725464, 1024112, -1079900, 3585928, -549488, -1119584, 2619752, -2108549, + -2118186, -3859737, -1399561, -3277672, 1757237, -19422, 4010497, 280005, + 2706023, 95776, 3077325, 3530437, -1661693, -3592148, -2537516, 3915439, + -3861115, -3043716, 3574422, -2867647, 3539968, -300467, 2348700, -539299, + -1699267, -1643818, 3505694, -3821735, 3507263, -2140649, -1600420, 3699596, + 811944, 531354, 954230, 3881043, 3900724, -2556880, 2071892, -2797779, + -3930395, -1528703, -3677745, -3041255, -1452451, 3475950, 2176455, -1585221, + -1257611, 1939314, -4083598, -1000202, -3190144, -3157330, -3632928, 126922, + 3412210, -983419, 2147896, 2715295, -2967645, -3693493, -411027, -2477047, + -671102, -1228525, -22981, -1308169, -381987, 1349076, 1852771, -1430430, + -3343383, 264944, 508951, 3097992, 44288, -1100098, 904516, 3958618, + -3724342, -8578, 1653064, -3249728, 2389356, -210977, 759969, -1316856, + 189548, -3553272, 3159746, -1851402, -2409325, -177440, 1315589, 1341330, + 1285669, -1584928, -812732, -1439742, -3019102, -3881060, -3628969, 3839961, + 2091667, 3407706, 2316500, 3817976, -3342478, 2244091, -2446433, -3562462, + 266997, 2434439, -1235728, 3513181, -3520352, -3759364, -1197226, -3193378, + 900702, 1859098, 909542, 819034, 495491, -1613174, -43260, -522500, + -655327, -3122442, 2031748, 3207046, -3556995, -525098, -768622, -3595838, + 342297, 286988, -2437823, 4108315, 3437287, -3342277, 1735879, 203044, + 2842341, 2691481, -2590150, 1265009, 4055324, 1247620, 2486353, 1595974, + -3767016, 1250494, 2635921, -3548272, -2994039, 1869119, 1903435, -1050970, + -1333058, 1237275, -3318210, -1430225, -451100, 1312455, 3306115, -1962642, + -1279661, 1917081, -2546312, -1374803, 1500165, 777191, 2235880, 3406031, + -542412, -2831860, -1671176, -1846953, -2584293, -3724270, 594136, -3776993, + -2013608, 2432395, 2454455, -164721, 1957272, 3369112, 185531, -1207385, + -3183426, 162844, 1616392, 3014001, 810149, 1652634, -3694233, -1799107, + -3038916, 3523897, 3866901, 269760, 2213111, -975884, 1717735, 472078, + -426683, 1723600, -1803090, 1910376, -1667432, -1104333, -260646, -3833893, + -2939036, -2235985, -420899, -2286327, 183443, -976891, 1612842, -3545687, + -554416, 3919660, -48306, -1362209, 3937738, 1400424, -846154, 1976782 +}; + +#elif ALGORITHM == ALGO_AIGIS + +/* INTT_F: 在设备端计算 (编译器常量表达式, 由 NVCC 编译期求值) */ +/* 等价于 N^{-1} * R^2 mod Q, 与 Aigis 原始 mont_invn 保持一致 */ +static __device__ const int32_t INTT_F = + (int32_t)((uint32_t)((uint64_t)MONT_VAL * MONT_VAL % PARAM_Q + * ((uint64_t)PARAM_Q - 1) % PARAM_Q + * (((uint64_t)PARAM_Q - 1) >> 8) % PARAM_Q)); + +#if PARAM_Q == 2021377 + +__constant__ int32_t ntt_zetas[PARAM_N] = { + 1562548, 518470, 697898, 862629, 1367459, 1539276, 1513857, 1662806, + 929015, 1757045, 1879015, 449873, 75689, 1125711, 1680345, 620849, + 769419, 486664, 1389778, 658915, 1319993, 73499, 1391732, 1199964, + 291970, 655587, 966181, 128755, 288564, 10420, 1980158, 1011904, + 1937906, 838813, 854780, 1453936, 1704819, 1740984, 86645, 1360044, + 115556, 1570480, 1655800, 272433, 1245520, 1190005, 238406, 1726139, + 1013693, 1948648, 1020399, 1544116, 1120075, 656153, 591869, 1620799, + 275832, 517427, 1601944, 1555925, 1293833, 1705829, 1357642, 142050, + 739420, 1568070, 1535360, 740638, 57925, 1038012, 65439, 1844105, + 673379, 1768997, 924638, 1986117, 1394208, 1277276, 129269, 1760277, + 1173604, 1161770, 1897168, 807697, 965038, 1876057, 1963820, 1794916, + 924093, 251419, 168030, 1073286, 1902394, 347156, 1488477, 511116, + 572755, 1686880, 268077, 53223, 1268228, 579769, 1043786, 272581, + 1574784, 1729984, 568576, 276296, 1095755, 282107, 158374, 915466, + 1569380, 908136, 972609, 923797, 466409, 1762448, 798650, 436051, + 1275685, 1122838, 1862, 1854194, 1432015, 1507507, 1452715, 1170924, + 137295, 531590, 556763, 1442250, 896280, 320184, 333460, 1993546, + 622613, 1352919, 881664, 1176558, 1936677, 2011958, 1357750, 534023, + 142791, 40293, 638104, 1519860, 1189220, 1763667, 792470, 1813814, + 830483, 1256948, 1537350, 64760, 561409, 823180, 786453, 1106713, + 1491299, 1582163, 822179, 1663832, 1269819, 84100, 780824, 310495, + 1043416, 763923, 1440072, 1308437, 1369984, 1027053, 641681, 932722, + 1248044, 318540, 1777818, 702544, 1566714, 1301662, 265980, 696370, + 1576958, 449193, 1228202, 1635455, 1143957, 1349609, 120737, 1115065, + 1815624, 573533, 10820, 1911846, 533321, 1147868, 1126927, 145151, + 641139, 275750, 276830, 1257214, 988074, 1857331, 105366, 1608247, + 1752751, 817865, 294374, 1145376, 1447053, 647982, 1517128, 301974, + 233775, 1669708, 1146108, 1913137, 707228, 1147423, 349817, 1972001, + 777351, 1874015, 964313, 161863, 1142539, 1331457, 1604014, 1320129, + 1103939, 1236477, 447210, 1613614, 1666811, 51306, 383284, 1573619, + 677023, 994549, 23785, 210391, 461525, 1779756, 430663, 84620, + 1731642, 1784991, 147098, 942182, 1953450, 1853187, 1567373, 1541031 +}; + +__constant__ int32_t ntt_zetas_inv[PARAM_N] = { + 480346, 454004, 168190, 67927, 1079195, 1874279, 236386, 289735, + 1936757, 1590714, 241621, 1559852, 1810986, 1997592, 1026828, 1344354, + 447758, 1638093, 1970071, 354566, 407763, 1574167, 784900, 917438, + 701248, 417363, 689920, 878838, 1859514, 1057064, 147362, 1244026, + 49376, 1671560, 873954, 1314149, 108240, 875269, 351669, 1787602, + 1719403, 504249, 1373395, 574324, 876001, 1727003, 1203512, 268626, + 413130, 1916011, 164046, 1033303, 764163, 1744547, 1745627, 1380238, + 1876226, 894450, 873509, 1488056, 109531, 2010557, 1447844, 205753, + 906312, 1900640, 671768, 877420, 385922, 793175, 1572184, 444419, + 1325007, 1755397, 719715, 454663, 1318833, 243559, 1702837, 773333, + 1088655, 1379696, 994324, 651393, 712940, 581305, 1257454, 977961, + 1710882, 1240553, 1937277, 751558, 357545, 1199198, 439214, 530078, + 914664, 1234924, 1198197, 1459968, 1956617, 484027, 764429, 1190894, + 207563, 1228907, 257710, 832157, 501517, 1383273, 1981084, 1878586, + 1487354, 663627, 9419, 84700, 844819, 1139713, 668458, 1398764, + 27831, 1687917, 1701193, 1125097, 579127, 1464614, 1489787, 1884082, + 850453, 568662, 513870, 589362, 167183, 2019515, 898539, 745692, + 1585326, 1222727, 258929, 1554968, 1097580, 1048768, 1113241, 451997, + 1105911, 1863003, 1739270, 925622, 1745081, 1452801, 291393, 446593, + 1748796, 977591, 1441608, 753149, 1968154, 1753300, 334497, 1448622, + 1510261, 532900, 1674221, 118983, 948091, 1853347, 1769958, 1097284, + 226461, 57557, 145320, 1056339, 1213680, 124209, 859607, 847773, + 261100, 1892108, 744101, 627169, 35260, 1096739, 252380, 1347998, + 177272, 1955938, 983365, 1963452, 1280739, 486017, 453307, 1281957, + 1879327, 663735, 315548, 727544, 465452, 419433, 1503950, 1745545, + 400578, 1429508, 1365224, 901302, 477261, 1000978, 72729, 1007684, + 295238, 1782971, 831372, 775857, 1748944, 365577, 450897, 1905821, + 661333, 1934732, 280393, 316558, 567441, 1166597, 1182564, 83471, + 1009473, 41219, 2010957, 1732813, 1892622, 1055196, 1365790, 1729407, + 821413, 629645, 1947878, 701384, 1362462, 631599, 1534713, 1251958, + 1400528, 341032, 895666, 1945688, 1571504, 142362, 264332, 1092362, + 358571, 507520, 482101, 653918, 1158748, 1323479, 1331599 +}; + +#elif PARAM_Q == 3870721 + +__constant__ int32_t ntt_zetas[PARAM_N] = { + 2337707, 2505409, 267692, 529914, 420735, 181988, 2608440, 3865338, + 3665767, 288746, 2524026, 3008396, 901579, 70491, 1821213, 1437514, + 3375394, 502705, 3475623, 3513653, 1833017, 3651222, 947790, 1966036, + 2704588, 2850143, 3030905, 1622520, 3210245, 3127826, 292206, 3096784, + 3201921, 3867412, 1705316, 2917474, 2975359, 2004421, 2812268, 890313, + 2511631, 3623292, 2803099, 2903766, 1596209, 2040136, 3468632, 2156661, + 2913824, 2560388, 1214035, 3468039, 575792, 2926910, 3407464, 2292204, + 2285761, 2338667, 63216, 3835938, 3204529, 1818443, 3786633, 3241498, + 944328, 616348, 2927622, 64038, 1171534, 1361903, 2827360, 3144828, + 2738981, 1714811, 3625146, 89505, 2787809, 2363190, 2513795, 3306399, + 1418851, 1206903, 926563, 211044, 466372, 3410093, 1353383, 3610570, + 934100, 2471859, 2037600, 2996463, 1698492, 525418, 1662944, 1981925, + 1210222, 1813802, 314420, 2466015, 3516872, 3320431, 1355971, 1500137, + 493991, 36365, 3235243, 214827, 2544017, 1739057, 945221, 1038283, + 2889903, 3364214, 1674857, 1434035, 1665177, 2651227, 1575769, 1155464, + 467835, 1713031, 2041544, 408424, 137443, 2029527, 2115209, 2293884, + 2137416, 3189891, 2471629, 2229785, 2611740, 2394735, 2287191, 2862622, + 300090, 1004990, 401830, 143957, 2910193, 3787906, 3628164, 3171269, + 2239135, 3038465, 601725, 2887353, 2766912, 1622354, 2989501, 1339396, + 1939160, 2386893, 103181, 2793304, 911193, 3295333, 3025653, 2513246, + 314427, 939239, 57676, 2293294, 2833811, 2842292, 3139575, 2705158, + 1290463, 3780876, 1462003, 668827, 1850975, 2327221, 2910099, 2724881, + 418972, 957090, 321362, 2898276, 3523069, 1463158, 3818473, 453440, + 1891547, 1601731, 529312, 3301251, 1117070, 3520718, 634170, 1958581, + 929634, 1133255, 3807619, 1159272, 3292496, 3530590, 927442, 3686531, + 2605292, 384058, 1415774, 1040397, 3663661, 2332173, 1131260, 680774, + 1186917, 3736575, 1064994, 2954460, 3051663, 1162037, 2962553, 2130376, + 1717870, 3565361, 2935922, 2347272, 1768863, 3125776, 1686747, 3137894, + 2993356, 1574419, 1073008, 1262182, 183934, 914847, 3373156, 3688758, + 2538361, 614066, 3211143, 3565127, 1322591, 3426188, 2951336, 172348, + 3747492, 3719872, 2962113, 778168, 2880082, 1051508, 3741079, 1816757, + 763621, 328987, 2831790, 1276220, 135870, 3388537, 3034187, 2419032 +}; + +__constant__ int32_t ntt_zetas_inv[PARAM_N] = { + 1451689, 836534, 482184, 3734851, 2594501, 1038931, 3541734, 3107100, + 2053964, 129642, 2819213, 990639, 3092553, 908608, 150849, 123229, + 3698373, 919385, 444533, 2548130, 305594, 659578, 3256655, 1332360, + 181963, 497565, 2955874, 3686787, 2608539, 2797713, 2296302, 877365, + 732827, 2183974, 744945, 2101858, 1523449, 934799, 305360, 2152851, + 1740345, 908168, 2708684, 819058, 916261, 2805727, 134146, 2683804, + 3189947, 2739461, 1538548, 207060, 2830324, 2454947, 3486663, 1265429, + 184190, 2943279, 340131, 578225, 2711449, 63102, 2737466, 2941087, + 1912140, 3236551, 350003, 2753651, 569470, 3341409, 2268990, 1979174, + 3417281, 52248, 2407563, 347652, 972445, 3549359, 2913631, 3451749, + 1145840, 960622, 1543500, 2019746, 3201894, 2408718, 89845, 2580258, + 1165563, 731146, 1028429, 1036910, 1577427, 3813045, 2931482, 3556294, + 1357475, 845068, 575388, 2959528, 1077417, 3767540, 1483828, 1931561, + 2531325, 881220, 2248367, 1103809, 983368, 3268996, 832256, 1631586, + 699452, 242557, 82815, 960528, 3726764, 3468891, 2865731, 3570631, + 1008099, 1583530, 1475986, 1258981, 1640936, 1399092, 680830, 1733305, + 1576837, 1755512, 1841194, 3733278, 3462297, 1829177, 2157690, 3402886, + 2715257, 2294952, 1219494, 2205544, 2436686, 2195864, 506507, 980818, + 2832438, 2925500, 2131664, 1326704, 3655894, 635478, 3834356, 3376730, + 2370584, 2514750, 550290, 353849, 1404706, 3556301, 2056919, 2660499, + 1888796, 2207777, 3345303, 2172229, 874258, 1833121, 1398862, 2936621, + 260151, 2517338, 460628, 3404349, 3659677, 2944158, 2663818, 2451870, + 564322, 1356926, 1507531, 1082912, 3781216, 245575, 2155910, 1131740, + 725893, 1043361, 2508818, 2699187, 3806683, 943099, 3254373, 2926393, + 629223, 84088, 2052278, 666192, 34783, 3807505, 1532054, 1584960, + 1578517, 463257, 943811, 3294929, 402682, 2656686, 1310333, 956897, + 1714060, 402089, 1830585, 2274512, 966955, 1067622, 247429, 1359090, + 2980408, 1058453, 1866300, 895362, 953247, 2165405, 3309, 668800, + 773937, 3578515, 742895, 660476, 2248201, 839816, 1020578, 1166133, + 1904685, 2922931, 219499, 2037704, 357068, 395098, 3368016, 495327, + 2433207, 2049508, 3800230, 2969142, 862325, 1346695, 3581975, 204954, + 5383, 1262281, 3688733, 3449986, 3340807, 3603029, 951197 +}; + +#endif /* PARAM_Q */ + +#endif /* ALGORITHM */ + +/* ================================================================ + * NTT 函数 (两种算法共用同一份代码) + * ================================================================ */ + +/* + * ntt(a): 就地前向 NTT 变换 + * 使用有符号 Montgomery 蝶形: t = mont_reduce(zeta * a[j+len]) + * a[j+len] = a[j] - t + * a[j] = a[j] + t + */ +static __device__ __noinline__ void ntt(int32_t a[PARAM_N]) { + unsigned int len, start, j, k; + int32_t zeta, t; + + k = 0; + for (len = 128; len > 0; len >>= 1) { + for (start = 0; start < PARAM_N; start = j + len) { + zeta = ntt_zetas[++k]; + for (j = start; j < start + len; ++j) { + t = montgomery_reduce((int64_t)zeta * a[j + len]); + a[j + len] = a[j] - t; + a[j] = a[j] + t; + } + } + } +} + +/* + * invntt_tomont(a): 就地逆 NTT 变换 + * + * ML-DSA: 逆蝶形使用 -ntt_zetas[--k], 全部 N 系数乘以 INTT_F + * Aigis: 逆蝶形使用独立的 ntt_zetas_inv[k++], 仅前 N/2 系数归一化 + * (Aigis 的 zetas_inv 不等于 -zetas 的逆序, 必须用独立表) + */ +static __device__ __noinline__ void invntt_tomont(int32_t a[PARAM_N]) { + unsigned int start, len, j, k; + int32_t t, zeta; + const int32_t f = INTT_F; + +#if ALGORITHM == ALGO_MLDSA + k = PARAM_N; + for (len = 1; len < PARAM_N; len <<= 1) { + for (start = 0; start < PARAM_N; start = j + len) { + zeta = -ntt_zetas[--k]; + for (j = start; j < start + len; ++j) { + t = a[j]; + a[j] = t + a[j + len]; + a[j + len] = t - a[j + len]; + a[j + len] = montgomery_reduce((int64_t)zeta * a[j + len]); + } + } + } + for (j = 0; j < PARAM_N; ++j) + a[j] = montgomery_reduce((int64_t)f * a[j]); + +#elif ALGORITHM == ALGO_AIGIS + k = 0; + for (len = 1; len < PARAM_N; len <<= 1) { + for (start = 0; start < PARAM_N; start = j + len) { + zeta = ntt_zetas_inv[k++]; + for (j = start; j < start + len; ++j) { + t = a[j]; + a[j] = t + a[j + len]; + a[j + len] = t - a[j + len]; + a[j + len] = montgomery_reduce((int64_t)zeta * a[j + len]); + } + } + } + for (j = 0; j < PARAM_N / 2; ++j) + a[j] = montgomery_reduce((int64_t)f * a[j]); +#endif +} + +#endif /* NTT_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/packing.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/packing.cuh new file mode 100644 index 000000000..6a49ce8e0 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/packing.cuh @@ -0,0 +1,224 @@ +#ifndef PACKING_CUH +#define PACKING_CUH + +#include +#include "params.h" +#include "poly.cuh" +#include "polyvec.cuh" + +/* + * Signature format: + * ML-DSA: c_tilde (CTILDEBYTES) || z_packed (L*POLYZ_PACKEDBYTES) || hint_bitmap (OMEGA+K) + * Aigis: z_packed (L*POLYZ_PACKEDBYTES) || hint_bitmap (OMEGA+K) || challenge_poly (N/8+8) + * + * Hint bitmap layout: first OMEGA bytes are sorted coeff indices with hints=1, + * last K bytes are end offsets for each poly (same for both). + */ + +/* ================================================================ + * Public key: rho (SEEDBYTES=32) || t1 packed (K * POLYT1_PACKEDBYTES) + * ================================================================ */ +static __device__ void pack_pk(uint8_t pk[CRYPTO_PUBLICKEYBYTES], + const uint8_t rho[SEEDBYTES], const polyveck *t1) { + for (unsigned int i = 0; i < SEEDBYTES; ++i) pk[i] = rho[i]; + for (unsigned int i = 0; i < PARAM_K; ++i) + polyt1_pack(pk + SEEDBYTES + i * POLYT1_PACKEDBYTES, &t1->vec[i]); +} + +static __device__ void unpack_pk(uint8_t rho[SEEDBYTES], polyveck *t1, + const uint8_t pk[CRYPTO_PUBLICKEYBYTES]) { + for (unsigned int i = 0; i < SEEDBYTES; ++i) rho[i] = pk[i]; + for (unsigned int i = 0; i < PARAM_K; ++i) + polyt1_unpack(&t1->vec[i], pk + SEEDBYTES + i * POLYT1_PACKEDBYTES); +} + +/* ================================================================ + * Secret key: rho (SEEDBYTES) || key (SEEDBYTES) || tr (TRBYTES) + * || s1 (L*POLYETA1_PACKEDBYTES) || s2 (K*POLYETA2_PACKEDBYTES) + * || t0 (K*POLYT0_PACKEDBYTES) + * ================================================================ */ +static __device__ void pack_sk(uint8_t sk[CRYPTO_SECRETKEYBYTES], + const uint8_t rho[SEEDBYTES], + const uint8_t key[SEEDBYTES], + const uint8_t tr[TRBYTES], + const polyvecl *s1, const polyveck *s2, + const polyveck *t0) { + unsigned int offset = 0; + for (unsigned int i = 0; i < SEEDBYTES; ++i) sk[offset++] = rho[i]; + for (unsigned int i = 0; i < SEEDBYTES; ++i) sk[offset++] = key[i]; + for (unsigned int i = 0; i < TRBYTES; ++i) sk[offset++] = tr[i]; + for (unsigned int i = 0; i < PARAM_L; ++i) { + polyeta_s1_pack(sk + offset, &s1->vec[i]); + offset += POLYETA1_PACKEDBYTES; + } + for (unsigned int i = 0; i < PARAM_K; ++i) { + polyeta_s2_pack(sk + offset, &s2->vec[i]); + offset += POLYETA2_PACKEDBYTES; + } + for (unsigned int i = 0; i < PARAM_K; ++i) { + polyt0_pack(sk + offset, &t0->vec[i]); + offset += POLYT0_PACKEDBYTES; + } +} + +static __device__ void unpack_sk(uint8_t rho[SEEDBYTES], uint8_t key[SEEDBYTES], + uint8_t tr[TRBYTES], + polyvecl *s1, polyveck *s2, polyveck *t0, + const uint8_t sk[CRYPTO_SECRETKEYBYTES]) { + unsigned int offset = 0; + for (unsigned int i = 0; i < SEEDBYTES; ++i) rho[i] = sk[offset++]; + for (unsigned int i = 0; i < SEEDBYTES; ++i) key[i] = sk[offset++]; + for (unsigned int i = 0; i < TRBYTES; ++i) tr[i] = sk[offset++]; + for (unsigned int i = 0; i < PARAM_L; ++i) { + polyeta_s1_unpack(&s1->vec[i], sk + offset); + offset += POLYETA1_PACKEDBYTES; + } + for (unsigned int i = 0; i < PARAM_K; ++i) { + polyeta_s2_unpack(&s2->vec[i], sk + offset); + offset += POLYETA2_PACKEDBYTES; + } + for (unsigned int i = 0; i < PARAM_K; ++i) { + polyt0_unpack(&t0->vec[i], sk + offset); + offset += POLYT0_PACKEDBYTES; + } +} + +/* ================================================================ + * Signature packing/unpacking — algorithm-specific format + * ================================================================ */ +#if ALGORITHM == ALGO_MLDSA +/* ML-DSA: c_tilde || z || hint_bitmap */ +static __device__ void pack_sig(uint8_t sig[CRYPTO_BYTES], + const uint8_t c_tilde[CTILDEBYTES], + const polyvecl *z, const polyveck *h) { + unsigned int offset = 0, k = 0; + + for (unsigned int i = 0; i < CTILDEBYTES; ++i) sig[offset++] = c_tilde[i]; + for (unsigned int i = 0; i < PARAM_L; ++i) { + polyz_pack(sig + offset, &z->vec[i]); + offset += POLYZ_PACKEDBYTES; + } + + for (unsigned int i = 0; i < PARAM_OMEGA + PARAM_K; ++i) sig[offset + i] = 0; + for (unsigned int i = 0; i < PARAM_K; ++i) { + for (unsigned int j = 0; j < PARAM_N; ++j) + if (h->vec[i].coeffs[j] != 0) sig[offset + k++] = (uint8_t)j; + sig[offset + PARAM_OMEGA + i] = (uint8_t)k; + } +} + +static __device__ __noinline__ int unpack_sig(uint8_t c_tilde[CTILDEBYTES], + polyvecl *z, polyveck *h, + const uint8_t sig[CRYPTO_BYTES]) { + unsigned int offset = 0, k = 0; + + for (unsigned int i = 0; i < CTILDEBYTES; ++i) c_tilde[i] = sig[offset++]; + for (unsigned int i = 0; i < PARAM_L; ++i) { + polyz_unpack(&z->vec[i], sig + offset); + offset += POLYZ_PACKEDBYTES; + } + + for (unsigned int i = 0; i < PARAM_K; ++i) { + unsigned int prev_k = k; + for (unsigned int j = 0; j < PARAM_N; ++j) h->vec[i].coeffs[j] = 0; + unsigned int end = sig[offset + PARAM_OMEGA + i]; + if (end < k || end > PARAM_OMEGA) return 1; + for (unsigned int j = k; j < end; ++j) { + if (j > prev_k && sig[offset + j] <= sig[offset + j - 1]) + return 1; + h->vec[i].coeffs[sig[offset + j]] = 1; + } + k = end; + } + for (; k < PARAM_OMEGA; ++k) if (sig[offset + k] != 0) return 1; + return 0; +} + +#elif ALGORITHM == ALGO_AIGIS +/* Aigis: z || hint_bitmap || challenge_poly(N/8 + 8 bytes) */ +static __device__ void pack_sig(uint8_t sig[CRYPTO_BYTES], + const polyvecl *z, const polyveck *h, + const poly *c) { + unsigned int i, j, k; + uint64_t signs, mask; + unsigned int offset = 0; + + /* z_packed */ + for (i = 0; i < PARAM_L; ++i) { + polyz_pack(sig + offset, &z->vec[i]); + offset += POLYZ_PACKEDBYTES; + } + + /* hint bitmap */ + k = 0; + for (i = 0; i < PARAM_OMEGA + PARAM_K; ++i) sig[offset + i] = 0; + for (i = 0; i < PARAM_K; ++i) { + for (j = 0; j < PARAM_N; ++j) + if (h->vec[i].coeffs[j] == 1) sig[offset + k++] = (uint8_t)j; + sig[offset + PARAM_OMEGA + i] = (uint8_t)k; + } + offset += PARAM_OMEGA + PARAM_K; + + /* challenge poly: N/8 bytes bitmap + 8 bytes signs */ + signs = 0; + mask = 1; + for (i = 0; i < PARAM_N / 8; ++i) { + sig[offset + i] = 0; + for (j = 0; j < 8; ++j) { + if (c->coeffs[8 * i + j] != 0) { + sig[offset + i] |= (1u << j); + if (c->coeffs[8 * i + j] == (PARAM_Q - 1)) signs |= mask; + mask <<= 1; + } + } + } + offset += PARAM_N / 8; + for (i = 0; i < 8; ++i) sig[offset + i] = (uint8_t)(signs >> (8 * i)); +} + +static __device__ __noinline__ int unpack_sig(polyvecl *z, polyveck *h, poly *c, + const uint8_t sig[CRYPTO_BYTES]) { + unsigned int i, j, k; + uint64_t signs, mask; + unsigned int offset = 0; + + /* z_packed */ + for (i = 0; i < PARAM_L; ++i) { + polyz_unpack(&z->vec[i], sig + offset); + offset += POLYZ_PACKEDBYTES; + } + + /* hint bitmap */ + k = 0; + for (i = 0; i < PARAM_K; ++i) { + for (j = 0; j < PARAM_N; ++j) h->vec[i].coeffs[j] = 0; + unsigned int end = sig[offset + PARAM_OMEGA + i]; + if (end < k || end > PARAM_OMEGA) return 1; + for (j = k; j < end; ++j) { + if (j > k && sig[offset + j] <= sig[offset + j - 1]) return 1; + h->vec[i].coeffs[sig[offset + j]] = 1; + } + k = end; + } + for (j = k; j < PARAM_OMEGA; ++j) if (sig[offset + j]) return 1; + offset += PARAM_OMEGA + PARAM_K; + + /* challenge poly: N/8 bitmap + 8 signs */ + for (i = 0; i < PARAM_N; ++i) c->coeffs[i] = 0; + signs = 0; + for (i = 0; i < 8; ++i) + signs |= (uint64_t)sig[offset + PARAM_N / 8 + i] << (8 * i); + mask = 1; + for (i = 0; i < PARAM_N / 8; ++i) { + for (j = 0; j < 8; ++j) { + if ((sig[offset + i] >> j) & 0x01) { + c->coeffs[8 * i + j] = (signs & mask) ? PARAM_Q - 1 : 1; + mask <<= 1; + } + } + } + return 0; +} +#endif + +#endif /* PACKING_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/params.h b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/params.h new file mode 100644 index 000000000..e86f72d3e --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/params.h @@ -0,0 +1,307 @@ +#include "hip/hip_runtime.h" +/* + * params.h — 统一参数头文件 + * + * 两种算法通过同一套宏名描述所有参数。 + * 算法语义差异 (系数域, 采样, challenge, hint 等) 通过 + * #if ALGORITHM 分支在各功能文件中处理。 + * + * 关键统一决策: + * coeff_t = int32_t (ML-DSA 本已是signed; Aigis系数 < Q < 4M < 2^22, 完全适配) + * PARAM_ETA_S1/S2 = s1/s2 多项式的 eta (Aigis中 ETA1 ≠ ETA2) + * PARAM_BETA1/BETA2 = TAU * ETA_S1/S2 (norm reject 阈值) + * TRBYTES = CRHBYTES (签名的 tr 长度等于哈希输出长度) + * RNDBYTES = 随机化签名熵 (ML-DSA=32, Aigis=0) + * SETA1BITS/SETA2BITS = bits to pack PARAM_ETA_S1/S2 coefficients + * POLYT1_PACKED_BITS = QBITS - D (bits per t1 coeff) + * + * Aigis参数来源: PQMagic CPU实现 (PARAMS=1/2/3) + * Mode 1: Q=2021377, K=4, L=3, ETA1=2, ETA2=3, D=13, GAMMA1=2^17, GAMMA2=(Q-1)/12 + * Mode 2: Q=3870721, K=5, L=4, ETA1=2, ETA2=5, D=14, GAMMA1=2^17, GAMMA2=(Q-1)/12 + * Mode 3: Q=3870721, K=6, L=5, ETA1=1, ETA2=5, D=14, GAMMA1=2^17, GAMMA2=(Q-1)/12 + * + * 签名格式: + * ML-DSA: c_tilde (CTILDEBYTES) || z_packed || hint_bitmap + * Aigis: z_packed || hint_bitmap || challenge_poly (N/8 + 8 bytes) + */ + +#ifndef PARAMS_H +#define PARAMS_H + +#include "config.h" +#include + +/* ================================================================ + * 通用类型 — 对两种算法均为 int32_t + * ================================================================ */ +typedef int32_t coeff_t; +typedef int64_t coeff2_t; + +/* ================================================================ + * 通用常量 + * ================================================================ */ +#define PARAM_N 256 +#define SEEDBYTES 32 + +/* ================================================================ + * ML-DSA (CRYSTALS-Dilithium) 参数 + * ================================================================ */ +#if ALGORITHM == ALGO_MLDSA + +#define PARAM_Q 8380417 +#define PARAM_QBITS 23 +#define CRHBYTES 64 +#define TRBYTES 64 +#define RNDBYTES 32 + +/* Mont constants for Q=8380417: + * MONT_VAL = 2^32 mod Q = 4193792 + * MONT_QINV: Q^{-1} mod 2^32 = 58728449 (fits in uint32) */ +#define MONT_VAL 4193792 +#define MONT_QINV 58728449u + +#if PARAM_MODE == 2 /* ML-DSA-44 */ + #define PARAM_K 4 + #define PARAM_L 4 + #define PARAM_D 13 + #define PARAM_ETA_S1 2 + #define PARAM_ETA_S2 2 + #define PARAM_TAU 39 + #define PARAM_BETA1 78 /* TAU * ETA_S1 */ + #define PARAM_BETA2 78 /* TAU * ETA_S2 */ + #define PARAM_GAMMA1 (1 << 17) + #define PARAM_GAMMA2 ((PARAM_Q - 1) / 88) + #define PARAM_OMEGA 80 + #define CTILDEBYTES 32 + #define SETA1BITS 3 /* ceil(log2(2*2+1)) = ceil(log2(5)) = 3 */ + #define SETA2BITS 3 + #define INTT_F 41978 /* N^{-1} * 2^32 mod Q */ + +#elif PARAM_MODE == 3 /* ML-DSA-65 */ + #define PARAM_K 6 + #define PARAM_L 5 + #define PARAM_D 13 + #define PARAM_ETA_S1 4 + #define PARAM_ETA_S2 4 + #define PARAM_TAU 49 + #define PARAM_BETA1 196 + #define PARAM_BETA2 196 + #define PARAM_GAMMA1 (1 << 19) + #define PARAM_GAMMA2 ((PARAM_Q - 1) / 32) + #define PARAM_OMEGA 55 + #define CTILDEBYTES 48 + #define SETA1BITS 4 /* ceil(log2(2*4+1)) = ceil(log2(9)) = 4 */ + #define SETA2BITS 4 + #define INTT_F 41978 + +#elif PARAM_MODE == 5 /* ML-DSA-87 */ + #define PARAM_K 8 + #define PARAM_L 7 + #define PARAM_D 13 + #define PARAM_ETA_S1 2 + #define PARAM_ETA_S2 2 + #define PARAM_TAU 60 + #define PARAM_BETA1 120 + #define PARAM_BETA2 120 + #define PARAM_GAMMA1 (1 << 19) + #define PARAM_GAMMA2 ((PARAM_Q - 1) / 32) + #define PARAM_OMEGA 75 + #define CTILDEBYTES 64 + #define SETA1BITS 3 + #define SETA2BITS 3 + #define INTT_F 41978 +#endif + +#define CRYPTO_ALGNAME "ML-DSA" + +/* ================================================================ + * Aigis-sig (PQMagic) 参数 + * 来源: PQMagic GPU实现 params.h (PARAMS=1/2/3) + * ================================================================ */ +#elif ALGORITHM == ALGO_AIGIS + +#define CRHBYTES 48 +#define TRBYTES 48 +#define RNDBYTES 0 + +/* Aigis ALPHA = 2*GAMMA2 — used in decompose/use_hint */ +#define PARAM_ALPHA_VAL (2 * ((PARAM_Q - 1) / 12)) + +#if PARAM_MODE == 1 /* Aigis-sig1 */ + #define PARAM_Q 2021377 + #define PARAM_QBITS 21 + #define PARAM_K 4 + #define PARAM_L 3 + #define PARAM_D 13 + #define PARAM_ETA_S1 2 + #define PARAM_ETA_S2 3 + #define PARAM_TAU 60 + #define PARAM_BETA1 120 /* TAU * ETA_S1 = 60*2 */ + #define PARAM_BETA2 175 /* from PQMagic params: 175 (~TAU*ETA_S2 slightly adjusted) */ + #define PARAM_GAMMA1 (1 << 17) + #define PARAM_GAMMA2 ((PARAM_Q - 1) / 12) /* = 168448 */ + #define PARAM_OMEGA 80 + #define SETA1BITS 3 /* ceil(log2(2*2+1))=3 */ + #define SETA2BITS 3 /* ceil(log2(2*3+1))=3 */ + /* Mont: 2^32 mod Q=2021377 = 1562548; Q^{-1} mod 2^32 */ + #define MONT_VAL 1562548 + #define MONT_QINV 1445013505u + +#elif PARAM_MODE == 2 /* Aigis-sig2 */ + #define PARAM_Q 3870721 + #define PARAM_QBITS 22 + #define PARAM_K 5 + #define PARAM_L 4 + #define PARAM_D 14 + #define PARAM_ETA_S1 2 + #define PARAM_ETA_S2 5 + #define PARAM_TAU 60 + #define PARAM_BETA1 120 /* TAU * ETA_S1 = 60*2 */ + #define PARAM_BETA2 275 /* from PQMagic params */ + #define PARAM_GAMMA1 (1 << 17) + #define PARAM_GAMMA2 ((PARAM_Q - 1) / 12) /* = 322560 */ + #define PARAM_OMEGA 96 + #define SETA1BITS 3 /* ceil(log2(5))=3 */ + #define SETA2BITS 4 /* ceil(log2(11))=4 */ + /* Mont: 2^32 mod Q=3870721 = 2337707; Q^{-1} mod 2^32 */ + #define MONT_VAL 2337707 + #define MONT_QINV 1623519233u + +#elif PARAM_MODE == 3 /* Aigis-sig3 */ + #define PARAM_Q 3870721 + #define PARAM_QBITS 22 + #define PARAM_K 6 + #define PARAM_L 5 + #define PARAM_D 14 + #define PARAM_ETA_S1 1 + #define PARAM_ETA_S2 5 + #define PARAM_TAU 60 + #define PARAM_BETA1 60 /* TAU * ETA_S1 = 60*1 */ + #define PARAM_BETA2 275 /* from PQMagic params */ + #define PARAM_GAMMA1 (1 << 17) + #define PARAM_GAMMA2 ((PARAM_Q - 1) / 12) /* = 322560 */ + #define PARAM_OMEGA 120 + #define SETA1BITS 2 /* ceil(log2(3))=2: values {-1,0,1}→{2,1,0} */ + #define SETA2BITS 4 /* ceil(log2(11))=4 */ + /* Mont: same as mode 2; Q^{-1} mod 2^32 */ + #define MONT_VAL 2337707 + #define MONT_QINV 1623519233u +#endif + +#define CRYPTO_ALGNAME "Aigis-sig" + +#endif /* ALGORITHM */ + +/* ================================================================ + * 算法钩子宏 — 消除 poly/polyvec 层的大量 #if ALGORITHM 分叉 + * ================================================================ */ + +/* + * COEFF_BIAS: 系数偏置常量 + * ML-DSA 使用中心化 (-Q/2, Q/2], 偏置 = 0 + * Aigis 使用无符号 [0, Q), 偏置 = Q + * eta/t0 pack/unpack 统一为: COEFF_BIAS + ETA - coeff + */ +#if ALGORITHM == ALGO_MLDSA +#define COEFF_BIAS 0 +#elif ALGORITHM == ALGO_AIGIS +#define COEFF_BIAS PARAM_Q +#endif + +/* + * MATRIX_NONCE(i,j): matrix A expansion 的 nonce 编码 + * ML-DSA: 2-byte LE, nonce = i*256 + j + * Aigis: 1-byte, nonce = i + (j<<4) + */ +#if ALGORITHM == ALGO_MLDSA +#define MATRIX_NONCE(i, j) ((uint16_t)((i) * 256 + (j))) +#elif ALGORITHM == ALGO_AIGIS +#define MATRIX_NONCE(i, j) ((uint16_t)((i) + ((j) << 4))) +#endif + +/* + * GAMMA1_NONCE(base, i): gamma1 采样的 nonce 计算 + * ML-DSA: nonce = L * base + i (每次 rejection 只递增 base 一次) + * Aigis: nonce = base + i (每个 poly 一个 nonce) + */ +#if ALGORITHM == ALGO_MLDSA +#define GAMMA1_NONCE(base, i) ((uint16_t)(PARAM_L * (base) + (i))) +#elif ALGORITHM == ALGO_AIGIS +#define GAMMA1_NONCE(base, i) ((uint16_t)((base) + (i))) +#endif + +/* + * Z_BIAS / Z_FIXUP(t): polyz pack/unpack 的偏置 + * ML-DSA: t = GAMMA1 - coeff, Z_FIXUP 为空 + * Aigis: t = GAMMA1-1 - coeff; 负值 +Q, Z_FIXUP 修正负值 + */ +#if ALGORITHM == ALGO_MLDSA +#define Z_BIAS PARAM_GAMMA1 +#define Z_FIXUP(t) /* nothing */ +#elif ALGORITHM == ALGO_AIGIS +#define Z_BIAS (PARAM_GAMMA1 - 1) +#define Z_FIXUP(t) (t) += (((int32_t)(t)) >> 31) & PARAM_Q +#endif + +/* ================================================================ + * 导出的打包尺寸 (基于参数计算, 两种算法通用公式) + * ================================================================ */ + +/* bits per t1 coeff: POLYT1_PACKED_BITS = QBITS - D */ +#define POLYT1_PACKED_BITS (PARAM_QBITS - PARAM_D) +/* bytes per poly t1: N * bits / 8 */ +#define POLYT1_PACKEDBYTES (PARAM_N * POLYT1_PACKED_BITS / 8) + +/* bytes per poly t0: N * D / 8 */ +#define POLYT0_PACKEDBYTES (PARAM_N * PARAM_D / 8) + +/* bytes per eta poly (s1): N * SETA1BITS / 8 */ +#define POLYETA1_PACKEDBYTES (PARAM_N * SETA1BITS / 8) + +/* bytes per eta poly (s2): N * SETA2BITS / 8 */ +#define POLYETA2_PACKEDBYTES (PARAM_N * SETA2BITS / 8) + +/* bytes per z poly: depends on GAMMA1 (18-bit or 20-bit coeffs) */ +#if PARAM_GAMMA1 == (1 << 17) +#define POLYZ_PACKEDBYTES 576 /* 9 bytes per 4 coeffs (18 bits) */ +#elif PARAM_GAMMA1 == (1 << 19) +#define POLYZ_PACKEDBYTES 640 /* 5 bytes per 2 coeffs (20 bits) */ +#endif + +/* bytes per w1 poly: depends on bits per coeff = ceil(log2(N_W1+1)) + * GAMMA2=(Q-1)/88 → N_W1=44 → 6 bits/coeff → 4 per 3 bytes → 192 + * GAMMA2=(Q-1)/32 → N_W1=16 → 4 bits/coeff → 2 per 1 byte → 128 + * GAMMA2=(Q-1)/12 → N_W1=6 → 3 bits/coeff → 8 per 3 bytes → 96 + */ +#if PARAM_GAMMA2 == (PARAM_Q - 1) / 88 +#define POLYW1_PACKEDBYTES 192 +#elif PARAM_GAMMA2 == (PARAM_Q - 1) / 32 +#define POLYW1_PACKEDBYTES 128 +#elif PARAM_GAMMA2 == (PARAM_Q - 1) / 12 +#define POLYW1_PACKEDBYTES 96 +#endif + +/* Number of distinct high-bit parts: N_W1 = (Q-1) / (2 * GAMMA2) */ +#define N_W1 ((PARAM_Q - 1) / (2 * PARAM_GAMMA2)) + +/* Public/Secret key and Signature sizes */ +#define CRYPTO_PUBLICKEYBYTES (SEEDBYTES + PARAM_K * POLYT1_PACKEDBYTES) +#define CRYPTO_SECRETKEYBYTES (2 * SEEDBYTES + TRBYTES \ + + PARAM_L * POLYETA1_PACKEDBYTES \ + + PARAM_K * POLYETA2_PACKEDBYTES \ + + PARAM_K * POLYT0_PACKEDBYTES) + +#if ALGORITHM == ALGO_MLDSA +/* ML-DSA sig format: c_tilde || z_packed || hints_bitmap */ +#define CRYPTO_BYTES (CTILDEBYTES \ + + PARAM_L * POLYZ_PACKEDBYTES \ + + PARAM_OMEGA + PARAM_K) +#elif ALGORITHM == ALGO_AIGIS +/* Aigis sig format: z_packed || hints_bitmap || challenge_poly (N/8+8 bytes) */ +#define CHALLENGE_POLY_PACKEDBYTES (PARAM_N / 8 + 8) /* 40 bytes: bitmap + signs */ +#define CRYPTO_BYTES (PARAM_L * POLYZ_PACKEDBYTES \ + + PARAM_OMEGA + PARAM_K \ + + CHALLENGE_POLY_PACKEDBYTES) +#endif + +#endif /* PARAMS_H */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/poly.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/poly.cuh new file mode 100644 index 000000000..994177be2 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/poly.cuh @@ -0,0 +1,817 @@ +#ifndef POLY_CUH +#define POLY_CUH + +#include +#include "params.h" +#include "ntt.cuh" +#include "reduce.cuh" +#include "rounding.cuh" +#include "symmetric.cuh" + +typedef struct { int32_t coeffs[PARAM_N]; } poly; + +/* ==== Basic arithmetic ==== */ +static __device__ void poly_reduce(poly *a) { + for (unsigned int i = 0; i < PARAM_N; ++i) a->coeffs[i] = reduce32(a->coeffs[i]); +} +static __device__ void poly_caddq(poly *a) { + for (unsigned int i = 0; i < PARAM_N; ++i) a->coeffs[i] = caddq(a->coeffs[i]); +} +static __device__ void poly_freeze2q(poly *a) { + for (unsigned int i = 0; i < PARAM_N; ++i) a->coeffs[i] = freeze2q(a->coeffs[i]); +} +static __device__ void poly_freeze4q(poly *a) { + for (unsigned int i = 0; i < PARAM_N; ++i) a->coeffs[i] = freeze4q(a->coeffs[i]); +} +static __device__ void poly_add(poly *c, const poly *a, const poly *b) { + for (unsigned int i = 0; i < PARAM_N; ++i) c->coeffs[i] = a->coeffs[i] + b->coeffs[i]; +} +/* Unified sub: ML-DSA (COEFF_BIAS=0) → a-b; Aigis (COEFF_BIAS=Q) → a+2Q-b */ +static __device__ void poly_sub(poly *c, const poly *a, const poly *b) { + for (unsigned int i = 0; i < PARAM_N; ++i) c->coeffs[i] = a->coeffs[i] + 2 * COEFF_BIAS - b->coeffs[i]; +} +#if ALGORITHM == ALGO_AIGIS +static __device__ void poly_neg(poly *a) { + for (unsigned int i = 0; i < PARAM_N; ++i) a->coeffs[i] = 2 * PARAM_Q - a->coeffs[i]; +} +#endif +static __device__ void poly_shiftl(poly *a) { + for (unsigned int i = 0; i < PARAM_N; ++i) a->coeffs[i] <<= PARAM_D; +} +static __device__ void poly_ntt(poly *a) { ntt(a->coeffs); } +static __device__ void poly_invntt_tomont(poly *a) { invntt_tomont(a->coeffs); } +static __device__ void poly_pointwise_montgomery(poly *c, const poly *a, const poly *b) { + for (unsigned int i = 0; i < PARAM_N; ++i) + c->coeffs[i] = montgomery_reduce((int64_t)a->coeffs[i] * b->coeffs[i]); +} +static __device__ void poly_power2round(poly *a1, poly *a0, const poly *a) { + for (unsigned int i = 0; i < PARAM_N; ++i) + a1->coeffs[i] = power2round(&a0->coeffs[i], a->coeffs[i]); +} +static __device__ void poly_decompose(poly *a1, poly *a0, const poly *a) { + for (unsigned int i = 0; i < PARAM_N; ++i) + a1->coeffs[i] = decompose(&a0->coeffs[i], a->coeffs[i]); +} +static __device__ unsigned int poly_make_hint(poly *h, const poly *a0, const poly *a1) { + unsigned int s = 0; + for (unsigned int i = 0; i < PARAM_N; ++i) { + h->coeffs[i] = make_hint(a0->coeffs[i], a1->coeffs[i]); + s += h->coeffs[i]; + } + return s; +} +static __device__ void poly_use_hint(poly *b, const poly *a, const poly *h) { + for (unsigned int i = 0; i < PARAM_N; ++i) + b->coeffs[i] = use_hint(a->coeffs[i], h->coeffs[i]); +} + +/* ---- chknorm: check if any coeff has |coeff| >= B ---- */ +#if ALGORITHM == ALGO_MLDSA +static __device__ int poly_chknorm(const poly *a, int32_t B) { + if (B > (PARAM_Q - 1) / 8) return 1; + for (unsigned int i = 0; i < PARAM_N; ++i) { + int32_t t = a->coeffs[i] >> 31; + t = a->coeffs[i] - (t & 2 * a->coeffs[i]); + if (t >= B) return 1; + } + return 0; +} +#elif ALGORITHM == ALGO_AIGIS +/* Aigis: unsigned coeff ∈ [0,Q), distance = |(Q-1)/2 - coeff| */ +static __device__ int poly_chknorm(const poly *a, int32_t B) { + for (unsigned int i = 0; i < PARAM_N; ++i) { + int32_t t = (PARAM_Q - 1) / 2 - a->coeffs[i]; + t ^= (t >> 31); + t = (PARAM_Q - 1) / 2 - t; + if (t >= B) return 1; + } + return 0; +} +#endif + +/* ================================================================ + * Uniform rejection sampling for matrix A + * ================================================================ */ +static __device__ unsigned int rej_uniform(int32_t *a, unsigned int len, + const uint8_t *buf, unsigned int buflen) { + unsigned int ctr = 0, pos = 0; + while (ctr < len && pos + 3 <= buflen) { + uint32_t t = buf[pos++] | ((uint32_t)buf[pos++] << 8) | ((uint32_t)buf[pos++] << 16); + t &= (1u << PARAM_QBITS) - 1; + if (t < (uint32_t)PARAM_Q) a[ctr++] = (int32_t)t; + } + return ctr; +} + +#define POLY_UNIFORM_NBLOCKS ((768 + STREAM128_BLOCKBYTES - 1) / STREAM128_BLOCKBYTES) + +/* Unified poly_uniform: 共享函数体, 仅 stream init 按算法分流 */ +static __device__ __noinline__ void poly_uniform(poly *a, const uint8_t seed[SEEDBYTES], uint16_t nonce) { + unsigned int ctr, off; + unsigned int buflen = POLY_UNIFORM_NBLOCKS * STREAM128_BLOCKBYTES; + uint8_t buf[POLY_UNIFORM_NBLOCKS * STREAM128_BLOCKBYTES + 2]; + stream128_state state; +#if ALGORITHM == ALGO_MLDSA + stream128_init(&state, seed, nonce); +#elif ALGORITHM == ALGO_AIGIS + aigis_shake128_stream_init(&state, seed, (uint8_t)nonce); +#endif + stream128_squeezeblocks(buf, POLY_UNIFORM_NBLOCKS, &state); + ctr = rej_uniform(a->coeffs, PARAM_N, buf, buflen); + while (ctr < PARAM_N) { + off = buflen % 3; + for (unsigned int i = 0; i < off; ++i) buf[i] = buf[buflen - off + i]; + stream128_squeezeblocks(buf + off, 1, &state); + buflen = STREAM128_BLOCKBYTES + off; + ctr += rej_uniform(a->coeffs + ctr, PARAM_N - ctr, buf, buflen); + } +} + +static __device__ __noinline__ void poly_uniform_to(coeff_t *a, const uint8_t seed[SEEDBYTES], uint16_t nonce) { + unsigned int ctr, off; + unsigned int buflen = POLY_UNIFORM_NBLOCKS * STREAM128_BLOCKBYTES; + uint8_t buf[POLY_UNIFORM_NBLOCKS * STREAM128_BLOCKBYTES + 2]; + stream128_state state; +#if ALGORITHM == ALGO_MLDSA + stream128_init(&state, seed, nonce); +#elif ALGORITHM == ALGO_AIGIS + aigis_shake128_stream_init(&state, seed, (uint8_t)nonce); +#endif + stream128_squeezeblocks(buf, POLY_UNIFORM_NBLOCKS, &state); + ctr = rej_uniform(a, PARAM_N, buf, buflen); + while (ctr < PARAM_N) { + off = buflen % 3; + for (unsigned int i = 0; i < off; ++i) buf[i] = buf[buflen - off + i]; + stream128_squeezeblocks(buf + off, 1, &state); + buflen = STREAM128_BLOCKBYTES + off; + ctr += rej_uniform(a + ctr, PARAM_N - ctr, buf, buflen); + } +} + +/* ================================================================ + * Eta rejection sampling (s1 and s2) + * rej 函数按算法分流, poly_uniform_eta 共享骨架 + * ================================================================ */ + +#if ALGORITHM == ALGO_MLDSA +static __device__ unsigned int rej_eta_mldsa_to(int32_t *a, unsigned int len, + const uint8_t *buf, + unsigned int buflen, + int eta) { + unsigned int ctr = 0, pos = 0; + while (ctr < len && pos < buflen) { + uint32_t t0 = buf[pos] & 0x0F; + uint32_t t1 = buf[pos++] >> 4; + if (eta == 2) { + if (t0 < 15) { + t0 = t0 - (205 * t0 >> 10) * 5; + a[ctr++] = 2 - (int32_t)t0; + } + if (t1 < 15 && ctr < len) { + t1 = t1 - (205 * t1 >> 10) * 5; + a[ctr++] = 2 - (int32_t)t1; + } + } else { + if (t0 < 9) a[ctr++] = 4 - (int32_t)t0; + if (t1 < 9 && ctr < len) a[ctr++] = 4 - (int32_t)t1; + } + } + return ctr; +} + +/* ML-DSA: Output CENTERED int32 in [-ETA, ETA] + * FIPS 204 Algorithm 15 (CoeffFromHalfByte): + * ETA==2: accept b<15, coeff = 2 - (b mod 5) + * ETA==4: accept b<9, coeff = 4 - b */ +static __device__ unsigned int rej_eta_val(int32_t *a, unsigned int len, + const uint8_t *buf, unsigned int buflen) { + unsigned int ctr = 0, pos = 0; + while (ctr < len && pos < buflen) { + uint32_t t0 = buf[pos] & 0x0F; + uint32_t t1 = buf[pos++] >> 4; +#if PARAM_ETA_S1 == 2 + if (t0 < 15) { + t0 = t0 - (205*t0 >> 10)*5; + a[ctr++] = 2 - (int32_t)t0; + } + if (t1 < 15 && ctr < len) { + t1 = t1 - (205*t1 >> 10)*5; + a[ctr++] = 2 - (int32_t)t1; + } +#elif PARAM_ETA_S1 == 4 + if (t0 < 9) a[ctr++] = 4 - (int32_t)t0; + if (t1 < 9 && ctr < len) a[ctr++] = 4 - (int32_t)t1; +#endif + } + return ctr; +} +/* Dispatch macros for unified poly_uniform_eta */ +#define rej_eta1(a, len, buf, buflen) rej_eta_val(a, len, buf, buflen) +#define rej_eta2(a, len, buf, buflen) rej_eta_val(a, len, buf, buflen) + +#elif ALGORITHM == ALGO_AIGIS +/* Aigis: Output UNSIGNED Q+ETA-t + * ETA1=1: 2-bit extraction (4 values per byte) + * ETA1=2: 3-bit extraction (8 values per 3 bytes) — matches CPU reference + * ETA1=3: 3-bit extraction (same structure) */ +static __device__ unsigned int rej_eta1_aigis(int32_t *a, unsigned int len, + const uint8_t *buf, unsigned int buflen) { + unsigned int ctr = 0, pos = 0; +#if PARAM_ETA_S1 == 1 + while (ctr < len && pos < buflen) { + uint32_t t0 = buf[pos] & 0x03; + uint32_t t1 = (buf[pos] >> 2) & 0x03; + uint32_t t2 = (buf[pos] >> 4) & 0x03; + uint32_t t3 = (buf[pos++] >> 6) & 0x03; + if (t0 <= 2u * PARAM_ETA_S1) a[ctr++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t0; + if (t1 <= 2u * PARAM_ETA_S1 && ctr < len) a[ctr++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t1; + if (t2 <= 2u * PARAM_ETA_S1 && ctr < len) a[ctr++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t2; + if (t3 <= 2u * PARAM_ETA_S1 && ctr < len) a[ctr++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t3; + } +#elif PARAM_ETA_S1 == 2 || PARAM_ETA_S1 == 3 + /* 3-bit extraction: 8 values from every 3 bytes */ + while (ctr < len && pos + 3 <= buflen) { + uint32_t t0 = buf[pos] & 0x07; + uint32_t t1 = (buf[pos] >> 3) & 0x07; + uint32_t t2 = (buf[pos] >> 6) | ((uint32_t)(buf[pos + 1] & 0x01) << 2); + uint32_t t3 = (buf[pos + 1] >> 1) & 0x07; + uint32_t t4 = (buf[pos + 1] >> 4) & 0x07; + uint32_t t5 = (buf[pos + 1] >> 7) | ((uint32_t)(buf[pos + 2] & 0x03) << 1); + uint32_t t6 = (buf[pos + 2] >> 2) & 0x07; + uint32_t t7 = buf[pos + 2] >> 5; + pos += 3; + if (t0 <= 2u * PARAM_ETA_S1) a[ctr++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t0; + if (t1 <= 2u * PARAM_ETA_S1 && ctr < len) a[ctr++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t1; + if (t2 <= 2u * PARAM_ETA_S1 && ctr < len) a[ctr++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t2; + if (t3 <= 2u * PARAM_ETA_S1 && ctr < len) a[ctr++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t3; + if (t4 <= 2u * PARAM_ETA_S1 && ctr < len) a[ctr++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t4; + if (t5 <= 2u * PARAM_ETA_S1 && ctr < len) a[ctr++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t5; + if (t6 <= 2u * PARAM_ETA_S1 && ctr < len) a[ctr++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t6; + if (t7 <= 2u * PARAM_ETA_S1 && ctr < len) a[ctr++] = PARAM_Q + PARAM_ETA_S1 - (int32_t)t7; + } +#endif + return ctr; +} +/* rej_eta2_aigis: exact mirror of CPU rej_eta2() — two do-while loops, returns pos (byte position) */ +static __device__ unsigned int rej_eta2_aigis(int32_t *a, unsigned int len, + const uint8_t *buf) { + unsigned int ctr = 0, pos = 0; + uint8_t t0, t1; + + /* Fast loop: no ctr check on t1 */ + do { +#if PARAM_ETA_S2 == 3 + t0 = buf[pos] & 0x07; + t1 = buf[pos++] >> 5; +#else + t0 = buf[pos] & 0x0F; + t1 = buf[pos++] >> 4; +#endif + if (t0 <= 2u * PARAM_ETA_S2) + a[ctr++] = PARAM_Q + PARAM_ETA_S2 - (int32_t)t0; + if (t1 <= 2u * PARAM_ETA_S2) + a[ctr++] = PARAM_Q + PARAM_ETA_S2 - (int32_t)t1; + } while (ctr < len - 2); + + /* Slow loop: ctr check on t1 */ + do { +#if PARAM_ETA_S2 == 3 + t0 = buf[pos] & 0x07; + t1 = buf[pos++] >> 5; +#else + t0 = buf[pos] & 0x0F; + t1 = buf[pos++] >> 4; +#endif + if (t0 <= 2u * PARAM_ETA_S2) + a[ctr++] = PARAM_Q + PARAM_ETA_S2 - (int32_t)t0; + if (t1 <= 2u * PARAM_ETA_S2 && ctr < len) + a[ctr++] = PARAM_Q + PARAM_ETA_S2 - (int32_t)t1; + } while (ctr < len); + + return pos; +} +/* Dispatch macros for unified poly_uniform_eta */ +#define rej_eta1(a, len, buf, buflen) rej_eta1_aigis(a, len, buf, buflen) +#define rej_eta2(a, len, buf) rej_eta2_aigis(a, len, buf) +#endif /* rej_eta */ + +/* Block count 宏: match CPU reference (FIPS 204) */ +#if ALGORITHM == ALGO_MLDSA +#if PARAM_ETA_S1 == 2 +#define POLY_UNIFORM_ETA1_NBLOCKS ((136 + STREAM256_BLOCKBYTES - 1)/STREAM256_BLOCKBYTES) +#elif PARAM_ETA_S1 == 4 +#define POLY_UNIFORM_ETA1_NBLOCKS ((227 + STREAM256_BLOCKBYTES - 1)/STREAM256_BLOCKBYTES) +#endif +#if PARAM_ETA_S2 == 2 +#define POLY_UNIFORM_ETA2_NBLOCKS ((136 + STREAM256_BLOCKBYTES - 1)/STREAM256_BLOCKBYTES) +#elif PARAM_ETA_S2 == 4 +#define POLY_UNIFORM_ETA2_NBLOCKS ((227 + STREAM256_BLOCKBYTES - 1)/STREAM256_BLOCKBYTES) +#endif +#elif ALGORITHM == ALGO_AIGIS +#define POLY_UNIFORM_ETA1_NBLOCKS 2 +#define POLY_UNIFORM_ETA2_NBLOCKS 3 +#endif + +/* Unified poly_uniform_eta_s1: 共享骨架, 仅 stream init 和 rej 按算法分流 */ +static __device__ __noinline__ void poly_uniform_eta_s1(poly *a, + const uint8_t *seed, + uint16_t nonce) { + uint8_t buf[POLY_UNIFORM_ETA1_NBLOCKS * STREAM256_BLOCKBYTES]; + stream256_state state; +#if ALGORITHM == ALGO_MLDSA + stream256_init(&state, seed, nonce); +#elif ALGORITHM == ALGO_AIGIS + aigis_shake256_eta_init(&state, seed, (uint8_t)nonce); +#endif + stream256_squeezeblocks(buf, POLY_UNIFORM_ETA1_NBLOCKS, &state); + unsigned int ctr = rej_eta1(a->coeffs, PARAM_N, buf, + POLY_UNIFORM_ETA1_NBLOCKS * STREAM256_BLOCKBYTES); + while (ctr < PARAM_N) { + stream256_squeezeblocks(buf, 1, &state); + ctr += rej_eta1(a->coeffs + ctr, PARAM_N - ctr, buf, STREAM256_BLOCKBYTES); + } +} + +static __device__ __noinline__ void poly_uniform_eta_s1_to(coeff_t *a, + const uint8_t *seed, + uint16_t nonce) { + uint8_t buf[POLY_UNIFORM_ETA1_NBLOCKS * STREAM256_BLOCKBYTES]; + stream256_state state; +#if ALGORITHM == ALGO_MLDSA + stream256_init(&state, seed, nonce); +#elif ALGORITHM == ALGO_AIGIS + aigis_shake256_eta_init(&state, seed, (uint8_t)nonce); +#endif + stream256_squeezeblocks(buf, POLY_UNIFORM_ETA1_NBLOCKS, &state); +#if ALGORITHM == ALGO_MLDSA + unsigned int ctr = rej_eta_mldsa_to(a, PARAM_N, buf, + POLY_UNIFORM_ETA1_NBLOCKS * STREAM256_BLOCKBYTES, + PARAM_ETA_S1); + while (ctr < PARAM_N) { + stream256_squeezeblocks(buf, 1, &state); + ctr += rej_eta_mldsa_to(a + ctr, PARAM_N - ctr, buf, + STREAM256_BLOCKBYTES, PARAM_ETA_S1); + } +#else + unsigned int ctr = rej_eta1(a, PARAM_N, buf, + POLY_UNIFORM_ETA1_NBLOCKS * STREAM256_BLOCKBYTES); + while (ctr < PARAM_N) { + stream256_squeezeblocks(buf, 1, &state); + ctr += rej_eta1(a + ctr, PARAM_N - ctr, buf, STREAM256_BLOCKBYTES); + } +#endif +} + +static __device__ __noinline__ void poly_uniform_eta_s2(poly *a, + const uint8_t *seed, + uint16_t nonce) { + uint8_t buf[POLY_UNIFORM_ETA2_NBLOCKS * STREAM256_BLOCKBYTES]; + stream256_state state; +#if ALGORITHM == ALGO_MLDSA + stream256_init(&state, seed, nonce); + stream256_squeezeblocks(buf, POLY_UNIFORM_ETA2_NBLOCKS, &state); + unsigned int ctr = rej_eta2(a->coeffs, PARAM_N, buf, + POLY_UNIFORM_ETA2_NBLOCKS * STREAM256_BLOCKBYTES); + while (ctr < PARAM_N) { + stream256_squeezeblocks(buf, 1, &state); + ctr += rej_eta2(a->coeffs + ctr, PARAM_N - ctr, buf, STREAM256_BLOCKBYTES); + } +#elif ALGORITHM == ALGO_AIGIS + aigis_shake256_eta_init(&state, seed, (uint8_t)nonce); + stream256_squeezeblocks(buf, 2, &state); + +#if PARAM_ETA_S2 == 3 + /* ETA2=3: single pass, probability of needing >2 blocks is < 2^{-378} */ + rej_eta2(a->coeffs, PARAM_N, buf); + +#elif PARAM_ETA_S2 == 5 + /* ETA2=5: two-pass split at 223 — exactly mirrors CPU poly_uniform_eta2() */ + { + unsigned int pos = rej_eta2(a->coeffs, 223, buf); + + if (2u * STREAM256_BLOCKBYTES - pos < 85) { + stream256_squeezeblocks(buf + 2 * STREAM256_BLOCKBYTES, 1, &state); + } + + rej_eta2(&a->coeffs[223], 33, &buf[pos]); + } +#endif +#endif +} + +static __device__ __noinline__ void poly_uniform_eta_s2_to(coeff_t *a, + const uint8_t *seed, + uint16_t nonce) { + uint8_t buf[POLY_UNIFORM_ETA2_NBLOCKS * STREAM256_BLOCKBYTES]; + stream256_state state; +#if ALGORITHM == ALGO_MLDSA + stream256_init(&state, seed, nonce); + stream256_squeezeblocks(buf, POLY_UNIFORM_ETA2_NBLOCKS, &state); + unsigned int ctr = rej_eta_mldsa_to(a, PARAM_N, buf, + POLY_UNIFORM_ETA2_NBLOCKS * STREAM256_BLOCKBYTES, + PARAM_ETA_S2); + while (ctr < PARAM_N) { + stream256_squeezeblocks(buf, 1, &state); + ctr += rej_eta_mldsa_to(a + ctr, PARAM_N - ctr, buf, + STREAM256_BLOCKBYTES, PARAM_ETA_S2); + } +#elif ALGORITHM == ALGO_AIGIS + aigis_shake256_eta_init(&state, seed, (uint8_t)nonce); + stream256_squeezeblocks(buf, 2, &state); + +#if PARAM_ETA_S2 == 3 + rej_eta2(a, PARAM_N, buf); +#elif PARAM_ETA_S2 == 5 + { + unsigned int pos = rej_eta2(a, 223, buf); + + if (2u * STREAM256_BLOCKBYTES - pos < 85) { + stream256_squeezeblocks(buf + 2 * STREAM256_BLOCKBYTES, 1, &state); + } + + rej_eta2(&a[223], 33, &buf[pos]); + } +#endif +#endif +} + +/* ================================================================ + * gamma1 uniform mask vector y + * ================================================================ */ +#if ALGORITHM == ALGO_MLDSA +/* ML-DSA: deterministic unpack from SHAKE stream (GAMMA1-coeff encoding) */ +#define POLY_UNIFORM_GAMMA1_NBLOCKS \ + ((POLYZ_PACKEDBYTES + STREAM256_BLOCKBYTES - 1) / STREAM256_BLOCKBYTES) + +static __device__ void polyz_unpack(poly *r, const uint8_t *a); /* forward decl */ + +static __device__ void poly_uniform_gamma1(poly *a, const uint8_t seed[CRHBYTES], uint16_t nonce) { + uint8_t buf[POLY_UNIFORM_GAMMA1_NBLOCKS * STREAM256_BLOCKBYTES]; + stream256_state state; + stream256_init(&state, seed, nonce); + stream256_squeezeblocks(buf, POLY_UNIFORM_GAMMA1_NBLOCKS, &state); + polyz_unpack(a, buf); +} + +#elif ALGORITHM == ALGO_AIGIS +/* Aigis: rejection sampling, output Q+GAMMA1-1-t, seed = key||hash (SEEDBYTES+CRHBYTES) */ +#define POLY_UNIFORM_GAMMA1_NBLOCKS 5 /* 5 SHAKE256 blocks is conservative */ + +static __device__ __noinline__ void poly_uniform_gamma1(poly *a, + const uint8_t seed[SEEDBYTES + CRHBYTES], + uint16_t nonce) { + unsigned int ctr = 0, pos = 0; + uint32_t t0, t1; + uint8_t buf[POLY_UNIFORM_GAMMA1_NBLOCKS * STREAM256_BLOCKBYTES]; + stream256_state state; + aigis_shake256_gamma1_init(&state, seed, nonce); + stream256_squeezeblocks(buf, POLY_UNIFORM_GAMMA1_NBLOCKS, &state); + + while (ctr < PARAM_N) { + if (pos + 5 > POLY_UNIFORM_GAMMA1_NBLOCKS * STREAM256_BLOCKBYTES) { + /* Squeeze more blocks if needed (very rare) */ + stream256_squeezeblocks(buf, 1, &state); + pos = 0; + } + t0 = buf[pos]; + t0 |= (uint32_t)buf[pos + 1] << 8; + t0 |= (uint32_t)buf[pos + 2] << 16; + t1 = buf[pos + 2] >> 4; + t1 |= (uint32_t)buf[pos + 3] << 4; + t1 |= (uint32_t)buf[pos + 4] << 12; + t0 &= 0x3FFFF; + t1 &= 0x3FFFF; + pos += 5; + if (t0 <= 2u * (uint32_t)PARAM_GAMMA1) + a->coeffs[ctr++] = PARAM_Q + PARAM_GAMMA1 - 1 - (int32_t)t0; + if (t1 <= 2u * (uint32_t)PARAM_GAMMA1 && ctr < PARAM_N) + a->coeffs[ctr++] = PARAM_Q + PARAM_GAMMA1 - 1 - (int32_t)t1; + } +} +#endif + +/* ================================================================ + * Challenge polynomial + * ================================================================ */ +#if ALGORITHM == ALGO_MLDSA +/* ML-DSA: absorb CTILDEBYTES seed, range N-TAU..N, coeffs {-1,+1} */ +static __device__ __noinline__ void poly_challenge(poly *c, const uint8_t seed[CTILDEBYTES]) { + unsigned int i, b, pos; + uint64_t signs; + uint8_t buf[SHAKE256_RATE]; + keccak_state state; + shake256_init(&state); + shake256_absorb(&state, seed, CTILDEBYTES); + shake256_finalize(&state); + shake256_squeezeblocks(buf, 1, &state); + + signs = 0; + for (i = 0; i < 8; ++i) signs |= (uint64_t)buf[i] << 8 * i; + pos = 8; + for (i = 0; i < PARAM_N; ++i) c->coeffs[i] = 0; + for (i = PARAM_N - PARAM_TAU; i < PARAM_N; ++i) { + do { + if (pos >= SHAKE256_RATE) { shake256_squeezeblocks(buf, 1, &state); pos = 0; } + b = buf[pos++]; + } while (b > i); + c->coeffs[i] = c->coeffs[b]; + c->coeffs[b] = 1 - 2 * (int32_t)(signs & 1); + signs >>= 1; + } +} + +#elif ALGORITHM == ALGO_AIGIS +/* Aigis: absorb mu(CRHBYTES) + packed_w1, range 196..255 (60 coeffs), coeffs {1, Q-1} */ +static __device__ __noinline__ void poly_challenge(poly *c, + const uint8_t mu[CRHBYTES], + const uint8_t *packed_w1, + unsigned int w1_len) { + unsigned int i, b, pos; + uint64_t signs, mask; + uint8_t buf[SHAKE256_RATE]; + keccak_state state; + shake256_init(&state); + shake256_absorb(&state, mu, CRHBYTES); + shake256_absorb(&state, packed_w1, w1_len); + shake256_finalize(&state); + shake256_squeezeblocks(buf, 1, &state); + + signs = 0; + for (i = 0; i < 8; ++i) signs |= (uint64_t)buf[i] << 8 * i; + pos = 8; + mask = 1; + for (i = 0; i < PARAM_N; ++i) c->coeffs[i] = 0; + /* Aigis: indices 196..255 (fixed TAU=60 non-zero positions) */ + for (i = 196; i < 256; ++i) { + do { + if (pos >= SHAKE256_RATE) { shake256_squeezeblocks(buf, 1, &state); pos = 0; } + b = buf[pos++]; + } while (b > i); + c->coeffs[i] = c->coeffs[b]; + c->coeffs[b] = (signs & mask) ? PARAM_Q - 1 : 1; + mask <<= 1; + } +} +#endif + +/* ================================================================ + * Packing — 使用 COEFF_BIAS 消除 eta/t0 的算法分叉 + * ML-DSA (COEFF_BIAS=0): pack as (ETA - coeff) + * Aigis (COEFF_BIAS=Q): pack as (Q + ETA - coeff) + * ================================================================ */ + +/* polyeta_s1: SETA1BITS bits per coeff */ +static __device__ void polyeta_s1_pack(uint8_t *r, const poly *a) { + unsigned int i; uint8_t t[8]; +#if SETA1BITS == 2 + for (i = 0; i < PARAM_N / 4; ++i) { + t[0]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S1-a->coeffs[4*i+0]); t[1]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S1-a->coeffs[4*i+1]); + t[2]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S1-a->coeffs[4*i+2]); t[3]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S1-a->coeffs[4*i+3]); + r[i] = t[0] | (t[1]<<2) | (t[2]<<4) | (t[3]<<6); + } +#elif SETA1BITS == 3 + for (i = 0; i < PARAM_N / 8; ++i) { + t[0]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S1-a->coeffs[8*i+0]); t[1]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S1-a->coeffs[8*i+1]); + t[2]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S1-a->coeffs[8*i+2]); t[3]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S1-a->coeffs[8*i+3]); + t[4]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S1-a->coeffs[8*i+4]); t[5]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S1-a->coeffs[8*i+5]); + t[6]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S1-a->coeffs[8*i+6]); t[7]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S1-a->coeffs[8*i+7]); + r[3*i+0] = t[0] | (t[1]<<3) | (t[2]<<6); + r[3*i+1] = (t[2]>>2) | (t[3]<<1) | (t[4]<<4) | (t[5]<<7); + r[3*i+2] = (t[5]>>1) | (t[6]<<2) | (t[7]<<5); + } +#elif SETA1BITS == 4 + for (i = 0; i < PARAM_N / 2; ++i) { + t[0]=(uint8_t)(PARAM_ETA_S1-a->coeffs[2*i+0]); t[1]=(uint8_t)(PARAM_ETA_S1-a->coeffs[2*i+1]); + r[i] = t[0] | (t[1]<<4); + } +#endif +} + +static __device__ void polyeta_s1_unpack(poly *r, const uint8_t *a) { + unsigned int i; +#if SETA1BITS == 2 + for (i = 0; i < PARAM_N / 4; ++i) { + r->coeffs[4*i+0] = COEFF_BIAS + PARAM_ETA_S1 - (int32_t)(a[i] & 0x03); + r->coeffs[4*i+1] = COEFF_BIAS + PARAM_ETA_S1 - (int32_t)((a[i]>>2) & 0x03); + r->coeffs[4*i+2] = COEFF_BIAS + PARAM_ETA_S1 - (int32_t)((a[i]>>4) & 0x03); + r->coeffs[4*i+3] = COEFF_BIAS + PARAM_ETA_S1 - (int32_t)((a[i]>>6) & 0x03); + } +#elif SETA1BITS == 3 + for (i = 0; i < PARAM_N / 8; ++i) { + r->coeffs[8*i+0] = COEFF_BIAS + PARAM_ETA_S1 - (int32_t)(a[3*i+0] & 0x07); + r->coeffs[8*i+1] = COEFF_BIAS + PARAM_ETA_S1 - (int32_t)((a[3*i+0]>>3) & 0x07); + r->coeffs[8*i+2] = COEFF_BIAS + PARAM_ETA_S1 - (int32_t)(((a[3*i+0]>>6)|(a[3*i+1]<<2)) & 0x07); + r->coeffs[8*i+3] = COEFF_BIAS + PARAM_ETA_S1 - (int32_t)((a[3*i+1]>>1) & 0x07); + r->coeffs[8*i+4] = COEFF_BIAS + PARAM_ETA_S1 - (int32_t)((a[3*i+1]>>4) & 0x07); + r->coeffs[8*i+5] = COEFF_BIAS + PARAM_ETA_S1 - (int32_t)(((a[3*i+1]>>7)|(a[3*i+2]<<1)) & 0x07); + r->coeffs[8*i+6] = COEFF_BIAS + PARAM_ETA_S1 - (int32_t)((a[3*i+2]>>2) & 0x07); + r->coeffs[8*i+7] = COEFF_BIAS + PARAM_ETA_S1 - (int32_t)((a[3*i+2]>>5) & 0x07); + } +#elif SETA1BITS == 4 + for (i = 0; i < PARAM_N / 2; ++i) { + r->coeffs[2*i+0] = PARAM_ETA_S1 - (int32_t)(a[i] & 0x0F); + r->coeffs[2*i+1] = PARAM_ETA_S1 - (int32_t)(a[i] >> 4); + } +#endif +} + +/* polyeta_s2: SETA2BITS bits per coeff */ +static __device__ void polyeta_s2_pack(uint8_t *r, const poly *a) { + unsigned int i; uint8_t t[8]; +#if SETA2BITS == 3 + for (i = 0; i < PARAM_N / 8; ++i) { + t[0]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S2-a->coeffs[8*i+0]); t[1]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S2-a->coeffs[8*i+1]); + t[2]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S2-a->coeffs[8*i+2]); t[3]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S2-a->coeffs[8*i+3]); + t[4]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S2-a->coeffs[8*i+4]); t[5]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S2-a->coeffs[8*i+5]); + t[6]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S2-a->coeffs[8*i+6]); t[7]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S2-a->coeffs[8*i+7]); + r[3*i+0] = t[0] | (t[1]<<3) | (t[2]<<6); + r[3*i+1] = (t[2]>>2) | (t[3]<<1) | (t[4]<<4) | (t[5]<<7); + r[3*i+2] = (t[5]>>1) | (t[6]<<2) | (t[7]<<5); + } +#elif SETA2BITS == 4 + for (i = 0; i < PARAM_N / 2; ++i) { + t[0]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S2-a->coeffs[2*i+0]); t[1]=(uint8_t)(COEFF_BIAS+PARAM_ETA_S2-a->coeffs[2*i+1]); + r[i] = t[0] | (t[1]<<4); + } +#endif +} + +static __device__ void polyeta_s2_unpack(poly *r, const uint8_t *a) { + unsigned int i; +#if SETA2BITS == 3 + for (i = 0; i < PARAM_N / 8; ++i) { + r->coeffs[8*i+0] = COEFF_BIAS + PARAM_ETA_S2 - (int32_t)(a[3*i+0] & 0x07); + r->coeffs[8*i+1] = COEFF_BIAS + PARAM_ETA_S2 - (int32_t)((a[3*i+0]>>3) & 0x07); + r->coeffs[8*i+2] = COEFF_BIAS + PARAM_ETA_S2 - (int32_t)(((a[3*i+0]>>6)|(a[3*i+1]<<2)) & 0x07); + r->coeffs[8*i+3] = COEFF_BIAS + PARAM_ETA_S2 - (int32_t)((a[3*i+1]>>1) & 0x07); + r->coeffs[8*i+4] = COEFF_BIAS + PARAM_ETA_S2 - (int32_t)((a[3*i+1]>>4) & 0x07); + r->coeffs[8*i+5] = COEFF_BIAS + PARAM_ETA_S2 - (int32_t)(((a[3*i+1]>>7)|(a[3*i+2]<<1)) & 0x07); + r->coeffs[8*i+6] = COEFF_BIAS + PARAM_ETA_S2 - (int32_t)((a[3*i+2]>>2) & 0x07); + r->coeffs[8*i+7] = COEFF_BIAS + PARAM_ETA_S2 - (int32_t)((a[3*i+2]>>5) & 0x07); + } +#elif SETA2BITS == 4 + for (i = 0; i < PARAM_N / 2; ++i) { + r->coeffs[2*i+0] = COEFF_BIAS + PARAM_ETA_S2 - (int32_t)(a[i] & 0x0F); + r->coeffs[2*i+1] = COEFF_BIAS + PARAM_ETA_S2 - (int32_t)(a[i] >> 4); + } +#endif +} + +/* polyt1: POLYT1_PACKED_BITS bits per coeff (10 for ML-DSA, 8 for Aigis) */ +static __device__ void polyt1_pack(uint8_t *r, const poly *a) { +#if POLYT1_PACKED_BITS == 10 + for (unsigned int i = 0; i < PARAM_N / 4; ++i) { + r[5*i+0] = (uint8_t)(a->coeffs[4*i+0]); + r[5*i+1] = (uint8_t)((a->coeffs[4*i+0]>>8) | (a->coeffs[4*i+1]<<2)); + r[5*i+2] = (uint8_t)((a->coeffs[4*i+1]>>6) | (a->coeffs[4*i+2]<<4)); + r[5*i+3] = (uint8_t)((a->coeffs[4*i+2]>>4) | (a->coeffs[4*i+3]<<6)); + r[5*i+4] = (uint8_t)(a->coeffs[4*i+3]>>2); + } +#elif POLYT1_PACKED_BITS == 8 + for (unsigned int i = 0; i < PARAM_N; ++i) r[i] = (uint8_t)a->coeffs[i]; +#endif +} + +static __device__ void polyt1_unpack(poly *r, const uint8_t *a) { +#if POLYT1_PACKED_BITS == 10 + for (unsigned int i = 0; i < PARAM_N / 4; ++i) { + r->coeffs[4*i+0] = ((uint32_t)a[5*i+0] | ((uint32_t)a[5*i+1]<<8)) & 0x3FF; + r->coeffs[4*i+1] = (((uint32_t)a[5*i+1]>>2) | ((uint32_t)a[5*i+2]<<6)) & 0x3FF; + r->coeffs[4*i+2] = (((uint32_t)a[5*i+2]>>4) | ((uint32_t)a[5*i+3]<<4)) & 0x3FF; + r->coeffs[4*i+3] = (((uint32_t)a[5*i+3]>>6) | ((uint32_t)a[5*i+4]<<2)) & 0x3FF; + } +#elif POLYT1_PACKED_BITS == 8 + for (unsigned int i = 0; i < PARAM_N; ++i) r->coeffs[i] = a[i]; +#endif +} + +/* polyt0: D bits per coeff, unified with COEFF_BIAS + * ML-DSA (COEFF_BIAS=0): stored as (2^{D-1} - coeff) + * Aigis (COEFF_BIAS=Q): stored as (Q + 2^{D-1} - coeff) */ +static __device__ void polyt0_pack(uint8_t *r, const poly *a) { + unsigned int i; uint32_t t[8]; +#if PARAM_D == 13 + for (i = 0; i < PARAM_N / 8; ++i) { + for (int j = 0; j < 8; j++) t[j] = COEFF_BIAS + (1 << (PARAM_D-1)) - a->coeffs[8*i+j]; + r[13*i+ 0] = t[0]; r[13*i+ 1] = t[0]>> 8; + r[13*i+ 1] |= t[1]<< 5; r[13*i+ 2] = t[1]>> 3; + r[13*i+ 3] = t[1]>>11; r[13*i+ 3] |= t[2]<< 2; + r[13*i+ 4] = t[2]>> 6; r[13*i+ 4] |= t[3]<< 7; + r[13*i+ 5] = t[3]>> 1; r[13*i+ 6] = t[3]>> 9; + r[13*i+ 6] |= t[4]<< 4; r[13*i+ 7] = t[4]>> 4; + r[13*i+ 8] = t[4]>>12; r[13*i+ 8] |= t[5]<< 1; + r[13*i+ 9] = t[5]>> 7; r[13*i+ 9] |= t[6]<< 6; + r[13*i+10] = t[6]>> 2; r[13*i+11] = t[6]>>10; + r[13*i+11] |= t[7]<< 3; r[13*i+12] = t[7]>> 5; + } +#elif PARAM_D == 14 + for (i = 0; i < PARAM_N / 4; ++i) { + for (int j = 0; j < 4; j++) t[j] = COEFF_BIAS + (1 << (PARAM_D-1)) - a->coeffs[4*i+j]; + r[7*i+0] = t[0]; r[7*i+1] = t[0]>> 8; + r[7*i+1] |= t[1]<< 6; r[7*i+2] = t[1]>> 2; + r[7*i+3] = t[1]>>10; r[7*i+3] |= t[2]<< 4; + r[7*i+4] = t[2]>> 4; r[7*i+5] = t[2]>>12; + r[7*i+5] |= t[3]<< 2; r[7*i+6] = t[3]>> 6; + } +#endif +} + +static __device__ void polyt0_unpack(poly *r, const uint8_t *a) { + unsigned int i; +#if PARAM_D == 13 + for (i = 0; i < PARAM_N / 8; ++i) { + r->coeffs[8*i+0] = (uint32_t)a[13*i+0] | ((uint32_t)a[13*i+1]<<8); r->coeffs[8*i+0] &= 0x1FFF; + r->coeffs[8*i+1] = (uint32_t)a[13*i+1]>>5 | ((uint32_t)a[13*i+2]<<3) | ((uint32_t)a[13*i+3]<<11); r->coeffs[8*i+1] &= 0x1FFF; + r->coeffs[8*i+2] = (uint32_t)a[13*i+3]>>2 | ((uint32_t)a[13*i+4]<<6); r->coeffs[8*i+2] &= 0x1FFF; + r->coeffs[8*i+3] = (uint32_t)a[13*i+4]>>7 | ((uint32_t)a[13*i+5]<<1) | ((uint32_t)a[13*i+6]<<9); r->coeffs[8*i+3] &= 0x1FFF; + r->coeffs[8*i+4] = (uint32_t)a[13*i+6]>>4 | ((uint32_t)a[13*i+7]<<4) | ((uint32_t)a[13*i+8]<<12); r->coeffs[8*i+4] &= 0x1FFF; + r->coeffs[8*i+5] = (uint32_t)a[13*i+8]>>1 | ((uint32_t)a[13*i+9]<<7); r->coeffs[8*i+5] &= 0x1FFF; + r->coeffs[8*i+6] = (uint32_t)a[13*i+9]>>6 | ((uint32_t)a[13*i+10]<<2) | ((uint32_t)a[13*i+11]<<10); r->coeffs[8*i+6] &= 0x1FFF; + r->coeffs[8*i+7] = (uint32_t)a[13*i+11]>>3 | ((uint32_t)a[13*i+12]<<5); r->coeffs[8*i+7] &= 0x1FFF; + for (int j = 0; j < 8; j++) r->coeffs[8*i+j] = COEFF_BIAS + (1 << (PARAM_D-1)) - (int32_t)r->coeffs[8*i+j]; + } +#elif PARAM_D == 14 + for (i = 0; i < PARAM_N / 4; ++i) { + r->coeffs[4*i+0] = (uint32_t)a[7*i+0] | (((uint32_t)a[7*i+1]&0x3F)<<8); r->coeffs[4*i+0] &= 0x3FFF; + r->coeffs[4*i+1] = (uint32_t)a[7*i+1]>>6 | ((uint32_t)a[7*i+2]<<2) | (((uint32_t)a[7*i+3]&0x0F)<<10); r->coeffs[4*i+1] &= 0x3FFF; + r->coeffs[4*i+2] = (uint32_t)a[7*i+3]>>4 | ((uint32_t)a[7*i+4]<<4) | (((uint32_t)a[7*i+5]&0x03)<<12); r->coeffs[4*i+2] &= 0x3FFF; + r->coeffs[4*i+3] = (uint32_t)a[7*i+5]>>2 | ((uint32_t)a[7*i+6]<<6); r->coeffs[4*i+3] &= 0x3FFF; + for (int j = 0; j < 4; j++) r->coeffs[4*i+j] = COEFF_BIAS + (1 << (PARAM_D-1)) - (int32_t)r->coeffs[4*i+j]; + } +#endif +} + +/* polyz: unified with Z_BIAS/Z_FIXUP + * ML-DSA (Z_BIAS=GAMMA1): t = GAMMA1 - coeff + * Aigis (Z_BIAS=GAMMA1-1): t = GAMMA1-1 - coeff; 负值+Q */ +static __device__ void polyz_pack(uint8_t *r, const poly *a) { + unsigned int i; uint32_t t[4]; +#if PARAM_GAMMA1 == (1 << 17) + for (i = 0; i < PARAM_N / 4; ++i) { + t[0]=Z_BIAS-a->coeffs[4*i+0]; Z_FIXUP(t[0]); + t[1]=Z_BIAS-a->coeffs[4*i+1]; Z_FIXUP(t[1]); + t[2]=Z_BIAS-a->coeffs[4*i+2]; Z_FIXUP(t[2]); + t[3]=Z_BIAS-a->coeffs[4*i+3]; Z_FIXUP(t[3]); + r[9*i+0]=t[0]; r[9*i+1]=t[0]>>8; + r[9*i+2]=t[0]>>16; r[9*i+2]|=t[1]<<2; + r[9*i+3]=t[1]>>6; r[9*i+4]=t[1]>>14; + r[9*i+4]|=t[2]<<4; r[9*i+5]=t[2]>>4; + r[9*i+6]=t[2]>>12; r[9*i+6]|=t[3]<<6; + r[9*i+7]=t[3]>>2; r[9*i+8]=t[3]>>10; + } +#elif PARAM_GAMMA1 == (1 << 19) + for (i = 0; i < PARAM_N / 2; ++i) { + t[0]=Z_BIAS-a->coeffs[2*i+0]; Z_FIXUP(t[0]); + t[1]=Z_BIAS-a->coeffs[2*i+1]; Z_FIXUP(t[1]); + r[5*i+0]=t[0]; r[5*i+1]=t[0]>>8; + r[5*i+2]=t[0]>>16; r[5*i+2]|=t[1]<<4; + r[5*i+3]=t[1]>>4; r[5*i+4]=t[1]>>12; + } +#endif +} + +static __device__ void polyz_unpack(poly *r, const uint8_t *a) { + unsigned int i; +#if PARAM_GAMMA1 == (1 << 17) + for (i = 0; i < PARAM_N / 4; ++i) { + r->coeffs[4*i+0]=((uint32_t)a[9*i+0]|((uint32_t)a[9*i+1]<<8)|((uint32_t)a[9*i+2]<<16))&0x3FFFF; + r->coeffs[4*i+1]=(((uint32_t)a[9*i+2]>>2)|((uint32_t)a[9*i+3]<<6)|((uint32_t)a[9*i+4]<<14))&0x3FFFF; + r->coeffs[4*i+2]=(((uint32_t)a[9*i+4]>>4)|((uint32_t)a[9*i+5]<<4)|((uint32_t)a[9*i+6]<<12))&0x3FFFF; + r->coeffs[4*i+3]=(((uint32_t)a[9*i+6]>>6)|((uint32_t)a[9*i+7]<<2)|((uint32_t)a[9*i+8]<<10))&0x3FFFF; + r->coeffs[4*i+0]=Z_BIAS-(int32_t)r->coeffs[4*i+0]; Z_FIXUP(r->coeffs[4*i+0]); + r->coeffs[4*i+1]=Z_BIAS-(int32_t)r->coeffs[4*i+1]; Z_FIXUP(r->coeffs[4*i+1]); + r->coeffs[4*i+2]=Z_BIAS-(int32_t)r->coeffs[4*i+2]; Z_FIXUP(r->coeffs[4*i+2]); + r->coeffs[4*i+3]=Z_BIAS-(int32_t)r->coeffs[4*i+3]; Z_FIXUP(r->coeffs[4*i+3]); + } +#elif PARAM_GAMMA1 == (1 << 19) + for (i = 0; i < PARAM_N / 2; ++i) { + r->coeffs[2*i+0]=((uint32_t)a[5*i+0]|((uint32_t)a[5*i+1]<<8)|((uint32_t)a[5*i+2]<<16))&0xFFFFF; + r->coeffs[2*i+1]=(((uint32_t)a[5*i+2]>>4)|((uint32_t)a[5*i+3]<<4)|((uint32_t)a[5*i+4]<<12))&0xFFFFF; + r->coeffs[2*i+0]=Z_BIAS-(int32_t)r->coeffs[2*i+0]; Z_FIXUP(r->coeffs[2*i+0]); + r->coeffs[2*i+1]=Z_BIAS-(int32_t)r->coeffs[2*i+1]; Z_FIXUP(r->coeffs[2*i+1]); + } +#endif +} + +/* polyw1: encode coeff in [0, N_W1) */ +static __device__ void polyw1_pack(uint8_t *r, const poly *a) { + unsigned int i; +#if PARAM_GAMMA2 == (PARAM_Q - 1) / 88 /* N_W1=44: 6 bits, 4 per 3 bytes */ + for (i = 0; i < PARAM_N / 4; ++i) { + r[3*i+0] = a->coeffs[4*i+0] | (a->coeffs[4*i+1]<<6); + r[3*i+1] = (a->coeffs[4*i+1]>>2) | (a->coeffs[4*i+2]<<4); + r[3*i+2] = (a->coeffs[4*i+2]>>4) | (a->coeffs[4*i+3]<<2); + } +#elif PARAM_GAMMA2 == (PARAM_Q - 1) / 32 /* N_W1=16: 4 bits (nibble), 2 per byte */ + for (i = 0; i < PARAM_N / 2; ++i) + r[i] = (uint8_t)(a->coeffs[2*i+0] | (a->coeffs[2*i+1]<<4)); +#elif PARAM_GAMMA2 == (PARAM_Q - 1) / 12 /* N_W1=6: 3 bits, 8 per 3 bytes */ + for (i = 0; i < PARAM_N / 8; ++i) { + r[3*i+0] = a->coeffs[8*i+0] | (a->coeffs[8*i+1]<<3) | (a->coeffs[8*i+2]<<6); + r[3*i+1] = (a->coeffs[8*i+2]>>2) | (a->coeffs[8*i+3]<<1) | (a->coeffs[8*i+4]<<4) | (a->coeffs[8*i+5]<<7); + r[3*i+2] = (a->coeffs[8*i+5]>>1) | (a->coeffs[8*i+6]<<2) | (a->coeffs[8*i+7]<<5); + } +#endif +} + +#endif /* POLY_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/polyvec.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/polyvec.cuh new file mode 100644 index 000000000..95449fc40 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/polyvec.cuh @@ -0,0 +1,155 @@ +#ifndef POLYVEC_CUH +#define POLYVEC_CUH + +#include "params.h" +#include "poly.cuh" + +/* Vectors of length L (s1, y, z) */ +typedef struct { poly vec[PARAM_L]; } polyvecl; +/* Vectors of length K (s2, t, w, ...) */ +typedef struct { poly vec[PARAM_K]; } polyveck; + +/* ---------------------------------------------------------------- */ +static __device__ void polyvecl_add(polyvecl *w, const polyvecl *u, const polyvecl *v) { + for (int i = 0; i < PARAM_L; ++i) poly_add(&w->vec[i], &u->vec[i], &v->vec[i]); +} +static __device__ void polyvecl_ntt(polyvecl *v) { + for (int i = 0; i < PARAM_L; ++i) poly_ntt(&v->vec[i]); +} +static __device__ void polyvecl_invntt_tomont(polyvecl *v) { + for (int i = 0; i < PARAM_L; ++i) poly_invntt_tomont(&v->vec[i]); +} +static __device__ void polyvecl_pointwise_poly_montgomery(polyvecl *r, const poly *a, const polyvecl *v) { + for (int i = 0; i < PARAM_L; ++i) poly_pointwise_montgomery(&r->vec[i], a, &v->vec[i]); +} +static __device__ void polyvecl_reduce(polyvecl *v) { + for (int i = 0; i < PARAM_L; ++i) poly_reduce(&v->vec[i]); +} +static __device__ void polyvecl_freeze2q(polyvecl *v) { + for (int i = 0; i < PARAM_L; ++i) poly_freeze2q(&v->vec[i]); +} +static __device__ void polyvecl_freeze4q(polyvecl *v) { + for (int i = 0; i < PARAM_L; ++i) poly_freeze4q(&v->vec[i]); +} +static __device__ int polyvecl_chknorm(const polyvecl *v, int32_t bound) { + for (int i = 0; i < PARAM_L; ++i) if (poly_chknorm(&v->vec[i], bound)) return 1; + return 0; +} + +/* ---------------------------------------------------------------- */ +static __device__ void polyveck_add(polyveck *w, const polyveck *u, const polyveck *v) { + for (int i = 0; i < PARAM_K; ++i) poly_add(&w->vec[i], &u->vec[i], &v->vec[i]); +} +static __device__ void polyveck_sub(polyveck *w, const polyveck *u, const polyveck *v) { + for (int i = 0; i < PARAM_K; ++i) poly_sub(&w->vec[i], &u->vec[i], &v->vec[i]); +} +#if ALGORITHM == ALGO_AIGIS +static __device__ void polyveck_neg(polyveck *v) { + for (int i = 0; i < PARAM_K; ++i) poly_neg(&v->vec[i]); +} +#endif +static __device__ void polyveck_ntt(polyveck *v) { + for (int i = 0; i < PARAM_K; ++i) poly_ntt(&v->vec[i]); +} +static __device__ void polyveck_invntt_tomont(polyveck *v) { + for (int i = 0; i < PARAM_K; ++i) poly_invntt_tomont(&v->vec[i]); +} +static __device__ void polyveck_pointwise_poly_montgomery(polyveck *r, const poly *a, const polyveck *v) { + for (int i = 0; i < PARAM_K; ++i) poly_pointwise_montgomery(&r->vec[i], a, &v->vec[i]); +} +static __device__ void polyveck_reduce(polyveck *v) { + for (int i = 0; i < PARAM_K; ++i) poly_reduce(&v->vec[i]); +} +static __device__ void polyveck_caddq(polyveck *v) { + for (int i = 0; i < PARAM_K; ++i) poly_caddq(&v->vec[i]); +} +static __device__ void polyveck_freeze2q(polyveck *v) { + for (int i = 0; i < PARAM_K; ++i) poly_freeze2q(&v->vec[i]); +} +static __device__ void polyveck_freeze4q(polyveck *v) { + for (int i = 0; i < PARAM_K; ++i) poly_freeze4q(&v->vec[i]); +} +static __device__ void polyveck_shiftl(polyveck *v) { + for (int i = 0; i < PARAM_K; ++i) poly_shiftl(&v->vec[i]); +} +static __device__ void polyveck_power2round(polyveck *v1, polyveck *v0, const polyveck *v) { + for (int i = 0; i < PARAM_K; ++i) poly_power2round(&v1->vec[i], &v0->vec[i], &v->vec[i]); +} +static __device__ void polyveck_decompose(polyveck *v1, polyveck *v0, const polyveck *v) { + for (int i = 0; i < PARAM_K; ++i) poly_decompose(&v1->vec[i], &v0->vec[i], &v->vec[i]); +} +static __device__ unsigned int polyveck_make_hint(polyveck *h, const polyveck *v0, const polyveck *v1) { + unsigned int s = 0; + for (int i = 0; i < PARAM_K; ++i) s += poly_make_hint(&h->vec[i], &v0->vec[i], &v1->vec[i]); + return s; +} +static __device__ __noinline__ void polyveck_use_hint(polyveck *w, const polyveck *v, const polyveck *h) { + for (int i = 0; i < PARAM_K; ++i) poly_use_hint(&w->vec[i], &v->vec[i], &h->vec[i]); +} +static __device__ int polyveck_chknorm(const polyveck *v, int32_t bound) { + for (int i = 0; i < PARAM_K; ++i) if (poly_chknorm(&v->vec[i], bound)) return 1; + return 0; +} + +/* ---------------------------------------------------------------- + * Matrix-vector: w = A*v (both in NTT domain, results accumulated) + * ---------------------------------------------------------------- */ +static __device__ __noinline__ void polyveck_accumulate_matvecntt( + polyveck *w, const polyvecl row[PARAM_K], const polyvecl *v) +{ + poly t; + for (int i = 0; i < PARAM_K; ++i) { + poly_pointwise_montgomery(&w->vec[i], &row[i].vec[0], &v->vec[0]); + for (int j = 1; j < PARAM_L; ++j) { + poly_pointwise_montgomery(&t, &row[i].vec[j], &v->vec[j]); + poly_add(&w->vec[i], &w->vec[i], &t); + } +#if ALGORITHM == ALGO_AIGIS + /* Aigis: accumulated sum in (-L*Q, L*Q); reduce to [0,Q) */ + for (unsigned int c = 0; c < PARAM_N; ++c) + w->vec[i].coeffs[c] = barrat_reduce(w->vec[i].coeffs[c]); +#endif + } +} + +/* ---------------------------------------------------------------- + * Matrix expansion from rho seed (unified via MATRIX_NONCE macro) + * ---------------------------------------------------------------- */ +static __device__ __noinline__ void polyvec_matrix_expand(polyvecl mat[PARAM_K], + const uint8_t rho[SEEDBYTES]) { + for (int i = 0; i < PARAM_K; ++i) + for (int j = 0; j < PARAM_L; ++j) + poly_uniform(&mat[i].vec[j], rho, MATRIX_NONCE(i, j)); +} + +/* ---------------------------------------------------------------- + * Uniform eta sampling for s1/s2 vectors (unified signature) + * ---------------------------------------------------------------- */ +static __device__ void polyvecl_uniform_eta_s1(polyvecl *v, const uint8_t *seed, + uint16_t nonce) { + for (int i = 0; i < PARAM_L; ++i) poly_uniform_eta_s1(&v->vec[i], seed, nonce++); +} +static __device__ void polyveck_uniform_eta_s2(polyveck *v, const uint8_t *seed, + uint16_t nonce) { + for (int i = 0; i < PARAM_K; ++i) poly_uniform_eta_s2(&v->vec[i], seed, nonce++); +} + +/* ---------------------------------------------------------------- + * Uniform gamma1 sampling for y (unified via GAMMA1_NONCE macro) + * ---------------------------------------------------------------- */ +static __device__ void polyvecl_uniform_gamma1(polyvecl *v, const uint8_t *seed, + uint16_t nonce) { + for (int i = 0; i < PARAM_L; ++i) + poly_uniform_gamma1(&v->vec[i], seed, GAMMA1_NONCE(nonce, i)); +} + +/* ---------------------------------------------------------------- + * Pack w1 hint bitmap into flat byte array + * ---------------------------------------------------------------- */ +static __device__ __noinline__ void polyveck_pack_w1(uint8_t r[PARAM_K * POLYW1_PACKEDBYTES], + const polyveck *w1) { + for (int i = 0; i < PARAM_K; ++i) + polyw1_pack(r + i * POLYW1_PACKEDBYTES, &w1->vec[i]); +} + +#endif /* POLYVEC_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/reduce.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/reduce.cuh new file mode 100644 index 000000000..ef15cfc2c --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/reduce.cuh @@ -0,0 +1,136 @@ +/* + * reduce.cuh — 统一约化函数 + * + * 对两种算法均使用 int32_t 系数。 + * montgomery_reduce 使用无符号乘法取低32位的技巧规避有符号溢出 UB, + * 数学结果与标准 Montgomery 约化完全等价。 + * + * ML-DSA: 系数中心化 (-Q/2, Q/2], 使用 reduce32 + caddq + * Aigis: 系数无符号 [0, Q), 使用 freeze2q/freeze4q + */ + +#ifndef REDUCE_CUH +#define REDUCE_CUH + +#include +#include "params.h" + +/* + * montgomery_reduce(a) + * 输入: a ∈ (-Q*2^32, Q*2^32) + * 输出: a * R^{-1} mod Q, 结果 ∈ (-Q, Q) (R = 2^32) + * + * 对 ML-DSA 和 Aigis 均正确。 + */ +static __device__ __forceinline__ int32_t montgomery_reduce(int64_t a) { + uint32_t t = (uint32_t)(int32_t)a * MONT_QINV; /* uint32 wraparound: defined */ + return (int32_t)((a - (int64_t)t * PARAM_Q) >> 32); +} + +/* ---- ML-DSA centered reduction ---- */ + +/* + * reduce32(a): 中心化约化至 (-Q/2, Q/2] + */ +static __device__ __forceinline__ int32_t reduce32(int32_t a) { + int32_t t = (a + (1 << (PARAM_QBITS - 1))) >> PARAM_QBITS; + return a - t * PARAM_Q; +} + +/* + * caddq(a): 如果 a < 0, 加上 Q, 使其进入 [0, Q) + */ +static __device__ __forceinline__ int32_t caddq(int32_t a) { + a += (a >> 31) & PARAM_Q; + return a; +} + +/* + * freeze(a): 完全约化至 [0, Q) (reduce32 + caddq) + */ +static __device__ __forceinline__ int32_t freeze(int32_t a) { + return caddq(reduce32(a)); +} + +/* ---- Aigis unsigned reduction [0, Q) ---- */ +/* GPU 使用有符号 Montgomery 运算, 中间值可能为负数。 + * 因此 freeze2q/freeze4q 必须先处理负值输入。 + * freeze2q: 输入 a ∈ (-2Q, 2Q), 输出 [0, Q) + * freeze4q: 输入 a ∈ (-4Q, 4Q), 输出 [0, Q) + */ + +static __device__ __forceinline__ int32_t freeze2q(int32_t a) { + a += (a >> 31) & (2 * PARAM_Q); /* 负值加 2Q → [0, 4Q) */ + a -= PARAM_Q; + a += (a >> 31) & PARAM_Q; + return a; +} + +static __device__ __forceinline__ int32_t freeze4q(int32_t a) { + a += (a >> 31) & (4 * PARAM_Q); /* 负值加 4Q → [0, 8Q) */ + a -= 2 * PARAM_Q; + a += (a >> 31) & (2 * PARAM_Q); + a -= PARAM_Q; + a += (a >> 31) & PARAM_Q; + return a; +} + +#if ALGORITHM == ALGO_AIGIS +/* + * barrat_reduce(a): GPU 有符号版 — 使用 reduce32 保证正确性 + * 输入: 任意 int32_t, 输出: [0, 2Q) 大致 + */ +static __device__ __forceinline__ int32_t barrat_reduce(int32_t a) { + /* reduce32 在 GPU 有符号算术下安全, 输出 (-Q/2, Q/2] */ + return caddq(reduce32(a)); +} +#endif /* ALGO_AIGIS */ + +/* ================================================================ + * 统一系数运算包装 — batch kernel 使用 + * 通过 coeff_t / coeff2_t 实现类型无关的批量运算 + * ================================================================ */ + +/* Montgomery multiply: c = a * b * R^{-1} mod Q */ +static __device__ __forceinline__ coeff_t coeff_fqmul(coeff_t a, coeff_t b) { + return montgomery_reduce((coeff2_t)a * b); +} + +/* 模减法: 保持在 lazy-reduced 范围 */ +static __device__ __forceinline__ coeff_t coeff_sub(coeff_t a, coeff_t b) { +#if ALGORITHM == ALGO_AIGIS + /* Aigis 使用 int32_t 但系数 [0,Q), 减法后可能为负 → 加 2Q 保正 */ + return a + 2 * PARAM_Q - b; +#else + return a - b; /* ML-DSA: signed 直接减 */ +#endif +} + +/* 轻量约化至 ~(-Q, Q) 或 ~[0, 2Q) */ +static __device__ __forceinline__ coeff_t coeff_reduce(coeff_t a) { +#if ALGORITHM == ALGO_AIGIS + return barrat_reduce(a); +#else + return reduce32(a); +#endif +} + +/* 归一化至 [0, Q) */ +static __device__ __forceinline__ coeff_t coeff_normalize(coeff_t a) { +#if ALGORITHM == ALGO_AIGIS + return freeze2q(a); +#else + return caddq(a); +#endif +} + +/* 宽范围归一化至 [0, Q) */ +static __device__ __forceinline__ coeff_t coeff_freeze_wide(coeff_t a) { +#if ALGORITHM == ALGO_AIGIS + return freeze4q(a); +#else + return caddq(reduce32(a)); +#endif +} + +#endif /* REDUCE_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/rounding.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/rounding.cuh new file mode 100644 index 000000000..b5e2c0b50 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/rounding.cuh @@ -0,0 +1,134 @@ +/* + * rounding.cuh + * + * ML-DSA: 中心化系数 (-Q/2, Q/2] + * Aigis: 无符号系数 [0, Q), a0 存为 Q+t 偏置形式 + */ + +#ifndef ROUNDING_CUH +#define ROUNDING_CUH + +#include +#include "params.h" +#include "reduce.cuh" + +/* ================================================================ + * power2round + * ================================================================ */ +#if ALGORITHM == ALGO_MLDSA + +static __device__ __forceinline__ int32_t power2round(int32_t *a0, int32_t a) { + int32_t a1 = (a + (1 << (PARAM_D - 1)) - 1) >> PARAM_D; + *a0 = a - (a1 << PARAM_D); + return a1; +} + +#elif ALGORITHM == ALGO_AIGIS + +/* Aigis: unsigned input a ∈ [0,Q), output a0 = Q + t (biased), a1 = (a-t)>>D */ +static __device__ __forceinline__ int32_t power2round(int32_t *a0, int32_t a) { + int32_t t; + t = a & ((1 << PARAM_D) - 1); + t -= (1 << (PARAM_D - 1)) + 1; + t += (t >> 31) & (1 << PARAM_D); + t -= (1 << (PARAM_D - 1)) - 1; + *a0 = PARAM_Q + t; + a = (a - t) >> PARAM_D; + return a; +} + +#endif /* power2round */ + +/* ================================================================ + * decompose + * ================================================================ */ +#if ALGORITHM == ALGO_MLDSA + +static __device__ __forceinline__ int32_t decompose(int32_t *a0, int32_t a) { + int32_t a1; +#if PARAM_GAMMA2 == (PARAM_Q - 1) / 32 + a1 = (a + 127) >> 7; + a1 = (a1 * 1025 + (1 << 21)) >> 22; + a1 &= 15; + *a0 = a - a1 * 2 * PARAM_GAMMA2; + *a0 -= (((PARAM_Q - 1) / 2 - *a0) >> 31) & PARAM_Q; +#elif PARAM_GAMMA2 == (PARAM_Q - 1) / 88 + a1 = (a + 127) >> 7; + a1 = (a1 * 11275 + (1 << 23)) >> 24; + a1 ^= ((43 - a1) >> 31) & a1; + *a0 = a - a1 * 2 * PARAM_GAMMA2; + *a0 -= (((PARAM_Q - 1) / 2 - *a0) >> 31) & PARAM_Q; +#endif + return a1; +} + +#elif ALGORITHM == ALGO_AIGIS + +/* Aigis: unsigned a ∈ [0,Q), ALPHA=2*GAMMA2, (Q-1)=6*ALPHA + * Output: a1 ∈ [0, N_W1), a0 = Q + t (biased, centered around Q) */ +static __device__ __forceinline__ int32_t decompose(int32_t *a0, int32_t a) { + int32_t t, u; + const int32_t ALPHA = 2 * PARAM_GAMMA2; + +#if PARAM_Q == 2021377 + u = ((int32_t)((uint32_t)a * 3u) >> 20) + 1; +#elif PARAM_Q == 3870721 + u = ((int32_t)((uint32_t)a * 3u) >> 21) + 1; +#endif + t = a - u * ALPHA; + u -= (t >> 31) & 1; + t += (t >> 31) & ALPHA; + t -= ALPHA / 2 + 1; + t += (t >> 31) & ALPHA; + t -= ALPHA / 2 - 1; + u += (t >> 31) & 1; + int32_t a1 = u; + if (a1 == N_W1) { *a0 = PARAM_Q + t - 1; a1 = 0; } + else { *a0 = PARAM_Q + t; } + return a1; +} + +#endif /* decompose */ + +/* ================================================================ + * make_hint / use_hint + * ================================================================ */ +#if ALGORITHM == ALGO_MLDSA + +static __device__ __forceinline__ int32_t make_hint(int32_t a0, int32_t a1) { + if (a0 > PARAM_GAMMA2 || a0 < -PARAM_GAMMA2 || + (a0 == -PARAM_GAMMA2 && a1 != 0)) + return 1; + return 0; +} + +static __device__ __forceinline__ int32_t use_hint(int32_t a, int32_t hint) { + int32_t a0, a1; + a1 = decompose(&a0, a); + if (hint == 0) return a1; + if (a0 > 0) return (a1 + 1 >= N_W1) ? 0 : a1 + 1; + else return (a1 - 1 < 0) ? N_W1 - 1 : a1 - 1; +} + +#elif ALGORITHM == ALGO_AIGIS + +/* Aigis make_hint: comparison-based — hint=1 iff decompose(a) ≠ decompose(freeze4q(a+b)) */ +static __device__ __forceinline__ int32_t make_hint(int32_t a, int32_t b) { + int32_t t; + return decompose(&t, a) != decompose(&t, freeze4q(a + b)); +} + +/* Aigis use_hint: unsigned a ∈ [0,Q), check a0 > Q (means centered a0 was negative) */ +static __device__ __forceinline__ int32_t use_hint(int32_t a, int32_t hint) { + int32_t a0, a1; + a1 = decompose(&a0, a); + if (hint == 0) return a1; + if (a0 > PARAM_Q) + return (a1 == (PARAM_Q - 1) / (2 * PARAM_GAMMA2) - 1) ? 0 : a1 + 1; + else + return (a1 == 0) ? (PARAM_Q - 1) / (2 * PARAM_GAMMA2) - 1 : a1 - 1; +} + +#endif /* make_hint / use_hint */ + +#endif /* ROUNDING_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/run_sig_policy_smoke.sh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/run_sig_policy_smoke.sh new file mode 100644 index 000000000..2da2df206 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/run_sig_policy_smoke.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +set -euo pipefail + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +out_dir="amd_results/policy_smoke" +mkdir -p "${out_dir}" + +targets=( + mldsa44_amd + mldsa65_amd + mldsa87_amd + aigis1_amd + aigis2_amd + aigis3_amd +) + +batch="${1:-128}" +summary="${out_dir}/policy_smoke_b${batch}.txt" +: > "${summary}" + +echo "[policy-smoke] batch=${batch}" | tee -a "${summary}" + +for exe in "${targets[@]}"; do + if [[ ! -x "./${exe}" ]]; then + echo "[skip] ./${exe} not found or not executable" | tee -a "${summary}" + continue + fi + + log="${out_dir}/${exe}_b${batch}.log" + echo "[run] ${exe} batch=${batch}" | tee -a "${summary}" + stdbuf -oL -eL "./${exe}" --batch "${batch}" --quiet --skip-keygen-oracle \ + 2>&1 | tee "${log}" + + grep -E "ROCm sign policy|monolithic-precomp|decomp-cp-fuse|decomp-tail|yhat-copy-fuse|decomp-adaptive|rationale|\\[Sign\\] correctness| Sign[[:space:]]+" "${log}" \ + | sed "s/^/[${exe}] /" | tee -a "${summary}" + + if grep -q "FAIL" "${log}"; then + echo "[policy-smoke] FAIL detected in ${log}" | tee -a "${summary}" >&2 + exit 1 + fi +done + +echo "[policy-smoke] PASS; summary=${summary}" diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/sign.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/sign.cuh new file mode 100644 index 000000000..a3f9a2fd4 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/sign.cuh @@ -0,0 +1,790 @@ +#ifndef SIGN_CUH +#define SIGN_CUH + +#include +#include +#include "params.h" +#include "packing.cuh" +#include "polyvec.cuh" +#include "poly.cuh" +#include "symmetric.cuh" +#include "fips202.cuh" + +/* ================================================================ + * KEY GENERATION (unified skeleton, 4 inner #if blocks) + * ================================================================ */ +static __device__ __noinline__ int crypto_sign_keypair( + uint8_t *pk, uint8_t *sk, const uint8_t *seed) +{ + polyvecl mat[PARAM_K]; + polyvecl s1, s1hat; + polyveck s2, t1, t0; + + /* ---- seed derivation (algo-specific) ---- */ +#if ALGORITHM == ALGO_MLDSA + uint8_t seedbuf[2 * SEEDBYTES + CRHBYTES]; + uint8_t tr[TRBYTES]; + memcpy(seedbuf, seed, SEEDBYTES); + seedbuf[SEEDBYTES] = (uint8_t)PARAM_K; + seedbuf[SEEDBYTES + 1] = (uint8_t)PARAM_L; + shake256(seedbuf, 2 * SEEDBYTES + CRHBYTES, seedbuf, SEEDBYTES + 2); + const uint8_t *rho = seedbuf; + const uint8_t *eta_seed = seedbuf + SEEDBYTES; /* rhoprime */ + const uint8_t *key = seedbuf + SEEDBYTES + CRHBYTES; +#elif ALGORITHM == ALGO_AIGIS + uint8_t buf[3 * SEEDBYTES + CRHBYTES]; + shake256(buf, 3 * SEEDBYTES, seed, SEEDBYTES); + const uint8_t *eta_seed = buf; /* sampling_seed */ + const uint8_t *rho = buf + SEEDBYTES; + const uint8_t *key = buf + 2 * SEEDBYTES; +#endif + + /* ---- shared: expand A, sample s1/s2 ---- */ + polyvec_matrix_expand(mat, rho); + polyvecl_uniform_eta_s1(&s1, eta_seed, 0); + polyveck_uniform_eta_s2(&s2, eta_seed, PARAM_L); + + /* ---- shared: t = A*NTT(s1) + s2, then power2round ---- */ + s1hat = s1; + polyvecl_ntt(&s1hat); + polyveck_accumulate_matvecntt(&t1, mat, &s1hat); +#if ALGORITHM == ALGO_MLDSA + polyveck_reduce(&t1); +#endif + polyveck_invntt_tomont(&t1); + polyveck_add(&t1, &t1, &s2); +#if ALGORITHM == ALGO_MLDSA + polyveck_caddq(&t1); +#elif ALGORITHM == ALGO_AIGIS + polyveck_freeze4q(&t1); +#endif + polyveck_power2round(&t1, &t0, &t1); + pack_pk(pk, rho, &t1); + + /* ---- hash pk and pack sk (algo-specific) ---- */ +#if ALGORITHM == ALGO_MLDSA + shake256(tr, TRBYTES, pk, CRYPTO_PUBLICKEYBYTES); + pack_sk(sk, rho, key, tr, &s1, &s2, &t0); +#elif ALGORITHM == ALGO_AIGIS + shake256(buf + 3 * SEEDBYTES, CRHBYTES, pk, CRYPTO_PUBLICKEYBYTES); + pack_sk(sk, rho, key, buf + 3 * SEEDBYTES, &s1, &s2, &t0); +#endif + + return 0; +} + +/* ================================================================ + * SIGNATURE + * ================================================================ */ +#if ALGORITHM == ALGO_MLDSA + +static __device__ __noinline__ int crypto_sign_signature( + uint8_t *sig, size_t *siglen, + const uint8_t *m, size_t mlen, + const uint8_t *pre, size_t prelen, + const uint8_t *rnd_in, + const uint8_t *sk) +{ + unsigned int n; + uint8_t seedbuf[2 * SEEDBYTES + TRBYTES + 2 * CRHBYTES]; + uint8_t *rho, *tr, *key, *mu, *rhoprime; + uint16_t nonce = 0; + polyvecl mat[PARAM_K], s1, y, z; + polyveck t0, s2, w1, w0, h; + poly cp; + keccak_state state; + + rho = seedbuf; + tr = rho + SEEDBYTES; + key = tr + TRBYTES; + mu = key + SEEDBYTES; + rhoprime = mu + CRHBYTES; + unpack_sk(rho, key, tr, &s1, &s2, &t0, sk); + + shake256_init(&state); + shake256_absorb(&state, tr, TRBYTES); + shake256_absorb(&state, pre, prelen); + shake256_absorb(&state, m, mlen); + shake256_finalize(&state); + shake256_squeeze(mu, CRHBYTES, &state); + + shake256_init(&state); + shake256_absorb(&state, key, SEEDBYTES); +#if RNDBYTES > 0 + shake256_absorb(&state, rnd_in, RNDBYTES); +#endif + shake256_absorb(&state, mu, CRHBYTES); + shake256_finalize(&state); + shake256_squeeze(rhoprime, CRHBYTES, &state); + + polyvec_matrix_expand(mat, rho); + polyvecl_ntt(&s1); + polyveck_ntt(&s2); + polyveck_ntt(&t0); + +rej: + polyvecl_uniform_gamma1(&y, rhoprime, nonce++); + + z = y; + polyvecl_ntt(&z); + polyveck_accumulate_matvecntt(&w1, mat, &z); + polyveck_reduce(&w1); + polyveck_invntt_tomont(&w1); + + polyveck_reduce(&w1); + polyveck_caddq(&w1); + polyveck_decompose(&w1, &w0, &w1); + polyveck_pack_w1(sig, &w1); + + shake256_init(&state); + shake256_absorb(&state, mu, CRHBYTES); + shake256_absorb(&state, sig, PARAM_K * POLYW1_PACKEDBYTES); + shake256_finalize(&state); + shake256_squeeze(sig, CTILDEBYTES, &state); + poly_challenge(&cp, sig); + poly_ntt(&cp); + + polyvecl_pointwise_poly_montgomery(&z, &cp, &s1); + polyvecl_invntt_tomont(&z); + polyvecl_add(&z, &z, &y); + polyvecl_reduce(&z); + if (polyvecl_chknorm(&z, PARAM_GAMMA1 - PARAM_BETA1)) + goto rej; + + polyveck_pointwise_poly_montgomery(&h, &cp, &s2); + polyveck_invntt_tomont(&h); + polyveck_sub(&w0, &w0, &h); + polyveck_reduce(&w0); + if (polyveck_chknorm(&w0, PARAM_GAMMA2 - PARAM_BETA2)) + goto rej; + + polyveck_pointwise_poly_montgomery(&h, &cp, &t0); + polyveck_invntt_tomont(&h); + polyveck_reduce(&h); + if (polyveck_chknorm(&h, PARAM_GAMMA2)) + goto rej; + + polyveck_add(&w0, &w0, &h); + n = polyveck_make_hint(&h, &w0, &w1); + if (n > PARAM_OMEGA) + goto rej; + + pack_sig(sig, sig, &z, &h); + *siglen = CRYPTO_BYTES; + return 0; +} + +#elif ALGORITHM == ALGO_AIGIS + +static __device__ __noinline__ int crypto_sign_signature( + uint8_t *sig, size_t *siglen, + const uint8_t *m, size_t mlen, + const uint8_t *rnd_in, + const uint8_t *sk) +{ + unsigned int n; + uint8_t rho[SEEDBYTES], key[SEEDBYTES], hash_pk[TRBYTES]; + uint8_t mu[CRHBYTES]; + uint8_t key_mu[SEEDBYTES + CRHBYTES]; /* gamma1 seed = key || mu */ + uint8_t w1_buf[PARAM_K * POLYW1_PACKEDBYTES]; + uint16_t nonce = 0; + polyvecl mat[PARAM_K], s1, y, z; + polyveck t0, s2, w, w1, wcs2, wcs20, ct0, h, tmp; + poly c, chat; + keccak_state state; + + unpack_sk(rho, key, hash_pk, &s1, &s2, &t0, sk); + + /* mu = shake256(hash_pk || m) */ + shake256_init(&state); + shake256_absorb(&state, hash_pk, TRBYTES); + shake256_absorb(&state, m, mlen); + shake256_finalize(&state); + shake256_squeeze(mu, CRHBYTES, &state); + + /* gamma1 seed = key || mu */ + memcpy(key_mu, key, SEEDBYTES); + memcpy(key_mu + SEEDBYTES, mu, CRHBYTES); + + polyvec_matrix_expand(mat, rho); + polyvecl_ntt(&s1); + polyveck_ntt(&s2); + polyveck_ntt(&t0); + +rej: + + polyvecl_uniform_gamma1(&y, key_mu, nonce); + nonce += PARAM_L; + + z = y; + polyvecl_ntt(&z); + polyveck_accumulate_matvecntt(&w, mat, &z); /* barrat_reduce included */ + polyveck_invntt_tomont(&w); + + polyveck_freeze2q(&w); + polyveck_decompose(&w1, &tmp, &w); + + /* Aigis: challenge from mu || packed_w1 */ + polyveck_pack_w1(w1_buf, &w1); + poly_challenge(&c, mu, w1_buf, PARAM_K * POLYW1_PACKEDBYTES); + + chat = c; + poly_ntt(&chat); + + /* z = chat*s1 + y */ + polyvecl_pointwise_poly_montgomery(&z, &chat, &s1); + polyvecl_invntt_tomont(&z); + polyvecl_add(&z, &z, &y); + polyvecl_freeze4q(&z); + if (polyvecl_chknorm(&z, PARAM_GAMMA1 - PARAM_BETA1)) + goto rej; + + /* wcs2 = w - chat*s2; decompose; check high bits == w1 */ + polyveck_pointwise_poly_montgomery(&wcs2, &chat, &s2); + polyveck_invntt_tomont(&wcs2); + polyveck_sub(&wcs2, &w, &wcs2); + polyveck_freeze4q(&wcs2); + polyveck_decompose(&tmp, &wcs20, &wcs2); + polyveck_freeze2q(&wcs20); + if (polyveck_chknorm(&wcs20, PARAM_GAMMA2 - PARAM_BETA2)) + goto rej; + + { + int _w1_mismatch = 0; + for (unsigned int i = 0; i < PARAM_K && !_w1_mismatch; ++i) + for (unsigned int j = 0; j < PARAM_N && !_w1_mismatch; ++j) + if (tmp.vec[i].coeffs[j] != w1.vec[i].coeffs[j]) + _w1_mismatch = 1; + if (_w1_mismatch) + goto rej; + } + + /* ct0 = chat*t0 */ + polyveck_pointwise_poly_montgomery(&ct0, &chat, &t0); + polyveck_invntt_tomont(&ct0); + polyveck_freeze2q(&ct0); + if (polyveck_chknorm(&ct0, PARAM_GAMMA2)) + goto rej; + + /* make_hint: h = hint(wcs2+ct0, neg(ct0)) */ + polyveck_add(&tmp, &wcs2, &ct0); + polyveck_neg(&ct0); + polyveck_freeze2q(&tmp); + n = polyveck_make_hint(&h, &tmp, &ct0); + if (n > PARAM_OMEGA) + goto rej; + + pack_sig(sig, &z, &h, &c); + *siglen = CRYPTO_BYTES; + return 0; +} + +#endif /* ALGORITHM sign */ + +/* ================================================================ + * VERIFY + * ================================================================ */ +#if ALGORITHM == ALGO_MLDSA + +static __device__ __noinline__ int crypto_sign_verify( + const uint8_t *sig, size_t siglen, + const uint8_t *m, size_t mlen, + const uint8_t *pre, size_t prelen, + const uint8_t *pk) +{ + unsigned int i; + uint8_t buf[PARAM_K * POLYW1_PACKEDBYTES]; + uint8_t rho[SEEDBYTES]; + uint8_t mu[CRHBYTES]; + uint8_t c[CTILDEBYTES]; + uint8_t c2[CTILDEBYTES]; + poly cp; + polyvecl mat[PARAM_K], z; + polyveck t1, w1, h; + keccak_state state; + + if (siglen != CRYPTO_BYTES) + return -1; + + unpack_pk(rho, &t1, pk); + if (unpack_sig(c, &z, &h, sig)) + return -1; + if (polyvecl_chknorm(&z, PARAM_GAMMA1 - PARAM_BETA1)) + return -1; + + shake256(mu, TRBYTES, pk, CRYPTO_PUBLICKEYBYTES); + shake256_init(&state); + shake256_absorb(&state, mu, TRBYTES); + shake256_absorb(&state, pre, prelen); + shake256_absorb(&state, m, mlen); + shake256_finalize(&state); + shake256_squeeze(mu, CRHBYTES, &state); + + poly_challenge(&cp, c); + polyvec_matrix_expand(mat, rho); + + polyvecl_ntt(&z); + polyveck_accumulate_matvecntt(&w1, mat, &z); + + poly_ntt(&cp); + polyveck_shiftl(&t1); + polyveck_ntt(&t1); + polyveck_pointwise_poly_montgomery(&t1, &cp, &t1); + + polyveck_sub(&w1, &w1, &t1); + polyveck_reduce(&w1); + polyveck_invntt_tomont(&w1); + + polyveck_reduce(&w1); + polyveck_caddq(&w1); + polyveck_use_hint(&w1, &w1, &h); + polyveck_pack_w1(buf, &w1); + + shake256_init(&state); + shake256_absorb(&state, mu, CRHBYTES); + shake256_absorb(&state, buf, PARAM_K * POLYW1_PACKEDBYTES); + shake256_finalize(&state); + shake256_squeeze(c2, CTILDEBYTES, &state); + for (i = 0; i < CTILDEBYTES; ++i) + if (c[i] != c2[i]) + return -1; + + return 0; +} + +#elif ALGORITHM == ALGO_AIGIS + +static __device__ __noinline__ int crypto_sign_verify( + const uint8_t *sig, size_t siglen, + const uint8_t *m, size_t mlen, + const uint8_t *pk) +{ + uint8_t rho[SEEDBYTES]; + uint8_t mu[CRHBYTES]; + uint8_t w1_buf[PARAM_K * POLYW1_PACKEDBYTES]; + poly c, cp, chat; + polyvecl mat[PARAM_K], z; + polyveck t1, w1, h, tmp1, tmp2; + keccak_state state; + + if (siglen != CRYPTO_BYTES) + return -1; + + unpack_pk(rho, &t1, pk); + if (unpack_sig(&z, &h, &c, sig)) + return -1; + if (polyvecl_chknorm(&z, PARAM_GAMMA1 - PARAM_BETA1)) + return -1; + + /* mu = shake256(shake256(pk) || m) */ + shake256(mu, CRHBYTES, pk, CRYPTO_PUBLICKEYBYTES); + shake256_init(&state); + shake256_absorb(&state, mu, CRHBYTES); + shake256_absorb(&state, m, mlen); + shake256_finalize(&state); + shake256_squeeze(mu, CRHBYTES, &state); + + polyvec_matrix_expand(mat, rho); + + polyvecl_ntt(&z); + polyveck_accumulate_matvecntt(&tmp1, mat, &z); /* barrat_reduce included */ + + chat = c; + poly_ntt(&chat); + polyveck_shiftl(&t1); + polyveck_ntt(&t1); + polyveck_pointwise_poly_montgomery(&tmp2, &chat, &t1); + + polyveck_sub(&tmp1, &tmp1, &tmp2); + polyveck_reduce(&tmp1); /* Remove 2*Q bias from poly_sub before INVNTT */ + polyveck_invntt_tomont(&tmp1); + + polyveck_freeze2q(&tmp1); + polyveck_use_hint(&w1, &tmp1, &h); + + /* Recompute challenge and compare coefficients */ + polyveck_pack_w1(w1_buf, &w1); + poly_challenge(&cp, mu, w1_buf, PARAM_K * POLYW1_PACKEDBYTES); + for (unsigned int i = 0; i < PARAM_N; ++i) + if (c.coeffs[i] != cp.coeffs[i]) + return -1; + + return 0; +} + +#endif /* ALGORITHM verify */ + +/* ================================================================ + * PRECOMPUTATION — 预计算结构和函数 + * + * 用于同一密钥的批量签名/验证场景。 + * 预计算内容: + * mat[K][L] — 扩展后的矩阵 A (NTT 域) + * s1_ntt — NTT(s1) (仅签名) + * s2_ntt — NTT(s2) (仅签名) + * t0_ntt — NTT(t0) (仅签名) + * key, tr — 种子材料 + * ================================================================ */ +typedef struct { + polyvecl mat[PARAM_K]; /* 扩展矩阵 A (NTT 域) */ + polyvecl s1_ntt; /* NTT(s1) — 签名用 */ + polyveck s2_ntt; /* NTT(s2) — 签名用 */ + polyveck t0_ntt; /* NTT(t0) — 签名用 */ + uint8_t key[SEEDBYTES]; /* 签名用: rhoprime 推导 */ + uint8_t tr[TRBYTES]; /* 签名/验证用: mu 计算 */ +} precomp_t; + +/* 创建预计算数据: 从 pk/sk 提取并预计算 */ +static __device__ __noinline__ void create_precomp( + precomp_t *pc, + const uint8_t *pk, + const uint8_t *sk) +{ + uint8_t rho[SEEDBYTES]; + /* unpack_sk 直接写入 pc 的存储, 避免额外拷贝 */ + unpack_sk(rho, pc->key, pc->tr, &pc->s1_ntt, &pc->s2_ntt, &pc->t0_ntt, sk); + polyvec_matrix_expand(pc->mat, rho); + polyvecl_ntt(&pc->s1_ntt); + polyveck_ntt(&pc->s2_ntt); + polyveck_ntt(&pc->t0_ntt); +} + +/* ================================================================ + * 预计算签名 — 跳过矩阵扩展和密钥 NTT + * ================================================================ */ +#if ALGORITHM == ALGO_MLDSA + +static __device__ __noinline__ int crypto_sign_signature_precomp_cached( + uint8_t *sig, size_t *siglen, + const uint8_t mu[CRHBYTES], + const uint8_t rhoprime[CRHBYTES], + const precomp_t *pc, + uint16_t nonce_start) +{ + unsigned int n; + uint16_t nonce = nonce_start; + polyvecl y, z; + polyveck w1, w0, h; + poly cp; + keccak_state state; + +rej_p_cached: + polyvecl_uniform_gamma1(&y, rhoprime, nonce++); + + z = y; + polyvecl_ntt(&z); + polyveck_accumulate_matvecntt(&w1, pc->mat, &z); + polyveck_reduce(&w1); + polyveck_invntt_tomont(&w1); + polyveck_reduce(&w1); + polyveck_caddq(&w1); + polyveck_decompose(&w1, &w0, &w1); + polyveck_pack_w1(sig, &w1); + + shake256_init(&state); + shake256_absorb(&state, mu, CRHBYTES); + shake256_absorb(&state, sig, PARAM_K * POLYW1_PACKEDBYTES); + shake256_finalize(&state); + shake256_squeeze(sig, CTILDEBYTES, &state); + poly_challenge(&cp, sig); + poly_ntt(&cp); + + polyvecl_pointwise_poly_montgomery(&z, &cp, &pc->s1_ntt); + polyvecl_invntt_tomont(&z); + polyvecl_add(&z, &z, &y); + polyvecl_reduce(&z); + if (polyvecl_chknorm(&z, PARAM_GAMMA1 - PARAM_BETA1)) + goto rej_p_cached; + + polyveck_pointwise_poly_montgomery(&h, &cp, &pc->s2_ntt); + polyveck_invntt_tomont(&h); + polyveck_sub(&w0, &w0, &h); + polyveck_reduce(&w0); + if (polyveck_chknorm(&w0, PARAM_GAMMA2 - PARAM_BETA2)) + goto rej_p_cached; + + polyveck_pointwise_poly_montgomery(&h, &cp, &pc->t0_ntt); + polyveck_invntt_tomont(&h); + polyveck_reduce(&h); + if (polyveck_chknorm(&h, PARAM_GAMMA2)) + goto rej_p_cached; + + polyveck_add(&w0, &w0, &h); + n = polyveck_make_hint(&h, &w0, &w1); + if (n > PARAM_OMEGA) + goto rej_p_cached; + + pack_sig(sig, sig, &z, &h); + *siglen = CRYPTO_BYTES; + return 0; +} + +static __device__ __noinline__ int crypto_sign_signature_precomp( + uint8_t *sig, size_t *siglen, + const uint8_t *m, size_t mlen, + const uint8_t *pre, size_t prelen, + const uint8_t *rnd_in, + const precomp_t *pc, + uint16_t nonce_start) +{ + uint8_t mu[CRHBYTES], rhoprime[CRHBYTES]; + keccak_state state; + + /* mu = H(tr || pre || m) */ + shake256_init(&state); + shake256_absorb(&state, pc->tr, TRBYTES); + shake256_absorb(&state, pre, prelen); + shake256_absorb(&state, m, mlen); + shake256_finalize(&state); + shake256_squeeze(mu, CRHBYTES, &state); + + /* rhoprime = H(key || rnd || mu) */ + shake256_init(&state); + shake256_absorb(&state, pc->key, SEEDBYTES); +#if RNDBYTES > 0 + shake256_absorb(&state, rnd_in, RNDBYTES); +#endif + shake256_absorb(&state, mu, CRHBYTES); + shake256_finalize(&state); + shake256_squeeze(rhoprime, CRHBYTES, &state); + + return crypto_sign_signature_precomp_cached(sig, siglen, mu, rhoprime, pc, nonce_start); +} + +#elif ALGORITHM == ALGO_AIGIS + +static __device__ __noinline__ int crypto_sign_signature_precomp_cached( + uint8_t *sig, size_t *siglen, + const uint8_t mu[CRHBYTES], + const uint8_t key_mu[SEEDBYTES + CRHBYTES], + const precomp_t *pc, + uint16_t nonce_start) +{ + unsigned int n; + uint8_t w1_buf[PARAM_K * POLYW1_PACKEDBYTES]; + uint16_t nonce = nonce_start; + polyvecl y, z; + polyveck w, w1, wcs2, wcs20, ct0, h, tmp; + poly c, chat; + +rej_p_cached: + polyvecl_uniform_gamma1(&y, key_mu, nonce); + nonce += PARAM_L; + + z = y; + polyvecl_ntt(&z); + polyveck_accumulate_matvecntt(&w, pc->mat, &z); + polyveck_invntt_tomont(&w); + polyveck_freeze2q(&w); + polyveck_decompose(&w1, &tmp, &w); + + polyveck_pack_w1(w1_buf, &w1); + poly_challenge(&c, mu, w1_buf, PARAM_K * POLYW1_PACKEDBYTES); + + chat = c; + poly_ntt(&chat); + + polyvecl_pointwise_poly_montgomery(&z, &chat, &pc->s1_ntt); + polyvecl_invntt_tomont(&z); + polyvecl_add(&z, &z, &y); + polyvecl_freeze4q(&z); + if (polyvecl_chknorm(&z, PARAM_GAMMA1 - PARAM_BETA1)) + goto rej_p_cached; + + polyveck_pointwise_poly_montgomery(&wcs2, &chat, &pc->s2_ntt); + polyveck_invntt_tomont(&wcs2); + polyveck_sub(&wcs2, &w, &wcs2); + polyveck_freeze4q(&wcs2); + polyveck_decompose(&tmp, &wcs20, &wcs2); + polyveck_freeze2q(&wcs20); + if (polyveck_chknorm(&wcs20, PARAM_GAMMA2 - PARAM_BETA2)) + goto rej_p_cached; + + { + int _w1_mismatch = 0; + for (unsigned int i = 0; i < PARAM_K && !_w1_mismatch; ++i) + for (unsigned int j = 0; j < PARAM_N && !_w1_mismatch; ++j) + if (tmp.vec[i].coeffs[j] != w1.vec[i].coeffs[j]) + _w1_mismatch = 1; + if (_w1_mismatch) + goto rej_p_cached; + } + + polyveck_pointwise_poly_montgomery(&ct0, &chat, &pc->t0_ntt); + polyveck_invntt_tomont(&ct0); + polyveck_freeze2q(&ct0); + if (polyveck_chknorm(&ct0, PARAM_GAMMA2)) + goto rej_p_cached; + + polyveck_add(&tmp, &wcs2, &ct0); + polyveck_neg(&ct0); + polyveck_freeze2q(&tmp); + n = polyveck_make_hint(&h, &tmp, &ct0); + if (n > PARAM_OMEGA) + goto rej_p_cached; + + pack_sig(sig, &z, &h, &c); + *siglen = CRYPTO_BYTES; + return 0; +} + +static __device__ __noinline__ int crypto_sign_signature_precomp( + uint8_t *sig, size_t *siglen, + const uint8_t *m, size_t mlen, + const uint8_t *rnd_in, + const precomp_t *pc, + uint16_t nonce_start) +{ + uint8_t mu[CRHBYTES]; + uint8_t key_mu[SEEDBYTES + CRHBYTES]; + keccak_state state; + + /* mu = H(hash_pk || m) */ + shake256_init(&state); + shake256_absorb(&state, pc->tr, TRBYTES); + shake256_absorb(&state, m, mlen); + shake256_finalize(&state); + shake256_squeeze(mu, CRHBYTES, &state); + + /* gamma1 seed = key || mu */ + memcpy(key_mu, pc->key, SEEDBYTES); + memcpy(key_mu + SEEDBYTES, mu, CRHBYTES); + + return crypto_sign_signature_precomp_cached(sig, siglen, mu, key_mu, pc, nonce_start); +} + +#endif /* ALGORITHM sign_precomp */ + +/* ================================================================ + * 预计算验证 — 跳过矩阵扩展 + * ================================================================ */ +#if ALGORITHM == ALGO_MLDSA + +static __device__ __noinline__ int crypto_sign_verify_precomp( + const uint8_t *sig, size_t siglen, + const uint8_t *m, size_t mlen, + const uint8_t *pre, size_t prelen, + const uint8_t *pk, + const polyvecl *precomp_mat) +{ + unsigned int i; + uint8_t buf[PARAM_K * POLYW1_PACKEDBYTES]; + uint8_t mu[CRHBYTES]; + uint8_t c[CTILDEBYTES]; + uint8_t c2[CTILDEBYTES]; + poly cp; + polyvecl z; + polyveck t1, w1, h; + keccak_state state; + + if (siglen != CRYPTO_BYTES) + return -1; + + unpack_pk(mu, &t1, pk); /* mu 暂存 rho, 后面覆盖 */ + if (unpack_sig(c, &z, &h, sig)) + return -1; + if (polyvecl_chknorm(&z, PARAM_GAMMA1 - PARAM_BETA1)) + return -1; + + shake256(mu, TRBYTES, pk, CRYPTO_PUBLICKEYBYTES); + shake256_init(&state); + shake256_absorb(&state, mu, TRBYTES); + shake256_absorb(&state, pre, prelen); + shake256_absorb(&state, m, mlen); + shake256_finalize(&state); + shake256_squeeze(mu, CRHBYTES, &state); + + poly_challenge(&cp, c); + /* 跳过 polyvec_matrix_expand — 使用预计算矩阵 */ + + polyvecl_ntt(&z); + polyveck_accumulate_matvecntt(&w1, precomp_mat, &z); + + poly_ntt(&cp); + polyveck_shiftl(&t1); + polyveck_ntt(&t1); + polyveck_pointwise_poly_montgomery(&t1, &cp, &t1); + + polyveck_sub(&w1, &w1, &t1); + polyveck_reduce(&w1); + polyveck_invntt_tomont(&w1); + + polyveck_reduce(&w1); + polyveck_caddq(&w1); + polyveck_use_hint(&w1, &w1, &h); + polyveck_pack_w1(buf, &w1); + + shake256_init(&state); + shake256_absorb(&state, mu, CRHBYTES); + shake256_absorb(&state, buf, PARAM_K * POLYW1_PACKEDBYTES); + shake256_finalize(&state); + shake256_squeeze(c2, CTILDEBYTES, &state); + for (i = 0; i < CTILDEBYTES; ++i) + if (c[i] != c2[i]) + return -1; + + return 0; +} + +#elif ALGORITHM == ALGO_AIGIS + +static __device__ __noinline__ int crypto_sign_verify_precomp( + const uint8_t *sig, size_t siglen, + const uint8_t *m, size_t mlen, + const uint8_t *pk, + const polyvecl *precomp_mat) +{ + uint8_t rho[SEEDBYTES]; + uint8_t mu[CRHBYTES]; + uint8_t w1_buf[PARAM_K * POLYW1_PACKEDBYTES]; + poly c, cp, chat; + polyvecl z; + polyveck t1, w1, h, tmp1, tmp2; + keccak_state state; + + if (siglen != CRYPTO_BYTES) + return -1; + + unpack_pk(rho, &t1, pk); + if (unpack_sig(&z, &h, &c, sig)) + return -1; + if (polyvecl_chknorm(&z, PARAM_GAMMA1 - PARAM_BETA1)) + return -1; + + shake256(mu, CRHBYTES, pk, CRYPTO_PUBLICKEYBYTES); + shake256_init(&state); + shake256_absorb(&state, mu, CRHBYTES); + shake256_absorb(&state, m, mlen); + shake256_finalize(&state); + shake256_squeeze(mu, CRHBYTES, &state); + + /* 跳过 polyvec_matrix_expand — 使用预计算矩阵 */ + + polyvecl_ntt(&z); + polyveck_accumulate_matvecntt(&tmp1, precomp_mat, &z); + + chat = c; + poly_ntt(&chat); + polyveck_shiftl(&t1); + polyveck_ntt(&t1); + polyveck_pointwise_poly_montgomery(&tmp2, &chat, &t1); + + polyveck_sub(&tmp1, &tmp1, &tmp2); + polyveck_reduce(&tmp1); + polyveck_invntt_tomont(&tmp1); + + polyveck_freeze2q(&tmp1); + polyveck_use_hint(&w1, &tmp1, &h); + + polyveck_pack_w1(w1_buf, &w1); + poly_challenge(&cp, mu, w1_buf, PARAM_K * POLYW1_PACKEDBYTES); + for (unsigned int i = 0; i < PARAM_N; ++i) + if (c.coeffs[i] != cp.coeffs[i]) + return -1; + + return 0; +} + +#endif /* ALGORITHM verify_precomp */ + +#endif /* SIGN_CUH */ diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/symmetric.cuh b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/symmetric.cuh new file mode 100644 index 000000000..af56b555d --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/sig_api/symmetric.cuh @@ -0,0 +1,105 @@ +#ifndef SYMMETRIC_CUH +#define SYMMETRIC_CUH + +#include +#include "params.h" +#include "fips202.cuh" + +typedef keccak_state stream128_state; +typedef keccak_state stream256_state; + +#define STREAM128_BLOCKBYTES SHAKE128_RATE +#define STREAM256_BLOCKBYTES SHAKE256_RATE + +#if ALGORITHM == ALGO_MLDSA +/* ---- ML-DSA: SEEDBYTES seed + 2-byte nonce (stream128) + * CRHBYTES seed + 2-byte nonce (stream256) ---- */ + +static __device__ void dilithium_shake128_stream_init(keccak_state *state, const uint8_t seed[SEEDBYTES], uint16_t nonce) { + uint8_t t[2]; + t[0] = nonce; + t[1] = nonce >> 8; + shake128_init(state); + shake128_absorb(state, seed, SEEDBYTES); + shake128_absorb(state, t, 2); + shake128_finalize(state); +} + +static __device__ void dilithium_shake256_stream_init(keccak_state *state, const uint8_t seed[CRHBYTES], uint16_t nonce) { + uint8_t t[2]; + t[0] = nonce; + t[1] = nonce >> 8; + shake256_init(state); + shake256_absorb(state, seed, CRHBYTES); + shake256_absorb(state, t, 2); + shake256_finalize(state); +} + +#define stream128_init(STATE, SEED, NONCE) dilithium_shake128_stream_init(STATE, SEED, NONCE) +#define stream128_squeezeblocks(OUT, OUTBLOCKS, STATE) shake128_squeezeblocks(OUT, OUTBLOCKS, STATE) +#define stream256_init(STATE, SEED, NONCE) dilithium_shake256_stream_init(STATE, SEED, NONCE) +#define stream256_squeezeblocks(OUT, OUTBLOCKS, STATE) shake256_squeezeblocks(OUT, OUTBLOCKS, STATE) + + +/* HIP clang parses non-instantiated template bodies more strictly than NVCC. + * These Aigis-named shims are only visible while compiling ML-DSA. */ +static __device__ void aigis_shake128_stream_init(keccak_state *state, const uint8_t seed[SEEDBYTES], uint8_t nonce) { + dilithium_shake128_stream_init(state, seed, (uint16_t)nonce); +} +static __device__ void aigis_shake256_eta_init(keccak_state *state, const uint8_t seed[SEEDBYTES], uint8_t nonce) { + (void)nonce; + shake256_init(state); + shake256_absorb(state, seed, SEEDBYTES); + shake256_finalize(state); +} +static __device__ void aigis_shake256_gamma1_init(keccak_state *state, const uint8_t seed[SEEDBYTES + CRHBYTES], uint16_t nonce) { + dilithium_shake256_stream_init(state, seed, nonce); +} +#elif ALGORITHM == ALGO_AIGIS +/* ---- Aigis: matrix A expand = SEEDBYTES + 1-byte nonce via shake128 + * eta sampling = SEEDBYTES + 1-byte nonce via shake256 + * gamma1 sampling = (SEEDBYTES+CRHBYTES) + 2-byte nonce via shake256 ---- */ + +/* Matrix A: shake128(seed || 1-byte nonce) */ +static __device__ void aigis_shake128_stream_init(keccak_state *state, const uint8_t seed[SEEDBYTES], uint8_t nonce) { + shake128_init(state); + shake128_absorb(state, seed, SEEDBYTES); + shake128_absorb(state, &nonce, 1); + shake128_finalize(state); +} + +/* Eta sampling: shake256(seed || 1-byte nonce) */ +static __device__ void aigis_shake256_eta_init(keccak_state *state, const uint8_t seed[SEEDBYTES], uint8_t nonce) { + shake256_init(state); + shake256_absorb(state, seed, SEEDBYTES); + shake256_absorb(state, &nonce, 1); + shake256_finalize(state); +} + +/* Gamma1 sampling: shake256(seed(SEEDBYTES+CRHBYTES) || 2-byte nonce) */ +static __device__ void aigis_shake256_gamma1_init(keccak_state *state, + const uint8_t seed[SEEDBYTES + CRHBYTES], + uint16_t nonce) { + uint8_t t[2]; + t[0] = nonce & 0xFF; + t[1] = nonce >> 8; + shake256_init(state); + shake256_absorb(state, seed, SEEDBYTES + CRHBYTES); + shake256_absorb(state, t, 2); + shake256_finalize(state); +} + +/* Aigis stream macros — these are NOT exact aliases of the ML-DSA ones; + * callers that differ (matrix, eta, gamma1) call the specific inits above. */ +#define stream128_squeezeblocks(OUT, OUTBLOCKS, STATE) shake128_squeezeblocks(OUT, OUTBLOCKS, STATE) +#define stream256_squeezeblocks(OUT, OUTBLOCKS, STATE) shake256_squeezeblocks(OUT, OUTBLOCKS, STATE) + +/* HIP clang parses non-instantiated ML-DSA template bodies while compiling Aigis. + * These stream_init aliases are only for parsing; active Aigis paths call the + * Aigis-specific init functions above. */ +#define stream128_init(STATE, SEED, NONCE) aigis_shake128_stream_init(STATE, SEED, (uint8_t)(NONCE)) +#define stream256_init(STATE, SEED, NONCE) aigis_shake256_gamma1_init(STATE, (const uint8_t *)(SEED), (uint16_t)(NONCE)) + +#endif /* ALGORITHM */ + +#endif diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/FRONTEND_MANIFEST.md b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/FRONTEND_MANIFEST.md new file mode 100644 index 000000000..f3bd723e4 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/FRONTEND_MANIFEST.md @@ -0,0 +1,19 @@ +# Frontend Manifest + +- Package: `pqc_trustflow_frontend` +- Version: `2026-06-16-alpha` +- Runtime: JupyterLab + Python + `ipywidgets` +- No extra port required + +## Contents + +- `__init__.py` +- `app.py` +- `state.py` +- `backends.py` +- `pqc_trustflow_widgets_demo.ipynb` +- `README.md` +- `FRONTEND_MANIFEST.md` +- `assets/` +- `outputs/` +- `logs/` diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/README.md b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/README.md new file mode 100644 index 000000000..ca4886ced --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/README.md @@ -0,0 +1,48 @@ +# PQC TrustFlow Frontend + +Notebook-friendly `ipywidgets` frontend for the AMD ROCm PQC workflow. + +## What this version does + +- Scans a folder of demo documents. +- Builds a `manifest.json` with file names, sizes, and SHA-256 digests. +- Encrypts every file into a `.pqcpack.zip` package. +- Uses AES-256-GCM when `cryptography` is available in the Jupyter image. +- Restores and verifies the package. +- Creates a tampered copy and confirms detection. +- Calls the selected ROCm KEM and signature executables as proof runs and saves logs. + +The current package authenticator is a demo-layer package MAC. The ROCm KEM/SIG +executables are still called and logged from the UI, and the next backend step is +to replace the package authenticator with direct ROCm file-I/O signing once the +minimal CLI mode is compiled into the HIP binaries. + +## Quick start + +Open `pqc_trustflow_widgets_demo.ipynb` in JupyterLab and run the second cell. + +Or run this in a notebook from `/app/PQC_TrustFlow_ROCm`: + +```python +from pqc_trustflow_frontend import launch_app +launch_app() +``` + +The default folder is `pqc_trustflow_frontend/sample_docs`. You can replace it +with any folder under `/app` that contains documents for the demo. + +## Terminal smoke test + +```bash +cd /app/PQC_TrustFlow_ROCm +python3 - <<'PY' +from pqc_trustflow_frontend.backends import ensure_sample_docs, create_secure_pack, create_tampered_copy_and_verify +src = ensure_sample_docs() +r = create_secure_pack(src, "Kyber-768", "ML-DSA-65", 128, "paper", run_rocm=True) +print("pack:", r.pack_dir) +print("verified:", r.verified) +print("rocm logs:", r.rocm_logs) +t = create_tampered_copy_and_verify(r.pack_dir) +print("tamper detected:", t["tamper_detected"]) +PY +``` diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/__init__.py b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/__init__.py new file mode 100644 index 000000000..55350ea6c --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/__init__.py @@ -0,0 +1,13 @@ +"""PQC TrustFlow frontend package.""" + + +def build_app(): + from .app import build_app as _build_app + + return _build_app() + + +def launch_app() -> None: + from .app import launch_app as _launch_app + + _launch_app() diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/app.py b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/app.py new file mode 100644 index 000000000..d86542704 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/app.py @@ -0,0 +1,380 @@ +from __future__ import annotations + +from datetime import datetime +from pathlib import Path +from typing import Callable +import json +import time + +import ipywidgets as widgets +from IPython.display import display + +from .backends import ( + create_secure_pack, + create_tampered_copy_and_verify, + ensure_sample_docs, + unpack_secure_pack, + summarize_rocm_logs, +) +from .state import TrustFlowState + + +BASE_DIR = Path(__file__).resolve().parent +ASSETS_DIR = BASE_DIR / "assets" +OUTPUTS_DIR = BASE_DIR / "outputs" +LOGS_DIR = BASE_DIR / "logs" + + +def ensure_dirs() -> None: + for path in (ASSETS_DIR, OUTPUTS_DIR, LOGS_DIR): + path.mkdir(parents=True, exist_ok=True) + + +STATUS_LABELS = { + "done": "完成", + "busy": "运行中", + "fail": "失败", + "idle": "待执行", +} + + +def _fmt_status(status: str) -> str: + color = {"done": "#1f7a4f", "busy": "#b26a00", "fail": "#a61b1b", "idle": "#58606b"}.get(status, "#58606b") + label = STATUS_LABELS.get(status, status) + return f'{label}' + + +def _json_artifact(artifacts: dict[str, str], key: str, default): + value = artifacts.get(key) + if not value: + return default + try: + return json.loads(value) + except Exception: + return default + + +def _fmt_ms(value: float | int | None) -> str: + if value is None: + return "-" + return f"{float(value):.2f} ms" + + +def _fmt_bool(value) -> str: + if value is True: + return "通过" + if value is False: + return "失败" + return "未启用" + + +def _format_artifacts(state: TrustFlowState, folder: str) -> str: + artifacts = state.artifacts + timings = state.timings_ms + rocm_summary = _json_artifact(artifacts, "rocm_summary", {}) + verify_detail = _json_artifact(artifacts, "verify_detail", {}) + tamper_detail = _json_artifact(artifacts, "tamper_detail", {}) + + lines: list[str] = [] + lines.append("一、本次配置") + lines.append(f" 输入目录: {folder}") + lines.append(f" KEM 算法: {state.kem_choice}") + lines.append(f" 签名算法: {state.sig_choice}") + lines.append(f" Batch 大小: {state.batch_size}") + lines.append(f" 性能模式: {state.mode}") + lines.append("") + + lines.append("二、核心结果") + normal_verified = verify_detail.get("verified") + tamper_detected = artifacts.get("tamper_detected") or ("YES" if tamper_detail.get("tamper_detected") else "") + lines.append(f" 正常包验证: {_fmt_bool(normal_verified)}") + lines.append(f" KEM 解封装: {_fmt_bool(verify_detail.get('kem_ok'))}") + lines.append(f" Manifest 包认证: {_fmt_bool(verify_detail.get('signature_ok'))}") + lines.append(f" ML-DSA/Aigis-sig 验签: {_fmt_bool(verify_detail.get('sig_api_ok'))}") + lines.append(f" 文件恢复数量: {verify_detail.get('restored_files', '-')}") + lines.append(f" 篡改检测: {tamper_detected or '未执行'}") + if tamper_detail: + lines.append(f" 篡改包验证结果: {_fmt_bool(tamper_detail.get('verified'))}") + errors = tamper_detail.get("file_errors") or [] + if errors: + lines.append(f" 篡改定位: {errors[0]}") + lines.append("") + + lines.append("三、传输包与恢复目录") + lines.append(f" 安全包目录: {artifacts.get('pack_dir', '-')}") + lines.append(f" 安全包 zip: {artifacts.get('pack_zip', '-')}") + lines.append(f" Manifest: {artifacts.get('manifest', '-')}") + lines.append(f" 恢复目录: {artifacts.get('restored_dir') or artifacts.get('unpack_dir', '-')}") + lines.append(f" 文件数量: {artifacts.get('file_count', '-')}") + lines.append(f" 文件总大小: {artifacts.get('total_bytes', '-')} bytes") + lines.append("") + + lines.append("四、密码学证据") + lines.append(f" KEM 密文摘要: {artifacts.get('kem_ciphertext', '-')}") + lines.append(f" Shared secret 摘要: {artifacts.get('kem_shared_key_sha256', '-')}") + lines.append(f" 包认证值: {artifacts.get('package_authenticator', '-')}") + lines.append(" 说明: KEM shared secret 只显示 SHA-256 摘要,不直接暴露原始密钥。") + lines.append("") + + lines.append("五、AMD ROCm 后端调用") + log_names = { + "kem_api_keygen": "KEM keygen 文件接口", + "kem_api_encaps": "KEM encaps 文件接口", + "sig_api_sign": "签名 batch/decomp 文件接口", + "sig_api_verify": "验签 batch 文件接口", + "kem_rocm": "KEM 性能/正确性 proof", + "sig_rocm": "签名性能/正确性 proof", + } + for key, label in log_names.items(): + item = rocm_summary.get(key) + if not item: + continue + tail = item.get("tail") or [] + last_line = tail[-1] if tail else "" + lines.append(f" {label}: {'PASS' if item.get('pass') else 'CHECK'}") + if last_line: + lines.append(f" 摘要: {last_line}") + lines.append(f" 日志: {item.get('log', '-')}") + if not any(k in rocm_summary for k in log_names): + lines.append(" 尚未生成 ROCm 日志。") + lines.append("") + + lines.append("六、耗时") + lines.append(f" 生成安全包总耗时: {_fmt_ms(timings.get('encaps'))}") + lines.append(f" AES 文件加密: {_fmt_ms(timings.get('encrypt_ms'))}") + lines.append(f" ROCm proof: {_fmt_ms(timings.get('rocm_proof_ms'))}") + lines.append(f" 解包/验签/解密: {_fmt_ms(timings.get('verify_decrypt_ms'))}") + lines.append(f" 篡改测试: {_fmt_ms(timings.get('decaps'))}") + lines.append("") + + lines.append("七、备注") + notes = _json_artifact(artifacts, "notes", []) + if notes: + for note in notes: + lines.append(f" - {note}") + else: + lines.append(" 无异常备注。") + return "\n".join(lines) + + +def build_app() -> widgets.VBox: + ensure_dirs() + state = TrustFlowState() + sample_dir = ensure_sample_docs() + last_pack_dir = "" + + title = widgets.HTML( + "

PQC TrustFlow

" + "
基于 AMD ROCm 的后量子多文档加密传输与可信验证演示
" + ) + source_dir_input = widgets.Text( + value=str(sample_dir), + description="文件夹", + layout=widgets.Layout(width="100%"), + ) + sensitive_input = widgets.Textarea(value="敏感数据传输样例", description="备注", layout=widgets.Layout(width="100%", height="70px")) + kem_choice = widgets.Dropdown(options=["Kyber-512", "Kyber-768", "Kyber-1024", "Aigis-enc-1", "Aigis-enc-2", "Aigis-enc-3", "Aigis-enc-4"], value="Kyber-768", description="KEM") + sig_choice = widgets.Dropdown(options=["ML-DSA-44", "ML-DSA-65", "ML-DSA-87", "Aigis-sig1", "Aigis-sig2", "Aigis-sig3"], value="ML-DSA-65", description="签名") + batch_size = widgets.Dropdown(options=[128, 1024, 8192, 16384, 32768], value=1024, description="Batch") + mode_choice = widgets.ToggleButtons(options=["paper", "independent"], value="paper", description="模式") + + status_html = widgets.HTML(value=_fmt_status("idle")) + transcript = widgets.Textarea(value="", description="日志", layout=widgets.Layout(width="100%", height="220px"), disabled=True) + artifact_box = widgets.Textarea(value="", description="证据", layout=widgets.Layout(width="100%", height="360px"), disabled=True) + + buttons = { + "prepare": widgets.Button(description="准备"), + "encaps": widgets.Button(description="生成安全包", button_style="info"), + "encrypt": widgets.Button(description="查看安全包", button_style="info"), + "sign": widgets.Button(description="查看证明", button_style="warning"), + "verify": widgets.Button(description="解包并验证", button_style="success"), + "decaps": widgets.Button(description="篡改测试", button_style="danger"), + "decrypt": widgets.Button(description="查看恢复目录", button_style="info"), + "run_all": widgets.Button(description="一键运行", button_style="success"), + "reset": widgets.Button(description="重置"), + } + + def refresh_views() -> None: + status_html.value = ( + "
" + f"
准备: {_fmt_status(state.stage_status['prepare'])}
" + f"
生成安全包: {_fmt_status(state.stage_status['encaps'])}
" + f"
查看安全包: {_fmt_status(state.stage_status['encrypt'])}
" + f"
查看证明: {_fmt_status(state.stage_status['sign'])}
" + f"
解包验证: {_fmt_status(state.stage_status['verify'])}
" + f"
篡改测试: {_fmt_status(state.stage_status['decaps'])}
" + f"
恢复目录: {_fmt_status(state.stage_status['decrypt'])}
" + "
" + ) + transcript.value = "\n".join(state.transcript) + artifact_box.value = _format_artifacts(state, source_dir_input.value) + + def save_snapshot() -> None: + stamp = datetime.now().strftime("%Y%m%d_%H%M%S") + state.save(OUTPUTS_DIR / f"trustflow_snapshot_{stamp}.json") + + def sync_choices() -> None: + state.sensitive_text = sensitive_input.value + state.kem_choice = kem_choice.value + state.sig_choice = sig_choice.value + state.batch_size = int(batch_size.value) + state.mode = mode_choice.value + + def current_pack_dir() -> str: + return last_pack_dir or state.artifacts.get("pack_dir", "") + + def run_stage(stage: str, message: str, fn: Callable[[], None]) -> None: + sync_choices() + state.set_stage(stage, "busy") + state.add_event(message) + refresh_views() + t0 = time.perf_counter() + try: + fn() + state.set_stage(stage, "done") + except Exception as exc: + state.set_stage(stage, "fail") + state.add_event(f"{stage} 失败: {exc!r}") + finally: + state.set_timing(stage, (time.perf_counter() - t0) * 1000.0) + refresh_views() + save_snapshot() + + def do_prepare(_=None): + sync_choices() + state.set_artifact("input_folder", source_dir_input.value) + state.add_event("已记录输入文件夹和算法选择") + state.set_stage("prepare", "done") + refresh_views() + save_snapshot() + + def do_encaps(_=None): + nonlocal last_pack_dir + def inner(): + result = create_secure_pack( + source_dir_input.value, + state.kem_choice, + state.sig_choice, + state.batch_size, + state.mode, + run_rocm=True, + ) + last_pack_dir = result.pack_dir + state.verified = result.verified + state.set_artifact("pack_dir", result.pack_dir) + state.set_artifact("pack_zip", result.pack_zip) + state.set_artifact("manifest", result.manifest_path) + state.set_artifact("unpack_dir", result.unpack_dir) + state.set_artifact("kem_ciphertext", result.kem_ciphertext) + state.set_artifact("kem_shared_key_sha256", result.kem_shared_key) + state.set_artifact("package_authenticator", result.signature) + state.set_artifact("file_count", str(result.file_count)) + state.set_artifact("total_bytes", str(result.total_bytes)) + state.set_artifact("rocm_logs", json.dumps(result.rocm_logs, ensure_ascii=False)) + state.set_artifact("rocm_summary", json.dumps(summarize_rocm_logs(result.rocm_logs), ensure_ascii=False, indent=2)) + state.set_artifact("notes", json.dumps(result.notes, ensure_ascii=False)) + for key, value in result.timings_ms.items(): + state.set_timing(key, value) + run_stage("encaps", f"正在生成安全包: {kem_choice.value} + {sig_choice.value}", inner) + + def do_encrypt(_=None): + def inner(): + pack_dir = current_pack_dir() + if not pack_dir: + raise RuntimeError("create a pack first") + state.set_artifact("encrypted_pack", pack_dir) + state.add_event("密文文件位于安全包的 encrypted_files 目录") + run_stage("encrypt", "正在显示安全包位置", inner) + + def do_sign(_=None): + def inner(): + if "package_authenticator" not in state.artifacts: + raise RuntimeError("create a pack first") + state.add_event(f"{sig_choice.value} 签名/包认证证明已生成") + run_stage("sign", f"正在显示 {sig_choice.value} 证明信息", inner) + + def do_verify(_=None): + def inner(): + pack_dir = current_pack_dir() + if not pack_dir: + raise RuntimeError("create a pack first") + result = unpack_secure_pack(pack_dir) + state.verified = bool(result["verified"]) + state.set_artifact("verify", "PASS" if result["verified"] else "FAIL") + state.set_artifact("restored_dir", result["out_dir"]) + state.set_artifact("verify_detail", json.dumps(result, ensure_ascii=False)) + for key, value in result["timings_ms"].items(): + state.set_timing(key, value) + run_stage("verify", f"正在解包、验签并恢复文件: {sig_choice.value}", inner) + + def do_decaps(_=None): + def inner(): + pack_dir = current_pack_dir() + if not pack_dir: + raise RuntimeError("create a pack first") + result = create_tampered_copy_and_verify(pack_dir) + state.set_artifact("tampered_pack_dir", result["tampered_pack_dir"]) + state.set_artifact("tampered_file", result["tampered_file"]) + state.set_artifact("tamper_detected", "YES" if result["tamper_detected"] else "NO") + state.set_artifact("tamper_detail", json.dumps(result, ensure_ascii=False)) + state.verified = not result["tamper_detected"] + run_stage("decaps", "正在篡改一个密文文件并验证检测能力", inner) + + def do_decrypt(_=None): + def inner(): + restored = state.artifacts.get("restored_dir") or state.artifacts.get("unpack_dir") + if not restored: + raise RuntimeError("unpack a pack first") + state.set_artifact("restored_dir", restored) + state.add_event(f"恢复后的文件目录: {restored}") + run_stage("decrypt", "正在显示恢复目录", inner) + + def do_run_all(_=None): + do_prepare() + do_encaps() + do_verify() + state.add_event("完整流程已完成") + refresh_views() + save_snapshot() + + def do_reset(_=None): + nonlocal state, last_pack_dir + state = TrustFlowState() + last_pack_dir = "" + source_dir_input.value = str(sample_dir) + sensitive_input.value = "敏感数据传输样例" + kem_choice.value = "Kyber-768" + sig_choice.value = "ML-DSA-65" + batch_size.value = 1024 + mode_choice.value = "paper" + state.add_event("状态已重置") + refresh_views() + save_snapshot() + + buttons["prepare"].on_click(do_prepare) + buttons["encaps"].on_click(do_encaps) + buttons["encrypt"].on_click(do_encrypt) + buttons["sign"].on_click(do_sign) + buttons["verify"].on_click(do_verify) + buttons["decaps"].on_click(do_decaps) + buttons["decrypt"].on_click(do_decrypt) + buttons["run_all"].on_click(do_run_all) + buttons["reset"].on_click(do_reset) + + controls = widgets.VBox([ + widgets.HBox([kem_choice, sig_choice]), + widgets.HBox([batch_size, mode_choice]), + widgets.HBox(list(buttons.values()), layout=widgets.Layout(flex_flow="row wrap")), + ]) + panels = widgets.Tab(children=[widgets.VBox([status_html, transcript]), widgets.VBox([artifact_box])]) + panels.set_title(0, "流程") + panels.set_title(1, "结果与证据") + root = widgets.VBox([title, source_dir_input, sensitive_input, controls, panels]) + state.add_event("前端已加载") + refresh_views() + save_snapshot() + return root + + +def launch_app() -> None: + display(build_app()) diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/backends.py b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/backends.py new file mode 100644 index 000000000..61259b286 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/backends.py @@ -0,0 +1,814 @@ +from __future__ import annotations + +import base64 +import hashlib +import hmac +import json +import os +import secrets +import shutil +import subprocess +import time +import zipfile +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Any + +try: + from cryptography.hazmat.primitives.ciphers.aead import AESGCM +except Exception: # pragma: no cover - depends on the Jupyter image + AESGCM = None + + +BASE_DIR = Path(__file__).resolve().parent +PROJECT_ROOT = BASE_DIR.parent +OUTPUTS_DIR = BASE_DIR / "outputs" +LOGS_DIR = BASE_DIR / "logs" +SAMPLES_DIR = BASE_DIR / "sample_docs" +PACKS_DIR = OUTPUTS_DIR / "packs" +UNPACKS_DIR = OUTPUTS_DIR / "unpacked" + + +def _digest(label: str, text: str, extra: str = "") -> str: + payload = f"{label}|{text}|{extra}".encode("utf-8", errors="ignore") + return hashlib.sha256(payload).hexdigest() + + +@dataclass +class FlowArtifacts: + kem_shared_key: str + kem_ciphertext: str + sym_ciphertext: str + signature: str + verified: bool + decrypted_text: str + + +@dataclass +class RealFlowResult: + pack_dir: str + pack_zip: str + unpack_dir: str + manifest_path: str + plaintext_dir: str + verified: bool + tamper_detected: bool + kem_shared_key: str + kem_ciphertext: str + signature: str + file_count: int + total_bytes: int + timings_ms: dict[str, float] + rocm_logs: dict[str, str] + notes: list[str] + + +def run_mock_trustflow(text: str, kem_choice: str, sig_choice: str, batch_size: int, mode: str) -> FlowArtifacts: + kem_key = _digest("kem-key", text, f"{kem_choice}|{batch_size}|{mode}")[:64] + kem_ct = _digest("kem-ct", text, kem_key)[:96] + sym_ct = _digest("sym-ct", text, kem_ct)[:96] + sig = _digest("sig", text, f"{sig_choice}|{batch_size}|{mode}")[:96] + verified = bool(text) and sig[:2] != "00" + decrypted = text if verified else "" + return FlowArtifacts(kem_key, kem_ct, sym_ct, sig, verified, decrypted) + + +def ensure_demo_dirs() -> None: + for path in (OUTPUTS_DIR, LOGS_DIR, SAMPLES_DIR, PACKS_DIR, UNPACKS_DIR): + path.mkdir(parents=True, exist_ok=True) + + +def ensure_sample_docs() -> Path: + ensure_demo_dirs() + if any(SAMPLES_DIR.iterdir()): + return SAMPLES_DIR + (SAMPLES_DIR / "medical_report.txt").write_text( + "Patient: demo-001\nStudy: MRI follow-up\nFinding: no acute abnormality.\n", + encoding="utf-8", + ) + (SAMPLES_DIR / "lab_panel.csv").write_text( + "item,value,unit\nWBC,6.1,10^9/L\nHb,132,g/L\nCRP,2.3,mg/L\n", + encoding="utf-8", + ) + (SAMPLES_DIR / "risk_features.json").write_text( + json.dumps( + { + "scenario": "financial-risk-demo", + "features": {"txn_count_7d": 42, "risk_score": 0.18, "region": "demo"}, + }, + ensure_ascii=False, + indent=2, + ), + encoding="utf-8", + ) + return SAMPLES_DIR + + +def _slug(value: str) -> str: + return "".join(ch.lower() if ch.isalnum() else "_" for ch in value).strip("_") + + +def _sha256_file(path: Path) -> str: + h = hashlib.sha256() + with path.open("rb") as f: + for chunk in iter(lambda: f.read(1024 * 1024), b""): + h.update(chunk) + return h.hexdigest() + + +def _iter_input_files(input_dir: Path) -> list[Path]: + files = [p for p in sorted(input_dir.rglob("*")) if p.is_file()] + ignored_roots = {PACKS_DIR.resolve(), UNPACKS_DIR.resolve(), LOGS_DIR.resolve(), OUTPUTS_DIR.resolve()} + ignored_dir_names = {".ipynb_checkpoints", "__pycache__"} + ignored_suffixes = {".pyc", ".pyo", ".tmp", ".bak"} + result: list[Path] = [] + for p in files: + resolved = p.resolve() + if any(str(resolved).startswith(str(root)) for root in ignored_roots): + continue + rel_parts = set(p.relative_to(input_dir).parts) + if rel_parts & ignored_dir_names: + continue + if p.name.startswith(".") or p.suffix.lower() in ignored_suffixes: + continue + result.append(p) + return result + + +def build_manifest(input_dir: Path, kem_choice: str, sig_choice: str, batch_size: int, mode: str) -> dict[str, Any]: + files = [] + total = 0 + for path in _iter_input_files(input_dir): + rel = path.relative_to(input_dir).as_posix() + size = path.stat().st_size + total += size + files.append({"path": rel, "size": size, "sha256": _sha256_file(path)}) + return { + "version": 1, + "created_at": datetime.now().isoformat(timespec="seconds"), + "input_dir": str(input_dir), + "kem": kem_choice, + "sig_algorithm": sig_choice, + "batch_size": batch_size, + "mode": mode, + "file_count": len(files), + "total_bytes": total, + "files": files, + } + + +def _keystream(key: bytes, nonce: bytes, length: int) -> bytes: + out = bytearray() + counter = 0 + while len(out) < length: + out.extend(hashlib.sha256(key + nonce + counter.to_bytes(8, "little")).digest()) + counter += 1 + return bytes(out[:length]) + + +def encrypt_bytes(plaintext: bytes, key: bytes) -> dict[str, str]: + if AESGCM is not None: + nonce = secrets.token_bytes(12) + ciphertext = AESGCM(key).encrypt(nonce, plaintext, None) + return { + "scheme": "AES-256-GCM", + "nonce": base64.b64encode(nonce).decode("ascii"), + "ciphertext": base64.b64encode(ciphertext).decode("ascii"), + } + nonce = secrets.token_bytes(16) + stream = _keystream(key, nonce, len(plaintext)) + ciphertext = bytes(a ^ b for a, b in zip(plaintext, stream)) + tag = hmac.new(key, nonce + ciphertext, hashlib.sha256).digest() + return { + "scheme": "SHA256-stream-HMAC-fallback", + "nonce": base64.b64encode(nonce).decode("ascii"), + "ciphertext": base64.b64encode(ciphertext).decode("ascii"), + "tag": base64.b64encode(tag).decode("ascii"), + } + + +def decrypt_bytes(record: dict[str, str], key: bytes) -> bytes: + nonce = base64.b64decode(record["nonce"]) + ciphertext = base64.b64decode(record["ciphertext"]) + if record.get("scheme") == "AES-256-GCM": + if AESGCM is None: + raise RuntimeError("AES-GCM package requires cryptography, but it is not installed") + return AESGCM(key).decrypt(nonce, ciphertext, None) + tag = base64.b64decode(record["tag"]) + expected = hmac.new(key, nonce + ciphertext, hashlib.sha256).digest() + if not hmac.compare_digest(tag, expected): + raise ValueError("ciphertext authentication failed") + stream = _keystream(key, nonce, len(ciphertext)) + return bytes(a ^ b for a, b in zip(ciphertext, stream)) + + +def _map_kem_exe(kem_choice: str) -> Path: + mapping = { + "Kyber-512": "kyber512_amd", + "Kyber-768": "kyber768_amd", + "Kyber-1024": "kyber1024_amd", + "Aigis-enc-1": "aigisenc1_amd", + "Aigis-enc-2": "aigisenc2_amd", + "Aigis-enc-3": "aigisenc3_amd", + "Aigis-enc-4": "aigisenc4_amd", + } + return PROJECT_ROOT / "kyberandaigis-enc" / mapping[kem_choice] + + +def _map_sig_exe(sig_choice: str) -> Path: + mapping = { + "ML-DSA-44": "mldsa44_amd", + "ML-DSA-65": "mldsa65_amd", + "ML-DSA-87": "mldsa87_amd", + "Aigis-sig1": "aigis1_amd", + "Aigis-sig2": "aigis2_amd", + "Aigis-sig3": "aigis3_amd", + } + return PROJECT_ROOT / "mldsaandaigis-sig" / mapping[sig_choice] + + +def _run_sig_cli(exe: Path, args: list[str], log_path: Path) -> dict[str, Any]: + return _run_command([str(exe), *args], exe.parent, log_path, timeout=120) + + +def _run_command(cmd: list[str], cwd: Path, log_path: Path, timeout: int = 120) -> dict[str, Any]: + env = os.environ.copy() + rocm_lib = "/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib" + env["LD_LIBRARY_PATH"] = rocm_lib + ":" + env.get("LD_LIBRARY_PATH", "") + t0 = time.perf_counter() + try: + proc = subprocess.run( + cmd, + cwd=str(cwd), + env=env, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + timeout=timeout, + check=False, + ) + output = proc.stdout + rc = proc.returncode + except Exception as exc: + output = f"{type(exc).__name__}: {exc}\n" + rc = -1 + elapsed = (time.perf_counter() - t0) * 1000.0 + log_path.parent.mkdir(parents=True, exist_ok=True) + log_path.write_text("$ " + " ".join(cmd) + "\n" + output, encoding="utf-8", errors="ignore") + return {"returncode": rc, "elapsed_ms": elapsed, "log": str(log_path), "output": output} + + +def run_rocm_proofs(kem_choice: str, sig_choice: str, batch_size: int, mode: str, stamp: str) -> tuple[dict[str, str], list[str]]: + logs: dict[str, str] = {} + notes: list[str] = [] + kem_exe = _map_kem_exe(kem_choice) + sig_exe = _map_sig_exe(sig_choice) + proof_batch = max(1, min(int(batch_size), 1024)) + + if kem_exe.exists(): + kem_log = LOGS_DIR / f"{stamp}_kem_{_slug(kem_choice)}.log" + res = _run_command( + [str(kem_exe), "--batch", str(proof_batch), "--n-ops", "1"], + kem_exe.parent, + kem_log, + ) + logs["kem_rocm"] = res["log"] + if res["returncode"] != 0: + notes.append(f"KEM ROCm proof returned {res['returncode']}; see {res['log']}") + else: + notes.append(f"KEM executable not found: {kem_exe}") + + if sig_exe.exists(): + sig_log = LOGS_DIR / f"{stamp}_sig_{_slug(sig_choice)}.log" + cmd = [str(sig_exe), "--batch", str(proof_batch), "--quiet", "--skip-keygen-oracle"] + if mode == "independent": + cmd.append("--bench-independent") + else: + cmd.append("--bench-paper") + res = _run_command(cmd, sig_exe.parent, sig_log) + logs["sig_rocm"] = res["log"] + if res["returncode"] != 0: + notes.append(f"SIG ROCm proof returned {res['returncode']}; see {res['log']}") + else: + notes.append(f"Signature executable not found: {sig_exe}") + + return logs, notes + + +def summarize_rocm_logs(logs: dict[str, str]) -> dict[str, Any]: + summary: dict[str, Any] = {} + for name, log_path in logs.items(): + path = Path(log_path) + item: dict[str, Any] = {"log": str(path), "exists": path.exists()} + if path.exists(): + text = path.read_text(encoding="utf-8", errors="ignore") + item["pass"] = "PASS" in text and "FAIL" not in text + item["has_fail"] = "FAIL" in text + item["returncode_hint"] = "error" in text.lower() or "not found" in text.lower() + lines = [line.strip() for line in text.splitlines() if line.strip()] + item["tail"] = lines[-8:] + summary[name] = item + return summary + + +def _write_json(path: Path, data: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") + + +def _zip_dir(src_dir: Path, zip_path: Path) -> None: + if zip_path.exists(): + zip_path.unlink() + with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf: + for path in sorted(src_dir.rglob("*")): + if path.is_file(): + zf.write(path, path.relative_to(src_dir).as_posix()) + + +def _package_mac(key: bytes, manifest: dict[str, Any]) -> str: + excluded = { + "package_authenticator", + "signature_backend", + "kem_backend", + "sig_public_key", + "sig_secret_key_demo", + "manifest_signature", + "sig_payload", + "sig_api_batch", + "sig_cli_verify_log", + "sig_api_verify_log", + } + signed = { + k: v + for k, v in manifest.items() + if k not in excluded + } + payload = json.dumps(signed, ensure_ascii=False, sort_keys=True).encode("utf-8") + return hmac.new(key, payload, hashlib.sha256).hexdigest() + + +def _signature_payload(manifest: dict[str, Any]) -> dict[str, Any]: + excluded = { + "signature_backend", + "sig_public_key", + "sig_secret_key_demo", + "manifest_signature", + "sig_payload", + "sig_api_batch", + "sig_cli_verify_log", + "sig_api_verify_log", + } + return {k: v for k, v in manifest.items() if k not in excluded} + + +def _write_signature_payload(path: Path, manifest: dict[str, Any]) -> None: + _write_json(path, _signature_payload(manifest)) + + +def _derive_aes_key(shared_secret: bytes) -> bytes: + return hashlib.sha256(shared_secret).digest() + + +def _try_create_kem_api_session( + kem_choice: str, + batch_size: int, + pack_dir: Path, + stamp: str, +) -> tuple[bytes | None, dict[str, Any], dict[str, str], list[str]]: + kem_exe = _map_kem_exe(kem_choice) + logs: dict[str, str] = {} + notes: list[str] = [] + manifest_fields: dict[str, Any] = {} + if not kem_exe.exists(): + return None, manifest_fields, logs, [f"KEM executable not found for API: {kem_exe}"] + + api_batch = max(1, min(int(batch_size), 1024)) + kem_dir = pack_dir / "kem" + kem_dir.mkdir(parents=True, exist_ok=True) + pk_path = kem_dir / "kem_pk.bin" + sk_path = kem_dir / "receiver_sk.demo_secret" + ct_path = kem_dir / "kem_ct.bin" + ss_sender_path = kem_dir / "ss_sender.demo_secret" + + keygen_log = LOGS_DIR / f"{stamp}_kemapi_keygen_{_slug(kem_choice)}.log" + encaps_log = LOGS_DIR / f"{stamp}_kemapi_encaps_{_slug(kem_choice)}.log" + kg = _run_command( + [str(kem_exe), "--api-kem-keygen", "--batch", str(api_batch), "--pk-out", str(pk_path), "--sk-out", str(sk_path)], + kem_exe.parent, + keygen_log, + ) + logs["kem_api_keygen"] = str(keygen_log) + if kg["returncode"] != 0: + notes.append(f"KEM API keygen not active; rc={kg['returncode']}; see {keygen_log}") + return None, manifest_fields, logs, notes + + enc = _run_command( + [ + str(kem_exe), + "--api-kem-encaps", + "--batch", + str(api_batch), + "--pk-in", + str(pk_path), + "--ct-out", + str(ct_path), + "--ss-out", + str(ss_sender_path), + ], + kem_exe.parent, + encaps_log, + ) + logs["kem_api_encaps"] = str(encaps_log) + if enc["returncode"] != 0 or not ss_sender_path.exists() or not ct_path.exists(): + notes.append(f"KEM API encaps not active; rc={enc['returncode']}; see {encaps_log}") + return None, manifest_fields, logs, notes + + shared_secret = ss_sender_path.read_bytes() + manifest_fields.update( + { + "kem_backend": "ROCm KEM batch file API", + "kem_api_batch": api_batch, + "kem_public_key": "kem/kem_pk.bin", + "kem_ciphertext_file": "kem/kem_ct.bin", + "kem_receiver_secret_demo": "kem/receiver_sk.demo_secret", + "kem_ciphertext": _sha256_file(ct_path), + } + ) + return _derive_aes_key(shared_secret), manifest_fields, logs, notes + + +def _pack_path(pack_dir: Path, value: str | None) -> Path | None: + if not value: + return None + path = Path(value) + return path if path.is_absolute() else pack_dir / path + + +def _recover_kem_api_session(manifest: dict[str, Any], pack_dir: Path) -> tuple[bytes | None, bool, str]: + kem_choice = manifest.get("kem", "Kyber-768") + kem_exe = _map_kem_exe(kem_choice) + sk_path = _pack_path(pack_dir, manifest.get("kem_receiver_secret_demo")) + ct_path = _pack_path(pack_dir, manifest.get("kem_ciphertext_file")) + if not kem_exe.exists() or not sk_path or not ct_path: + return None, False, "" + stamp = datetime.now().strftime("%Y%m%d_%H%M%S") + ss_receiver_path = pack_dir / "kem" / f"ss_receiver_{stamp}.demo_secret" + verify_log = LOGS_DIR / f"{stamp}_kemapi_decaps_{_slug(kem_choice)}.log" + res = _run_command( + [ + str(kem_exe), + "--api-kem-decaps", + "--batch", + str(max(1, int(manifest.get("kem_api_batch", 128)))), + "--sk-in", + str(sk_path), + "--ct-in", + str(ct_path), + "--ss-out", + str(ss_receiver_path), + ], + kem_exe.parent, + verify_log, + ) + if res["returncode"] != 0 or not ss_receiver_path.exists(): + return None, False, str(verify_log) + return _derive_aes_key(ss_receiver_path.read_bytes()), True, str(verify_log) + + +def _try_create_sig_api_signature( + manifest: dict[str, Any], + sig_choice: str, + batch_size: int, + pack_dir: Path, + stamp: str, +) -> tuple[dict[str, Any], dict[str, str], list[str]]: + sig_exe = _map_sig_exe(sig_choice) + logs: dict[str, str] = {} + notes: list[str] = [] + fields: dict[str, Any] = {} + if not sig_exe.exists(): + return fields, logs, [f"Signature executable not found for API: {sig_exe}"] + + api_batch = max(1, min(int(batch_size), 1024)) + sig_dir = pack_dir / "sig" + sig_dir.mkdir(parents=True, exist_ok=True) + payload_path = sig_dir / "manifest.payload.json" + pk_path = sig_dir / "sig_pk.bin" + sk_path = sig_dir / "sig_sk.demo_secret" + sig_path = sig_dir / "manifest.sig" + sign_log = LOGS_DIR / f"{stamp}_sigapi_sign_{_slug(sig_choice)}.log" + verify_log = LOGS_DIR / f"{stamp}_sigapi_verify_{_slug(sig_choice)}.log" + + _write_signature_payload(payload_path, manifest) + sign = _run_command( + [ + str(sig_exe), + "--api-sig-sign", + "--batch", + str(api_batch), + "--msg-in", + str(payload_path), + "--pk-out", + str(pk_path), + "--sk-out", + str(sk_path), + "--sig-out", + str(sig_path), + ], + sig_exe.parent, + sign_log, + timeout=240, + ) + logs["sig_api_sign"] = str(sign_log) + if sign["returncode"] != 0 or not pk_path.exists() or not sig_path.exists(): + notes.append(f"SIG API sign not active; rc={sign['returncode']}; see {sign_log}") + return fields, logs, notes + + verify = _run_command( + [ + str(sig_exe), + "--api-sig-verify", + "--batch", + str(api_batch), + "--msg-in", + str(payload_path), + "--pk-in", + str(pk_path), + "--sig-in", + str(sig_path), + ], + sig_exe.parent, + verify_log, + timeout=240, + ) + logs["sig_api_verify"] = str(verify_log) + if verify["returncode"] != 0: + notes.append(f"SIG API verify not active; rc={verify['returncode']}; see {verify_log}") + return fields, logs, notes + + fields.update( + { + "signature_backend": "ROCm ML-DSA/Aigis-sig batch file API", + "sig_api_batch": api_batch, + "sig_payload": "sig/manifest.payload.json", + "sig_public_key": "sig/sig_pk.bin", + "sig_secret_key_demo": "sig/sig_sk.demo_secret", + "manifest_signature": "sig/manifest.sig", + } + ) + return fields, logs, notes + + +def _verify_sig_api_signature(manifest: dict[str, Any], pack_dir: Path) -> tuple[bool | None, str]: + if manifest.get("signature_backend") != "ROCm ML-DSA/Aigis-sig batch file API": + return None, "" + sig_alg = manifest.get("sig_algorithm", "ML-DSA-65") + sig_exe = _map_sig_exe(sig_alg) + pk_path = _pack_path(pack_dir, manifest.get("sig_public_key")) + sig_path = _pack_path(pack_dir, manifest.get("manifest_signature")) + if not sig_exe.exists() or not pk_path or not sig_path: + return False, "" + stamp = datetime.now().strftime("%Y%m%d_%H%M%S") + payload_path = pack_dir / "sig" / f"manifest.payload.verify_{stamp}.json" + verify_log = LOGS_DIR / f"{stamp}_sigapi_unpack_verify_{_slug(sig_alg)}.log" + _write_signature_payload(payload_path, manifest) + res = _run_command( + [ + str(sig_exe), + "--api-sig-verify", + "--batch", + str(max(1, int(manifest.get("sig_api_batch", 128)))), + "--msg-in", + str(payload_path), + "--pk-in", + str(pk_path), + "--sig-in", + str(sig_path), + ], + sig_exe.parent, + verify_log, + timeout=240, + ) + return res["returncode"] == 0, str(verify_log) + + +def create_secure_pack( + input_dir: str | Path | None, + kem_choice: str, + sig_choice: str, + batch_size: int, + mode: str, + run_rocm: bool = True, +) -> RealFlowResult: + ensure_demo_dirs() + source_dir = Path(input_dir).expanduser().resolve() if input_dir else ensure_sample_docs().resolve() + if not source_dir.exists() or not source_dir.is_dir(): + raise FileNotFoundError(f"input directory not found: {source_dir}") + manifest = build_manifest(source_dir, kem_choice, sig_choice, batch_size, mode) + if manifest["file_count"] <= 0: + raise ValueError(f"input directory has no files: {source_dir}") + + stamp = datetime.now().strftime("%Y%m%d_%H%M%S") + pack_dir = PACKS_DIR / f"pack_{stamp}_{_slug(kem_choice)[:10]}_{_slug(sig_choice)[:10]}" + enc_dir = pack_dir / "encrypted_files" + enc_dir.mkdir(parents=True, exist_ok=True) + + t0 = time.perf_counter() + session_key = secrets.token_bytes(32) + kem_ciphertext = hashlib.sha256(session_key + kem_choice.encode("utf-8")).hexdigest() + kem_api_logs: dict[str, str] = {} + kem_api_notes: list[str] = [] + kem_manifest_fields: dict[str, Any] = {} + if run_rocm: + kem_session_key, kem_manifest_fields, kem_api_logs, kem_api_notes = _try_create_kem_api_session( + kem_choice, batch_size, pack_dir, stamp + ) + if kem_session_key is not None: + session_key = kem_session_key + kem_ciphertext = kem_manifest_fields.get("kem_ciphertext", kem_ciphertext) + else: + kem_api_notes.append("KEM API unavailable; using demo session key capsule fallback") + timings = {"prepare_ms": (time.perf_counter() - t0) * 1000.0} + + t1 = time.perf_counter() + encrypted_records = [] + for entry in manifest["files"]: + src = source_dir / entry["path"] + record = encrypt_bytes(src.read_bytes(), session_key) + out_name = hashlib.sha256(entry["path"].encode("utf-8")).hexdigest()[:20] + ".json" + out_path = enc_dir / out_name + _write_json(out_path, record) + encrypted_records.append({"path": entry["path"], "encrypted": f"encrypted_files/{out_name}"}) + timings["encrypt_ms"] = (time.perf_counter() - t1) * 1000.0 + + manifest["encrypted_files"] = encrypted_records + manifest["kem_ciphertext"] = kem_ciphertext + manifest["signature_backend"] = "ROCm batch/decomp proof + package authenticator" + manifest["kem_backend"] = "ROCm CLI proof + SHA-256 package key capsule" + manifest.update(kem_manifest_fields) + (pack_dir / "session_key.demo_secret").write_text( + base64.b64encode(session_key).decode("ascii"), + encoding="ascii", + ) + + manifest_path = pack_dir / "manifest.json" + signature = _package_mac(session_key, manifest) + manifest["package_authenticator"] = signature + _write_json(manifest_path, manifest) + + sig_api_logs: dict[str, str] = {} + sig_api_notes: list[str] = [] + if run_rocm: + sig_fields, sig_api_logs, sig_api_notes = _try_create_sig_api_signature( + manifest, sig_choice, batch_size, pack_dir, stamp + ) + if sig_fields: + manifest.update(sig_fields) + signature = _package_mac(session_key, manifest) + manifest["package_authenticator"] = signature + _write_json(manifest_path, manifest) + payload = _pack_path(pack_dir, manifest.get("sig_payload")) + if payload: + _write_signature_payload(payload, manifest) + else: + sig_api_notes.append("SIG API unavailable; using package authenticator fallback") + + rocm_logs: dict[str, str] = {**kem_api_logs, **sig_api_logs} + notes: list[str] = [*kem_api_notes, *sig_api_notes] + if run_rocm: + t2 = time.perf_counter() + rocm_logs, proof_notes = run_rocm_proofs(kem_choice, sig_choice, batch_size, mode, stamp) + rocm_logs = {**kem_api_logs, **sig_api_logs, **rocm_logs} + notes.extend(proof_notes) + timings["rocm_proof_ms"] = (time.perf_counter() - t2) * 1000.0 + + zip_path = pack_dir.with_suffix(".pqcpack.zip") + _zip_dir(pack_dir, zip_path) + + unpack_dir = UNPACKS_DIR / f"unpack_{stamp}" + verify = unpack_secure_pack(pack_dir, unpack_dir) + return RealFlowResult( + pack_dir=str(pack_dir), + pack_zip=str(zip_path), + unpack_dir=str(unpack_dir), + manifest_path=str(manifest_path), + plaintext_dir=str(source_dir), + verified=verify["verified"], + tamper_detected=False, + kem_shared_key=hashlib.sha256(session_key).hexdigest(), + kem_ciphertext=kem_ciphertext, + signature=signature, + file_count=manifest["file_count"], + total_bytes=manifest["total_bytes"], + timings_ms={**timings, **verify["timings_ms"]}, + rocm_logs=rocm_logs, + notes=notes, + ) + + +def unpack_secure_pack(pack_dir: str | Path, out_dir: str | Path | None = None) -> dict[str, Any]: + pack_dir = Path(pack_dir).resolve() + out_dir = Path(out_dir).resolve() if out_dir else (UNPACKS_DIR / f"unpack_{pack_dir.name}") + t0 = time.perf_counter() + manifest = json.loads((pack_dir / "manifest.json").read_text(encoding="utf-8")) + kem_ok = True + kem_log = "" + if manifest.get("kem_backend") == "ROCm KEM batch file API": + recovered_key, kem_ok, kem_log = _recover_kem_api_session(manifest, pack_dir) + if recovered_key is not None: + key = recovered_key + else: + key = base64.b64decode((pack_dir / "session_key.demo_secret").read_text(encoding="ascii")) + else: + key = base64.b64decode((pack_dir / "session_key.demo_secret").read_text(encoding="ascii")) + signature = manifest.get("package_authenticator", "") + sig_ok = hmac.compare_digest(signature, _package_mac(key, manifest)) + sig_cli_ok = None + sig_cli_log = "" + sig_api_ok, sig_api_log = _verify_sig_api_signature(manifest, pack_dir) + if manifest.get("signature_backend") == "ROCm ML-DSA/Aigis-sig CLI": + sig_alg = manifest.get("sig_algorithm", "ML-DSA-65") + sig_exe = _map_sig_exe(sig_alg) + sig_pk = manifest.get("sig_public_key") + sig_file = manifest.get("manifest_signature") + if sig_exe.exists() and sig_pk and sig_file: + stamp = datetime.now().strftime("%Y%m%d_%H%M%S") + verify_log = LOGS_DIR / f"{stamp}_unpack_sigcli_verify_{_slug(sig_alg)}.log" + res = _run_sig_cli( + sig_exe, + ["--cli-verify", "--pk-in", sig_pk, "--msg-in", str(pack_dir / "manifest.json"), "--sig-in", sig_file], + verify_log, + ) + sig_cli_ok = res["returncode"] == 0 + sig_cli_log = str(verify_log) + else: + sig_cli_ok = False + + if out_dir.exists(): + shutil.rmtree(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + restored = 0 + total = 0 + file_errors = [] + enc_map = {item["path"]: item["encrypted"] for item in manifest.get("encrypted_files", [])} + for entry in manifest["files"]: + rel = entry["path"] + try: + enc_record = json.loads((pack_dir / enc_map[rel]).read_text(encoding="utf-8")) + plaintext = decrypt_bytes(enc_record, key) + out_path = out_dir / rel + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_bytes(plaintext) + actual = hashlib.sha256(plaintext).hexdigest() + if actual != entry["sha256"]: + file_errors.append(f"{rel}: sha256 mismatch") + restored += 1 + total += len(plaintext) + except Exception as exc: + file_errors.append(f"{rel}: {exc}") + + return { + "verified": kem_ok and sig_ok and (sig_cli_ok is not False) and (sig_api_ok is not False) and not file_errors, + "kem_ok": kem_ok, + "kem_log": kem_log, + "signature_ok": sig_ok, + "sig_cli_ok": sig_cli_ok, + "sig_cli_log": sig_cli_log, + "sig_api_ok": sig_api_ok, + "sig_api_log": sig_api_log, + "file_errors": file_errors, + "restored_files": restored, + "restored_bytes": total, + "out_dir": str(out_dir), + "timings_ms": {"verify_decrypt_ms": (time.perf_counter() - t0) * 1000.0}, + } + + +def tamper_pack(pack_dir: str | Path) -> str: + pack_dir = Path(pack_dir).resolve() + candidates = sorted((pack_dir / "encrypted_files").glob("*.json")) + if not candidates: + raise FileNotFoundError("no encrypted file to tamper") + target = candidates[0] + data = json.loads(target.read_text(encoding="utf-8")) + raw = bytearray(base64.b64decode(data["ciphertext"])) + if raw: + raw[0] ^= 1 + data["ciphertext"] = base64.b64encode(bytes(raw)).decode("ascii") + target.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") + return str(target) + + +def create_tampered_copy_and_verify(pack_dir: str | Path) -> dict[str, Any]: + pack_dir = Path(pack_dir).resolve() + tampered = pack_dir.with_name(pack_dir.name + "_tampered") + if tampered.exists(): + shutil.rmtree(tampered) + shutil.copytree(pack_dir, tampered) + tampered_file = tamper_pack(tampered) + result = unpack_secure_pack(tampered, UNPACKS_DIR / f"tampered_{pack_dir.name}") + result["tampered_pack_dir"] = str(tampered) + result["tampered_file"] = tampered_file + result["tamper_detected"] = not result["verified"] + return result diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/sample_docs/Untitled.ipynb b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/sample_docs/Untitled.ipynb new file mode 100644 index 000000000..d38c55462 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/sample_docs/Untitled.ipynb @@ -0,0 +1,33 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "48823f79-4617-4464-b01f-3f50ab75a2ce", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/sample_docs/lab_panel.csv b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/sample_docs/lab_panel.csv new file mode 100644 index 000000000..d873431f2 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/sample_docs/lab_panel.csv @@ -0,0 +1,4 @@ +item,value,unit +WBC,6.1,10^9/L +Hb,132,g/L +CRP,2.3,mg/L diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/sample_docs/medical_report.txt b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/sample_docs/medical_report.txt new file mode 100644 index 000000000..9f1cfb223 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/sample_docs/medical_report.txt @@ -0,0 +1,3 @@ +Patient: demo-001 +Study: MRI follow-up +Finding: no acute abnormality. diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/sample_docs/risk_features.json b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/sample_docs/risk_features.json new file mode 100644 index 000000000..35852fd4e --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/sample_docs/risk_features.json @@ -0,0 +1,8 @@ +{ + "scenario": "financial-risk-demo", + "features": { + "txn_count_7d": 42, + "risk_score": 0.18, + "region": "demo" + } +} \ No newline at end of file diff --git a/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/state.py b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/state.py new file mode 100644 index 000000000..926cbb589 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/01_unsupported_feature_rocm_pqc_api/trustflow_frontend/state.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from dataclasses import dataclass, field, asdict +from pathlib import Path +from typing import Any +import json +import time + + +@dataclass +class TrustFlowState: + sensitive_text: str = "" + kem_choice: str = "Kyber-768" + sig_choice: str = "ML-DSA-65" + batch_size: int = 1024 + mode: str = "paper" + stage_status: dict[str, str] = field(default_factory=lambda: { + "prepare": "idle", + "encaps": "idle", + "encrypt": "idle", + "sign": "idle", + "verify": "idle", + "decaps": "idle", + "decrypt": "idle", + }) + artifacts: dict[str, str] = field(default_factory=dict) + timings_ms: dict[str, float] = field(default_factory=dict) + transcript: list[str] = field(default_factory=list) + verified: bool = False + + def add_event(self, message: str) -> None: + stamp = time.strftime("%H:%M:%S") + self.transcript.append(f"[{stamp}] {message}") + + def set_stage(self, stage: str, status: str) -> None: + self.stage_status[stage] = status + + def set_artifact(self, name: str, value: str) -> None: + self.artifacts[name] = value + + def set_timing(self, name: str, value: float) -> None: + self.timings_ms[name] = value + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + def save(self, path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(self.to_dict(), ensure_ascii=False, indent=2), encoding="utf-8") diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/README.md b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/README.md new file mode 100644 index 000000000..f5c016378 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/README.md @@ -0,0 +1,50 @@ +# Performance Bottleneck Localization And ROCm Optimization + +This folder is prepared for the competition item: + +```text +(2) Performance bottleneck localization and optimization +``` + +## What This Contains + +- `kem_optimization/`: KEM profiling and tuning scripts for operation-specific TPB, launch bounds, device buffer reuse, and ROCm trace/resource analysis. +- `sig_optimization/`: ML-DSA/Aigis-sig profiling and feature-matrix tooling for resource-aware signing on ROCm. +- `evidence/`: small result summaries and decision notes suitable for PR description, report tables, and defense slides. + +## Main Optimization Evidence + +KEM: + +- Kyber-768 keygen profiling shows the sample/XOF/rejection-sampling stage is the dominant bottleneck, about 70% of the pipeline time in the recorded run. +- Kyber-768 encaps improved from about 6.00M ops/s to about 7.10M ops/s after launch-bound tuning. +- Device buffer reuse improved continuous full-KEM throughput from about 1.93M to 2.05M instances/s in the recorded run. + +Signature: + +- The stable ROCm path uses a resource-aware decomp pipeline because monolithic/cached-style signing creates private-segment and scratch-pressure risk. +- Feature-matrix candidates include `adaptive`, `check8`, `check16`, `wave64_ctrl`, `cp_fuse`, `tail16_base`, `tail16_cp_fuse`, and `yhat_dup`. +- Local wins exist, for example ML-DSA-65 independent batch=16384 with `cp_fuse` reached 1.3625x speedup, and Aigis-sig3 independent batch=16384 with `wave64_ctrl` reached 1.4200x. +- Conservative selected builds keep the base decomp pipeline when no candidate satisfies the no-regression rule across all measured cells. + +## Reproduction Entry Points + +```bash +cd kem_optimization +bash run_kem_tune_amd.sh kyber768 +bash run_kem_confirm_amd.sh kyber768 +bash run_kem_final_report_amd.sh +bash run_kem_resource_profile_amd.sh kyber768 32768 200 +``` + +```bash +cd sig_optimization +bash amd_tools/run_sig_policy_smoke.sh 128 +bash amd_tools/run_sig_amd_feature_matrix.sh +python3 amd_tools/write_optimization_claims.py +bash amd_tools/run_sig_large_sweep.sh +``` + +## Why It Fits The Scoring Item + +This folder demonstrates a complete optimization loop: workload profiling, bottleneck attribution, candidate implementation, repeated measurement, conservative promotion decisions, quantified speedups, and stability/maintainability discussion. diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/kem_final_extract.txt b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/kem_final_extract.txt new file mode 100644 index 000000000..381e03d6a --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/kem_final_extract.txt @@ -0,0 +1,28 @@ +Algorithm: Kyber-512 K=2 Q=3329 + Keygen: 3.2 ms/batch → 10095164 ops/sec + Encaps: 2.9 ms/batch → 11368410 ops/sec + Decaps: 3.9 ms/batch → 8451971 ops/sec +Algorithm: Kyber-768 K=3 Q=3329 + Keygen: 5.2 ms/batch → 6276945 ops/sec + Encaps: 4.6 ms/batch → 7142451 ops/sec + Decaps: 5.8 ms/batch → 5651891 ops/sec +Algorithm: Kyber-1024 K=4 Q=3329 + Keygen: 7.4 ms/batch → 4447101 ops/sec + Encaps: 6.7 ms/batch → 4916932 ops/sec + Decaps: 8.6 ms/batch → 3829267 ops/sec +Algorithm: Aigis-enc-1 K=2 Q=7681 + Keygen: 6.4 ms/batch → 10240547 ops/sec + Encaps: 8.0 ms/batch → 8204293 ops/sec + Decaps: 10.1 ms/batch → 6497139 ops/sec +Algorithm: Aigis-enc-2 K=3 Q=7681 + Keygen: 9.9 ms/batch → 6605967 ops/sec + Encaps: 11.6 ms/batch → 5630086 ops/sec + Decaps: 17.1 ms/batch → 3827435 ops/sec +Algorithm: Aigis-enc-3 K=3 Q=7681 + Keygen: 10.4 ms/batch → 6305602 ops/sec + Encaps: 12.7 ms/batch → 5144497 ops/sec + Decaps: 16.1 ms/batch → 4060120 ops/sec +Algorithm: Aigis-enc-4 K=4 Q=7681 + Keygen: 15.8 ms/batch → 4156445 ops/sec + Encaps: 19.0 ms/batch → 3444781 ops/sec + Decaps: 22.1 ms/batch → 2961419 ops/sec diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/kyber_amd_first_run_bottleneck_analysis.md b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/kyber_amd_first_run_bottleneck_analysis.md new file mode 100644 index 000000000..7b208ee93 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/kyber_amd_first_run_bottleneck_analysis.md @@ -0,0 +1,1460 @@ +# Kyber/Aigis-enc AMD 首轮跑通与瓶颈分析 + +日期:2026-06-10 + +## 1. 当前阶段结论 + +Kyber-768 已经在 AMD JupyterLab / ROCm 环境中完成首轮跑通。 + +当前已确认: + +- `kyber768_amd` 可以通过 `hipcc` 编译生成。 +- 运行时使用 HIP Runtime。 +- 正确性测试通过,`KEM 正确性: PASS`。 +- 已获得 AMD 单卡上的首批 keygen / encaps / decaps 吞吐数据。 +- 已完成一次程序内 pipeline profile,定位到 keygen 的主要瓶颈。 + +这说明 Kyber/Aigis-enc 模块已经从“4090 上跑通”推进到“AMD ROCm 上可运行、可测量、可分析”的阶段。 + +## 2. AMD 环境与构建信息 + +AMD JupyterLab 中显示的设备信息: + +```text +GPU: AMD Radeon Graphics (gfx1100, 48 CUs, 51.5 GB VRAM) +Runtime: HIP +Algorithm: Kyber-768 K=3 Q=3329 +``` + +构建方式: + +```bash +bash build_hip.sh kyber768 +``` + +构建结果: + +```text +84 warnings generated when compiling for host. +HIP 构建完成 +``` + +说明: + +- 当前 warning 暂时不影响可执行文件生成。 +- 运行时需要设置 ROCm runtime 动态库路径,脚本中已内置: + +```bash +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" +``` + +## 3. 正确性冒烟结果 + +命令: + +```bash +bash run_kem_smoke_amd.sh +cat amd_results/kem_smoke_summary.csv +``` + +结果摘要: + +| Algorithm | Batch | Keygen ops/s | Encaps ops/s | Decaps ops/s | Correctness | +|---|---:|---:|---:|---:|---| +| Kyber-768 | 1 | 724 | 578 | 471 | PASS | +| Kyber-768 | 8 | 5,695 | 4,572 | 3,712 | PASS | +| Kyber-768 | 32 | 27,020 | 22,823 | 19,908 | PASS | +| Kyber-768 | 128 | 101,527 | 89,225 | 78,159 | PASS | + +结论: + +- 小 batch 下正确性稳定。 +- 随 batch 增大,吞吐明显提升,说明该工作负载具备 GPU 批处理收益。 +- 小 batch 的吞吐较低,后续可作为 kernel launch 开销和任务粒度不足的分析案例。 + +## 3.1 全量 KEM 目标冒烟结果 + +后续已完成 7 个 KEM 目标的 AMD 小 batch 冒烟测试: + +```bash +bash build_hip.sh +bash run_kem_smoke_amd.sh +cat amd_results/kem_smoke_summary.csv +``` + +结果: + +| Target | K | Q | Batch=128 Keygen | Batch=128 Encaps | Batch=128 Decaps | Correctness | +|---|---:|---:|---:|---:|---:|---| +| Kyber-512 | 2 | 3329 | 156,364 ops/s | 142,702 ops/s | 122,041 ops/s | PASS | +| Kyber-768 | 3 | 3329 | 101,685 ops/s | 88,712 ops/s | 77,192 ops/s | PASS | +| Kyber-1024 | 4 | 3329 | 63,358 ops/s | 66,156 ops/s | 61,995 ops/s | PASS | +| Aigis-enc-1 | 2 | 7681 | 94,551 ops/s | 70,573 ops/s | 51,935 ops/s | PASS | +| Aigis-enc-2 | 3 | 7681 | 65,397 ops/s | 49,648 ops/s | 39,623 ops/s | PASS | +| Aigis-enc-3 | 3 | 7681 | 60,851 ops/s | 49,192 ops/s | 40,077 ops/s | PASS | +| Aigis-enc-4 | 4 | 7681 | 43,553 ops/s | 36,797 ops/s | 32,014 ops/s | PASS | + +阶段性结论: + +> Kyber-512/768/1024 与 Aigis-enc-1/2/3/4 均已在 AMD ROCm 上完成小 batch 正确性验证,说明 KEM 模块的 AMD 基础迁移已经闭环。后续工作重点应从“能否运行”转向“最佳 batch 吞吐、热点瓶颈定位与 ROCm 专项优化”。 + +可用于 PPT 的表述: + +> 本项目已完成 Kyber 与 Aigis-enc 共 7 个参数集在 AMD ROCm 平台上的编译、运行与正确性冒烟验证,为后续构建完整后量子科研数据可信流转平台提供了可复现的 KEM 基础能力。 + +## 4. 当前最好性能 + +命令: + +```bash +cat amd_results/kem_best.csv +``` + +当前 Kyber-768 最好结果: + +| Operation | Best Throughput | Best Config | +|---|---:|---| +| Keygen | 4,521,348 ops/s | batch=32768, serial, streams=1 | +| Encaps | 5,932,625 ops/s | batch=32768, serial, streams=1 | +| Decaps | 5,509,231 ops/s | batch=32768, serial, streams=1 | + +与此前 4090 Kyber-768 最好结果对比: + +| Platform | Keygen | Encaps | Decaps | +|---|---:|---:|---:| +| RTX 4090D | 5.70M ops/s | 7.95M ops/s | 8.16M ops/s | +| AMD 初始 ROCm 版 | 4.52M ops/s | 5.93M ops/s | 5.51M ops/s | + +阶段性判断: + +> 初始 HIP/ROCm 迁移版已经达到 RTX 4090D 同量级吞吐。当前性能差距不能简单归因于 AMD 硬件不足,更合理的解释是:现有实现仍是迁移与初步适配版本,尚未充分围绕 ROCm/RDNA3 的执行模型进行采样、访存和并行粒度调优。 + +## 4.1 全量 KEM Sweep 最佳性能表 + +命令: + +```bash +bash run_kem_sweep_amd.sh +cat amd_results/kem_best.csv +``` + +AMD ROCm 全量 KEM 最佳结果: + +| Algorithm | Best Keygen | Keygen Config | Best Encaps | Encaps Config | Best Decaps | Decaps Config | +|---|---:|---|---:|---|---:|---| +| Kyber-512 | 6,553,707 ops/s | batch=32768 serial | 10,898,867 ops/s | batch=65536 serial | 7,200,522 ops/s | batch=32768 serial | +| Kyber-768 | 4,797,831 ops/s | batch=32768 serial | 5,893,981 ops/s | batch=32768 serial | 5,502,606 ops/s | batch=32768 serial | +| Kyber-1024 | 3,691,492 ops/s | batch=32768 serial | 4,207,232 ops/s | batch=32768 serial | 3,674,413 ops/s | batch=32768 serial | +| Aigis-enc-1 | 8,316,531 ops/s | batch=65536 serial | 7,092,336 ops/s | batch=65536 serial | 5,175,258 ops/s | batch=65536 serial | +| Aigis-enc-2 | 5,708,114 ops/s | batch=65536 serial | 4,656,565 ops/s | batch=65536 serial | 3,360,105 ops/s | batch=65536 serial | +| Aigis-enc-3 | 5,159,607 ops/s | batch=65536 serial | 4,589,059 ops/s | batch=65536 serial | 3,583,328 ops/s | batch=65536 serial | +| Aigis-enc-4 | 3,874,049 ops/s | batch=65536 serial | 2,943,683 ops/s | batch=65536 serial | 2,370,340 ops/s | batch=65536 serial | + +阶段性结论: + +- Kyber 系列最佳 batch 多集中在 `32768`,Kyber-512 的 encaps 在 `65536` 达到最高。 +- Aigis-enc 系列最佳 batch 均集中在 `65536`,说明 Aigis-enc 在当前实现下更依赖大 batch 来摊薄 launch 和调度开销。 +- 当前 AMD 初始版已经具备百万级到千万级 KEM 吞吐,足以支撑“多文件科研数据可信流转”的高并发密钥封装场景。 +- 高安全等级参数集吞吐下降明显,后续可将 Kyber-1024 和 Aigis-enc-4 作为重点 profile 对象。 + +适合 PPT 的结论: + +> 在 AMD ROCm 初始适配版本中,Kyber-512 encaps 已达到 10.90M ops/s,Aigis-enc-1 keygen 达到 8.32M ops/s。全量 7 个 KEM 参数集均达到百万级以上吞吐,证明 AMD Radeon PRO/RDNA3 平台具备支撑高并发后量子密钥封装的工程潜力。 + +## 5. Pipeline Profile 结果 + +命令: + +```bash +bash profile_kem_one_amd.sh kyber768_amd 32768 3 +cat amd_results/profile/kyber768_amd_b32768_profile.log +``` + +profile 输出: + +```text +--- batch=32768 n_ops=3 mode=serial --- + Keygen: 7.3 ms/batch -> 4512026 ops/sec + Encaps: 5.5 ms/batch -> 5977275 ops/sec + Decaps: 6.0 ms/batch -> 5488131 ops/sec + +--- batch=32768 n_ops=3 mode=pipeline --- + Pipeline profile: sample=5.065 ntt=0.454 matvec=0.487 invntt=0.349 add=0.315 pack=0.569 total=7.238 ms + Keygen: 7.3 ms/batch -> 4517392 ops/sec + Encaps: 4.6 ms/batch -> 7149509 ops/sec + Decaps: 5.8 ms/batch -> 5632301 ops/sec +``` + +阶段占比估算: + +| Stage | Time | Ratio | +|---|---:|---:| +| sample | 5.065 ms | 70.0% | +| ntt | 0.454 ms | 6.3% | +| matvec | 0.487 ms | 6.7% | +| invntt | 0.349 ms | 4.8% | +| add | 0.315 ms | 4.4% | +| pack | 0.569 ms | 7.9% | +| total | 7.238 ms | 100% | + +核心瓶颈: + +> 当前 Kyber-768 keygen 的主要瓶颈是 sample 阶段,占 pipeline 总耗时约 70%。这说明在 AMD ROCm 初始实现中,主要限制来自 SHAKE/XOF 展开、拒绝采样和噪声采样相关路径,而不是 NTT 或矩阵向量乘。 + +## 5.1 多目标 Profile 对比 + +已完成三个代表目标的 profile: + +```bash +bash profile_kem_one_amd.sh kyber768_amd 32768 3 +bash profile_kem_one_amd.sh kyber1024_amd 32768 3 +bash profile_kem_one_amd.sh aigisenc4_amd 32768 3 +``` + +### Kyber-768 + +```text +Pipeline profile: sample=5.075 ntt=0.453 matvec=0.487 invntt=0.346 add=0.315 pack=0.572 total=7.247 ms +``` + +| Stage | Time | Ratio | +|---|---:|---:| +| sample | 5.075 ms | 70.0% | +| ntt | 0.453 ms | 6.3% | +| matvec | 0.487 ms | 6.7% | +| invntt | 0.346 ms | 4.8% | +| add | 0.315 ms | 4.3% | +| pack | 0.572 ms | 7.9% | + +### Kyber-1024 + +```text +Pipeline profile: sample=7.227 ntt=0.628 matvec=0.887 invntt=0.477 add=0.437 pack=0.930 total=10.584 ms +``` + +| Stage | Time | Ratio | +|---|---:|---:| +| sample | 7.227 ms | 68.3% | +| ntt | 0.628 ms | 5.9% | +| matvec | 0.887 ms | 8.4% | +| invntt | 0.477 ms | 4.5% | +| add | 0.437 ms | 4.1% | +| pack | 0.930 ms | 8.8% | + +### Aigis-enc-4 + +```text +Pipeline profile: sample=9.439 ntt=0.840 matvec=0.811 invntt=0.470 add=0.438 pack=0.781 total=12.780 ms +``` + +| Stage | Time | Ratio | +|---|---:|---:| +| sample | 9.439 ms | 73.9% | +| ntt | 0.840 ms | 6.6% | +| matvec | 0.811 ms | 6.3% | +| invntt | 0.470 ms | 3.7% | +| add | 0.438 ms | 3.4% | +| pack | 0.781 ms | 6.1% | + +综合判断: + +> Kyber-768、Kyber-1024 和 Aigis-enc-4 均表现出 sample 阶段主导的瓶颈特征,占 keygen pipeline 总耗时约 68% 到 74%。因此,下一阶段优化不应优先放在 NTT 或矩阵向量乘,而应优先围绕 SHAKE/XOF 展开、拒绝采样、噪声采样以及采样阶段的并行粒度进行 ROCm 专项调优。 + +另一个重要现象: + +| Target | Serial Keygen | Pipeline Keygen | 现象 | +|---|---:|---:|---| +| Kyber-768 | 4.20M ops/s | 4.52M ops/s | pipeline 略优 | +| Kyber-1024 | 3.68M ops/s | 3.16M ops/s | pipeline 反而变慢 | +| Aigis-enc-4 | 3.44M ops/s | 2.57M ops/s | pipeline 明显变慢 | + +说明: + +> 当前 pipeline keygen 不是所有参数集的最优路径。对于 Kyber-1024 和 Aigis-enc-4,大参数集下 pipeline 引入的中间缓冲、拆分 kernel 和打包路径开销超过了收益。因此短期 benchmark 应继续以 serial 路径作为 best-throughput 基线,同时将 pipeline 路径作为 profile 和优化实验对象。 + +## 6. 当前暴露的问题 + +### 6.1 简单迁移能跑,但不代表跑满 + +当前结果说明 HIP 迁移已经完成基本功能,但 profile 显示 sample 阶段占比过高。后续需要针对 AMD/RDNA3 重新调整采样并行粒度、访存布局和 kernel 拆分方式。 + +适合论文表述: + +> HIP 迁移解决了代码在 ROCm 平台上的可运行性问题,但后量子密码算法中的采样和 XOF 扩展并不是 ROCm 现有 AI/HPC 库重点覆盖的算子类型。简单迁移不能充分发挥 AMD GPU 的潜力,需要结合 profile 数据进行面向 ROCm 的专门优化。 + +### 6.2 ROCm 缺少类 cuPQC 的后量子密码专用库 + +NVIDIA 已经有 cuPQC,面向 ML-KEM、ML-DSA 等后量子密码负载提供 GPU 库支持。AMD ROCm 目前公开生态中缺少同类 PQC 专用库。 + +本项目可作为 AMD 生态补齐方向: + +- 批量 ML-KEM/Kyber KEM kernel。 +- 批量 Aigis-enc KEM kernel。 +- 批量 ML-DSA/Aigis-sig 签名验签 kernel。 +- NTT、Keccak/SHAKE、采样、packing 等 PQC 基础算子。 +- 面向多文件科研数据可信流转的上层接口。 + +### 6.3 小 batch 端到端效率低 + +batch=1/8/32 时吞吐明显低于大 batch。这说明小任务场景下 GPU launch、同步和数据准备开销占比较高。 + +后续优化方向: + +- 多文件任务聚合。 +- 自适应 batch size。 +- 多 stream 并发。 +- 缓冲区复用。 +- CPU I/O 与 GPU 密码计算流水线化。 + +## 7. 下一步实验计划 + +### Step 1:验证 split sampling 是否改善 sample 瓶颈 + +目的: + +当前 sample 阶段占总耗时约 70%,优先测试已有的 `KEM_SPLIT_KEYGEN_SAMPLE=1` 路径。 + +在 AMD JupyterLab 中执行: + +```bash +cd /app/kyberandaigis-enc +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +hipcc -O2 -std=c++17 -x hip --offload-arch=gfx1100 \ + -DKEM_SERIAL_TPB=64 \ + -DKEM_SPLIT_KEYGEN_SAMPLE=1 \ + -DALGORITHM=1 -DPARAM_MODE=3 \ + main.cu -o kyber768_split_amd + +./kyber768_split_amd --batch 32768 --n-ops 3 --no-correctness --pipeline --profile-pipeline +./kyber768_split_amd --batch 128 --n-ops 1 +``` + +需要记录: + +- `sample` 时间是否下降。 +- `total` 时间是否下降。 +- `Keygen ops/s` 是否提升。 +- 正确性是否仍然 PASS。 + +如果 split 版本有效,可写成第一个优化点: + +> 通过拆分 seed expand、matrix sample、noise sample,提高采样阶段并行粒度,降低 sample 阶段耗时。 + +### Step 1 实验结果:split sampling 失败 + +实际执行: + +```bash +hipcc -O2 -std=c++17 -x hip --offload-arch=gfx1100 \ + -DKEM_SERIAL_TPB=64 \ + -DKEM_SPLIT_KEYGEN_SAMPLE=1 \ + -DALGORITHM=1 -DPARAM_MODE=3 \ + main.cu -o kyber768_split_amd + +./kyber768_split_amd --batch 32768 --n-ops 3 --no-correctness --pipeline --profile-pipeline +./kyber768_split_amd --batch 128 --n-ops 1 +``` + +结果: + +```text +--- batch=32768 n_ops=3 mode=serial --- + Keygen: 7.2 ms/batch -> 4556756 ops/sec + Encaps: 5.6 ms/batch -> 5883818 ops/sec + Decaps: 6.0 ms/batch -> 5502531 ops/sec + +--- batch=32768 n_ops=3 mode=pipeline --- + Pipeline profile: sample=16.648 ntt=0.476 matvec=0.523 invntt=0.346 add=0.318 pack=0.571 total=18.883 ms + Keygen: 18.8 ms/batch -> 1742529 ops/sec + Encaps: 4.6 ms/batch -> 7100841 ops/sec + Decaps: 5.8 ms/batch -> 5633798 ops/sec +``` + +正确性: + +```text +KEM 正确性: PASS +``` + +对比: + +| Version | sample | total | Pipeline Keygen | +|---|---:|---:|---:| +| 原 pipeline | 5.075 ms | 7.247 ms | 4.52M ops/s | +| split sample | 16.648 ms | 18.883 ms | 1.74M ops/s | + +结论: + +> `KEM_SPLIT_KEYGEN_SAMPLE=1` 在 AMD ROCm 上不是有效优化。简单拆分 seed expand、matrix sample、noise sample 会显著增加 sample 阶段耗时,推测原因是额外 kernel launch、全局内存中间结果写回、cache/locality 下降和更高的调度开销超过了并行粒度提升带来的收益。 + +论文中可以将该实验作为“负优化案例”: + +> 并非所有 CUDA/NVIDIA 风格或直觉上的 kernel 拆分都适合 AMD ROCm。对后量子密码采样路径而言,kernel fusion 与数据局部性可能比盲目拆分更重要。该实验说明需要 profile-driven tuning,而不是仅通过增加 kernel 数量提高表面并行度。 + +下一步调整: + +> 放弃 split sampling 路线,优先保留当前 serial/baseline 路径作为最佳吞吐实现;后续优化转向 `KEM_SERIAL_TPB`、batch size、stream 并发、采样 kernel 内部并行方式和减少中间内存访问。 + +### Step 2 实验结果:TPB=256 对 Kyber-768 有效 + +在 Kyber-768 上测试 `KEM_SERIAL_TPB=256`: + +```bash +hipcc -O2 -std=c++17 -x hip --offload-arch=gfx1100 \ + -DKEM_SERIAL_TPB=256 \ + -DALGORITHM=1 -DPARAM_MODE=3 \ + main.cu -o kyber768_tpb256_amd + +./kyber768_tpb256_amd --batch 32768 --n-ops 5 --no-correctness +``` + +结果: + +```text +--- batch=32768 n_ops=5 mode=serial --- + Keygen: 6.0 ms/batch -> 5444673 ops/sec + Encaps: 5.4 ms/batch -> 6035024 ops/sec + Decaps: 5.9 ms/batch -> 5581568 ops/sec +``` + +与此前 Kyber-768 sweep 最佳值对比: + +| Metric | 原最佳 | TPB=256 | 变化 | +|---|---:|---:|---:| +| Keygen | 4,797,831 ops/s | 5,444,673 ops/s | +13.5% | +| Encaps | 5,893,981 ops/s | 6,035,024 ops/s | +2.4% | +| Decaps | 5,502,606 ops/s | 5,581,568 ops/s | +1.4% | + +结论: + +> 相比 split sampling,`KEM_SERIAL_TPB` 调参是当前更有效的 ROCm 优化方向。将 Kyber-768 的 serial kernel 线程块大小从默认 64 调整到 256 后,keygen 吞吐提升约 13.5%,说明 AMD/RDNA3 上的线程组织参数对后量子 KEM 批处理性能有显著影响。 + +下一步需要补充: + +- 补齐 `TPB=32/64/128/256/512` 的完整对比。 +- 对 `TPB=256` 跑一次正确性测试。 +- 若 `TPB=256` 稳定,再推广到 Kyber-512/1024 与 Aigis-enc 系列。 + +### Step 2 补充:TPB=512 不是最优 + +`KEM_SERIAL_TPB=512` 正确性测试: + +```text +KEM 正确性: PASS +``` + +性能结果: + +```text +--- batch=32768 n_ops=5 mode=serial --- + Keygen: 6.5 ms/batch -> 5049880 ops/sec + Encaps: 5.8 ms/batch -> 5603379 ops/sec + Decaps: 6.3 ms/batch -> 5232946 ops/sec +``` + +与 `TPB=256` 对比: + +| Metric | TPB=256 | TPB=512 | 结论 | +|---|---:|---:|---| +| Keygen | 5,444,673 ops/s | 5,049,880 ops/s | TPB=512 下降 | +| Encaps | 6,035,024 ops/s | 5,603,379 ops/s | TPB=512 下降 | +| Decaps | 5,581,568 ops/s | 5,232,946 ops/s | TPB=512 下降 | + +判断: + +> `TPB=512` 虽然正确性通过,但性能低于 `TPB=256`。说明继续增大线程块并不能进一步提升吞吐,可能引入更高的寄存器/调度压力或降低有效 occupancy。当前 Kyber-768 的候选最优配置仍为 `KEM_SERIAL_TPB=256`。 + +### Step 2 补充:TPB=128 结果 + +`KEM_SERIAL_TPB=128` 正确性测试: + +```text +KEM 正确性: PASS +``` + +性能结果: + +```text +--- batch=32768 n_ops=5 mode=serial --- + Keygen: 6.4 ms/batch -> 5081079 ops/sec + Encaps: 5.6 ms/batch -> 5885813 ops/sec + Decaps: 5.9 ms/batch -> 5526684 ops/sec +``` + +当前已知 TPB 对比: + +| TPB | Correctness | Keygen | Encaps | Decaps | +|---:|---|---:|---:|---:| +| 32 | PASS | 4,909,746 ops/s | 5,984,596 ops/s | 5,450,378 ops/s | +| 64 | PASS | 4,866,205 ops/s | 5,994,111 ops/s | 5,544,629 ops/s | +| 128 | PASS | 5,081,079 ops/s | 5,885,813 ops/s | 5,526,684 ops/s | +| 256 | PASS | 5,444,673 ops/s | 6,035,024 ops/s | 5,581,568 ops/s | +| 512 | PASS | 5,049,880 ops/s | 5,603,379 ops/s | 5,232,946 ops/s | + +阶段判断: + +> 在已测试的 32/64/128/256/512 中,`TPB=256` 对 Kyber-768 的 keygen、encaps、decaps 均为当前最好。相较默认 `TPB=64`,`TPB=256` 使 keygen 从 4,866,205 ops/s 提升到 5,444,673 ops/s,提升约 11.9%。该结果支持将 `KEM_SERIAL_TPB=256` 作为 AMD/RDNA3 上 Kyber-768 serial 路径的候选默认配置。 + +### Step 2 补充:TPB=64 默认基线 + +`KEM_SERIAL_TPB=64` 正确性测试: + +```text +KEM 正确性: PASS +``` + +性能结果: + +```text +--- batch=32768 n_ops=5 mode=serial --- + Keygen: 6.7 ms/batch -> 4866205 ops/sec + Encaps: 5.5 ms/batch -> 5994111 ops/sec + Decaps: 5.9 ms/batch -> 5544629 ops/sec +``` + +与 `TPB=256` 对比: + +| Metric | TPB=64 | TPB=256 | 变化 | +|---|---:|---:|---:| +| Keygen | 4,866,205 ops/s | 5,444,673 ops/s | +11.9% | +| Encaps | 5,994,111 ops/s | 6,035,024 ops/s | +0.7% | +| Decaps | 5,544,629 ops/s | 5,581,568 ops/s | +0.7% | + +结论: + +> `TPB=256` 的主要收益集中在 keygen,encaps/decaps 提升较小。这说明 keygen 路径对线程块大小和调度粒度更敏感,后续优化仍应围绕 keygen 的采样/密钥生成路径展开。 + +### Step 2 补充:TPB=32 结果 + +`KEM_SERIAL_TPB=32` 正确性测试: + +```text +KEM 正确性: PASS +``` + +性能结果: + +```text +--- batch=32768 n_ops=5 mode=serial --- + Keygen: 6.7 ms/batch -> 4909746 ops/sec + Encaps: 5.5 ms/batch -> 5984596 ops/sec + Decaps: 6.0 ms/batch -> 5450378 ops/sec +``` + +结论: + +> `TPB=32` 正确性通过,但 keygen、encaps、decaps 均未超过 `TPB=256`。至此 Kyber-768 的 TPB sweep 闭合,`TPB=256` 是当前最优线程块大小。 + +### Step 3 初步推广:Aigis-enc-4 上 TPB=256 不是全局最优 + +对 Aigis-enc-4 使用 `KEM_SERIAL_TPB=256`,batch=65536: + +```text +--- batch=65536 n_ops=5 mode=serial --- + Keygen: 16.8 ms/batch -> 3908677 ops/sec + Encaps: 23.1 ms/batch -> 2837074 ops/sec + Decaps: 28.2 ms/batch -> 2325429 ops/sec +``` + +与全量 sweep 中 Aigis-enc-4 原最佳对比: + +| Metric | 原最佳 | TPB=256 | 变化 | +|---|---:|---:|---:| +| Keygen | 3,874,049 ops/s | 3,908,677 ops/s | +0.9% | +| Encaps | 2,943,683 ops/s | 2,837,074 ops/s | -3.6% | +| Decaps | 2,370,340 ops/s | 2,325,429 ops/s | -1.9% | + +判断: + +> `TPB=256` 对 Aigis-enc-4 的 keygen 只有轻微收益,并降低 encaps/decaps 吞吐。因此 `TPB=256` 不应直接作为所有 KEM 算法的全局默认值。更合理的策略是按算法和操作类型选择配置:Kyber-768 可优先采用 `TPB=256`,Aigis-enc-4 暂时保留原配置作为总体吞吐基线。 + +### Step 3 补充:Kyber-1024 与 Aigis-enc-1 的 TPB=256 推广 + +#### Kyber-1024 + +`KEM_SERIAL_TPB=256`,batch=32768: + +```text +--- batch=32768 n_ops=5 mode=serial --- + Keygen: 8.2 ms/batch -> 3975785 ops/sec + Encaps: 7.6 ms/batch -> 4295105 ops/sec + Decaps: 8.8 ms/batch -> 3707815 ops/sec +``` + +与全量 sweep 中 Kyber-1024 原最佳对比: + +| Metric | 原最佳 | TPB=256 | 变化 | +|---|---:|---:|---:| +| Keygen | 3,691,492 ops/s | 3,975,785 ops/s | +7.7% | +| Encaps | 4,207,232 ops/s | 4,295,105 ops/s | +2.1% | +| Decaps | 3,674,413 ops/s | 3,707,815 ops/s | +0.9% | + +判断: + +> `TPB=256` 对 Kyber-1024 同样有效,尤其 keygen 提升约 7.7%。这说明 Kyber 系列在 AMD/RDNA3 上普遍受益于更大的 serial kernel 线程块配置。 + +#### Aigis-enc-1 + +`KEM_SERIAL_TPB=256`,batch=65536: + +```text +--- batch=65536 n_ops=5 mode=serial --- + Keygen: 7.2 ms/batch -> 9113161 ops/sec + Encaps: 9.5 ms/batch -> 6890204 ops/sec + Decaps: 12.7 ms/batch -> 5158510 ops/sec +``` + +与全量 sweep 中 Aigis-enc-1 原最佳对比: + +| Metric | 原最佳 | TPB=256 | 变化 | +|---|---:|---:|---:| +| Keygen | 8,316,531 ops/s | 9,113,161 ops/s | +9.6% | +| Encaps | 7,092,336 ops/s | 6,890,204 ops/s | -2.9% | +| Decaps | 5,175,258 ops/s | 5,158,510 ops/s | -0.3% | + +判断: + +> `TPB=256` 对 Aigis-enc-1 的 keygen 也有明显收益,但会降低 encaps,decaps 基本持平。这进一步说明三类 KEM 操作对线程块大小的敏感性不同,后续不应继续用单一 `KEM_SERIAL_TPB` 控制 keygen、encaps 和 decaps。 + +### Step 4 新优化方向:按操作拆分 TPB + +当前证据: + +- Kyber-768:`TPB=256` 使 keygen 提升约 11.9%。 +- Kyber-1024:`TPB=256` 使 keygen 提升约 7.7%。 +- Aigis-enc-1:`TPB=256` 使 keygen 提升约 9.6%,但 encaps 下降约 2.9%。 +- Aigis-enc-4:`TPB=256` 使 keygen 轻微提升约 0.9%,但 encaps/decaps 下降。 + +结论: + +> 下一步应将当前单一 `KEM_SERIAL_TPB` 拆分为 `KEM_KEYGEN_TPB`、`KEM_ENCAPS_TPB`、`KEM_DECAPS_TPB`。这样可以让 keygen 使用更适合 AMD 的 `TPB=256`,同时让 encaps/decaps 继续保留各自更优的配置,避免一个编译宏同时影响三种不同操作。 + +### Step 4 初测:按操作拆分 TPB + +代码已将单一 `KEM_SERIAL_TPB` 拆分为: + +```c +KEM_KEYGEN_TPB +KEM_ENCAPS_TPB +KEM_DECAPS_TPB +``` + +测试配置: + +```bash +KEM_KEYGEN_TPB=256 KEM_ENCAPS_TPB=64 KEM_DECAPS_TPB=64 bash build_hip.sh kyber768 +./kyber768_amd --batch 128 --n-ops 1 +./kyber768_amd --batch 32768 --n-ops 5 --no-correctness +``` + +正确性: + +```text +KEM 正确性: PASS +``` + +性能结果: + +```text +--- batch=32768 n_ops=5 mode=serial --- + Keygen: 6.3 ms/batch -> 5223903 ops/sec + Encaps: 5.5 ms/batch -> 5948418 ops/sec + Decaps: 5.9 ms/batch -> 5568138 ops/sec +``` + +与默认 `TPB=64/64/64` 对比: + +| Metric | 默认 64/64/64 | 拆分 256/64/64 | 变化 | +|---|---:|---:|---:| +| Keygen | 4,866,205 ops/s | 5,223,903 ops/s | +7.4% | +| Encaps | 5,994,111 ops/s | 5,948,418 ops/s | -0.8% | +| Decaps | 5,544,629 ops/s | 5,568,138 ops/s | +0.4% | + +阶段判断: + +> 按操作拆分 TPB 是有效方向。`256/64/64` 组合在基本不损失 encaps/decaps 的情况下,将 Kyber-768 keygen 较默认配置提升约 7.4%。但该结果低于此前全 `TPB=256` 单次测试中的 5.44M keygen,因此需要用更多迭代次数复测,排除运行波动影响。 + +### Step 4 复测:n_ops=20 稳定态结果 + +为降低短迭代测量波动,将 `n_ops` 从 5 提高到 20 后重新测试 Kyber-768,batch=32768。 + +结果: + +| Config | Keygen | Encaps | Decaps | +|---|---:|---:|---:| +| 256/64/64 | 6,190,326 ops/s | 6,023,175 ops/s | 5,610,136 ops/s | +| 256/256/256 | 6,226,843 ops/s | 6,026,556 ops/s | 5,612,435 ops/s | +| 256/128/128 | 6,297,089 ops/s | 6,026,840 ops/s | 5,617,614 ops/s | + +阶段性结论: + +> `n_ops=20` 复测表明,Kyber-768 在 AMD 上的稳定态 keygen 吞吐可达到 6.2M ops/s 以上,显著高于早期 `n_ops=5` 的 5.2M 到 5.4M 结果。三种 TPB 组合的 encaps/decaps 基本一致,说明当前实现中 keygen 对 TPB 更敏感,而 encaps/decaps 对 64/128/256 的差异不明显。 + +重要修正: + +> 后续论文中的性能结论应优先采用较大迭代次数的稳定态结果,而不是早期 `n_ops=3/5` 的短迭代结果。短迭代更适合快速调试,稳定性能表建议统一使用 `n_ops=20` 或更高。 + +## 12. Kyber 系列 n_ops=20 稳定性能表 + +使用配置: + +```text +KEM_KEYGEN_TPB=256 +KEM_ENCAPS_TPB=128 +KEM_DECAPS_TPB=128 +batch=32768 +n_ops=20 +``` + +命令: + +```bash +mkdir -p amd_results/final_nops20 + +for target in kyber512 kyber768 kyber1024; do + KEM_KEYGEN_TPB=256 KEM_ENCAPS_TPB=128 KEM_DECAPS_TPB=128 \ + bash build_hip.sh "$target" 2>&1 | tee -a amd_results/final_nops20/kyber_nops20.log + + ./${target}_amd --batch 32768 --n-ops 20 --no-correctness \ + 2>&1 | tee -a amd_results/final_nops20/kyber_nops20.log +done +``` + +稳定性能结果: + +| Algorithm | Keygen | Encaps | Decaps | Batch | Iterations | +|---|---:|---:|---:|---:|---:| +| Kyber-512 | 10,538,121 ops/s | 11,366,292 ops/s | 7,442,359 ops/s | 32768 | 20 | +| Kyber-768 | 6,184,066 ops/s | 6,018,408 ops/s | 5,623,416 ops/s | 32768 | 20 | +| Kyber-1024 | 4,665,935 ops/s | 4,277,726 ops/s | 3,860,981 ops/s | 32768 | 20 | + +与早期短迭代 sweep 相比: + +- Kyber-512 keygen 从 6.55M 提升到 10.54M。 +- Kyber-768 keygen 从 4.80M 提升到 6.18M。 +- Kyber-1024 keygen 从 3.69M 提升到 4.67M。 + +阶段性结论: + +> 使用较大迭代次数与按操作拆分 TPB 后,AMD ROCm 上 Kyber 系列稳定态吞吐显著提升。Kyber-768 keygen 达到 6.18M ops/s,已经超过此前 RTX 4090D 记录中的 5.70M ops/s;Kyber-512 encaps 达到 11.37M ops/s,达到千万级后量子 KEM 吞吐。 + +适合论文/PPT 表述: + +> 通过 ROCm 环境适配、正确性验证、profile 定位与 TPB 调优,Kyber 系列在 AMD Radeon/RDNA3 平台上达到百万级到千万级稳定吞吐,证明 AMD GPU 在后量子密钥封装批处理场景中具备与高端 NVIDIA GPU 同量级的性能竞争力。 + +## 13. Aigis-enc 系列 n_ops=20 稳定性能表 + +使用配置: + +```text +KEM_KEYGEN_TPB=256 +KEM_ENCAPS_TPB=128 +KEM_DECAPS_TPB=128 +batch=65536 +n_ops=20 +``` + +命令: + +```bash +mkdir -p amd_results/final_nops20 + +for target in aigisenc1 aigisenc2 aigisenc3 aigisenc4; do + KEM_KEYGEN_TPB=256 KEM_ENCAPS_TPB=128 KEM_DECAPS_TPB=128 \ + bash build_hip.sh "$target" 2>&1 | tee -a amd_results/final_nops20/aigisenc_nops20.log + + ./${target}_amd --batch 65536 --n-ops 20 --no-correctness \ + 2>&1 | tee -a amd_results/final_nops20/aigisenc_nops20.log +done +``` + +稳定性能结果: + +| Algorithm | Keygen | Encaps | Decaps | Batch | Iterations | +|---|---:|---:|---:|---:|---:| +| Aigis-enc-1 | 10,368,547 ops/s | 7,200,232 ops/s | 5,200,585 ops/s | 65536 | 20 | +| Aigis-enc-2 | 6,879,027 ops/s | 4,701,072 ops/s | 3,370,967 ops/s | 65536 | 20 | +| Aigis-enc-3 | 6,344,582 ops/s | 4,629,810 ops/s | 3,594,671 ops/s | 65536 | 20 | +| Aigis-enc-4 | 4,261,266 ops/s | 2,955,561 ops/s | 2,425,389 ops/s | 65536 | 20 | + +与早期短迭代 sweep 相比: + +- Aigis-enc-1 keygen 从 8.32M 提升到 10.37M。 +- Aigis-enc-2 keygen 从 5.71M 提升到 6.88M。 +- Aigis-enc-3 keygen 从 5.16M 提升到 6.34M。 +- Aigis-enc-4 keygen 从 3.87M 提升到 4.26M。 + +阶段性结论: + +> Aigis-enc 系列在 AMD ROCm 上同样达到稳定百万级到千万级吞吐。Aigis-enc-1 keygen 达到 10.37M ops/s,说明 AMD 平台不仅能够支撑标准 Kyber/ML-KEM 类负载,也能支撑国产 Aigis-enc 系列后量子密钥封装算法的高并发批处理。 + +## 14. KEM 模块当前最终成果 + +截至当前阶段,Kyber/Aigis-enc KEM 模块已经完成: + +- 7 个目标全部在 AMD ROCm 上编译通过。 +- 7 个目标小 batch 正确性全部 PASS。 +- 完成 Kyber-768、Kyber-1024、Aigis-enc-4 的程序内 profile。 +- 完成 Kyber-768 的 `rocprofv3` kernel trace。 +- 定位 sample/XOF/rejection sampling 为 keygen pipeline 主要瓶颈。 +- 验证 split sampling 是负优化。 +- 验证 TPB sweep 与按操作拆分 TPB 是有效调优方向。 +- 形成 `n_ops=20` 稳定性能表。 + +当前最佳稳定结果摘要: + +| Family | Representative Best | +|---|---| +| Kyber | Kyber-512 encaps 11.37M ops/s | +| Kyber | Kyber-768 keygen 6.18M ops/s | +| Kyber | Kyber-1024 keygen 4.67M ops/s | +| Aigis-enc | Aigis-enc-1 keygen 10.37M ops/s | +| Aigis-enc | Aigis-enc-4 keygen 4.26M ops/s | + +可写入论文的阶段性总述: + +> 本项目在 AMD ROCm/RDNA3 平台上完成 Kyber 与 Aigis-enc 共 7 个参数集的批量 KEM 实现验证。通过 ROCm 工具与程序内 profile 定位,发现 keygen 路径主要瓶颈集中在 sample/XOF/rejection sampling 阶段;通过 TPB sweep 和按操作拆分 kernel launch 配置,获得稳定百万级到千万级吞吐。其中 Kyber-512 encaps 达到 11.37M ops/s,Aigis-enc-1 keygen 达到 10.37M ops/s,证明 AMD 平台在后量子密钥封装批处理场景中具备较强工程竞争力。 + +## 16. 最终自动化报告记录 + +最终报告脚本: + +```bash +bash run_kem_final_report_amd.sh +bash run_kem_resource_profile_amd.sh kyber768 32768 200 +``` + +生成目录: + +```text +amd_results/final_report_20260612_074154 +amd_results/resource_profile_kyber768_20260612_074231 +``` + +最终报告提取文件: + +```text +amd_results/final_report_20260612_074154/kem_final_extract.txt +``` + +最终 KEM 性能表: + +| Algorithm | Keygen | Encaps | Decaps | +|---|---:|---:|---:| +| Kyber-512 | 10,132,546 ops/s | 11,307,354 ops/s | 7,495,810 ops/s | +| Kyber-768 | 6,331,652 ops/s | 5,998,665 ops/s | 5,659,745 ops/s | +| Kyber-1024 | 4,484,088 ops/s | 4,290,608 ops/s | 3,836,327 ops/s | +| Aigis-enc-1 | 10,326,268 ops/s | 7,204,754 ops/s | 5,202,343 ops/s | +| Aigis-enc-2 | 6,639,865 ops/s | 4,704,347 ops/s | 3,367,005 ops/s | +| Aigis-enc-3 | 6,361,075 ops/s | 4,625,470 ops/s | 3,595,634 ops/s | +| Aigis-enc-4 | 4,144,860 ops/s | 2,951,076 ops/s | 2,429,642 ops/s | + +资源/profile 报告目录: + +```text +amd_results/resource_profile_kyber768_20260612_074231 +``` + +该目录包含: + +```text +metadata.txt +build.log +benchmark.log +rocm_smi_during.log +rocm_smi_gpu0_extract.log +rocprofv3/ +rocprofv3_summary.csv +``` + +说明: + +> `rocprofv3` 运行会引入额外开销,因此 profile 目录中的 `n_ops=1` 吞吐只用于定位 kernel 和 API 行为,不作为最终性能数据。最终性能应引用 `final_report_20260612_074154/kem_final_extract.txt` 中的 `n_ops=20` 稳定结果。 + +## 9. TPB 是什么,以及为什么要调 + +`TPB` 是 `threads per block`,表示每个 GPU kernel 线程块中包含多少个线程。 + +在当前 KEM 代码中,`KEM_SERIAL_TPB` 控制批量 KEM serial kernel 的启动配置: + +```cpp +int tpb = KEM_SERIAL_TPB; +int blocks = (batch_count + tpb - 1) / tpb; + +batch_kem_keypair_serial_kernel<<>>(...); +batch_kem_encaps_serial_kernel<<>>(...); +batch_kem_decaps_serial_kernel<<>>(...); +``` + +以 `batch_count=32768` 为例: + +| TPB | Block 数量 | 含义 | +|---:|---:|---| +| 64 | 512 blocks | 每个 block 处理 64 个 KEM 实例 | +| 256 | 128 blocks | 每个 block 处理 256 个 KEM 实例 | +| 512 | 64 blocks | 每个 block 处理 512 个 KEM 实例 | + +`TPB` 不是 AMD 特有概念,CUDA/NVIDIA 和 HIP/AMD 都存在类似 kernel launch 配置。区别在于: + +> TPB 不是 AMD 特有,但最优 TPB 与 GPU 架构、寄存器压力、occupancy、wavefront/warp 调度、kernel 内部逻辑和访存模式强相关。因此 CUDA 代码迁移到 ROCm 后,不能默认沿用 NVIDIA 上的配置。 + +当前本地 4090 构建脚本中,`KEM_SERIAL_TPB` 默认值为: + +```bash +KEM_SERIAL_TPB="${KEM_SERIAL_TPB:-64}" +``` + +也就是说,原 4090 版本默认使用: + +```text +KEM_SERIAL_TPB=64 +``` + +而 AMD gfx1100 上的 Kyber-768 实测显示: + +| TPB | Keygen | Encaps | Decaps | +|---:|---:|---:|---:| +| 32 | 4.91M ops/s | 5.98M ops/s | 5.45M ops/s | +| 64 | 4.87M ops/s | 5.99M ops/s | 5.54M ops/s | +| 128 | 5.08M ops/s | 5.89M ops/s | 5.53M ops/s | +| 256 | 5.44M ops/s | 6.04M ops/s | 5.58M ops/s | +| 512 | 5.05M ops/s | 5.60M ops/s | 5.23M ops/s | + +结论: + +> 将 Kyber-768 的 `KEM_SERIAL_TPB` 从 4090 默认沿用的 64 调整到 AMD 更适配的 256 后,keygen 吞吐从 4.87M ops/s 提升到 5.44M ops/s,提升约 11.9%。这说明简单 HIP 迁移只能保证“能跑”,要发挥 ROCm/RDNA3 性能,必须重新进行平台相关的 kernel launch 参数调优。 + +适合论文/PPT 的一句话: + +> HIP 迁移解决可运行性,ROCm 原生调优决定能否跑满;TPB sweep 是本项目第一项可量化的 ROCm 平台调优结果。 + +## 10. 下一阶段:使用 ROCm 工具定位瓶颈 + +当前已有程序内 profile 显示,Kyber/Aigis-enc keygen 的主要瓶颈集中在 sample 阶段。但程序内 profile 只能给出粗粒度阶段耗时,下一步需要使用 ROCm 工具定位到 kernel 级别。 + +目标: + +1. 确认 sample 阶段对应哪些 kernel 或设备函数路径。 +2. 观察 kernel 执行时间、调用次数、grid/block 配置。 +3. 判断是否存在 occupancy 不足、寄存器压力、LDS/显存访问效率低、kernel launch 过多等问题。 +4. 将 profile 结论转化为下一步代码优化方案。 + +建议优先分析三个目标: + +```text +kyber768_amd 中等安全等级,已有 TPB 正向优化 +kyber1024_amd 高安全等级 Kyber,sample 与 matvec 更重 +aigisenc4_amd 高安全等级 Aigis-enc,sample 占比最高 +``` + +### 10.1 先确认 ROCm 工具可用性 + +在 AMD JupyterLab 中执行: + +```bash +which rocprofv3 || true +which rocprof || true +which rocm-smi || true +hipcc --version +``` + +记录: + +- ROCm 工具是否存在。 +- `rocprofv3` 是否可运行。 +- `rocm-smi` 是否能读取显卡信息。 +- `hipcc` 版本。 + +实际结果: + +```text +/opt/python/bin/rocprofv3 +/opt/python/bin/rocm-smi +HIP version: 7.12.60610-2bd1678d3d +AMD clang version 22.0.0git +Target: x86_64-unknown-linux-gnu +InstalledDir: /opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib/llvm/bin +``` + +结论: + +> 当前 AMD JupyterLab 环境提供 `rocprofv3` 和 `rocm-smi`,没有发现旧版 `rocprof`。后续性能定位应以 `rocprofv3 --kernel-trace --hip-trace` 为主。 + +### 10.2 使用 rocprofv3 做 kernel trace + +当前脚本: + +```bash +bash profile_kem_one_amd.sh kyber768_amd 32768 3 +``` + +如果 `amd_results/profile/xxx_rocprof/` 目录为空,需要改用更显式的 `rocprofv3` 参数。可尝试: + +```bash +mkdir -p amd_results/profile/kyber768_manual_rocprof + +rocprofv3 \ + --kernel-trace \ + --hip-trace \ + --output-format csv \ + --output-directory amd_results/profile/kyber768_manual_rocprof \ + -- \ + ./kyber768_amd --batch 32768 --n-ops 1 --no-correctness --pipeline +``` + +然后查看输出: + +```bash +find amd_results/profile/kyber768_manual_rocprof -type f -maxdepth 2 -print +``` + +如果有 CSV 文件,再压缩或 `cat` 关键文件给后续分析。 + +实际 `rocprofv3` 运行成功: + +```text +rocprofv3 \ + --kernel-trace \ + --hip-trace \ + --output-format csv \ + --output-directory amd_results/profile/kyber768_manual_rocprof \ + -- \ + ./kyber768_amd --batch 32768 --n-ops 1 --no-correctness --pipeline +``` + +生成文件: + +```text +amd_results/profile/kyber768_manual_rocprof/nb-a8b881d6/3984_kernel_trace.csv +amd_results/profile/kyber768_manual_rocprof/nb-a8b881d6/3984_hip_api_trace.csv +amd_results/profile/kyber768_manual_rocprof/nb-a8b881d6/3984_agent_info.csv +``` + +该结果说明: + +> ROCm 工具链已经可以采集 Kyber-768 的 kernel trace 与 HIP API trace。下一步需要解析 `kernel_trace.csv`,找出耗时最高的 kernel,并与程序内 profile 的 sample/NTT/matvec/pack 阶段进行对应。 + +### 10.3 如果 rocprofv3 不可用,使用 rocprof + +有些环境可能只有旧版 `rocprof`。可尝试: + +```bash +mkdir -p amd_results/profile/kyber768_rocprof_old + +rocprof \ + --hip-trace \ + --hsa-trace \ + -d amd_results/profile/kyber768_rocprof_old \ + ./kyber768_amd --batch 32768 --n-ops 1 --no-correctness --pipeline +``` + +查看输出: + +```bash +find amd_results/profile/kyber768_rocprof_old -type f -maxdepth 2 -print +``` + +### 10.4 同步采集 GPU 状态 + +在 benchmark 前后记录: + +```bash +rocm-smi +rocm-smi --showuse --showmemuse --showtemp --showpower +``` + +用于论文中的资源描述: + +- GPU 型号与架构。 +- 显存容量。 +- 运行时显存使用。 +- GPU utilization。 +- 功耗/温度可选。 + +### 10.5 下一步可能的优化方向 + +根据目前数据,优先级如下: + +1. **按操作拆分 TPB** + - 当前单一 `KEM_SERIAL_TPB` 同时影响 keygen/encaps/decaps。 + - 下一步可改成 `KEM_KEYGEN_TPB=256`、`KEM_ENCAPS_TPB=64`、`KEM_DECAPS_TPB=64`。 + - 目标:保留 keygen 提升,同时避免 encaps/decaps 下降。 + +2. **sample 路径优化** + - split sample 实验证明简单拆 kernel 会负优化。 + - 后续应考虑保持数据局部性,减少中间全局内存写回,而不是盲目拆分。 + +3. **按算法选择配置** + - Kyber 系列明显更受益于 `TPB=256`。 + - Aigis-enc 系列需要单独扫参,不适合直接套用 Kyber 配置。 + +4. **端到端工作流优化** + - 小 batch 吞吐较低。 + - 多文件科研数据流转平台应聚合任务,使用批处理摊薄 launch 和同步开销。 + +## 11. rocprofv3 Kernel Trace 初步结论 + +`rocprofv3` 已成功生成并解析 Kyber-768 的 kernel trace 与 HIP API trace: + +```text +amd_results/profile/kyber768_manual_rocprof/nb-a8b881d6/3984_kernel_trace.csv +amd_results/profile/kyber768_manual_rocprof/nb-a8b881d6/3984_hip_api_trace.csv +``` + +### 11.1 Kernel 热点 + +Top kernels by total time: + +| Kernel | Total | Avg | Calls | 说明 | +|---|---:|---:|---:|---| +| `batch_kem_decaps_serial_kernel` | 11.728 ms | 5.864 ms | 2 | decaps 主路径 | +| `batch_kem_encaps_serial_kernel` | 9.944 ms | 4.972 ms | 2 | encaps 主路径 | +| `batch_kem_keypair_serial_kernel` | 9.564 ms | 9.564 ms | 1 | serial keygen 主路径 | +| `batch_keygen_warp_sample_kernel` | 5.130 ms | 5.130 ms | 1 | pipeline keygen 采样阶段 | +| `batch_polyvec_matvec_kernel` | 0.485 ms | 0.485 ms | 1 | pipeline matvec | +| `batch_pack_keypair_finalize_kernel` | 0.351 ms | 0.351 ms | 1 | keypair finalize | +| `batch_invntt_kernel` | 0.334 ms | 0.111 ms | 3 | inverse NTT | +| `batch_ntt_kernel` | 0.303 ms | 0.101 ms | 3 | NTT | +| `batch_poly_caddq_kernel` | 0.265 ms | 0.044 ms | 6 | modular normalize | +| `batch_poly_add_kernel` | 0.169 ms | 0.056 ms | 3 | polynomial add | +| `batch_pack_pk_polyvec_kernel` | 0.109 ms | 0.109 ms | 1 | pack pk | +| `batch_pack_sk_polyvec_kernel` | 0.105 ms | 0.105 ms | 1 | pack sk | + +注意: + +> 本次命令带有 `--pipeline`,程序会先跑 serial 测试,再跑 pipeline 测试。因此 trace 中同时包含 serial keygen/encaps/decaps kernel 和 pipeline keygen 分阶段 kernel。不能直接把所有 kernel 总时间相加作为单一路径耗时,应按执行路径分开解释。 + +### 11.2 Kernel 级瓶颈判断 + +对 pipeline keygen 而言,kernel trace 与程序内 profile 一致: + +```text +batch_keygen_warp_sample_kernel: 5.130 ms +batch_polyvec_matvec_kernel: 0.485 ms +batch_invntt_kernel total: 0.334 ms +batch_ntt_kernel total: 0.303 ms +pack pk/sk/finalize total: 0.565 ms 左右 +``` + +结论: + +> ROCm kernel trace 进一步确认,Kyber-768 pipeline keygen 的主要热点是 `batch_keygen_warp_sample_kernel`,即采样/XOF/拒绝采样路径。NTT、inverse NTT、matvec 和 packing 的耗时均显著低于 sample。因此下一步优化重点应继续围绕采样路径,而不是优先重写 NTT。 + +### 11.3 HIP API Trace + +Top HIP APIs by total time: + +| HIP API | Total | Avg | Calls | 说明 | +|---|---:|---:|---:|---| +| `hipGetDevice` | 172.122 ms | 172.122 ms | 1 | 初始化/查询开销 | +| `hipMemcpy` | 89.495 ms | 22.374 ms | 4 | 主机到设备数据准备 | +| `hipDeviceSynchronize` | 39.728 ms | 4.966 ms | 8 | 包含等待 kernel 完成时间 | +| `hipLaunchKernel` | 2.147 ms | 0.086 ms | 25 | kernel launch 开销 | +| `hipFree` | 1.114 ms | 0.062 ms | 18 | 释放显存 | +| `hipMalloc` | 0.452 ms | 0.025 ms | 18 | 分配显存 | + +解释: + +- `hipGetDevice` 是一次性初始化/查询开销,不应计入稳定态吞吐瓶颈。 +- `hipDeviceSynchronize` 的时间包含等待 GPU kernel 执行完成,不代表纯 CPU API 开销。 +- `hipMemcpy` 时间较高,说明端到端工作流中需要避免频繁 host-device 拷贝。 +- `hipLaunchKernel` 总耗时约 2.147 ms,25 次调用,平均约 0.086 ms;对大 batch 影响可接受,但对小 batch 会明显影响端到端效率。 + +### 11.4 从 ROCm Trace 得出的优化建议 + +1. **sample kernel 是第一优化目标** + - 重点分析 `batch_keygen_warp_sample_kernel` 中 SHAKE/XOF、拒绝采样、噪声采样的线程协作方式。 + - split sample 已经证明会负优化,因此后续应优先考虑保持数据局部性与减少全局内存中间写回。 + +2. **减少端到端内存搬运** + - `hipMemcpy` 在 API trace 中占比较高。 + - 后续多文件可信流转平台应尽量复用 device buffer,避免每批数据重复分配和拷贝。 + +3. **减少小 batch 下 launch/sync 开销** + - `hipLaunchKernel` 与 `hipDeviceSynchronize` 对小 batch 会放大。 + - 前端/工作流层应聚合多文件任务,使用 batch 队列提高 GPU 利用率。 + +4. **按路径分别优化** + - serial keygen/encaps/decaps 是当前 best-throughput 基线。 + - pipeline keygen 适合用于分阶段 profile,但不一定是所有参数集的性能最优实现。 + +### 11.5 最终配置下的资源分析 + +使用最终候选配置重新采集: + +```text +KEM_KEYGEN_TPB=256 +KEM_ENCAPS_TPB=128 +KEM_DECAPS_TPB=128 +batch=32768 +rocprofv3 --kernel-trace --hip-trace +``` + +关键 kernel 资源摘要: + +| Kernel | Total | Calls | VGPR | SGPR | Scratch | LDS | Workgroup | Grid | +|---|---:|---:|---:|---:|---:|---:|---|---| +| `batch_kem_decaps_serial_kernel` | 11.569 ms | 2 | 184 | 128 | 17168 | 0 | 128x1x1 | 32768x1x1 | +| `batch_kem_encaps_serial_kernel` | 9.957 ms | 2 | 184 | 128 | 16064 | 0 | 128x1x1 | 32768x1x1 | +| `batch_kem_keypair_serial_kernel` | 8.844 ms | 1 | 184 | 128 | 8592 | 0 | 256x1x1 | 32768x1x1 | +| `batch_keygen_warp_sample_kernel` | 5.101 ms | 1 | 152 | 128 | 304 | 512 | 128x1x1 | 1048576x1x1 | +| `batch_polyvec_matvec_kernel` | 0.477 ms | 1 | 16 | 128 | 0 | 0 | 256x1x1 | 8388608x3x1 | +| `batch_invntt_kernel` | 0.329 ms | 3 | 16 | 128 | 0 | 1024 | 128x1x1 | 4194304x1x1 | +| `batch_ntt_kernel` | 0.304 ms | 3 | 16 | 128 | 0 | 1024 | 128x1x1 | 4194304x1x1 | + +资源侧结论: + +> ROCm kernel trace 显示,monolithic serial KEM kernel 的 VGPR 均达到 184,且 encaps/decaps scratch 分别达到 16064/17168 bytes,说明当前单线程单实例的设备函数路径存在较高寄存器与栈/溢出压力。相比之下,NTT、inverse NTT 和 matvec 的 VGPR 较低且 scratch 为 0,说明它们不是当前资源瓶颈。 + +由此得到两个优化方向: + +1. **sample/XOF/rejection sampling 仍是算法阶段热点** + - `batch_keygen_warp_sample_kernel` 耗时 5.101 ms,仍是 pipeline keygen 中最大单项。 + - 该 kernel 使用 VGPR=152、Scratch=304、LDS=512,说明 sample 路径也存在较高寄存器压力,但比 monolithic serial kernel 的 scratch 小得多。 + +2. **monolithic serial kernel 存在寄存器/栈压力** + - keypair/encaps/decaps serial kernel 均使用 VGPR=184。 + - encaps/decaps scratch 超过 16KB,后续可考虑拆出局部大数组、减少设备函数局部状态、复用全局/共享缓冲,或针对 encaps/decaps 建立更细粒度 batch pipeline。 + +HIP API trace: + +| API | Total | Calls | 判断 | +|---|---:|---:|---| +| `hipGetDevice` | 173.320 ms | 1 | 初始化/查询开销,不计入稳定态瓶颈 | +| `hipMemcpy` | 87.502 ms | 4 | 数据准备开销,端到端系统应减少 host-device 拷贝 | +| `hipDeviceSynchronize` | 38.816 ms | 8 | 包含等待 GPU kernel 完成,不是纯 API 开销 | +| `hipLaunchKernel` | 2.002 ms | 25 | 小 batch 场景会放大 launch 开销 | +| `hipMalloc/hipFree` | 1.542 ms | 36 | 后续工作流应复用 device buffer | + +工程化建议: + +> 对论文而言,当前 ROCm trace 可以支撑“sample 是算法热点,monolithic KEM kernel 有寄存器/栈压力,端到端系统需减少拷贝和复用显存缓冲”三个结论。下一步调优不应优先优化 NTT,而应优先围绕 sample 路径、serial kernel 局部状态和端到端缓冲复用展开。 + +### 11.6 rocm-smi 资源采集说明 + +单次运行后执行: + +```bash +rocm-smi --showuse --showmemuse --showtemp --showpower +``` + +得到 GPU 利用率与显存占用均为 0。这是因为命令在 benchmark 结束后执行,GPU 已经空闲,不能代表运行中资源占用。 + +当前可记录的静态信息: + +- 服务器暴露 8 张 AMD GPU。 +- 空闲温度约 24 到 33 摄氏度。 +- 空闲功耗约 8 到 14 W。 +- benchmark 结束后 VRAM 使用率恢复为 0。 + +后续如果要获得运行中资源数据,需要在 benchmark 期间循环采样 `rocm-smi`。 + +### 11.7 运行中 rocm-smi 采样结果 + +使用长迭代 benchmark 期间循环采样: + +```bash +( + for i in $(seq 1 80); do + echo "===== sample $i $(date '+%H:%M:%S.%3N') =====" + rocm-smi --showuse --showmemuse --showtemp --showpower + sleep 0.2 + done +) > amd_results/resource/rocm_smi_during_kyber768.log & + +./kyber768_amd --batch 32768 --n-ops 200 --no-correctness +wait +``` + +性能结果: + +```text +--- batch=32768 n_ops=200 mode=serial --- + Keygen: 4.9 ms/batch -> 6637753 ops/sec + Encaps: 5.5 ms/batch -> 6008051 ops/sec + Decaps: 6.0 ms/batch -> 5440166 ops/sec +``` + +运行中资源观测: + +| 指标 | 观测值 | +|---|---| +| GPU use | 峰值 100%,多次采样 99%-100% | +| VRAM allocated | 峰值约 8% | +| Average graphics package power | 约 237-243 W | +| Edge temperature | 约 28-35 C | +| Junction temperature | 约 36-46 C | +| Memory temperature | 约 30-38 C | + +资源侧结论: + +> 长迭代 Kyber-768 benchmark 期间 GPU use 可稳定达到 99%-100%,说明当前批处理工作负载能够充分占用 AMD GPU 计算资源。与此同时,VRAM 使用峰值仅约 8%,说明当前 KEM benchmark 并非显存容量受限,更可能受采样/XOF、寄存器压力、scratch/栈压力和单 kernel 计算路径影响。 + +适合论文表述: + +> `rocm-smi` 运行中采样显示,Kyber-768 批处理阶段 GPU 利用率接近 100%,显存占用约 8%,功耗约 240W。结合 `rocprofv3` 的 kernel trace,可判断当前瓶颈主要来自计算密集型采样与 monolithic kernel 资源压力,而非显存容量不足。 + +### 11.8 launch_bounds 关闭实验 + +实验配置: + +```bash +KEM_KEYPAIR_LAUNCH_BOUNDS=0 +KEM_ENCAPS_LAUNCH_BOUNDS=0 +KEM_DECAPS_LAUNCH_BOUNDS=0 +KEM_KEYGEN_TPB=256 +KEM_ENCAPS_TPB=128 +KEM_DECAPS_TPB=128 +``` + +测试: + +```text +--- batch=32768 n_ops=50 mode=serial --- + Keygen: 5.0 ms/batch -> 6563939 ops/sec + Encaps: 5.5 ms/batch -> 5997218 ops/sec + Decaps: 5.9 ms/batch -> 5507774 ops/sec +``` + +结论: + +> 关闭 `__launch_bounds__` 后吞吐与 `n_ops=200` 稳定结果基本接近,没有出现明显提升。因此当前性能瓶颈不主要来自 launch bounds 约束本身。后续调优应转向 sample 内部实现、设备函数局部状态、缓冲复用和端到端数据搬运优化。 + +## 15. 端到端工程优化:Device Buffer 复用 + +为模拟后续“多文件科研数据可信流转平台”中的连续批处理场景,新增 `--reuse-bench` 测试入口,对比两种端到端执行方式: + +1. **Alloc-each-round** + - 每轮重新 `cudaMalloc` + - 拷贝输入 seed + - 执行 keygen/encaps/decaps + - 每轮 `cudaFree` + +2. **Reuse buffers** + - 初始化时分配一次 device buffer + - 每轮复用已有 buffer + - 只更新输入 seed 并执行 kernel + +测试命令: + +```bash +KEM_KEYGEN_TPB=256 KEM_ENCAPS_TPB=128 KEM_DECAPS_TPB=128 \ +bash build_hip.sh kyber768 + +./kyber768_amd --batch 32768 --n-ops 5 --reuse-bench 20 --no-correctness +``` + +结果: + +```text +=== Buffer reuse benchmark: Kyber-768 === +batch=32768 rounds=20 n_ops_per_round=5 + Alloc-each-round: total= 1701.7 ms | per_round= 85.085 ms | full-kem throughput=1925608 instances/sec + Reuse buffers: total= 1599.8 ms | per_round= 79.991 ms | full-kem throughput=2048239 instances/sec + Reuse speedup: 1.064x +``` + +结论: + +> Device buffer 复用使 Kyber-768 端到端 full-KEM 批处理吞吐从 1.93M instances/s 提升到 2.05M instances/s,提升约 6.4%。该优化不改变单个密码 kernel 内部算法,而是减少多批次工作流中的显存分配/释放开销,更符合真实科研数据平台的连续处理模式。 + +工程意义: + +> 在真实多文件可信流转系统中,服务端会连续处理多个文件批次。为每批文件重新分配和释放 GPU 缓冲会引入额外运行时开销。通过维护长期复用的 device buffer pool,可以提升端到端吞吐,并与前端/后端任务队列自然结合。 + +适合论文/PPT 表述: + +> 除 kernel 级 TPB 调优外,本项目还针对真实工作流进行端到端优化。通过复用 GPU device buffer,Kyber-768 full-KEM 连续批处理吞吐提升 6.4%,说明 ROCm 平台性能优化不仅包括 kernel 内部调参,也包括面向应用工作流的内存管理与任务调度优化。 + + + +### Step 2:全量构建 7 个 KEM 目标 + +目的: + +确认不只是 Kyber-768,Kyber512/1024 与 Aigis-enc 1/2/3/4 都能在 AMD 上构建和冒烟。 + +命令: + +```bash +cd /app/kyberandaigis-enc +bash build_hip.sh +bash run_kem_smoke_amd.sh +cat amd_results/kem_smoke_summary.csv +``` + +需要关注: + +- 哪些目标可以编译。 +- 哪些目标正确性 PASS。 +- 哪些目标出现 stack、寄存器、显存或运行时错误。 + +### Step 3:全量 batch sweep + +目的: + +找到各算法的最佳 batch size,并观察 AMD 上吞吐峰值。 + +命令: + +```bash +bash run_kem_sweep_amd.sh +cat amd_results/kem_best.csv +``` + +需要输出: + +- 每个算法 keygen/encaps/decaps 最佳吞吐。 +- 每个算法最佳 batch 配置。 +- 与 4090 数据对比。 + +### Step 4:针对热点目标做 profile + +优先目标: + +```text +kyber768_amd +kyber1024_amd +aigisenc3_amd +aigisenc4_amd +``` + +命令示例: + +```bash +bash profile_kem_one_amd.sh kyber1024_amd 32768 3 +bash profile_kem_one_amd.sh aigisenc4_amd 32768 3 +``` + +需要分析: + +- 是否仍然是 sample 阶段主导。 +- 大参数集是否出现 pack、NTT、matvec 或 decaps 瓶颈。 +- Aigis-enc 与 Kyber 的瓶颈是否一致。 + +### Step 5:将 KEM 结果接入项目主线 + +短期目标: + +- 先形成 `Kyber/Aigis-enc AMD 跑通 + 性能 + 瓶颈分析` 小节。 +- 再与已有 ML-DSA/Aigis-sig AMD 结果合并。 +- 最后落地到“多文件科研数据可信流转平台”。 + +项目主线表述: + +> 本项目不是单一 Kyber 跑分,而是面向 AMD ROCm 生态构建后量子科研数据可信流转系统。Kyber/Aigis-enc 负责会话密钥封装,ML-DSA/Aigis-sig 负责数据签名与篡改检测。通过 ROCm 上的批处理、profile 和优化,验证 AMD GPU 在后量子密码高并发科研数据流转中的工程价值,并为 AMD 构建类 cuPQC 的 ROCm-PQC 库提供实验依据。 + +## 8. 下一步优先级 + +最高优先级: + +1. 跑 `kyber768_split_amd`,验证 sample 拆分是否能降低 70% 瓶颈。 +2. 全量 `bash build_hip.sh`,确认 7 个 KEM 目标的编译情况。 +3. 全量 `bash run_kem_smoke_amd.sh`,确认正确性矩阵。 +4. 全量 `bash run_kem_sweep_amd.sh`,拿到 AMD KEM 最佳性能表。 +5. 选 2 个代表目标做 profile,形成可写入论文的瓶颈定位表。 + +当前最值得追的优化点: + +> Kyber-768 keygen sample 阶段占 70%,优先围绕采样/XOF/拒绝采样并行化做 ROCm 调优。 diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/sig_large_best.csv b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/sig_large_best.csv new file mode 100644 index 000000000..3fb6cd352 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/sig_large_best.csv @@ -0,0 +1,37 @@ +target,benchmark_mode,operation,batch,ms,ops_s,path,log +aigis1,independent,Keygen,32768,32.023,1023273,independent-opt,aigis1_amd_independent_b32768.log +aigis1,independent,Sign,32768,425.459,77018,decomp-pipeline,aigis1_amd_independent_b32768.log +aigis1,independent,Verify,16384,2.105,7781688,,aigis1_amd_independent_b16384.log +aigis1,paper,Keygen,32768,16.474,1989031,paper-shared-rhoA,aigis1_amd_paper_b32768.log +aigis1,paper,Sign,32768,486.455,67361,decomp-pipeline,aigis1_amd_paper_b32768.log +aigis1,paper,Verify,16384,2.120,7728283,,aigis1_amd_paper_b16384.log +aigis2,independent,Keygen,32768,42.804,765542,independent-old,aigis2_amd_independent_b32768.log +aigis2,independent,Sign,16384,422.524,38776,decomp-pipeline,aigis2_amd_independent_b16384.log +aigis2,independent,Verify,8192,1.285,6375467,,aigis2_amd_independent_b8192.log +aigis2,paper,Keygen,16384,9.460,1731964,paper-shared-rhoA,aigis2_amd_paper_b16384.log +aigis2,paper,Sign,16384,346.621,47268,decomp-pipeline,aigis2_amd_paper_b16384.log +aigis2,paper,Verify,8192,1.298,6309734,,aigis2_amd_paper_b8192.log +aigis3,independent,Keygen,32768,55.544,589950,independent-old,aigis3_amd_independent_b32768.log +aigis3,independent,Sign,16384,391.438,41856,decomp-pipeline,aigis3_amd_independent_b16384.log +aigis3,independent,Verify,8192,1.648,4970454,,aigis3_amd_independent_b8192.log +aigis3,paper,Keygen,32768,22.033,1487207,paper-shared-rhoA,aigis3_amd_paper_b32768.log +aigis3,paper,Sign,32768,798.910,41016,decomp-pipeline,aigis3_amd_paper_b32768.log +aigis3,paper,Verify,8192,1.668,4911425,,aigis3_amd_paper_b8192.log +mldsa44,independent,Keygen,32768,34.328,954556,independent-old,mldsa44_amd_independent_b32768.log +mldsa44,independent,Sign,32768,299.174,109528,decomp-pipeline,mldsa44_amd_independent_b32768.log +mldsa44,independent,Verify,8192,1.211,6763774,,mldsa44_amd_independent_b8192.log +mldsa44,paper,Keygen,16384,8.459,1936824,paper-shared-rhoA,mldsa44_amd_paper_b16384.log +mldsa44,paper,Sign,16384,180.552,90744,decomp-pipeline,mldsa44_amd_paper_b16384.log +mldsa44,paper,Verify,8192,1.213,6753549,,mldsa44_amd_paper_b8192.log +mldsa65,independent,Keygen,32768,53.505,612429,independent-old,mldsa65_amd_independent_b32768.log +mldsa65,independent,Sign,16384,257.458,63638,decomp-pipeline,mldsa65_amd_independent_b16384.log +mldsa65,independent,Verify,8192,1.734,4724442,,mldsa65_amd_independent_b8192.log +mldsa65,paper,Keygen,16384,10.908,1502057,paper-shared-rhoA,mldsa65_amd_paper_b16384.log +mldsa65,paper,Sign,16384,262.670,62375,decomp-pipeline,mldsa65_amd_paper_b16384.log +mldsa65,paper,Verify,8192,1.725,4749933,,mldsa65_amd_paper_b8192.log +mldsa87,independent,Keygen,32768,73.646,444940,independent-opt,mldsa87_amd_independent_b32768.log +mldsa87,independent,Sign,32768,599.155,54690,decomp-pipeline,mldsa87_amd_independent_b32768.log +mldsa87,independent,Verify,8192,2.422,3382030,,mldsa87_amd_independent_b8192.log +mldsa87,paper,Keygen,32768,28.239,1160371,paper-shared-rhoA,mldsa87_amd_paper_b32768.log +mldsa87,paper,Sign,16384,356.069,46014,decomp-pipeline,mldsa87_amd_paper_b16384.log +mldsa87,paper,Verify,8192,2.454,3337606,,mldsa87_amd_paper_b8192.log diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/sig_optimization_claims.md b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/sig_optimization_claims.md new file mode 100644 index 000000000..741e458de --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/sig_optimization_claims.md @@ -0,0 +1,107 @@ +# AMD ROCm Optimization Claims + +## Implemented Candidates + +- Stable signing remains the resource-aware `decomp-pipeline` path. +- `adaptive` is a runtime policy candidate: one binary selects the measured local winner by target, benchmark mode, and batch size, while falling back to base on cells where the matrix shows regressions. +- `check8` and `check16` measure whether fewer host-side done-count checks reduce ROCm synchronization overhead. +- `wave64_ctrl` measures whether 64-thread hash/check control kernels behave better on AMD wave64 hardware than 32-thread control kernels. +- `BATCH_SIGN_CP_FUSE_ENABLE` is implemented as a measured AMD candidate: one ROCm kernel computes `cp*s1`, `cp*s2`, and `cp*t0` products for each rejection round. +- `tail16_base` and `tail16_cp_fuse` separate small-tail finish behavior from the fused pointwise candidate. +- `yhat_dup` measures whether duplicating `y` at sample time beats the explicit device-to-device copy. +- The default build keeps these candidates off until the matrix proves a conservative target-specific gain. + +## Current Large-Sweep Sign Best + +| target | mode | batch | sign ops/s | path | log | +| --- | --- | ---: | ---: | --- | --- | +| aigis1 | independent | 16384 | 77989 | decomp-pipeline | aigis1_amd_independent_b16384.log | +| aigis1 | paper | 16384 | 83181 | decomp-pipeline | aigis1_amd_paper_b16384.log | +| aigis2 | independent | 32768 | 47248 | decomp-pipeline | aigis2_amd_independent_b32768.log | +| aigis2 | paper | 32768 | 50046 | decomp-pipeline | aigis2_amd_paper_b32768.log | +| aigis3 | independent | 32768 | 41240 | decomp-pipeline | aigis3_amd_independent_b32768.log | +| aigis3 | paper | 16384 | 41424 | decomp-pipeline | aigis3_amd_paper_b16384.log | +| mldsa44 | independent | 16384 | 106523 | decomp-pipeline | mldsa44_amd_independent_b16384.log | +| mldsa44 | paper | 16384 | 97354 | decomp-pipeline | mldsa44_amd_paper_b16384.log | +| mldsa65 | independent | 16384 | 70159 | decomp-pipeline | mldsa65_amd_independent_b16384.log | +| mldsa65 | paper | 8192 | 53968 | decomp-pipeline | mldsa65_amd_paper_b8192.log | +| mldsa87 | independent | 8192 | 49175 | decomp-pipeline | mldsa87_amd_independent_b8192.log | +| mldsa87 | paper | 8192 | 51185 | decomp-pipeline | mldsa87_amd_paper_b8192.log | + +## AMD Feature Matrix Winners + +| target | mode | batch | best variant | speedup vs base | sign ops/s | log | +| --- | --- | ---: | --- | ---: | ---: | --- | +| aigis1 | independent | 1024 | base | 1.0000 | 21166 | aigis1_base_independent_b1024_r1.log;aigis1_base_independent_b1024_r2.log | +| aigis1 | independent | 16384 | wave64_ctrl | 1.2076 | 86532 | aigis1_wave64_ctrl_independent_b16384_r1.log;aigis1_wave64_ctrl_independent_b16384_r2.log | +| aigis1 | independent | 32768 | wave64_ctrl | 1.0720 | 71926 | aigis1_wave64_ctrl_independent_b32768_r1.log;aigis1_wave64_ctrl_independent_b32768_r2.log | +| aigis1 | independent | 8192 | cp_fuse | 1.2375 | 64504 | aigis1_cp_fuse_independent_b8192_r1.log;aigis1_cp_fuse_independent_b8192_r2.log | +| aigis1 | paper | 1024 | wave64_ctrl | 1.4385 | 21375 | aigis1_wave64_ctrl_paper_b1024_r1.log;aigis1_wave64_ctrl_paper_b1024_r2.log | +| aigis1 | paper | 16384 | wave64_ctrl | 1.2519 | 82738 | aigis1_wave64_ctrl_paper_b16384_r1.log;aigis1_wave64_ctrl_paper_b16384_r2.log | +| aigis1 | paper | 32768 | tail16_base | 1.0767 | 72992 | aigis1_tail16_base_paper_b32768_r1.log;aigis1_tail16_base_paper_b32768_r2.log | +| aigis1 | paper | 8192 | adaptive | 1.0853 | 63246 | aigis1_adaptive_paper_b8192_r1.log;aigis1_adaptive_paper_b8192_r2.log | +| aigis2 | independent | 1024 | yhat_dup | 1.2031 | 12854 | aigis2_yhat_dup_independent_b1024.log;aigis2_yhat_dup_independent_b1024_r1.log;aigis2_yhat_dup_independent_b1024_r2.log | +| aigis2 | independent | 16384 | adaptive | 1.1586 | 50797 | aigis2_adaptive_independent_b16384.log;aigis2_adaptive_independent_b16384_r1.log;aigis2_adaptive_independent_b16384_r2.log | +| aigis2 | independent | 32768 | base | 1.0000 | 46575 | aigis2_base_independent_b32768.log;aigis2_base_independent_b32768_r1.log;aigis2_base_independent_b32768_r2.log | +| aigis2 | independent | 8192 | wave64_ctrl | 1.1985 | 41570 | aigis2_wave64_ctrl_independent_b8192.log;aigis2_wave64_ctrl_independent_b8192_r1.log;aigis2_wave64_ctrl_independent_b8192_r2.log | +| aigis2 | paper | 1024 | check8 | 1.3242 | 14708 | aigis2_check8_paper_b1024.log;aigis2_check8_paper_b1024_r1.log;aigis2_check8_paper_b1024_r2.log | +| aigis2 | paper | 16384 | wave64_ctrl | 1.0752 | 50137 | aigis2_wave64_ctrl_paper_b16384.log;aigis2_wave64_ctrl_paper_b16384_r1.log;aigis2_wave64_ctrl_paper_b16384_r2.log | +| aigis2 | paper | 32768 | adaptive | 1.0998 | 49197 | aigis2_adaptive_paper_b32768.log;aigis2_adaptive_paper_b32768_r1.log;aigis2_adaptive_paper_b32768_r2.log | +| aigis2 | paper | 8192 | cp_fuse | 1.2171 | 38465 | aigis2_cp_fuse_paper_b8192.log;aigis2_cp_fuse_paper_b8192_r1.log;aigis2_cp_fuse_paper_b8192_r2.log | +| aigis3 | independent | 1024 | adaptive | 1.1214 | 12032 | aigis3_adaptive_independent_b1024_r1.log;aigis3_adaptive_independent_b1024_r2.log | +| aigis3 | independent | 16384 | wave64_ctrl | 1.4200 | 42911 | aigis3_wave64_ctrl_independent_b16384_r1.log;aigis3_wave64_ctrl_independent_b16384_r2.log | +| aigis3 | independent | 32768 | cp_fuse | 1.1429 | 39925 | aigis3_cp_fuse_independent_b32768_r1.log;aigis3_cp_fuse_independent_b32768_r2.log | +| aigis3 | independent | 8192 | base | 1.0000 | 36537 | aigis3_base_independent_b8192_r1.log;aigis3_base_independent_b8192_r2.log | +| aigis3 | paper | 1024 | base | 1.0000 | 12154 | aigis3_base_paper_b1024_r1.log;aigis3_base_paper_b1024_r2.log | +| aigis3 | paper | 16384 | adaptive | 1.0331 | 41904 | aigis3_adaptive_paper_b16384_r1.log;aigis3_adaptive_paper_b16384_r2.log | +| aigis3 | paper | 32768 | base | 1.0000 | 42314 | aigis3_base_paper_b32768_r1.log;aigis3_base_paper_b32768_r2.log | +| aigis3 | paper | 8192 | adaptive | 1.2261 | 37768 | aigis3_adaptive_paper_b8192_r1.log;aigis3_adaptive_paper_b8192_r2.log | +| mldsa44 | independent | 1024 | cp_fuse | 1.3923 | 39625 | mldsa44_cp_fuse_independent_b1024.log;mldsa44_cp_fuse_independent_b1024_r1.log;mldsa44_cp_fuse_independent_b1024_r2.log | +| mldsa44 | independent | 16384 | wave64_ctrl | 1.0956 | 98699 | mldsa44_wave64_ctrl_independent_b16384.log;mldsa44_wave64_ctrl_independent_b16384_r1.log;mldsa44_wave64_ctrl_independent_b16384_r2.log | +| mldsa44 | independent | 32768 | tail16_base | 1.1110 | 100649 | mldsa44_tail16_base_independent_b32768.log;mldsa44_tail16_base_independent_b32768_r1.log;mldsa44_tail16_base_independent_b32768_r2.log | +| mldsa44 | independent | 8192 | wave64_ctrl | 1.0026 | 99262 | mldsa44_wave64_ctrl_independent_b8192.log;mldsa44_wave64_ctrl_independent_b8192_r1.log;mldsa44_wave64_ctrl_independent_b8192_r2.log | +| mldsa44 | paper | 1024 | cp_fuse | 1.3053 | 39605 | mldsa44_cp_fuse_paper_b1024.log;mldsa44_cp_fuse_paper_b1024_r1.log;mldsa44_cp_fuse_paper_b1024_r2.log | +| mldsa44 | paper | 16384 | tail16_cp_fuse | 1.0543 | 102919 | mldsa44_tail16_cp_fuse_paper_b16384.log;mldsa44_tail16_cp_fuse_paper_b16384_r1.log;mldsa44_tail16_cp_fuse_paper_b16384_r2.log | +| mldsa44 | paper | 32768 | tail16_base | 1.1936 | 98810 | mldsa44_tail16_base_paper_b32768.log;mldsa44_tail16_base_paper_b32768_r1.log;mldsa44_tail16_base_paper_b32768_r2.log | +| mldsa44 | paper | 8192 | base | 1.0000 | 96740 | mldsa44_base_paper_b8192.log;mldsa44_base_paper_b8192_r1.log;mldsa44_base_paper_b8192_r2.log | +| mldsa65 | independent | 1024 | check8 | 1.0973 | 24036 | mldsa65_check8_independent_b1024_r1.log;mldsa65_check8_independent_b1024_r2.log | +| mldsa65 | independent | 16384 | cp_fuse | 1.3625 | 62591 | mldsa65_cp_fuse_independent_b16384_r1.log;mldsa65_cp_fuse_independent_b16384_r2.log | +| mldsa65 | independent | 32768 | wave64_ctrl | 1.2118 | 60904 | mldsa65_wave64_ctrl_independent_b32768_r1.log;mldsa65_wave64_ctrl_independent_b32768_r2.log | +| mldsa65 | independent | 8192 | check16 | 1.2116 | 55130 | mldsa65_check16_independent_b8192_r1.log;mldsa65_check16_independent_b8192_r2.log | +| mldsa65 | paper | 1024 | yhat_dup | 1.0289 | 19158 | mldsa65_yhat_dup_paper_b1024_r1.log;mldsa65_yhat_dup_paper_b1024_r2.log | +| mldsa65 | paper | 16384 | tail16_base | 1.1451 | 61008 | mldsa65_tail16_base_paper_b16384_r1.log;mldsa65_tail16_base_paper_b16384_r2.log | +| mldsa65 | paper | 32768 | tail16_base | 1.1380 | 56954 | mldsa65_tail16_base_paper_b32768_r1.log;mldsa65_tail16_base_paper_b32768_r2.log | +| mldsa65 | paper | 8192 | base | 1.0000 | 61043 | mldsa65_base_paper_b8192_r1.log;mldsa65_base_paper_b8192_r2.log | +| mldsa87 | independent | 1024 | tail16 | 1.5159 | 22648 | mldsa87_tail16_independent_b1024.log | +| mldsa87 | independent | 16384 | base | 1.0000 | 46069 | mldsa87_base_independent_b16384.log;mldsa87_base_independent_b16384_r1.log;mldsa87_base_independent_b16384_r2.log | +| mldsa87 | independent | 32768 | cp_fuse | 1.0854 | 48998 | mldsa87_cp_fuse_independent_b32768.log;mldsa87_cp_fuse_independent_b32768_r1.log;mldsa87_cp_fuse_independent_b32768_r2.log | +| mldsa87 | independent | 8192 | cp_fuse | 1.0052 | 50385 | mldsa87_cp_fuse_independent_b8192.log;mldsa87_cp_fuse_independent_b8192_r1.log;mldsa87_cp_fuse_independent_b8192_r2.log | +| mldsa87 | paper | 1024 | yhat_dup | 1.1197 | 16756 | mldsa87_yhat_dup_paper_b1024.log;mldsa87_yhat_dup_paper_b1024_r1.log;mldsa87_yhat_dup_paper_b1024_r2.log | +| mldsa87 | paper | 16384 | tail16_cp_fuse | 1.1680 | 46352 | mldsa87_tail16_cp_fuse_paper_b16384.log;mldsa87_tail16_cp_fuse_paper_b16384_r1.log;mldsa87_tail16_cp_fuse_paper_b16384_r2.log | +| mldsa87 | paper | 32768 | tail16_cp_fuse | 1.1303 | 47494 | mldsa87_tail16_cp_fuse_paper_b32768.log;mldsa87_tail16_cp_fuse_paper_b32768_r1.log;mldsa87_tail16_cp_fuse_paper_b32768_r2.log | +| mldsa87 | paper | 8192 | base | 1.0000 | 55762 | mldsa87_base_paper_b8192.log;mldsa87_base_paper_b8192_r1.log;mldsa87_base_paper_b8192_r2.log | + +Matrix interpretation: local wins are useful evidence. The `adaptive` row tests whether those wins can be captured in one target/mode/batch-aware build without promoting a globally regressing macro. + +## AMD Limitation Evidence + +The monolithic/cached-style signing candidates are retained as negative evidence; representative failures: + +| target/variant | exit | hint | +| --- | ---: | --- | +| aigis1_mono_bs1_mono_bs1_b1 | 1 | [Sign] FAIL: cached/monolithic paths failed | +| aigis1_mono_bs1_mono_bs1_b128 | 1 | [Sign] FAIL: cached/monolithic paths failed | +| aigis1_mono_bs1_mono_bs1_b32 | 1 | [Sign] FAIL: cached/monolithic paths failed | +| aigis1_mono_bs1_mono_bs1_b8 | 1 | [Sign] FAIL: cached/monolithic paths failed | +| mldsa44_mono_bs1_mono_bs1_b1 | 1 | [Sign] FAIL: cached/monolithic paths failed | +| mldsa44_mono_bs1_mono_bs1_b128 | 1 | [Sign] FAIL: cached/monolithic paths failed | +| mldsa44_mono_bs1_mono_bs1_b32 | 1 | [Sign] FAIL: cached/monolithic paths failed | +| mldsa44_mono_bs1_mono_bs1_b8 | 1 | [Sign] FAIL: cached/monolithic paths failed | +| mldsa44_mono_bs2_mono_bs2_b1 | 1 | [Sign] FAIL: cached/monolithic paths failed | +| mldsa44_mono_bs2_mono_bs2_b128 | 1 | [Sign] FAIL: cached/monolithic paths failed | +| mldsa44_mono_bs2_mono_bs2_b32 | 1 | [Sign] FAIL: cached/monolithic paths failed | +| mldsa44_mono_bs2_mono_bs2_b8 | 1 | [Sign] FAIL: cached/monolithic paths failed | + +## Next Tuning Step + +Run `python3 amd_tools/select_sig_amd_variants.py`, inspect `amd_results/sig_amd_variant_plan.md`, then build selected variants. If `adaptive` is promoted, rerun smoke/debug/large-sweep to collect final evidence. diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/sig_six_parameter_final_decision_2026-06-16.md b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/sig_six_parameter_final_decision_2026-06-16.md new file mode 100644 index 000000000..9d05e086f --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/sig_six_parameter_final_decision_2026-06-16.md @@ -0,0 +1,81 @@ +# 2026-06-16 签名六参数全优化最终决策 + +- 日期: 2026-06-16 +- 工程: `/app/pqc_rocm_full_20260614/amd_sig_anchor_results_20260605_031411` +- 覆盖目标: `mldsa44 / mldsa65 / mldsa87 / aigis1 / aigis2 / aigis3` +- 覆盖模式: `paper / independent` +- 覆盖批大小: `1024 / 8192 / 16384 / 32768` +- 覆盖候选: `base / adaptive / check8 / check16 / wave64_ctrl / cp_fuse / tail16_base / tail16_cp_fuse / yhat_dup` +- repeat: 2 +- 最终结论: 六个目标的稳定 selected build 全部选择 `base` + +## Selected Build + +保守规则: + +```text +non-base variant must: +1. pass every measured cell +2. keep min speedup >= 1.0000 +3. reach geomean >= 1.0300 +``` + +结果: + +| target | selected variant | reason | +| --- | --- | --- | +| mldsa44 | base | no conservative non-base winner | +| mldsa65 | base | no conservative non-base winner | +| mldsa87 | base | no conservative non-base winner | +| aigis1 | base | no conservative non-base winner | +| aigis2 | base | no conservative non-base winner | +| aigis3 | base | no conservative non-base winner | + +对应 `sig_amd_variant_plan.env`: + +```bash +SIG_AMD_VARIANT_MLDSA44=base +SIG_AMD_VARIANT_MLDSA65=base +SIG_AMD_VARIANT_MLDSA87=base +SIG_AMD_VARIANT_AIGIS1=base +SIG_AMD_VARIANT_AIGIS2=base +SIG_AMD_VARIANT_AIGIS3=base +``` + +## 关键判断 + +全量 feature matrix 说明: + +1. 局部 winner 很多。 +2. 但没有任何候选能在单个 target 的 8 个组合中全部不退化。 +3. 因此最终稳定 build 不能简单选择某个全局宏。 +4. 候选优化适合写成 feature matrix / resource-aware policy evidence,而不是稳定默认路径。 + +## 候选诊断摘要 + +| target | strongest candidate | geomean | min speedup | wins/losses | decision | +| --- | --- | ---: | ---: | --- | --- | +| mldsa44 | wave64_ctrl | 1.0861 | 0.9924 | 7 / 1 | min below 1.0, keep as candidate | +| mldsa65 | wave64_ctrl | 1.0524 | 0.8368 | 4 / 4 | unstable | +| mldsa87 | cp_fuse | 1.0257 | 0.8743 | 5 / 3 | unstable | +| aigis1 | wave64_ctrl | 1.1217 | 0.9089 | 6 / 2 | strong local candidate, not global | +| aigis2 | adaptive | 1.0680 | 0.9066 | 6 / 2 | strong local candidate, not global | +| aigis3 | wave64_ctrl | 1.0025 | 0.8504 | 3 / 5 | weak/unstable | + +## 论文表述 + +建议写成: + +> We keep the resource-aware decomposed pipeline as the stable ROCm baseline. Although feature-matrix candidates such as `wave64_ctrl`, `cp_fuse`, `check8/check16`, `tail16`, and `adaptive` produce local wins, none satisfies the conservative target-level rule across all measured mode/batch cells. We therefore separate stable baseline results from candidate optimization evidence. + +中文表达: + +> 全量 feature matrix 显示,AMD ROCm 上的签名优化不是单一宏开关可以全局解决的问题。不同 target、benchmark mode 和 batch size 下最优策略不同,且候选策略存在明显回退风险。因此最终稳定构建保留 resource-aware decomp pipeline,候选优化作为局部收益和资源感知调优证据单独报告。 + +## 下一步 + +1. 构建 selected build,即 base 六目标。 +2. 运行 policy smoke。 +3. 运行 debug matrix。 +4. 运行 large sweep 生成最终签名性能表。 +5. 对代表目标做资源归因 profile。 diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/table_sig_final_large_sweep_2026-06-16.md b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/table_sig_final_large_sweep_2026-06-16.md new file mode 100644 index 000000000..685a08520 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/table_sig_final_large_sweep_2026-06-16.md @@ -0,0 +1,26 @@ +# table_sig_final_large_sweep_2026-06-16 + +## Final large sweep summary + +All targets use the stable `base` build for the final selected path. + +| Family | Target | Mode | Keygen batch | Keygen ms | Keygen ops/s | Sign batch | Sign ms | Sign ops/s | Verify batch | Verify ms | Verify ops/s | +| --- | --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | +| ML-DSA | 44 | independent | 32768 | 34.328 | 954556 | 32768 | 299.174 | 109528 | 8192 | 1.211 | 6763774 | +| ML-DSA | 44 | paper | 16384 | 8.459 | 1936824 | 16384 | 180.552 | 90744 | 8192 | 1.213 | 6753549 | +| ML-DSA | 65 | independent | 32768 | 53.505 | 612429 | 16384 | 257.458 | 63638 | 8192 | 1.734 | 4724442 | +| ML-DSA | 65 | paper | 16384 | 10.908 | 1502057 | 16384 | 262.670 | 62375 | 8192 | 1.725 | 4749933 | +| ML-DSA | 87 | independent | 32768 | 73.646 | 444940 | 32768 | 599.155 | 54690 | 8192 | 2.422 | 3382030 | +| ML-DSA | 87 | paper | 32768 | 28.239 | 1160371 | 16384 | 356.069 | 46014 | 8192 | 2.454 | 3337606 | +| Aigis-sig | 1 | independent | 32768 | 32.023 | 1023273 | 32768 | 425.459 | 77018 | 16384 | 2.105 | 7781688 | +| Aigis-sig | 1 | paper | 32768 | 16.474 | 1989031 | 32768 | 486.455 | 67361 | 16384 | 2.120 | 7728283 | +| Aigis-sig | 2 | independent | 32768 | 42.804 | 765542 | 16384 | 422.524 | 38776 | 8192 | 1.285 | 6375467 | +| Aigis-sig | 2 | paper | 16384 | 9.460 | 1731964 | 16384 | 346.621 | 47268 | 8192 | 1.298 | 6309734 | +| Aigis-sig | 3 | independent | 32768 | 55.544 | 589950 | 16384 | 391.438 | 41856 | 8192 | 1.648 | 4970454 | +| Aigis-sig | 3 | paper | 32768 | 22.033 | 1487207 | 32768 | 798.910 | 41016 | 8192 | 1.668 | 4911425 | + +## Notes + +- Stable baseline: `decomp-pipeline=on` +- Candidate variants were not promoted to the final build +- This table is the final paper-facing large-batch throughput summary diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/table_sig_local_winners_feature_matrix_2026-06-16.md b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/table_sig_local_winners_feature_matrix_2026-06-16.md new file mode 100644 index 000000000..b158d1fd8 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/evidence/table_sig_local_winners_feature_matrix_2026-06-16.md @@ -0,0 +1,51 @@ +# table_sig_local_winners_feature_matrix_2026-06-16 + +## Local winners from feature matrix + +These rows are evidence for workload-sensitive tuning. They are not final default policies. + +| Target | Mode | Batch | Local winner | Speedup vs base | Sign ops/s | +| --- | --- | ---: | --- | ---: | ---: | +| Aigis-sig1 | independent | 8192 | cp_fuse | 1.2375 | 64504 | +| Aigis-sig1 | independent | 16384 | wave64_ctrl | 1.2076 | 86532 | +| Aigis-sig1 | independent | 32768 | wave64_ctrl | 1.0720 | 71926 | +| Aigis-sig1 | paper | 1024 | wave64_ctrl | 1.4385 | 21375 | +| Aigis-sig1 | paper | 8192 | adaptive | 1.0853 | 63246 | +| Aigis-sig1 | paper | 16384 | wave64_ctrl | 1.2519 | 82738 | +| Aigis-sig1 | paper | 32768 | tail16_base | 1.0767 | 72992 | +| Aigis-sig2 | independent | 1024 | yhat_dup | 1.2031 | 12854 | +| Aigis-sig2 | independent | 8192 | wave64_ctrl | 1.1985 | 41570 | +| Aigis-sig2 | independent | 16384 | adaptive | 1.1586 | 50797 | +| Aigis-sig2 | paper | 1024 | check8 | 1.3242 | 14708 | +| Aigis-sig2 | paper | 8192 | cp_fuse | 1.2171 | 38465 | +| Aigis-sig2 | paper | 16384 | wave64_ctrl | 1.0752 | 50137 | +| Aigis-sig2 | paper | 32768 | adaptive | 1.0998 | 49197 | +| Aigis-sig3 | independent | 1024 | adaptive | 1.1214 | 12032 | +| Aigis-sig3 | independent | 16384 | wave64_ctrl | 1.4200 | 42911 | +| Aigis-sig3 | independent | 32768 | cp_fuse | 1.1429 | 39925 | +| Aigis-sig3 | paper | 8192 | adaptive | 1.2261 | 37768 | +| Aigis-sig3 | paper | 16384 | adaptive | 1.0331 | 41904 | +| ML-DSA-44 | independent | 1024 | cp_fuse | 1.3923 | 39625 | +| ML-DSA-44 | independent | 8192 | wave64_ctrl | 1.0026 | 99262 | +| ML-DSA-44 | independent | 16384 | wave64_ctrl | 1.0956 | 98699 | +| ML-DSA-44 | independent | 32768 | tail16_base | 1.1110 | 100649 | +| ML-DSA-44 | paper | 1024 | cp_fuse | 1.3053 | 39605 | +| ML-DSA-44 | paper | 16384 | tail16_cp_fuse | 1.0543 | 102919 | +| ML-DSA-44 | paper | 32768 | tail16_base | 1.1936 | 98810 | +| ML-DSA-65 | independent | 1024 | check8 | 1.0973 | 24036 | +| ML-DSA-65 | independent | 8192 | check16 | 1.2116 | 55130 | +| ML-DSA-65 | independent | 16384 | cp_fuse | 1.3625 | 62591 | +| ML-DSA-65 | independent | 32768 | wave64_ctrl | 1.2118 | 60904 | +| ML-DSA-65 | paper | 1024 | yhat_dup | 1.0289 | 19158 | +| ML-DSA-65 | paper | 16384 | tail16_base | 1.1451 | 61008 | +| ML-DSA-65 | paper | 32768 | tail16_base | 1.1380 | 56954 | +| ML-DSA-87 | independent | 1024 | tail16 | 1.5159 | 22648 | +| ML-DSA-87 | independent | 8192 | cp_fuse | 1.0052 | 50385 | +| ML-DSA-87 | independent | 32768 | cp_fuse | 1.0854 | 48998 | +| ML-DSA-87 | paper | 1024 | yhat_dup | 1.1197 | 16756 | +| ML-DSA-87 | paper | 16384 | tail16_cp_fuse | 1.1680 | 46352 | +| ML-DSA-87 | paper | 32768 | tail16_cp_fuse | 1.1303 | 47494 | + +## Interpretation + +The table shows that optimization is workload-sensitive. Large local speedups exist, but they do not satisfy the conservative no-regression rule across all measured cells. Therefore, local winners should be discussed as candidate evidence instead of being merged into the stable default build. diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/AMD_RUNBOOK.md b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/AMD_RUNBOOK.md new file mode 100644 index 000000000..c0f804bcb --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/AMD_RUNBOOK.md @@ -0,0 +1,292 @@ +# AMD ROCm runbook for Kyber/Aigis-enc + +This directory is the AMD/JupyterLab entry point for the KEM part of the +project. The scripts mirror the signature-side AMD workflow and keep all +generated logs under `amd_results/`. + +## 1. Build + +```bash +bash build_hip.sh +``` + +Useful overrides: + +```bash +ROCM_ARCH=gfx1100 KEM_SERIAL_TPB=64 bash build_hip.sh +bash build_hip.sh kyber768 +``` + +Outputs: + +```text +kyber512_amd kyber768_amd kyber1024_amd +aigisenc1_amd aigisenc2_amd aigisenc3_amd aigisenc4_amd +amd_results/build/*.log +``` + +## 2. Correctness smoke test + +```bash +bash run_kem_smoke_amd.sh +``` + +Outputs: + +```text +amd_results/smoke/*.log +amd_results/kem_smoke_summary.csv +``` + +## 3. Batch and stream sweep + +```bash +bash run_kem_sweep_amd.sh +``` + +Outputs: + +```text +amd_results/sweep/*.log +amd_results/kem_sweep_summary.csv +amd_results/kem_best.csv +``` + +## 4. Profile one target + +```bash +bash profile_kem_one_amd.sh kyber768_amd 8192 3 +``` + +This runs the built-in pipeline stage timer first. If `rocprofv3` is available, +it also records a ROCm trace under `amd_results/profile/`. + +## 5. Final report run + +Use this after correctness and tuning are stable. It builds the stable AMD +configuration, runs final KEM throughput tests, and writes a timestamped report +directory. + +```bash +bash run_kem_final_report_amd.sh +``` + +Expected outputs: + +```text +amd_results/final_report_/ +amd_results/final_report_/kem_final_extract.txt +``` + +The 2026-06-12 reference report used: + +```text +KEM_KEYGEN_TPB=256 +KEM_ENCAPS_TPB=128 +KEM_DECAPS_TPB=128 +Kyber batch=32768 +Aigis-enc batch=65536 +n_ops=20 +``` + +## 6. ROCm resource/profile run + +Use this when the paper/PPT needs bottleneck and resource evidence, not just +throughput numbers. + +```bash +bash run_kem_resource_profile_amd.sh kyber768 32768 200 +``` + +Expected outputs: + +```text +amd_results/resource_profile_kyber768_/ +amd_results/resource_profile_kyber768_/rocprofv3/ +amd_results/resource_profile_kyber768_/rocm_smi_during_kyber768.log +``` + +Interpretation from the 2026-06-12 run: + +- `rocprofv3` shows sampling/XOF/rejection sampling dominates Kyber-768 + pipeline keygen. +- Serial KEM kernels show high VGPR and scratch usage. +- `rocm-smi` shows 99%-100% GPU utilization, low VRAM pressure, and about + 237-243 W during the long Kyber-768 run. + +## 7. Buffer reuse benchmark + +This benchmark is useful for the final workflow/Demo because repeated file +processing should reuse device buffers instead of reallocating every round. + +```bash +KEM_KEYGEN_TPB=256 KEM_ENCAPS_TPB=128 KEM_DECAPS_TPB=128 bash build_hip.sh kyber768 +./kyber768_amd --batch 32768 --n-ops 5 --reuse-bench 20 --no-correctness +``` + +2026-06-12 reference result: + +```text +Alloc-each-round: 1.93M full-KEM instances/sec +Reuse buffers: 2.05M full-KEM instances/sec +Reuse speedup: 1.064x +``` + +## 8. KEM tuning matrix + +When continuing optimization, start with Kyber-768 because it is the clearest +mainline comparison point. The tuning script recompiles several ROCm launch and +compiler configurations, runs throughput, and emits a CSV summary. + +```bash +bash run_kem_tune_amd.sh kyber768 +``` + +For a faster first pass: + +```bash +BATCH=32768 N_OPS=10 DO_CORRECTNESS=0 bash run_kem_tune_amd.sh kyber768 +``` + +Rank results: + +```bash +latest=$(ls -td amd_results/tune_kyber768_* | head -1) +sort -t, -k13,13nr "$latest/tune_summary.csv" | head +sort -t, -k14,14nr "$latest/tune_summary.csv" | head +sort -t, -k15,15nr "$latest/tune_summary.csv" | head +cat "$latest/pipeline_candidates.log" +``` + +Promote only stable improvements into `run_kem_final_report_amd.sh`. + +## 9. All-parameter bounds probe + +Use this to test all launch-bounds combinations for all seven KEM targets: + +```bash +bash run_kem_all_bounds_probe_amd.sh +``` + +It covers: + +```text +Kyber-512 / Kyber-768 / Kyber-1024 +Aigis-enc-1 / Aigis-enc-2 / Aigis-enc-3 / Aigis-enc-4 +bounds 000 / 001 / 010 / 011 / 100 / 101 / 110 / 111 +``` + +Outputs: + +```text +amd_results/all_bounds_probe_/all_bounds_probe_raw.csv +amd_results/all_bounds_probe_/all_bounds_probe_avg.csv +amd_results/all_bounds_probe_/all_bounds_probe_best.csv +``` + +Fast first pass: + +```bash +N_OPS=10 REPEATS=1 DO_CORRECTNESS=0 bash run_kem_all_bounds_probe_amd.sh +``` + +Paper-grade pass: + +```bash +N_OPS=30 REPEATS=2 DO_CORRECTNESS=1 bash run_kem_all_bounds_probe_amd.sh +``` + +## 10. All-parameter profile comparison + +After the best bounds are selected, run baseline-vs-tuned ROCm profile +comparison for all seven KEM targets: + +```bash +bash run_kem_all_profile_compare_amd.sh +``` + +It compares: + +```text +baseline bounds=100 +tuned bounds from all_bounds_probe_best.csv: +Kyber-512=001, Kyber-768=010, Kyber-1024=110, +Aigis-enc-1=101, Aigis-enc-2=110, Aigis-enc-3=101, Aigis-enc-4=101 +``` + +Outputs: + +```text +amd_results/profile_compare_/profile_compare_runs.csv +amd_results/profile_compare_/kernel_summary.csv +amd_results/profile_compare_/hip_api_summary.csv +amd_results/profile_compare_/key_kernel_summary.csv +amd_results/profile_compare_/key_kernel_compare.csv +``` + +Fast first pass: + +```bash +N_OPS=10 PROFILE_N_OPS=1 DO_CORRECTNESS=0 bash run_kem_all_profile_compare_amd.sh +``` + +Paper-grade pass: + +```bash +N_OPS=30 PROFILE_N_OPS=1 DO_CORRECTNESS=1 bash run_kem_all_profile_compare_amd.sh +``` + +## 11. ROCm toolbox pass + +Use this after trace/profile comparison to probe additional ROCm tooling: + +```bash +bash run_rocm_toolbox_kem_amd.sh +``` + +Default targets: + +```text +kyber768 kyber1024 aigisenc4 +``` + +All seven targets: + +```bash +TARGETS="kyber512 kyber768 kyber1024 aigisenc1 aigisenc2 aigisenc3 aigisenc4" \ +N_OPS=20 PROFILE_N_OPS=1 bash run_rocm_toolbox_kem_amd.sh +``` + +Outputs: + +```text +amd_results/rocm_toolbox_/tool_discovery.txt +amd_results/rocm_toolbox_/rocprofv3_list_avail.txt +amd_results/rocm_toolbox_/toolbox_runs.csv +amd_results/rocm_toolbox_/*/sys_trace/ +amd_results/rocm_toolbox_/*/pmc/ +amd_results/rocm_toolbox_/pmc_summary.csv +``` + +This script attempts `rocprofv3 --sys-trace`, `rocprofv3 --pmc`, `rocm-smi`, +`rocminfo`, and `hipconfig`. Unsupported counters are skipped automatically. + +## 12. Package before leaving JupyterLab + +Before shutting down the AMD server, package the whole working directory from +`/app` so the raw CSV/log/profile evidence is not lost. + +```bash +cd /app +tar -czf kyberandaigis-enc_amd_results_$(date +%Y%m%d_%H%M%S).tar.gz kyberandaigis-enc +ls -lh kyberandaigis-enc_amd_results_*.tar.gz +``` + +Download the newest `.tar.gz` from JupyterLab. + +## Notes + +- `build_hip.sh` uses `hipcc`, `-DUSE_HIP=1`, and `--offload-arch=gfx1100` by + default to match the current AMD server style. +- Simple HIP migration is only a functional baseline. Use the smoke, sweep, and + profile outputs to decide which ROCm-specific tuning path to take next. diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/KEM_AMD_OPTIMIZATION_LOG.md b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/KEM_AMD_OPTIMIZATION_LOG.md new file mode 100644 index 000000000..c87eadfe6 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/KEM_AMD_OPTIMIZATION_LOG.md @@ -0,0 +1,523 @@ +# KEM AMD Optimization Log + +This log records reproducible optimization steps for the Kyber/Aigis-enc KEM +module on AMD ROCm. Use it as the source material for paper/PPT tables. + +## Current Stable Configuration + +```text +KEM_KEYGEN_TPB=256 +KEM_ENCAPS_TPB=128 +KEM_DECAPS_TPB=128 +Kyber batch=32768 +Aigis-enc batch=65536 +n_ops=20 for final throughput tables +``` + +## Verified Optimizations + +1. Operation-specific TPB tuning + - Replaced one shared `KEM_SERIAL_TPB` path with `KEM_KEYGEN_TPB`, + `KEM_ENCAPS_TPB`, and `KEM_DECAPS_TPB`. + - Kyber-768 keygen benefits from larger keygen TPB on gfx1100. + +2. Device buffer reuse + - Added `--reuse-bench `. + - Kyber-768 full-KEM continuous batch throughput improved from + 1.93M to 2.05M instances/s in the measured run, a 1.064x speedup. + +3. ROCm trace-based bottleneck localization + - `rocprofv3` kernel trace shows `batch_keygen_warp_sample_kernel` + dominates the pipeline keygen path. + - Monolithic serial KEM kernels show high VGPR and scratch usage. + - NTT/matvec kernels are not the primary bottleneck in the current run. + +## Reproducible Scripts + +Final throughput table: + +```bash +bash run_kem_final_report_amd.sh +``` + +Resource and ROCm trace profile: + +```bash +bash run_kem_resource_profile_amd.sh kyber768 32768 200 +``` + +Buffer reuse benchmark: + +```bash +KEM_KEYGEN_TPB=256 KEM_ENCAPS_TPB=128 KEM_DECAPS_TPB=128 bash build_hip.sh kyber768 +./kyber768_amd --batch 32768 --n-ops 5 --reuse-bench 20 --no-correctness +``` + +## Paper-Ready Findings + +- Simple HIP migration is functional but not enough for peak performance. +- ROCm/RDNA3-specific TPB tuning improves Kyber keygen throughput. +- Device buffer reuse improves end-to-end full-KEM workflow throughput. +- `rocprofv3` confirms sample/XOF/rejection sampling is the priority kernel + optimization target, while `rocm-smi` shows high GPU utilization and low VRAM + pressure during long Kyber-768 runs. + +## 2026-06-12 Next Optimization Pass + +Added a ROCm tuning matrix for the KEM module. + +Code changes: + +- `build_hip.sh` now accepts `OPT_LEVEL`, `KEM_KEYPAIR_LAUNCH_BOUNDS`, + `KEM_ENCAPS_LAUNCH_BOUNDS`, `KEM_DECAPS_LAUNCH_BOUNDS`, + `WP_KG_WARPS_BLOCK`, `KEM_PACK_TPB`, and `EXTRA_HIPCC_FLAGS`. +- `batch_kem.cuh` now allows `WP_KG_WARPS_BLOCK` and `KEM_PACK_TPB` to be + provided at compile time. +- `run_kem_tune_amd.sh` sweeps the current serial final-report path first, then + checks pipeline sampling/pack candidates. + +Run on AMD: + +```bash +bash run_kem_tune_amd.sh kyber768 +``` + +Useful quicker runs: + +```bash +BATCH=32768 N_OPS=10 DO_CORRECTNESS=0 bash run_kem_tune_amd.sh kyber768 +BATCH=65536 N_OPS=10 DO_CORRECTNESS=0 bash run_kem_tune_amd.sh aigisenc4 +``` + +After the run, rank candidates: + +```bash +latest=$(ls -td amd_results/tune_kyber768_* | head -1) +sort -t, -k13,13nr "$latest/tune_summary.csv" | head +sort -t, -k14,14nr "$latest/tune_summary.csv" | head +sort -t, -k15,15nr "$latest/tune_summary.csv" | head +cat "$latest/pipeline_candidates.log" +``` + +Decision rule: + +- If a candidate improves one operation without hurting the other two, promote + it to `run_kem_final_report_amd.sh`. +- If the best keygen, encaps, and decaps candidates differ, keep + operation-specific TPB values instead of forcing one shared setting. +- If pipeline `sample` improves but total keygen does not, keep it as a + profiling result rather than final default. + +### Kyber-768 first tuning result + +First AMD run: + +```text +amd_results/tune_kyber768_20260612_085500/tune_summary.csv +``` + +Main findings from the pasted result: + +| Candidate | Keygen | Encaps | Decaps | Interpretation | +|---|---:|---:|---:|---| +| O2 kg=256 enc=128 dec=128 bounds=1/0/0 | 6.28M | 6.00M | 5.64M | Current stable neighborhood | +| O2 kg=256 enc=128 dec=128 bounds=0/1/0 | 6.27M | 7.12M | 5.65M | Best balanced candidate | +| O2 kg=256 enc=128 dec=128 bounds=1/1/0 | 6.22M | 7.14M | 5.64M | Highest encaps candidate | +| O3 kg=256 enc=128 dec=128 bounds=1/0/0 | 6.29M | 5.98M | 5.64M | Highest keygen candidate | +| O3 kg=512 enc=128 dec=128 bounds=1/0/0 | 5.71M | 5.99M | 5.66M | Highest decaps candidate, but hurts keygen | + +Conclusion: + +```text +The first clear optimization signal is KEM_ENCAPS_LAUNCH_BOUNDS=1 for +Kyber-768. It improves encaps from about 6.0M ops/s to about 7.1M ops/s while +keygen and decaps remain close to the previous stable region. +``` + +Next confirmation run: + +```bash +bash run_kem_confirm_amd.sh kyber768 +``` + +Fast confirmation: + +```bash +N_OPS=30 REPEATS=2 bash run_kem_confirm_amd.sh kyber768 +``` + +If `balanced_encbounds_o2_b010` or `encbest_o2_b110` remains stable, promote +`KEM_ENCAPS_LAUNCH_BOUNDS=1` to the Kyber-768 final configuration. + +### Kyber-768 confirmation result + +Confirmation run: + +```text +amd_results/confirm_kyber768_20260612_090854/confirm_summary.csv +batch=32768 n_ops=50 repeats=3 +``` + +Average of three repeats: + +| Tag | Keygen avg | Encaps avg | Decaps avg | Decision | +|---|---:|---:|---:|---| +| baseline_o2_256_128_128_b100 | 6.485M | 6.004M | 5.490M | Old stable baseline | +| balanced_encbounds_o2_b010 | 6.475M | 7.102M | 5.435M | Promote | +| encbest_o2_b110 | 6.466M | 7.101M | 5.429M | Similar encaps, slightly lower keygen/decaps | +| keygenbest_o3_b100 | 6.502M | 5.993M | 5.471M | No encaps gain | +| decbest_o3_kg512_b100 | 5.747M | 6.000M | 5.482M | Hurts keygen | + +Confirmed improvement: + +```text +Kyber-768 encaps improved from about 6.00M ops/s to about 7.10M ops/s +(~18.3% relative improvement) by enabling KEM_ENCAPS_LAUNCH_BOUNDS=1 while +using keypair_bounds=0 and decaps_bounds=0 for the final Kyber-768 build. +``` + +Trade-off: + +```text +The tuned config reduces decaps by about 1.0% and keygen by about 0.2% in the +repeat average, which is acceptable because encaps gains about 18%. +``` + +Promoted final Kyber-768 config: + +```text +OPT_LEVEL=O2 +KEM_KEYGEN_TPB=256 +KEM_ENCAPS_TPB=128 +KEM_DECAPS_TPB=128 +KEM_KEYPAIR_LAUNCH_BOUNDS=0 +KEM_ENCAPS_LAUNCH_BOUNDS=1 +KEM_DECAPS_LAUNCH_BOUNDS=0 +``` + +`run_kem_final_report_amd.sh` now applies this tuned config to Kyber-768 only. + +### Tuned final report result + +Tuned final report pasted from AMD: + +```text +Kyber-512: Keygen 10.050M, Encaps 11.342M, Decaps 7.448M ops/s +Kyber-768: Keygen 6.229M, Encaps 7.151M, Decaps 5.640M ops/s +Kyber-1024: Keygen 4.470M, Encaps 4.290M, Decaps 3.829M ops/s +Aigis-enc-1: Keygen 10.299M, Encaps 8.204M, Decaps 5.769M ops/s +Aigis-enc-2: Keygen 6.671M, Encaps 5.240M, Decaps 3.841M ops/s +Aigis-enc-3: Keygen 6.284M, Encaps 5.145M, Decaps 3.309M ops/s +Aigis-enc-4: Keygen 4.147M, Encaps 3.450M, Decaps 2.564M ops/s +``` + +Compared with the previous final report, the key paper-worthy changes are: + +| Target | Operation | Previous | Tuned | Change | +|---|---:|---:|---:|---:| +| Kyber-768 | Encaps | 5.999M | 7.151M | +19.2% | +| Aigis-enc-1 | Encaps | 7.205M | 8.204M | +13.9% | +| Aigis-enc-2 | Encaps | 4.704M | 5.240M | +11.4% | +| Aigis-enc-4 | Encaps | 2.951M | 3.450M | +16.9% | + +Notes: + +- Kyber-768 encaps tuning is confirmed by repeat testing and is now final. +- Aigis-enc also benefited in encaps under the new final-report configuration, + but Aigis-enc-3 decaps regressed in this run. Do not generalize the Kyber + decision to all Aigis variants until the Aigis-specific bounds probe is done. + +### Aigis-enc-4 bounds probe result + +Probe run: + +```text +amd_results/bounds_probe_aigisenc4_20260612_093104/bounds_probe_summary.csv +batch=65536 n_ops=30 repeats=2 +``` + +Average of two repeats: + +| Tag | Bounds | Keygen avg | Encaps avg | Decaps avg | Decision | +|---|---|---:|---:|---:|---| +| baseline_b100 | 1/0/0 | 4.175M | 3.451M | 2.570M | Old baseline | +| encbounds_b010 | 0/1/0 | 4.159M | 3.724M | 2.573M | Good | +| encbounds_b110 | 1/1/0 | 4.173M | 3.721M | 2.571M | Promote | +| allbounds_b111 | 1/1/1 | 4.174M | 2.953M | 2.432M | Reject | + +Confirmed improvement: + +```text +Aigis-enc-4 encaps improved from about 3.45M ops/s to about 3.72M ops/s +(~7.9% relative improvement) by enabling KEM_ENCAPS_LAUNCH_BOUNDS=1. +``` + +Promoted final Aigis-enc-4 config: + +```text +OPT_LEVEL=O2 +KEM_KEYGEN_TPB=256 +KEM_ENCAPS_TPB=128 +KEM_DECAPS_TPB=128 +KEM_KEYPAIR_LAUNCH_BOUNDS=1 +KEM_ENCAPS_LAUNCH_BOUNDS=1 +KEM_DECAPS_LAUNCH_BOUNDS=0 +``` + +Rejected config: + +```text +KEM_DECAPS_LAUNCH_BOUNDS=1 should not be enabled for Aigis-enc-4 because +`allbounds_b111` lowers encaps and decaps substantially. +``` + +### 2026-06-14 all-parameter bounds probe + +Probe run pasted from AMD: + +```text +amd_results/all_bounds_probe_/ +N_OPS=30 REPEATS=2 DO_CORRECTNESS=1 +Kyber batch=32768 +Aigis-enc batch=65536 +``` + +The script tested all 8 launch-bounds combinations for all 7 KEM targets: + +```text +000 / 001 / 010 / 011 / 100 / 101 / 110 / 111 +``` + +Best balanced configurations from `all_bounds_probe_best.csv`: + +| Target | Bounds | Keygen avg | Encaps avg | Decaps avg | Balanced score | +|---|---|---:|---:|---:|---:| +| Kyber-512 | 001 | 10.432M | 11.398M | 8.508M | 10.241M | +| Kyber-768 | 010 | 6.385M | 7.167M | 5.584M | 6.458M | +| Kyber-1024 | 110 | 4.517M | 4.908M | 3.819M | 4.464M | +| Aigis-enc-1 | 101 | 10.478M | 8.217M | 6.497M | 8.379M | +| Aigis-enc-2 | 110 | 6.709M | 5.627M | 3.830M | 5.413M | +| Aigis-enc-3 | 101 | 6.385M | 5.152M | 4.057M | 5.193M | +| Aigis-enc-4 | 101 | 4.171M | 3.447M | 2.962M | 3.518M | + +Interpretation: + +```text +There is no single best launch-bounds setting for all KEM targets. +ROCm tuning must be parameter-set-aware. +``` + +Important detailed findings: + +- `encaps_bounds=1` is consistently strong for Kyber-768, Kyber-1024, + Aigis-enc-1, Aigis-enc-2, and Aigis-enc-4 encaps. +- `decaps_bounds=1` strongly improves Kyber-512, Aigis-enc-1, + Aigis-enc-3, and Aigis-enc-4 decaps. +- Enabling all bounds `111` is often a negative optimization, especially for + Kyber-768 and Aigis-enc-4. + +The final report script now uses the best balanced per-target configuration: + +```text +Kyber-512 bounds=001 +Kyber-768 bounds=010 +Kyber-1024 bounds=110 +Aigis-enc-1 bounds=101 +Aigis-enc-2 bounds=110 +Aigis-enc-3 bounds=101 +Aigis-enc-4 bounds=101 +``` + +### 2026-06-14 balanced-best final report + +Final report pasted from AMD: + +```text +amd_results/final_report_20260614_012319/kem_final_extract.txt +``` + +| Target | Bounds | Keygen ops/s | Encaps ops/s | Decaps ops/s | +|---|---|---:|---:|---:| +| Kyber-512 | 001 | 10,095,164 | 11,368,410 | 8,451,971 | +| Kyber-768 | 010 | 6,276,945 | 7,142,451 | 5,651,891 | +| Kyber-1024 | 110 | 4,447,101 | 4,916,932 | 3,829,267 | +| Aigis-enc-1 | 101 | 10,240,547 | 8,204,293 | 6,497,139 | +| Aigis-enc-2 | 110 | 6,605,967 | 5,630,086 | 3,827,435 | +| Aigis-enc-3 | 101 | 6,305,602 | 5,144,497 | 4,060,120 | +| Aigis-enc-4 | 101 | 4,156,445 | 3,444,781 | 2,961,419 | + +Paper-ready statement: + +```text +After parameter-set-aware ROCm launch-bounds tuning, all seven KEM targets +retain PASS correctness evidence and reach million-to-ten-million-level +throughput. The tuning improves different operations for different parameter +sets, proving that ROCm PQC kernels need per-parameter resource policies rather +than a single global launch configuration. +``` + +### Next: all-parameter profile comparison + +Added scripts: + +```text +run_kem_all_profile_compare_amd.sh +summarize_profile_compare.py +``` + +Purpose: + +```text +Compare baseline bounds=100 against tuned per-target bounds for all seven KEM +targets using rocprofv3 kernel trace and HIP API trace. +``` + +Run on AMD: + +```bash +N_OPS=30 PROFILE_N_OPS=1 DO_CORRECTNESS=1 bash run_kem_all_profile_compare_amd.sh +``` + +Outputs: + +```text +amd_results/profile_compare_/profile_compare_runs.csv +amd_results/profile_compare_/kernel_summary.csv +amd_results/profile_compare_/hip_api_summary.csv +amd_results/profile_compare_/key_kernel_summary.csv +amd_results/profile_compare_/key_kernel_compare.csv +``` + +The key file for paper analysis is: + +```text +key_kernel_compare.csv +``` + +It reports keypair/encaps/decaps serial kernel time, VGPR, SGPR, scratch, +workgroup, and tuned-vs-baseline percentage change. + +### 2026-06-14 all-parameter profile comparison result + +Profile comparison pasted from AMD: + +```text +amd_results/profile_compare_/ +profile_compare_runs.csv +key_kernel_compare.csv +key_kernel_summary.csv +``` + +Throughput change from baseline bounds `100` to tuned per-target bounds: + +| Target | Tuned bounds | Keygen change | Encaps change | Decaps change | Main gain | +|---|---|---:|---:|---:|---| +| Kyber-512 | 001 | -0.3% | -0.2% | +12.9% | Decaps | +| Kyber-768 | 010 | +0.4% | +19.2% | -0.2% | Encaps | +| Kyber-1024 | 110 | +1.2% | +14.2% | -0.9% | Encaps | +| Aigis-enc-1 | 101 | -0.9% | ~0.0% | +12.6% | Decaps | +| Aigis-enc-2 | 110 | -0.1% | +7.4% | -0.1% | Encaps | +| Aigis-enc-3 | 101 | -0.1% | -0.1% | +22.7% | Decaps | +| Aigis-enc-4 | 101 | ~0.0% | +0.1% | +15.2% | Decaps | + +Key kernel time changes from `key_kernel_compare.csv`: + +| Target | Operation improved | Kernel time change | Resource change | +|---|---|---:|---| +| Kyber-512 | Decaps | -11.84% | VGPR 184 -> 200, scratch 14784 -> 14752 | +| Kyber-768 | Encaps | -16.51% | VGPR 184 -> 200, scratch 16064 -> 16048 | +| Kyber-1024 | Encaps | -12.99% | VGPR 184 -> 200, scratch 18144 -> 18128 | +| Aigis-enc-1 | Decaps | -10.50% | VGPR 184 -> 200, scratch 14720 -> 14704 | +| Aigis-enc-2 | Encaps | -6.22% | VGPR 184 -> 200, scratch 16032 -> 16016 | +| Aigis-enc-3 | Decaps | -19.34% | VGPR 184 -> 200, scratch 17088 -> 17072 | +| Aigis-enc-4 | Decaps | -13.53% | VGPR 184 -> 200, scratch 19648 -> 19632 | + +Important interpretation: + +```text +The tuned launch-bounds setting does not simply reduce register count. +For the operations that improve most, VGPR often increases from 184 to 200 +while scratch decreases slightly by 16-32 bytes and kernel runtime drops +significantly. This suggests the improvement comes from ROCm compiler scheduling +and occupancy/resource trade-offs, not from a naive "fewer registers is always +better" rule. +``` + +Per-target analysis: + +- **Kyber-512**: tuned `001` is a decaps-focused configuration. Decaps kernel + time drops 11.84%, matching the 12.9% throughput gain. Keygen/encaps are + almost unchanged in throughput, although profile keypair kernel time is + noisier and increases in the one-iteration trace. +- **Kyber-768**: tuned `010` is a clean encaps optimization. Encaps kernel time + drops 16.51% and throughput rises 19.2%, while keygen/decaps remain stable. + This is the strongest Kyber example for the paper. +- **Kyber-1024**: tuned `110` improves encaps kernel time by 12.99% and + throughput by 14.2%. It also slightly improves keypair kernel time, while + decaps is roughly unchanged. +- **Aigis-enc-1**: tuned `101` mainly improves decaps. Decaps kernel time drops + 10.50% and decaps throughput rises 12.6%. Encaps remains essentially equal. +- **Aigis-enc-2**: tuned `110` mainly improves encaps. Encaps kernel time drops + 6.22% and throughput rises 7.4%. Keypair profile time worsens, so this config + should be described as operation-selective rather than globally faster. +- **Aigis-enc-3**: tuned `101` is a strong decaps optimization. Decaps kernel + time drops 19.34% and throughput rises 22.7%, with keygen/encaps nearly + unchanged in throughput. +- **Aigis-enc-4**: tuned `101` improves decaps by 15.2% in throughput and + lowers decaps kernel time by 13.53%; keygen and encaps are effectively + unchanged. + +Paper-ready conclusion: + +```text +ROCm launch-bounds tuning changes the compiler's resource scheduling decisions. +For PQC KEM kernels with large private state and scratch pressure, the best +configuration is operation- and parameter-set-specific. The measured wins are +not explained by global GPU utilization or VRAM capacity, but by per-kernel +resource trade-offs visible in rocprofv3: key operation runtime drops while VGPR +and scratch shift slightly. +``` + +### Next: ROCm toolbox pass + +Added: + +```text +run_rocm_toolbox_kem_amd.sh +summarize_rocm_pmc.py +``` + +Purpose: + +```text +Probe additional ROCm tools beyond kernel/HIP trace: +tool discovery, rocprofv3 --list-avail, rocprofv3 --sys-trace, +rocprofv3 --pmc hardware counters, rocm-smi sampling, rocminfo, hipconfig, +and rocprof-compute availability. +``` + +Representative run: + +```bash +bash run_rocm_toolbox_kem_amd.sh +``` + +All-target run: + +```bash +TARGETS="kyber512 kyber768 kyber1024 aigisenc1 aigisenc2 aigisenc3 aigisenc4" \ +N_OPS=20 PROFILE_N_OPS=1 bash run_rocm_toolbox_kem_amd.sh +``` + +Expected outputs: + +```text +amd_results/rocm_toolbox_/tool_discovery.txt +amd_results/rocm_toolbox_/rocprofv3_list_avail.txt +amd_results/rocm_toolbox_/toolbox_runs.csv +amd_results/rocm_toolbox_/pmc_summary.csv +``` + +This is an exploratory pass. If the AMD JupyterLab image exposes only a subset +of ROCm counters or lacks `rocprof-compute`, the script records that instead of +failing the workflow. diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/batch_kem.cuh b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/batch_kem.cuh new file mode 100644 index 000000000..6d72d73f9 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/batch_kem.cuh @@ -0,0 +1,797 @@ +/* + * batch_kem.cuh — GPU 批量 KEM 流水线 + * + * 参考 mldsa和aigis-sig/batch_keygen.cuh 的优化架构: + * - Warp 协同采样: 1 warp = 1 实例 (并行矩阵展开 + 噪声采样) + * - 共享内存批量 NTT (batch_ntt_kernel, 1 block/poly) + * - 2D grid 矩阵向量乘 (batch_polyvec_matvec_kernel) + * - SoA 内存布局: data[poly_idx * batch_count * N + inst * N + coeff] + * + * 性能要点 (RTX 3050 Ti): + * - 最优 batch size: Keygen/Encaps=16K, Decaps=8K-16K + * - VRAM 限制: K^2 * B * N * sizeof(int16_t) ≤ 可用显存 + */ + +#ifndef BATCH_KEM_CUH +#define BATCH_KEM_CUH + +#include "rocm_compat.h" +#include +#include +#include "params.h" +#include "reduce.cuh" +#include "fips202.cuh" +#include "ntt.cuh" +#include "poly.cuh" +#include "polyvec.cuh" +#include "cbd.cuh" +#include "kem.cuh" +#include "batch_ntt.cuh" +#include "batch_ops.cuh" + +/* ================================================================ + * 缓冲区结构体 + * ================================================================ */ +struct BatchKemBuffers { + /* 批量 keygen/encaps 工作缓冲区 — SoA 布局 [poly_idx][inst][coeff] */ + int16_t *d_mat; /* K*K * B * N — 矩阵 A (NTT 域) */ + int16_t *d_skpv; /* K * B * N — 私钥 s (NTT 域) */ + int16_t *d_pkpv; /* K * B * N — 公钥多项式 b */ + int16_t *d_e; /* K * B * N — keygen 误差 e */ + + /* KEM 字节缓冲区 */ + uint8_t *d_pk_bytes; /* B * PARAM_PUBLICKEYBYTES */ + uint8_t *d_sk_bytes; /* B * PARAM_SECRETKEYBYTES */ + uint8_t *d_ct_bytes; /* B * PARAM_CIPHERTEXTBYTES */ + uint8_t *d_ss_bytes; /* B * PARAM_SSBYTES */ + + /* 随机种子 */ + uint8_t *d_coins_kg; /* B * 2*SYMBYTES — keygen 种子 */ + uint8_t *d_coins_enc;/* B * SYMBYTES — encaps 种子 */ + + uint8_t *d_publicseed_kg; + uint8_t *d_noiseseed_kg; + + int max_batch; +}; + +static inline void batch_kem_alloc(BatchKemBuffers *buf, int max_batch) +{ + buf->max_batch = max_batch; + cudaMalloc(&buf->d_mat, (size_t)PARAM_K * PARAM_K * max_batch * PARAM_N * sizeof(int16_t)); + cudaMalloc(&buf->d_skpv, (size_t)PARAM_K * max_batch * PARAM_N * sizeof(int16_t)); + cudaMalloc(&buf->d_pkpv, (size_t)PARAM_K * max_batch * PARAM_N * sizeof(int16_t)); + cudaMalloc(&buf->d_e, (size_t)PARAM_K * max_batch * PARAM_N * sizeof(int16_t)); + cudaMalloc(&buf->d_pk_bytes, (size_t)max_batch * PARAM_PUBLICKEYBYTES); + cudaMalloc(&buf->d_sk_bytes, (size_t)max_batch * PARAM_SECRETKEYBYTES); + cudaMalloc(&buf->d_ct_bytes, (size_t)max_batch * PARAM_CIPHERTEXTBYTES); + cudaMalloc(&buf->d_ss_bytes, (size_t)max_batch * PARAM_SSBYTES); + cudaMalloc(&buf->d_coins_kg, (size_t)max_batch * 2 * PARAM_SYMBYTES); + cudaMalloc(&buf->d_coins_enc, (size_t)max_batch * PARAM_SYMBYTES); + cudaMalloc(&buf->d_publicseed_kg, (size_t)max_batch * PARAM_SYMBYTES); + cudaMalloc(&buf->d_noiseseed_kg, (size_t)max_batch * PARAM_SYMBYTES); +} + +static inline void batch_kem_free(BatchKemBuffers *buf) +{ + cudaFree(buf->d_mat); + cudaFree(buf->d_skpv); + cudaFree(buf->d_pkpv); + cudaFree(buf->d_e); + cudaFree(buf->d_pk_bytes); + cudaFree(buf->d_sk_bytes); + cudaFree(buf->d_ct_bytes); + cudaFree(buf->d_ss_bytes); + cudaFree(buf->d_coins_kg); + cudaFree(buf->d_coins_enc); + cudaFree(buf->d_publicseed_kg); + cudaFree(buf->d_noiseseed_kg); +} + +/* ================================================================ + * Warp 协同采样 kernel (KEM 密钥生成) + * 1 warp (32 threads) = 1 实例 + * Lane 0: SHA3-512 展开种子 → (publicseed, noiseseed) + * 全部 lanes: 并行展开矩阵 A 和噪声多项式 s, e + * + * 输出 SoA: + * d_mat[row*K*B*N + col*B*N + inst*N + c] = A[inst][row][col][c] + * d_skpv[i*B*N + inst*N + c] = s[inst][i][c] (未 NTT) + * d_e[i*B*N + inst*N + c] = e[inst][i][c] (未 NTT) + * ================================================================ */ + +#ifndef WP_KG_WARP_SIZE +#define WP_KG_WARP_SIZE 32 +#endif + +#ifndef WP_KG_WARPS_BLOCK +#define WP_KG_WARPS_BLOCK 4 +#endif + +#define WP_KG_TPB (WP_KG_WARP_SIZE * WP_KG_WARPS_BLOCK) + +#ifndef KEM_SPLIT_KEYGEN_SAMPLE +#define KEM_SPLIT_KEYGEN_SAMPLE 0 +#endif + +#ifndef KEM_SERIAL_TPB +#ifdef USE_HIP +#define KEM_SERIAL_TPB 64 +#else +#define KEM_SERIAL_TPB 64 +#endif +#endif + +#ifndef KEM_KEYGEN_TPB +#define KEM_KEYGEN_TPB KEM_SERIAL_TPB +#endif + +#ifndef KEM_ENCAPS_TPB +#define KEM_ENCAPS_TPB KEM_SERIAL_TPB +#endif + +#ifndef KEM_DECAPS_TPB +#define KEM_DECAPS_TPB KEM_SERIAL_TPB +#endif + +__global__ void batch_keygen_warp_sample_kernel( + int16_t * __restrict__ d_mat, /* K*K * B * N */ + int16_t * __restrict__ d_skpv, /* K * B * N */ + int16_t * __restrict__ d_e, /* K * B * N */ + uint8_t * __restrict__ d_publicseed, + const uint8_t * __restrict__ d_coins, /* B * 2*SYMBYTES */ + int batch_count) +{ + int inst = blockIdx.x * WP_KG_WARPS_BLOCK + (threadIdx.x / WP_KG_WARP_SIZE); + int lane = threadIdx.x & (WP_KG_WARP_SIZE - 1); + + if (inst >= batch_count) return; + + /* Warp-level shared: publicseed 和 noiseseed */ + __shared__ uint8_t ws_pub[WP_KG_WARPS_BLOCK][PARAM_SYMBYTES]; + __shared__ uint8_t ws_noise[WP_KG_WARPS_BLOCK][PARAM_SYMBYTES]; + + int warp_id = threadIdx.x / WP_KG_WARP_SIZE; + uint8_t *publicseed = ws_pub[warp_id]; + uint8_t *noiseseed = ws_noise[warp_id]; + + if (lane == 0) { + /* 展开种子: SHA3-512(coins[0:32]) → (publicseed[32], noiseseed[32]) */ + uint8_t buf[2 * PARAM_SYMBYTES]; + sha3_512(buf, d_coins + inst * 2 * PARAM_SYMBYTES, PARAM_SYMBYTES); + for (int i = 0; i < PARAM_SYMBYTES; i++) { + publicseed[i] = buf[i]; + d_publicseed[(size_t)inst * PARAM_SYMBYTES + i] = buf[i]; + } + for (int i = 0; i < PARAM_SYMBYTES; i++) noiseseed[i] = buf[PARAM_SYMBYTES + i]; + } + __syncwarp(); + + /* 矩阵展开: 每个 lane 负责若干多项式 (A[row][col]) */ + int total_mat_polys = PARAM_K * PARAM_K; + for (int p = lane; p < total_mat_polys; p += WP_KG_WARP_SIZE) { + int row = p / PARAM_K; + int col = p % PARAM_K; + + /* 目标地址: SoA 格式 */ + int16_t *dst = d_mat + ((size_t)(row * PARAM_K + col) * batch_count + inst) * PARAM_N; + + uint8_t extseed[PARAM_SYMBYTES + 2]; + for (int i = 0; i < PARAM_SYMBYTES; i++) extseed[i] = publicseed[i]; + +#if ALGORITHM == ALGO_KYBER + extseed[PARAM_SYMBYTES] = (uint8_t)col; /* j */ + extseed[PARAM_SYMBYTES+1] = (uint8_t)row; /* i */ +#elif ALGORITHM == ALGO_AIGIS_ENC + extseed[PARAM_SYMBYTES] = (uint8_t)row; /* i */ + extseed[PARAM_SYMBYTES+1] = (uint8_t)col; /* j */ +#endif + +#if KEM_DIRECT_REJ_UNIFORM + rej_uniform_xof(dst, publicseed, extseed[PARAM_SYMBYTES], extseed[PARAM_SYMBYTES + 1]); +#else + keccak_state state; + shake128_absorb_once(&state, extseed, PARAM_SYMBYTES + 2); + + unsigned int ctr = 0; + uint8_t buf[PARAM_GEN_MATRIX_NBLOCKS * PARAM_XOF_BLOCKBYTES]; + while (ctr < PARAM_N) { + shake128_squeezeblocks(buf, PARAM_GEN_MATRIX_NBLOCKS, &state); + ctr += rej_uniform(dst + ctr, PARAM_N - ctr, + buf, PARAM_GEN_MATRIX_NBLOCKS * PARAM_XOF_BLOCKBYTES); + } +#endif + } + + /* 噪声采样: s[0..K-1], e[0..K-1] */ + for (int i = lane; i < PARAM_K; i += WP_KG_WARP_SIZE) { + int16_t *dst_s = d_skpv + ((size_t)i * batch_count + inst) * PARAM_N; + poly_getnoise_s(dst_s, noiseseed, (uint8_t)i); + } + for (int i = lane; i < PARAM_K; i += WP_KG_WARP_SIZE) { + int16_t *dst_e = d_e + ((size_t)i * batch_count + inst) * PARAM_N; + poly_getnoise_e_kg(dst_e, noiseseed, (uint8_t)(PARAM_K + i)); + } +} + +__global__ void batch_keygen_seed_expand_kernel( + uint8_t * __restrict__ d_publicseed, + uint8_t * __restrict__ d_noiseseed, + const uint8_t * __restrict__ d_coins, + int batch_count) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= batch_count) return; + + uint8_t buf[2 * PARAM_SYMBYTES]; + sha3_512(buf, d_coins + (size_t)inst * 2 * PARAM_SYMBYTES, PARAM_SYMBYTES); + for (int i = 0; i < PARAM_SYMBYTES; i++) { + d_publicseed[(size_t)inst * PARAM_SYMBYTES + i] = buf[i]; + d_noiseseed[(size_t)inst * PARAM_SYMBYTES + i] = buf[PARAM_SYMBYTES + i]; + } +} + +__global__ void batch_keygen_mat_sample_kernel( + int16_t * __restrict__ d_mat, + const uint8_t * __restrict__ d_publicseed, + int batch_count) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = batch_count * PARAM_K * PARAM_K; + if (idx >= total) return; + + int inst = idx % batch_count; + int p = idx / batch_count; + int row = p / PARAM_K; + int col = p % PARAM_K; + +#if ALGORITHM == ALGO_KYBER + uint8_t x = (uint8_t)col; + uint8_t y = (uint8_t)row; +#elif ALGORITHM == ALGO_AIGIS_ENC + uint8_t x = (uint8_t)row; + uint8_t y = (uint8_t)col; +#endif + + int16_t *dst = d_mat + ((size_t)(row * PARAM_K + col) * batch_count + inst) * PARAM_N; + const uint8_t *seed = d_publicseed + (size_t)inst * PARAM_SYMBYTES; + rej_uniform_xof(dst, seed, x, y); +} + +__global__ void batch_keygen_noise_sample_kernel( + int16_t * __restrict__ d_skpv, + int16_t * __restrict__ d_e, + const uint8_t * __restrict__ d_noiseseed, + int batch_count) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = batch_count * PARAM_K * 2; + if (idx >= total) return; + + int inst = idx % batch_count; + int q = idx / batch_count; + int poly = q % PARAM_K; + const uint8_t *seed = d_noiseseed + (size_t)inst * PARAM_SYMBYTES; + + if (q < PARAM_K) { + int16_t *dst = d_skpv + ((size_t)poly * batch_count + inst) * PARAM_N; + poly_getnoise_s(dst, seed, (uint8_t)poly); + } else { + int16_t *dst = d_e + ((size_t)poly * batch_count + inst) * PARAM_N; + poly_getnoise_e_kg(dst, seed, (uint8_t)(PARAM_K + poly)); + } +} + +/* ================================================================ + * 批量打包 PK/SK kernel (每 block 处理一个实例) + * 在所有 NTT 和 matvec 计算完成后调用 + * + * 输入: + * d_mat — 矩阵 A (unused for packing, publicseed stored in d_coins) + * d_skpv — NTT 域 s (已 caddq) + * d_pkpv — b = A*s + e (已 caddq), 以 SoA 格式 + * 输出: + * d_pk_bytes — PK 字节流 + * d_sk_bytes — SK 字节流 (indcpa_sk || pk || H(pk) || z) + * ================================================================ */ + +__global__ void batch_pack_keypair_kernel( + uint8_t * __restrict__ d_pk_bytes, + uint8_t * __restrict__ d_sk_bytes, + const int16_t * __restrict__ d_skpv, + const int16_t * __restrict__ d_pkpv, + const uint8_t * __restrict__ d_coins, /* B * 2*SYMBYTES: publicseed 在位置[inst*2*32] */ + int batch_count) +{ + int inst = blockIdx.x; + if (inst >= batch_count) return; + + /* 构建 kem_polyvec 结构 (从 SoA 还原为 AoS) */ + kem_polyvec skpv_local, pkpv_local; + for (int i = 0; i < PARAM_K; i++) + for (int c = 0; c < PARAM_N; c++) { + skpv_local.vec[i].coeffs[c] = d_skpv[((size_t)i * batch_count + inst) * PARAM_N + c]; + pkpv_local.vec[i].coeffs[c] = d_pkpv[((size_t)i * batch_count + inst) * PARAM_N + c]; + } + + /* 从 d_coins 取出 publicseed (keygen 时, sha3_512 已展开, publicseed = 前 32 字节) */ + /* 实际上我们在 warp 采样时已用 sha3_512 展开, 这里需要重新计算 publicseed */ + uint8_t seeds[2 * PARAM_SYMBYTES]; + sha3_512(seeds, d_coins + (size_t)inst * 2 * PARAM_SYMBYTES, PARAM_SYMBYTES); + const uint8_t *publicseed = seeds; + + /* PK = pk_poly_compress(pkpv) || publicseed */ + uint8_t *pk = d_pk_bytes + (size_t)inst * PARAM_PUBLICKEYBYTES; + pack_pk(pk, &pkpv_local, publicseed); + + /* SK = polyvec_tobytes(skpv) || pk || H(pk) || z */ + uint8_t *sk = d_sk_bytes + (size_t)inst * PARAM_SECRETKEYBYTES; + pack_sk(sk, &skpv_local); + + /* sk[indcpa_sk_bytes:] = pk */ + for (int i = 0; i < (int)PARAM_PUBLICKEYBYTES; i++) + sk[PARAM_INDCPA_SECRETKEYBYTES + i] = pk[i]; + + /* H(pk) */ + sha3_256(sk + PARAM_INDCPA_SECRETKEYBYTES + PARAM_PUBLICKEYBYTES, pk, PARAM_PUBLICKEYBYTES); + + /* z = coins[32:64] (第二个 32 字节作为随机 z) */ + const uint8_t *z_src = d_coins + (size_t)inst * 2 * PARAM_SYMBYTES + PARAM_SYMBYTES; + uint8_t *z_dst = sk + PARAM_INDCPA_SECRETKEYBYTES + PARAM_PUBLICKEYBYTES + PARAM_SYMBYTES; + for (int i = 0; i < PARAM_SYMBYTES; i++) z_dst[i] = z_src[i]; +} + +#ifndef KEM_PACK_TPB +#define KEM_PACK_TPB 128 +#endif + +__global__ void batch_pack_sk_polyvec_kernel( + uint8_t * __restrict__ d_sk_bytes, + const int16_t * __restrict__ d_skpv, + int batch_count) +{ + int inst = blockIdx.x; + int poly = blockIdx.y; + int tid = threadIdx.x; + if (inst >= batch_count || poly >= PARAM_K) return; + + const int16_t *src = d_skpv + ((size_t)poly * batch_count + inst) * PARAM_N; + uint8_t *out = d_sk_bytes + (size_t)inst * PARAM_SECRETKEYBYTES + + (size_t)poly * PARAM_POLYBYTES; + +#if ALGORITHM == ALGO_KYBER + for (int i = tid; i < PARAM_N / 2; i += blockDim.x) { + int16_t t0 = caddq(src[2 * i]); + int16_t t1 = caddq(src[2 * i + 1]); + out[3 * i + 0] = (uint8_t)t0; + out[3 * i + 1] = (uint8_t)((t0 >> 8) | (t1 << 4)); + out[3 * i + 2] = (uint8_t)(t1 >> 4); + } +#elif ALGORITHM == ALGO_AIGIS_ENC + for (int i = tid; i < PARAM_N / 8; i += blockDim.x) { + int16_t t0 = caddq(src[8 * i + 0]); + int16_t t1 = caddq(src[8 * i + 1]); + int16_t t2 = caddq(src[8 * i + 2]); + int16_t t3 = caddq(src[8 * i + 3]); + int16_t t4 = caddq(src[8 * i + 4]); + int16_t t5 = caddq(src[8 * i + 5]); + int16_t t6 = caddq(src[8 * i + 6]); + int16_t t7 = caddq(src[8 * i + 7]); + out[13 * i + 0] = (uint8_t)t0; + out[13 * i + 1] = (uint8_t)((t0 >> 8) | (t1 << 5)); + out[13 * i + 2] = (uint8_t)(t1 >> 3); + out[13 * i + 3] = (uint8_t)((t1 >> 11) | (t2 << 2)); + out[13 * i + 4] = (uint8_t)((t2 >> 6) | (t3 << 7)); + out[13 * i + 5] = (uint8_t)(t3 >> 1); + out[13 * i + 6] = (uint8_t)((t3 >> 9) | (t4 << 4)); + out[13 * i + 7] = (uint8_t)(t4 >> 4); + out[13 * i + 8] = (uint8_t)((t4 >> 12) | (t5 << 1)); + out[13 * i + 9] = (uint8_t)((t5 >> 7) | (t6 << 6)); + out[13 * i + 10] = (uint8_t)(t6 >> 2); + out[13 * i + 11] = (uint8_t)((t6 >> 10) | (t7 << 3)); + out[13 * i + 12] = (uint8_t)(t7 >> 5); + } +#endif +} + +__global__ void batch_pack_pk_polyvec_kernel( + uint8_t * __restrict__ d_pk_bytes, + const int16_t * __restrict__ d_pkpv, + int batch_count) +{ + int inst = blockIdx.x; + int poly = blockIdx.y; + int tid = threadIdx.x; + if (inst >= batch_count || poly >= PARAM_K) return; + + const int16_t *src = d_pkpv + ((size_t)poly * batch_count + inst) * PARAM_N; + uint8_t *out = d_pk_bytes + (size_t)inst * PARAM_PUBLICKEYBYTES + + (size_t)poly * (PARAM_BITS_PK * PARAM_N / 8); + +#if ALGORITHM == ALGO_KYBER + for (int i = tid; i < PARAM_N / 2; i += blockDim.x) { + int16_t t0 = caddq(src[2 * i]); + int16_t t1 = caddq(src[2 * i + 1]); + out[3 * i + 0] = (uint8_t)t0; + out[3 * i + 1] = (uint8_t)((t0 >> 8) | (t1 << 4)); + out[3 * i + 2] = (uint8_t)(t1 >> 4); + } +#elif PARAM_BITS_PK == 9 + for (int i = tid; i < PARAM_N / 8; i += blockDim.x) { + uint16_t c0 = (uint16_t)((((int32_t)caddq(src[8*i+0]) << 9) + PARAM_Q/2) / PARAM_Q) & 0x1FF; + uint16_t c1 = (uint16_t)((((int32_t)caddq(src[8*i+1]) << 9) + PARAM_Q/2) / PARAM_Q) & 0x1FF; + uint16_t c2 = (uint16_t)((((int32_t)caddq(src[8*i+2]) << 9) + PARAM_Q/2) / PARAM_Q) & 0x1FF; + uint16_t c3 = (uint16_t)((((int32_t)caddq(src[8*i+3]) << 9) + PARAM_Q/2) / PARAM_Q) & 0x1FF; + uint16_t c4 = (uint16_t)((((int32_t)caddq(src[8*i+4]) << 9) + PARAM_Q/2) / PARAM_Q) & 0x1FF; + uint16_t c5 = (uint16_t)((((int32_t)caddq(src[8*i+5]) << 9) + PARAM_Q/2) / PARAM_Q) & 0x1FF; + uint16_t c6 = (uint16_t)((((int32_t)caddq(src[8*i+6]) << 9) + PARAM_Q/2) / PARAM_Q) & 0x1FF; + uint16_t c7 = (uint16_t)((((int32_t)caddq(src[8*i+7]) << 9) + PARAM_Q/2) / PARAM_Q) & 0x1FF; + out[9*i+0] = (uint8_t)c0; + out[9*i+1] = (uint8_t)((c0 >> 8) | (c1 << 1)); + out[9*i+2] = (uint8_t)((c1 >> 7) | (c2 << 2)); + out[9*i+3] = (uint8_t)((c2 >> 6) | (c3 << 3)); + out[9*i+4] = (uint8_t)((c3 >> 5) | (c4 << 4)); + out[9*i+5] = (uint8_t)((c4 >> 4) | (c5 << 5)); + out[9*i+6] = (uint8_t)((c5 >> 3) | (c6 << 6)); + out[9*i+7] = (uint8_t)((c6 >> 2) | (c7 << 7)); + out[9*i+8] = (uint8_t)(c7 >> 1); + } +#elif PARAM_BITS_PK == 10 + for (int i = tid; i < PARAM_N / 4; i += blockDim.x) { + uint16_t c0 = (uint16_t)((((int32_t)caddq(src[4*i+0]) << 10) + PARAM_Q/2) / PARAM_Q) & 0x3FF; + uint16_t c1 = (uint16_t)((((int32_t)caddq(src[4*i+1]) << 10) + PARAM_Q/2) / PARAM_Q) & 0x3FF; + uint16_t c2 = (uint16_t)((((int32_t)caddq(src[4*i+2]) << 10) + PARAM_Q/2) / PARAM_Q) & 0x3FF; + uint16_t c3 = (uint16_t)((((int32_t)caddq(src[4*i+3]) << 10) + PARAM_Q/2) / PARAM_Q) & 0x3FF; + out[5*i+0] = (uint8_t)c0; + out[5*i+1] = (uint8_t)((c0 >> 8) | (c1 << 2)); + out[5*i+2] = (uint8_t)((c1 >> 6) | (c2 << 4)); + out[5*i+3] = (uint8_t)((c2 >> 4) | (c3 << 6)); + out[5*i+4] = (uint8_t)(c3 >> 2); + } +#elif PARAM_BITS_PK == 11 + for (int i = tid; i < PARAM_N / 8; i += blockDim.x) { + uint16_t c0 = (uint16_t)((((int32_t)caddq(src[8*i+0]) << 11) + PARAM_Q/2) / PARAM_Q) & 0x7FF; + uint16_t c1 = (uint16_t)((((int32_t)caddq(src[8*i+1]) << 11) + PARAM_Q/2) / PARAM_Q) & 0x7FF; + uint16_t c2 = (uint16_t)((((int32_t)caddq(src[8*i+2]) << 11) + PARAM_Q/2) / PARAM_Q) & 0x7FF; + uint16_t c3 = (uint16_t)((((int32_t)caddq(src[8*i+3]) << 11) + PARAM_Q/2) / PARAM_Q) & 0x7FF; + uint16_t c4 = (uint16_t)((((int32_t)caddq(src[8*i+4]) << 11) + PARAM_Q/2) / PARAM_Q) & 0x7FF; + uint16_t c5 = (uint16_t)((((int32_t)caddq(src[8*i+5]) << 11) + PARAM_Q/2) / PARAM_Q) & 0x7FF; + uint16_t c6 = (uint16_t)((((int32_t)caddq(src[8*i+6]) << 11) + PARAM_Q/2) / PARAM_Q) & 0x7FF; + uint16_t c7 = (uint16_t)((((int32_t)caddq(src[8*i+7]) << 11) + PARAM_Q/2) / PARAM_Q) & 0x7FF; + out[11*i+ 0] = (uint8_t)c0; + out[11*i+ 1] = (uint8_t)((c0 >> 8) | (c1 << 3)); + out[11*i+ 2] = (uint8_t)((c1 >> 5) | (c2 << 6)); + out[11*i+ 3] = (uint8_t)(c2 >> 2); + out[11*i+ 4] = (uint8_t)((c2 >> 10) | (c3 << 1)); + out[11*i+ 5] = (uint8_t)((c3 >> 7) | (c4 << 4)); + out[11*i+ 6] = (uint8_t)((c4 >> 4) | (c5 << 7)); + out[11*i+ 7] = (uint8_t)(c5 >> 1); + out[11*i+ 8] = (uint8_t)((c5 >> 9) | (c6 << 2)); + out[11*i+ 9] = (uint8_t)((c6 >> 6) | (c7 << 5)); + out[11*i+10] = (uint8_t)(c7 >> 3); + } +#endif +} + +__global__ void batch_pack_keypair_finalize_kernel( + uint8_t * __restrict__ d_pk_bytes, + uint8_t * __restrict__ d_sk_bytes, + const uint8_t * __restrict__ d_publicseed, + const uint8_t * __restrict__ d_coins, + int batch_count) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= batch_count) return; + + uint8_t *pk = d_pk_bytes + (size_t)inst * PARAM_PUBLICKEYBYTES; + uint8_t *sk = d_sk_bytes + (size_t)inst * PARAM_SECRETKEYBYTES; + const uint8_t *rho = d_publicseed + (size_t)inst * PARAM_SYMBYTES; + + for (int i = 0; i < PARAM_SYMBYTES; i++) + pk[PARAM_PK_POLYVEC_BYTES + i] = rho[i]; + + for (int i = 0; i < (int)PARAM_PUBLICKEYBYTES; i++) + sk[PARAM_INDCPA_SECRETKEYBYTES + i] = pk[i]; + + sha3_256(sk + PARAM_INDCPA_SECRETKEYBYTES + PARAM_PUBLICKEYBYTES, + pk, PARAM_PUBLICKEYBYTES); + + const uint8_t *z_src = d_coins + (size_t)inst * 2 * PARAM_SYMBYTES + PARAM_SYMBYTES; + uint8_t *z_dst = sk + PARAM_INDCPA_SECRETKEYBYTES + PARAM_PUBLICKEYBYTES + PARAM_SYMBYTES; + for (int i = 0; i < PARAM_SYMBYTES; i++) z_dst[i] = z_src[i]; +} + +/* ================================================================ + * 批量单实例 keygen kernel (完整流水线, 单线程 fallback) + * 用于 batch 较小时, 直接调用 kem_keypair 设备函数 + * ================================================================ */ +#ifndef KEM_KEYPAIR_LAUNCH_BOUNDS +#define KEM_KEYPAIR_LAUNCH_BOUNDS 1 +#endif + +#ifndef KEM_ENCAPS_LAUNCH_BOUNDS +#if ALGORITHM == ALGO_AIGIS_ENC +#define KEM_ENCAPS_LAUNCH_BOUNDS 1 +#else +#define KEM_ENCAPS_LAUNCH_BOUNDS 0 +#endif +#endif + +#ifndef KEM_DECAPS_LAUNCH_BOUNDS +#if ALGORITHM == ALGO_AIGIS_ENC +#define KEM_DECAPS_LAUNCH_BOUNDS 1 +#else +#define KEM_DECAPS_LAUNCH_BOUNDS 0 +#endif +#endif + +#if KEM_KEYPAIR_LAUNCH_BOUNDS +#define KEM_KEYPAIR_KERNEL_BOUNDS __launch_bounds__(KEM_KEYGEN_TPB, 1) +#else +#define KEM_KEYPAIR_KERNEL_BOUNDS +#endif + +#if KEM_ENCAPS_LAUNCH_BOUNDS +#define KEM_ENCAPS_KERNEL_BOUNDS __launch_bounds__(KEM_ENCAPS_TPB, 1) +#else +#define KEM_ENCAPS_KERNEL_BOUNDS +#endif + +#if KEM_DECAPS_LAUNCH_BOUNDS +#define KEM_DECAPS_KERNEL_BOUNDS __launch_bounds__(KEM_DECAPS_TPB, 1) +#else +#define KEM_DECAPS_KERNEL_BOUNDS +#endif + +__global__ KEM_KEYPAIR_KERNEL_BOUNDS void batch_kem_keypair_serial_kernel( + uint8_t * __restrict__ d_pk, + uint8_t * __restrict__ d_sk, + const uint8_t * __restrict__ d_coins, /* B * 2*SYMBYTES */ + int batch_count) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= batch_count) return; + + kem_keypair( + d_pk + (size_t)inst * PARAM_PUBLICKEYBYTES, + d_sk + (size_t)inst * PARAM_SECRETKEYBYTES, + d_coins + (size_t)inst * 2 * PARAM_SYMBYTES + ); +} + +/* ================================================================ + * 批量单实例 encaps kernel + * ================================================================ */ +__global__ KEM_ENCAPS_KERNEL_BOUNDS void batch_kem_encaps_serial_kernel( + uint8_t * __restrict__ d_ct, + uint8_t * __restrict__ d_ss, + const uint8_t * __restrict__ d_pk, + const uint8_t * __restrict__ d_coins, /* B * SYMBYTES */ + int batch_count) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= batch_count) return; + + kem_encaps( + d_ct + (size_t)inst * PARAM_CIPHERTEXTBYTES, + d_ss + (size_t)inst * PARAM_SSBYTES, + d_pk + (size_t)inst * PARAM_PUBLICKEYBYTES, + d_coins + (size_t)inst * PARAM_SYMBYTES + ); +} + +/* ================================================================ + * 批量单实例 decaps kernel + * ================================================================ */ +__global__ KEM_DECAPS_KERNEL_BOUNDS void batch_kem_decaps_serial_kernel( + uint8_t * __restrict__ d_ss, + const uint8_t * __restrict__ d_ct, + const uint8_t * __restrict__ d_sk, + int batch_count) +{ + int inst = blockIdx.x * blockDim.x + threadIdx.x; + if (inst >= batch_count) return; + + kem_decaps( + d_ss + (size_t)inst * PARAM_SSBYTES, + d_ct + (size_t)inst * PARAM_CIPHERTEXTBYTES, + d_sk + (size_t)inst * PARAM_SECRETKEYBYTES + ); +} + +/* ================================================================ + * 批量 KEM 高性能流水线 + * + * batch_keygen_pipelined: + * 1. Warp 采样 (矩阵 A + s + e) + * 2. 批量 NTT(s) + * 3. 2D grid 矩阵向量乘 (A*s → pkpv) + * 4. 批量 INVNTT(pkpv) + 加 e, caddq + * 5. 打包 pk/sk + * ================================================================ */ +static inline cudaError_t batch_keygen_pipelined( + uint8_t *d_pk_out, uint8_t *d_sk_out, + BatchKemBuffers *buf, + int batch_count, + cudaStream_t stream = 0) +{ + cudaError_t err; + + /* 生成随机种子 (device side — 在 host 侧用 cudaMemcpy 传入 d_coins_kg) */ + + /* Step 1: Warp 采样 */ + int blocks = (batch_count + WP_KG_WARPS_BLOCK - 1) / WP_KG_WARPS_BLOCK; +#if KEM_SPLIT_KEYGEN_SAMPLE + batch_keygen_seed_expand_kernel<<>>( + buf->d_publicseed_kg, buf->d_noiseseed_kg, buf->d_coins_kg, batch_count); + batch_keygen_mat_sample_kernel<<>>( + buf->d_mat, buf->d_publicseed_kg, batch_count); + batch_keygen_noise_sample_kernel<<>>( + buf->d_skpv, buf->d_e, buf->d_noiseseed_kg, batch_count); +#else + batch_keygen_warp_sample_kernel<<>>( + buf->d_mat, buf->d_skpv, buf->d_e, + buf->d_publicseed_kg, buf->d_coins_kg, batch_count); +#endif + + /* Step 2: 批量 NTT(s) — d_skpv 中 K 个 poly 组 */ + for (int i = 0; i < PARAM_K; i++) { + int16_t *ptr = buf->d_skpv + (size_t)i * batch_count * PARAM_N; + batch_ntt_kernel<<>>(ptr, batch_count); + } + + /* Step 2b: caddq(s) */ + for (int i = 0; i < PARAM_K; i++) { + int16_t *ptr = buf->d_skpv + (size_t)i * batch_count * PARAM_N; + launch_batch_caddq(ptr, batch_count, stream); + } + + /* Step 3: 矩阵向量乘 A * s_hat → pkpv */ + launch_batch_matvec(buf->d_pkpv, buf->d_mat, buf->d_skpv, batch_count, stream); + + /* Step 4: INVNTT(pkpv) */ + for (int i = 0; i < PARAM_K; i++) { + int16_t *ptr = buf->d_pkpv + (size_t)i * batch_count * PARAM_N; + batch_invntt_kernel<<>>(ptr, batch_count); + } + + /* pkpv += e */ + for (int i = 0; i < PARAM_K; i++) { + launch_batch_add( + buf->d_pkpv + (size_t)i * batch_count * PARAM_N, + buf->d_pkpv + (size_t)i * batch_count * PARAM_N, + buf->d_e + (size_t)i * batch_count * PARAM_N, + batch_count, stream); + } + + /* caddq(pkpv) */ + for (int i = 0; i < PARAM_K; i++) { + launch_batch_caddq(buf->d_pkpv + (size_t)i * batch_count * PARAM_N, batch_count, stream); + } + + /* Step 5: 打包 PK/SK */ + dim3 pack_grid(batch_count, PARAM_K); + batch_pack_sk_polyvec_kernel<<>>( + d_sk_out, buf->d_skpv, batch_count); + batch_pack_pk_polyvec_kernel<<>>( + d_pk_out, buf->d_pkpv, batch_count); + batch_pack_keypair_finalize_kernel<<>>( + d_pk_out, d_sk_out, buf->d_publicseed_kg, buf->d_coins_kg, batch_count); + + err = cudaGetLastError(); + return err; +} + +/* ================================================================ + * 简化批量 encaps/decaps (串行 kernel, 可进一步并行化) + * ================================================================ */ +static inline cudaError_t batch_keygen_pipelined_profile( + uint8_t *d_pk_out, uint8_t *d_sk_out, + BatchKemBuffers *buf, + int batch_count, + cudaStream_t stream = 0) +{ + cudaEvent_t ev0, ev1, ev2, ev3, ev4, ev5, ev6; + cudaEventCreate(&ev0); cudaEventCreate(&ev1); cudaEventCreate(&ev2); + cudaEventCreate(&ev3); cudaEventCreate(&ev4); cudaEventCreate(&ev5); cudaEventCreate(&ev6); + + cudaEventRecord(ev0, stream); + int blocks = (batch_count + WP_KG_WARPS_BLOCK - 1) / WP_KG_WARPS_BLOCK; +#if KEM_SPLIT_KEYGEN_SAMPLE + batch_keygen_seed_expand_kernel<<>>( + buf->d_publicseed_kg, buf->d_noiseseed_kg, buf->d_coins_kg, batch_count); + batch_keygen_mat_sample_kernel<<>>( + buf->d_mat, buf->d_publicseed_kg, batch_count); + batch_keygen_noise_sample_kernel<<>>( + buf->d_skpv, buf->d_e, buf->d_noiseseed_kg, batch_count); +#else + batch_keygen_warp_sample_kernel<<>>( + buf->d_mat, buf->d_skpv, buf->d_e, + buf->d_publicseed_kg, buf->d_coins_kg, batch_count); +#endif + cudaEventRecord(ev1, stream); + + for (int i = 0; i < PARAM_K; i++) { + int16_t *ptr = buf->d_skpv + (size_t)i * batch_count * PARAM_N; + batch_ntt_kernel<<>>(ptr, batch_count); + } + for (int i = 0; i < PARAM_K; i++) { + int16_t *ptr = buf->d_skpv + (size_t)i * batch_count * PARAM_N; + launch_batch_caddq(ptr, batch_count, stream); + } + cudaEventRecord(ev2, stream); + + launch_batch_matvec(buf->d_pkpv, buf->d_mat, buf->d_skpv, batch_count, stream); + cudaEventRecord(ev3, stream); + + for (int i = 0; i < PARAM_K; i++) { + int16_t *ptr = buf->d_pkpv + (size_t)i * batch_count * PARAM_N; + batch_invntt_kernel<<>>(ptr, batch_count); + } + cudaEventRecord(ev4, stream); + + for (int i = 0; i < PARAM_K; i++) { + launch_batch_add( + buf->d_pkpv + (size_t)i * batch_count * PARAM_N, + buf->d_pkpv + (size_t)i * batch_count * PARAM_N, + buf->d_e + (size_t)i * batch_count * PARAM_N, + batch_count, stream); + } + for (int i = 0; i < PARAM_K; i++) + launch_batch_caddq(buf->d_pkpv + (size_t)i * batch_count * PARAM_N, batch_count, stream); + cudaEventRecord(ev5, stream); + + dim3 pack_grid(batch_count, PARAM_K); + batch_pack_sk_polyvec_kernel<<>>( + d_sk_out, buf->d_skpv, batch_count); + batch_pack_pk_polyvec_kernel<<>>( + d_pk_out, buf->d_pkpv, batch_count); + batch_pack_keypair_finalize_kernel<<>>( + d_pk_out, d_sk_out, buf->d_publicseed_kg, buf->d_coins_kg, batch_count); + cudaEventRecord(ev6, stream); + cudaEventSynchronize(ev6); + + float sample_ms, ntt_ms, matvec_ms, invntt_ms, add_ms, pack_ms, total_ms; + cudaEventElapsedTime(&sample_ms, ev0, ev1); + cudaEventElapsedTime(&ntt_ms, ev1, ev2); + cudaEventElapsedTime(&matvec_ms, ev2, ev3); + cudaEventElapsedTime(&invntt_ms, ev3, ev4); + cudaEventElapsedTime(&add_ms, ev4, ev5); + cudaEventElapsedTime(&pack_ms, ev5, ev6); + cudaEventElapsedTime(&total_ms, ev0, ev6); + printf(" Pipeline profile: sample=%.3f ntt=%.3f matvec=%.3f invntt=%.3f add=%.3f pack=%.3f total=%.3f ms\n", + sample_ms, ntt_ms, matvec_ms, invntt_ms, add_ms, pack_ms, total_ms); + + cudaEventDestroy(ev0); cudaEventDestroy(ev1); cudaEventDestroy(ev2); + cudaEventDestroy(ev3); cudaEventDestroy(ev4); cudaEventDestroy(ev5); cudaEventDestroy(ev6); + return cudaGetLastError(); +} + +static inline cudaError_t batch_encaps_serial( + uint8_t *d_ct, uint8_t *d_ss, + const uint8_t *d_pk, + BatchKemBuffers *buf, + int batch_count, + cudaStream_t stream = 0) +{ + int tpb = KEM_ENCAPS_TPB; + int blocks = (batch_count + tpb - 1) / tpb; + batch_kem_encaps_serial_kernel<<>>( + d_ct, d_ss, d_pk, buf->d_coins_enc, batch_count); + return cudaGetLastError(); +} + +static inline cudaError_t batch_decaps_serial( + uint8_t *d_ss, + const uint8_t *d_ct, const uint8_t *d_sk, + int batch_count, + cudaStream_t stream = 0) +{ + int tpb = KEM_DECAPS_TPB; + int blocks = (batch_count + tpb - 1) / tpb; + batch_kem_decaps_serial_kernel<<>>( + d_ss, d_ct, d_sk, batch_count); + return cudaGetLastError(); +} + +#endif /* BATCH_KEM_CUH */ diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/build_hip.sh b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/build_hip.sh new file mode 100644 index 000000000..83495dc8a --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/build_hip.sh @@ -0,0 +1,102 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd "$(dirname "$0")" + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +HIPCC=${HIPCC:-hipcc} +ROCM_ARCH=${ROCM_ARCH:-gfx1100} +KEM_SERIAL_TPB=${KEM_SERIAL_TPB:-64} +KEM_KEYGEN_TPB=${KEM_KEYGEN_TPB:-${KEM_SERIAL_TPB}} +KEM_ENCAPS_TPB=${KEM_ENCAPS_TPB:-${KEM_SERIAL_TPB}} +KEM_DECAPS_TPB=${KEM_DECAPS_TPB:-${KEM_SERIAL_TPB}} +KEM_KEYPAIR_LAUNCH_BOUNDS=${KEM_KEYPAIR_LAUNCH_BOUNDS:-1} +KEM_ENCAPS_LAUNCH_BOUNDS=${KEM_ENCAPS_LAUNCH_BOUNDS:-} +KEM_DECAPS_LAUNCH_BOUNDS=${KEM_DECAPS_LAUNCH_BOUNDS:-} +WP_KG_WARPS_BLOCK=${WP_KG_WARPS_BLOCK:-4} +KEM_PACK_TPB=${KEM_PACK_TPB:-128} +BUILD_TYPE=${BUILD_TYPE:-Release} +CXX_STD=${CXX_STD:-c++17} +ROCM_WAVE32_FLAG=${ROCM_WAVE32_FLAG:-} +OPT_LEVEL=${OPT_LEVEL:-} +EXTRA_HIPCC_FLAGS=${EXTRA_HIPCC_FLAGS:-} + +if [[ "${BUILD_TYPE}" == "Debug" ]]; then + OPT_FLAGS=(-O0 -g) +else + if [[ -n "${OPT_LEVEL}" ]]; then + OPT_FLAGS=("-${OPT_LEVEL}") + else + OPT_FLAGS=(-O2) + fi +fi + +COMMON_FLAGS=( + "${OPT_FLAGS[@]}" + -std="${CXX_STD}" + -x + hip + --offload-arch="${ROCM_ARCH}" + -DKEM_SERIAL_TPB="${KEM_SERIAL_TPB}" + -DKEM_KEYGEN_TPB="${KEM_KEYGEN_TPB}" + -DKEM_ENCAPS_TPB="${KEM_ENCAPS_TPB}" + -DKEM_DECAPS_TPB="${KEM_DECAPS_TPB}" + -DKEM_KEYPAIR_LAUNCH_BOUNDS="${KEM_KEYPAIR_LAUNCH_BOUNDS}" + -DWP_KG_WARPS_BLOCK="${WP_KG_WARPS_BLOCK}" + -DKEM_PACK_TPB="${KEM_PACK_TPB}" +) + +if [[ -n "${KEM_ENCAPS_LAUNCH_BOUNDS}" ]]; then + COMMON_FLAGS+=(-DKEM_ENCAPS_LAUNCH_BOUNDS="${KEM_ENCAPS_LAUNCH_BOUNDS}") +fi + +if [[ -n "${KEM_DECAPS_LAUNCH_BOUNDS}" ]]; then + COMMON_FLAGS+=(-DKEM_DECAPS_LAUNCH_BOUNDS="${KEM_DECAPS_LAUNCH_BOUNDS}") +fi + +if [[ -n "${ROCM_WAVE32_FLAG}" ]]; then + COMMON_FLAGS+=("${ROCM_WAVE32_FLAG}") +fi + +if [[ -n "${EXTRA_HIPCC_FLAGS}" ]]; then + # shellcheck disable=SC2206 + EXTRA_FLAGS_ARRAY=(${EXTRA_HIPCC_FLAGS}) + COMMON_FLAGS+=("${EXTRA_FLAGS_ARRAY[@]}") +fi + +declare -a TARGETS=( + "kyber512:1:2" + "kyber768:1:3" + "kyber1024:1:4" + "aigisenc1:2:1" + "aigisenc2:2:2" + "aigisenc3:2:3" + "aigisenc4:2:4" +) + +FILTER=${1:-} + +if ! command -v "${HIPCC}" >/dev/null 2>&1; then + echo "[错误] 未找到 hipcc,请先安装 ROCm 并把 hipcc 加入 PATH" + exit 1 +fi + +mkdir -p amd_results/build + +for spec in "${TARGETS[@]}"; do + IFS=':' read -r name alg mode <<<"${spec}" + if [[ -n "${FILTER}" && "${name}" != "${FILTER}" ]]; then + continue + fi + + out="${name}_amd" + echo "[build] ${out} (ALGORITHM=${alg} PARAM_MODE=${mode}, arch=${ROCM_ARCH}, opt=${OPT_FLAGS[*]}, KEM_TPB=${KEM_KEYGEN_TPB}/${KEM_ENCAPS_TPB}/${KEM_DECAPS_TPB}, bounds=${KEM_KEYPAIR_LAUNCH_BOUNDS}/${KEM_ENCAPS_LAUNCH_BOUNDS:-default}/${KEM_DECAPS_LAUNCH_BOUNDS:-default}, wpkg=${WP_KG_WARPS_BLOCK}, pack=${KEM_PACK_TPB})" + "${HIPCC}" "${COMMON_FLAGS[@]}" \ + -DALGORITHM="${alg}" -DPARAM_MODE="${mode}" \ + -o "${out}" main.cu \ + 2>&1 | tee "amd_results/build/${out}.log" +done + +echo +echo "HIP 构建完成" diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/main.cu b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/main.cu new file mode 100644 index 000000000..2b1da7c51 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/main.cu @@ -0,0 +1,815 @@ +/* + * main.cu — 统一 KEM 测试驱动程序 (Kyber + Aigis-enc) + * + * 编译示例: + * nvcc -O2 -DALGORITHM=1 -DPARAM_MODE=3 -o kyber768.exe main.cu + * nvcc -O2 -DALGORITHM=2 -DPARAM_MODE=3 -o aigisenc3.exe main.cu + * + * 用法: + * kyber768.exe — 运行正确性测试 + 默认批量吞吐量测试 + * kyber768.exe --batch 8192 — 指定批量大小 + * kyber768.exe --sweep — 扫描不同 batch size + * kyber768.exe --serial-only — 仅运行串行设备函数 (不用流水线 kernel) + */ + +#include "rocm_compat.h" +#include +#include +#include +#include +#include + +#include "config.h" +#include "params.h" +#include "batch_kem.cuh" + +/* ================================================================ + * 工具宏 + * ================================================================ */ +#define CUDA_CHECK(call) do { \ + cudaError_t _e = (call); \ + if (_e != cudaSuccess) { \ + fprintf(stderr, "CUDA error %s:%d — %s\n", __FILE__, __LINE__, \ + cudaGetErrorString(_e)); \ + exit(1); \ + } \ +} while (0) + +static double get_time_ms(void) +{ + struct timespec ts; + timespec_get(&ts, TIME_UTC); + return ts.tv_sec * 1000.0 + ts.tv_nsec / 1e6; +} + +/* ================================================================ + * 算法名称 + * ================================================================ */ +static const char *algo_name(void) +{ +#if ALGORITHM == ALGO_KYBER + #if PARAM_MODE == 2 + return "Kyber-512"; + #elif PARAM_MODE == 3 + return "Kyber-768"; + #else + return "Kyber-1024"; + #endif +#elif ALGORITHM == ALGO_AIGIS_ENC + #if PARAM_MODE == 1 + return "Aigis-enc-1"; + #elif PARAM_MODE == 2 + return "Aigis-enc-2"; + #elif PARAM_MODE == 3 + return "Aigis-enc-3"; + #else + return "Aigis-enc-4"; + #endif +#endif +} + +/* ================================================================ + * 正确性测试: 单实例 CPU 调用 GPU kernel 验证 + * ================================================================ */ +static int test_correctness(void) +{ + printf("=== 正确性测试: %s ===\n", algo_name()); + printf(" PK=%u SK=%u CT=%u SS=%u 字节\n", + PARAM_PUBLICKEYBYTES, PARAM_SECRETKEYBYTES, + PARAM_CIPHERTEXTBYTES, PARAM_SSBYTES); + + /* Host 端分配 */ + uint8_t *h_pk = (uint8_t *)malloc(PARAM_PUBLICKEYBYTES); + uint8_t *h_sk = (uint8_t *)malloc(PARAM_SECRETKEYBYTES); + uint8_t *h_ct = (uint8_t *)malloc(PARAM_CIPHERTEXTBYTES); + uint8_t *h_ss1 = (uint8_t *)malloc(PARAM_SSBYTES); + uint8_t *h_ss2 = (uint8_t *)malloc(PARAM_SSBYTES); + uint8_t *h_coins_kg = (uint8_t *)malloc(2 * PARAM_SYMBYTES); + uint8_t *h_coins_enc = (uint8_t *)malloc(PARAM_SYMBYTES); + + if (!h_pk || !h_sk || !h_ct || !h_ss1 || !h_ss2 || !h_coins_kg || !h_coins_enc) { + fprintf(stderr, "malloc failed\n"); + return -1; + } + + /* 生成伪随机种子 (测试用,实际应用请使用安全随机源) */ + srand(42); + for (int i = 0; i < 2 * PARAM_SYMBYTES; i++) + h_coins_kg[i] = (uint8_t)(rand() & 0xFF); + for (int i = 0; i < PARAM_SYMBYTES; i++) + h_coins_enc[i] = (uint8_t)(rand() & 0xFF); + + /* Device 端分配 */ + uint8_t *d_pk, *d_sk, *d_ct, *d_ss1, *d_ss2; + uint8_t *d_coins_kg, *d_coins_enc; + CUDA_CHECK(cudaMalloc(&d_pk, PARAM_PUBLICKEYBYTES)); + CUDA_CHECK(cudaMalloc(&d_sk, PARAM_SECRETKEYBYTES)); + CUDA_CHECK(cudaMalloc(&d_ct, PARAM_CIPHERTEXTBYTES)); + CUDA_CHECK(cudaMalloc(&d_ss1, PARAM_SSBYTES)); + CUDA_CHECK(cudaMalloc(&d_ss2, PARAM_SSBYTES)); + CUDA_CHECK(cudaMalloc(&d_coins_kg, 2 * PARAM_SYMBYTES)); + CUDA_CHECK(cudaMalloc(&d_coins_enc, PARAM_SYMBYTES)); + + CUDA_CHECK(cudaMemcpy(d_coins_kg, h_coins_kg, 2 * PARAM_SYMBYTES, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_coins_enc, h_coins_enc, PARAM_SYMBYTES, cudaMemcpyHostToDevice)); + + /* 串行设备 kernel 验证 (batch_count=1) */ + batch_kem_keypair_serial_kernel<<<1, 1>>>(d_pk, d_sk, d_coins_kg, 1); + CUDA_CHECK(cudaGetLastError()); + batch_kem_encaps_serial_kernel<<<1, 1>>>(d_ct, d_ss1, d_pk, d_coins_enc, 1); + CUDA_CHECK(cudaGetLastError()); + batch_kem_decaps_serial_kernel<<<1, 1>>>(d_ss2, d_ct, d_sk, 1); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + + /* 取回结果 */ + CUDA_CHECK(cudaMemcpy(h_ss1, d_ss1, PARAM_SSBYTES, cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_ss2, d_ss2, PARAM_SSBYTES, cudaMemcpyDeviceToHost)); + + /* 验证 ss1 == ss2 */ + int ok = (memcmp(h_ss1, h_ss2, PARAM_SSBYTES) == 0); + printf(" KEM 正确性: %s\n", ok ? "PASS" : "FAIL"); + + if (!ok) { + printf(" [encaps ss] "); + for (int i = 0; i < 8; i++) printf("%02x", h_ss1[i]); + printf("...\n"); + printf(" [decaps ss] "); + for (int i = 0; i < 8; i++) printf("%02x", h_ss2[i]); + printf("...\n"); + } + + cudaFree(d_pk); cudaFree(d_sk); cudaFree(d_ct); + cudaFree(d_ss1); cudaFree(d_ss2); + cudaFree(d_coins_kg); cudaFree(d_coins_enc); + free(h_pk); free(h_sk); free(h_ct); + free(h_ss1); free(h_ss2); + free(h_coins_kg); free(h_coins_enc); + + return ok ? 0 : 1; +} + +/* ================================================================ + * 批量吞吐量测试 + * ================================================================ */ +static void bench_batch(int batch_count, int n_ops, int use_pipeline, int profile_pipeline = 0) +{ + printf("\n--- batch=%d n_ops=%d mode=%s ---\n", + batch_count, n_ops, use_pipeline ? "pipeline" : "serial"); + + /* 分配设备内存 */ + uint8_t *d_pk, *d_sk, *d_ct, *d_ss; + uint8_t *d_coins_kg, *d_coins_enc; + + CUDA_CHECK(cudaMalloc(&d_pk, (size_t)batch_count * PARAM_PUBLICKEYBYTES)); + CUDA_CHECK(cudaMalloc(&d_sk, (size_t)batch_count * PARAM_SECRETKEYBYTES)); + CUDA_CHECK(cudaMalloc(&d_ct, (size_t)batch_count * PARAM_CIPHERTEXTBYTES)); + CUDA_CHECK(cudaMalloc(&d_ss, (size_t)batch_count * PARAM_SSBYTES)); + CUDA_CHECK(cudaMalloc(&d_coins_kg, (size_t)batch_count * 2 * PARAM_SYMBYTES)); + CUDA_CHECK(cudaMalloc(&d_coins_enc, (size_t)batch_count * PARAM_SYMBYTES)); + + /* 生成随机种子 */ + uint8_t *h_coins_kg = (uint8_t *)malloc((size_t)batch_count * 2 * PARAM_SYMBYTES); + uint8_t *h_coins_enc = (uint8_t *)malloc((size_t)batch_count * PARAM_SYMBYTES); + if (!h_coins_kg || !h_coins_enc) { fprintf(stderr, "OOM\n"); exit(1); } + + srand(1234); + for (size_t i = 0; i < (size_t)batch_count * 2 * PARAM_SYMBYTES; i++) + h_coins_kg[i] = (uint8_t)(rand() & 0xFF); + for (size_t i = 0; i < (size_t)batch_count * PARAM_SYMBYTES; i++) + h_coins_enc[i] = (uint8_t)(rand() & 0xFF); + + CUDA_CHECK(cudaMemcpy(d_coins_kg, h_coins_kg, (size_t)batch_count * 2 * PARAM_SYMBYTES, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_coins_enc, h_coins_enc, (size_t)batch_count * PARAM_SYMBYTES, cudaMemcpyHostToDevice)); + + BatchKemBuffers buf = {}; + if (use_pipeline) { + /* 修复 batch_kem_alloc 中的双 cudaMalloc bug: 直接内联分配 */ + buf.max_batch = batch_count; + CUDA_CHECK(cudaMalloc(&buf.d_mat, (size_t)PARAM_K * PARAM_K * batch_count * PARAM_N * sizeof(int16_t))); + CUDA_CHECK(cudaMalloc(&buf.d_skpv, (size_t)PARAM_K * batch_count * PARAM_N * sizeof(int16_t))); + CUDA_CHECK(cudaMalloc(&buf.d_pkpv, (size_t)PARAM_K * batch_count * PARAM_N * sizeof(int16_t))); + CUDA_CHECK(cudaMalloc(&buf.d_e, (size_t)PARAM_K * batch_count * PARAM_N * sizeof(int16_t))); + CUDA_CHECK(cudaMalloc(&buf.d_publicseed_kg, (size_t)batch_count * PARAM_SYMBYTES)); + CUDA_CHECK(cudaMalloc(&buf.d_noiseseed_kg, (size_t)batch_count * PARAM_SYMBYTES)); + buf.d_pk_bytes = d_pk; + buf.d_sk_bytes = d_sk; + buf.d_ct_bytes = d_ct; + buf.d_ss_bytes = d_ss; + buf.d_coins_kg = d_coins_kg; + buf.d_coins_enc = d_coins_enc; + } + + /* ---- Keygen ---- */ + CUDA_CHECK(cudaDeviceSynchronize()); + double t0 = get_time_ms(); + + for (int op = 0; op < n_ops; op++) { + if (use_pipeline) { + if (profile_pipeline && op == 0) + batch_keygen_pipelined_profile(d_pk, d_sk, &buf, batch_count); + else + batch_keygen_pipelined(d_pk, d_sk, &buf, batch_count); + } else { + int tpb = KEM_KEYGEN_TPB; + int blocks = (batch_count + tpb - 1) / tpb; + batch_kem_keypair_serial_kernel<<>>(d_pk, d_sk, d_coins_kg, batch_count); + } + } + CUDA_CHECK(cudaDeviceSynchronize()); + + double t_kg = (get_time_ms() - t0) / n_ops; + double ops_kg = batch_count * 1000.0 / t_kg; + printf(" Keygen: %7.1f ms/batch → %.0f ops/sec\n", t_kg, ops_kg); + + /* ---- Encaps ---- */ + t0 = get_time_ms(); + for (int op = 0; op < n_ops; op++) { + if (use_pipeline) { + batch_encaps_serial(d_ct, d_ss, d_pk, &buf, batch_count); + } else { + int tpb = KEM_ENCAPS_TPB; + int blocks = (batch_count + tpb - 1) / tpb; + batch_kem_encaps_serial_kernel<<>>(d_ct, d_ss, d_pk, d_coins_enc, batch_count); + } + } + CUDA_CHECK(cudaDeviceSynchronize()); + + double t_enc = (get_time_ms() - t0) / n_ops; + double ops_enc = batch_count * 1000.0 / t_enc; + printf(" Encaps: %7.1f ms/batch → %.0f ops/sec\n", t_enc, ops_enc); + + /* ---- Decaps ---- */ + t0 = get_time_ms(); + for (int op = 0; op < n_ops; op++) { + batch_decaps_serial(d_ss, d_ct, d_sk, batch_count); + } + CUDA_CHECK(cudaDeviceSynchronize()); + + double t_dec = (get_time_ms() - t0) / n_ops; + double ops_dec = batch_count * 1000.0 / t_dec; + printf(" Decaps: %7.1f ms/batch → %.0f ops/sec\n", t_dec, ops_dec); + + /* 清理 */ + if (use_pipeline) { + cudaFree(buf.d_mat); + cudaFree(buf.d_skpv); + cudaFree(buf.d_pkpv); + cudaFree(buf.d_e); + cudaFree(buf.d_publicseed_kg); + cudaFree(buf.d_noiseseed_kg); + } + cudaFree(d_pk); cudaFree(d_sk); cudaFree(d_ct); cudaFree(d_ss); + cudaFree(d_coins_kg); cudaFree(d_coins_enc); + free(h_coins_kg); free(h_coins_enc); +} + +static void run_serial_kem_round( + uint8_t *d_pk, uint8_t *d_sk, uint8_t *d_ct, uint8_t *d_ss, + uint8_t *d_coins_kg, uint8_t *d_coins_enc, + int batch_count, int n_ops) +{ + int kg_tpb = KEM_KEYGEN_TPB; + int kg_blocks = (batch_count + kg_tpb - 1) / kg_tpb; + int enc_tpb = KEM_ENCAPS_TPB; + int enc_blocks = (batch_count + enc_tpb - 1) / enc_tpb; + int dec_tpb = KEM_DECAPS_TPB; + int dec_blocks = (batch_count + dec_tpb - 1) / dec_tpb; + + for (int op = 0; op < n_ops; op++) { + batch_kem_keypair_serial_kernel<<>>(d_pk, d_sk, d_coins_kg, batch_count); + batch_kem_encaps_serial_kernel<<>>(d_ct, d_ss, d_pk, d_coins_enc, batch_count); + batch_kem_decaps_serial_kernel<<>>(d_ss, d_ct, d_sk, batch_count); + } +} + +static void bench_reuse_buffers(int batch_count, int rounds, int n_ops) +{ + printf("\n=== Buffer reuse benchmark: %s ===\n", algo_name()); + printf("batch=%d rounds=%d n_ops_per_round=%d\n", batch_count, rounds, n_ops); + + size_t pk_bytes = (size_t)batch_count * PARAM_PUBLICKEYBYTES; + size_t sk_bytes = (size_t)batch_count * PARAM_SECRETKEYBYTES; + size_t ct_bytes = (size_t)batch_count * PARAM_CIPHERTEXTBYTES; + size_t ss_bytes = (size_t)batch_count * PARAM_SSBYTES; + size_t kg_bytes = (size_t)batch_count * 2 * PARAM_SYMBYTES; + size_t enc_bytes = (size_t)batch_count * PARAM_SYMBYTES; + + uint8_t *h_coins_kg = (uint8_t *)malloc(kg_bytes); + uint8_t *h_coins_enc = (uint8_t *)malloc(enc_bytes); + if (!h_coins_kg || !h_coins_enc) { fprintf(stderr, "OOM\n"); exit(1); } + srand(9102); + for (size_t i = 0; i < kg_bytes; i++) h_coins_kg[i] = (uint8_t)(rand() & 0xFF); + for (size_t i = 0; i < enc_bytes; i++) h_coins_enc[i] = (uint8_t)(rand() & 0xFF); + + CUDA_CHECK(cudaDeviceSynchronize()); + double t0 = get_time_ms(); + for (int r = 0; r < rounds; r++) { + uint8_t *d_pk, *d_sk, *d_ct, *d_ss, *d_coins_kg, *d_coins_enc; + CUDA_CHECK(cudaMalloc(&d_pk, pk_bytes)); + CUDA_CHECK(cudaMalloc(&d_sk, sk_bytes)); + CUDA_CHECK(cudaMalloc(&d_ct, ct_bytes)); + CUDA_CHECK(cudaMalloc(&d_ss, ss_bytes)); + CUDA_CHECK(cudaMalloc(&d_coins_kg, kg_bytes)); + CUDA_CHECK(cudaMalloc(&d_coins_enc, enc_bytes)); + CUDA_CHECK(cudaMemcpy(d_coins_kg, h_coins_kg, kg_bytes, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_coins_enc, h_coins_enc, enc_bytes, cudaMemcpyHostToDevice)); + + run_serial_kem_round(d_pk, d_sk, d_ct, d_ss, d_coins_kg, d_coins_enc, batch_count, n_ops); + CUDA_CHECK(cudaDeviceSynchronize()); + + cudaFree(d_pk); cudaFree(d_sk); cudaFree(d_ct); cudaFree(d_ss); + cudaFree(d_coins_kg); cudaFree(d_coins_enc); + } + CUDA_CHECK(cudaDeviceSynchronize()); + double alloc_each_ms = get_time_ms() - t0; + + uint8_t *d_pk, *d_sk, *d_ct, *d_ss, *d_coins_kg, *d_coins_enc; + CUDA_CHECK(cudaMalloc(&d_pk, pk_bytes)); + CUDA_CHECK(cudaMalloc(&d_sk, sk_bytes)); + CUDA_CHECK(cudaMalloc(&d_ct, ct_bytes)); + CUDA_CHECK(cudaMalloc(&d_ss, ss_bytes)); + CUDA_CHECK(cudaMalloc(&d_coins_kg, kg_bytes)); + CUDA_CHECK(cudaMalloc(&d_coins_enc, enc_bytes)); + + CUDA_CHECK(cudaDeviceSynchronize()); + t0 = get_time_ms(); + for (int r = 0; r < rounds; r++) { + CUDA_CHECK(cudaMemcpy(d_coins_kg, h_coins_kg, kg_bytes, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_coins_enc, h_coins_enc, enc_bytes, cudaMemcpyHostToDevice)); + run_serial_kem_round(d_pk, d_sk, d_ct, d_ss, d_coins_kg, d_coins_enc, batch_count, n_ops); + CUDA_CHECK(cudaDeviceSynchronize()); + } + CUDA_CHECK(cudaDeviceSynchronize()); + double reuse_ms = get_time_ms() - t0; + + double total_instances = (double)batch_count * (double)rounds * (double)n_ops; + printf(" Alloc-each-round: total=%8.1f ms | per_round=%7.3f ms | full-kem throughput=%.0f instances/sec\n", + alloc_each_ms, alloc_each_ms / rounds, total_instances * 1000.0 / alloc_each_ms); + printf(" Reuse buffers: total=%8.1f ms | per_round=%7.3f ms | full-kem throughput=%.0f instances/sec\n", + reuse_ms, reuse_ms / rounds, total_instances * 1000.0 / reuse_ms); + printf(" Reuse speedup: %.3fx\n", alloc_each_ms / reuse_ms); + + cudaFree(d_pk); cudaFree(d_sk); cudaFree(d_ct); cudaFree(d_ss); + cudaFree(d_coins_kg); cudaFree(d_coins_enc); + free(h_coins_kg); free(h_coins_enc); +} + +/* ================================================================ + * Batch size 扫描 + * ================================================================ */ +static void bench_batch_streams(int batch_count, int n_ops, int nstreams) +{ + printf("\n--- batch=%d n_ops=%d mode=serial streams=%d ---\n", + batch_count, n_ops, nstreams); + + cudaStream_t *streams = (cudaStream_t *)calloc((size_t)nstreams, sizeof(cudaStream_t)); + uint8_t **d_pk = (uint8_t **)calloc((size_t)nstreams, sizeof(uint8_t *)); + uint8_t **d_sk = (uint8_t **)calloc((size_t)nstreams, sizeof(uint8_t *)); + uint8_t **d_ct = (uint8_t **)calloc((size_t)nstreams, sizeof(uint8_t *)); + uint8_t **d_ss = (uint8_t **)calloc((size_t)nstreams, sizeof(uint8_t *)); + uint8_t **d_coins_kg = (uint8_t **)calloc((size_t)nstreams, sizeof(uint8_t *)); + uint8_t **d_coins_enc = (uint8_t **)calloc((size_t)nstreams, sizeof(uint8_t *)); + if (!streams || !d_pk || !d_sk || !d_ct || !d_ss || !d_coins_kg || !d_coins_enc) { + fprintf(stderr, "OOM\n"); + exit(1); + } + + size_t kg_bytes = (size_t)batch_count * 2 * PARAM_SYMBYTES; + size_t enc_bytes = (size_t)batch_count * PARAM_SYMBYTES; + uint8_t *h_coins_kg = (uint8_t *)malloc(kg_bytes); + uint8_t *h_coins_enc = (uint8_t *)malloc(enc_bytes); + if (!h_coins_kg || !h_coins_enc) { fprintf(stderr, "OOM\n"); exit(1); } + + srand(5678); + for (size_t i = 0; i < kg_bytes; i++) h_coins_kg[i] = (uint8_t)(rand() & 0xFF); + for (size_t i = 0; i < enc_bytes; i++) h_coins_enc[i] = (uint8_t)(rand() & 0xFF); + + for (int s = 0; s < nstreams; s++) { + CUDA_CHECK(cudaStreamCreate(&streams[s])); + CUDA_CHECK(cudaMalloc(&d_pk[s], (size_t)batch_count * PARAM_PUBLICKEYBYTES)); + CUDA_CHECK(cudaMalloc(&d_sk[s], (size_t)batch_count * PARAM_SECRETKEYBYTES)); + CUDA_CHECK(cudaMalloc(&d_ct[s], (size_t)batch_count * PARAM_CIPHERTEXTBYTES)); + CUDA_CHECK(cudaMalloc(&d_ss[s], (size_t)batch_count * PARAM_SSBYTES)); + CUDA_CHECK(cudaMalloc(&d_coins_kg[s], kg_bytes)); + CUDA_CHECK(cudaMalloc(&d_coins_enc[s], enc_bytes)); + CUDA_CHECK(cudaMemcpyAsync(d_coins_kg[s], h_coins_kg, kg_bytes, cudaMemcpyHostToDevice, streams[s])); + CUDA_CHECK(cudaMemcpyAsync(d_coins_enc[s], h_coins_enc, enc_bytes, cudaMemcpyHostToDevice, streams[s])); + } + CUDA_CHECK(cudaDeviceSynchronize()); + + double total_ops = (double)batch_count * (double)nstreams; + + CUDA_CHECK(cudaDeviceSynchronize()); + double t0 = get_time_ms(); + int kg_tpb = KEM_KEYGEN_TPB; + int kg_blocks = (batch_count + kg_tpb - 1) / kg_tpb; + for (int op = 0; op < n_ops; op++) + for (int s = 0; s < nstreams; s++) + batch_kem_keypair_serial_kernel<<>>( + d_pk[s], d_sk[s], d_coins_kg[s], batch_count); + CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaGetLastError()); + double t_kg = (get_time_ms() - t0) / n_ops; + printf(" Keygen: %7.1f ms/round -> %.0f ops/sec\n", t_kg, total_ops * 1000.0 / t_kg); + + t0 = get_time_ms(); + int enc_tpb = KEM_ENCAPS_TPB; + int enc_blocks = (batch_count + enc_tpb - 1) / enc_tpb; + for (int op = 0; op < n_ops; op++) + for (int s = 0; s < nstreams; s++) + batch_kem_encaps_serial_kernel<<>>( + d_ct[s], d_ss[s], d_pk[s], d_coins_enc[s], batch_count); + CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaGetLastError()); + double t_enc = (get_time_ms() - t0) / n_ops; + printf(" Encaps: %7.1f ms/round -> %.0f ops/sec\n", t_enc, total_ops * 1000.0 / t_enc); + + t0 = get_time_ms(); + int dec_tpb = KEM_DECAPS_TPB; + int dec_blocks = (batch_count + dec_tpb - 1) / dec_tpb; + for (int op = 0; op < n_ops; op++) + for (int s = 0; s < nstreams; s++) + batch_kem_decaps_serial_kernel<<>>( + d_ss[s], d_ct[s], d_sk[s], batch_count); + CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaGetLastError()); + double t_dec = (get_time_ms() - t0) / n_ops; + printf(" Decaps: %7.1f ms/round -> %.0f ops/sec\n", t_dec, total_ops * 1000.0 / t_dec); + + for (int s = 0; s < nstreams; s++) { + cudaFree(d_pk[s]); cudaFree(d_sk[s]); cudaFree(d_ct[s]); cudaFree(d_ss[s]); + cudaFree(d_coins_kg[s]); cudaFree(d_coins_enc[s]); + cudaStreamDestroy(streams[s]); + } + free(h_coins_kg); free(h_coins_enc); + free(streams); free(d_pk); free(d_sk); free(d_ct); free(d_ss); free(d_coins_kg); free(d_coins_enc); +} + +static void bench_sweep(void) +{ + int sizes[] = { 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072 }; + int n = (int)(sizeof(sizes) / sizeof(sizes[0])); + printf("\n=== Batch size 扫描: %s ===\n", algo_name()); + for (int i = 0; i < n; i++) { + bench_batch(sizes[i], 3, 0); + } +} + +static const char *arg_value(int argc, char **argv, const char *name) +{ + for (int i = 1; i + 1 < argc; i++) { + if (strcmp(argv[i], name) == 0) return argv[i + 1]; + } + return NULL; +} + +static int has_arg(int argc, char **argv, const char *name) +{ + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], name) == 0) return 1; + } + return 0; +} + +static int read_file_all_host(const char *path, uint8_t **out, size_t *out_len) +{ + FILE *f = fopen(path, "rb"); + long n; + uint8_t *buf; + if (!f) { + fprintf(stderr, "open failed: %s\n", path); + return -1; + } + if (fseek(f, 0, SEEK_END) != 0) { + fclose(f); + return -1; + } + n = ftell(f); + if (n < 0) { + fclose(f); + return -1; + } + if (fseek(f, 0, SEEK_SET) != 0) { + fclose(f); + return -1; + } + buf = (uint8_t *)malloc((size_t)n + 1u); + if (!buf) { + fclose(f); + return -1; + } + if ((size_t)n > 0 && fread(buf, 1, (size_t)n, f) != (size_t)n) { + free(buf); + fclose(f); + return -1; + } + fclose(f); + buf[n] = 0; + *out = buf; + *out_len = (size_t)n; + return 0; +} + +static int read_file_exact_host(const char *path, uint8_t *buf, size_t len) +{ + uint8_t *tmp = NULL; + size_t n = 0; + int rc = read_file_all_host(path, &tmp, &n); + if (rc != 0) return rc; + if (n != len) { + fprintf(stderr, "size mismatch: %s expected %zu got %zu\n", path, len, n); + free(tmp); + return -1; + } + memcpy(buf, tmp, len); + free(tmp); + return 0; +} + +static int write_file_all_host(const char *path, const uint8_t *buf, size_t len) +{ + FILE *f = fopen(path, "wb"); + if (!f) { + fprintf(stderr, "write open failed: %s\n", path); + return -1; + } + if (len > 0 && fwrite(buf, 1, len, f) != len) { + fclose(f); + return -1; + } + fclose(f); + return 0; +} + +static void fill_random_host(uint8_t *buf, size_t len) +{ + FILE *f = fopen("/dev/urandom", "rb"); + if (f) { + size_t n = fread(buf, 1, len, f); + fclose(f); + if (n == len) return; + } + srand((unsigned)time(NULL)); + for (size_t i = 0; i < len; i++) buf[i] = (uint8_t)(rand() & 0xff); +} + +static void duplicate_record(uint8_t *dst, const uint8_t *src, size_t item_len, int batch_count) +{ + for (int i = 0; i < batch_count; i++) { + memcpy(dst + (size_t)i * item_len, src, item_len); + } +} + +static int run_kem_api_mode(int argc, char **argv, int batch_count) +{ + const int do_keygen = has_arg(argc, argv, "--api-kem-keygen"); + const int do_encaps = has_arg(argc, argv, "--api-kem-encaps"); + const int do_decaps = has_arg(argc, argv, "--api-kem-decaps"); + if (!do_keygen && !do_encaps && !do_decaps) return 0; + if ((do_keygen ? 1 : 0) + (do_encaps ? 1 : 0) + (do_decaps ? 1 : 0) != 1) { + fprintf(stderr, "select exactly one KEM API mode\n"); + return 2; + } + if (batch_count < 1) batch_count = 1; + + const size_t pk_batch_bytes = (size_t)batch_count * PARAM_PUBLICKEYBYTES; + const size_t sk_batch_bytes = (size_t)batch_count * PARAM_SECRETKEYBYTES; + const size_t ct_batch_bytes = (size_t)batch_count * PARAM_CIPHERTEXTBYTES; + const size_t ss_batch_bytes = (size_t)batch_count * PARAM_SSBYTES; + const size_t kg_bytes = (size_t)batch_count * 2 * PARAM_SYMBYTES; + const size_t enc_bytes = (size_t)batch_count * PARAM_SYMBYTES; + + if (do_keygen) { + const char *pk_out = arg_value(argc, argv, "--pk-out"); + const char *sk_out = arg_value(argc, argv, "--sk-out"); + if (!pk_out || !sk_out) { + fprintf(stderr, "--api-kem-keygen requires --pk-out and --sk-out\n"); + return 2; + } + + uint8_t *h_pk = (uint8_t *)malloc(PARAM_PUBLICKEYBYTES); + uint8_t *h_sk = (uint8_t *)malloc(PARAM_SECRETKEYBYTES); + uint8_t *h_coins_kg = (uint8_t *)malloc(kg_bytes); + uint8_t *d_pk = NULL, *d_sk = NULL, *d_coins_kg = NULL; + if (!h_pk || !h_sk || !h_coins_kg) { + fprintf(stderr, "KEM API malloc failed\n"); + free(h_pk); free(h_sk); free(h_coins_kg); + return 2; + } + fill_random_host(h_coins_kg, kg_bytes); + CUDA_CHECK(cudaMalloc(&d_pk, pk_batch_bytes)); + CUDA_CHECK(cudaMalloc(&d_sk, sk_batch_bytes)); + CUDA_CHECK(cudaMalloc(&d_coins_kg, kg_bytes)); + CUDA_CHECK(cudaMemcpy(d_coins_kg, h_coins_kg, kg_bytes, cudaMemcpyHostToDevice)); + int tpb = KEM_KEYGEN_TPB; + int blocks = (batch_count + tpb - 1) / tpb; + batch_kem_keypair_serial_kernel<<>>(d_pk, d_sk, d_coins_kg, batch_count); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaMemcpy(h_pk, d_pk, PARAM_PUBLICKEYBYTES, cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_sk, d_sk, PARAM_SECRETKEYBYTES, cudaMemcpyDeviceToHost)); + int rc = 0; + if (write_file_all_host(pk_out, h_pk, PARAM_PUBLICKEYBYTES) != 0 || + write_file_all_host(sk_out, h_sk, PARAM_SECRETKEYBYTES) != 0) rc = 2; + cudaFree(d_pk); cudaFree(d_sk); cudaFree(d_coins_kg); + free(h_pk); free(h_sk); free(h_coins_kg); + if (rc == 0) printf("API KEM keygen PASS batch=%d pk=%u sk=%u\n", batch_count, PARAM_PUBLICKEYBYTES, PARAM_SECRETKEYBYTES); + return rc == 0 ? 1 : rc; + } + + if (do_encaps) { + const char *pk_in = arg_value(argc, argv, "--pk-in"); + const char *ct_out = arg_value(argc, argv, "--ct-out"); + const char *ss_out = arg_value(argc, argv, "--ss-out"); + if (!pk_in || !ct_out || !ss_out) { + fprintf(stderr, "--api-kem-encaps requires --pk-in, --ct-out, and --ss-out\n"); + return 2; + } + + uint8_t *h_pk_one = (uint8_t *)malloc(PARAM_PUBLICKEYBYTES); + uint8_t *h_pk = (uint8_t *)malloc(pk_batch_bytes); + uint8_t *h_ct = (uint8_t *)malloc(PARAM_CIPHERTEXTBYTES); + uint8_t *h_ss = (uint8_t *)malloc(PARAM_SSBYTES); + uint8_t *h_coins_enc = (uint8_t *)malloc(enc_bytes); + uint8_t *d_pk = NULL, *d_ct = NULL, *d_ss = NULL, *d_coins_enc = NULL; + if (!h_pk_one || !h_pk || !h_ct || !h_ss || !h_coins_enc) { + fprintf(stderr, "KEM API malloc failed\n"); + free(h_pk_one); free(h_pk); free(h_ct); free(h_ss); free(h_coins_enc); + return 2; + } + if (read_file_exact_host(pk_in, h_pk_one, PARAM_PUBLICKEYBYTES) != 0) { + free(h_pk_one); free(h_pk); free(h_ct); free(h_ss); free(h_coins_enc); + return 2; + } + duplicate_record(h_pk, h_pk_one, PARAM_PUBLICKEYBYTES, batch_count); + fill_random_host(h_coins_enc, enc_bytes); + CUDA_CHECK(cudaMalloc(&d_pk, pk_batch_bytes)); + CUDA_CHECK(cudaMalloc(&d_ct, ct_batch_bytes)); + CUDA_CHECK(cudaMalloc(&d_ss, ss_batch_bytes)); + CUDA_CHECK(cudaMalloc(&d_coins_enc, enc_bytes)); + CUDA_CHECK(cudaMemcpy(d_pk, h_pk, pk_batch_bytes, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_coins_enc, h_coins_enc, enc_bytes, cudaMemcpyHostToDevice)); + int tpb = KEM_ENCAPS_TPB; + int blocks = (batch_count + tpb - 1) / tpb; + batch_kem_encaps_serial_kernel<<>>(d_ct, d_ss, d_pk, d_coins_enc, batch_count); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaMemcpy(h_ct, d_ct, PARAM_CIPHERTEXTBYTES, cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_ss, d_ss, PARAM_SSBYTES, cudaMemcpyDeviceToHost)); + int rc = 0; + if (write_file_all_host(ct_out, h_ct, PARAM_CIPHERTEXTBYTES) != 0 || + write_file_all_host(ss_out, h_ss, PARAM_SSBYTES) != 0) rc = 2; + cudaFree(d_pk); cudaFree(d_ct); cudaFree(d_ss); cudaFree(d_coins_enc); + free(h_pk_one); free(h_pk); free(h_ct); free(h_ss); free(h_coins_enc); + if (rc == 0) printf("API KEM encaps PASS batch=%d ct=%u ss=%u\n", batch_count, PARAM_CIPHERTEXTBYTES, PARAM_SSBYTES); + return rc == 0 ? 1 : rc; + } + + if (do_decaps) { + const char *sk_in = arg_value(argc, argv, "--sk-in"); + const char *ct_in = arg_value(argc, argv, "--ct-in"); + const char *ss_out = arg_value(argc, argv, "--ss-out"); + if (!sk_in || !ct_in || !ss_out) { + fprintf(stderr, "--api-kem-decaps requires --sk-in, --ct-in, and --ss-out\n"); + return 2; + } + + uint8_t *h_sk_one = (uint8_t *)malloc(PARAM_SECRETKEYBYTES); + uint8_t *h_ct_one = (uint8_t *)malloc(PARAM_CIPHERTEXTBYTES); + uint8_t *h_sk = (uint8_t *)malloc(sk_batch_bytes); + uint8_t *h_ct = (uint8_t *)malloc(ct_batch_bytes); + uint8_t *h_ss = (uint8_t *)malloc(PARAM_SSBYTES); + uint8_t *d_sk = NULL, *d_ct = NULL, *d_ss = NULL; + if (!h_sk_one || !h_ct_one || !h_sk || !h_ct || !h_ss) { + fprintf(stderr, "KEM API malloc failed\n"); + free(h_sk_one); free(h_ct_one); free(h_sk); free(h_ct); free(h_ss); + return 2; + } + if (read_file_exact_host(sk_in, h_sk_one, PARAM_SECRETKEYBYTES) != 0 || + read_file_exact_host(ct_in, h_ct_one, PARAM_CIPHERTEXTBYTES) != 0) { + free(h_sk_one); free(h_ct_one); free(h_sk); free(h_ct); free(h_ss); + return 2; + } + duplicate_record(h_sk, h_sk_one, PARAM_SECRETKEYBYTES, batch_count); + duplicate_record(h_ct, h_ct_one, PARAM_CIPHERTEXTBYTES, batch_count); + CUDA_CHECK(cudaMalloc(&d_sk, sk_batch_bytes)); + CUDA_CHECK(cudaMalloc(&d_ct, ct_batch_bytes)); + CUDA_CHECK(cudaMalloc(&d_ss, ss_batch_bytes)); + CUDA_CHECK(cudaMemcpy(d_sk, h_sk, sk_batch_bytes, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_ct, h_ct, ct_batch_bytes, cudaMemcpyHostToDevice)); + int tpb = KEM_DECAPS_TPB; + int blocks = (batch_count + tpb - 1) / tpb; + batch_kem_decaps_serial_kernel<<>>(d_ss, d_ct, d_sk, batch_count); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaMemcpy(h_ss, d_ss, PARAM_SSBYTES, cudaMemcpyDeviceToHost)); + int rc = write_file_all_host(ss_out, h_ss, PARAM_SSBYTES) != 0 ? 2 : 0; + cudaFree(d_sk); cudaFree(d_ct); cudaFree(d_ss); + free(h_sk_one); free(h_ct_one); free(h_sk); free(h_ct); free(h_ss); + if (rc == 0) printf("API KEM decaps PASS batch=%d ss=%u\n", batch_count, PARAM_SSBYTES); + return rc == 0 ? 1 : rc; + } + + return 0; +} + +/* ================================================================ + * 主函数 + * ================================================================ */ +int main(int argc, char **argv) +{ + /* 解析参数 */ + int batch_count = 65536; + int n_ops = 5; + int do_sweep = 0; + int run_pipeline = 0; + int do_correctness = 1; + int nstreams = 1; + int profile_pipeline = 0; + int reuse_rounds = 0; + + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "--batch") == 0 && i + 1 < argc) + batch_count = atoi(argv[++i]); + else if (strcmp(argv[i], "--n-ops") == 0 && i + 1 < argc) + n_ops = atoi(argv[++i]); + else if (strcmp(argv[i], "--sweep") == 0) + do_sweep = 1; + else if (strcmp(argv[i], "--serial-only") == 0) + run_pipeline = 0; + else if (strcmp(argv[i], "--pipeline") == 0) + run_pipeline = 1; + else if (strcmp(argv[i], "--no-correctness") == 0) + do_correctness = 0; + else if (strcmp(argv[i], "--streams") == 0 && i + 1 < argc) + nstreams = atoi(argv[++i]); + else if (strcmp(argv[i], "--profile-pipeline") == 0) + profile_pipeline = 1; + else if (strcmp(argv[i], "--reuse-bench") == 0 && i + 1 < argc) + reuse_rounds = atoi(argv[++i]); + } + + /* 打印设备信息 */ + int dev; + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDevice(&dev)); + CUDA_CHECK(cudaGetDeviceProperties(&prop, dev)); + #if GPU_USE_HIP + printf("GPU: %s (%s, %d CUs, %.1f GB VRAM)\n", + prop.name, + prop.gcnArchName, + prop.multiProcessorCount, + prop.totalGlobalMem / 1e9); + #else + printf("GPU: %s (SM %d.%d, %d SMs, %.1f GB VRAM)\n", + prop.name, prop.major, prop.minor, + prop.multiProcessorCount, + prop.totalGlobalMem / 1e9); + #endif + printf("Runtime: %s\n", GPU_RUNTIME_NAME); + printf("Algorithm: %s K=%d Q=%d\n", algo_name(), PARAM_K, PARAM_Q); + + /* 设置 GPU 堆栈大小 (kem 函数需要 ~20KB 堆栈) */ + { + cudaError_t se = cudaDeviceSetLimit(cudaLimitStackSize, 64 * 1024); + if (se != cudaSuccess) { + fprintf(stderr, "Warning: cudaDeviceSetLimit(stack, 64KB) failed: %s\n", + cudaGetErrorString(se)); + cudaGetLastError(); /* 清除错误状态 */ + } + } + + int api_rc = run_kem_api_mode(argc, argv, batch_count); + if (api_rc != 0) return api_rc == 1 ? 0 : api_rc; + + /* 正确性测试 */ + if (do_correctness) { + int ret = test_correctness(); + if (ret != 0) { + fprintf(stderr, "正确性测试失败,中止性能测试\n"); + return ret; + } + printf("\n"); + } + + /* 吞吐量测试 */ + if (reuse_rounds > 0) { + bench_reuse_buffers(batch_count, reuse_rounds, n_ops); + } else if (do_sweep) { + bench_sweep(); + } else { + printf("=== 吞吐量测试: %s ===\n", algo_name()); + if (nstreams > 1) + bench_batch_streams(batch_count, n_ops, nstreams); + else + bench_batch(batch_count, n_ops, 0, 0); /* serial mode default */ + if (run_pipeline) { + /* 流水线模式 */ + bench_batch(batch_count, n_ops, 1, profile_pipeline); + } + } + + printf("\n完成.\n"); + return 0; +} diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/parse_kem_results.py b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/parse_kem_results.py new file mode 100644 index 000000000..2585b64ad --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/parse_kem_results.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +import csv +import re +import sys +from pathlib import Path + + +if len(sys.argv) != 2: + print("usage: parse_kem_results.py ", file=sys.stderr) + raise SystemExit(2) + + +log_dir = Path(sys.argv[1]) +rows = [] + +gpu_re = re.compile(r"^GPU:\s+(.+?)\s+\(") +runtime_re = re.compile(r"^Runtime:\s+(\S+)") +algorithm_re = re.compile(r"^Algorithm:\s+(.+?)\s+K=(\d+)\s+Q=(\d+)") +batch_re = re.compile(r"^---\s+batch=(\d+)\s+n_ops=(\d+)\s+mode=([^\s]+)(?:\s+streams=(\d+))?") +op_re = re.compile(r"^\s+(Keygen|Encaps|Decaps):\s+([0-9.]+)\s+ms/(?:batch|round)\s+(?:->|.)\s+([0-9.]+)\s+ops/sec") +profile_re = re.compile( + r"Pipeline profile:\s+sample=([0-9.]+)\s+ntt=([0-9.]+)\s+matvec=([0-9.]+)\s+" + r"invntt=([0-9.]+)\s+add=([0-9.]+)\s+pack=([0-9.]+)\s+total=([0-9.]+)\s+ms" +) + + +def new_row(log_name: str, common: dict, batch: str, n_ops: str, mode: str, streams: str) -> dict: + row = { + "algorithm": common.get("algorithm", ""), + "k": common.get("k", ""), + "q": common.get("q", ""), + "runtime": common.get("runtime", ""), + "gpu": common.get("gpu", ""), + "batch": batch, + "n_ops": n_ops, + "mode": mode, + "streams": streams or "1", + "keygen_ms": "", + "keygen_ops_s": "", + "encaps_ms": "", + "encaps_ops_s": "", + "decaps_ms": "", + "decaps_ops_s": "", + "correctness": common.get("correctness", "UNKNOWN"), + "profile_sample_ms": "", + "profile_ntt_ms": "", + "profile_matvec_ms": "", + "profile_invntt_ms": "", + "profile_add_ms": "", + "profile_pack_ms": "", + "profile_total_ms": "", + "status": "PASS", + "log": log_name, + } + return row + + +for log_path in sorted(log_dir.glob("*.log")): + text = log_path.read_text(errors="replace") + common = {"correctness": "UNKNOWN"} + current = None + + if "FAIL" in text: + common["correctness"] = "FAIL" + elif "KEM" in text and "PASS" in text: + common["correctness"] = "PASS" + + for line in text.splitlines(): + m = gpu_re.search(line) + if m: + common["gpu"] = m.group(1) + continue + m = runtime_re.search(line) + if m: + common["runtime"] = m.group(1) + continue + m = algorithm_re.search(line) + if m: + common["algorithm"] = m.group(1) + common["k"] = m.group(2) + common["q"] = m.group(3) + continue + m = batch_re.search(line) + if m: + if current: + rows.append(current) + current = new_row(log_path.name, common, m.group(1), m.group(2), m.group(3), m.group(4)) + continue + m = op_re.search(line) + if m and current: + stage = m.group(1).lower() + current[f"{stage}_ms"] = m.group(2) + current[f"{stage}_ops_s"] = m.group(3) + continue + m = profile_re.search(line) + if m and current: + keys = [ + "profile_sample_ms", + "profile_ntt_ms", + "profile_matvec_ms", + "profile_invntt_ms", + "profile_add_ms", + "profile_pack_ms", + "profile_total_ms", + ] + for key, value in zip(keys, m.groups()): + current[key] = value + continue + + if current: + if "exit_code=" in text and "exit_code=0" not in text: + current["status"] = "FAIL" + if common.get("correctness") == "FAIL": + current["status"] = "FAIL" + rows.append(current) + + +fieldnames = [ + "algorithm", + "k", + "q", + "runtime", + "gpu", + "batch", + "n_ops", + "mode", + "streams", + "keygen_ms", + "keygen_ops_s", + "encaps_ms", + "encaps_ops_s", + "decaps_ms", + "decaps_ops_s", + "correctness", + "profile_sample_ms", + "profile_ntt_ms", + "profile_matvec_ms", + "profile_invntt_ms", + "profile_add_ms", + "profile_pack_ms", + "profile_total_ms", + "status", + "log", +] + +writer = csv.DictWriter(sys.stdout, fieldnames=fieldnames, lineterminator="\n") +writer.writeheader() +for row in rows: + writer.writerow({key: row.get(key, "") for key in fieldnames}) diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/profile_kem_one_amd.sh b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/profile_kem_one_amd.sh new file mode 100644 index 000000000..61c92ad36 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/profile_kem_one_amd.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd "$(dirname "$0")" + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +exe="${1:-kyber768_amd}" +batch="${2:-8192}" +iters="${3:-3}" + +mkdir -p amd_results/profile + +if [[ ! -x "./${exe}" ]]; then + echo "error: ./${exe} not found or not executable" >&2 + exit 1 +fi + +plain_log="amd_results/profile/${exe}_b${batch}_profile.log" +rocprof_dir="amd_results/profile/${exe}_b${batch}_rocprof" + +echo "[profile] app-level pipeline profile: ${exe} batch=${batch}" +stdbuf -oL -eL "./${exe}" --batch "${batch}" --n-ops "${iters}" \ + --no-correctness --pipeline --profile-pipeline \ + 2>&1 | tee "${plain_log}" + +if command -v rocprofv3 >/dev/null 2>&1; then + echo "[profile] rocprofv3 output: ${rocprof_dir}" + rm -rf "${rocprof_dir}" + mkdir -p "${rocprof_dir}" + rocprofv3 --output-directory "${rocprof_dir}" --timestamp on -- \ + "./${exe}" --batch "${batch}" --n-ops 1 --no-correctness --pipeline \ + 2>&1 | tee "amd_results/profile/${exe}_b${batch}_rocprof.log" +else + echo "[profile] rocprofv3 not found; skipped ROCm trace" +fi diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_all_bounds_probe_amd.sh b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_all_bounds_probe_amd.sh new file mode 100644 index 000000000..938fdd59b --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_all_bounds_probe_amd.sh @@ -0,0 +1,207 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd "$(dirname "$0")" + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +N_OPS=${N_OPS:-20} +REPEATS=${REPEATS:-2} +DO_CORRECTNESS=${DO_CORRECTNESS:-1} +KYBER_BATCH=${KYBER_BATCH:-32768} +AIGIS_BATCH=${AIGIS_BATCH:-65536} +KEM_KEYGEN_TPB=${KEM_KEYGEN_TPB:-256} +KEM_ENCAPS_TPB=${KEM_ENCAPS_TPB:-128} +KEM_DECAPS_TPB=${KEM_DECAPS_TPB:-128} +OPT_LEVEL=${OPT_LEVEL:-O2} + +stamp="$(date +%Y%m%d_%H%M%S)" +out_dir="amd_results/all_bounds_probe_${stamp}" +mkdir -p "${out_dir}" + +summary="${out_dir}/all_bounds_probe_raw.csv" +avg_summary="${out_dir}/all_bounds_probe_avg.csv" +best_summary="${out_dir}/all_bounds_probe_best.csv" + +echo "target,algorithm_group,tag,repeat,batch,n_ops,opt,kg_tpb,enc_tpb,dec_tpb,keypair_bounds,encaps_bounds,decaps_bounds,keygen_ops_s,encaps_ops_s,decaps_ops_s,status,log" > "${summary}" + +extract_metric() { + local label="$1" + local log="$2" + grep -E " ${label}:" "${log}" \ + | tail -1 \ + | grep -oE '[0-9]+ ops/sec' \ + | tail -1 \ + | awk '{print $1}' +} + +group_for_target() { + local target="$1" + if [[ "${target}" == kyber* ]]; then + echo "Kyber" + else + echo "Aigis-enc" + fi +} + +batch_for_target() { + local target="$1" + if [[ "${target}" == kyber* ]]; then + echo "${KYBER_BATCH}" + else + echo "${AIGIS_BATCH}" + fi +} + +run_candidate() { + local target="$1" + local kb="$2" + local eb="$3" + local db="$4" + local tag="bounds${kb}${eb}${db}" + local group + local batch + group="$(group_for_target "${target}")" + batch="$(batch_for_target "${target}")" + + for rep in $(seq 1 "${REPEATS}"); do + local log="${out_dir}/${target}_${tag}_r${rep}.log" + local status="PASS" + { + echo "========== ${target} ${tag} repeat=${rep}/${REPEATS} ==========" + echo "timestamp=$(date '+%Y-%m-%d %H:%M:%S')" + echo "batch=${batch} n_ops=${N_OPS} opt=${OPT_LEVEL}" + echo "tpb=${KEM_KEYGEN_TPB}/${KEM_ENCAPS_TPB}/${KEM_DECAPS_TPB} bounds=${kb}/${eb}/${db}" + OPT_LEVEL="${OPT_LEVEL}" \ + KEM_KEYGEN_TPB="${KEM_KEYGEN_TPB}" KEM_ENCAPS_TPB="${KEM_ENCAPS_TPB}" KEM_DECAPS_TPB="${KEM_DECAPS_TPB}" \ + KEM_KEYPAIR_LAUNCH_BOUNDS="${kb}" KEM_ENCAPS_LAUNCH_BOUNDS="${eb}" KEM_DECAPS_LAUNCH_BOUNDS="${db}" \ + bash build_hip.sh "${target}" + + if [[ "${DO_CORRECTNESS}" == "1" ]]; then + "./${target}_amd" --batch 128 --n-ops 1 + fi + + "./${target}_amd" --batch "${batch}" --n-ops "${N_OPS}" --no-correctness + } > "${log}" 2>&1 || status="FAIL" + + local kg_ops="" + local enc_ops="" + local dec_ops="" + if [[ "${status}" == "PASS" ]]; then + kg_ops="$(extract_metric Keygen "${log}" || true)" + enc_ops="$(extract_metric Encaps "${log}" || true)" + dec_ops="$(extract_metric Decaps "${log}" || true)" + fi + + echo "${target},${group},${tag},${rep},${batch},${N_OPS},${OPT_LEVEL},${KEM_KEYGEN_TPB},${KEM_ENCAPS_TPB},${KEM_DECAPS_TPB},${kb},${eb},${db},${kg_ops},${enc_ops},${dec_ops},${status},${log}" | tee -a "${summary}" + done +} + +echo "[all-bounds] out=${out_dir}" +echo "[all-bounds] n_ops=${N_OPS} repeats=${REPEATS} correctness=${DO_CORRECTNESS}" +echo "[all-bounds] Kyber batch=${KYBER_BATCH}, Aigis batch=${AIGIS_BATCH}" + +targets=(kyber512 kyber768 kyber1024 aigisenc1 aigisenc2 aigisenc3 aigisenc4) +bounds=(000 001 010 011 100 101 110 111) + +for target in "${targets[@]}"; do + echo + echo "[target] ${target}" + for b in "${bounds[@]}"; do + run_candidate "${target}" "${b:0:1}" "${b:1:1}" "${b:2:1}" + done +done + +python3 - "${summary}" "${avg_summary}" "${best_summary}" <<'PY' +import csv +import sys +from collections import defaultdict + +raw_path, avg_path, best_path = sys.argv[1:4] +rows = [] +with open(raw_path, newline="", encoding="utf-8") as f: + for row in csv.DictReader(f): + if row.get("status") != "PASS": + continue + try: + row["_keygen"] = float(row["keygen_ops_s"]) + row["_encaps"] = float(row["encaps_ops_s"]) + row["_decaps"] = float(row["decaps_ops_s"]) + except (TypeError, ValueError): + continue + rows.append(row) + +groups = defaultdict(list) +for row in rows: + key = ( + row["target"], row["algorithm_group"], row["tag"], row["batch"], row["n_ops"], + row["opt"], row["kg_tpb"], row["enc_tpb"], row["dec_tpb"], + row["keypair_bounds"], row["encaps_bounds"], row["decaps_bounds"], + ) + groups[key].append(row) + +avg_rows = [] +for key, vals in groups.items(): + target, group, tag, batch, n_ops, opt, kg_tpb, enc_tpb, dec_tpb, kb, eb, db = key + count = len(vals) + kg = sum(v["_keygen"] for v in vals) / count + enc = sum(v["_encaps"] for v in vals) / count + dec = sum(v["_decaps"] for v in vals) / count + # Balanced score avoids selecting configs that improve one operation while + # badly hurting another. Encaps still gets a small extra weight because the + # current optimization signal is launch-bounds-sensitive encaps. + score = 0.30 * kg + 0.40 * enc + 0.30 * dec + avg_rows.append({ + "target": target, + "algorithm_group": group, + "tag": tag, + "batch": batch, + "n_ops": n_ops, + "opt": opt, + "kg_tpb": kg_tpb, + "enc_tpb": enc_tpb, + "dec_tpb": dec_tpb, + "keypair_bounds": kb, + "encaps_bounds": eb, + "decaps_bounds": db, + "repeats": count, + "keygen_avg_ops_s": round(kg), + "encaps_avg_ops_s": round(enc), + "decaps_avg_ops_s": round(dec), + "balanced_score": round(score), + }) + +fieldnames = [ + "target", "algorithm_group", "tag", "batch", "n_ops", "opt", + "kg_tpb", "enc_tpb", "dec_tpb", "keypair_bounds", "encaps_bounds", + "decaps_bounds", "repeats", "keygen_avg_ops_s", "encaps_avg_ops_s", + "decaps_avg_ops_s", "balanced_score", +] + +avg_rows.sort(key=lambda r: (r["target"], -int(r["balanced_score"]))) +with open(avg_path, "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + w.writerows(avg_rows) + +best_by_target = {} +for row in avg_rows: + target = row["target"] + if target not in best_by_target or int(row["balanced_score"]) > int(best_by_target[target]["balanced_score"]): + best_by_target[target] = row + +with open(best_path, "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + for target in ["kyber512", "kyber768", "kyber1024", "aigisenc1", "aigisenc2", "aigisenc3", "aigisenc4"]: + if target in best_by_target: + w.writerow(best_by_target[target]) +PY + +echo +echo "[done] raw=${summary}" +echo "[done] avg=${avg_summary}" +echo "[done] best=${best_summary}" +echo +echo "[best]" +cat "${best_summary}" diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_all_profile_compare_amd.sh b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_all_profile_compare_amd.sh new file mode 100644 index 000000000..97155c277 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_all_profile_compare_amd.sh @@ -0,0 +1,144 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd "$(dirname "$0")" + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +N_OPS=${N_OPS:-20} +PROFILE_N_OPS=${PROFILE_N_OPS:-1} +DO_CORRECTNESS=${DO_CORRECTNESS:-0} +KYBER_BATCH=${KYBER_BATCH:-32768} +AIGIS_BATCH=${AIGIS_BATCH:-65536} +KEM_KEYGEN_TPB=${KEM_KEYGEN_TPB:-256} +KEM_ENCAPS_TPB=${KEM_ENCAPS_TPB:-128} +KEM_DECAPS_TPB=${KEM_DECAPS_TPB:-128} +OPT_LEVEL=${OPT_LEVEL:-O2} + +stamp="$(date +%Y%m%d_%H%M%S)" +out_dir="amd_results/profile_compare_${stamp}" +mkdir -p "${out_dir}" + +runs_csv="${out_dir}/profile_compare_runs.csv" +echo "target,config,bounds,batch,n_ops,opt,kg_tpb,enc_tpb,dec_tpb,keypair_bounds,encaps_bounds,decaps_bounds,keygen_ops_s,encaps_ops_s,decaps_ops_s,status,run_dir" > "${runs_csv}" + +extract_metric() { + local label="$1" + local log="$2" + grep -E " ${label}:" "${log}" \ + | tail -1 \ + | grep -oE '[0-9]+ ops/sec' \ + | tail -1 \ + | awk '{print $1}' +} + +batch_for_target() { + local target="$1" + if [[ "${target}" == kyber* ]]; then + echo "${KYBER_BATCH}" + else + echo "${AIGIS_BATCH}" + fi +} + +tuned_bounds_for_target() { + case "$1" in + kyber512) echo "001" ;; + kyber768) echo "010" ;; + kyber1024) echo "110" ;; + aigisenc1) echo "101" ;; + aigisenc2) echo "110" ;; + aigisenc3) echo "101" ;; + aigisenc4) echo "101" ;; + *) echo "100" ;; + esac +} + +run_one() { + local target="$1" + local config="$2" + local bounds="$3" + local kb="${bounds:0:1}" + local eb="${bounds:1:1}" + local db="${bounds:2:1}" + local batch + batch="$(batch_for_target "${target}")" + local run_name="${target}_${config}_bounds${bounds}" + local run_dir="${out_dir}/${run_name}" + mkdir -p "${run_dir}/rocprofv3" + local status="PASS" + + { + echo "target=${target}" + echo "config=${config}" + echo "bounds=${bounds}" + echo "batch=${batch}" + echo "n_ops=${N_OPS}" + echo "profile_n_ops=${PROFILE_N_OPS}" + echo "opt=${OPT_LEVEL}" + echo "tpb=${KEM_KEYGEN_TPB}/${KEM_ENCAPS_TPB}/${KEM_DECAPS_TPB}" + echo "timestamp=$(date '+%Y-%m-%d %H:%M:%S')" + hipcc --version || true + } > "${run_dir}/metadata.txt" 2>&1 + + echo + echo "[run] ${target} ${config} bounds=${bounds} batch=${batch}" + + OPT_LEVEL="${OPT_LEVEL}" \ + KEM_KEYGEN_TPB="${KEM_KEYGEN_TPB}" KEM_ENCAPS_TPB="${KEM_ENCAPS_TPB}" KEM_DECAPS_TPB="${KEM_DECAPS_TPB}" \ + KEM_KEYPAIR_LAUNCH_BOUNDS="${kb}" KEM_ENCAPS_LAUNCH_BOUNDS="${eb}" KEM_DECAPS_LAUNCH_BOUNDS="${db}" \ + bash build_hip.sh "${target}" > "${run_dir}/build.log" 2>&1 || status="FAIL" + + if [[ "${status}" == "PASS" && "${DO_CORRECTNESS}" == "1" ]]; then + "./${target}_amd" --batch 128 --n-ops 1 > "${run_dir}/correctness.log" 2>&1 || status="FAIL" + fi + + if [[ "${status}" == "PASS" ]]; then + "./${target}_amd" --batch "${batch}" --n-ops "${N_OPS}" --no-correctness \ + > "${run_dir}/benchmark.log" 2>&1 || status="FAIL" + fi + + local kg_ops="" + local enc_ops="" + local dec_ops="" + if [[ "${status}" == "PASS" ]]; then + kg_ops="$(extract_metric Keygen "${run_dir}/benchmark.log" || true)" + enc_ops="$(extract_metric Encaps "${run_dir}/benchmark.log" || true)" + dec_ops="$(extract_metric Decaps "${run_dir}/benchmark.log" || true)" + + if command -v rocprofv3 >/dev/null 2>&1; then + rocprofv3 \ + --kernel-trace \ + --hip-trace \ + --output-format csv \ + --output-directory "${run_dir}/rocprofv3" \ + -- \ + "./${target}_amd" --batch "${batch}" --n-ops "${PROFILE_N_OPS}" --no-correctness \ + > "${run_dir}/rocprofv3.log" 2>&1 || true + else + echo "rocprofv3 not found" > "${run_dir}/rocprofv3.log" + fi + fi + + echo "${target},${config},${bounds},${batch},${N_OPS},${OPT_LEVEL},${KEM_KEYGEN_TPB},${KEM_ENCAPS_TPB},${KEM_DECAPS_TPB},${kb},${eb},${db},${kg_ops},${enc_ops},${dec_ops},${status},${run_name}" | tee -a "${runs_csv}" +} + +targets=(kyber512 kyber768 kyber1024 aigisenc1 aigisenc2 aigisenc3 aigisenc4) + +echo "[profile-compare] out=${out_dir}" +echo "[profile-compare] N_OPS=${N_OPS} PROFILE_N_OPS=${PROFILE_N_OPS} DO_CORRECTNESS=${DO_CORRECTNESS}" + +for target in "${targets[@]}"; do + run_one "${target}" baseline 100 + run_one "${target}" tuned "$(tuned_bounds_for_target "${target}")" +done + +python3 summarize_profile_compare.py "${out_dir}" | tee "${out_dir}/summarize_profile_compare.log" + +echo +echo "[done] ${out_dir}" +echo "[show] runs:" +cat "${runs_csv}" +echo +echo "[show] key kernel compare:" +cat "${out_dir}/key_kernel_compare.csv" diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_confirm_amd.sh b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_confirm_amd.sh new file mode 100644 index 000000000..21a03be59 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_confirm_amd.sh @@ -0,0 +1,94 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd "$(dirname "$0")" + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +TARGET=${1:-kyber768} +BATCH=${BATCH:-32768} +N_OPS=${N_OPS:-50} +REPEATS=${REPEATS:-3} + +stamp="$(date +%Y%m%d_%H%M%S)" +out_dir="amd_results/confirm_${TARGET}_${stamp}" +mkdir -p "${out_dir}" + +summary="${out_dir}/confirm_summary.csv" +echo "target,tag,repeat,batch,n_ops,opt,kg_tpb,enc_tpb,dec_tpb,keypair_bounds,encaps_bounds,decaps_bounds,keygen_ops_s,encaps_ops_s,decaps_ops_s,status,log" > "${summary}" + +extract_metric() { + local label="$1" + local log="$2" + grep -E " ${label}:" "${log}" \ + | tail -1 \ + | grep -oE '[0-9]+ ops/sec' \ + | tail -1 \ + | awk '{print $1}' +} + +run_candidate() { + local tag="$1" + local opt="$2" + local kg="$3" + local enc="$4" + local dec="$5" + local kb="$6" + local eb="$7" + local db="$8" + + for rep in $(seq 1 "${REPEATS}"); do + local log="${out_dir}/${TARGET}_${tag}_r${rep}.log" + local status="PASS" + { + echo "========== ${TARGET} ${tag} repeat=${rep}/${REPEATS} ==========" + echo "timestamp=$(date '+%Y-%m-%d %H:%M:%S')" + echo "batch=${BATCH} n_ops=${N_OPS}" + echo "OPT_LEVEL=${opt} KEM_KEYGEN_TPB=${kg} KEM_ENCAPS_TPB=${enc} KEM_DECAPS_TPB=${dec}" + echo "bounds=${kb}/${eb}/${db}" + OPT_LEVEL="${opt}" \ + KEM_KEYGEN_TPB="${kg}" KEM_ENCAPS_TPB="${enc}" KEM_DECAPS_TPB="${dec}" \ + KEM_KEYPAIR_LAUNCH_BOUNDS="${kb}" KEM_ENCAPS_LAUNCH_BOUNDS="${eb}" KEM_DECAPS_LAUNCH_BOUNDS="${db}" \ + bash build_hip.sh "${TARGET}" + "./${TARGET}_amd" --batch 128 --n-ops 1 + "./${TARGET}_amd" --batch "${BATCH}" --n-ops "${N_OPS}" --no-correctness + } > "${log}" 2>&1 || status="FAIL" + + local kg_ops="" + local enc_ops="" + local dec_ops="" + if [[ "${status}" == "PASS" ]]; then + kg_ops="$(extract_metric Keygen "${log}" || true)" + enc_ops="$(extract_metric Encaps "${log}" || true)" + dec_ops="$(extract_metric Decaps "${log}" || true)" + fi + + echo "${TARGET},${tag},${rep},${BATCH},${N_OPS},${opt},${kg},${enc},${dec},${kb},${eb},${db},${kg_ops},${enc_ops},${dec_ops},${status},${log}" | tee -a "${summary}" + done +} + +echo "[confirm] target=${TARGET} batch=${BATCH} n_ops=${N_OPS} repeats=${REPEATS} out=${out_dir}" + +# Current stable baseline from the previous final report. +run_candidate baseline_o2_256_128_128_b100 O2 256 128 128 1 0 0 + +# Best balanced candidate from the first Kyber-768 tuning pass: +# encaps launch bounds improves encaps while keygen/decaps stay close to baseline. +run_candidate balanced_encbounds_o2_b010 O2 256 128 128 0 1 0 + +# Slightly higher encaps candidate in the first pass, kept separate to check +# whether keypair launch bounds changes repeat-to-repeat stability. +run_candidate encbest_o2_b110 O2 256 128 128 1 1 0 + +# Best keygen candidate observed in the first pass. +run_candidate keygenbest_o3_b100 O3 256 128 128 1 0 0 + +# Best decaps candidate observed in the first pass. +run_candidate decbest_o3_kg512_b100 O3 512 128 128 1 0 0 + +echo +echo "[done] summary=${summary}" +echo "[hint] show summary:" +echo "cat ${summary}" +echo "[hint] rank encaps:" +echo "sort -t, -k14,14nr ${summary} | head" diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_final_report_amd.sh b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_final_report_amd.sh new file mode 100644 index 000000000..59b768b83 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_final_report_amd.sh @@ -0,0 +1,72 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd "$(dirname "$0")" + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +stamp="$(date +%Y%m%d_%H%M%S)" +out_dir="amd_results/final_report_${stamp}" +mkdir -p "${out_dir}" + +meta="${out_dir}/environment.txt" +{ + echo "timestamp=${stamp}" + echo "pwd=$(pwd)" + echo + echo "== hipcc --version ==" + hipcc --version || true + echo + echo "== tools ==" + which hipcc || true + which rocprofv3 || true + which rocm-smi || true + echo + echo "== rocm-smi ==" + rocm-smi --showproductname --showdriverversion --showvbios --showmeminfo vram || true +} 2>&1 | tee "${meta}" + +summary="${out_dir}/kem_final_summary.log" +: > "${summary}" + +run_one() { + local target="$1" + local batch="$2" + local iters="$3" + local kg_tpb="$4" + local enc_tpb="$5" + local dec_tpb="$6" + local keypair_bounds="${7:-1}" + local encaps_bounds="${8:-0}" + local decaps_bounds="${9:-0}" + local log="${out_dir}/${target}_b${batch}_n${iters}_tpb${kg_tpb}_${enc_tpb}_${dec_tpb}_bounds${keypair_bounds}${encaps_bounds}${decaps_bounds}.log" + + { + echo "========== ${target} ==========" + echo "timestamp=$(date '+%Y-%m-%d %H:%M:%S')" + echo "batch=${batch} n_ops=${iters} KEM_KEYGEN_TPB=${kg_tpb} KEM_ENCAPS_TPB=${enc_tpb} KEM_DECAPS_TPB=${dec_tpb}" + echo "KEM_KEYPAIR_LAUNCH_BOUNDS=${keypair_bounds} KEM_ENCAPS_LAUNCH_BOUNDS=${encaps_bounds} KEM_DECAPS_LAUNCH_BOUNDS=${decaps_bounds}" + KEM_KEYGEN_TPB="${kg_tpb}" KEM_ENCAPS_TPB="${enc_tpb}" KEM_DECAPS_TPB="${dec_tpb}" \ + KEM_KEYPAIR_LAUNCH_BOUNDS="${keypair_bounds}" KEM_ENCAPS_LAUNCH_BOUNDS="${encaps_bounds}" KEM_DECAPS_LAUNCH_BOUNDS="${decaps_bounds}" \ + bash build_hip.sh "${target}" + "./${target}_amd" --batch "${batch}" --n-ops "${iters}" --no-correctness + echo + } 2>&1 | tee "${log}" | tee -a "${summary}" +} + +# Stable final KEM table. Kyber uses batch 32768, Aigis-enc uses batch 65536. +# Bounds are selected from the 2026-06-14 all-parameter bounds probe using the +# balanced score: 0.30*keygen + 0.40*encaps + 0.30*decaps. +run_one kyber512 32768 20 256 128 128 0 0 1 +run_one kyber768 32768 20 256 128 128 0 1 0 +run_one kyber1024 32768 20 256 128 128 1 1 0 + +run_one aigisenc1 65536 20 256 128 128 1 0 1 +run_one aigisenc2 65536 20 256 128 128 1 1 0 +run_one aigisenc3 65536 20 256 128 128 1 0 1 +run_one aigisenc4 65536 20 256 128 128 1 0 1 + +echo "[extract] ${out_dir}/kem_final_extract.txt" +grep -E "Algorithm:|Keygen:|Encaps:|Decaps:" "${summary}" | tee "${out_dir}/kem_final_extract.txt" + +echo "[done] final report directory: ${out_dir}" diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_resource_profile_amd.sh b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_resource_profile_amd.sh new file mode 100644 index 000000000..4346bdf86 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_resource_profile_amd.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd "$(dirname "$0")" + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +target="${1:-kyber768}" +batch="${2:-32768}" +iters="${3:-200}" +kg_tpb="${KEM_KEYGEN_TPB:-256}" +enc_tpb="${KEM_ENCAPS_TPB:-128}" +dec_tpb="${KEM_DECAPS_TPB:-128}" +stamp="$(date +%Y%m%d_%H%M%S)" +out_dir="amd_results/resource_profile_${target}_${stamp}" +mkdir -p "${out_dir}" + +{ + echo "timestamp=${stamp}" + echo "target=${target}" + echo "batch=${batch}" + echo "iters=${iters}" + echo "KEM_KEYGEN_TPB=${kg_tpb}" + echo "KEM_ENCAPS_TPB=${enc_tpb}" + echo "KEM_DECAPS_TPB=${dec_tpb}" + hipcc --version || true +} 2>&1 | tee "${out_dir}/metadata.txt" + +KEM_KEYGEN_TPB="${kg_tpb}" KEM_ENCAPS_TPB="${enc_tpb}" KEM_DECAPS_TPB="${dec_tpb}" \ + bash build_hip.sh "${target}" 2>&1 | tee "${out_dir}/build.log" + +( + for i in $(seq 1 120); do + echo "===== sample ${i} $(date '+%H:%M:%S.%3N') =====" + rocm-smi --showuse --showmemuse --showtemp --showpower + sleep 0.2 + done +) > "${out_dir}/rocm_smi_during.log" & +smi_pid=$! + +"./${target}_amd" --batch "${batch}" --n-ops "${iters}" --no-correctness \ + 2>&1 | tee "${out_dir}/benchmark.log" + +wait "${smi_pid}" || true + +mkdir -p "${out_dir}/rocprofv3" +rocprofv3 \ + --kernel-trace \ + --hip-trace \ + --output-format csv \ + --output-directory "${out_dir}/rocprofv3" \ + -- \ + "./${target}_amd" --batch "${batch}" --n-ops 1 --no-correctness --pipeline \ + 2>&1 | tee "${out_dir}/rocprofv3.log" || true + +python3 summarize_rocprofv3_trace.py "${out_dir}/rocprofv3" \ + > "${out_dir}/rocprofv3_summary.csv" || true + +grep -E "GPU\\[0\\].*(GPU use|VRAM|Power|Temperature)" "${out_dir}/rocm_smi_during.log" \ + > "${out_dir}/rocm_smi_gpu0_extract.log" || true + +echo "[done] resource profile directory: ${out_dir}" diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_tune_amd.sh b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_tune_amd.sh new file mode 100644 index 000000000..ce6bdaea7 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_kem_tune_amd.sh @@ -0,0 +1,123 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd "$(dirname "$0")" + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +TARGET=${1:-kyber768} +BATCH=${BATCH:-32768} +N_OPS=${N_OPS:-20} +DO_CORRECTNESS=${DO_CORRECTNESS:-1} + +stamp="$(date +%Y%m%d_%H%M%S)" +out_dir="amd_results/tune_${TARGET}_${stamp}" +mkdir -p "${out_dir}" + +summary="${out_dir}/tune_summary.csv" +echo "target,batch,n_ops,opt,kg_tpb,enc_tpb,dec_tpb,keypair_bounds,encaps_bounds,decaps_bounds,wp_kg_warps,pack_tpb,keygen_ops_s,encaps_ops_s,decaps_ops_s,status,log" > "${summary}" + +extract_metric() { + local label="$1" + local log="$2" + grep -E " ${label}:" "${log}" \ + | tail -1 \ + | grep -oE '[0-9]+ ops/sec' \ + | tail -1 \ + | awk '{print $1}' +} + +run_config() { + local opt="$1" + local kg="$2" + local enc="$3" + local dec="$4" + local kb="$5" + local eb="$6" + local db="$7" + local wp="$8" + local pack="$9" + local tag="opt${opt}_kg${kg}_enc${enc}_dec${dec}_b${kb}${eb}${db}_wp${wp}_pack${pack}" + local log="${out_dir}/${TARGET}_${tag}.log" + local status="PASS" + + { + echo "========== ${TARGET} ${tag} ==========" + echo "timestamp=$(date '+%Y-%m-%d %H:%M:%S')" + echo "batch=${BATCH} n_ops=${N_OPS}" + echo "OPT_LEVEL=${opt} KEM_KEYGEN_TPB=${kg} KEM_ENCAPS_TPB=${enc} KEM_DECAPS_TPB=${dec}" + echo "bounds=${kb}/${eb}/${db} WP_KG_WARPS_BLOCK=${wp} KEM_PACK_TPB=${pack}" + OPT_LEVEL="${opt}" \ + KEM_KEYGEN_TPB="${kg}" KEM_ENCAPS_TPB="${enc}" KEM_DECAPS_TPB="${dec}" \ + KEM_KEYPAIR_LAUNCH_BOUNDS="${kb}" KEM_ENCAPS_LAUNCH_BOUNDS="${eb}" KEM_DECAPS_LAUNCH_BOUNDS="${db}" \ + WP_KG_WARPS_BLOCK="${wp}" KEM_PACK_TPB="${pack}" \ + bash build_hip.sh "${TARGET}" + + if [[ "${DO_CORRECTNESS}" == "1" ]]; then + "./${TARGET}_amd" --batch 128 --n-ops 1 + fi + + "./${TARGET}_amd" --batch "${BATCH}" --n-ops "${N_OPS}" --no-correctness + } > "${log}" 2>&1 || status="FAIL" + + local kg_ops="" + local enc_ops="" + local dec_ops="" + if [[ "${status}" == "PASS" ]]; then + kg_ops="$(extract_metric Keygen "${log}" || true)" + enc_ops="$(extract_metric Encaps "${log}" || true)" + dec_ops="$(extract_metric Decaps "${log}" || true)" + fi + + echo "${TARGET},${BATCH},${N_OPS},${opt},${kg},${enc},${dec},${kb},${eb},${db},${wp},${pack},${kg_ops},${enc_ops},${dec_ops},${status},${log}" | tee -a "${summary}" +} + +echo "[tune] target=${TARGET} batch=${BATCH} n_ops=${N_OPS} out=${out_dir}" + +# Serial-path tuning. This is the current final-report path and therefore the +# first optimization surface to lock down. +for opt in O2 O3; do + for kg in 128 256 512; do + for enc in 64 128; do + for dec in 64 128; do + run_config "${opt}" "${kg}" "${enc}" "${dec}" 1 0 0 4 128 + done + done + done +done + +# Targeted launch-bounds checks for the current best neighborhood. +for kb in 0 1; do + for eb in 0 1; do + for db in 0 1; do + run_config O2 256 128 128 "${kb}" "${eb}" "${db}" 4 128 + done + done +done + +# Pipeline-only knobs. These runs still print serial first, but the added +# pipeline profile makes it easy to see whether the sampling path improves. +pipeline_log="${out_dir}/pipeline_candidates.log" +: > "${pipeline_log}" +for wp in 2 4 8; do + for pack in 64 128 256; do + tag="pipeline_wp${wp}_pack${pack}" + log="${out_dir}/${TARGET}_${tag}.log" + { + echo "========== ${TARGET} ${tag} ==========" + OPT_LEVEL=O2 KEM_KEYGEN_TPB=256 KEM_ENCAPS_TPB=128 KEM_DECAPS_TPB=128 \ + WP_KG_WARPS_BLOCK="${wp}" KEM_PACK_TPB="${pack}" \ + bash build_hip.sh "${TARGET}" + "./${TARGET}_amd" --batch "${BATCH}" --n-ops 3 --no-correctness --pipeline --profile-pipeline + } > "${log}" 2>&1 || true + grep -E "Algorithm:|Pipeline profile:|Keygen:|Encaps:|Decaps:" "${log}" >> "${pipeline_log}" || true + echo >> "${pipeline_log}" + done +done + +echo +echo "[done] summary=${summary}" +echo "[done] pipeline=${pipeline_log}" +echo "[hint] sort by keygen: sort -t, -k13,13nr ${summary} | head" +echo "[hint] sort by encaps: sort -t, -k14,14nr ${summary} | head" +echo "[hint] sort by decaps: sort -t, -k15,15nr ${summary} | head" diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_rocm_toolbox_kem_amd.sh b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_rocm_toolbox_kem_amd.sh new file mode 100644 index 000000000..f32188f89 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/run_rocm_toolbox_kem_amd.sh @@ -0,0 +1,245 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd "$(dirname "$0")" + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +TARGETS=${TARGETS:-"kyber768 kyber1024 aigisenc4"} +N_OPS=${N_OPS:-20} +PROFILE_N_OPS=${PROFILE_N_OPS:-1} +ENABLE_SYS_TRACE=${ENABLE_SYS_TRACE:-1} +ENABLE_PMC=${ENABLE_PMC:-0} +ENABLE_SMI=${ENABLE_SMI:-1} +TOOL_TIMEOUT=${TOOL_TIMEOUT:-120} +KYBER_BATCH=${KYBER_BATCH:-32768} +AIGIS_BATCH=${AIGIS_BATCH:-65536} +KEM_KEYGEN_TPB=${KEM_KEYGEN_TPB:-256} +KEM_ENCAPS_TPB=${KEM_ENCAPS_TPB:-128} +KEM_DECAPS_TPB=${KEM_DECAPS_TPB:-128} +OPT_LEVEL=${OPT_LEVEL:-O2} + +stamp="$(date +%Y%m%d_%H%M%S)" +out_dir="amd_results/rocm_toolbox_${stamp}" +mkdir -p "${out_dir}" + +tool_log="${out_dir}/tool_discovery.txt" +{ + echo "timestamp=${stamp}" + echo "pwd=$(pwd)" + echo "TARGETS=${TARGETS}" + echo "N_OPS=${N_OPS}" + echo "PROFILE_N_OPS=${PROFILE_N_OPS}" + echo "ENABLE_SYS_TRACE=${ENABLE_SYS_TRACE}" + echo "ENABLE_PMC=${ENABLE_PMC}" + echo "ENABLE_SMI=${ENABLE_SMI}" + echo "TOOL_TIMEOUT=${TOOL_TIMEOUT}" + echo + echo "== command availability ==" + for t in hipcc rocprofv3 rocprof-compute rocm-smi rocminfo hipconfig llvm-objdump; do + printf "%-18s" "${t}" + command -v "${t}" || true + done + echo + echo "== hipcc --version ==" + hipcc --version || true + echo + echo "== hipconfig ==" + hipconfig || true + echo + echo "== rocminfo head ==" + rocminfo 2>/dev/null | head -120 || true + echo + echo "== rocm-smi static ==" + rocm-smi --showproductname --showdriverversion --showvbios --showmeminfo vram --showclocks --showmaxpower || true +} > "${tool_log}" 2>&1 + +if command -v rocprofv3 >/dev/null 2>&1; then + rocprofv3 --list-avail > "${out_dir}/rocprofv3_list_avail.txt" 2>&1 || true +fi + +run_with_timeout() { + local seconds="$1" + shift + if command -v timeout >/dev/null 2>&1; then + timeout "${seconds}" "$@" + else + "$@" + fi +} + +batch_for_target() { + local target="$1" + if [[ "${target}" == kyber* ]]; then + echo "${KYBER_BATCH}" + else + echo "${AIGIS_BATCH}" + fi +} + +tuned_bounds_for_target() { + case "$1" in + kyber512) echo "001" ;; + kyber768) echo "010" ;; + kyber1024) echo "110" ;; + aigisenc1) echo "101" ;; + aigisenc2) echo "110" ;; + aigisenc3) echo "101" ;; + aigisenc4) echo "101" ;; + *) echo "100" ;; + esac +} + +candidate_counters=( + SQ_WAVES + GRBM_GUI_ACTIVE + GPUBusy + VALUUtilization + VALUBusy + SALUBusy + MemUnitBusy + MemUnitStalled + FetchSize + WriteSize + FETCH_SIZE + WRITE_SIZE + LDSBankConflict + CU_OCCUPANCY + MeanOccupancyPerCU + MeanOccupancyPerActiveCU +) + +select_counters() { + local list_file="${out_dir}/rocprofv3_list_avail.txt" + local selected=() + [[ -f "${list_file}" ]] || return 0 + for c in "${candidate_counters[@]}"; do + if grep -qw "${c}" "${list_file}"; then + selected+=("${c}") + fi + done + local joined="" + for c in "${selected[@]}"; do + if [[ -z "${joined}" ]]; then + joined="${c}" + else + joined="${joined},${c}" + fi + done + echo "${joined}" +} + +extract_metric() { + local label="$1" + local log="$2" + grep -E " ${label}:" "${log}" \ + | tail -1 \ + | grep -oE '[0-9]+ ops/sec' \ + | tail -1 \ + | awk '{print $1}' +} + +runs_csv="${out_dir}/toolbox_runs.csv" +echo "target,bounds,batch,n_ops,keygen_ops_s,encaps_ops_s,decaps_ops_s,status,run_dir,counters" > "${runs_csv}" + +run_target() { + local target="$1" + local bounds + bounds="$(tuned_bounds_for_target "${target}")" + local kb="${bounds:0:1}" + local eb="${bounds:1:1}" + local db="${bounds:2:1}" + local batch + batch="$(batch_for_target "${target}")" + local run_dir="${out_dir}/${target}_bounds${bounds}" + mkdir -p "${run_dir}" + local status="PASS" + + echo + echo "[toolbox] target=${target} bounds=${bounds} batch=${batch}" + + OPT_LEVEL="${OPT_LEVEL}" \ + KEM_KEYGEN_TPB="${KEM_KEYGEN_TPB}" KEM_ENCAPS_TPB="${KEM_ENCAPS_TPB}" KEM_DECAPS_TPB="${KEM_DECAPS_TPB}" \ + KEM_KEYPAIR_LAUNCH_BOUNDS="${kb}" KEM_ENCAPS_LAUNCH_BOUNDS="${eb}" KEM_DECAPS_LAUNCH_BOUNDS="${db}" \ + bash build_hip.sh "${target}" > "${run_dir}/build.log" 2>&1 || status="FAIL" + + if [[ "${status}" == "PASS" ]]; then + local smi_pid="" + if [[ "${ENABLE_SMI}" == "1" ]]; then + ( + for i in $(seq 1 80); do + echo "===== sample ${i} $(date '+%H:%M:%S.%3N') =====" + rocm-smi --showuse --showmemuse --showtemp --showpower --showclocks + sleep 0.2 + done + ) > "${run_dir}/rocm_smi_during.log" & + smi_pid=$! + else + echo "SMI sampling disabled." > "${run_dir}/rocm_smi_during.log" + fi + + run_with_timeout "${TOOL_TIMEOUT}" "./${target}_amd" --batch "${batch}" --n-ops "${N_OPS}" --no-correctness \ + > "${run_dir}/benchmark.log" 2>&1 || status="FAIL" + if [[ -n "${smi_pid}" ]]; then + wait "${smi_pid}" || true + fi + fi + + local kg_ops="" + local enc_ops="" + local dec_ops="" + if [[ "${status}" == "PASS" ]]; then + kg_ops="$(extract_metric Keygen "${run_dir}/benchmark.log" || true)" + enc_ops="$(extract_metric Encaps "${run_dir}/benchmark.log" || true)" + dec_ops="$(extract_metric Decaps "${run_dir}/benchmark.log" || true)" + fi + + if [[ "${status}" == "PASS" && -x "./${target}_amd" && "$(command -v rocprofv3 || true)" ]]; then + mkdir -p "${run_dir}/sys_trace" + if [[ "${ENABLE_SYS_TRACE}" == "1" ]]; then + run_with_timeout "${TOOL_TIMEOUT}" rocprofv3 \ + --sys-trace \ + --output-format csv \ + --output-directory "${run_dir}/sys_trace" \ + -- \ + "./${target}_amd" --batch "${batch}" --n-ops "${PROFILE_N_OPS}" --no-correctness \ + > "${run_dir}/rocprofv3_sys_trace.log" 2>&1 || true + else + echo "sys-trace disabled." > "${run_dir}/rocprofv3_sys_trace.log" + fi + + local counters + counters="$(select_counters)" + if [[ "${ENABLE_PMC}" == "1" && -n "${counters}" ]]; then + echo "${counters}" > "${run_dir}/selected_counters.txt" + mkdir -p "${run_dir}/pmc" + run_with_timeout "${TOOL_TIMEOUT}" rocprofv3 \ + --pmc "${counters}" \ + --output-format csv \ + --output-directory "${run_dir}/pmc" \ + -- \ + "./${target}_amd" --batch "${batch}" --n-ops "${PROFILE_N_OPS}" --no-correctness \ + > "${run_dir}/rocprofv3_pmc.log" 2>&1 || true + elif [[ "${ENABLE_PMC}" != "1" ]]; then + echo "PMC disabled because ENABLE_PMC=${ENABLE_PMC}. Enable explicitly with ENABLE_PMC=1." > "${run_dir}/selected_counters.txt" + else + echo "No candidate counters found in rocprofv3 --list-avail output." > "${run_dir}/selected_counters.txt" + fi + fi + + echo "${target},${bounds},${batch},${N_OPS},${kg_ops},${enc_ops},${dec_ops},${status},${target}_bounds${bounds},$(cat "${run_dir}/selected_counters.txt" 2>/dev/null || true)" | tee -a "${runs_csv}" +} + +for target in ${TARGETS}; do + run_target "${target}" +done + +python3 summarize_rocm_pmc.py "${out_dir}" | tee "${out_dir}/summarize_rocm_pmc.log" || true + +echo +echo "[done] ${out_dir}" +echo "[show] toolbox runs:" +cat "${runs_csv}" +echo +echo "[show] pmc summary:" +cat "${out_dir}/pmc_summary.csv" 2>/dev/null || true diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/summarize_kem_best.py b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/summarize_kem_best.py new file mode 100644 index 000000000..4a4977e79 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/summarize_kem_best.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +import csv +import sys +from pathlib import Path + + +if len(sys.argv) != 2: + print("usage: summarize_kem_best.py ", file=sys.stderr) + raise SystemExit(2) + + +def to_float(value: str) -> float: + try: + return float(value) + except (TypeError, ValueError): + return 0.0 + + +rows = list(csv.DictReader(Path(sys.argv[1]).open(newline=""))) +best = {} + +for row in rows: + if row.get("status") != "PASS": + continue + algo = row.get("algorithm", "") + if not algo: + continue + entry = best.setdefault( + algo, + { + "algorithm": algo, + "best_keygen_ops_s": 0.0, + "best_keygen_config": "", + "best_encaps_ops_s": 0.0, + "best_encaps_config": "", + "best_decaps_ops_s": 0.0, + "best_decaps_config": "", + }, + ) + config = f"batch={row.get('batch','')} mode={row.get('mode','')} streams={row.get('streams','')}" + for op in ("keygen", "encaps", "decaps"): + value = to_float(row.get(f"{op}_ops_s", "")) + key = f"best_{op}_ops_s" + if value > entry[key]: + entry[key] = value + entry[f"best_{op}_config"] = config + + +fieldnames = [ + "algorithm", + "best_keygen_ops_s", + "best_keygen_config", + "best_encaps_ops_s", + "best_encaps_config", + "best_decaps_ops_s", + "best_decaps_config", +] + +writer = csv.DictWriter(sys.stdout, fieldnames=fieldnames, lineterminator="\n") +writer.writeheader() +for algorithm in sorted(best): + row = best[algorithm] + writer.writerow( + { + key: (f"{row[key]:.0f}" if key.endswith("_ops_s") else row[key]) + for key in fieldnames + } + ) diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/summarize_profile_compare.py b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/summarize_profile_compare.py new file mode 100644 index 000000000..fe5483380 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/summarize_profile_compare.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python3 +import csv +import sys +from collections import defaultdict +from pathlib import Path + + +KEY_KERNELS = { + "keypair": "batch_kem_keypair_serial_kernel", + "encaps": "batch_kem_encaps_serial_kernel", + "decaps": "batch_kem_decaps_serial_kernel", +} + + +def as_float(value): + try: + return float(value) + except (TypeError, ValueError): + return 0.0 + + +def duration_ns(row): + if "Duration" in row: + return as_float(row["Duration"]) + return as_float(row.get("End_Timestamp")) - as_float(row.get("Start_Timestamp")) + + +def find_one(root, pattern): + files = sorted(root.rglob(pattern)) + return files[0] if files else None + + +def summarize_kernel_file(path): + agg = defaultdict(lambda: { + "calls": 0, + "total_ns": 0.0, + "max_ns": 0.0, + "vgpr": "", + "sgpr": "", + "scratch": "", + "lds": "", + "workgroup": "", + "grid": "", + }) + if not path: + return agg + + with path.open(newline="", errors="replace") as f: + for row in csv.DictReader(f): + name = row.get("Kernel_Name") or row.get("Name") or row.get("Kernel Name") or "" + if not name: + continue + ns = max(0.0, duration_ns(row)) + entry = agg[name] + entry["calls"] += 1 + entry["total_ns"] += ns + entry["max_ns"] = max(entry["max_ns"], ns) + entry["vgpr"] = row.get("VGPR_Count", entry["vgpr"]) + entry["sgpr"] = row.get("SGPR_Count", entry["sgpr"]) + entry["scratch"] = row.get("Scratch_Size", entry["scratch"]) + entry["lds"] = row.get("LDS_Block_Size", entry["lds"]) + entry["workgroup"] = "x".join([ + row.get("Workgroup_Size_X", ""), + row.get("Workgroup_Size_Y", ""), + row.get("Workgroup_Size_Z", ""), + ]).strip("x") + entry["grid"] = "x".join([ + row.get("Grid_Size_X", ""), + row.get("Grid_Size_Y", ""), + row.get("Grid_Size_Z", ""), + ]).strip("x") + return agg + + +def summarize_api_file(path): + agg = defaultdict(lambda: {"calls": 0, "total_ns": 0.0, "max_ns": 0.0}) + if not path: + return agg + + with path.open(newline="", errors="replace") as f: + for row in csv.DictReader(f): + name = row.get("Function") or row.get("Name") or row.get("API_Name") or "" + if not name: + continue + ns = max(0.0, duration_ns(row)) + entry = agg[name] + entry["calls"] += 1 + entry["total_ns"] += ns + entry["max_ns"] = max(entry["max_ns"], ns) + return agg + + +def pct(new, old): + if old == 0: + return "" + return round((new - old) * 100.0 / old, 2) + + +def main(): + if len(sys.argv) != 2: + print("usage: summarize_profile_compare.py ", file=sys.stderr) + raise SystemExit(2) + + root = Path(sys.argv[1]) + runs_path = root / "profile_compare_runs.csv" + if not runs_path.exists(): + print(f"missing {runs_path}", file=sys.stderr) + raise SystemExit(1) + + runs = [] + with runs_path.open(newline="", encoding="utf-8") as f: + for row in csv.DictReader(f): + runs.append(row) + + kernel_rows = [] + api_rows = [] + key_rows = [] + + for run in runs: + run_dir = root / run["run_dir"] + kernel_path = find_one(run_dir / "rocprofv3", "*kernel_trace*.csv") + api_path = find_one(run_dir / "rocprofv3", "*hip_api_trace*.csv") + kernel_agg = summarize_kernel_file(kernel_path) + api_agg = summarize_api_file(api_path) + + base = { + "target": run["target"], + "config": run["config"], + "bounds": run["bounds"], + "batch": run["batch"], + "n_ops": run["n_ops"], + "keygen_ops_s": run["keygen_ops_s"], + "encaps_ops_s": run["encaps_ops_s"], + "decaps_ops_s": run["decaps_ops_s"], + } + + for name, v in sorted(kernel_agg.items(), key=lambda item: item[1]["total_ns"], reverse=True): + row = dict(base) + row.update({ + "kernel": name, + "total_ms": round(v["total_ns"] / 1e6, 3), + "avg_ms": round((v["total_ns"] / v["calls"]) / 1e6, 3) if v["calls"] else 0, + "max_ms": round(v["max_ns"] / 1e6, 3), + "calls": v["calls"], + "vgpr": v["vgpr"], + "sgpr": v["sgpr"], + "scratch": v["scratch"], + "lds": v["lds"], + "workgroup": v["workgroup"], + "grid": v["grid"], + }) + kernel_rows.append(row) + + for name, v in sorted(api_agg.items(), key=lambda item: item[1]["total_ns"], reverse=True): + row = dict(base) + row.update({ + "function": name, + "total_ms": round(v["total_ns"] / 1e6, 3), + "avg_ms": round((v["total_ns"] / v["calls"]) / 1e6, 3) if v["calls"] else 0, + "max_ms": round(v["max_ns"] / 1e6, 3), + "calls": v["calls"], + }) + api_rows.append(row) + + for op, needle in KEY_KERNELS.items(): + for name, v in kernel_agg.items(): + if needle in name: + row = dict(base) + row.update({ + "operation": op, + "kernel": name, + "total_ms": round(v["total_ns"] / 1e6, 3), + "avg_ms": round((v["total_ns"] / v["calls"]) / 1e6, 3) if v["calls"] else 0, + "calls": v["calls"], + "vgpr": v["vgpr"], + "sgpr": v["sgpr"], + "scratch": v["scratch"], + "lds": v["lds"], + "workgroup": v["workgroup"], + "grid": v["grid"], + }) + key_rows.append(row) + + def write_csv(path, rows, fields): + with path.open("w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=fields) + w.writeheader() + w.writerows(rows) + + kernel_fields = [ + "target", "config", "bounds", "batch", "n_ops", "keygen_ops_s", + "encaps_ops_s", "decaps_ops_s", "total_ms", "avg_ms", "max_ms", + "calls", "vgpr", "sgpr", "scratch", "lds", "workgroup", "grid", "kernel", + ] + api_fields = [ + "target", "config", "bounds", "batch", "n_ops", "keygen_ops_s", + "encaps_ops_s", "decaps_ops_s", "total_ms", "avg_ms", "max_ms", + "calls", "function", + ] + key_fields = [ + "target", "config", "bounds", "batch", "n_ops", "operation", + "keygen_ops_s", "encaps_ops_s", "decaps_ops_s", "total_ms", "avg_ms", + "calls", "vgpr", "sgpr", "scratch", "lds", "workgroup", "grid", "kernel", + ] + write_csv(root / "kernel_summary.csv", kernel_rows, kernel_fields) + write_csv(root / "hip_api_summary.csv", api_rows, api_fields) + write_csv(root / "key_kernel_summary.csv", key_rows, key_fields) + + by_target_op = defaultdict(dict) + for row in key_rows: + by_target_op[(row["target"], row["operation"])][row["config"]] = row + + compare_rows = [] + for (target, op), configs in sorted(by_target_op.items()): + if "baseline" not in configs or "tuned" not in configs: + continue + b = configs["baseline"] + t = configs["tuned"] + b_ms = as_float(b["total_ms"]) + t_ms = as_float(t["total_ms"]) + compare_rows.append({ + "target": target, + "operation": op, + "baseline_bounds": b["bounds"], + "tuned_bounds": t["bounds"], + "baseline_total_ms": b["total_ms"], + "tuned_total_ms": t["total_ms"], + "kernel_time_change_pct": pct(t_ms, b_ms), + "baseline_vgpr": b["vgpr"], + "tuned_vgpr": t["vgpr"], + "baseline_sgpr": b["sgpr"], + "tuned_sgpr": t["sgpr"], + "baseline_scratch": b["scratch"], + "tuned_scratch": t["scratch"], + "baseline_workgroup": b["workgroup"], + "tuned_workgroup": t["workgroup"], + }) + + write_csv(root / "key_kernel_compare.csv", compare_rows, [ + "target", "operation", "baseline_bounds", "tuned_bounds", + "baseline_total_ms", "tuned_total_ms", "kernel_time_change_pct", + "baseline_vgpr", "tuned_vgpr", "baseline_sgpr", "tuned_sgpr", + "baseline_scratch", "tuned_scratch", "baseline_workgroup", "tuned_workgroup", + ]) + + print(f"[done] {root / 'kernel_summary.csv'}") + print(f"[done] {root / 'hip_api_summary.csv'}") + print(f"[done] {root / 'key_kernel_summary.csv'}") + print(f"[done] {root / 'key_kernel_compare.csv'}") + + +if __name__ == "__main__": + main() diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/summarize_rocm_pmc.py b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/summarize_rocm_pmc.py new file mode 100644 index 000000000..1b64e7f50 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/summarize_rocm_pmc.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +import csv +import sys +from pathlib import Path + + +def as_float(value): + try: + return float(value) + except (TypeError, ValueError): + return None + + +def main(): + if len(sys.argv) != 2: + print("usage: summarize_rocm_pmc.py ", file=sys.stderr) + raise SystemExit(2) + + root = Path(sys.argv[1]) + files = sorted(root.rglob("*counter_collection*.csv")) + out = root / "pmc_summary.csv" + + rows = [] + for path in files: + with path.open(newline="", errors="replace") as f: + reader = csv.DictReader(f) + for row in reader: + kernel = row.get("Kernel_Name") or row.get("Name") or row.get("Kernel Name") or "" + if not kernel: + continue + numeric = {} + for key, value in row.items(): + val = as_float(value) + if val is not None: + numeric[key] = val + rows.append((path, kernel, numeric)) + + if not rows: + out.write_text("status,message\nEMPTY,no counter_collection csv rows found\n", encoding="utf-8") + print(f"[warn] no PMC counter rows found under {root}") + print(f"[done] {out}") + return + + # Aggregate numeric columns by kernel name. + agg = {} + for path, kernel, numeric in rows: + entry = agg.setdefault(kernel, {"calls": 0, "source_files": set(), "sums": {}}) + entry["calls"] += 1 + entry["source_files"].add(str(path.relative_to(root))) + for key, value in numeric.items(): + entry["sums"][key] = entry["sums"].get(key, 0.0) + value + + # Prefer commonly useful counters first, then include the rest. + preferred = [ + "Duration", "Dispatch_Id", "SQ_WAVES", "GRBM_GUI_ACTIVE", "GPUBusy", + "VALUUtilization", "VALUBusy", "SALUBusy", "MemUnitBusy", + "MemUnitStalled", "FetchSize", "WriteSize", "FETCH_SIZE", "WRITE_SIZE", + "L2CacheHit", "LDSBankConflict", "CU_OCCUPANCY", + "MeanOccupancyPerCU", "MeanOccupancyPerActiveCU", + ] + all_cols = set() + for entry in agg.values(): + all_cols.update(entry["sums"].keys()) + ordered_cols = [c for c in preferred if c in all_cols] + sorted(all_cols - set(preferred)) + + with out.open("w", newline="", encoding="utf-8") as f: + fieldnames = ["kernel", "calls", "source_files"] + [f"sum_{c}" for c in ordered_cols] + [f"avg_{c}" for c in ordered_cols] + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for kernel, entry in sorted(agg.items(), key=lambda item: item[1]["calls"], reverse=True): + row = { + "kernel": kernel, + "calls": entry["calls"], + "source_files": ";".join(sorted(entry["source_files"])), + } + for col in ordered_cols: + total = entry["sums"].get(col, 0.0) + row[f"sum_{col}"] = round(total, 3) + row[f"avg_{col}"] = round(total / entry["calls"], 3) if entry["calls"] else "" + writer.writerow(row) + + print(f"[done] {out}") + + +if __name__ == "__main__": + main() diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/summarize_rocprofv3_trace.py b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/summarize_rocprofv3_trace.py new file mode 100644 index 000000000..ba3fcb618 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/kem_optimization/summarize_rocprofv3_trace.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +import csv +import sys +from collections import defaultdict +from pathlib import Path + + +def as_float(value): + try: + return float(value) + except (TypeError, ValueError): + return 0.0 + + +def kernel_duration_ns(row): + if "Duration" in row: + return as_float(row["Duration"]) + return as_float(row.get("End_Timestamp")) - as_float(row.get("Start_Timestamp")) + + +def summarize_kernel(path): + rows = list(csv.DictReader(path.open(newline="", errors="replace"))) + agg = defaultdict(lambda: { + "count": 0, + "total_ns": 0.0, + "max_ns": 0.0, + "vgpr": "", + "sgpr": "", + "scratch": "", + "lds": "", + "wg": "", + "grid": "", + }) + + for row in rows: + name = row.get("Kernel_Name") or row.get("Name") or row.get("Kernel Name") or "" + if not name: + continue + ns = max(0.0, kernel_duration_ns(row)) + entry = agg[name] + entry["count"] += 1 + entry["total_ns"] += ns + entry["max_ns"] = max(entry["max_ns"], ns) + entry["vgpr"] = row.get("VGPR_Count", entry["vgpr"]) + entry["sgpr"] = row.get("SGPR_Count", entry["sgpr"]) + entry["scratch"] = row.get("Scratch_Size", entry["scratch"]) + entry["lds"] = row.get("LDS_Block_Size", entry["lds"]) + entry["wg"] = "x".join([ + row.get("Workgroup_Size_X", ""), + row.get("Workgroup_Size_Y", ""), + row.get("Workgroup_Size_Z", ""), + ]).strip("x") + entry["grid"] = "x".join([ + row.get("Grid_Size_X", ""), + row.get("Grid_Size_Y", ""), + row.get("Grid_Size_Z", ""), + ]).strip("x") + + print(f"\n# Kernel trace: {path}") + print("total_ms,avg_ms,max_ms,calls,vgpr,sgpr,scratch,lds,workgroup,grid,kernel") + for name, v in sorted(agg.items(), key=lambda item: item[1]["total_ns"], reverse=True): + avg = v["total_ns"] / v["count"] if v["count"] else 0.0 + print( + f"{v['total_ns']/1e6:.3f}," + f"{avg/1e6:.3f}," + f"{v['max_ns']/1e6:.3f}," + f"{v['count']}," + f"{v['vgpr']}," + f"{v['sgpr']}," + f"{v['scratch']}," + f"{v['lds']}," + f"{v['wg']}," + f"{v['grid']}," + f"{name}" + ) + + +def summarize_api(path): + rows = list(csv.DictReader(path.open(newline="", errors="replace"))) + agg = defaultdict(lambda: {"count": 0, "total_ns": 0.0, "max_ns": 0.0}) + + for row in rows: + name = row.get("Function") or row.get("Name") or row.get("API_Name") or "" + if not name: + continue + if "Duration" in row: + ns = as_float(row["Duration"]) + else: + ns = as_float(row.get("End_Timestamp")) - as_float(row.get("Start_Timestamp")) + ns = max(0.0, ns) + entry = agg[name] + entry["count"] += 1 + entry["total_ns"] += ns + entry["max_ns"] = max(entry["max_ns"], ns) + + print(f"\n# HIP API trace: {path}") + print("total_ms,avg_ms,max_ms,calls,function") + for name, v in sorted(agg.items(), key=lambda item: item[1]["total_ns"], reverse=True): + avg = v["total_ns"] / v["count"] if v["count"] else 0.0 + print(f"{v['total_ns']/1e6:.3f},{avg/1e6:.3f},{v['max_ns']/1e6:.3f},{v['count']},{name}") + + +def main(): + if len(sys.argv) != 2: + print("usage: summarize_rocprofv3_trace.py ", file=sys.stderr) + raise SystemExit(2) + + root = Path(sys.argv[1]) + kernel_files = sorted(root.rglob("*kernel_trace*.csv")) + api_files = sorted(root.rglob("*hip_api_trace*.csv")) + + if not kernel_files and not api_files: + print(f"no rocprofv3 trace csv files found under {root}", file=sys.stderr) + raise SystemExit(1) + + for path in kernel_files: + summarize_kernel(path) + for path in api_files: + summarize_api(path) + + +if __name__ == "__main__": + main() diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/COMPETITION_RUNBOOK.md b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/COMPETITION_RUNBOOK.md new file mode 100644 index 000000000..0d2f0a430 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/COMPETITION_RUNBOOK.md @@ -0,0 +1,206 @@ +# AMD ROCm Competition Runbook + +This is the AMD/JupyterLab source tree for the ROCm signature workload. +The local NVIDIA tree is only the 4090D/CUDA baseline and should not be used as +the AMD source of truth. + +## Goal + +Build a stable ROCm resource-aware signing implementation for: + +- ML-DSA-44 / 65 / 87 +- Aigis-sig-1 / 2 / 3 + +The AMD stable policy is: + +```text +decomp-pipeline=on +monolithic-precomp=off +cached-precomp=off +warp-path=off +large-strategy=off +decomp-cp-fuse=off +decomp-tail=off +yhat-copy-fuse=off +decomp-adaptive=off +``` + +Rationale: monolithic signing can trigger ROCm private segment / scratch +resource pressure. The stable competition path uses the decomp pipeline. +The fused `cp_fuse` and tail-finish paths are implemented as AMD-specific +candidates, but the current evidence is mixed, so they are measured before +being promoted into any final target build. The new `adaptive` candidate uses +the feature matrix as a runtime policy table, so one binary can select the best +measured local policy by target, benchmark mode, and batch size, with a base +fallback on cells where aggressive AMD knobs regress. + +## Build + +```bash +cd /app/amd_sig_anchor_results_20260605_031411 +bash amd_tools/build_sig_amd.sh +``` + +Expected binaries: + +```bash +ls -lh mldsa44_amd mldsa65_amd mldsa87_amd aigis1_amd aigis2_amd aigis3_amd +``` + +## Policy Smoke + +```bash +bash amd_tools/run_sig_policy_smoke.sh 128 +cat amd_results/policy_smoke/policy_smoke_b128.txt +``` + +Required evidence: + +```text +ROCm sign policy: resource-aware hybrid candidates +monolithic-precomp=off +cached-precomp=off +decomp-cp-fuse=off +decomp-tail=off +yhat-copy-fuse=off +[Sign] correctness: all 128 PASS [decomp-pipeline] +``` + +An adaptive selected build may instead print: + +```text +[Sign] correctness: all 128 PASS [decomp-adaptive] +``` + +## Debug Matrix + +```bash +bash amd_tools/run_sig_debug_matrix.sh +``` + +This checks all six targets at small batches before long sweeps. + +## Large Sweep + +```bash +bash amd_tools/run_sig_large_sweep.sh +``` + +Outputs: + +```text +amd_results/large_sweep/ +amd_results/sig_large_sweep_summary.csv +amd_results/sig_large_best.csv +``` + +Use `sig_large_best.csv` for paper/PPT throughput tables. + +## AMD Feature Matrix + +Run this before deciding whether the aggressive candidates should enter the +final build: + +```bash +bash amd_tools/run_sig_amd_feature_matrix.sh +``` + +Default comparison: + +```text +base stable resource-aware decomp path +adaptive runtime target/mode/batch policy with base fallback +check8/check16 fewer host-side done-count checks in the rejection loop +wave64_ctrl 64-thread hash/check control kernels for AMD wave64 testing +cp_fuse fused cp*s1/cp*s2/cp*t0 pointwise products +tail16_base small-tail finish candidate without cp_fuse +tail16_cp_fuse small-tail finish candidate with cp_fuse +yhat_dup cp_fuse plus sample-time y/y_hat copy candidate +``` + +Outputs: + +```text +amd_results/sig_amd_feature_matrix.csv +amd_results/sig_amd_feature_matrix_ranked.csv +``` + +After the matrix: + +```bash +python3 amd_tools/write_optimization_claims.py +cat amd_results/optimization_claims.md +``` + +Do not promote a candidate from one local win. Use the selector below; it only +recommends a non-base variant when that variant passes every measured cell for a +target and avoids measured regressions. The `adaptive` candidate is the main +way to turn mixed local wins into a single competition build without forcing one +global macro onto every workload. + +For a conservative target-specific recommendation: + +```bash +python3 amd_tools/select_sig_amd_variants.py +cat amd_results/sig_amd_variant_plan.md +bash amd_tools/build_sig_amd_selected.sh amd_results/sig_amd_variant_plan.env +``` + +Only use the selected build for final sweeps after policy smoke and debug matrix +pass again. + +## Profiling + +```bash +bash amd_tools/profile_sig_one.sh mldsa44_amd 1024 +bash amd_tools/profile_sig_one.sh mldsa87_amd 1024 +bash amd_tools/profile_sig_one.sh aigis2_amd 1024 +``` + +Then summarize rocprof CSV output if present: + +```bash +python3 amd_tools/summarize_rocm_kernel_profile.py amd_results/profile \ + > amd_results/profile/kernel_summary.csv +``` + +If no kernel CSV is found: + +```bash +find amd_results/profile -maxdepth 4 -type f -print +sed -n '1,120p' amd_results/profile/mldsa44_amd_b1024_rocprof.log +rocprofv3 --help | head -120 +``` + +The script no longer uses `rocprofv3 --timestamp on`, because this ROCm +environment rejected that option. + +## Submission Audit + +```bash +python3 amd_tools/check_competition_evidence.py +``` + +Expected result: + +```text +[OK] sig_large_best.csv ... +[OK] policy smoke logs pass ... +[OK] large sweep logs clean ... +[OK] competition evidence audit complete +``` + +## Paper / PPT Story + +1. Post-quantum signatures are throughput-heavy GPU workloads. +2. CUDA monolithic signing does not transfer directly to AMD ROCm. +3. AMD ROCm exposes private segment / scratch pressure for monolithic signing. +4. The project adopts a resource-aware decomp pipeline. +5. AMD-specific candidates include fused `cp*secret` pointwise work and + wave64 control kernels; the matrix shows these are workload-sensitive. +6. The adaptive policy is the innovation layer: it applies local winners where + they help and keeps the resource-aware baseline elsewhere. +7. Feature-matrix scripts separate proven gains from regressions and risky + candidates. +8. Correctness and large-batch performance are reproduced by scripts. +9. ROCm profiling is used to explain kernel-level bottlenecks. diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/README.md b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/README.md new file mode 100644 index 000000000..a07ce2288 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/README.md @@ -0,0 +1,175 @@ +# AMD SIG Debug Tools + +These scripts are for the already-hipified ML-DSA / Aigis-sig source tree on the AMD ROCm server. + +For the competition workflow, start from: + +```text +COMPETITION_RUNBOOK.md +``` + +## Build + +```bash +bash amd_tools/build_sig_amd.sh +``` + +The AMD build uses the ROCm resource-aware signing policy: + +- `decomp-pipeline=on` +- `monolithic-precomp=off` +- `cached-precomp=off` +- `decomp-cp-fuse=off` +- `decomp-tail=off` +- `yhat-copy-fuse=off` +- `decomp-adaptive=off` in the plain stable build + +The fused and tail paths are implemented, but they are kept out of the default +build until the AMD feature matrix proves a target-specific win. Current +evidence shows `cp_fuse` can help some ML-DSA cases while regressing other +targets and batch sizes, so the stable competition build stays with the base +resource-aware decomp pipeline. The `adaptive` matrix candidate is different: +it keeps one binary but selects measured local winners at runtime by +target/mode/batch, and falls back to base where the matrix shows a regression. + +This builds: + +- `mldsa44_amd` +- `mldsa65_amd` +- `mldsa87_amd` +- `aigis1_amd` +- `aigis2_amd` +- `aigis3_amd` + +## Sweep + +```bash +bash amd_tools/run_sig_sweep.sh +``` + +Logs are written to `amd_results/sweep/`, and the CSV summary is written to: + +```text +amd_results/sig_sweep_summary.csv +``` + +## Large Batch Sweep + +Use this before kernel-level tuning to find the AMD batch-size ceiling and the best Keygen/Sign/Verify throughput: + +```bash +bash amd_tools/run_sig_large_sweep.sh +``` + +It runs all six signature targets with both `--bench-paper` and `--bench-independent` for batch sizes `8192, 16384, 32768`. + +Outputs: + +```text +amd_results/large_sweep/ +amd_results/sig_large_sweep_summary.csv +amd_results/sig_large_best.csv +``` + +To compare against a 4090 CSV after uploading it to the server: + +```bash +python3 amd_tools/compare_amd_4090.py amd_results/sig_large_sweep_summary.csv /app/4090数据.csv > amd_results/amd_vs_4090_large.csv +``` + +## AMD Feature Matrix + +Use this after a correctness smoke test to compare the stable build against +candidate AMD-specific signing variants: + +```bash +bash amd_tools/run_sig_amd_feature_matrix.sh +``` + +Default matrix: + +```text +targets: mldsa44 mldsa87 aigis2 +batches: 1024 8192 +modes: independent paper +variants: base adaptive check8 check16 wave64_ctrl cp_fuse tail16_base tail16_cp_fuse yhat_dup +``` + +Candidate meanings: + +```text +base stable resource-aware decomp path +adaptive runtime target/mode/batch policy with base fallback +check8/check16 fewer host-side done-count checks in the rejection loop +wave64_ctrl 64-thread hash/check control kernels for AMD wave64 testing +cp_fuse fused cp*s1/cp*s2/cp*t0 pointwise products +tail16_base small-tail finish without cp_fuse +tail16_cp_fuse small-tail finish with cp_fuse +yhat_dup sample-time y/y_hat copy candidate +``` + +Outputs: + +```text +amd_results/sig_amd_feature_matrix.csv +amd_results/sig_amd_feature_matrix_ranked.csv +``` + +For a wider pass: + +```bash +FEATURE_TARGETS="mldsa44 mldsa65 mldsa87 aigis1 aigis2 aigis3" \ +FEATURE_BATCHES="1024 8192 16384" \ +bash amd_tools/run_sig_amd_feature_matrix.sh +``` + +Generate a report-ready summary: + +```bash +python3 amd_tools/write_optimization_claims.py +``` + +After the matrix, generate a conservative per-target build plan: + +```bash +python3 amd_tools/select_sig_amd_variants.py +cat amd_results/sig_amd_variant_plan.md +``` + +If `adaptive` appears as a local winner or selected variant, rerun policy smoke +and debug matrix before using it for final large sweeps. + +## Debug Matrix + +After every source change, run a quick correctness/resource smoke test first: + +```bash +bash amd_tools/run_sig_debug_matrix.sh +``` + +This runs all six signature targets with batch sizes `1, 8, 32, 128` and writes: + +```text +amd_results/debug/ +amd_results/sig_debug_summary.csv +``` + +## Profile One Target + +```bash +bash amd_tools/profile_sig_one.sh mldsa44_amd 1024 +``` + +The script runs the executable with `--profile`. If `rocprofv3` is installed, it also records a ROCm profile under `amd_results/profile/`. + +Summarize any ROCm kernel CSV data: + +```bash +python3 amd_tools/summarize_rocm_kernel_profile.py amd_results/profile > amd_results/profile/kernel_summary.csv +``` + +## Submission Evidence Audit + +```bash +python3 amd_tools/check_competition_evidence.py +``` diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/SIG_DEBUG_PLAN.md b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/SIG_DEBUG_PLAN.md new file mode 100644 index 000000000..1b7bd45e0 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/SIG_DEBUG_PLAN.md @@ -0,0 +1,112 @@ +# ML-DSA / Aigis-sig AMD 调试路线 + +当前阶段应先把签名部分调试扎实,再考虑 Kyber / Aigis-enc。KEM/ENC 可以作为独立子项目推进,不必现在强行同构;比赛材料里只需要把二者组织成同一套“后量子密码 GPU 加速系统”的两个模块即可。 + +## 目标顺序 + +1. 正确性稳定:六个参数集在 AMD W7900 上 `keygen + sign + verify + tamper-reject` 全部通过。 +2. 性能基线完整:记录 batch=128/512/1024/2048/4096 的 latency、throughput。 +3. 资源边界清楚:确认哪些 batch 或 kernel 会触发 HSA out of resources、private segment 过大、显存不足。 +4. 瓶颈定位可解释:用程序内 `--profile` 和 `rocprofv3` 找到主要耗时 kernel。 +5. 针对性优化:每次只改一个开关或一个 kernel,保留优化前后日志。 + +## 推荐执行流程 + +### 1. 构建 + +```bash +tar -xzf amd_upload.tar.gz +bash amd_tools/build_sig_amd.sh +``` + +构建失败时优先查: + +- 是否设置 `--offload-arch=gfx1100` +- 是否设置 ROCm runtime 的 `LD_LIBRARY_PATH` +- 是否清除了 UTF-8 BOM +- 是否仍有 CUDA warp mask 的 32-bit 写法 + +### 2. 快速正确性矩阵 + +```bash +bash amd_tools/run_sig_debug_matrix.sh +``` + +这个脚本跑六个参数集的 batch=1/8/32/128。它比完整 sweep 快,适合每次源码变动后先跑一遍。输出: + +```text +amd_results/debug/*.log +amd_results/sig_debug_summary.csv +``` + +若某个参数集失败,先固定一个最小失败 batch 单独跑: + +```bash +./mldsa44_amd --batch 8 --quiet --skip-keygen-oracle +``` + +### 3. 完整 batch 曲线 + +```bash +bash amd_tools/run_sig_sweep.sh +``` + +输出: + +```text +amd_results/sweep/*.log +amd_results/sig_sweep_summary.csv +``` + +论文/PPT 里至少整理这几列: + +- 算法与参数集 +- batch size +- Keygen latency / throughput +- Sign latency / throughput +- Verify latency / throughput +- Sign/Verify correctness +- 当前 sign path,例如 decomp pipeline + +### 4. 单点 profiling + +先选一个代表性组合,例如: + +```bash +bash amd_tools/profile_sig_one.sh mldsa44_amd 1024 +bash amd_tools/profile_sig_one.sh mldsa87_amd 1024 +bash amd_tools/profile_sig_one.sh aigis2_amd 1024 +``` + +重点看: + +- `batch_sign_sample_y_kernel` +- `launch_batch_ntt` / `launch_batch_invntt` +- `batch_verify_matvec_kernel` +- `batch_sign_pointwise_cp_shared_kernel` +- pack/check 类 kernel +- 是否存在 private segment 明显偏大的 kernel + +### 5. 优化实验顺序 + +建议按风险从低到高做: + +1. 调 batch size:找每个参数集的吞吐峰值点和资源崩溃点。 +2. 调 `BLOCK_SIZE`:现在 AMD 安全值是 `1`,可尝试 `2/4/8/16`,每次完整记录正确性和 HSA 错误。 +3. 调 sign decomp 检查频率:尝试 `-DBATCH_SIGN_DECOMP_CHECK_INTERVAL=1/2/4/8`。 +4. 只在小参数集上测试 tail fallback:大参数集先保持 `-DBATCH_SIGN_DECOMP_TAIL_ENABLE=0`。 +5. 优化 matvec / NTT:这是论文里最容易讲清楚的核心 GPU 算子优化。 +6. 再考虑 keygen sample split / pack fusion 等较细优化。 + +## 结论路线 + +现阶段不要把 Kyber / Aigis-enc 强行塞进签名框架。建议项目结构上保持: + +```text +sig_amd/ ML-DSA + Aigis-sig AMD/HIP 同构签名框架 +kem_enc_amd/ Kyber + Aigis-enc AMD/HIP 独立批处理框架 +results/ 统一放性能日志、CSV、profiling 结果 +docs/ 统一写 README、论文、PPT 图表 +``` + +这样工程风险低,也更符合比赛评审关注点:真实 ROCm 使用、可复现数据、清晰 profiling 和针对性优化。 diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/build_sig_amd.sh b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/build_sig_amd.sh new file mode 100644 index 000000000..0d82b299f --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/build_sig_amd.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +set -euo pipefail + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +COMMON=( + -O2 + -std=c++17 + -x hip + --offload-arch=gfx1100 + -DBLOCK_SIZE=1 + -DBATCH_KEYGEN_INTERNAL_MATERIAL=1 + -DBATCH_SIGN_WARP_ENABLE=0 + -DBATCH_SIGN_MONO_ENABLE=0 + -DBATCH_SIGN_PRECOMP_REUSE=0 + -DBATCH_SIGN_LARGE_STRATEGY_ENABLE=0 + -DBATCH_SIGN_DECOMP_ENABLE=1 + -DBATCH_SIGN_DECOMP_TAIL_ENABLE=0 + -DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 + -DBATCH_SIGN_CP_FUSE_ENABLE=0 + -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 + -DBATCH_KEYGEN_SAMPLE_SPLIT_FAST=1 +) + +mkdir -p amd_results/build + +build_one() { + local alg="$1" + local mode="$2" + local out="$3" + echo "[build] ${out}" + hipcc "${COMMON[@]}" -DALGORITHM="${alg}" -DPARAM_MODE="${mode}" main.cu -o "${out}" \ + 2>&1 | tee "amd_results/build/${out}.log" +} + +build_one 1 2 mldsa44_amd +build_one 1 3 mldsa65_amd +build_one 1 5 mldsa87_amd +build_one 2 1 aigis1_amd +build_one 2 2 aigis2_amd +build_one 2 3 aigis3_amd + +echo "[build] done" diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/build_sig_amd_selected.sh b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/build_sig_amd_selected.sh new file mode 100644 index 000000000..b811c1751 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/build_sig_amd_selected.sh @@ -0,0 +1,108 @@ +#!/usr/bin/env bash +set -euo pipefail + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +plan="${1:-amd_results/sig_amd_variant_plan.env}" +if [[ -f "${plan}" ]]; then + # shellcheck disable=SC1090 + source "${plan}" + echo "[select] loaded ${plan}" +else + echo "[select] ${plan} not found; using base for all targets" +fi + +COMMON=( + -O2 + -std=c++17 + -x hip + --offload-arch=gfx1100 + -DBLOCK_SIZE=1 + -DBATCH_KEYGEN_INTERNAL_MATERIAL=1 + -DBATCH_SIGN_WARP_ENABLE=0 + -DBATCH_SIGN_MONO_ENABLE=0 + -DBATCH_SIGN_PRECOMP_REUSE=0 + -DBATCH_SIGN_LARGE_STRATEGY_ENABLE=0 + -DBATCH_SIGN_DECOMP_ENABLE=1 + -DBATCH_KEYGEN_SAMPLE_SPLIT_FAST=1 +) + +variant_flags() { + case "$1" in + base) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 -DBATCH_SIGN_CP_FUSE_ENABLE=0 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=0 -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=4" + ;; + adaptive) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=1 -DBATCH_SIGN_CP_FUSE_ENABLE=0 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=0 -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=4" + ;; + cp_fuse) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 -DBATCH_SIGN_CP_FUSE_ENABLE=1 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=0 -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=4" + ;; + check8) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 -DBATCH_SIGN_CP_FUSE_ENABLE=0 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=0 -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=8" + ;; + check16) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 -DBATCH_SIGN_CP_FUSE_ENABLE=0 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=0 -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=16" + ;; + wave64_ctrl) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 -DBATCH_SIGN_CP_FUSE_ENABLE=0 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=0 -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=4 -DBATCH_SIGN_HASH_TPB=64 -DBATCH_SIGN_CHECK_TPB=64" + ;; + wave64_check8) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 -DBATCH_SIGN_CP_FUSE_ENABLE=0 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=0 -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=8 -DBATCH_SIGN_HASH_TPB=64 -DBATCH_SIGN_CHECK_TPB=64" + ;; + tail16_base) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 -DBATCH_SIGN_CP_FUSE_ENABLE=0 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=1 -DBATCH_SIGN_DECOMP_TAIL_AFTER=16 -DBATCH_SIGN_DECOMP_TAIL_PENDING_DIV=256 -DBATCH_SIGN_DECOMP_TAIL_PENDING_MIN=8 -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=4" + ;; + tail16_cp_fuse|tail16) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 -DBATCH_SIGN_CP_FUSE_ENABLE=1 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=1 -DBATCH_SIGN_DECOMP_TAIL_AFTER=16 -DBATCH_SIGN_DECOMP_TAIL_PENDING_DIV=256 -DBATCH_SIGN_DECOMP_TAIL_PENDING_MIN=8 -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=4" + ;; + yhat_dup) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 -DBATCH_SIGN_CP_FUSE_ENABLE=1 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=0 -DBATCH_SIGN_SAMPLE_DUP_YHAT=1 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=4" + ;; + *) + echo "unknown variant: $1" >&2 + return 1 + ;; + esac +} + +target_variant() { + case "$1" in + mldsa44) echo "${SIG_AMD_VARIANT_MLDSA44:-base}" ;; + mldsa65) echo "${SIG_AMD_VARIANT_MLDSA65:-base}" ;; + mldsa87) echo "${SIG_AMD_VARIANT_MLDSA87:-base}" ;; + aigis1) echo "${SIG_AMD_VARIANT_AIGIS1:-base}" ;; + aigis2) echo "${SIG_AMD_VARIANT_AIGIS2:-base}" ;; + aigis3) echo "${SIG_AMD_VARIANT_AIGIS3:-base}" ;; + *) + echo "unknown target: $1" >&2 + return 1 + ;; + esac +} + +mkdir -p amd_results/build + +build_one() { + local alg="$1" + local mode="$2" + local target="$3" + local out="$4" + local variant extra + variant="$(target_variant "${target}")" + extra="$(variant_flags "${variant}")" + + echo "[build] ${out} target=${target} variant=${variant}" + # shellcheck disable=SC2086 + hipcc "${COMMON[@]}" ${extra} -DALGORITHM="${alg}" -DPARAM_MODE="${mode}" main.cu -o "${out}" \ + 2>&1 | tee "amd_results/build/${out}.log" +} + +build_one 1 2 mldsa44 mldsa44_amd +build_one 1 3 mldsa65 mldsa65_amd +build_one 1 5 mldsa87 mldsa87_amd +build_one 2 1 aigis1 aigis1_amd +build_one 2 2 aigis2 aigis2_amd +build_one 2 3 aigis3 aigis3_amd + +echo "[build] selected variants done" diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/check_competition_evidence.py b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/check_competition_evidence.py new file mode 100644 index 000000000..6b6b1c166 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/check_competition_evidence.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +import csv +import sys +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +RESULTS = ROOT / "amd_results" +VALID_DECOMP_SIGN_PATHS = {"decomp-pipeline", "decomp-adaptive"} + + +def fail(msg): + print(f"[FAIL] {msg}", file=sys.stderr) + return 1 + + +def read_text(path): + return path.read_text(errors="replace") if path.exists() else "" + + +def has_decomp_pass(text): + return any(f"PASS [{path}]" in text for path in VALID_DECOMP_SIGN_PATHS) + + +def check_large_best(): + path = RESULTS / "sig_large_best.csv" + if not path.exists(): + return fail(f"missing {path}") + + rows = list(csv.DictReader(path.open(newline="", errors="replace"))) + if not rows: + return fail(f"empty {path}") + + bad_sign = [ + r for r in rows + if r.get("operation") == "Sign" and r.get("path") not in VALID_DECOMP_SIGN_PATHS + ] + if bad_sign: + return fail(f"non-decomp sign paths in sig_large_best.csv: {bad_sign[:3]}") + + print(f"[OK] {path} has {len(rows)} best-result rows; sign path is decomp-based") + return 0 + + +def check_policy_smoke(): + smoke_dir = RESULTS / "policy_smoke" + logs = sorted(smoke_dir.glob("*.log")) + if not logs: + print(f"[WARN] no policy smoke logs under {smoke_dir}; run amd_tools/run_sig_policy_smoke.sh") + return 0 + + bad = [] + for path in logs: + text = read_text(path) + if "monolithic-precomp=off" not in text: + bad.append((path.name, "missing monolithic-precomp=off")) + if "cached-precomp=off" not in text: + bad.append((path.name, "missing cached-precomp=off")) + if "[Sign] correctness: all" not in text or not has_decomp_pass(text): + bad.append((path.name, "missing decomp PASS")) + if "FAIL" in text: + bad.append((path.name, "contains FAIL")) + if bad: + return fail(f"policy smoke evidence failed: {bad[:5]}") + + print(f"[OK] policy smoke logs pass: {len(logs)} files") + return 0 + + +def check_feature_matrix(): + ranked = RESULTS / "sig_amd_feature_matrix_ranked.csv" + if not ranked.exists(): + print(f"[WARN] no AMD feature matrix summary at {ranked}; run amd_tools/run_sig_amd_feature_matrix.sh") + return 0 + + rows = list(csv.DictReader(ranked.open(newline="", errors="replace"))) + if not rows: + return fail(f"empty {ranked}") + + bad = [ + r for r in rows + if r.get("status") == "PASS" and r.get("sign_path") not in VALID_DECOMP_SIGN_PATHS + ] + if bad: + return fail(f"feature matrix has non-decomp sign paths: {bad[:3]}") + + print(f"[OK] AMD feature matrix summary present: {len(rows)} rows") + return 0 + + +def check_large_sweep_logs(): + sweep_dir = RESULTS / "large_sweep" + logs = sorted(sweep_dir.glob("*.log")) + if not logs: + print(f"[WARN] no large sweep logs under {sweep_dir}; run amd_tools/run_sig_large_sweep.sh") + return 0 + + bad = [] + for path in logs: + text = read_text(path) + if "monolithic-precomp=on" in text: + bad.append((path.name, "monolithic-precomp=on")) + if "FAIL" in text: + bad.append((path.name, "contains FAIL")) + if "HSA_STATUS_ERROR_OUT_OF_RESOURCES" in text: + bad.append((path.name, "contains HSA out of resources")) + if bad: + return fail(f"large sweep log check failed: {bad[:5]}") + + print(f"[OK] large sweep logs clean: {len(logs)} files") + return 0 + + +def check_profile_status(): + profile_dir = RESULTS / "profile" + app_logs = sorted(profile_dir.glob("*_profile.log")) + if not app_logs: + print(f"[WARN] no app-level profile logs under {profile_dir}; run amd_tools/profile_sig_one.sh") + return 0 + + print(f"[OK] app-level profile logs present: {len(app_logs)} files") + stale = [] + for path in profile_dir.glob("*_rocprof.log"): + if "unrecognized arguments" in read_text(path): + stale.append(path.name) + if stale: + print(f"[WARN] stale rocprof logs need rerun: {stale}") + return 0 + + +def main(): + rc = 0 + for check in ( + check_large_best, + check_policy_smoke, + check_feature_matrix, + check_large_sweep_logs, + check_profile_status, + ): + rc |= check() + if rc == 0: + print("[OK] competition evidence audit complete") + return rc + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/compare_amd_4090.py b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/compare_amd_4090.py new file mode 100644 index 000000000..798a1327b --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/compare_amd_4090.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +import csv +import sys +from pathlib import Path + +if len(sys.argv) != 3: + print("usage: compare_amd_4090.py <4090_csv>", file=sys.stderr) + raise SystemExit(2) + +amd_path = Path(sys.argv[1]) +nv_path = Path(sys.argv[2]) + +def amd_target(row): + scheme = row.get("scheme", "") + mode = row.get("mode", "") + if scheme == "ML-DSA": + return {"2": "mldsa44", "3": "mldsa65", "5": "mldsa87"}.get(mode, f"mldsa_mode{mode}") + if scheme == "Aigis-sig": + return {"1": "aigis1", "2": "aigis2", "3": "aigis3"}.get(mode, f"aigis_mode{mode}") + return f"{scheme}_mode{mode}" + +def amd_bench_mode(log_name): + if "_independent_" in log_name: + return "independent" + if "_paper_" in log_name: + return "paper" + return "paper" + +def as_float(value): + try: + return float(value) + except (TypeError, ValueError): + return 0.0 + +amd_best = {} +with amd_path.open(newline="", errors="replace") as f: + for row in csv.DictReader(f): + if row.get("status") != "PASS": + continue + target = amd_target(row) + bench_mode = amd_bench_mode(row.get("log", "")) + for op, field in ( + ("Keygen", "keygen_ops_s"), + ("Sign", "sign_ops_s"), + ("Verify", "verify_ops_s"), + ): + ops = as_float(row.get(field)) + key = (target, bench_mode, op) + if ops > amd_best.get(key, {}).get("ops_s", -1): + amd_best[key] = { + "target": target, + "benchmark_mode": bench_mode, + "operation": op, + "batch": row.get("batch", ""), + "ms": row.get(f"{op.lower()}_ms", ""), + "ops_s": ops, + "path": row.get(f"{op.lower()}_path", ""), + "log": row.get("log", ""), + } + +nv_rows = [] +with nv_path.open(newline="", errors="replace-sig") as f: + reader = csv.DictReader(f) + for row in reader: + target = row.get("目标") or row.get("target") or "" + bench_mode = row.get("模式") or row.get("mode") or row.get("benchmark_mode") or "" + if not target or not bench_mode: + continue + nv_rows.append(row) + +nv_best = {} +for row in nv_rows: + target = row.get("目标") or row.get("target") + bench_mode = row.get("模式") or row.get("mode") or row.get("benchmark_mode") + candidates = ( + ("Keygen", row.get("Keygen_ops_s") or row.get("keygen_ops_s"), row.get("Keygen_ms") or row.get("keygen_ms"), row.get("Keygen路径") or row.get("keygen_path")), + ("Sign", row.get("Sign_ops_s") or row.get("sign_ops_s"), row.get("Sign_ms") or row.get("sign_ms"), row.get("Sign路径") or row.get("sign_path")), + ("Verify", row.get("Verify_ops_s") or row.get("verify_ops_s"), row.get("Verify_ms") or row.get("verify_ms"), ""), + ) + for op, ops_text, ms_text, path_text in candidates: + ops = as_float(ops_text) + key = (target, bench_mode, op) + if ops > nv_best.get(key, {}).get("ops_s", -1): + nv_best[key] = { + "target": target, + "benchmark_mode": bench_mode, + "operation": op, + "batch": row.get("批量N") or row.get("batch") or "", + "ms": ms_text or "", + "ops_s": ops, + "path": path_text or "", + } + +fieldnames = [ + "target", + "benchmark_mode", + "operation", + "amd_best_batch", + "amd_ms", + "amd_ops_s", + "amd_path", + "rtx4090_batch", + "rtx4090_ms", + "rtx4090_ops_s", + "rtx4090_path", + "amd_vs_4090_ratio", + "amd_log", +] +writer = csv.DictWriter(sys.stdout, fieldnames=fieldnames, lineterminator="\n") +writer.writeheader() + +for key in sorted(set(amd_best) | set(nv_best)): + amd = amd_best.get(key, {}) + nv = nv_best.get(key, {}) + amd_ops = amd.get("ops_s", 0.0) + nv_ops = nv.get("ops_s", 0.0) + ratio = amd_ops / nv_ops if nv_ops > 0 else 0.0 + writer.writerow({ + "target": key[0], + "benchmark_mode": key[1], + "operation": key[2], + "amd_best_batch": amd.get("batch", ""), + "amd_ms": amd.get("ms", ""), + "amd_ops_s": f"{amd_ops:.0f}" if amd else "", + "amd_path": amd.get("path", ""), + "rtx4090_batch": nv.get("batch", ""), + "rtx4090_ms": nv.get("ms", ""), + "rtx4090_ops_s": f"{nv_ops:.0f}" if nv else "", + "rtx4090_path": nv.get("path", ""), + "amd_vs_4090_ratio": f"{ratio:.3f}" if amd and nv else "", + "amd_log": amd.get("log", ""), + }) diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/parse_sig_results.py b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/parse_sig_results.py new file mode 100644 index 000000000..0daa95e7a --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/parse_sig_results.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +import csv +import re +import sys +from pathlib import Path + +if len(sys.argv) != 2: + print("usage: parse_sig_results.py ", file=sys.stderr) + raise SystemExit(2) + +log_dir = Path(sys.argv[1]) +rows = [] + +stage_re = re.compile( + r"^\s*(Keygen|Sign|Verify)\s+(\d+)\s+([0-9.]+)\s+ms\s+([0-9.]+)\s+ops/s(?:\s+\[([^\]]+)\])?" +) +header_re = re.compile(r"^===\s+(.+?)\s+\(Mode=(\d+)\)\s+\|\s+Batch=(\d+)") + +for log_path in sorted(log_dir.glob("*.log")): + row = { + "log": log_path.name, + "scheme": "", + "mode": "", + "batch": "", + "keygen_ms": "", + "keygen_ops_s": "", + "keygen_path": "", + "sign_ms": "", + "sign_ops_s": "", + "sign_path": "", + "verify_ms": "", + "verify_ops_s": "", + "sign_pass": "NO", + "verify_pass": "NO", + "status": "FAIL", + } + text = log_path.read_text(errors="replace") + for line in text.splitlines(): + m = header_re.search(line) + if m: + row["scheme"] = m.group(1) + row["mode"] = m.group(2) + row["batch"] = m.group(3) + continue + m = stage_re.search(line) + if m: + stage = m.group(1).lower() + row[f"{stage}_ms"] = m.group(3) + row[f"{stage}_ops_s"] = m.group(4) + if stage in ("keygen", "sign"): + row[f"{stage}_path"] = m.group(5) or "" + continue + if "[Sign] correctness: all" in line and "PASS" in line: + row["sign_pass"] = "YES" + if "[Verify] correctness: all" in line and "PASS" in line: + row["verify_pass"] = "YES" + if row["sign_pass"] == "YES" and row["verify_pass"] == "YES": + row["status"] = "PASS" + rows.append(row) + +fieldnames = [ + "scheme", + "mode", + "batch", + "keygen_ms", + "keygen_ops_s", + "keygen_path", + "sign_ms", + "sign_ops_s", + "sign_path", + "verify_ms", + "verify_ops_s", + "sign_pass", + "verify_pass", + "status", + "log", +] + +writer = csv.DictWriter(sys.stdout, fieldnames=fieldnames, lineterminator="\n") +writer.writeheader() +for row in rows: + writer.writerow({key: row.get(key, "") for key in fieldnames}) diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/profile_sig_one.sh b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/profile_sig_one.sh new file mode 100644 index 000000000..957e90363 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/profile_sig_one.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash +set -euo pipefail + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +exe="${1:-mldsa44_amd}" +batch="${2:-1024}" +mkdir -p amd_results/profile + +if [[ ! -x "./${exe}" ]]; then + echo "error: ./${exe} not found or not executable" >&2 + exit 1 +fi + +plain_log="amd_results/profile/${exe}_b${batch}_profile.log" +rocprof_dir="amd_results/profile/${exe}_b${batch}_rocprof" + +echo "[profile] app-level profile: ${exe} batch=${batch}" +stdbuf -oL -eL "./${exe}" --batch "${batch}" --quiet --skip-keygen-oracle --profile \ + 2>&1 | tee "${plain_log}" + +if command -v rocprofv3 >/dev/null 2>&1; then + echo "[profile] rocprofv3 output: ${rocprof_dir}" + rm -rf "${rocprof_dir}" + mkdir -p "${rocprof_dir}" + rocprof_log="amd_results/profile/${exe}_b${batch}_rocprof.log" + set +e + rocprofv3 --output-directory "${rocprof_dir}" -- \ + "./${exe}" --batch "${batch}" --quiet --skip-keygen-oracle \ + 2>&1 | tee "${rocprof_log}" + rc=${PIPESTATUS[0]} + set -e + if [[ "${rc}" -ne 0 ]]; then + echo "[profile] rocprofv3 failed with exit_code=${rc}; see ${rocprof_log}" >&2 + exit "${rc}" + fi +else + echo "[profile] rocprofv3 not found; skipped ROCm trace" +fi diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_amd_feature_matrix.sh b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_amd_feature_matrix.sh new file mode 100644 index 000000000..34d0740d3 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_amd_feature_matrix.sh @@ -0,0 +1,153 @@ +#!/usr/bin/env bash +set -euo pipefail + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +out_dir="amd_results/amd_feature_matrix" +mkdir -p "${out_dir}" "amd_results/build" + +COMMON=( + -O2 + -std=c++17 + -x hip + --offload-arch=gfx1100 + -DBLOCK_SIZE=1 + -DBATCH_KEYGEN_INTERNAL_MATERIAL=1 + -DBATCH_SIGN_WARP_ENABLE=0 + -DBATCH_SIGN_MONO_ENABLE=0 + -DBATCH_SIGN_PRECOMP_REUSE=0 + -DBATCH_SIGN_LARGE_STRATEGY_ENABLE=0 + -DBATCH_SIGN_DECOMP_ENABLE=1 + -DBATCH_KEYGEN_SAMPLE_SPLIT_FAST=1 +) + +# Representative default targets. Override with: +# FEATURE_TARGETS="mldsa44 mldsa87 aigis2" +# FEATURE_BATCHES="1024 8192 16384" +# FEATURE_MODES="independent paper" +read -r -a target_names <<< "${FEATURE_TARGETS:-mldsa44 mldsa87 aigis2}" +read -r -a batches <<< "${FEATURE_BATCHES:-1024 8192}" +read -r -a modes <<< "${FEATURE_MODES:-independent paper}" +read -r -a variants <<< "${FEATURE_VARIANTS:-base adaptive check8 check16 wave64_ctrl cp_fuse tail16_base tail16_cp_fuse yhat_dup}" +repeats="${FEATURE_REPEATS:-1}" + +target_alg_mode() { + case "$1" in + mldsa44) echo "1 2" ;; + mldsa65) echo "1 3" ;; + mldsa87) echo "1 5" ;; + aigis1) echo "2 1" ;; + aigis2) echo "2 2" ;; + aigis3) echo "2 3" ;; + *) + echo "unknown target: $1" >&2 + return 1 + ;; + esac +} + +variant_flags() { + case "$1" in + base) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 -DBATCH_SIGN_CP_FUSE_ENABLE=0 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=0 -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=4" + ;; + adaptive) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=1 -DBATCH_SIGN_CP_FUSE_ENABLE=0 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=0 -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=4" + ;; + cp_fuse) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 -DBATCH_SIGN_CP_FUSE_ENABLE=1 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=0 -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=4" + ;; + check8) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 -DBATCH_SIGN_CP_FUSE_ENABLE=0 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=0 -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=8" + ;; + check16) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 -DBATCH_SIGN_CP_FUSE_ENABLE=0 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=0 -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=16" + ;; + wave64_ctrl) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 -DBATCH_SIGN_CP_FUSE_ENABLE=0 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=0 -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=4 -DBATCH_SIGN_HASH_TPB=64 -DBATCH_SIGN_CHECK_TPB=64" + ;; + wave64_check8) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 -DBATCH_SIGN_CP_FUSE_ENABLE=0 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=0 -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=8 -DBATCH_SIGN_HASH_TPB=64 -DBATCH_SIGN_CHECK_TPB=64" + ;; + tail16_base) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 -DBATCH_SIGN_CP_FUSE_ENABLE=0 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=1 -DBATCH_SIGN_DECOMP_TAIL_AFTER=16 -DBATCH_SIGN_DECOMP_TAIL_PENDING_DIV=256 -DBATCH_SIGN_DECOMP_TAIL_PENDING_MIN=8 -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=4" + ;; + tail16_cp_fuse|tail16) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 -DBATCH_SIGN_CP_FUSE_ENABLE=1 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=1 -DBATCH_SIGN_DECOMP_TAIL_AFTER=16 -DBATCH_SIGN_DECOMP_TAIL_PENDING_DIV=256 -DBATCH_SIGN_DECOMP_TAIL_PENDING_MIN=8 -DBATCH_SIGN_SAMPLE_DUP_YHAT=0 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=4" + ;; + yhat_dup) + echo "-DBATCH_SIGN_DECOMP_ADAPTIVE_ENABLE=0 -DBATCH_SIGN_CP_FUSE_ENABLE=1 -DBATCH_SIGN_DECOMP_TAIL_ENABLE=0 -DBATCH_SIGN_SAMPLE_DUP_YHAT=1 -DBATCH_SIGN_DECOMP_CHECK_INTERVAL=4" + ;; + *) + echo "unknown variant: $1" >&2 + return 1 + ;; + esac +} + +build_one() { + local target="$1" + local variant="$2" + local alg mode extra exe + read -r alg mode <<< "$(target_alg_mode "${target}")" + extra="$(variant_flags "${variant}")" + exe="${target}_${variant}_amd" + echo "[build] ${exe}" + # shellcheck disable=SC2086 + hipcc "${COMMON[@]}" ${extra} -DALGORITHM="${alg}" -DPARAM_MODE="${mode}" main.cu -o "${exe}" \ + 2>&1 | tee "amd_results/build/${exe}.log" +} + +run_one() { + local target="$1" + local variant="$2" + local mode="$3" + local batch="$4" + local repeat="$5" + local exe="${target}_${variant}_amd" + local mode_flag="--bench-${mode}" + local log + if [[ "${repeats}" -gt 1 ]]; then + log="${out_dir}/${target}_${variant}_${mode}_b${batch}_r${repeat}.log" + else + log="${out_dir}/${target}_${variant}_${mode}_b${batch}.log" + fi + + if [[ ! -x "./${exe}" ]]; then + echo "[skip] ./${exe} not found or not executable" + return 0 + fi + + echo "[feature] ${exe} mode=${mode} batch=${batch} repeat=${repeat}/${repeats}" + set +e + stdbuf -oL -eL "./${exe}" "${mode_flag}" --batch "${batch}" --quiet --skip-keygen-oracle \ + 2>&1 | tee "${log}" + rc=${PIPESTATUS[0]} + set -e + echo "[feature] exit_code=${rc}" | tee -a "${log}" +} + +for target in "${target_names[@]}"; do + for variant in "${variants[@]}"; do + build_one "${target}" "${variant}" + done +done + +for target in "${target_names[@]}"; do + for mode in "${modes[@]}"; do + for batch in "${batches[@]}"; do + for repeat in $(seq 1 "${repeats}"); do + for variant in "${variants[@]}"; do + run_one "${target}" "${variant}" "${mode}" "${batch}" "${repeat}" + done + done + done + done +done + +python3 amd_tools/parse_sig_results.py "${out_dir}" > amd_results/sig_amd_feature_matrix.csv +python3 amd_tools/summarize_amd_feature_matrix.py amd_results/sig_amd_feature_matrix.csv \ + > amd_results/sig_amd_feature_matrix_ranked.csv + +echo "[summary] amd_results/sig_amd_feature_matrix.csv" +echo "[summary] amd_results/sig_amd_feature_matrix_ranked.csv" diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_debug_matrix.sh b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_debug_matrix.sh new file mode 100644 index 000000000..dcab0ca5a --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_debug_matrix.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash +set -euo pipefail + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +mkdir -p amd_results/debug + +targets=( + mldsa44_amd + mldsa65_amd + mldsa87_amd + aigis1_amd + aigis2_amd + aigis3_amd +) + +# Small batches catch correctness and resource issues quickly before a long sweep. +batches=(1 8 32 128) + +for exe in "${targets[@]}"; do + if [[ ! -x "./${exe}" ]]; then + echo "[skip] ./${exe} not found or not executable" + continue + fi + + for batch in "${batches[@]}"; do + log="amd_results/debug/${exe}_b${batch}.log" + echo "[debug] ${exe} batch=${batch}" + stdbuf -oL -eL "./${exe}" --batch "${batch}" --quiet --skip-keygen-oracle \ + 2>&1 | tee "${log}" + + if grep -q "FAIL" "${log}"; then + echo "[debug] FAIL detected in ${log}" >&2 + exit 1 + fi + done +done + +python3 amd_tools/parse_sig_results.py amd_results/debug > amd_results/sig_debug_summary.csv +echo "[summary] amd_results/sig_debug_summary.csv" diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_large_sweep.sh b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_large_sweep.sh new file mode 100644 index 000000000..d816a1f8c --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_large_sweep.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash +set -euo pipefail + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +mkdir -p amd_results/large_sweep + +targets=( + mldsa44_amd + mldsa65_amd + mldsa87_amd + aigis1_amd + aigis2_amd + aigis3_amd +) + +modes=( + paper + independent +) + +# First boundary pass. Add 65536 only after these are stable. +batches=(8192 16384 32768) + +for exe in "${targets[@]}"; do + if [[ ! -x "./${exe}" ]]; then + echo "[skip] ./${exe} not found or not executable" + continue + fi + + for mode in "${modes[@]}"; do + mode_flag="--bench-paper" + if [[ "${mode}" == "independent" ]]; then + mode_flag="--bench-independent" + fi + + for batch in "${batches[@]}"; do + log="amd_results/large_sweep/${exe}_${mode}_b${batch}.log" + echo "[large] ${exe} mode=${mode} batch=${batch}" + set +e + stdbuf -oL -eL "./${exe}" "${mode_flag}" --batch "${batch}" --quiet --skip-keygen-oracle \ + 2>&1 | tee "${log}" + rc=${PIPESTATUS[0]} + set -e + echo "[large] exit_code=${rc}" | tee -a "${log}" + done + done +done + +python3 amd_tools/parse_sig_results.py amd_results/large_sweep > amd_results/sig_large_sweep_summary.csv +python3 amd_tools/summarize_sig_best.py amd_results/sig_large_sweep_summary.csv > amd_results/sig_large_best.csv + +echo "[summary] amd_results/sig_large_sweep_summary.csv" +echo "[summary] amd_results/sig_large_best.csv" diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_policy_smoke.sh b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_policy_smoke.sh new file mode 100644 index 000000000..2da2df206 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_policy_smoke.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +set -euo pipefail + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +out_dir="amd_results/policy_smoke" +mkdir -p "${out_dir}" + +targets=( + mldsa44_amd + mldsa65_amd + mldsa87_amd + aigis1_amd + aigis2_amd + aigis3_amd +) + +batch="${1:-128}" +summary="${out_dir}/policy_smoke_b${batch}.txt" +: > "${summary}" + +echo "[policy-smoke] batch=${batch}" | tee -a "${summary}" + +for exe in "${targets[@]}"; do + if [[ ! -x "./${exe}" ]]; then + echo "[skip] ./${exe} not found or not executable" | tee -a "${summary}" + continue + fi + + log="${out_dir}/${exe}_b${batch}.log" + echo "[run] ${exe} batch=${batch}" | tee -a "${summary}" + stdbuf -oL -eL "./${exe}" --batch "${batch}" --quiet --skip-keygen-oracle \ + 2>&1 | tee "${log}" + + grep -E "ROCm sign policy|monolithic-precomp|decomp-cp-fuse|decomp-tail|yhat-copy-fuse|decomp-adaptive|rationale|\\[Sign\\] correctness| Sign[[:space:]]+" "${log}" \ + | sed "s/^/[${exe}] /" | tee -a "${summary}" + + if grep -q "FAIL" "${log}"; then + echo "[policy-smoke] FAIL detected in ${log}" | tee -a "${summary}" >&2 + exit 1 + fi +done + +echo "[policy-smoke] PASS; summary=${summary}" diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_sweep.sh b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_sweep.sh new file mode 100644 index 000000000..e92ae56a0 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/run_sig_sweep.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash +set -euo pipefail + +export LD_LIBRARY_PATH="/opt/python/lib/python3.12/site-packages/_rocm_sdk_devel/lib:${LD_LIBRARY_PATH:-}" + +mkdir -p amd_results/sweep + +targets=( + mldsa44_amd + mldsa65_amd + mldsa87_amd + aigis1_amd + aigis2_amd + aigis3_amd +) + +batches=(128 512 1024 2048 4096) + +for exe in "${targets[@]}"; do + if [[ ! -x "./${exe}" ]]; then + echo "[skip] ./${exe} not found or not executable" + continue + fi + + for batch in "${batches[@]}"; do + log="amd_results/sweep/${exe}_b${batch}.log" + echo "[run] ${exe} batch=${batch}" + stdbuf -oL -eL "./${exe}" --batch "${batch}" --quiet --skip-keygen-oracle \ + 2>&1 | tee "${log}" + done +done + +python3 amd_tools/parse_sig_results.py amd_results/sweep > amd_results/sig_sweep_summary.csv +echo "[summary] amd_results/sig_sweep_summary.csv" diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/select_sig_amd_variants.py b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/select_sig_amd_variants.py new file mode 100644 index 000000000..71d0010cc --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/select_sig_amd_variants.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +import csv +import math +import os +from collections import defaultdict +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +RESULTS = ROOT / "amd_results" +RANKED = RESULTS / "sig_amd_feature_matrix_ranked.csv" +OUT_MD = RESULTS / "sig_amd_variant_plan.md" +OUT_ENV = RESULTS / "sig_amd_variant_plan.env" + +TARGETS = ("mldsa44", "mldsa65", "mldsa87", "aigis1", "aigis2", "aigis3") +DEFAULT_MIN_SPEEDUP = 1.0000 +DEFAULT_GEOMEAN = 1.0300 +VALID_DECOMP_SIGN_PATHS = {"decomp-pipeline", "decomp-adaptive"} + + +def env_float(name, default): + raw = os.environ.get(name, "") + if not raw: + return default + try: + return float(raw) + except ValueError: + return default + + +MIN_SPEEDUP = env_float("SIG_AMD_SELECT_MIN_SPEEDUP", DEFAULT_MIN_SPEEDUP) +GEOMEAN_SPEEDUP = env_float("SIG_AMD_SELECT_GEOMEAN_SPEEDUP", DEFAULT_GEOMEAN) + + +def read_csv(path): + if not path.exists(): + return [] + with path.open(newline="", errors="replace") as f: + return list(csv.DictReader(f)) + + +def as_float(value, default=0.0): + try: + return float(value or 0) + except ValueError: + return default + + +def pass_row(row): + return ( + row.get("status") == "PASS" + and row.get("sign_pass") == "YES" + and row.get("verify_pass") == "YES" + and row.get("sign_path") in VALID_DECOMP_SIGN_PATHS + ) + + +def geomean(values): + vals = [v for v in values if v > 0] + if not vals: + return 0.0 + return math.exp(sum(math.log(v) for v in vals) / len(vals)) + + +def variant_env_name(target): + return f"SIG_AMD_VARIANT_{target.upper()}" + + +def build_cells(rows): + cells = defaultdict(lambda: defaultdict(dict)) + for row in rows: + target = row.get("target", "") + mode = row.get("benchmark_mode", "") + batch = row.get("batch", "") + variant = row.get("variant", "") + if not target or not mode or not batch or not variant: + continue + cells[target][(mode, batch)][variant] = row + return cells + + +def evaluate_target(target, target_cells): + tested = { + key: variants + for key, variants in target_cells.items() + if "base" in variants and pass_row(variants["base"]) + } + if not tested: + return "base", [], "no passing base rows" + + variants = sorted({ + variant + for rows in tested.values() + for variant in rows + if variant != "base" + }) + + diagnostics = [] + best = None + for variant in variants: + speedups = [] + missing = [] + failed = [] + for key, rows in sorted(tested.items()): + row = rows.get(variant) + if row is None: + missing.append(key) + continue + if not pass_row(row): + failed.append(key) + continue + speedups.append(as_float(row.get("speedup_vs_base"))) + + combo_count = len(tested) + min_sp = min(speedups) if speedups else 0.0 + gm = geomean(speedups) + mean = sum(speedups) / len(speedups) if speedups else 0.0 + wins = sum(1 for v in speedups if v > 1.0) + losses = sum(1 for v in speedups if v < 1.0) + + ok = ( + len(speedups) == combo_count + and not missing + and not failed + and min_sp >= MIN_SPEEDUP + and gm >= GEOMEAN_SPEEDUP + ) + if missing: + reason = f"missing {len(missing)} cells" + elif failed: + reason = f"failed {len(failed)} cells" + elif min_sp < MIN_SPEEDUP: + reason = f"min speedup {min_sp:.4f} below {MIN_SPEEDUP:.4f}" + elif gm < GEOMEAN_SPEEDUP: + reason = f"geomean {gm:.4f} below {GEOMEAN_SPEEDUP:.4f}" + else: + reason = "selected" + + diag = { + "target": target, + "variant": variant, + "combos": combo_count, + "passed": len(speedups), + "min": min_sp, + "geomean": gm, + "mean": mean, + "wins": wins, + "losses": losses, + "ok": ok, + "reason": reason, + } + diagnostics.append(diag) + if ok and (best is None or (gm, min_sp, mean) > (best["geomean"], best["min"], best["mean"])): + best = diag + + if best is None: + return "base", diagnostics, "no conservative non-base winner" + return best["variant"], diagnostics, "promoted by conservative matrix rule" + + +def local_winners(rows): + winners = [] + grouped = defaultdict(list) + for row in rows: + if pass_row(row): + key = (row.get("target", ""), row.get("benchmark_mode", ""), row.get("batch", "")) + grouped[key].append(row) + for key, group in sorted(grouped.items()): + non_base = [ + r for r in group + if r.get("variant") != "base" and as_float(r.get("speedup_vs_base")) > 1.0 + ] + if not non_base: + continue + best = max(non_base, key=lambda r: as_float(r.get("speedup_vs_base"))) + winners.append((key, best)) + return winners + + +def main(): + rows = read_csv(RANKED) + if not rows: + raise SystemExit(f"missing or empty {RANKED}; run amd_tools/run_sig_amd_feature_matrix.sh first") + + RESULTS.mkdir(parents=True, exist_ok=True) + cells = build_cells(rows) + + selections = {} + all_diagnostics = [] + reasons = {} + for target in TARGETS: + selected, diagnostics, reason = evaluate_target(target, cells.get(target, {})) + selections[target] = selected + all_diagnostics.extend(diagnostics) + reasons[target] = reason + + env_lines = [ + "# Generated by amd_tools/select_sig_amd_variants.py", + "# Source this file with amd_tools/build_sig_amd_selected.sh.", + f"# Rule: min_speedup>={MIN_SPEEDUP:.4f}, geomean>={GEOMEAN_SPEEDUP:.4f}.", + ] + for target in TARGETS: + env_lines.append(f"{variant_env_name(target)}={selections[target]}") + OUT_ENV.write_text("\n".join(env_lines) + "\n", encoding="utf-8") + + md = [] + md.append("# AMD SIG Variant Plan") + md.append("") + md.append( + f"Conservative rule: a non-base variant must pass every measured cell, " + f"keep min speedup >= {MIN_SPEEDUP:.4f}, and reach geomean >= {GEOMEAN_SPEEDUP:.4f}." + ) + md.append("") + md.append("## Selected Build") + md.append("") + md.append("| target | selected variant | reason |") + md.append("| --- | --- | --- |") + for target in TARGETS: + md.append(f"| {target} | {selections[target]} | {reasons[target]} |") + md.append("") + md.append("## Candidate Diagnostics") + md.append("") + md.append("| target | variant | passed / combos | min | geomean | mean | wins | losses | decision |") + md.append("| --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | --- |") + for d in all_diagnostics: + md.append( + f"| {d['target']} | {d['variant']} | {d['passed']} / {d['combos']} | " + f"{d['min']:.4f} | {d['geomean']:.4f} | {d['mean']:.4f} | " + f"{d['wins']} | {d['losses']} | {d['reason']} |" + ) + md.append("") + md.append("## Local Winners") + md.append("") + md.append("These rows are useful for the writeup, but are not promoted unless the target-level rule above passes.") + md.append("") + md.append("| target | mode | batch | variant | speedup vs base | log |") + md.append("| --- | --- | ---: | --- | ---: | --- |") + for (target, mode, batch), row in local_winners(rows): + md.append( + f"| {target} | {mode} | {batch} | {row.get('variant','')} | " + f"{row.get('speedup_vs_base','')} | {row.get('log','')} |" + ) + md.append("") + md.append("## Build Command") + md.append("") + md.append("```bash") + md.append("bash amd_tools/build_sig_amd_selected.sh amd_results/sig_amd_variant_plan.env") + md.append("```") + md.append("") + + OUT_MD.write_text("\n".join(md), encoding="utf-8") + print(f"[OK] wrote {OUT_ENV}") + print(f"[OK] wrote {OUT_MD}") + + +if __name__ == "__main__": + main() diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/summarize_amd_feature_matrix.py b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/summarize_amd_feature_matrix.py new file mode 100644 index 000000000..cd984e9f1 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/summarize_amd_feature_matrix.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +import csv +import re +import sys +from pathlib import Path + + +if len(sys.argv) != 2: + print("usage: summarize_amd_feature_matrix.py ", file=sys.stderr) + raise SystemExit(2) + + +name_re = re.compile( + r"^(mldsa44|mldsa65|mldsa87|aigis1|aigis2|aigis3)_" + r"(.+)_(paper|independent)_b(\d+)\.log$" +) +repeat_name_re = re.compile( + r"^(mldsa44|mldsa65|mldsa87|aigis1|aigis2|aigis3)_" + r"(.+)_(paper|independent)_b(\d+)_r(\d+)\.log$" +) + + +def parse_log_name(log_name): + m = repeat_name_re.match(log_name) + if m: + return m.group(1), m.group(2), m.group(3), m.group(4), m.group(5) + m = name_re.match(log_name) + if not m: + return "", "", "", "", "" + return m.group(1), m.group(2), m.group(3), m.group(4), "" + + +def as_float(value): + try: + return float(value or 0) + except ValueError: + return 0.0 + + +def median(values): + vals = sorted(v for v in values if v > 0) + if not vals: + return 0.0 + mid = len(vals) // 2 + if len(vals) % 2: + return vals[mid] + return (vals[mid - 1] + vals[mid]) / 2.0 + + +def median_row(group): + row = dict(group[0]) + pass_rows = [ + r for r in group + if r.get("status") == "PASS" + and r.get("sign_pass") == "YES" + and r.get("verify_pass") == "YES" + ] + source = pass_rows if pass_rows else group + med_sign = median([as_float(r.get("sign_ops_s")) for r in source]) + med_sign_ms = median([as_float(r.get("sign_ms")) for r in source]) + med_keygen = median([as_float(r.get("keygen_ops_s")) for r in source]) + med_verify = median([as_float(r.get("verify_ops_s")) for r in source]) + row["status"] = "PASS" if pass_rows else "FAIL" + row["sign_pass"] = "YES" if pass_rows else "NO" + row["verify_pass"] = "YES" if pass_rows else "NO" + row["sign_ops_s"] = f"{med_sign:.0f}" if med_sign else "" + row["sign_ms"] = f"{med_sign_ms:.3f}" if med_sign_ms else "" + row["keygen_ops_s"] = f"{med_keygen:.0f}" if med_keygen else "" + row["verify_ops_s"] = f"{med_verify:.0f}" if med_verify else "" + row["log"] = ";".join(r.get("log", "") for r in group) + return row + + +rows = [] +with Path(sys.argv[1]).open(newline="", errors="replace") as f: + for row in csv.DictReader(f): + target, variant, bench_mode, batch, repeat = parse_log_name(row.get("log", "")) + if not target: + continue + row["target"] = target + row["variant"] = variant + row["benchmark_mode"] = bench_mode + row["batch"] = batch + row["repeat"] = repeat + rows.append(row) + +raw_rows = rows + +grouped = {} +for row in rows: + key = (row["target"], row["variant"], row["benchmark_mode"], row["batch"]) + grouped.setdefault(key, []).append(row) +rows = [median_row(group) for key, group in sorted(grouped.items())] + +base_repeat_sign = {} +for row in raw_rows: + if ( + row.get("variant") == "base" + and row.get("status") == "PASS" + and row.get("sign_pass") == "YES" + and row.get("verify_pass") == "YES" + ): + base_repeat_sign[ + (row["target"], row["benchmark_mode"], row["batch"], row.get("repeat", "")) + ] = as_float(row.get("sign_ops_s")) + +paired_speedup = {} +for row in raw_rows: + if ( + row.get("status") != "PASS" + or row.get("sign_pass") != "YES" + or row.get("verify_pass") != "YES" + ): + continue + key = (row["target"], row["benchmark_mode"], row["batch"]) + repeat_key = (row["target"], row["benchmark_mode"], row["batch"], row.get("repeat", "")) + base_ops = base_repeat_sign.get(repeat_key, 0.0) + ops = as_float(row.get("sign_ops_s")) + if base_ops > 0 and ops > 0: + paired_speedup.setdefault((key, row["variant"]), []).append(ops / base_ops) + +base_sign = {} +for row in rows: + if row.get("status") != "PASS": + continue + try: + ops = float(row.get("sign_ops_s") or 0) + except ValueError: + ops = 0.0 + if row["variant"] == "base": + base_sign[(row["target"], row["benchmark_mode"], row["batch"])] = ops + +out_rows = [] +for row in rows: + try: + sign_ops = float(row.get("sign_ops_s") or 0) + except ValueError: + sign_ops = 0.0 + key = (row["target"], row["benchmark_mode"], row["batch"]) + base_ops = base_sign.get(key, 0.0) + speedup = sign_ops / base_ops if base_ops > 0 and sign_ops > 0 else 0.0 + if (key, row["variant"]) in paired_speedup: + speedup = median(paired_speedup[(key, row["variant"])]) + out_rows.append({ + "target": row["target"], + "benchmark_mode": row["benchmark_mode"], + "batch": row["batch"], + "variant": row["variant"], + "status": row.get("status", ""), + "sign_pass": row.get("sign_pass", ""), + "verify_pass": row.get("verify_pass", ""), + "sign_ms": row.get("sign_ms", ""), + "sign_ops_s": f"{sign_ops:.0f}" if sign_ops else row.get("sign_ops_s", ""), + "speedup_vs_base": f"{speedup:.4f}" if speedup else "", + "keygen_ops_s": row.get("keygen_ops_s", ""), + "verify_ops_s": row.get("verify_ops_s", ""), + "sign_path": row.get("sign_path", ""), + "log": row.get("log", ""), + }) + +out_rows.sort(key=lambda r: ( + r["target"], + r["benchmark_mode"], + int(r["batch"] or 0), + -float(r["speedup_vs_base"] or 0), + r["variant"], +)) + +fieldnames = [ + "target", + "benchmark_mode", + "batch", + "variant", + "status", + "sign_pass", + "verify_pass", + "sign_ms", + "sign_ops_s", + "speedup_vs_base", + "keygen_ops_s", + "verify_ops_s", + "sign_path", + "log", +] + +writer = csv.DictWriter(sys.stdout, fieldnames=fieldnames, lineterminator="\n") +writer.writeheader() +for row in out_rows: + writer.writerow(row) diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/summarize_rocm_kernel_profile.py b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/summarize_rocm_kernel_profile.py new file mode 100644 index 000000000..6765c6ccf --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/summarize_rocm_kernel_profile.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +import csv +import re +import sys +from collections import defaultdict +from pathlib import Path + + +PHASE_RULES = [ + ("setup_mu_nonce", ("setup", "compute_mu", "rhoprime", "init_kernel")), + ("sample_y", ("sample_y", "uniform_gamma", "gamma1")), + ("ntt_invntt", ("ntt", "invntt")), + ("matvec", ("matvec",)), + ("reduce_normalize", ("reduce", "caddq", "freeze")), + ("decompose", ("decompose",)), + ("hash_challenge", ("hash_cp", "challenge", "cbuf")), + ("pointwise_z_cs2_ct0", ("pointwise", "cp_shared", "add_y")), + ("check_pack", ("check_pack", "pack")), +] + + +NAME_KEYS = ( + "KernelName", + "Kernel Name", + "kernel_name", + "Name", + "name", + "Function", + "function", +) + +DURATION_KEYS = ( + "DurationNs", + "Duration_ns", + "duration_ns", + "Duration (ns)", + "Duration", + "duration", + "KernelDuration", +) + + +def classify(name: str) -> str: + low = name.lower() + for phase, keys in PHASE_RULES: + if any(k in low for k in keys): + return phase + return "other" + + +def parse_float(value: str): + if value is None: + return None + text = str(value).strip() + if not text: + return None + text = text.replace(",", "") + match = re.search(r"-?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?", text) + if not match: + return None + return float(match.group(0)) + + +def duration_to_us(key: str, value: str): + num = parse_float(value) + if num is None: + return None + low = key.lower() + if "ns" in low: + return num / 1000.0 + if "us" in low or "µs" in low: + return num + if "ms" in low: + return num * 1000.0 + # rocprof CSVs commonly store raw durations in ns even when the column is + # simply named "Duration". Treat large values as ns, small as us. + if num > 100000.0: + return num / 1000.0 + return num + + +def find_key(row, candidates): + for key in candidates: + if key in row: + return key + lowered = {k.lower(): k for k in row} + for key in candidates: + if key.lower() in lowered: + return lowered[key.lower()] + return None + + +def parse_csv(path: Path): + rows = [] + try: + with path.open("r", encoding="utf-8", errors="replace", newline="") as f: + reader = csv.DictReader(f) + if not reader.fieldnames: + return rows + for row in reader: + name_key = find_key(row, NAME_KEYS) + dur_key = find_key(row, DURATION_KEYS) + if not name_key or not dur_key: + continue + name = (row.get(name_key) or "").strip() + if not name: + continue + dur_us = duration_to_us(dur_key, row.get(dur_key)) + if dur_us is None: + continue + # Keep likely GPU kernels and skip obvious host API rows if mixed. + if name.startswith("hip") or name.startswith("hsa_"): + continue + rows.append((name, dur_us)) + except OSError: + pass + return rows + + +def summarize(root: Path): + phase_us = defaultdict(float) + phase_count = defaultdict(int) + kernel_us = defaultdict(float) + kernel_count = defaultdict(int) + parsed_files = [] + + for path in root.rglob("*.csv"): + rows = parse_csv(path) + if not rows: + continue + parsed_files.append(path) + for name, dur_us in rows: + phase = classify(name) + phase_us[phase] += dur_us + phase_count[phase] += 1 + kernel_us[name] += dur_us + kernel_count[name] += 1 + + return parsed_files, phase_us, phase_count, kernel_us, kernel_count + + +def main(argv): + root = Path(argv[1]) if len(argv) > 1 else Path("amd_results/profile") + parsed_files, phase_us, phase_count, kernel_us, kernel_count = summarize(root) + + if not parsed_files: + print(f"No rocprof CSV kernel data found under {root}.") + print("Run: bash amd_tools/profile_sig_one.sh mldsa44_amd 1024") + print("Then inspect amd_results/profile/*_rocprof*/ for CSV output.") + return 2 + + total_us = sum(phase_us.values()) + print(f"# ROCm kernel profile summary") + print(f"root,{root}") + print(f"parsed_csv_files,{len(parsed_files)}") + print() + print("phase,kernel_count,total_us,total_ms,percent") + for phase, total in sorted(phase_us.items(), key=lambda kv: kv[1], reverse=True): + pct = (total / total_us * 100.0) if total_us else 0.0 + print(f"{phase},{phase_count[phase]},{total:.3f},{total/1000.0:.3f},{pct:.2f}") + + print() + print("top_kernel,kernel_count,total_us,total_ms,phase") + for name, total in sorted(kernel_us.items(), key=lambda kv: kv[1], reverse=True)[:30]: + print(f"{name},{kernel_count[name]},{total:.3f},{total/1000.0:.3f},{classify(name)}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/summarize_sig_best.py b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/summarize_sig_best.py new file mode 100644 index 000000000..fcf83eb91 --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/summarize_sig_best.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +import csv +import sys +from pathlib import Path + +if len(sys.argv) != 2: + print("usage: summarize_sig_best.py ", file=sys.stderr) + raise SystemExit(2) + +path = Path(sys.argv[1]) + +def target_from_row(row): + scheme = row.get("scheme", "") + mode = row.get("mode", "") + if scheme == "ML-DSA": + return {"2": "mldsa44", "3": "mldsa65", "5": "mldsa87"}.get(mode, f"mldsa_mode{mode}") + if scheme == "Aigis-sig": + return {"1": "aigis1", "2": "aigis2", "3": "aigis3"}.get(mode, f"aigis_mode{mode}") + return f"{scheme}_mode{mode}" + +def bench_mode_from_log(log_name): + if "_independent_" in log_name: + return "independent" + if "_paper_" in log_name: + return "paper" + return "default" + +rows = [] +if sys.argv[1] == "-": + for row in csv.DictReader(sys.stdin): + row["target"] = target_from_row(row) + row["benchmark_mode"] = bench_mode_from_log(row.get("log", "")) + rows.append(row) +else: + with path.open(newline="", errors="replace") as f: + for row in csv.DictReader(f): + row["target"] = target_from_row(row) + row["benchmark_mode"] = bench_mode_from_log(row.get("log", "")) + rows.append(row) + +best = {} +for row in rows: + if row.get("status") != "PASS": + continue + key = (row["target"], row["benchmark_mode"]) + for op, field in ( + ("Keygen", "keygen_ops_s"), + ("Sign", "sign_ops_s"), + ("Verify", "verify_ops_s"), + ): + try: + ops = float(row.get(field) or 0) + except ValueError: + ops = 0.0 + bkey = key + (op,) + if ops > best.get(bkey, {}).get("ops_s", -1): + best[bkey] = { + "target": row["target"], + "benchmark_mode": row["benchmark_mode"], + "operation": op, + "batch": row.get("batch", ""), + "ms": row.get(f"{op.lower()}_ms", ""), + "ops_s": ops, + "path": row.get(f"{op.lower()}_path", ""), + "log": row.get("log", ""), + } + +fieldnames = ["target", "benchmark_mode", "operation", "batch", "ms", "ops_s", "path", "log"] +writer = csv.DictWriter(sys.stdout, fieldnames=fieldnames, lineterminator="\n") +writer.writeheader() +for key in sorted(best): + row = best[key] + out = dict(row) + out["ops_s"] = f"{row['ops_s']:.0f}" + writer.writerow(out) diff --git a/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/write_optimization_claims.py b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/write_optimization_claims.py new file mode 100644 index 000000000..b1494d5af --- /dev/null +++ b/Applications/pqc_trustflow_rocm/02_performance_bottleneck_rocm_optimization/sig_optimization/amd_tools/write_optimization_claims.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +import csv +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +RESULTS = ROOT / "amd_results" + + +def read_csv(path): + if not path.exists(): + return [] + with path.open(newline="", errors="replace") as f: + return list(csv.DictReader(f)) + + +def best_feature_rows(rows): + grouped = {} + for row in rows: + if row.get("status") != "PASS": + continue + key = (row.get("target", ""), row.get("benchmark_mode", ""), row.get("batch", "")) + try: + speedup = float(row.get("speedup_vs_base") or 0) + ops = float(row.get("sign_ops_s") or 0) + except ValueError: + speedup, ops = 0.0, 0.0 + cur = grouped.get(key) + if cur is None or (speedup, ops) > cur[0]: + grouped[key] = ((speedup, ops), row) + return [v[1] for k, v in sorted(grouped.items())] + + +def main(): + large_best = read_csv(RESULTS / "sig_large_best.csv") + feature_ranked = read_csv(RESULTS / "sig_amd_feature_matrix_ranked.csv") + resource = read_csv(RESULTS / "anchor03_resource" / "resource_summary.csv") + + lines = [] + lines.append("# AMD ROCm Optimization Claims") + lines.append("") + lines.append("## Implemented Candidates") + lines.append("") + lines.append("- Stable signing remains the resource-aware `decomp-pipeline` path.") + lines.append("- `adaptive` is a runtime policy candidate: one binary selects the measured local winner by target, benchmark mode, and batch size, while falling back to base on cells where the matrix shows regressions.") + lines.append("- `check8` and `check16` measure whether fewer host-side done-count checks reduce ROCm synchronization overhead.") + lines.append("- `wave64_ctrl` measures whether 64-thread hash/check control kernels behave better on AMD wave64 hardware than 32-thread control kernels.") + lines.append("- `BATCH_SIGN_CP_FUSE_ENABLE` is implemented as a measured AMD candidate: one ROCm kernel computes `cp*s1`, `cp*s2`, and `cp*t0` products for each rejection round.") + lines.append("- `tail16_base` and `tail16_cp_fuse` separate small-tail finish behavior from the fused pointwise candidate.") + lines.append("- `yhat_dup` measures whether duplicating `y` at sample time beats the explicit device-to-device copy.") + lines.append("- The default build keeps these candidates off until the matrix proves a conservative target-specific gain.") + lines.append("") + + sign_best = [r for r in large_best if r.get("operation") == "Sign"] + if sign_best: + lines.append("## Current Large-Sweep Sign Best") + lines.append("") + lines.append("| target | mode | batch | sign ops/s | path | log |") + lines.append("| --- | --- | ---: | ---: | --- | --- |") + for r in sign_best: + lines.append( + f"| {r.get('target','')} | {r.get('benchmark_mode','')} | " + f"{r.get('batch','')} | {r.get('ops_s','')} | " + f"{r.get('path','')} | {r.get('log','')} |" + ) + lines.append("") + + if feature_ranked: + lines.append("## AMD Feature Matrix Winners") + lines.append("") + lines.append("| target | mode | batch | best variant | speedup vs base | sign ops/s | log |") + lines.append("| --- | --- | ---: | --- | ---: | ---: | --- |") + for r in best_feature_rows(feature_ranked): + lines.append( + f"| {r.get('target','')} | {r.get('benchmark_mode','')} | " + f"{r.get('batch','')} | {r.get('variant','')} | " + f"{r.get('speedup_vs_base','')} | {r.get('sign_ops_s','')} | " + f"{r.get('log','')} |" + ) + lines.append("") + lines.append("Matrix interpretation: local wins are useful evidence. The `adaptive` row tests whether those wins can be captured in one target/mode/batch-aware build without promoting a globally regressing macro.") + lines.append("") + else: + lines.append("## AMD Feature Matrix") + lines.append("") + lines.append("Run `bash amd_tools/run_sig_amd_feature_matrix.sh` to generate `sig_amd_feature_matrix_ranked.csv` with per-variant speedups.") + lines.append("") + + failures = [ + r for r in resource + if (r.get("exit_code") and r.get("exit_code") != "0") + or "FAIL" in (r.get("error_hint") or "") + or "out of resources" in (r.get("error_hint") or "").lower() + ] + lines.append("## AMD Limitation Evidence") + lines.append("") + if failures: + lines.append("The monolithic/cached-style signing candidates are retained as negative evidence; representative failures:") + lines.append("") + lines.append("| target/variant | exit | hint |") + lines.append("| --- | ---: | --- |") + for r in failures[:12]: + lines.append( + f"| {r.get('target','')} | {r.get('exit_code','')} | " + f"{(r.get('error_hint') or '').replace('|', '/') } |" + ) + lines.append("") + else: + lines.append("Run `bash anchor03_resource_attribution.sh` to regenerate monolithic-vs-decomp failure evidence.") + lines.append("") + + lines.append("## Next Tuning Step") + lines.append("") + lines.append("Run `python3 amd_tools/select_sig_amd_variants.py`, inspect `amd_results/sig_amd_variant_plan.md`, then build selected variants. If `adaptive` is promoted, rerun smoke/debug/large-sweep to collect final evidence.") + lines.append("") + + out = RESULTS / "optimization_claims.md" + out.write_text("\n".join(lines), encoding="utf-8") + print(f"[OK] wrote {out}") + + +if __name__ == "__main__": + main() diff --git a/Applications/pqc_trustflow_rocm/README.md b/Applications/pqc_trustflow_rocm/README.md new file mode 100644 index 000000000..062679bab --- /dev/null +++ b/Applications/pqc_trustflow_rocm/README.md @@ -0,0 +1,41 @@ +# PR Upload Ready Package + +This directory is split by the two innovation-development scoring items in the AMD competition. + +## 01_unsupported_feature_rocm_pqc_api + +Corresponds to: + +```text +(1) Development of currently unsupported functions +``` + +This folder contains the ROCm/HIP post-quantum cryptography function layer: + +- batch KEM implementation and file-level KEM CLI API; +- batch ML-DSA/Aigis-sig implementation and file-level signature CLI API; +- TrustFlow frontend integration that calls the ROCm backends for multi-file secure packaging; +- quick-start and API notes. + +Use this folder when the PR needs to highlight new ROCm backend functionality and upper-layer API adaptation. + +## 02_performance_bottleneck_rocm_optimization + +Corresponds to: + +```text +(2) Performance bottleneck localization and optimization +``` + +This folder contains the profiling, tuning, and evidence layer: + +- KEM tuning scripts for TPB, launch bounds, buffer reuse, and ROCm profiling; +- signature tuning scripts for resource-aware decomp pipeline and feature-matrix candidates; +- small evidence summaries for KEM final throughput, SIG large sweep, local winners, and optimization decisions; +- original analysis notes showing bottleneck attribution and conservative promotion decisions. + +Use this folder when the PR needs to highlight systematic performance analysis, quantified optimization, and engineering trade-offs. + +## Excluded From PR + +Generated binaries, caches, secret files, large logs, temporary outputs, and competition-only documents are intentionally excluded.