Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jefequien committed Oct 29, 2024
1 parent 8a02e74 commit c8fd74d
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 43 deletions.
5 changes: 1 addition & 4 deletions examples/benchmarks/mcmc_deblur.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
SCENE_DIR="data/deblur_dataset/real_defocus_blur"
SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools"
SCENE_LIST="defocuscake defocustools defocussausage defocuscupcake defocuscups defocuscoral defocusdaisy defocusseal defocuscaps defocuscisco"

DATA_FACTOR=4
RENDER_TRAJ_PATH="spiral"
CAP_MAX=250000

RESULT_DIR="results/benchmark_mcmc_deblur/c0.2_a10"
RESULT_DIR="results/benchmark_mcmc_deblur"
for SCENE in $SCENE_LIST;
do
echo "Running $SCENE"
Expand All @@ -15,8 +14,6 @@ do
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \
--strategy.cap-max $CAP_MAX \
--blur_opt \
--blur_a 10 \
--blur_c 0.2 \
--render_traj_path $RENDER_TRAJ_PATH \
--data_dir $SCENE_DIR/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE
Expand Down
17 changes: 4 additions & 13 deletions examples/blur_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,15 @@
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from kornia.filters import median_blur
from examples.mlp import create_mlp
from gsplat.utils import log_transform


class BlurOptModule(nn.Module):
"""Blur optimization module."""

def __init__(self, cfg, n: int, embed_dim: int = 4):
def __init__(self, n: int, embed_dim: int = 4):
super().__init__()
self.a = cfg.blur_a
self.c = cfg.blur_c

self.embeds = torch.nn.Embedding(n, embed_dim)
self.means_encoder = get_encoder(3, 3)
self.depths_encoder = get_encoder(3, 1)
Expand Down Expand Up @@ -80,15 +76,10 @@ def mask_mean_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2):
x = blur_mask.mean()
if step <= 2000:
a = 20
b = 1
c = 0.2
else:
a = self.a
b = 1
c = self.c
print(x.item(), a, b, c)
meanloss = a * (1 / (1 - x + eps) - 1) + b * (1 / (x + eps) - 1)
return c * meanloss
a = 10
meanloss = a * (1 / (1 - x + eps) - 1) + (1 / (x + eps) - 1)
return meanloss


def get_encoder(num_freqs: int, input_dims: int):
Expand Down
37 changes: 11 additions & 26 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class Config:
# Number of training steps
max_steps: int = 30_000
# Steps to evaluate the model
eval_steps: List[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Steps to save the model
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])

Expand Down Expand Up @@ -153,10 +153,7 @@ class Config:
# Learning rate for blur optimization
blur_opt_lr: float = 1e-3
# Regularization for blur mask mean
blur_mean_reg: float = 0.001
# Regularization for blur mask smoothness
blur_a: float = 4
blur_c: float = 0.5
blur_mean_reg: float = 0.0002

# Enable bilateral grid. (experimental)
use_bilateral_grid: bool = False
Expand Down Expand Up @@ -655,6 +652,7 @@ def train(self):
bkgd = torch.rand(1, 3, device=device)
colors = colors + bkgd * (1.0 - alphas)
if cfg.blur_opt:
blur_mask = self.blur_module.predict_mask(image_ids, depths)
renders_blur, _, _ = self.rasterize_splats(
camtoworlds=camtoworlds,
Ks=Ks,
Expand All @@ -664,16 +662,11 @@ def train(self):
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
image_ids=image_ids,
render_mode="RGB+ED",
render_mode="RGB",
masks=masks,
blur=True,
)
colors_blur, depths_blur = (
renders_blur[..., 0:3],
renders_blur[..., 3:4],
)
blur_mask = self.blur_module.predict_mask(image_ids, depths)
colors = (1 - blur_mask) * colors + blur_mask * colors_blur
colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3]

self.cfg.strategy.step_pre_backward(
params=self.splats,
Expand Down Expand Up @@ -871,10 +864,8 @@ def train(self):
# eval the full set
if step in [i - 1 for i in cfg.eval_steps]:
self.eval(step, stage="train")
self.eval(step)
self.eval(step, stage="val")
self.render_traj(step)
if (step + 1) % 1000 == 0 or step == 0:
self.eval(step, stage="train", vis_skip=True)

# run compression
if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]:
Expand All @@ -892,7 +883,7 @@ def train(self):
self.viewer.update(step, num_train_rays_per_step)

@torch.no_grad()
def eval(self, step: int, stage: str = "val", vis_skip: bool = False):
def eval(self, step: int, stage: str = "val"):
"""Entry for evaluation."""
print("Running evaluation...")
cfg = self.cfg
Expand All @@ -904,13 +895,10 @@ def eval(self, step: int, stage: str = "val", vis_skip: bool = False):
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=False, num_workers=1
)
train_vis_image_ids = np.linspace(0, len(dataloader) - 1, 7).astype(int)

ellipse_time = 0
metrics = defaultdict(list)
for i, data in enumerate(dataloader):
if vis_skip and stage == "train" and i not in train_vis_image_ids:
continue
camtoworlds = data["camtoworld"].to(device)
Ks = data["K"].to(device)
pixels = data["image"].to(device) / 255.0
Expand Down Expand Up @@ -943,6 +931,8 @@ def eval(self, step: int, stage: str = "val", vis_skip: bool = False):
colors = torch.clamp(colors, 0.0, 1.0)
canvas_list = [pixels, colors]
if self.cfg.blur_opt and stage == "train":
blur_mask = self.blur_module.predict_mask(image_ids, depths)
canvas_list.append(blur_mask.repeat(1, 1, 1, 3))
renders_blur, _, _ = self.rasterize_splats(
camtoworlds=camtoworlds,
Ks=Ks,
Expand All @@ -952,16 +942,11 @@ def eval(self, step: int, stage: str = "val", vis_skip: bool = False):
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
image_ids=image_ids,
render_mode="RGB+ED",
render_mode="RGB",
masks=masks,
blur=True,
)
colors_blur, depths_blur = (
renders_blur[..., 0:3],
renders_blur[..., 3:4],
)
blur_mask = self.blur_module.predict_mask(image_ids, depths)
canvas_list.append(blur_mask.repeat(1, 1, 1, 3))
colors_blur = renders_blur[..., 0:3]
canvas_list.append(torch.clamp(colors_blur, 0.0, 1.0))
colors = (1 - blur_mask) * colors + blur_mask * colors_blur
colors = torch.clamp(colors, 0.0, 1.0)
Expand Down

0 comments on commit c8fd74d

Please sign in to comment.