Skip to content

Commit

Permalink
fix simple trainer (#105)
Browse files Browse the repository at this point in the history
Co-authored-by: Quei-An Chen <[email protected]>
  • Loading branch information
kwea123 and Quei-An Chen authored Jan 21, 2024
1 parent 3776352 commit 97732cd
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import numpy as np
import torch
import tyro
from gsplat.project_gaussians import ProjectGaussians
from gsplat.rasterize import RasterizeGaussians
from gsplat.project_gaussians import _ProjectGaussians
from gsplat.rasterize import _RasterizeGaussians
from PIL import Image
from torch import Tensor, optim

Expand Down Expand Up @@ -71,6 +71,8 @@ def _init_gaussians(self):
],
device=self.device,
)
self.background = torch.zeros(3, device=self.device)

self.means.requires_grad = True
self.scales.requires_grad = True
self.quats.requires_grad = True
Expand All @@ -87,7 +89,7 @@ def train(self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = Fals
times = [0] * 3 # project, rasterize, backward
for iter in range(iterations):
start = time.time()
xys, depths, radii, conics, num_tiles_hit, cov3d = ProjectGaussians.apply(
xys, depths, radii, conics, num_tiles_hit, cov3d = _ProjectGaussians.apply(
self.means,
self.scales,
1,
Expand All @@ -105,7 +107,7 @@ def train(self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = Fals
torch.cuda.synchronize()
times[0] += time.time() - start
start = time.time()
out_img = RasterizeGaussians.apply(
out_img = _RasterizeGaussians.apply(
xys,
depths,
radii,
Expand All @@ -115,6 +117,7 @@ def train(self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = Fals
torch.sigmoid(self.opacities),
self.H,
self.W,
self.background,
)
torch.cuda.synchronize()
times[1] += time.time() - start
Expand Down

0 comments on commit 97732cd

Please sign in to comment.