diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index ca9271e81..9a994d2c2 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -352,7 +352,9 @@ def __init__( scene_scale=self.scene_scale ) elif isinstance(self.cfg.strategy, MCMCStrategy): - self.strategy_state = self.cfg.strategy.initialize_state() + self.strategy_state = self.cfg.strategy.initialize_state( + scene_scale=self.scene_scale + ) else: assert_never(self.cfg.strategy) diff --git a/gsplat/strategy/mcmc.py b/gsplat/strategy/mcmc.py index c07e17376..4b0c3679d 100644 --- a/gsplat/strategy/mcmc.py +++ b/gsplat/strategy/mcmc.py @@ -28,6 +28,8 @@ class MCMCStrategy(Strategy): refine_stop_iter (int): Stop refining GSs after this iteration. Default to 25_000. refine_every (int): Refine GSs every this steps. Default to 100. min_opacity (float): GSs with opacity below this value will be pruned. Default to 0.005. + prune_scale3d (float): GSs with 3d scale (normalized by scene_scale) above this + value will be pruned. Default is 0.1. verbose (bool): Whether to print verbose information. Default to False. Examples: @@ -37,7 +39,7 @@ class MCMCStrategy(Strategy): >>> optimizers: Dict[str, torch.optim.Optimizer] = ... >>> strategy = MCMCStrategy() >>> strategy.check_sanity(params, optimizers) - >>> strategy_state = strategy.initialize_state() + >>> strategy_state = strategy.initialize_state(scene_scale=scene_scale) >>> for step in range(1000): ... render_image, render_alpha, info = rasterization(...) ... loss = ... @@ -52,16 +54,17 @@ class MCMCStrategy(Strategy): refine_stop_iter: int = 25_000 refine_every: int = 100 min_opacity: float = 0.005 + prune_scale3d: float = 0.1 verbose: bool = False - def initialize_state(self) -> Dict[str, Any]: + def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]: """Initialize and return the running state for this strategy.""" n_max = 51 binoms = torch.zeros((n_max, n_max)) for n in range(n_max): for k in range(n + 1): binoms[n, k] = math.comb(n, k) - return {"binoms": binoms} + return {"binoms": binoms, "scene_scale": scene_scale} def check_sanity( self, @@ -125,7 +128,7 @@ def step_post_backward( and step % self.refine_every == 0 ): # teleport GSs - n_relocated_gs = self._relocate_gs(params, optimizers, binoms) + n_relocated_gs = self._relocate_gs(params, optimizers, binoms, state) if self.verbose: print(f"Step {step}: Relocated {n_relocated_gs} GSs.") @@ -150,9 +153,15 @@ def _relocate_gs( params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], optimizers: Dict[str, torch.optim.Optimizer], binoms: Tensor, + state: Dict[str, Any], ) -> int: opacities = torch.sigmoid(params["opacities"].flatten()) dead_mask = opacities <= self.min_opacity + is_too_big = ( + torch.exp(params["scales"]).max(dim=-1).values + > self.prune_scale3d * state["scene_scale"] + ) + dead_mask = dead_mask | is_too_big n_gs = dead_mask.sum().item() if n_gs > 0: relocate(