Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA]: Possible improvement of deviceHistogram::HistogramEven #4219

Open
1 task done
mfranzreb opened this issue Mar 20, 2025 · 0 comments
Open
1 task done

[FEA]: Possible improvement of deviceHistogram::HistogramEven #4219

mfranzreb opened this issue Mar 20, 2025 · 0 comments
Labels
feature request New feature or request.

Comments

@mfranzreb
Copy link

Is this a duplicate?

Area

CUB

Is your feature request related to a problem? Please describe.

For my application I had to create a custom histogram since I could not use CUB's, because I had to do other things apart from histogramming. Out of curiosity, I benchmarked it against the CUB implementation, and it seems to be considerably faster. Here's a plot of the results for an RTX 2080Ti:

Image

My use case is creating the histogram of a text composed of an alphabet. For the tests I used uniformly randomly distributed texts, from an alphabet ´[0, alphabet_size)´. Also, for the alphabet size of 64'000 and the character 33'555, the CUB histogram entry is 0, which shouldn't be, and for the alphabet size of 100'000 and the character 0, the results are different. I've tested my implementation throughly with many random inputs, so am relatively sure the CUB result is wrong.

Here's the script I used for benchmarking:

#include <omp.h>

#include <cub/device/device_histogram.cuh>
#include <fstream>
#include <random>

typedef unsigned long long cu_size_t;

static uint32_t kMaxTPB = 0;
static cudaDeviceProp prop;

#define WS 32

#if defined(__CUDA_ARCH__)
#if __CUDA_ARCH__ > 800 && __CUDA_ARCH__ < 900
#define MAX_TPB 768
#define MIN_BPM 2
#elif __CUDA_ARCH__ == 750
#define MAX_TPB 1024
#define MIN_BPM 1
#else
#define MAX_TPB 1024
#define MIN_BPM 2
#endif

#define LB(x, y) __launch_bounds__(x, y)
#else
#define LB(x, y)
#endif

#define gpuErrchk(ans)                    \
  {                                       \
    gpuAssert((ans), __FILE__, __LINE__); \
  }
