diff --git a/csrc/renorm.cu b/csrc/renorm.cu index d186b40eae..14eb7d541c 100644 --- a/csrc/renorm.cu +++ b/csrc/renorm.cu @@ -23,7 +23,8 @@ using tvm::ffi::Optional; void top_p_renorm_probs(TensorView probs, TensorView renorm_probs, Optional maybe_top_p_arr, double top_p_val) { - CHECK_INPUT(probs); + CHECK_CUDA(probs); + CHECK_LAST_DIM_CONTIGUOUS(probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) unsigned int batch_size = probs->shape[0]; unsigned int vocab_size = probs->shape[1]; @@ -34,14 +35,15 @@ void top_p_renorm_probs(TensorView probs, TensorView renorm_probs, cudaError_t status = sampling::TopPRenormProb( static_cast(probs->data), static_cast(renorm_probs->data), has_top_p_arr ? static_cast(maybe_top_p_arr.value()->data) : nullptr, batch_size, - top_p_val, vocab_size, stream); + top_p_val, vocab_size, probs->strides[0], stream); TVM_FFI_ICHECK(status == cudaSuccess) << "TopPRenormProb failed with error code " << cudaGetErrorString(status); } void top_k_renorm_probs(TensorView probs, TensorView renorm_probs, Optional maybe_top_k_arr, int64_t top_k_val) { - CHECK_INPUT(probs); + CHECK_CUDA(probs); + CHECK_LAST_DIM_CONTIGUOUS(probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) unsigned int batch_size = probs->shape[0]; unsigned int vocab_size = probs->shape[1]; @@ -52,7 +54,7 @@ void top_k_renorm_probs(TensorView probs, TensorView renorm_probs, cudaError_t status = sampling::TopKRenormProb( static_cast(probs->data), static_cast(renorm_probs->data), has_top_k_arr ? static_cast(maybe_top_k_arr.value()->data) : nullptr, batch_size, - top_k_val, vocab_size, stream); + top_k_val, vocab_size, probs->strides[0], stream); TVM_FFI_ICHECK(status == cudaSuccess) << "TopKRenormProb failed with error code " << cudaGetErrorString(status); @@ -60,7 +62,8 @@ void top_k_renorm_probs(TensorView probs, TensorView renorm_probs, void top_k_mask_logits(TensorView logits, TensorView mask_logits, Optional maybe_top_k_arr, int64_t top_k_val) { - CHECK_INPUT(logits); + CHECK_CUDA(logits); + CHECK_LAST_DIM_CONTIGUOUS(logits); CHECK_DIM(2, logits); // logits: (batch_size, vocab_size) unsigned int batch_size = logits->shape[0]; unsigned int vocab_size = logits->shape[1]; @@ -71,7 +74,7 @@ void top_k_mask_logits(TensorView logits, TensorView mask_logits, cudaError_t status = sampling::TopKMaskLogits( static_cast(logits->data), static_cast(mask_logits->data), has_top_k_arr ? static_cast(maybe_top_k_arr.value()->data) : nullptr, batch_size, - top_k_val, vocab_size, stream); + top_k_val, vocab_size, logits->strides[0], stream); TVM_FFI_ICHECK(status == cudaSuccess) << "TopKMaskLogits failed with error code " << cudaGetErrorString(status); diff --git a/csrc/sampling.cu b/csrc/sampling.cu index d17295d091..55a7cca8b5 100644 --- a/csrc/sampling.cu +++ b/csrc/sampling.cu @@ -45,7 +45,8 @@ void softmax(TensorView workspace_buffer, TensorView logits, TensorView output, void sampling_from_logits(TensorView logits, TensorView output, Optional maybe_indices, bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { - CHECK_INPUT(logits); + CHECK_CUDA(logits); + CHECK_LAST_DIM_CONTIGUOUS(logits); CHECK_DIM(2, logits); // logits: (batch_size, vocab_size) unsigned int batch_size = output->shape[0]; unsigned int vocab_size = logits->shape[1]; @@ -55,14 +56,16 @@ void sampling_from_logits(TensorView logits, TensorView output, Optional(logits->data), static_cast(output->data), maybe_indices.has_value() ? static_cast(maybe_indices.value()->data) : nullptr, - batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream); + batch_size, vocab_size, logits->strides[0], deterministic, philox_seed, philox_offset, + stream); TVM_FFI_ICHECK(status == cudaSuccess) << "SamplingFromLogits failed with error code " << cudaGetErrorString(status); } void sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { - CHECK_INPUT(probs); + CHECK_CUDA(probs); + CHECK_LAST_DIM_CONTIGUOUS(probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) unsigned int batch_size = output->shape[0]; unsigned int vocab_size = probs->shape[1]; @@ -72,7 +75,7 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional(probs->data), static_cast(output->data), maybe_indices.has_value() ? static_cast(maybe_indices.value()->data) : nullptr, - batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream); + batch_size, vocab_size, probs->strides[0], deterministic, philox_seed, philox_offset, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "SamplingFromProbs failed with error code " << cudaGetErrorString(status); } @@ -81,7 +84,8 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, Optional maybe_top_p_arr, double top_p_val, bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { - CHECK_INPUT(probs); + CHECK_CUDA(probs); + CHECK_LAST_DIM_CONTIGUOUS(probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) unsigned int batch_size = output->shape[0]; unsigned int vocab_size = probs->shape[1]; @@ -93,7 +97,7 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output, static_cast(probs->data), static_cast(output->data), maybe_indices.has_value() ? static_cast(maybe_indices.value()->data) : nullptr, has_top_p_arr ? static_cast(maybe_top_p_arr.value()->data) : nullptr, batch_size, - top_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream); + top_p_val, vocab_size, probs->strides[0], deterministic, philox_seed, philox_offset, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "TopPSamplingFromProbs failed with error code " << cudaGetErrorString(status); } @@ -102,7 +106,8 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, Optional maybe_top_k_arr, int64_t top_k_val, bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { - CHECK_INPUT(probs); + CHECK_CUDA(probs); + CHECK_LAST_DIM_CONTIGUOUS(probs); CHECK_INPUT(output); CHECK_DEVICE(output, probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) @@ -117,7 +122,7 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output, static_cast(probs->data), static_cast(output->data), maybe_indices.has_value() ? static_cast(maybe_indices.value()->data) : nullptr, has_top_k_arr ? static_cast(maybe_top_k_arr.value()->data) : nullptr, batch_size, - top_k_val, vocab_size, deterministic, philox_seed, philox_offset, stream); + top_k_val, vocab_size, probs->strides[0], deterministic, philox_seed, philox_offset, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "TopKSamplingFromProbs failed with error code " << cudaGetErrorString(status); } @@ -126,7 +131,8 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output, Optional maybe_indices, Optional maybe_min_p_arr, double min_p_val, bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { - CHECK_INPUT(probs); + CHECK_CUDA(probs); + CHECK_LAST_DIM_CONTIGUOUS(probs); CHECK_INPUT(output); CHECK_DEVICE(output, probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) @@ -142,7 +148,8 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output, has_min_p_arr ? static_cast(maybe_min_p_arr.value()->data) : nullptr, static_cast(output->data), maybe_indices.has_value() ? static_cast(maybe_indices.value()->data) : nullptr, - batch_size, min_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream); + batch_size, min_p_val, vocab_size, probs->strides[0], deterministic, philox_seed, + philox_offset, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "MinPSamplingFromProb failed with error code " << cudaGetErrorString(status); } @@ -153,7 +160,8 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, Optional maybe_top_p_arr, double top_p_val, bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { - CHECK_INPUT(probs); + CHECK_CUDA(probs); + CHECK_LAST_DIM_CONTIGUOUS(probs); CHECK_INPUT(output); CHECK_DEVICE(output, probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) @@ -171,8 +179,8 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, has_top_p_arr ? static_cast(maybe_top_p_arr.value()->data) : nullptr, static_cast(output->data), maybe_indices.has_value() ? static_cast(maybe_indices.value()->data) : nullptr, - batch_size, top_k_val, top_p_val, vocab_size, deterministic, philox_seed, philox_offset, - stream); + batch_size, top_k_val, top_p_val, vocab_size, probs->strides[0], deterministic, philox_seed, + philox_offset, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "TopKTopPSamplingFromProbs failed with error code " << cudaGetErrorString(status); } diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 6b134630cf..d06fa35bc4 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -245,7 +245,7 @@ __device__ __forceinline__ void DeterministicInclusiveSum( template __device__ __forceinline__ std::tuple GetMinMaxValue(float* in_data, uint32_t row_idx, - uint32_t d, + uint32_t d, uint32_t stride, TempStorage& temp_storage) { const uint32_t tx = threadIdx.x; vec_t in_data_vec; @@ -254,7 +254,8 @@ __device__ __forceinline__ std::tuple GetMinMaxValue(float* in_dat for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { in_data_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - in_data_vec.cast_load(in_data + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + in_data_vec.cast_load(in_data + row_idx * stride + i * BLOCK_THREADS * VEC_SIZE + + tx * VEC_SIZE); } float in_data_[VEC_SIZE]; #pragma unroll @@ -284,7 +285,7 @@ __device__ __forceinline__ std::tuple GetMinMaxValue(float* in_dat template __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d, - TempStorage& temp_storage) { + uint32_t stride, TempStorage& temp_storage) { const uint32_t tx = threadIdx.x; vec_t in_data_vec; @@ -292,7 +293,7 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { in_data_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + in_data_vec.cast_load(in_data + row_idx * stride + (i * BLOCK_THREADS + tx) * VEC_SIZE); } float in_data_[VEC_SIZE]; #pragma unroll @@ -745,7 +746,8 @@ template __global__ void SamplingFromLogitsKernel(DType* logits, IdType* output, IdType* indices, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { + uint32_t stride, uint64_t philox_seed, + uint64_t philox_offset) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; using SharedMem = typename BlockReduce, BLOCK_THREADS, @@ -758,7 +760,8 @@ __global__ void SamplingFromLogitsKernel(DType* logits, IdType* output, IdType* for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { logits_vec.fill(-cuda::std::numeric_limits::infinity()); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - logits_vec.cast_load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + logits_vec.cast_load(logits + row_idx * stride + i * BLOCK_THREADS * VEC_SIZE + + tx * VEC_SIZE); } vec_t gumbel_noise = GenerateGumbelNoise( @@ -786,7 +789,8 @@ template __global__ void SamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { + uint32_t stride, uint64_t philox_seed, + uint64_t philox_offset) { curandStatePhilox4_32_10_t state; const uint32_t bx = blockIdx.x, tx = threadIdx.x; curand_init(philox_seed, bx, philox_offset, &state); @@ -809,7 +813,7 @@ __global__ void SamplingFromProbKernel(DType* probs, IdType* output, IdType* ind for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * stride + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); } DeviceSamplingFromProb __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, IdType* top_k_arr, uint32_t top_k_val, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { + uint32_t stride, uint64_t philox_seed, + uint64_t philox_offset) { const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; curandStatePhilox4_32_10_t state; @@ -865,7 +870,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * stride + (i * BLOCK_THREADS + tx) * VEC_SIZE); } DeviceSamplingFromProb aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; @@ -891,7 +896,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * stride + (i * BLOCK_THREADS + tx) * VEC_SIZE); } ValueCount probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; @@ -949,7 +954,8 @@ template __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, float* top_p_arr, float top_p_val, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { + uint32_t stride, uint64_t philox_seed, + uint64_t philox_offset) { const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; curandStatePhilox4_32_10_t state; @@ -978,7 +984,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * stride + (i * BLOCK_THREADS + tx) * VEC_SIZE); } DeviceSamplingFromProb __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdType* output, IdType* indices, float min_p_val, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { + uint32_t stride, uint64_t philox_seed, + uint64_t philox_offset) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; float p = (min_p_arr == nullptr) ? min_p_val : min_p_arr[bx]; curandStatePhilox4_32_10_t state; @@ -1072,7 +1079,7 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp float max_val = GetMaxValue>( - probs, row_idx, d, temp_storage); + probs, row_idx, d, stride, temp_storage); float pivot = max_val * p; vec_t probs_vec; @@ -1081,7 +1088,7 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * stride + (i * BLOCK_THREADS + tx) * VEC_SIZE); } float probs_gt_pivot[VEC_SIZE]; @@ -1109,7 +1116,7 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * stride + (i * BLOCK_THREADS + tx) * VEC_SIZE); } DeviceSamplingFromProb __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr, float* top_p_arr, IdType* output, IdType* indices, IdType top_k_val, - float top_p_val, uint32_t d, uint64_t philox_seed, - uint64_t philox_offset) { + float top_p_val, uint32_t d, uint32_t stride, + uint64_t philox_seed, uint64_t philox_offset) { const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; curandStatePhilox4_32_10_t state; @@ -1165,7 +1172,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr, for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * stride + (i * BLOCK_THREADS + tx) * VEC_SIZE); } DeviceSamplingFromProb aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; @@ -1191,7 +1198,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr, for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * stride + (i * BLOCK_THREADS + tx) * VEC_SIZE); } ValueCount probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; @@ -1389,15 +1396,16 @@ cudaError_t OnlineSoftmax(DType* logits, DType* output, uint32_t batch_size, uin template cudaError_t SamplingFromLogits(T* logits, IdType* output, IdType* indices, uint32_t batch_size, - uint32_t d, bool deterministic, uint64_t philox_seed, - uint64_t philox_offset, cudaStream_t stream = 0) { + uint32_t d, uint32_t stride, bool deterministic, + uint64_t philox_seed, uint64_t philox_offset, + cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); auto compute_capacity = GetCudaComputeCapability(); DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&logits, &output, &indices, &d, &philox_seed, &philox_offset}; + void* args[] = {&logits, &output, &indices, &d, &stride, &philox_seed, &philox_offset}; const uint32_t smem_size = sizeof( typename BlockReduce, BLOCK_THREADS, REDUCE_ALGO>::TempStorage); @@ -1414,7 +1422,7 @@ cudaError_t SamplingFromLogits(T* logits, IdType* output, IdType* indices, uint3 template cudaError_t SamplingFromProb(T* probs, IdType* output, IdType* indices, uint32_t batch_size, - uint32_t d, bool deterministic, uint64_t philox_seed, + uint32_t d, uint32_t stride, bool deterministic, uint64_t philox_seed, uint64_t philox_offset, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); @@ -1422,7 +1430,7 @@ cudaError_t SamplingFromProb(T* probs, IdType* output, IdType* indices, uint32_t DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &indices, &d, &philox_seed, &philox_offset}; + void* args[] = {&probs, &output, &indices, &d, &stride, &philox_seed, &philox_offset}; const uint32_t smem_size = sizeof(SamplingTempStorage); DISPATCH_ALIGNED_VEC_SIZE( @@ -1439,8 +1447,8 @@ cudaError_t SamplingFromProb(T* probs, IdType* output, IdType* indices, uint32_t template cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, - bool deterministic, uint64_t philox_seed, uint64_t philox_offset, - cudaStream_t stream = 0) { + uint32_t stride, bool deterministic, uint64_t philox_seed, + uint64_t philox_offset, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); auto compute_capacity = GetCudaComputeCapability(); @@ -1448,8 +1456,8 @@ cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* t const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &indices, &top_k_arr, - &top_k_val, &d, &philox_seed, &philox_offset}; + void* args[] = {&probs, &output, &indices, &top_k_arr, &top_k_val, + &d, &stride, &philox_seed, &philox_offset}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { @@ -1466,8 +1474,8 @@ cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* t template cudaError_t TopPSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_p_arr, - uint32_t batch_size, T top_p_val, uint32_t d, bool deterministic, - uint64_t philox_seed, uint64_t philox_offset, + uint32_t batch_size, T top_p_val, uint32_t d, uint32_t stride, + bool deterministic, uint64_t philox_seed, uint64_t philox_offset, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); @@ -1476,8 +1484,8 @@ cudaError_t TopPSamplingFromProb(T* probs, IdType* output, IdType* indices, T* t const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &indices, &top_p_arr, - &top_p_val, &d, &philox_seed, &philox_offset}; + void* args[] = {&probs, &output, &indices, &top_p_arr, &top_p_val, + &d, &stride, &philox_seed, &philox_offset}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { @@ -1494,7 +1502,7 @@ cudaError_t TopPSamplingFromProb(T* probs, IdType* output, IdType* indices, T* t template cudaError_t MinPSamplingFromProb(T* probs, T* min_p_arr, IdType* output, IdType* indices, - uint32_t batch_size, float min_p_val, uint32_t d, + uint32_t batch_size, float min_p_val, uint32_t d, uint32_t stride, bool deterministic, uint64_t philox_seed, uint64_t philox_offset, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); @@ -1504,8 +1512,8 @@ cudaError_t MinPSamplingFromProb(T* probs, T* min_p_arr, IdType* output, IdType* const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &min_p_arr, &output, &indices, - &min_p_val, &d, &philox_seed, &philox_offset}; + void* args[] = {&probs, &min_p_arr, &output, &indices, &min_p_val, + &d, &stride, &philox_seed, &philox_offset}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { @@ -1523,7 +1531,7 @@ cudaError_t MinPSamplingFromProb(T* probs, T* min_p_arr, IdType* output, IdType* template cudaError_t TopKTopPSamplingFromProb(T* probs, IdType* top_k_arr, T* top_p_arr, IdType* output, IdType* indices, uint32_t batch_size, IdType top_k_val, - T top_p_val, uint32_t d, bool deterministic, + T top_p_val, uint32_t d, uint32_t stride, bool deterministic, uint64_t philox_seed, uint64_t philox_offset, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); @@ -1533,8 +1541,8 @@ cudaError_t TopKTopPSamplingFromProb(T* probs, IdType* top_k_arr, T* top_p_arr, const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &top_k_arr, &top_p_arr, &output, &indices, - &top_k_val, &top_p_val, &d, &philox_seed, &philox_offset}; + void* args[] = {&probs, &top_k_arr, &top_p_arr, &output, &indices, &top_k_val, + &top_p_val, &d, &stride, &philox_seed, &philox_offset}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { @@ -1578,7 +1586,7 @@ struct RenormTempStorage { template __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* top_p_arr, - float top_p_val, uint32_t d) { + float top_p_val, uint32_t d, uint32_t stride) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = bx; float p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx]; @@ -1598,7 +1606,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* probs_vec.fill(0.0f); const uint32_t base_idx = (i * BLOCK_THREADS + tx) * VEC_SIZE; if (base_idx < d) { - probs_vec.cast_load(probs + row_idx * d + base_idx); + probs_vec.cast_load(probs + row_idx * stride + base_idx); } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { @@ -1625,7 +1633,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* probs_vec.fill(0.0f); const uint32_t base_idx = (i * BLOCK_THREADS + tx) * VEC_SIZE; if (base_idx < d) { - probs_vec.cast_load(probs + row_idx * d + base_idx); + probs_vec.cast_load(probs + row_idx * stride + base_idx); } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { @@ -1643,8 +1651,8 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* // Original Top-P renormalization logic temp_storage.max_val = 0; float max_val = GetMaxValue>(probs, row_idx, d, - temp_storage); + RenormTempStorage>( + probs, row_idx, d, stride, temp_storage); double low = 0, high = max_val; float min_gt_low, max_le_high; @@ -1667,7 +1675,8 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * stride + i * BLOCK_THREADS * VEC_SIZE + + tx * VEC_SIZE); } float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; @@ -1731,7 +1740,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * stride + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { @@ -1747,7 +1756,7 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* template __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType* top_k_arr, - uint32_t top_k_val, uint32_t d) { + uint32_t top_k_val, uint32_t d, uint32_t stride) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = bx; uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; @@ -1762,7 +1771,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType auto [min_val, max_val] = GetMinMaxValue>( - logits, row_idx, d, temp_storage); + logits, row_idx, d, stride, temp_storage); double low = (min_val == -cuda::std::numeric_limits::infinity()) ? cuda::std::numeric_limits::lowest() @@ -1787,7 +1796,8 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { logits_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - logits_vec.cast_load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + logits_vec.cast_load(logits + row_idx * stride + i * BLOCK_THREADS * VEC_SIZE + + tx * VEC_SIZE); } int probs_gt_pivot_0_count[VEC_SIZE], probs_gt_pivot_1_count[VEC_SIZE]; #pragma unroll @@ -1851,7 +1861,8 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { logits_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - logits_vec.cast_load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + logits_vec.cast_load(logits + row_idx * stride + i * BLOCK_THREADS * VEC_SIZE + + tx * VEC_SIZE); } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { @@ -1867,7 +1878,7 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType template __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, - uint32_t top_k_val, uint32_t d) { + uint32_t top_k_val, uint32_t d, uint32_t stride) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = bx; uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; @@ -1882,7 +1893,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* float max_val = GetMaxValue>( - probs, row_idx, d, temp_storage); + probs, row_idx, d, stride, temp_storage); double low = 0, high = max_val; float min_gt_low, max_le_high; @@ -1905,7 +1916,8 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * stride + i * BLOCK_THREADS * VEC_SIZE + + tx * VEC_SIZE); } ValueCount probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE]; #pragma unroll @@ -1975,7 +1987,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * stride + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { @@ -1989,7 +2001,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* template cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, float* top_p_arr, - uint32_t batch_size, float top_p_val, uint32_t d, + uint32_t batch_size, float top_p_val, uint32_t d, uint32_t stride, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); @@ -1998,7 +2010,7 @@ cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, float* top_p_arr, const uint32_t smem_size = sizeof(RenormTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &d}; + void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &d, &stride}; DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = TopPRenormProbKernel; FLASHINFER_CUDA_CALL( @@ -2011,7 +2023,7 @@ cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, float* top_p_arr, template cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr, - uint32_t batch_size, uint32_t top_k_val, uint32_t d, + uint32_t batch_size, uint32_t top_k_val, uint32_t d, uint32_t stride, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); @@ -2020,7 +2032,7 @@ cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr const uint32_t smem_size = sizeof(RenormTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d}; + void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d, &stride}; DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = TopKRenormProbKernel; FLASHINFER_CUDA_CALL( @@ -2033,7 +2045,7 @@ cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr template cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_arr, - uint32_t batch_size, uint32_t top_k_val, uint32_t d, + uint32_t batch_size, uint32_t top_k_val, uint32_t d, uint32_t stride, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); @@ -2042,7 +2054,7 @@ cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_ar const uint32_t smem_size = sizeof(RenormTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &d}; + void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &d, &stride}; DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = TopKMaskLogitsKernel; FLASHINFER_CUDA_CALL( diff --git a/tests/utils/test_sampling.py b/tests/utils/test_sampling.py index 20df72b55d..90a7854627 100644 --- a/tests/utils/test_sampling.py +++ b/tests/utils/test_sampling.py @@ -478,11 +478,18 @@ def test_top_k_renorm_probs(batch_size, vocab_size, k): @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("k", [10, 100, 500]) @pytest.mark.parametrize("neginf_input", [False, True]) -def test_top_k_mask_logits(batch_size, vocab_size, k, neginf_input): +@pytest.mark.parametrize("contiguous", [False, True]) +def test_top_k_mask_logits(batch_size, vocab_size, k, neginf_input, contiguous): if k > vocab_size: pytest.skip("k should be less than vocab_size") torch.manual_seed(42) - logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5 + + if contiguous: + logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5 + else: + logits = torch.randn(batch_size * 2, vocab_size, device="cuda:0") * 5 + logits = logits[::2, :] + if neginf_input: num_neginf = torch.randint(1, vocab_size * batch_size, (1,)).item() idxs = torch.randperm(batch_size * vocab_size, device="cuda:0")[:num_neginf]