|
7 | 7 | namespace at {
|
8 | 8 | namespace native {
|
9 | 9 |
|
| 10 | +void topk_out_with_sort( |
| 11 | + const Tensor& self, |
| 12 | + int64_t k, int64_t dim, bool largest, |
| 13 | + const Tensor& values, |
| 14 | + const Tensor& indices |
| 15 | +) { |
| 16 | + Tensor sorted_values, sorted_indices; |
| 17 | + std::tie(sorted_values, sorted_indices) = at::native::sort_cuda(self, dim, largest); |
| 18 | + values.copy_(sorted_values.narrow(dim, 0, k)); |
| 19 | + indices.copy_(sorted_indices.narrow(dim, 0, k)); |
| 20 | +} |
| 21 | + |
| 22 | +bool should_use_sort(const Tensor& self, int64_t dim) { |
| 23 | + // This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/68632 |
| 24 | + if (self.dim() == 0) return false; |
| 25 | + if (self.dtype() == kBool) return false; // Bool is not support by topk |
| 26 | + int64_t slice_size = self.size(dim); |
| 27 | + if (slice_size == 0) return false; |
| 28 | + int64_t num_slices = self.numel() / slice_size; |
| 29 | + return num_slices <= 16 && slice_size >= 100000; |
| 30 | +} |
| 31 | + |
10 | 32 | TORCH_IMPL_FUNC(topk_out_cuda)
|
11 | 33 | (const Tensor& self,
|
12 | 34 | int64_t k, int64_t dim, bool largest, bool sorted,
|
13 | 35 | const Tensor& values,
|
14 | 36 | const Tensor& indices) {
|
15 | 37 | TensorArg topK_arg{values, "topK", 1}, indices_arg{indices, "indices", 2}, input_arg{self, "self", 3};
|
16 | 38 | checkAllSameGPU(__func__, {topK_arg, indices_arg, input_arg});
|
| 39 | + |
17 | 40 | dim = at::maybe_wrap_dim(dim, self);
|
18 | 41 |
|
| 42 | + if (should_use_sort(self, dim)) { |
| 43 | + topk_out_with_sort(self, k, dim, largest, values, indices); |
| 44 | + return; |
| 45 | + } |
| 46 | + |
19 | 47 | // If k is 0 the result is an empty tensor, so we don't need to launch a kernel.
|
20 | 48 | if (k == 0) {
|
21 | 49 | return;
|
|
0 commit comments