Skip to content

Commit

Permalink
minor fix to 2dgs corner case (#543)
Browse files Browse the repository at this point in the history
* minor fix to 2dgs corner case

* reformat

* torch.zeros
  • Loading branch information
FantasticOven2 authored Feb 24, 2025
1 parent ddf88c6 commit bd64a47
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
4 changes: 2 additions & 2 deletions gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,14 +1695,14 @@ def rasterize_to_pixels_2dgs(
if channels not in (1, 2, 3, 4, 8, 16, 32, 64, 128, 256, 512):
padded_channels = (1 << (channels - 1).bit_length()) - channels
colors = torch.cat(
[colors, torch.empty(*colors.shape[:-1], padded_channels, device=device)],
[colors, torch.zeros(*colors.shape[:-1], padded_channels, device=device)],
dim=-1,
)
if backgrounds is not None:
backgrounds = torch.cat(
[
backgrounds,
torch.empty(
torch.zeros(
*backgrounds.shape[:-1], padded_channels, device=device
),
],
Expand Down
4 changes: 2 additions & 2 deletions gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,9 @@ fully_fused_projection_fwd_2dgs_tensor(
torch::Tensor radii =
torch::empty({C, N}, means.options().dtype(torch::kInt32));
torch::Tensor means2d = torch::empty({C, N, 2}, means.options());
torch::Tensor depths = torch::empty({C, N}, means.options());
torch::Tensor depths = torch::zeros({C, N}, means.options());
torch::Tensor ray_transforms = torch::empty({C, N, 3, 3}, means.options());
torch::Tensor normals = torch::empty({C, N, 3}, means.options());
torch::Tensor normals = torch::zeros({C, N, 3}, means.options());

if (C && N) {
fully_fused_projection_fwd_2dgs_kernel<float>
Expand Down
4 changes: 3 additions & 1 deletion gsplat/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,9 @@ def rasterization_2dgs(
# Rasterize to pixels
if render_mode in ["RGB+D", "RGB+ED"]:
colors = torch.cat((colors, depths[..., None]), dim=-1)
# backgrounds = torch.cat((backgrounds, torch.zeros((C, 1), device="cuda")), dim=-1)
backgrounds = torch.cat(

This comment has been minimized.

Copy link
@KacperKazan

KacperKazan Feb 27, 2025

this is missing the check if backgrounds is not None as is in the rasterization function above

(backgrounds, torch.zeros((C, 1), device=colors.device)), dim=-1
)
elif render_mode in ["D", "ED"]:
colors = depths[..., None]
else: # RGB
Expand Down

0 comments on commit bd64a47

Please sign in to comment.