Skip to content

Commit

Permalink
updating project gaussians torch impl and numerical tests (#118)
Browse files Browse the repository at this point in the history
* updating project gaussians torch impl and numerical tests

* zero out gaussians not in view when projecting

* update tests to also check masked regions

* switch in-place with torch.where

* rm torch compile

* lint

---------

Co-authored-by: Justin Kerr <[email protected]>
  • Loading branch information
vye16 and kerrj authored Feb 8, 2024
1 parent 210ed53 commit 9fffa1a
Show file tree
Hide file tree
Showing 2 changed files with 301 additions and 109 deletions.
141 changes: 73 additions & 68 deletions gsplat/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,35 +114,21 @@ def eval_sh_bases(basis_dim: int, dirs: torch.Tensor):
def quat_to_rotmat(quat: Tensor) -> Tensor:
assert quat.shape[-1] == 4, quat.shape
w, x, y, z = torch.unbind(F.normalize(quat, dim=-1), dim=-1)
return torch.stack(
mat = torch.stack(
[
torch.stack(
[
1 - 2 * (y**2 + z**2),
2 * (x * y - w * z),
2 * (x * z + w * y),
],
dim=-1,
),
torch.stack(
[
2 * (x * y + w * z),
1 - 2 * (x**2 + z**2),
2 * (y * z - w * x),
],
dim=-1,
),
torch.stack(
[
2 * (x * z - w * y),
2 * (y * z + w * x),
1 - 2 * (x**2 + y**2),
],
dim=-1,
),
1 - 2 * (y ** 2 + z ** 2),
2 * (x * y - w * z),
2 * (x * z + w * y),
2 * (x * y + w * z),
1 - 2 * (x ** 2 + z ** 2),
2 * (y * z - w * x),
2 * (x * z - w * y),
2 * (y * z + w * x),
1 - 2 * (x ** 2 + y ** 2),
],
dim=-2,
dim=-1,
)
return mat.reshape(quat.shape[:-1] + (3, 3))


def scale_rot_to_cov3d(scale: Tensor, glob_scale: float, quat: Tensor) -> Tensor:
Expand All @@ -169,35 +155,40 @@ def project_cov3d_ewa(
assert viewmat.shape[-2:] == (4, 4), viewmat.shape
W = viewmat[..., :3, :3] # (..., 3, 3)
p = viewmat[..., :3, 3] # (..., 3)
t = torch.matmul(W, mean3d[..., None])[..., 0] + p # (..., 3)
t = torch.einsum("...ij,...j->...i", W, mean3d) + p # (..., 3)

rz = 1.0 / t[..., 2] # (...,)
rz2 = rz ** 2 # (...,)

lim_x = 1.3 * torch.tensor([tan_fovx], device=mean3d.device)
lim_y = 1.3 * torch.tensor([tan_fovy], device=mean3d.device)
x_clamp = t[..., 2] * torch.clamp(t[..., 0] * rz, min=-lim_x, max=lim_x)
y_clamp = t[..., 2] * torch.clamp(t[..., 1] * rz, min=-lim_y, max=lim_y)
t = torch.stack([x_clamp, y_clamp, t[..., 2]], dim=-1)

min_lim_x = t[..., 2] * torch.min(lim_x, torch.max(-lim_x, t[..., 0] / t[..., 2]))
min_lim_y = t[..., 2] * torch.min(lim_y, torch.max(-lim_y, t[..., 1] / t[..., 2]))
t = torch.cat([min_lim_x[..., None], min_lim_y[..., None], t[..., 2, None]], dim=-1)

rz = 1.0 / t[..., 2] # (...,)
rz2 = rz**2 # (...,)
O = torch.zeros_like(rz)
J = torch.stack(
[
torch.stack([fx * rz, torch.zeros_like(rz), -fx * t[..., 0] * rz2], dim=-1),
torch.stack([torch.zeros_like(rz), fy * rz, -fy * t[..., 1] * rz2], dim=-1),
],
dim=-2,
) # (..., 2, 3)
T = J @ W # (..., 2, 3)
cov2d = T @ cov3d @ T.transpose(-1, -2) # (..., 2, 2)
[fx * rz, O, -fx * t[..., 0] * rz2, O, fy * rz, -fy * t[..., 1] * rz2],
dim=-1,
).reshape(*rz.shape, 2, 3)
T = torch.matmul(J, W) # (..., 2, 3)
cov2d = torch.einsum("...ij,...jk,...kl->...il", T, cov3d, T.transpose(-1, -2))
# add a little blur along axes and (TODO save upper triangular elements)
cov2d[..., 0, 0] = cov2d[..., 0, 0] + 0.3
cov2d[..., 1, 1] = cov2d[..., 1, 1] + 0.3
return cov2d
return cov2d[..., :2, :2]


def compute_cov2d_bounds(cov2d: Tensor, eps=1e-6):
det = cov2d[..., 0, 0] * cov2d[..., 1, 1] - cov2d[..., 0, 1] ** 2
det = torch.clamp(det, min=eps)
def compute_cov2d_bounds(cov2d_mat: Tensor):
"""
param: cov2d matrix (*, 2, 2)
returns: conic parameters (*, 3)
"""
det_all = cov2d_mat[..., 0, 0] * cov2d_mat[..., 1, 1] - cov2d_mat[..., 0, 1] ** 2
valid = det_all != 0
# det = torch.clamp(det, min=eps)
det = det_all[valid]
cov2d = cov2d_mat[valid]
conic = torch.stack(
[
cov2d[..., 1, 1] / det,
Expand All @@ -207,30 +198,34 @@ def compute_cov2d_bounds(cov2d: Tensor, eps=1e-6):
dim=-1,
) # (..., 3)
b = (cov2d[..., 0, 0] + cov2d[..., 1, 1]) / 2 # (...,)
v1 = b + torch.sqrt(torch.clamp(b**2 - det, min=0.1)) # (...,)
v2 = b - torch.sqrt(torch.clamp(b**2 - det, min=0.1)) # (...,)
v1 = b + torch.sqrt(torch.clamp(b ** 2 - det, min=0.1)) # (...,)
v2 = b - torch.sqrt(torch.clamp(b ** 2 - det, min=0.1)) # (...,)
radius = torch.ceil(3.0 * torch.sqrt(torch.max(v1, v2))) # (...,)
return conic, radius, det > eps
radius_all = torch.zeros(*cov2d_mat.shape[:-2], device=cov2d_mat.device)
conic_all = torch.zeros(*cov2d_mat.shape[:-2], 3, device=cov2d_mat.device)
radius_all[valid] = radius
conic_all[valid] = conic
return conic_all, radius_all, valid


def ndc2pix(x, W):
return 0.5 * ((x + 1.0) * W - 1.0)
def ndc2pix(x, W, c):
return 0.5 * W * x - 0.5 + c


def project_pix(mat, p, img_size, eps=1e-6):
def project_pix(fullmat, p, img_size, center, eps=1e-6):
p_hom = F.pad(p, (0, 1), value=1.0)
p_hom = torch.einsum("...ij,...j->...i", mat, p_hom)
rw = 1.0 / torch.clamp(p_hom[..., 3], min=eps)
p_hom = torch.einsum("...ij,...j->...i", fullmat, p_hom)
rw = 1.0 / (p_hom[..., 3] + eps)
p_proj = p_hom[..., :3] * rw[..., None]
u = ndc2pix(p_proj[..., 0], img_size[0])
v = ndc2pix(p_proj[..., 1], img_size[1])
u = ndc2pix(p_proj[..., 0], img_size[0], center[0])
v = ndc2pix(p_proj[..., 1], img_size[1], center[1])
return torch.stack([u, v], dim=-1)


def clip_near_plane(p, viewmat, clip_thresh=0.01):
R = viewmat[..., :3, :3]
T = viewmat[..., :3, 3]
p_view = torch.matmul(R, p[..., None])[..., 0] + T
R = viewmat[:3, :3]
T = viewmat[:3, 3]
p_view = torch.einsum("ij,nj->ni", R, p) + T[None]
return p_view, p_view[..., 2] < clip_thresh


Expand Down Expand Up @@ -266,21 +261,21 @@ def project_gaussians_forward(
glob_scale,
quats,
viewmat,
projmat,
fx,
fy,
fullmat,
intrins,
img_size,
tile_bounds,
clip_thresh=0.01,
):
tan_fovx = 0.5 * img_size[1] / fx
tan_fovy = 0.5 * img_size[0] / fy
fx, fy, cx, cy = intrins
tan_fovx = 0.5 * img_size[0] / fx
tan_fovy = 0.5 * img_size[1] / fy
p_view, is_close = clip_near_plane(means3d, viewmat, clip_thresh)
cov3d = scale_rot_to_cov3d(scales, glob_scale, quats)
cov2d = project_cov3d_ewa(means3d, cov3d, viewmat, fx, fy, tan_fovx, tan_fovy)
conic, radius, det_valid = compute_cov2d_bounds(cov2d)
center = project_pix(projmat, means3d, img_size)
tile_min, tile_max = get_tile_bbox(center, radius, tile_bounds)
xys = project_pix(fullmat, means3d, img_size, (cx, cy))
tile_min, tile_max = get_tile_bbox(xys, radius, tile_bounds)
tile_area = (tile_max[..., 0] - tile_min[..., 0]) * (
tile_max[..., 1] - tile_min[..., 1]
)
Expand All @@ -289,10 +284,20 @@ def project_gaussians_forward(
num_tiles_hit = tile_area
depths = p_view[..., 2]
radii = radius.to(torch.int32)
xys = center
conics = conic

return cov3d, xys, depths, radii, conics, num_tiles_hit, mask
radii = torch.where(~mask, 0, radii)
conic = torch.where(~mask[..., None], 0, conic)
xys = torch.where(~mask[..., None], 0, xys)
cov3d = torch.where(~mask[..., None, None], 0, cov3d)
cov2d = torch.where(~mask[..., None, None], 0, cov2d)
num_tiles_hit = torch.where(~mask, 0, num_tiles_hit)
depths = torch.where(~mask, 0, depths)

i, j = torch.triu_indices(3, 3)
cov3d_triu = cov3d[..., i, j]
i, j = torch.triu_indices(2, 2)
cov2d_triu = cov2d[..., i, j]
return cov3d_triu, cov2d_triu, xys, depths, radii, conic, num_tiles_hit, mask


def map_gaussian_to_intersects(
Expand Down
Loading

0 comments on commit 9fffa1a

Please sign in to comment.