Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions csrc/renorm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ using tvm::ffi::Optional;

void top_p_renorm_probs(TensorView probs, TensorView renorm_probs,
Optional<TensorView> maybe_top_p_arr, double top_p_val) {
CHECK_INPUT(probs);
CHECK_CUDA(probs);
CHECK_LAST_DIM_CONTIGUOUS(probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
unsigned int batch_size = probs->shape[0];
unsigned int vocab_size = probs->shape[1];
Expand All @@ -34,14 +35,15 @@ void top_p_renorm_probs(TensorView probs, TensorView renorm_probs,
cudaError_t status = sampling::TopPRenormProb<float>(
static_cast<float*>(probs->data), static_cast<float*>(renorm_probs->data),
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr.value()->data) : nullptr, batch_size,
top_p_val, vocab_size, stream);
top_p_val, vocab_size, probs->strides[0], stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopPRenormProb failed with error code " << cudaGetErrorString(status);
}

void top_k_renorm_probs(TensorView probs, TensorView renorm_probs,
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val) {
CHECK_INPUT(probs);
CHECK_CUDA(probs);
CHECK_LAST_DIM_CONTIGUOUS(probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
unsigned int batch_size = probs->shape[0];
unsigned int vocab_size = probs->shape[1];
Expand All @@ -52,15 +54,16 @@ void top_k_renorm_probs(TensorView probs, TensorView renorm_probs,
cudaError_t status = sampling::TopKRenormProb<float>(
static_cast<float*>(probs->data), static_cast<float*>(renorm_probs->data),
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr.value()->data) : nullptr, batch_size,
top_k_val, vocab_size, stream);
top_k_val, vocab_size, probs->strides[0], stream);

TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopKRenormProb failed with error code " << cudaGetErrorString(status);
}

void top_k_mask_logits(TensorView logits, TensorView mask_logits,
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val) {
CHECK_INPUT(logits);
CHECK_CUDA(logits);
CHECK_LAST_DIM_CONTIGUOUS(logits);
CHECK_DIM(2, logits); // logits: (batch_size, vocab_size)
unsigned int batch_size = logits->shape[0];
unsigned int vocab_size = logits->shape[1];
Expand All @@ -71,7 +74,7 @@ void top_k_mask_logits(TensorView logits, TensorView mask_logits,
cudaError_t status = sampling::TopKMaskLogits<float>(
static_cast<float*>(logits->data), static_cast<float*>(mask_logits->data),
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr.value()->data) : nullptr, batch_size,
top_k_val, vocab_size, stream);
top_k_val, vocab_size, logits->strides[0], stream);

TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopKMaskLogits failed with error code " << cudaGetErrorString(status);
Expand Down
34 changes: 21 additions & 13 deletions csrc/sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ void softmax(TensorView workspace_buffer, TensorView logits, TensorView output,

void sampling_from_logits(TensorView logits, TensorView output, Optional<TensorView> maybe_indices,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
CHECK_INPUT(logits);
CHECK_CUDA(logits);
CHECK_LAST_DIM_CONTIGUOUS(logits);
CHECK_DIM(2, logits); // logits: (batch_size, vocab_size)
unsigned int batch_size = output->shape[0];
unsigned int vocab_size = logits->shape[1];
Expand All @@ -55,14 +56,16 @@ void sampling_from_logits(TensorView logits, TensorView output, Optional<TensorV
cudaError_t status = sampling::SamplingFromLogits(
static_cast<float*>(logits->data), static_cast<int*>(output->data),
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value()->data) : nullptr,
batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream);
batch_size, vocab_size, logits->strides[0], deterministic, philox_seed, philox_offset,
stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "SamplingFromLogits failed with error code " << cudaGetErrorString(status);
}

void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorView> maybe_indices,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
CHECK_INPUT(probs);
CHECK_CUDA(probs);
CHECK_LAST_DIM_CONTIGUOUS(probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
unsigned int batch_size = output->shape[0];
unsigned int vocab_size = probs->shape[1];
Expand All @@ -72,7 +75,7 @@ void sampling_from_probs(TensorView probs, TensorView output, Optional<TensorVie
cudaError_t status = sampling::SamplingFromProb(
static_cast<float*>(probs->data), static_cast<int*>(output->data),
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value()->data) : nullptr,
batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream);
batch_size, vocab_size, probs->strides[0], deterministic, philox_seed, philox_offset, stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "SamplingFromProbs failed with error code " << cudaGetErrorString(status);
}
Expand All @@ -81,7 +84,8 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_top_p_arr, double top_p_val,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
CHECK_INPUT(probs);
CHECK_CUDA(probs);
CHECK_LAST_DIM_CONTIGUOUS(probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
unsigned int batch_size = output->shape[0];
unsigned int vocab_size = probs->shape[1];
Expand All @@ -93,7 +97,7 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output,
static_cast<float*>(probs->data), static_cast<int*>(output->data),
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value()->data) : nullptr,
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr.value()->data) : nullptr, batch_size,
top_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
top_p_val, vocab_size, probs->strides[0], deterministic, philox_seed, philox_offset, stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopPSamplingFromProbs failed with error code " << cudaGetErrorString(status);
}
Expand All @@ -102,7 +106,8 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
CHECK_INPUT(probs);
CHECK_CUDA(probs);
CHECK_LAST_DIM_CONTIGUOUS(probs);
CHECK_INPUT(output);
CHECK_DEVICE(output, probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
Expand All @@ -117,7 +122,7 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output,
static_cast<float*>(probs->data), static_cast<int*>(output->data),
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value()->data) : nullptr,
has_top_k_arr ? static_cast<float*>(maybe_top_k_arr.value()->data) : nullptr, batch_size,
top_k_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
top_k_val, vocab_size, probs->strides[0], deterministic, philox_seed, philox_offset, stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopKSamplingFromProbs failed with error code " << cudaGetErrorString(status);
}
Expand All @@ -126,7 +131,8 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output,
Optional<TensorView> maybe_indices,
Optional<TensorView> maybe_min_p_arr, double min_p_val,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset) {
CHECK_INPUT(probs);
CHECK_CUDA(probs);
CHECK_LAST_DIM_CONTIGUOUS(probs);
CHECK_INPUT(output);
CHECK_DEVICE(output, probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
Expand All @@ -142,7 +148,8 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output,
has_min_p_arr ? static_cast<float*>(maybe_min_p_arr.value()->data) : nullptr,
static_cast<int*>(output->data),
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value()->data) : nullptr,
batch_size, min_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream);
batch_size, min_p_val, vocab_size, probs->strides[0], deterministic, philox_seed,
philox_offset, stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "MinPSamplingFromProb failed with error code " << cudaGetErrorString(status);
}
Expand All @@ -153,7 +160,8 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
Optional<TensorView> maybe_top_p_arr, double top_p_val,
bool deterministic, uint64_t philox_seed,
uint64_t philox_offset) {
CHECK_INPUT(probs);
CHECK_CUDA(probs);
CHECK_LAST_DIM_CONTIGUOUS(probs);
CHECK_INPUT(output);
CHECK_DEVICE(output, probs);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
Expand All @@ -171,8 +179,8 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output,
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr.value()->data) : nullptr,
static_cast<int*>(output->data),
maybe_indices.has_value() ? static_cast<int*>(maybe_indices.value()->data) : nullptr,
batch_size, top_k_val, top_p_val, vocab_size, deterministic, philox_seed, philox_offset,
stream);
batch_size, top_k_val, top_p_val, vocab_size, probs->strides[0], deterministic, philox_seed,
philox_offset, stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopKTopPSamplingFromProbs failed with error code " << cudaGetErrorString(status);
}
Expand Down
Loading