Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions adet/layers/csrc/ml_nms/ml_nms.cu
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THC.h>
#include <THC/THCDeviceUtils.cuh>

//#include <THC/THC.h>
#include <ATen/cuda/DeviceUtils.cuh>
#include <ATen/ceil_div.h>
#include <ATen/cuda/ThrustAllocator.h>
#include <vector>
#include <iostream>

Expand Down Expand Up @@ -65,7 +66,7 @@ __global__ void ml_nms_kernel(const int n_boxes, const float nms_overlap_thresh,
t |= 1ULL << i;
}
}
const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock);
const int col_blocks = at::ceil_div(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
Expand All @@ -82,28 +83,28 @@ at::Tensor ml_nms_cuda(const at::Tensor boxes, const float nms_overlap_thresh) {

int boxes_num = boxes.size(0);

const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock);
const int col_blocks = at::ceil_div(boxes_num, threadsPerBlock);

scalar_t* boxes_dev = boxes_sorted.data<scalar_t>();

THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState
// THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState

unsigned long long* mask_dev = NULL;
//THCudaCheck(THCudaMalloc(state, (void**) &mask_dev,
// boxes_num * col_blocks * sizeof(unsigned long long)));

mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long));
mask_dev = (unsigned long long*) c10::cuda::CUDACachingAllocator::raw_alloc(boxes_num * col_blocks * sizeof(unsigned long long));

dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
THCCeilDiv(boxes_num, threadsPerBlock));
dim3 blocks(at::ceil_div(boxes_num, threadsPerBlock),
at::ceil_div(boxes_num, threadsPerBlock));
dim3 threads(threadsPerBlock);
ml_nms_kernel<<<blocks, threads>>>(boxes_num,
nms_overlap_thresh,
boxes_dev,
mask_dev);

std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
THCudaCheck(cudaMemcpy(&mask_host[0],
c10::cuda::CUDACachingAllocator::raw_alloc(cudaMemcpy(&mask_host[0],
mask_dev,
sizeof(unsigned long long) * boxes_num * col_blocks,
cudaMemcpyDeviceToHost));
Expand All @@ -128,7 +129,7 @@ at::Tensor ml_nms_cuda(const at::Tensor boxes, const float nms_overlap_thresh) {
}
}

THCudaFree(state, mask_dev);
c10::cuda::CUDACachingAllocator::raw_delete(mask_dev);
// TODO improve this part
return std::get<0>(order_t.index({
keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
Expand Down