__host__ __device__ inline void gpuAssert(cudaError_t code, const char* file,
                                          int line, bool abort = true) {
  if (code != cudaSuccess) {
    printf("GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
    if (abort) {
#ifdef __CUDA_ARCH__
      asm("trap;");
#else
      exit(EXIT_FAILURE);
#endif
    }
  }
}

__host__ cudaDeviceProp& getDeviceProperties() {
  assert(prop.totalGlobalMem != 0);
  return prop;
}

__host__ void checkWarpSize(uint32_t const GPU_index) {
  if (prop.totalGlobalMem == 0) {
    gpuErrchk(cudaSetDevice(GPU_index));
    cudaGetDeviceProperties(&prop, GPU_index);
    auto const threads_per_sm = prop.maxThreadsPerMultiProcessor;
    kMaxTPB = prop.maxThreadsPerBlock;
    // find max block size that can still fully load an SM
    while (threads_per_sm % kMaxTPB != 0) {
      kMaxTPB -= WS;
    }
    assert(kMaxTPB > WS);
  }
  if (prop.warpSize != WS) {
    fprintf(stderr, "Warp size must be 32, but is %d\n", prop.warpSize);
    exit(EXIT_FAILURE);
  }
}

template <typename T>
std::vector<T> generateRandomData(std::vector<T> const& alphabet,
                                  size_t const data_size) {
  std::vector<T> data(data_size);
#pragma omp parallel
  {
    // Create a thread-local random number generator
    std::random_device rd;
    std::mt19937 gen(rd() + omp_get_thread_num());  // Add thread number to seed
                                                    // for better randomness
    std::uniform_int_distribution<size_t> dis(0, alphabet.size() - 1);

#pragma omp for
    for (size_t i = 0; i < data_size; i++) {
      data[i] = alphabet[dis(gen)];
    }
  }

  return data;
}

template <typename T>
__host__ T findLargestDivisor(T const n, T const divisor) {
  if (divisor == 0) return 1;
  return divisor - n % divisor;
}

template <typename T, bool UseShmem>
__global__ LB(MAX_TPB, MIN_BPM) void computeGlobalHistogramKernel(
    T* data, size_t const data_size, size_t* counts, T* const alphabet,
    size_t const alphabet_size, uint16_t const hists_per_block) {
  assert(blockDim.x % WS == 0);
  extern __shared__ size_t shared_hist[];
  size_t offset;
  if constexpr (UseShmem) {
    offset = (threadIdx.x % hists_per_block) * alphabet_size;
    for (size_t i = threadIdx.x; i < alphabet_size * hists_per_block;
         i += blockDim.x) {
      shared_hist[i] = 0;
    }
    __syncthreads();
  }

  size_t const total_threads = blockDim.x * gridDim.x;
  size_t const global_t_id = blockIdx.x * blockDim.x + threadIdx.x;
  T char_data;
  for (size_t i = global_t_id; i < data_size; i += total_threads) {
    char_data = data[i];
    if constexpr (UseShmem) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <= 520
      atomicAdd((cu_size_t*)&shared_hist[offset + char_data], size_t(1));
#else
      atomicAdd_block((cu_size_t*)&shared_hist[offset + char_data], size_t(1));
#endif
    } else {
      atomicAdd((cu_size_t*)&counts[char_data], size_t(1));
    }
  }

  if constexpr (UseShmem) {
    __syncthreads();
    // Reduce shared histograms to first one
    for (size_t i = threadIdx.x; i < alphabet_size; i += blockDim.x) {
      size_t sum = shared_hist[i];
      for (size_t j = 1; j < hists_per_block; ++j) {
        sum += shared_hist[j * alphabet_size + i];
      }
      atomicAdd((cu_size_t*)&counts[i], sum);
    }
  }
}

template <typename T>
__host__ void computeGlobalHistogram(size_t const data_size, T* d_data,
                                     T* d_alphabet, size_t* d_histogram,
                                     size_t const alphabet_size) {
  struct cudaFuncAttributes funcAttrib;
  gpuErrchk(cudaFuncGetAttributes(&funcAttrib,
                                  computeGlobalHistogramKernel<T, true>));

  struct cudaDeviceProp& prop = getDeviceProperties();

  uint32_t max_TPB =
      std::min(kMaxTPB, static_cast<uint32_t>(funcAttrib.maxThreadsPerBlock));

  max_TPB = findLargestDivisor(kMaxTPB, max_TPB);

  auto const max_shmem_per_SM = prop.sharedMemPerMultiprocessor;
  auto const max_threads_per_SM = prop.maxThreadsPerMultiProcessor;
  size_t const hist_size = sizeof(size_t) * alphabet_size;

  // Compute global_histogram and change text to min_alphabet
  size_t const total_threads = std::min(
      (data_size / WS) * WS,
      static_cast<size_t>(max_threads_per_SM * prop.multiProcessorCount));

  auto const threads_per_block = max_TPB;
  auto const num_blocks = total_threads / threads_per_block;

  uint16_t const blocks_per_SM = max_threads_per_SM / threads_per_block;

  size_t const used_shmem =
      std::min(max_shmem_per_SM / blocks_per_SM, prop.sharedMemPerBlock);

  uint16_t const hists_per_block =
      std::min(static_cast<size_t>(threads_per_block), used_shmem / hist_size);
  if (hists_per_block > 0) {
    computeGlobalHistogramKernel<T, true>
        <<<num_blocks, threads_per_block, used_shmem>>>(
            d_data, data_size, d_histogram, d_alphabet, alphabet_size,
            hists_per_block);
  } else {
    computeGlobalHistogramKernel<T, false><<<num_blocks, threads_per_block>>>(
        d_data, data_size, d_histogram, d_alphabet, alphabet_size,
        hists_per_block);
  }
}

