From c8fd74dacea9afca82fb06b3b5bdc7ad3435b23f Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 28 Oct 2024 17:01:44 -0700 Subject: [PATCH] cleanup --- examples/benchmarks/mcmc_deblur.sh | 5 +--- examples/blur_kernel.py | 17 ++++---------- examples/simple_trainer.py | 37 +++++++++--------------------- 3 files changed, 16 insertions(+), 43 deletions(-) diff --git a/examples/benchmarks/mcmc_deblur.sh b/examples/benchmarks/mcmc_deblur.sh index d497515f9..c3d5a2822 100644 --- a/examples/benchmarks/mcmc_deblur.sh +++ b/examples/benchmarks/mcmc_deblur.sh @@ -1,12 +1,11 @@ SCENE_DIR="data/deblur_dataset/real_defocus_blur" SCENE_LIST="defocuscake defocuscaps defocuscisco defocuscoral defocuscupcake defocuscups defocusdaisy defocussausage defocusseal defocustools" -SCENE_LIST="defocuscake defocustools defocussausage defocuscupcake defocuscups defocuscoral defocusdaisy defocusseal defocuscaps defocuscisco" DATA_FACTOR=4 RENDER_TRAJ_PATH="spiral" CAP_MAX=250000 -RESULT_DIR="results/benchmark_mcmc_deblur/c0.2_a10" +RESULT_DIR="results/benchmark_mcmc_deblur" for SCENE in $SCENE_LIST; do echo "Running $SCENE" @@ -15,8 +14,6 @@ do CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ --strategy.cap-max $CAP_MAX \ --blur_opt \ - --blur_a 10 \ - --blur_c 0.2 \ --render_traj_path $RENDER_TRAJ_PATH \ --data_dir $SCENE_DIR/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE diff --git a/examples/blur_kernel.py b/examples/blur_kernel.py index a28f3244c..90b5206b7 100644 --- a/examples/blur_kernel.py +++ b/examples/blur_kernel.py @@ -2,7 +2,6 @@ import torch.nn as nn from torch import Tensor import torch.nn.functional as F -from kornia.filters import median_blur from examples.mlp import create_mlp from gsplat.utils import log_transform @@ -10,11 +9,8 @@ class BlurOptModule(nn.Module): """Blur optimization module.""" - def __init__(self, cfg, n: int, embed_dim: int = 4): + def __init__(self, n: int, embed_dim: int = 4): super().__init__() - self.a = cfg.blur_a - self.c = cfg.blur_c - self.embeds = torch.nn.Embedding(n, embed_dim) self.means_encoder = get_encoder(3, 3) self.depths_encoder = get_encoder(3, 1) @@ -80,15 +76,10 @@ def mask_mean_loss(self, blur_mask: Tensor, step: int, eps: float = 1e-2): x = blur_mask.mean() if step <= 2000: a = 20 - b = 1 - c = 0.2 else: - a = self.a - b = 1 - c = self.c - print(x.item(), a, b, c) - meanloss = a * (1 / (1 - x + eps) - 1) + b * (1 / (x + eps) - 1) - return c * meanloss + a = 10 + meanloss = a * (1 / (1 - x + eps) - 1) + (1 / (x + eps) - 1) + return meanloss def get_encoder(num_freqs: int, input_dims: int): diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 1cd997697..8e3696790 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -83,7 +83,7 @@ class Config: # Number of training steps max_steps: int = 30_000 # Steps to evaluate the model - eval_steps: List[int] = field(default_factory=lambda: [7_000, 15_000, 30_000]) + eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # Steps to save the model save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) @@ -153,10 +153,7 @@ class Config: # Learning rate for blur optimization blur_opt_lr: float = 1e-3 # Regularization for blur mask mean - blur_mean_reg: float = 0.001 - # Regularization for blur mask smoothness - blur_a: float = 4 - blur_c: float = 0.5 + blur_mean_reg: float = 0.0002 # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False @@ -655,6 +652,7 @@ def train(self): bkgd = torch.rand(1, 3, device=device) colors = colors + bkgd * (1.0 - alphas) if cfg.blur_opt: + blur_mask = self.blur_module.predict_mask(image_ids, depths) renders_blur, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, @@ -664,16 +662,11 @@ def train(self): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB+ED", + render_mode="RGB", masks=masks, blur=True, ) - colors_blur, depths_blur = ( - renders_blur[..., 0:3], - renders_blur[..., 3:4], - ) - blur_mask = self.blur_module.predict_mask(image_ids, depths) - colors = (1 - blur_mask) * colors + blur_mask * colors_blur + colors = (1 - blur_mask) * colors + blur_mask * renders_blur[..., 0:3] self.cfg.strategy.step_pre_backward( params=self.splats, @@ -871,10 +864,8 @@ def train(self): # eval the full set if step in [i - 1 for i in cfg.eval_steps]: self.eval(step, stage="train") - self.eval(step) + self.eval(step, stage="val") self.render_traj(step) - if (step + 1) % 1000 == 0 or step == 0: - self.eval(step, stage="train", vis_skip=True) # run compression if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: @@ -892,7 +883,7 @@ def train(self): self.viewer.update(step, num_train_rays_per_step) @torch.no_grad() - def eval(self, step: int, stage: str = "val", vis_skip: bool = False): + def eval(self, step: int, stage: str = "val"): """Entry for evaluation.""" print("Running evaluation...") cfg = self.cfg @@ -904,13 +895,10 @@ def eval(self, step: int, stage: str = "val", vis_skip: bool = False): dataloader = torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, num_workers=1 ) - train_vis_image_ids = np.linspace(0, len(dataloader) - 1, 7).astype(int) ellipse_time = 0 metrics = defaultdict(list) for i, data in enumerate(dataloader): - if vis_skip and stage == "train" and i not in train_vis_image_ids: - continue camtoworlds = data["camtoworld"].to(device) Ks = data["K"].to(device) pixels = data["image"].to(device) / 255.0 @@ -943,6 +931,8 @@ def eval(self, step: int, stage: str = "val", vis_skip: bool = False): colors = torch.clamp(colors, 0.0, 1.0) canvas_list = [pixels, colors] if self.cfg.blur_opt and stage == "train": + blur_mask = self.blur_module.predict_mask(image_ids, depths) + canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) renders_blur, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, @@ -952,16 +942,11 @@ def eval(self, step: int, stage: str = "val", vis_skip: bool = False): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB+ED", + render_mode="RGB", masks=masks, blur=True, ) - colors_blur, depths_blur = ( - renders_blur[..., 0:3], - renders_blur[..., 3:4], - ) - blur_mask = self.blur_module.predict_mask(image_ids, depths) - canvas_list.append(blur_mask.repeat(1, 1, 1, 3)) + colors_blur = renders_blur[..., 0:3] canvas_list.append(torch.clamp(colors_blur, 0.0, 1.0)) colors = (1 - blur_mask) * colors + blur_mask * colors_blur colors = torch.clamp(colors, 0.0, 1.0)