-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcompact_topk_cuda.cu
More file actions
149 lines (126 loc) · 4.41 KB
/
compact_topk_cuda.cu
File metadata and controls
149 lines (126 loc) · 4.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cstdint>
#include <vector>
namespace {
template <typename scalar_t>
__global__ void count_valid_kernel(
const scalar_t* values,
int64_t rows,
int64_t top_k,
double threshold,
int64_t* counts) {
const int64_t row = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (row >= rows) {
return;
}
const int64_t base = row * top_k;
int64_t count = 0;
for (int64_t k = 0; k < top_k; ++k) {
if (static_cast<double>(values[base + k]) >= threshold) {
++count;
}
}
counts[row] = count;
}
template <typename scalar_t>
__global__ void scatter_valid_kernel(
const scalar_t* values,
const int64_t* indices,
int64_t rows,
int64_t seq_len,
int64_t top_k,
double threshold,
const int64_t* offsets,
int64_t* batch_out,
int64_t* pos_out,
int64_t* feat_out,
scalar_t* value_out) {
const int64_t row = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (row >= rows) {
return;
}
const int64_t base = row * top_k;
int64_t out = offsets[row];
const int64_t batch_idx = row / seq_len;
const int64_t pos_idx = row % seq_len;
for (int64_t k = 0; k < top_k; ++k) {
const scalar_t value = values[base + k];
if (static_cast<double>(value) < threshold) {
continue;
}
batch_out[out] = batch_idx;
pos_out[out] = pos_idx;
feat_out[out] = indices[base + k];
value_out[out] = value;
++out;
}
}
void check_cuda_launch(const char* kernel_name) {
const cudaError_t error = cudaGetLastError();
TORCH_CHECK(error == cudaSuccess, kernel_name, " failed: ", cudaGetErrorString(error));
}
} // namespace
std::vector<torch::Tensor> compact_topk_threshold_cuda(
torch::Tensor top_vals,
torch::Tensor top_idx,
double threshold) {
TORCH_CHECK(top_vals.is_cuda(), "compact_topk_threshold_cuda expects CUDA tensors");
TORCH_CHECK(top_idx.is_cuda(), "compact_topk_threshold_cuda expects CUDA tensors");
auto values = top_vals.contiguous();
auto indices = top_idx.contiguous();
const auto batch = values.size(0);
const auto seq_len = values.size(1);
const auto top_k = values.size(2);
const auto rows = batch * seq_len;
auto long_options = indices.options().dtype(torch::kLong);
auto counts = torch::zeros({rows}, long_options);
if (rows == 0) {
auto empty_long = torch::empty({0}, long_options);
auto empty_values = torch::empty({0}, values.options());
return {empty_long, empty_long.clone(), empty_long.clone(), empty_values};
}
const int threads = 256;
const int blocks = static_cast<int>((rows + threads - 1) / threads);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(values.scalar_type(), "compact_topk_threshold_cuda_count", [&] {
count_valid_kernel<scalar_t><<<blocks, threads>>>(
values.data_ptr<scalar_t>(),
rows,
top_k,
threshold,
counts.data_ptr<int64_t>()
);
});
check_cuda_launch("count_valid_kernel");
auto cumsum_counts = counts.cumsum(0, torch::kLong);
const int64_t valid_count = cumsum_counts[rows - 1].item<int64_t>();
auto offsets = torch::zeros({rows}, long_options);
if (rows > 1) {
offsets.slice(0, 1).copy_(cumsum_counts.slice(0, 0, rows - 1));
}
auto batch_out = torch::empty({valid_count}, long_options);
auto pos_out = torch::empty({valid_count}, long_options);
auto feat_out = torch::empty({valid_count}, long_options);
auto value_out = torch::empty({valid_count}, values.options());
if (valid_count == 0) {
return {batch_out, pos_out, feat_out, value_out};
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(values.scalar_type(), "compact_topk_threshold_cuda_scatter", [&] {
scatter_valid_kernel<scalar_t><<<blocks, threads>>>(
values.data_ptr<scalar_t>(),
indices.data_ptr<int64_t>(),
rows,
seq_len,
top_k,
threshold,
offsets.data_ptr<int64_t>(),
batch_out.data_ptr<int64_t>(),
pos_out.data_ptr<int64_t>(),
feat_out.data_ptr<int64_t>(),
value_out.data_ptr<scalar_t>()
);
});
check_cuda_launch("scatter_valid_kernel");
return {batch_out, pos_out, feat_out, value_out};
}