Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rasterize indices only so that the alpha composition can be done in python with more flexibility. #120

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions gsplat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
import warnings
from typing import Any

import torch

from .project_gaussians import project_gaussians
from .rasterize import rasterize_gaussians
from .rasterize import rasterize_gaussians, rasterize_indices
from .sh import spherical_harmonics
from .utils import (
map_gaussian_to_intersects,
bin_and_sort_gaussians,
compute_cumulative_intersects,
compute_cov2d_bounds,
compute_cumulative_intersects,
get_tile_bin_edges,
map_gaussian_to_intersects,
)
from .sh import spherical_harmonics
from .version import __version__
import warnings


__all__ = [
"__version__",
"project_gaussians",
"rasterize_gaussians",
"spherical_harmonics",
# utils
"rasterize_indices",
"bin_and_sort_gaussians",
"compute_cumulative_intersects",
"compute_cov2d_bounds",
Expand Down
1 change: 1 addition & 0 deletions gsplat/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ def call_cuda(*args, **kwargs):
get_tile_bin_edges = _make_lazy_cuda_func("get_tile_bin_edges")
rasterize_forward = _make_lazy_cuda_func("rasterize_forward")
nd_rasterize_forward = _make_lazy_cuda_func("nd_rasterize_forward")
rasterize_indices = _make_lazy_cuda_func("rasterize_indices")
82 changes: 82 additions & 0 deletions gsplat/cuda/csrc/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,88 @@ nd_rasterize_forward_tensor(
return std::make_tuple(out_img, final_Ts, final_idx);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
rasterize_indices_tensor(
const std::tuple<int, int, int> tile_bounds,
const std::tuple<int, int, int> block,
const std::tuple<int, int, int> img_size,
const torch::Tensor &gaussian_ids_sorted,
const torch::Tensor &tile_bins,
const torch::Tensor &xys,
const torch::Tensor &conics,
const torch::Tensor &opacities
) {
CHECK_INPUT(gaussian_ids_sorted);
CHECK_INPUT(tile_bins);
CHECK_INPUT(xys);
CHECK_INPUT(conics);
CHECK_INPUT(opacities);

dim3 tile_bounds_dim3;
tile_bounds_dim3.x = std::get<0>(tile_bounds);
tile_bounds_dim3.y = std::get<1>(tile_bounds);
tile_bounds_dim3.z = std::get<2>(tile_bounds);

dim3 block_dim3;
block_dim3.x = std::get<0>(block);
block_dim3.y = std::get<1>(block);
block_dim3.z = std::get<2>(block);

dim3 img_size_dim3;
img_size_dim3.x = std::get<0>(img_size);
img_size_dim3.y = std::get<1>(img_size);
img_size_dim3.z = std::get<2>(img_size);

const int img_width = img_size_dim3.x;
const int img_height = img_size_dim3.y;

// First pass: count the number of gaussians contributing to each pixel.
// Note: early stopping is applied.
torch::Tensor chunk_cnts = torch::empty(
{img_height * img_width}, xys.options().dtype(torch::kInt32)
);

rasterize_indices<<<tile_bounds_dim3, block_dim3>>>(
tile_bounds_dim3,
img_size_dim3,
gaussian_ids_sorted.contiguous().data_ptr<int>(),
(int2 *)tile_bins.contiguous().data_ptr<int>(),
(float2 *)xys.contiguous().data_ptr<float>(),
(float3 *)conics.contiguous().data_ptr<float>(),
opacities.contiguous().data_ptr<float>(),
nullptr, // chunk_starts
chunk_cnts.contiguous().data_ptr<int>(),
nullptr, // out_gaussian_ids
nullptr // out_pixel_ids
);

// Second pass: allocate memory and write out the gaussian and pixel ids.
torch::Tensor cumsum = torch::cumsum(chunk_cnts, 0, chunk_cnts.scalar_type());
int64_t n_elems = cumsum[-1].item<int64_t>();
torch::Tensor chunk_starts = cumsum - chunk_cnts;

torch::Tensor out_gaussian_ids = torch::empty(
{n_elems}, xys.options().dtype(torch::kInt32)
);
torch::Tensor out_pixel_ids = torch::empty(
{n_elems}, xys.options().dtype(torch::kInt32)
);
rasterize_indices<<<tile_bounds_dim3, block_dim3>>>(
tile_bounds_dim3,
img_size_dim3,
gaussian_ids_sorted.contiguous().data_ptr<int>(),
(int2 *)tile_bins.contiguous().data_ptr<int>(),
(float2 *)xys.contiguous().data_ptr<float>(),
(float3 *)conics.contiguous().data_ptr<float>(),
opacities.contiguous().data_ptr<float>(),
chunk_starts.contiguous().data_ptr<int>(),
nullptr, // chunk_cnts
out_gaussian_ids.contiguous().data_ptr<int>(),
out_pixel_ids.contiguous().data_ptr<int>()
);

return std::make_tuple(out_gaussian_ids, out_pixel_ids, chunk_cnts);
}


std::
Expand Down
12 changes: 12 additions & 0 deletions gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ std::tuple<
const torch::Tensor &background
);

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
rasterize_indices_tensor(
const std::tuple<int, int, int> tile_bounds,
const std::tuple<int, int, int> block,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we update this interface so that we hide the block size from the user as in #129 ?

const std::tuple<int, int, int> img_size,
const torch::Tensor &gaussian_ids_sorted,
const torch::Tensor &tile_bins,
const torch::Tensor &xys,
const torch::Tensor &conics,
const torch::Tensor &opacities
);

std::tuple<
torch::Tensor,
torch::Tensor,
Expand Down
1 change: 1 addition & 0 deletions gsplat/cuda/csrc/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("nd_rasterize_backward", &nd_rasterize_backward_tensor);
m.def("rasterize_forward", &rasterize_forward_tensor);
m.def("rasterize_backward", &rasterize_backward_tensor);
m.def("rasterize_indices", &rasterize_indices_tensor);
m.def("project_gaussians_forward", &project_gaussians_forward_tensor);
m.def("project_gaussians_backward", &project_gaussians_backward_tensor);
m.def("compute_sh_forward", &compute_sh_forward_tensor);
Expand Down
128 changes: 128 additions & 0 deletions gsplat/cuda/csrc/forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,134 @@ __global__ void rasterize_forward(
}
}


