Skip to content

Commit

Permalink
Add prune_scale3d argument to MCMCStrategy
Browse files Browse the repository at this point in the history
  • Loading branch information
bchretien committed Feb 24, 2025
1 parent ddf88c6 commit 3649685
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
4 changes: 3 additions & 1 deletion examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,9 @@ def __init__(
scene_scale=self.scene_scale
)
elif isinstance(self.cfg.strategy, MCMCStrategy):
self.strategy_state = self.cfg.strategy.initialize_state()
self.strategy_state = self.cfg.strategy.initialize_state(
scene_scale=self.scene_scale
)
else:
assert_never(self.cfg.strategy)

Expand Down
17 changes: 13 additions & 4 deletions gsplat/strategy/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class MCMCStrategy(Strategy):
refine_stop_iter (int): Stop refining GSs after this iteration. Default to 25_000.
refine_every (int): Refine GSs every this steps. Default to 100.
min_opacity (float): GSs with opacity below this value will be pruned. Default to 0.005.
prune_scale3d (float): GSs with 3d scale (normalized by scene_scale) above this
value will be pruned. Default is 0.1.
verbose (bool): Whether to print verbose information. Default to False.
Examples:
Expand All @@ -37,7 +39,7 @@ class MCMCStrategy(Strategy):
>>> optimizers: Dict[str, torch.optim.Optimizer] = ...
>>> strategy = MCMCStrategy()
>>> strategy.check_sanity(params, optimizers)
>>> strategy_state = strategy.initialize_state()
>>> strategy_state = strategy.initialize_state(scene_scale=scene_scale)
>>> for step in range(1000):
... render_image, render_alpha, info = rasterization(...)
... loss = ...
Expand All @@ -52,16 +54,17 @@ class MCMCStrategy(Strategy):
refine_stop_iter: int = 25_000
refine_every: int = 100
min_opacity: float = 0.005
prune_scale3d: float = 0.1
verbose: bool = False

def initialize_state(self) -> Dict[str, Any]:
def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]:
"""Initialize and return the running state for this strategy."""
n_max = 51
binoms = torch.zeros((n_max, n_max))
for n in range(n_max):
for k in range(n + 1):
binoms[n, k] = math.comb(n, k)
return {"binoms": binoms}
return {"binoms": binoms, "scene_scale": scene_scale}

def check_sanity(
self,
Expand Down Expand Up @@ -125,7 +128,7 @@ def step_post_backward(
and step % self.refine_every == 0
):
# teleport GSs
n_relocated_gs = self._relocate_gs(params, optimizers, binoms)
n_relocated_gs = self._relocate_gs(params, optimizers, binoms, state)
if self.verbose:
print(f"Step {step}: Relocated {n_relocated_gs} GSs.")

Expand All @@ -150,9 +153,15 @@ def _relocate_gs(
params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
optimizers: Dict[str, torch.optim.Optimizer],
binoms: Tensor,
state: Dict[str, Any],
) -> int:
opacities = torch.sigmoid(params["opacities"].flatten())
dead_mask = opacities <= self.min_opacity
is_too_big = (
torch.exp(params["scales"]).max(dim=-1).values
> self.prune_scale3d * state["scene_scale"]
)
dead_mask = dead_mask | is_too_big
n_gs = dead_mask.sum().item()
if n_gs > 0:
relocate(
Expand Down

0 comments on commit 3649685

Please sign in to comment.