From bd64a47414e182dc105c1a2fdb6691068518d060 Mon Sep 17 00:00:00 2001 From: FantasticOven2 <91100968+FantasticOven2@users.noreply.github.com> Date: Mon, 24 Feb 2025 10:09:59 -0800 Subject: [PATCH] minor fix to 2dgs corner case (#543) * minor fix to 2dgs corner case * reformat * torch.zeros --- gsplat/cuda/_wrapper.py | 4 ++-- gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu | 4 ++-- gsplat/rendering.py | 4 +++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 5a1065af0..1d1167867 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -1695,14 +1695,14 @@ def rasterize_to_pixels_2dgs( if channels not in (1, 2, 3, 4, 8, 16, 32, 64, 128, 256, 512): padded_channels = (1 << (channels - 1).bit_length()) - channels colors = torch.cat( - [colors, torch.empty(*colors.shape[:-1], padded_channels, device=device)], + [colors, torch.zeros(*colors.shape[:-1], padded_channels, device=device)], dim=-1, ) if backgrounds is not None: backgrounds = torch.cat( [ backgrounds, - torch.empty( + torch.zeros( *backgrounds.shape[:-1], padded_channels, device=device ), ], diff --git a/gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu index bc56528ff..7cf971b19 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu @@ -272,9 +272,9 @@ fully_fused_projection_fwd_2dgs_tensor( torch::Tensor radii = torch::empty({C, N}, means.options().dtype(torch::kInt32)); torch::Tensor means2d = torch::empty({C, N, 2}, means.options()); - torch::Tensor depths = torch::empty({C, N}, means.options()); + torch::Tensor depths = torch::zeros({C, N}, means.options()); torch::Tensor ray_transforms = torch::empty({C, N, 3, 3}, means.options()); - torch::Tensor normals = torch::empty({C, N, 3}, means.options()); + torch::Tensor normals = torch::zeros({C, N, 3}, means.options()); if (C && N) { fully_fused_projection_fwd_2dgs_kernel diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 78da64abf..e72536015 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -1233,7 +1233,9 @@ def rasterization_2dgs( # Rasterize to pixels if render_mode in ["RGB+D", "RGB+ED"]: colors = torch.cat((colors, depths[..., None]), dim=-1) - # backgrounds = torch.cat((backgrounds, torch.zeros((C, 1), device="cuda")), dim=-1) + backgrounds = torch.cat( + (backgrounds, torch.zeros((C, 1), device=colors.device)), dim=-1 + ) elif render_mode in ["D", "ED"]: colors = depths[..., None] else: # RGB