Skip to content

Commit

Permalink
Refactor interfaces to hide tile_bounds and allow dynamic block_size (#…
Browse files Browse the repository at this point in the history
…129)

* refactor block size to be dynamic, hide tile_bounds from interface, update tests to reflect this change

* block_size->block_width, lint

* lint
  • Loading branch information
kerrj authored Feb 20, 2024
1 parent aa1ff65 commit 10bc1d0
Show file tree
Hide file tree
Showing 16 changed files with 145 additions and 312 deletions.
27 changes: 15 additions & 12 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import numpy as np
import torch
import tyro
from gsplat.project_gaussians import _ProjectGaussians
from gsplat.rasterize import _RasterizeGaussians
from gsplat.project_gaussians import project_gaussians
from gsplat.rasterize import rasterize_gaussians
from PIL import Image
from torch import Tensor, optim

Expand All @@ -25,17 +25,10 @@ def __init__(
self.gt_image = gt_image.to(device=self.device)
self.num_points = num_points

BLOCK_X, BLOCK_Y = 16, 16
fov_x = math.pi / 2.0
self.H, self.W = gt_image.shape[0], gt_image.shape[1]
self.focal = 0.5 * float(self.W) / math.tan(0.5 * fov_x)
self.tile_bounds = (
(self.W + BLOCK_X - 1) // BLOCK_X,
(self.H + BLOCK_Y - 1) // BLOCK_Y,
1,
)
self.img_size = torch.tensor([self.W, self.H, 1], device=self.device)
self.block = torch.tensor([BLOCK_X, BLOCK_Y, 1], device=self.device)

self._init_gaussians()

Expand Down Expand Up @@ -87,9 +80,18 @@ def train(self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = Fals
mse_loss = torch.nn.MSELoss()
frames = []
times = [0] * 3 # project, rasterize, backward
B_SIZE = 16
for iter in range(iterations):
start = time.time()
xys, depths, radii, conics, compensation, num_tiles_hit, cov3d = _ProjectGaussians.apply(
(
xys,
depths,
radii,
conics,
compensation,
num_tiles_hit,
cov3d,
) = project_gaussians(
self.means,
self.scales,
1,
Expand All @@ -102,12 +104,12 @@ def train(self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = Fals
self.H / 2,
self.H,
self.W,
self.tile_bounds,
B_SIZE,
)
torch.cuda.synchronize()
times[0] += time.time() - start
start = time.time()
out_img = _RasterizeGaussians.apply(
out_img = rasterize_gaussians(
xys,
depths,
radii,
Expand All @@ -117,6 +119,7 @@ def train(self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = Fals
torch.sigmoid(self.opacities),
self.H,
self.W,
B_SIZE,
self.background,
)
torch.cuda.synchronize()
Expand Down
223 changes: 0 additions & 223 deletions examples/test_rasterize.py

This file was deleted.

20 changes: 14 additions & 6 deletions gsplat/_torch_impl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Pure PyTorch implementations of various functions"""

import struct

import torch
Expand Down Expand Up @@ -233,9 +234,9 @@ def clip_near_plane(p, viewmat, clip_thresh=0.01):
return p_view, p_view[..., 2] < clip_thresh


def get_tile_bbox(pix_center, pix_radius, tile_bounds, BLOCK_X=16, BLOCK_Y=16):
def get_tile_bbox(pix_center, pix_radius, tile_bounds, block_width):
tile_size = torch.tensor(
[BLOCK_X, BLOCK_Y], dtype=torch.float32, device=pix_center.device
[block_width, block_width], dtype=torch.float32, device=pix_center.device
)
tile_center = pix_center / tile_size
tile_radius = pix_radius[..., None] / tile_size
Expand Down Expand Up @@ -268,9 +269,14 @@ def project_gaussians_forward(
fullmat,
intrins,
img_size,
tile_bounds,
block_width,
clip_thresh=0.01,
):
tile_bounds = (
(img_size[0] + block_width - 1) // block_width,
(img_size[1] + block_width - 1) // block_width,
1,
)
fx, fy, cx, cy = intrins
tan_fovx = 0.5 * img_size[0] / fx
tan_fovy = 0.5 * img_size[1] / fy
Expand All @@ -281,7 +287,7 @@ def project_gaussians_forward(
)
conic, radius, det_valid = compute_cov2d_bounds(cov2d)
xys = project_pix(fullmat, means3d, img_size, (cx, cy))
tile_min, tile_max = get_tile_bbox(xys, radius, tile_bounds)
tile_min, tile_max = get_tile_bbox(xys, radius, tile_bounds, block_width)
tile_area = (tile_max[..., 0] - tile_min[..., 0]) * (
tile_max[..., 1] - tile_min[..., 1]
)
Expand Down Expand Up @@ -318,7 +324,7 @@ def project_gaussians_forward(


def map_gaussian_to_intersects(
num_points, xys, depths, radii, cum_tiles_hit, tile_bounds
num_points, xys, depths, radii, cum_tiles_hit, tile_bounds, block_width
):
num_intersects = cum_tiles_hit[-1]
isect_ids = torch.zeros(num_intersects, dtype=torch.int64, device=xys.device)
Expand All @@ -328,7 +334,9 @@ def map_gaussian_to_intersects(
if radii[idx] <= 0:
break

tile_min, tile_max = get_tile_bbox(xys[idx], radii[idx], tile_bounds)
tile_min, tile_max = get_tile_bbox(
xys[idx], radii[idx], tile_bounds, block_width
)

cur_idx = 0 if idx == 0 else cum_tiles_hit[idx - 1].item()

Expand Down
15 changes: 8 additions & 7 deletions gsplat/cuda/csrc/backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,13 @@ __global__ void rasterize_backward_kernel(
// first collect gaussians between range.x and range.y in batches
// which gaussians to look through in this tile
const int2 range = tile_bins[tile_id];
const int num_batches = (range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int block_size = block.size();
const int num_batches = (range.y - range.x + block_size - 1) / block_size;

__shared__ int32_t id_batch[BLOCK_SIZE];
__shared__ float3 xy_opacity_batch[BLOCK_SIZE];
__shared__ float3 conic_batch[BLOCK_SIZE];
__shared__ float3 rgbs_batch[BLOCK_SIZE];
__shared__ int32_t id_batch[MAX_BLOCK_SIZE];
__shared__ float3 xy_opacity_batch[MAX_BLOCK_SIZE];
__shared__ float3 conic_batch[MAX_BLOCK_SIZE];
__shared__ float3 rgbs_batch[MAX_BLOCK_SIZE];

// df/d_out for this pixel
const float3 v_out = v_output[pix_id];
Expand All @@ -206,8 +207,8 @@ __global__ void rasterize_backward_kernel(
// 0 index will be furthest back in batch
// index of gaussian to load
// batch end is the index of the last gaussian in the batch
const int batch_end = range.y - 1 - BLOCK_SIZE * b;
int batch_size = min(BLOCK_SIZE, batch_end + 1 - range.x);
const int batch_end = range.y - 1 - block_size * b;
int batch_size = min(block_size, batch_end + 1 - range.x);
const int idx = batch_end - tr;
if (idx >= range.x) {
int32_t g_id = gaussian_ids_sorted[idx];
Expand Down
Loading

0 comments on commit 10bc1d0

Please sign in to comment.