Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/ubc-vision/nf-soft-mining i…
Browse files Browse the repository at this point in the history
…nto main
  • Loading branch information
Shakiba Kheradmand committed Jun 29, 2024
2 parents 2543dd5 + 3b5c99a commit 7365b8e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 11 deletions.
15 changes: 9 additions & 6 deletions examples/losses.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import torch
import torch.nn as nn

from nerfacc.losses import DistortionLoss
class NeRFLoss(nn.Module):
def __init__(self, lambda_opacity=0.0, lambda_distortion=0.01):
def __init__(self, lambda_opacity=0.0, lambda_distortion=0.0):
super().__init__()

self.lambda_opacity = lambda_opacity
Expand All @@ -18,9 +17,13 @@ def forward(self, rgb, target, opp=None, distkwargs=None):
# encourage opacity to be either 0 or 1 to avoid floater
d['opacity'] = self.lambda_opacity*(-o*torch.log(o))

if self.lambda_distortion > 0 and distkwargs is not None:
d['distortion'] = self.lambda_distortion * \
DistortionLoss.apply(distkwargs['ws'], distkwargs['deltas'],
distkwargs['ts'], distkwargs['rays_a'])
if self.lambda_distortion > 0:
raise NotImplementedError
# TODO(Shakiba): push distortion loss code to nerfacc
# if self.lambda_distortion > 0 and distkwargs is not None:
# from nerfacc.losses import DistortionLoss
# d['distortion'] = self.lambda_distortion * \
# DistortionLoss.apply(distkwargs['ws'], distkwargs['deltas'],
# distkwargs['ts'], distkwargs['rays_a'])

return d
8 changes: 4 additions & 4 deletions examples/train_ngp_nerf_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
parser.add_argument(
"--data_root",
type=str,
default=str("/ubc/cs/research/kmyi/shakiba/g/data/nerf_llff_data"),
# default=str("../../data/nerf_synthetic"),
# default=str("../../data/nerf_llff_data"),
default=str("../../data/nerf_synthetic"),
help="the root dir of the dataset",
)
parser.add_argument(
Expand All @@ -47,7 +47,7 @@
parser.add_argument(
"--scene",
type=str,
default="trex",
default="mic",
choices=NERF_SYNTHETIC_SCENES + LLFF_NDC_SCENES,
help="which scene to use",
)
Expand Down Expand Up @@ -200,7 +200,7 @@
lpips_norm_fn = lambda x: x[None, ...].permute(0, 3, 1, 2) * 2 - 1
lpips_fn = lambda x, y: lpips_net(lpips_norm_fn(x), lpips_norm_fn(y)).mean()

loss_fn = NeRFLoss(lambda_distortion=1e-1, lambda_opacity=1e-3)
loss_fn = NeRFLoss()

gradval = None
lossperpix_prev = None
Expand Down
13 changes: 12 additions & 1 deletion examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@
"stump",
]

LLFF_NDC_SCENES = [
"fern",
"flower",
"fortress",
"horns",
"leaves",
"orchids",
"room_llff",
"trex",
]


def set_random_seed(seed):
random.seed(seed)
Expand Down Expand Up @@ -184,7 +195,7 @@ def render_image_with_propnet(

def prop_sigma_fn(t_starts, t_ends, proposal_network):
t_origins = chunk_rays.origins[..., None, :]
t_dirs = chunk_rays.viewdirs[..., None, :]
t_dirs = chunk_rays.viewdirs[..., None, :].detach()
positions = t_origins + t_dirs * (t_starts + t_ends)[..., None] / 2.0
sigmas = proposal_network(positions)
if opaque_bkgd:
Expand Down

0 comments on commit 7365b8e

Please sign in to comment.