Skip to content

Commit 95f4cd0

Browse files
zasdfgbnmfacebook-github-bot
authored andcommitted
Implement topk with sort for some cases (pytorch#68632)
Summary: Benchmark that compares original implementation and the sort implementation (this code should run on a branch without this patch): ```python import torch import timeit def tune_dtype(f): def ret(*args, **kwargs): for dtype in [torch.int8, torch.half, torch.float, torch.double]: f(*args, **kwargs, dtype=dtype) return ret def tune_slice(f): def ret(*args, **kwargs): slice = 1 while slice <= 256: f(*args, **kwargs, slice=slice) slice *= 2 return ret def tune_slice_size(f): def ret(*args, **kwargs): slice_size = 1 while slice_size <= 1_000_000: f(*args, **kwargs, slice_size=slice_size) slice_size *= 10 return ret def tune_k(f): def ret(*args, slice_size, **kwargs): k = 1 while k <= slice_size: f(*args, **kwargs, k=k, slice_size=slice_size) k *= 10 return ret def topk_with_sort(tensor, k, dim=-1, largest=True): values, indices = tensor.sort(dim=dim, descending=largest) return values.narrow(dim, 0, k), indices.narrow(dim, 0, k) def run50sync(f): for _ in range(50): f() torch.cuda.synchronize() def warmup(): N = 1000000 for i in range(1, N // 10000): torch.randn(i, device='cuda') def benchmark_one(slice, slice_size, k, dtype): input_ = torch.empty((slice, slice_size), dtype=dtype, device="cuda").random_() torch.cuda.synchronize() time = timeit.timeit(lambda: run50sync(lambda: torch.topk(input_, k, dim=1)), number=1) torch.cuda.synchronize() time_sort = timeit.timeit(lambda: run50sync(lambda: topk_with_sort(input_, k, dim=1)), number=1) method = "orig" if time < time_sort else "sort" speedup = time / time_sort print(f"(dtype={dtype}, slice={slice}, slice_size={slice_size}, k={k}) -> (method={method}, speedup={speedup})") if __name__ == "__main__": warmup() tune_dtype(tune_slice(tune_slice_size(tune_k(benchmark_one))))() ``` Benchmark result see next comment. Pull Request resolved: pytorch#68632 Reviewed By: dagitses Differential Revision: D32566233 Pulled By: ngimel fbshipit-source-id: f7a508176ef3685b491048c4a6562121c60b8b2a
1 parent e554d8b commit 95f4cd0

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

aten/src/ATen/native/cuda/TensorTopK.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,43 @@
77
namespace at {
88
namespace native {
99

10+
void topk_out_with_sort(
11+
const Tensor& self,
12+
int64_t k, int64_t dim, bool largest,
13+
const Tensor& values,
14+
const Tensor& indices
15+
) {
16+
Tensor sorted_values, sorted_indices;
17+
std::tie(sorted_values, sorted_indices) = at::native::sort_cuda(self, dim, largest);
18+
values.copy_(sorted_values.narrow(dim, 0, k));
19+
indices.copy_(sorted_indices.narrow(dim, 0, k));
20+
}
21+
22+
bool should_use_sort(const Tensor& self, int64_t dim) {
23+
// This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/68632
24+
if (self.dim() == 0) return false;
25+
if (self.dtype() == kBool) return false; // Bool is not support by topk
26+
int64_t slice_size = self.size(dim);
27+
if (slice_size == 0) return false;
28+
int64_t num_slices = self.numel() / slice_size;
29+
return num_slices <= 16 && slice_size >= 100000;
30+
}
31+
1032
TORCH_IMPL_FUNC(topk_out_cuda)
1133
(const Tensor& self,
1234
int64_t k, int64_t dim, bool largest, bool sorted,
1335
const Tensor& values,
1436
const Tensor& indices) {
1537
TensorArg topK_arg{values, "topK", 1}, indices_arg{indices, "indices", 2}, input_arg{self, "self", 3};
1638
checkAllSameGPU(__func__, {topK_arg, indices_arg, input_arg});
39+
1740
dim = at::maybe_wrap_dim(dim, self);
1841

42+
if (should_use_sort(self, dim)) {
43+
topk_out_with_sort(self, k, dim, largest, values, indices);
44+
return;
45+
}
46+
1947
// If k is 0 the result is an empty tensor, so we don't need to launch a kernel.
2048
if (k == 0) {
2149
return;

test/test_sort_and_select.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,11 @@ def compare(t, k, dim, dir):
370370
k = random.randint(1, testTensor.size(dim))
371371
compare(testTensor, k, dim, dir)
372372

373+
# This tests the code path where on CUDA, topk is implemented with sort.
374+
t = torch.randn((2, 100000), device=device)
375+
compare(t, 2000, 1, True)
376+
compare(t, 2000, 1, False)
377+
373378
def test_topk_arguments(self, device):
374379
q = torch.randn(10, 2, 10, device=device)
375380
# Make sure True isn't mistakenly taken as the 2nd dimension (interpreted as 1)

0 commit comments

Comments
 (0)