Skip to content

Commit 9dc74c1

Browse files
committed
modified cpu and cuda argmax
1 parent d641bd5 commit 9dc74c1

File tree

2 files changed

+48
-27
lines changed

2 files changed

+48
-27
lines changed

src/ops/random_sample/cpu/random_sample.cc

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,11 @@ void random_sample_cpu_f16(RandomSampleCpuDescriptor_t desc,
135135
auto index_ = reinterpret_cast<uint64_t *>(result);
136136
auto source = reinterpret_cast<const uint16_t *>(probs);
137137

138-
char *origin = reinterpret_cast<char *>(workspace);
139-
uint16_t *logits_ = (uint16_t *) origin;
140-
141-
std::copy(source, source + voc, logits_);
142-
143-
float M = f16_to_f32(logits_[0]);
138+
float M = f16_to_f32(source[0]);
144139
int index = 0;
145140
for (int j = 1; j < voc; j++) {
146-
if (M < f16_to_f32(logits_[j])) {
147-
M = f16_to_f32(logits_[j]);
141+
if (M < f16_to_f32(source[j])) {
142+
M = f16_to_f32(source[j]);
148143
index = j;
149144
}
150145
}

src/ops/random_sample/cuda/random_sample.cu

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,31 @@
33
#include "random_sample.cuh"
44
#include <cub/block/block_reduce.cuh>
55
#include <cub/cub.cuh>
6-
6+
template<class T, int BLOCK_DIM>
7+
__global__ void argmaxKernel(T *val_out, int voc, uint64_t *result) {
8+
float localM = -__FLT_MAX__;
9+
uint64_t index = threadIdx.x;
10+
for (int i = threadIdx.x; i < voc; i += BLOCK_DIM) {
11+
if (localM < static_cast<float>(val_out[i])) {
12+
localM = static_cast<float>(val_out[i]);
13+
index = i;
14+
}
15+
}
16+
__shared__ uint64_t globalInd[BLOCK_DIM];
17+
__shared__ float globalM[BLOCK_DIM];
18+
globalInd[threadIdx.x] = index;
19+
globalM[threadIdx.x] = localM;
20+
for (int strip = BLOCK_DIM / 2; strip > 0; strip /= 2) {
21+
if (threadIdx.x < strip) {
22+
if (globalM[threadIdx.x] < globalM[threadIdx.x + strip]) {
23+
globalM[threadIdx.x] = globalM[threadIdx.x + strip];
24+
globalInd[threadIdx.x] = globalInd[threadIdx.x + strip];
25+
}
26+
}
27+
__syncthreads();
28+
}
29+
result[0] = globalInd[0];
30+
}
731
template<class T, int BLOCK_DIM>
832
__global__ void softmax(
933
T *val_out,
@@ -132,25 +156,26 @@ void random_sample_nv_gpu_f16(RandomSampleCudaDescriptor_t desc, void *workspace
132156
void *stream) {
133157
int voc = desc->voc;
134158
//下面这段代码在排序
135-
char *origin = reinterpret_cast<char *>(workspace);
136-
char *keyTmp = origin + voc * sizeof(half);
137-
half *val_out = (half *) origin;
138159

139-
uint64_t *key_in = (uint64_t *) keyTmp;
140-
uint64_t *key_out = key_in + voc;
160+
if (topp > 0 && topk > 1) {
161+
char *origin = reinterpret_cast<char *>(workspace);
162+
char *keyTmp = origin + voc * sizeof(half);
163+
half *val_out = (half *) origin;
141164

142-
index<<<(voc + 1023) / 1024, 1024, 0, (cudaStream_t) stream>>>(key_in, voc);
143-
//下面开始计算workspace空间
165+
uint64_t *key_in = (uint64_t *) keyTmp;
166+
uint64_t *key_out = key_in + voc;
144167

145-
void *workspace_extra = reinterpret_cast<char *>(workspace) + desc->step;
146-
uint64_t workspace_len = workspace_size - desc->step;
147-
sort_pairs_descending<half, uint64_t>(
148-
workspace_extra, workspace_len,
149-
(half *) probs, val_out,
150-
key_in, key_out,
151-
voc, (cudaStream_t) stream);//该函数会把排序结果和对应索引保存在val_out和key_out上
152-
//排序结束,然后开始做softmax变换
153-
if (topp > 0 && topk > 1) {
168+
index<<<(voc + 1023) / 1024, 1024, 0, (cudaStream_t) stream>>>(key_in, voc);
169+
//下面开始计算workspace空间
170+
171+
void *workspace_extra = reinterpret_cast<char *>(workspace) + desc->step;
172+
uint64_t workspace_len = workspace_size - desc->step;
173+
sort_pairs_descending<half, uint64_t>(
174+
workspace_extra, workspace_len,
175+
(half *) probs, val_out,
176+
key_in, key_out,
177+
voc, (cudaStream_t) stream);//该函数会把排序结果和对应索引保存在val_out和key_out上
178+
//排序结束,然后开始做softmax变换
154179
int BLOCK_DIM = 1024;
155180
int num_blocks = (voc + BLOCK_DIM - 1) / BLOCK_DIM;
156181
softmax<half, 1024><<<num_blocks, BLOCK_DIM, 0, (cudaStream_t) stream>>>(val_out, topk,
@@ -169,8 +194,9 @@ void random_sample_nv_gpu_f16(RandomSampleCudaDescriptor_t desc, void *workspace
169194
key_out);
170195

171196
} else {
172-
random_sample_kernel<<<1, 1, 0, (cudaStream_t) stream>>>((uint64_t *) result,
173-
key_out);
197+
int BLOCK_DIM = 1024;
198+
int num_blocks = (voc + BLOCK_DIM - 1) / BLOCK_DIM;
199+
argmaxKernel<half, 1024><<<num_blocks, BLOCK_DIM, 0, (cudaStream_t) stream>>>((half *) probs, voc, (uint64_t *) result);
174200
}
175201
}
176202

0 commit comments

Comments
 (0)