__global__ void rasterize_indices(
const dim3 tile_bounds,
const dim3 img_size,
const int32_t* __restrict__ gaussian_ids_sorted,
const int2* __restrict__ tile_bins,
const float2* __restrict__ xys,
const float3* __restrict__ conics,
const float* __restrict__ opacities,
const int* __restrict__ chunk_starts,
int* __restrict__ chunk_cnts,
int* __restrict__ out_gaussian_ids,
int* __restrict__ out_pixel_ids
) {
// each thread draws one pixel, but also timeshares caching gaussians in a
// shared tile
auto block = cg::this_thread_block();
int32_t tile_id =
block.group_index().y * tile_bounds.x + block.group_index().x;
unsigned i =
block.group_index().y * block.group_dim().y + block.thread_index().y;
unsigned j =
block.group_index().x * block.group_dim().x + block.thread_index().x;

float px = (float)j;
float py = (float)i;
int32_t pix_id = i * img_size.x + j;

bool first_pass = true;
int base;
if (chunk_starts != nullptr) {
first_pass = false;
base = chunk_starts[pix_id];
}

// return if out of bounds
// keep not rasterizing threads around for reading data
bool inside = (i < img_size.y && j < img_size.x);
bool done = !inside;

// have all threads in tile process the same gaussians in batches
// first collect gaussians between range.x and range.y in batches
// which gaussians to look through in this tile
int2 range = tile_bins[tile_id];
int num_batches = (range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE;

__shared__ int32_t id_batch[BLOCK_SIZE];
__shared__ float3 xy_opacity_batch[BLOCK_SIZE];
__shared__ float3 conic_batch[BLOCK_SIZE];

// current visibility left to render
float T = 1.f;
// index of most recent gaussian to write to this thread's pixel
int cur_idx = 0;

// collect and process batches of gaussians
// each thread loads one gaussian at a time before rasterizing its
// designated pixel
int tr = block.thread_rank();
int cnt = 0;
for (int b = 0; b < num_batches; ++b) {
// resync all threads before beginning next batch
// end early if entire tile is done
if (__syncthreads_count(done) >= BLOCK_SIZE) {
break;
}

// each thread fetch 1 gaussian from front to back
// index of gaussian to load
int batch_start = range.x + BLOCK_SIZE * b;
int idx = batch_start + tr;
if (idx < range.y) {
int32_t g_id = gaussian_ids_sorted[idx];
id_batch[tr] = g_id;
const float2 xy = xys[g_id];
const float opac = opacities[g_id];
xy_opacity_batch[tr] = {xy.x, xy.y, opac};
conic_batch[tr] = conics[g_id];
}

// wait for other threads to collect the gaussians in batch
block.sync();

// process gaussians in the current batch for this pixel
int batch_size = min(BLOCK_SIZE, range.y - batch_start);
for (int t = 0; (t < batch_size) && !done; ++t) {
const float3 conic = conic_batch[t];
const float3 xy_opac = xy_opacity_batch[t];
const float opac = xy_opac.z;
const float2 delta = {xy_opac.x - px, xy_opac.y - py};
const float sigma = 0.5f * (conic.x * delta.x * delta.x +
conic.z * delta.y * delta.y) +
conic.y * delta.x * delta.y;
const float alpha = min(0.999f, opac * __expf(-sigma));
if (sigma < 0.f || alpha < 1.f / 255.f) {
continue;
}

const float next_T = T * (1.f - alpha);
if (next_T <= 1e-4f) { // this pixel is done
// we want to render the last gaussian that contributes and note
// that here idx > range.x so we don't underflow
done = true;
break;
}

if (first_pass) {
// First pass of this function we count the number of gaussians
// that contribute to each pixel
cnt += 1;
} else {
// Second pass we write out the gaussian ids and pixel ids
int32_t g = id_batch[t];
out_gaussian_ids[base + cnt] = g;
out_pixel_ids[base + cnt] = pix_id;
cnt += 1;
}

T = next_T;
}
}

if (first_pass) {
chunk_cnts[pix_id] = cnt;
}
}


// device helper to approximate projected 2d cov from 3d mean and cov
__device__ float3 project_cov3d_ewa(
const float3& __restrict__ mean3d,
Expand Down
14 changes: 14 additions & 0 deletions gsplat/cuda/csrc/forward.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,18 @@ __global__ void nd_rasterize_forward(
int* __restrict__ final_index,
float* __restrict__ out_img,
const float* __restrict__ background
);

__global__ void rasterize_indices(
const dim3 tile_bounds,
const dim3 img_size,
const int32_t* __restrict__ gaussian_ids_sorted,
const int2* __restrict__ tile_bins,
const float2* __restrict__ xys,
const float3* __restrict__ conics,
const float* __restrict__ opacities,
const int* __restrict__ chunk_starts,
int* __restrict__ chunk_cnts,
int* __restrict__ out_gaussian_ids,
int* __restrict__ out_pixel_ids
);
Loading
Loading