From 7588733a1aae0d115e7e0e7bc4bc6108dad59d15 Mon Sep 17 00:00:00 2001 From: ikkoham Date: Tue, 16 Jun 2026 16:35:16 +0900 Subject: [PATCH] [nvqir] Reuse the cuStateVec handle across kernel executions Signed-off-by: ikkoham --- .../custatevec/CuStateVecCircuitSimulator.cpp | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/runtime/nvqir/custatevec/CuStateVecCircuitSimulator.cpp b/runtime/nvqir/custatevec/CuStateVecCircuitSimulator.cpp index 1deb1f53347..9dcfd7a5e0b 100644 --- a/runtime/nvqir/custatevec/CuStateVecCircuitSimulator.cpp +++ b/runtime/nvqir/custatevec/CuStateVecCircuitSimulator.cpp @@ -50,6 +50,22 @@ class CuStateVecCircuitSimulator /// @brief The cuStateVec handle custatevecHandle_t handle = nullptr; + /// @brief CUDA device the handle was created on (reuse guard) + int handleDevice = -1; + + /// @brief Create the cuStateVec handle once per device and reuse + /// it across kernel executions; the handle is a device context, + /// not tied to a particular state vector. + void ensureHandle() { + int dev; + HANDLE_CUDA_ERROR(cudaGetDevice(&dev)); + if (handle && handleDevice == dev) + return; + if (handle) + HANDLE_ERROR(custatevecDestroy(handle)); + HANDLE_ERROR(custatevecCreate(&handle)); + handleDevice = dev; + } /// @brief Pointer to potentially needed extra memory void *extraWorkspace = nullptr; @@ -179,7 +195,7 @@ class CuStateVecCircuitSimulator // Create the memory and the handle HANDLE_CUDA_ERROR(cudaMalloc((void **)&deviceStateVector, stateDimension * sizeof(CudaDataType))); - HANDLE_ERROR(custatevecCreate(&handle)); + ensureHandle(); ownsDeviceVector = true; // If no state provided, initialize to the zero state if (state == nullptr) { @@ -269,7 +285,7 @@ class CuStateVecCircuitSimulator HANDLE_CUDA_ERROR(cudaMalloc((void **)&deviceStateVector, stateDimension * sizeof(CudaDataType))); ownsDeviceVector = true; - HANDLE_ERROR(custatevecCreate(&handle)); + ensureHandle(); ScopedTraceWithContext( "CuStateVecCircuitSimulator::addQubitsToState cudaMemcpy"); // First allocation, so just copy the user provided data (device mem) here @@ -317,7 +333,7 @@ class CuStateVecCircuitSimulator (stateDimension + threads_per_block - 1) / threads_per_block; nvqir::initializeDeviceStateVector( n_blocks, threads_per_block, deviceStateVector, stateDimension); - HANDLE_ERROR(custatevecCreate(&handle)); + ensureHandle(); } else { // Allocate new state.. void *newDeviceStateVector; @@ -336,8 +352,6 @@ class CuStateVecCircuitSimulator /// @brief Reset the qubit state. void deallocateStateImpl() override { - if (deviceStateVector) - HANDLE_ERROR(custatevecDestroy(handle)); if (deviceStateVector && ownsDeviceVector) { HANDLE_CUDA_ERROR(cudaFree(deviceStateVector)); } @@ -416,7 +430,10 @@ class CuStateVecCircuitSimulator } /// The destructor - virtual ~CuStateVecCircuitSimulator() = default; + virtual ~CuStateVecCircuitSimulator() { + if (handle) + custatevecDestroy(handle); + } void setRandomSeed(std::size_t randomSeed) override { randomEngine = std::mt19937(randomSeed);