diff --git a/gsplat/_torch_impl.py b/gsplat/_torch_impl.py index 513116a15..8fe5951c6 100644 --- a/gsplat/_torch_impl.py +++ b/gsplat/_torch_impl.py @@ -114,35 +114,21 @@ def eval_sh_bases(basis_dim: int, dirs: torch.Tensor): def quat_to_rotmat(quat: Tensor) -> Tensor: assert quat.shape[-1] == 4, quat.shape w, x, y, z = torch.unbind(F.normalize(quat, dim=-1), dim=-1) - return torch.stack( + mat = torch.stack( [ - torch.stack( - [ - 1 - 2 * (y**2 + z**2), - 2 * (x * y - w * z), - 2 * (x * z + w * y), - ], - dim=-1, - ), - torch.stack( - [ - 2 * (x * y + w * z), - 1 - 2 * (x**2 + z**2), - 2 * (y * z - w * x), - ], - dim=-1, - ), - torch.stack( - [ - 2 * (x * z - w * y), - 2 * (y * z + w * x), - 1 - 2 * (x**2 + y**2), - ], - dim=-1, - ), + 1 - 2 * (y ** 2 + z ** 2), + 2 * (x * y - w * z), + 2 * (x * z + w * y), + 2 * (x * y + w * z), + 1 - 2 * (x ** 2 + z ** 2), + 2 * (y * z - w * x), + 2 * (x * z - w * y), + 2 * (y * z + w * x), + 1 - 2 * (x ** 2 + y ** 2), ], - dim=-2, + dim=-1, ) + return mat.reshape(quat.shape[:-1] + (3, 3)) def scale_rot_to_cov3d(scale: Tensor, glob_scale: float, quat: Tensor) -> Tensor: @@ -169,35 +155,40 @@ def project_cov3d_ewa( assert viewmat.shape[-2:] == (4, 4), viewmat.shape W = viewmat[..., :3, :3] # (..., 3, 3) p = viewmat[..., :3, 3] # (..., 3) - t = torch.matmul(W, mean3d[..., None])[..., 0] + p # (..., 3) + t = torch.einsum("...ij,...j->...i", W, mean3d) + p # (..., 3) + + rz = 1.0 / t[..., 2] # (...,) + rz2 = rz ** 2 # (...,) lim_x = 1.3 * torch.tensor([tan_fovx], device=mean3d.device) lim_y = 1.3 * torch.tensor([tan_fovy], device=mean3d.device) + x_clamp = t[..., 2] * torch.clamp(t[..., 0] * rz, min=-lim_x, max=lim_x) + y_clamp = t[..., 2] * torch.clamp(t[..., 1] * rz, min=-lim_y, max=lim_y) + t = torch.stack([x_clamp, y_clamp, t[..., 2]], dim=-1) - min_lim_x = t[..., 2] * torch.min(lim_x, torch.max(-lim_x, t[..., 0] / t[..., 2])) - min_lim_y = t[..., 2] * torch.min(lim_y, torch.max(-lim_y, t[..., 1] / t[..., 2])) - t = torch.cat([min_lim_x[..., None], min_lim_y[..., None], t[..., 2, None]], dim=-1) - - rz = 1.0 / t[..., 2] # (...,) - rz2 = rz**2 # (...,) + O = torch.zeros_like(rz) J = torch.stack( - [ - torch.stack([fx * rz, torch.zeros_like(rz), -fx * t[..., 0] * rz2], dim=-1), - torch.stack([torch.zeros_like(rz), fy * rz, -fy * t[..., 1] * rz2], dim=-1), - ], - dim=-2, - ) # (..., 2, 3) - T = J @ W # (..., 2, 3) - cov2d = T @ cov3d @ T.transpose(-1, -2) # (..., 2, 2) + [fx * rz, O, -fx * t[..., 0] * rz2, O, fy * rz, -fy * t[..., 1] * rz2], + dim=-1, + ).reshape(*rz.shape, 2, 3) + T = torch.matmul(J, W) # (..., 2, 3) + cov2d = torch.einsum("...ij,...jk,...kl->...il", T, cov3d, T.transpose(-1, -2)) # add a little blur along axes and (TODO save upper triangular elements) cov2d[..., 0, 0] = cov2d[..., 0, 0] + 0.3 cov2d[..., 1, 1] = cov2d[..., 1, 1] + 0.3 - return cov2d + return cov2d[..., :2, :2] -def compute_cov2d_bounds(cov2d: Tensor, eps=1e-6): - det = cov2d[..., 0, 0] * cov2d[..., 1, 1] - cov2d[..., 0, 1] ** 2 - det = torch.clamp(det, min=eps) +def compute_cov2d_bounds(cov2d_mat: Tensor): + """ + param: cov2d matrix (*, 2, 2) + returns: conic parameters (*, 3) + """ + det_all = cov2d_mat[..., 0, 0] * cov2d_mat[..., 1, 1] - cov2d_mat[..., 0, 1] ** 2 + valid = det_all != 0 + # det = torch.clamp(det, min=eps) + det = det_all[valid] + cov2d = cov2d_mat[valid] conic = torch.stack( [ cov2d[..., 1, 1] / det, @@ -207,30 +198,34 @@ def compute_cov2d_bounds(cov2d: Tensor, eps=1e-6): dim=-1, ) # (..., 3) b = (cov2d[..., 0, 0] + cov2d[..., 1, 1]) / 2 # (...,) - v1 = b + torch.sqrt(torch.clamp(b**2 - det, min=0.1)) # (...,) - v2 = b - torch.sqrt(torch.clamp(b**2 - det, min=0.1)) # (...,) + v1 = b + torch.sqrt(torch.clamp(b ** 2 - det, min=0.1)) # (...,) + v2 = b - torch.sqrt(torch.clamp(b ** 2 - det, min=0.1)) # (...,) radius = torch.ceil(3.0 * torch.sqrt(torch.max(v1, v2))) # (...,) - return conic, radius, det > eps + radius_all = torch.zeros(*cov2d_mat.shape[:-2], device=cov2d_mat.device) + conic_all = torch.zeros(*cov2d_mat.shape[:-2], 3, device=cov2d_mat.device) + radius_all[valid] = radius + conic_all[valid] = conic + return conic_all, radius_all, valid -def ndc2pix(x, W): - return 0.5 * ((x + 1.0) * W - 1.0) +def ndc2pix(x, W, c): + return 0.5 * W * x - 0.5 + c -def project_pix(mat, p, img_size, eps=1e-6): +def project_pix(fullmat, p, img_size, center, eps=1e-6): p_hom = F.pad(p, (0, 1), value=1.0) - p_hom = torch.einsum("...ij,...j->...i", mat, p_hom) - rw = 1.0 / torch.clamp(p_hom[..., 3], min=eps) + p_hom = torch.einsum("...ij,...j->...i", fullmat, p_hom) + rw = 1.0 / (p_hom[..., 3] + eps) p_proj = p_hom[..., :3] * rw[..., None] - u = ndc2pix(p_proj[..., 0], img_size[0]) - v = ndc2pix(p_proj[..., 1], img_size[1]) + u = ndc2pix(p_proj[..., 0], img_size[0], center[0]) + v = ndc2pix(p_proj[..., 1], img_size[1], center[1]) return torch.stack([u, v], dim=-1) def clip_near_plane(p, viewmat, clip_thresh=0.01): - R = viewmat[..., :3, :3] - T = viewmat[..., :3, 3] - p_view = torch.matmul(R, p[..., None])[..., 0] + T + R = viewmat[:3, :3] + T = viewmat[:3, 3] + p_view = torch.einsum("ij,nj->ni", R, p) + T[None] return p_view, p_view[..., 2] < clip_thresh @@ -266,21 +261,21 @@ def project_gaussians_forward( glob_scale, quats, viewmat, - projmat, - fx, - fy, + fullmat, + intrins, img_size, tile_bounds, clip_thresh=0.01, ): - tan_fovx = 0.5 * img_size[1] / fx - tan_fovy = 0.5 * img_size[0] / fy + fx, fy, cx, cy = intrins + tan_fovx = 0.5 * img_size[0] / fx + tan_fovy = 0.5 * img_size[1] / fy p_view, is_close = clip_near_plane(means3d, viewmat, clip_thresh) cov3d = scale_rot_to_cov3d(scales, glob_scale, quats) cov2d = project_cov3d_ewa(means3d, cov3d, viewmat, fx, fy, tan_fovx, tan_fovy) conic, radius, det_valid = compute_cov2d_bounds(cov2d) - center = project_pix(projmat, means3d, img_size) - tile_min, tile_max = get_tile_bbox(center, radius, tile_bounds) + xys = project_pix(fullmat, means3d, img_size, (cx, cy)) + tile_min, tile_max = get_tile_bbox(xys, radius, tile_bounds) tile_area = (tile_max[..., 0] - tile_min[..., 0]) * ( tile_max[..., 1] - tile_min[..., 1] ) @@ -289,10 +284,20 @@ def project_gaussians_forward( num_tiles_hit = tile_area depths = p_view[..., 2] radii = radius.to(torch.int32) - xys = center - conics = conic - return cov3d, xys, depths, radii, conics, num_tiles_hit, mask + radii = torch.where(~mask, 0, radii) + conic = torch.where(~mask[..., None], 0, conic) + xys = torch.where(~mask[..., None], 0, xys) + cov3d = torch.where(~mask[..., None, None], 0, cov3d) + cov2d = torch.where(~mask[..., None, None], 0, cov2d) + num_tiles_hit = torch.where(~mask, 0, num_tiles_hit) + depths = torch.where(~mask, 0, depths) + + i, j = torch.triu_indices(3, 3) + cov3d_triu = cov3d[..., i, j] + i, j = torch.triu_indices(2, 2) + cov2d_triu = cov2d[..., i, j] + return cov3d_triu, cov2d_triu, xys, depths, radii, conic, num_tiles_hit, mask def map_gaussian_to_intersects( diff --git a/tests/test_project_gaussians.py b/tests/test_project_gaussians.py index f3ed543dd..cd37677cf 100644 --- a/tests/test_project_gaussians.py +++ b/tests/test_project_gaussians.py @@ -1,29 +1,67 @@ import pytest +import traceback import torch +from torch.func import vjp # type: ignore +from gsplat import _torch_impl +import gsplat.cuda as _C + + +torch.manual_seed(42) device = torch.device("cuda:0") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") -def test_project_gaussians_forward(): - from gsplat import _torch_impl - import gsplat.cuda as _C +def projection_matrix(fx, fy, W, H, n=0.01, f=1000.0): + return torch.tensor( + [ + [2.0 * fx / W, 0.0, 0.0, 0.0], + [0.0, 2.0 * fy / H, 0.0, 0.0], + [0.0, 0.0, (f + n) / (f - n), -2 * f * n / (f - n)], + [0.0, 0.0, 1.0, 0.0], + ], + device=device, + ) + + +def check_close(a, b, atol=1e-5, rtol=1e-5): + try: + torch.testing.assert_close(a, b, atol=atol, rtol=rtol) + except AssertionError: + traceback.print_exc() + diff = torch.abs(a - b).detach() + print(f"{diff.max()=} {diff.mean()=}") + import ipdb - torch.manual_seed(42) + ipdb.set_trace() + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +def test_project_gaussians_forward(): num_points = 100 means3d = torch.randn((num_points, 3), device=device, requires_grad=True) - scales = torch.randn((num_points, 3), device=device) - glob_scale = 0.3 + scales = torch.rand((num_points, 3), device=device) + 0.2 + glob_scale = 1.0 quats = torch.randn((num_points, 4), device=device) quats /= torch.linalg.norm(quats, dim=-1, keepdim=True) - viewmat = torch.eye(4, device=device) - projmat = torch.eye(4, device=device) - fx, fy = 3.0, 3.0 + H, W = 512, 512 + cx, cy = W / 2, H / 2 + # 90 degree FOV + fx, fy = W / 2, W / 2 clip_thresh = 0.01 + viewmat = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 8.0], + [0.0, 0.0, 0.0, 1.0], + ], + device=device, + ) + projmat = projection_matrix(fx, fy, W, H) + fullmat = projmat @ viewmat BLOCK_X, BLOCK_Y = 16, 16 tile_bounds = (W + BLOCK_X - 1) // BLOCK_X, (H + BLOCK_Y - 1) // BLOCK_Y, 1 @@ -35,57 +73,206 @@ def test_project_gaussians_forward(): glob_scale, quats, viewmat, - projmat, + fullmat, fx, fy, - W / 2, - H / 2, + cx, + cy, H, W, tile_bounds, clip_thresh, ) + masks = num_tiles_hit > 0 + + with torch.no_grad(): + ( + _cov3d, + _, + _xys, + _depths, + _radii, + _conics, + _num_tiles_hit, + _masks, + ) = _torch_impl.project_gaussians_forward( + means3d, + scales, + glob_scale, + quats, + viewmat, + fullmat, + (fx, fy, cx, cy), + (W, H), + tile_bounds, + clip_thresh, + ) + + check_close(masks, _masks, atol=1e-5, rtol=1e-5) + check_close(cov3d, _cov3d) + check_close(xys, _xys) + check_close(depths, _depths) + check_close(radii, _radii) + check_close(conics, _conics) + check_close(num_tiles_hit, _num_tiles_hit) + print("passed project_gaussians_forward test") + + +def test_project_gaussians_backward(): + num_points = 100 + + means3d = torch.randn((num_points, 3), device=device, requires_grad=True) + scales = torch.rand((num_points, 3), device=device) + 0.2 + glob_scale = 1.0 + quats = torch.randn((num_points, 4), device=device) + quats /= torch.linalg.norm(quats, dim=-1, keepdim=True) + + H, W = 512, 512 + cx, cy = W / 2, H / 2 + # 90 degree FOV + fx, fy = W / 2, W / 2 + clip_thresh = 0.01 + viewmat = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 8.0], + [0.0, 0.0, 0.0, 1.0], + ], + device=device, + ) + projmat = projection_matrix(fx, fy, W, H) + # projmat = torch.eye(4, device=device) + fullmat = projmat @ viewmat + + BLOCK_X, BLOCK_Y = 16, 16 + tile_bounds = (W + BLOCK_X - 1) // BLOCK_X, (H + BLOCK_Y - 1) // BLOCK_Y, 1 ( - _cov3d, - _xys, - _depths, - _radii, - _conics, - _num_tiles_hit, - _masks, + cov3d, + cov2d, + xys, + depths, + radii, + conics, + _, + masks, ) = _torch_impl.project_gaussians_forward( means3d, scales, glob_scale, quats, viewmat, - projmat, - fx, - fy, - (H, W), + fullmat, + (fx, fy, cx, cy), + (W, H), tile_bounds, clip_thresh, ) - # TODO: failing - # torch.testing.assert_close( - # cov3d[_masks], - # _cov3d.view(-1, 9)[_masks][:, [0, 1, 2, 4, 5, 8]], - # atol=1e-5, - # rtol=1e-5, - # ) - # torch.testing.assert_close( - # xys[_masks], - # _xys[_masks], - # atol=1e-4, - # rtol=1e-4, - # ) - # torch.testing.assert_close(depths[_masks], _depths[_masks]) - # torch.testing.assert_close(radii[_masks], _radii[_masks]) - # torch.testing.assert_close(conics[_masks], _conics[_masks]) - # torch.testing.assert_close(num_tiles_hit[_masks], _num_tiles_hit[_masks]) + # Test backward pass + + v_xys = torch.randn_like(xys) + # v_depths = torch.randn_like(depths) + v_depths = torch.zeros_like(depths) + # scale gradients by pixels to account for finite difference + v_conics = torch.randn_like(conics) * 1e-3 + v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat = _C.project_gaussians_backward( + num_points, + means3d, + scales, + glob_scale, + quats, + viewmat, + fullmat, + fx, + fy, + cx, + cy, + H, + W, + cov3d, + radii, + conics, + v_xys, + v_depths, + v_conics, + ) + + def scale_rot_to_cov3d_partial(scale, quat): + """ + scale (*, 3), quat (*, 3) -> cov3d (upper tri) (*, 6) + """ + cov3d = _torch_impl.scale_rot_to_cov3d(scale, glob_scale, quat) + i, j = torch.triu_indices(3, 3) + cov3d_triu = cov3d[..., i, j] + return cov3d_triu + + def project_cov3d_ewa_partial(mean3d, cov3d): + """ + mean3d (*, 3), cov3d (upper tri) (*, 6) -> cov2d (upper tri) (*, 3) + """ + tan_fovx = 0.5 * W / fx + tan_fovy = 0.5 * H / fy + + cov3d_mat = torch.zeros(*cov3d.shape[:-1], 3, 3, device=device) + i, j = torch.triu_indices(3, 3) + cov3d_mat[..., i, j] = cov3d + cov3d_mat[..., [1, 2, 2], [0, 0, 1]] = cov3d[..., [1, 2, 4]] + cov2d = _torch_impl.project_cov3d_ewa( + mean3d, cov3d_mat, viewmat, fx, fy, tan_fovx, tan_fovy + ) + ii, jj = torch.triu_indices(2, 2) + return cov2d[..., ii, jj] + + def compute_cov2d_bounds_partial(cov2d): + """ + cov2d (upper tri) (*, 3) -> conic (upper tri) (*, 3) + """ + cov2d_mat = torch.zeros(*cov2d.shape[:-1], 2, 2, device=device) + i, j = torch.triu_indices(2, 2) + cov2d_mat[..., i, j] = cov2d + cov2d_mat[..., 1, 0] = cov2d[..., 1] + conic, _, _ = _torch_impl.compute_cov2d_bounds(cov2d_mat) + return conic + + def project_pix_partial(mean3d): + """ + mean3d (*, 3) -> xy (*, 2) + """ + return _torch_impl.project_pix(fullmat, mean3d, (W, H), (cx, cy)) + + def compute_depth_partial(mean3d): + """ + mean3d (*, 3) -> depth (*) + """ + p_view, _ = _torch_impl.clip_near_plane(mean3d, viewmat, clip_thresh) + depth = p_view[..., 2] + return depth + + _, vjp_scale_rot_to_cov3d = vjp(scale_rot_to_cov3d_partial, scales, quats) # type: ignore + _, vjp_project_cov3d_ewa = vjp(project_cov3d_ewa_partial, means3d, cov3d) # type: ignore + _, vjp_compute_cov2d_bounds = vjp(compute_cov2d_bounds_partial, cov2d) # type: ignore + _, vjp_project_pix = vjp(project_pix_partial, means3d) # type: ignore + _, vjp_compute_depth = vjp(compute_depth_partial, means3d) # type: ignore + + _v_cov2d = vjp_compute_cov2d_bounds(v_conics)[0] + _v_mean3d_cov2d, _v_cov3d = vjp_project_cov3d_ewa(_v_cov2d) + _v_mean3d_xy = vjp_project_pix(v_xys)[0] + _v_mean3d_depth = vjp_compute_depth(v_depths)[0] + _v_mean3d = _v_mean3d_cov2d + _v_mean3d_xy + _v_mean3d_depth + _v_scale, _v_quat = vjp_scale_rot_to_cov3d(_v_cov3d) + + atol = 5e-4 + rtol = 1e-5 + check_close(v_cov2d, _v_cov2d, atol=atol, rtol=rtol) + check_close(v_cov3d, _v_cov3d, atol=atol, rtol=rtol) + check_close(v_mean3d[:, :2], _v_mean3d[:, :2], atol=atol, rtol=rtol) + check_close(v_scale, _v_scale, atol=atol, rtol=rtol) + check_close(v_quat, _v_quat, atol=atol, rtol=rtol) + print("passed project_gaussians_backward test") if __name__ == "__main__": test_project_gaussians_forward() + test_project_gaussians_backward()