template <typename T>
void BM_HistComputation(size_t const data_size, size_t const alphabet_size,
                        size_t const num_iters, std::string const& output) {
  auto alphabet = std::vector<T>(alphabet_size);
  std::iota(alphabet.begin(), alphabet.end(), 0ULL);
  auto data = generateRandomData<T>(alphabet, data_size);

  T* d_data;
  T* d_alphabet;
  size_t* d_histogram;
  gpuErrchk(cudaMalloc(&d_data, data_size * sizeof(T)));
  gpuErrchk(cudaMemcpy(d_data, data.data(), data_size * sizeof(T),
                       cudaMemcpyHostToDevice));
  gpuErrchk(cudaMalloc(&d_alphabet, alphabet_size * sizeof(T)));
  gpuErrchk(cudaMemcpy(d_alphabet, alphabet.data(), alphabet_size * sizeof(T),
                       cudaMemcpyHostToDevice));
  gpuErrchk(cudaMalloc(&d_histogram, alphabet_size * sizeof(size_t)));

  cudaEvent_t start, stop;
  gpuErrchk(cudaEventCreate(&start));
  gpuErrchk(cudaEventCreate(&stop));
  std::vector<float> times(num_iters);

  void* d_temp_storage = nullptr;
  size_t temp_storage_bytes = 0;
  cub::DeviceHistogram::HistogramEven(
      d_temp_storage, temp_storage_bytes, d_data, (cu_size_t*)d_histogram,
      alphabet_size + 1, T(0), T(alphabet_size), data_size);
  gpuErrchk(cudaMalloc(&d_temp_storage, temp_storage_bytes));

  // Warmup
  for (int i = 0; i < 5; ++i) {
    gpuErrchk(cudaMemset(d_histogram, 0, alphabet_size * sizeof(size_t)));
    cub::DeviceHistogram::HistogramEven(
        d_temp_storage, temp_storage_bytes, d_data, (cu_size_t*)d_histogram,
        alphabet_size + 1, T(0), T(alphabet_size), data_size);
  }

  float median_cub = 0;
  std::vector<size_t> cub_hist(alphabet_size);
  for (size_t i = 0; i < num_iters; ++i) {
    gpuErrchk(cudaMemset(d_histogram, 0, alphabet_size * sizeof(size_t)));
    gpuErrchk(cudaEventRecord(start));
    cub::DeviceHistogram::HistogramEven(
        d_temp_storage, temp_storage_bytes, d_data, (cu_size_t*)d_histogram,
        alphabet_size + 1, T(0), T(alphabet_size), data_size);
    gpuErrchk(cudaEventRecord(stop));
    gpuErrchk(cudaEventSynchronize(stop));
    float milliseconds = 0;
    gpuErrchk(cudaEventElapsedTime(&milliseconds, start, stop));
    times[i] = milliseconds;
  }
  // Get median time
  std::nth_element(times.begin(), times.begin() + times.size() / 2,
                   times.end());
  median_cub = times[times.size() / 2];
  gpuErrchk(cudaFree(d_temp_storage));
  gpuErrchk(cudaMemcpy(cub_hist.data(), d_histogram,
                       alphabet_size * sizeof(size_t), cudaMemcpyDeviceToHost));

  // Warmup
  for (int i = 0; i < 5; ++i) {
    gpuErrchk(cudaMemset(d_histogram, 0, alphabet_size * sizeof(size_t)));
    computeGlobalHistogram(data_size, d_data, d_alphabet, d_histogram,
                           alphabet_size);
  }

  float median_custom = 0;
  std::vector<size_t> custom_hist(alphabet_size);
  for (size_t i = 0; i < num_iters; ++i) {
    gpuErrchk(cudaMemset(d_histogram, 0, alphabet_size * sizeof(size_t)));
    gpuErrchk(cudaEventRecord(start));
    computeGlobalHistogram(data_size, d_data, d_alphabet, d_histogram,
                           alphabet_size);
    gpuErrchk(cudaEventRecord(stop));
    gpuErrchk(cudaEventSynchronize(stop));
    float milliseconds = 0;
    gpuErrchk(cudaEventElapsedTime(&milliseconds, start, stop));
    times[i] = milliseconds;
  }
  // Get median time
  std::nth_element(times.begin(), times.begin() + times.size() / 2,
                   times.end());
  median_custom = times[times.size() / 2];
  std::ofstream out(output, std::ios::app);
  out << data_size << "," << alphabet_size << "," << median_cub << ","
      << median_custom << std::endl;
  out.close();

  gpuErrchk(cudaMemcpy(custom_hist.data(), d_histogram,
                       alphabet_size * sizeof(size_t), cudaMemcpyDeviceToHost));
  for (size_t i = 0; i < alphabet_size; ++i) {
    if (cub_hist[i] != custom_hist[i]) {
      std::cerr << "Mismatch at index " << i << " cub_hist: " << cub_hist[i]
                << " custom_hist: " << custom_hist[i] << std::endl;
      break;
    }
  }
  gpuErrchk(cudaFree(d_data));
  gpuErrchk(cudaFree(d_alphabet));
  gpuErrchk(cudaFree(d_histogram));
}

