Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions runtime/nvqir/custatevec/CuStateVecCircuitSimulator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,22 @@ class CuStateVecCircuitSimulator

/// @brief The cuStateVec handle
custatevecHandle_t handle = nullptr;
/// @brief CUDA device the handle was created on (reuse guard)
int handleDevice = -1;

/// @brief Create the cuStateVec handle once per device and reuse
/// it across kernel executions; the handle is a device context,
/// not tied to a particular state vector.
void ensureHandle() {
int dev;
HANDLE_CUDA_ERROR(cudaGetDevice(&dev));
if (handle && handleDevice == dev)
return;
if (handle)
HANDLE_ERROR(custatevecDestroy(handle));
HANDLE_ERROR(custatevecCreate(&handle));
handleDevice = dev;
}

/// @brief Pointer to potentially needed extra memory
void *extraWorkspace = nullptr;
Expand Down Expand Up @@ -179,7 +195,7 @@ class CuStateVecCircuitSimulator
// Create the memory and the handle
HANDLE_CUDA_ERROR(cudaMalloc((void **)&deviceStateVector,
stateDimension * sizeof(CudaDataType)));
HANDLE_ERROR(custatevecCreate(&handle));
ensureHandle();
ownsDeviceVector = true;
// If no state provided, initialize to the zero state
if (state == nullptr) {
Expand Down Expand Up @@ -269,7 +285,7 @@ class CuStateVecCircuitSimulator
HANDLE_CUDA_ERROR(cudaMalloc((void **)&deviceStateVector,
stateDimension * sizeof(CudaDataType)));
ownsDeviceVector = true;
HANDLE_ERROR(custatevecCreate(&handle));
ensureHandle();
ScopedTraceWithContext(
"CuStateVecCircuitSimulator::addQubitsToState cudaMemcpy");
// First allocation, so just copy the user provided data (device mem) here
Expand Down Expand Up @@ -317,7 +333,7 @@ class CuStateVecCircuitSimulator
(stateDimension + threads_per_block - 1) / threads_per_block;
nvqir::initializeDeviceStateVector<CudaDataType>(
n_blocks, threads_per_block, deviceStateVector, stateDimension);
HANDLE_ERROR(custatevecCreate(&handle));
ensureHandle();
} else {
// Allocate new state..
void *newDeviceStateVector;
Expand All @@ -336,8 +352,6 @@ class CuStateVecCircuitSimulator

/// @brief Reset the qubit state.
void deallocateStateImpl() override {
if (deviceStateVector)
HANDLE_ERROR(custatevecDestroy(handle));
if (deviceStateVector && ownsDeviceVector) {
HANDLE_CUDA_ERROR(cudaFree(deviceStateVector));
}
Expand Down Expand Up @@ -416,7 +430,10 @@ class CuStateVecCircuitSimulator
}

/// The destructor
virtual ~CuStateVecCircuitSimulator() = default;
virtual ~CuStateVecCircuitSimulator() {
if (handle)
custatevecDestroy(handle);
}

void setRandomSeed(std::size_t randomSeed) override {
randomEngine = std::mt19937(randomSeed);
Expand Down
Loading