int main(int argc, char** argv) {
  if (argc != 4) {
    std::cerr << "Usage: " << argv[0] << " <GPU_index> <num_iters> <output_csv>"
              << std::endl;
    return EXIT_FAILURE;
  }

  uint32_t const GPU_index = std::stoi(argv[1]);
  uint32_t const num_iters = std::stoi(argv[2]);
  std::string const output = argv[3];

  std::ofstream out(output);
  out << "data_size,alphabet_size,median_cub(ms),median_custom(ms)"
      << std::endl;
  out.close();

  checkWarpSize(GPU_index);
  std::vector<size_t> const data_sizes = {500'000'000, 1'000'000'000,
                                          2'000'000'000};
  std::vector<size_t> const alphabet_sizes = {
      4,    8,    16,   32,   64,    128,   256,   512,
      1024, 2048, 4096, 8192, 16384, 32768, 64000, 100000};

  for (auto const data_size : data_sizes) {
    for (auto const alphabet_size : alphabet_sizes) {
      if (alphabet_size >= std::numeric_limits<uint16_t>::max()) {
        BM_HistComputation<uint32_t>(data_size, alphabet_size, num_iters,
                                     output);
      } else if (alphabet_size >= std::numeric_limits<uint8_t>::max()) {
        BM_HistComputation<uint16_t>(data_size, alphabet_size, num_iters,
                                     output);
      } else {
        BM_HistComputation<uint8_t>(data_size, alphabet_size, num_iters,
                                    output);
      }
    }
  }
  return EXIT_SUCCESS;
}

I used nvcc -O3 -Xcompiler -fopenmp -arch=sm_75 hist_comp.cu -o hist_comp and ran it for 20 iterations, using version 12.0.

My implementation packs as many histograms as possible into shared memory, and assigns a local histogram to each thread in the block, in a round robin fashion 2 minimize how many threads within a warp share the same local histogram. If a histogram doesnt fit in shared memory, it just performs it atomically to global memory.

Here is also the CSV of the results, with more data sizes: 2080Ti.csv.

Is your implementation optimized for other specific use cases? Just leaving this here in case you're interested in improving the performance for this use case.

Describe the solution you'd like

Improve the performance of HistogramEven for creating histograms of texts.

Describe alternatives you've considered

No response

Additional context

No response

@mfranzreb mfranzreb added the feature request New feature or request. label Mar 20, 2025
@github-project-automation github-project-automation bot moved this to Todo in CCCL Mar 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request.
Projects
Status: Todo
Development

No branches or pull requests

1 participant