From c6a57b35cf8fc55c4525de4871fb2304d5e6ed1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Devernay?= Date: Thu, 20 Apr 2023 16:26:39 -0700 Subject: [PATCH] add support for "vertical" orientation and "focus" centering methods - See https://github.com/nerfstudio-project/nerfstudio/pull/1543 - Also default to "vertical" which works better than "up" (https://github.com/nerfstudio-project/nerfstudio/issues/1765) - Also rename train_split_percentage to train_split_fraction (https://github.com/nerfstudio-project/nerfstudio/pull/1497) --- .../developer_guides/pipelines/dataparsers.md | 10 +- nerfstudio/cameras/camera_utils.py | 107 ++++++++++++++++-- .../data/dataparsers/heritage_dataparser.py | 16 ++- .../data/dataparsers/mipnerf360_dataparser.py | 8 +- .../data/dataparsers/nerfstudio_dataparser.py | 14 +-- .../data/dataparsers/nuscenes_dataparser.py | 8 +- .../dataparsers/phototourism_dataparser.py | 14 +-- .../data/dataparsers/record3d_dataparser.py | 4 +- .../data/dataparsers/sdfstudio_dataparser.py | 4 +- 9 files changed, 135 insertions(+), 50 deletions(-) diff --git a/docs/developer_guides/pipelines/dataparsers.md b/docs/developer_guides/pipelines/dataparsers.md index 458abdeb..a6015204 100644 --- a/docs/developer_guides/pipelines/dataparsers.md +++ b/docs/developer_guides/pipelines/dataparsers.md @@ -67,10 +67,14 @@ class NerfstudioDataParserConfig(DataParserConfig): """How much to downscale images. If not set, images are chosen such that the max dimension is <1600px.""" scene_scale: float = 1.0 """How much to scale the region of interest by.""" - orientation_method: Literal["pca", "up"] = "up" + orientation_method: Literal["pca", "up", "vertical", "none"] = "vertical" """The method to use for orientation.""" - train_split_percentage: float = 0.9 - """The percent of images to use for training. The remaining images are for eval.""" + center_method: Literal["poses", "focus", "none"] = "poses" + """The method to use to center the poses.""" + auto_scale_poses: bool = True + """Whether to automatically scale the poses to fit in +/- 1 bounding box.""" + train_split_fraction: float = 0.9 + """The fraction of images to use for training. The remaining images are for eval.""" @dataclass class Nerfstudio(DataParser): diff --git a/nerfstudio/cameras/camera_utils.py b/nerfstudio/cameras/camera_utils.py index 2adaffd6..3c8e2e76 100644 --- a/nerfstudio/cameras/camera_utils.py +++ b/nerfstudio/cameras/camera_utils.py @@ -407,35 +407,86 @@ def rotation_matrix(a: TensorType[3], b: TensorType[3]) -> TensorType[3, 3]: return torch.eye(3) + skew_sym_mat + skew_sym_mat @ skew_sym_mat * ((1 - c) / (s**2 + 1e-8)) +def focus_of_attention(poses: TensorType["num_poses":..., 4, 4], initial_focus: TensorType[3]) -> TensorType[3]: + """Compute the focus of attention of a set of cameras. Only cameras + that have the focus of attention in front of them are considered. + Args: + poses: The poses to orient. + initial_focus: The 3D point views to decide which cameras are initially activated. + Returns: + The 3D position of the focus of attention. + """ + # References to the same method in third-party code: + # https://github.com/google-research/multinerf/blob/1c8b1c552133cdb2de1c1f3c871b2813f6662265/internal/camera_utils.py#L145 + # https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/load_llff.py#L197 + active_directions = -poses[:, :3, 2:3] + active_origins = poses[:, :3, 3:4] + # initial value for testing if the focus_pt is in front or behind + focus_pt = initial_focus + # Prune cameras which have the current have the focus_pt behind them. + active = torch.sum(active_directions.squeeze(-1) * (focus_pt - active_origins.squeeze(-1)), dim=-1) > 0 + done = False + # We need at least two active cameras, else fallback on the previous solution. + # This may be the "poses" solution if no cameras are active on first iteration, e.g. + # they are in an outward-looking configuration. + while torch.sum(active.int()) > 1 and not done: + active_directions = active_directions[active] + active_origins = active_origins[active] + # https://en.wikipedia.org/wiki/Line–line_intersection#In_more_than_two_dimensions + m = torch.eye(3) - active_directions * torch.transpose(active_directions, -2, -1) + mt_m = torch.transpose(m, -2, -1) @ m + focus_pt = torch.linalg.inv(mt_m.mean(0)) @ (mt_m @ active_origins).mean(0)[:, 0] + active = torch.sum(active_directions.squeeze(-1) * (focus_pt - active_origins.squeeze(-1)), dim=-1) > 0 + if active.all(): + # the set of active cameras did not change, so we're done. + done = True + return focus_pt + + def auto_orient_and_center_poses( - poses: TensorType["num_poses":..., 4, 4], method: Literal["pca", "up", "none"] = "up", center_poses: bool = True + poses: TensorType["num_poses":..., 4, 4], method: Literal["pca", "up", "vertical", "none"] = "vertical", + center_method: Literal["poses", "focus", "none"] = "poses", ) -> TensorType["num_poses":..., 3, 4]: """Orients and centers the poses. We provide two methods for orientation: pca and up. - pca: Orient the poses so that the principal component of the points is aligned with the axes. - This method works well when all of the cameras are in the same plane. + pca: Orient the poses so that the principal directions of the camera centers are aligned + with the axes, Z corresponding to the smallest principal component. + This method works well when all of the cameras are in the same plane, for example when + images are taken using a mobile robot. up: Orient the poses so that the average up vector is aligned with the z axis. This method works well when images are not at arbitrary angles. + vertical: Orient the poses so that the Z 3D direction projects close to the + y axis in images. This method works better if cameras are not all + looking in the same 3D direction, which may happen in camera arrays or in LLFF. + There are two centering methods: + poses: The poses are centered around the origin. + focus: The origin is set to the focus of attention of all cameras (the + closest point to cameras optical axes). Recommended for inward-looking + camera configurations. Args: poses: The poses to orient. method: The method to use for orientation. - center_poses: If True, the poses are centered around the origin. + center_method: The method to use to center the poses. Returns: The oriented poses. """ - translation = poses[..., :3, 3] + origins = poses[..., :3, 3] - mean_translation = torch.mean(translation, dim=0) - translation_diff = translation - mean_translation + mean_origin = torch.mean(origins, dim=0) + translation_diff = origins - mean_origin - if center_poses: - translation = mean_translation + if center_method == "poses": + translation = mean_origin + elif center_method == "focus": + translation = focus_of_attention(poses, mean_origin) + elif center_method == "none": + translation = torch.zeros_like(mean_origin) else: - translation = torch.zeros_like(mean_translation) + raise ValueError(f"Unknown value for center_method: {center_method}") if method == "pca": _, eigvec = torch.linalg.eigh(translation_diff.T @ translation_diff) @@ -449,9 +500,41 @@ def auto_orient_and_center_poses( if oriented_poses.mean(axis=0)[2, 1] < 0: oriented_poses[:, 1:3] = -1 * oriented_poses[:, 1:3] - elif method == "up": + elif method in ("up", "vertical"): up = torch.mean(poses[:, :3, 1], dim=0) up = up / torch.linalg.norm(up) + if method == "vertical": + # If cameras are not all parallel (e.g. not in an LLFF configuration), + # we can find the 3D direction that most projects vertically in all + # cameras by minimizing ||Xu|| s.t. ||u||=1. This total least squares + # problem is solved by SVD. + x_axis_matrix = poses[:, :3, 0] + _, S, Vh = torch.linalg.svd(x_axis_matrix, full_matrices=False) + # Singular values are S_i=||Xv_i|| for each right singular vector v_i. + # ||S|| = sqrt(n) because lines of X are all unit vectors and the v_i + # are an orthonormal basis. + # ||Xv_i|| = sqrt(sum(dot(x_axis_j,v_i)^2)), thus S_i/sqrt(n) is the + # RMS of cosines between x axes and v_i. If the second smallest singular + # value corresponds to an angle error less than 10° (cos(80°)=0.17), + # this is probably a degenerate camera configuration (typical values + # are around 5° average error for the true vertical). In this case, + # rather than taking the vector corresponding to the smallest singular + # value, we project the "up" vector on the plane spanned by the two + # best singular vectors. We could also just fallback to the "up" + # solution. + if S[1] > 0.17 * math.sqrt(poses.shape[0]): + # regular non-degenerate configuration + up_vertical = Vh[2, :] + # It may be pointing up or down. Use "up" to disambiguate the sign. + up = up_vertical if torch.dot(up_vertical, up) > 0 else -up_vertical + else: + # Degenerate configuration: project "up" on the plane spanned by + # the last two right singular vectors (which are orthogonal to the + # first). v_0 is a unit vector, no need to divide by its norm when + # projecting. + up = up - Vh[0, :] * torch.dot(up, Vh[0, :]) + # re-normalize + up = up / torch.linalg.norm(up) rotation = rotation_matrix(up, torch.Tensor([0, 0, 1])) transform = torch.cat([rotation, rotation @ -translation[..., None]], dim=-1) @@ -461,5 +544,7 @@ def auto_orient_and_center_poses( transform[:3, 3] = -translation transform = transform[:3, :] oriented_poses = transform @ poses + else: + raise ValueError(f"Unknown value for method: {method}") return oriented_poses, transform diff --git a/nerfstudio/data/dataparsers/heritage_dataparser.py b/nerfstudio/data/dataparsers/heritage_dataparser.py index 0d84a7e1..9a899900 100644 --- a/nerfstudio/data/dataparsers/heritage_dataparser.py +++ b/nerfstudio/data/dataparsers/heritage_dataparser.py @@ -80,16 +80,14 @@ class HeritageDataParserConfig(DataParserConfig): """How much to scale the camera origins by.""" alpha_color: str = "white" """alpha color of background""" - train_split_percentage: float = 0.9 - """The percent of images to use for training. The remaining images are for eval.""" + train_split_fraction: float = 0.9 + """The fraction of images to use for training. The remaining images are for eval.""" scene_scale: float = 1.0 """How much to scale the region of interest by.""" - orientation_method: Literal["pca", "up", "none"] = "up" + orientation_method: Literal["pca", "up", "vertical", "none"] = "vertical" """The method to use for orientation.""" auto_scale_poses: bool = True """Whether to automatically scale the poses to fit in +/- 1 bounding box.""" - center_poses: bool = True - """Whether to center the poses.""" @dataclass @@ -105,7 +103,7 @@ def __init__(self, config: HeritageDataParserConfig): self.data: Path = config.data self.scale_factor: float = config.scale_factor self.alpha_color = config.alpha_color - self.train_split_percentage = config.train_split_percentage + self.train_split_fraction = config.train_split_fraction # pylint: disable=too-many-statements def _generate_dataparser_outputs(self, split="train"): @@ -208,7 +206,7 @@ def _generate_dataparser_outputs(self, split="train"): # filter image_filenames and poses based on train/eval split percentage num_images = len(image_filenames) - num_train_images = math.ceil(num_images * self.config.train_split_percentage) + num_train_images = math.ceil(num_images * self.config.train_split_fraction) num_eval_images = num_images - num_train_images i_all = np.arange(num_images) i_train = np.linspace( @@ -225,7 +223,7 @@ def _generate_dataparser_outputs(self, split="train"): """ poses = camera_utils.auto_orient_and_center_poses( - poses, method=self.config.orientation_method, center_poses=self.config.center_poses + poses, method=self.config.orientation_method, center_method=self.config.center_method ) # Scale poses @@ -248,7 +246,7 @@ def _generate_dataparser_outputs(self, split="train"): poses, transform = camera_utils.auto_orient_and_center_poses( poses, method=self.config.orientation_method, - center_poses=False, + center_method="none", ) # scale pts accordingly diff --git a/nerfstudio/data/dataparsers/mipnerf360_dataparser.py b/nerfstudio/data/dataparsers/mipnerf360_dataparser.py index 803edeae..6634214c 100644 --- a/nerfstudio/data/dataparsers/mipnerf360_dataparser.py +++ b/nerfstudio/data/dataparsers/mipnerf360_dataparser.py @@ -53,10 +53,10 @@ class Mipnerf360DataParserConfig(DataParserConfig): """How much to downscale images. If not set, images are chosen such that the max dimension is <1600px.""" scene_scale: float = 1.0 """How much to scale the region of interest by.""" - orientation_method: Literal["pca", "up", "none"] = "up" + orientation_method: Literal["pca", "up", "vertical", "none"] = "vertical" """The method to use for orientation.""" - center_poses: bool = True - """Whether to center the poses.""" + center_method: Literal["poses", "focus", "none"] = "poses" + """The method to use to center the poses.""" auto_scale_poses: bool = True """Whether to automatically scale the poses to fit in +/- 1 bounding box.""" eval_interval: int = 8 @@ -205,7 +205,7 @@ def _generate_dataparser_outputs(self, split="train"): poses, transform_matrix = camera_utils.auto_orient_and_center_poses( poses, method=orientation_method, - center_poses=self.config.center_poses, + center_method=self.config.center_method, ) # Scale poses diff --git a/nerfstudio/data/dataparsers/nerfstudio_dataparser.py b/nerfstudio/data/dataparsers/nerfstudio_dataparser.py index 83b74211..57ff56c9 100644 --- a/nerfstudio/data/dataparsers/nerfstudio_dataparser.py +++ b/nerfstudio/data/dataparsers/nerfstudio_dataparser.py @@ -54,14 +54,14 @@ class NerfstudioDataParserConfig(DataParserConfig): """How much to downscale images. If not set, images are chosen such that the max dimension is <1600px.""" scene_scale: float = 1.0 """How much to scale the region of interest by.""" - orientation_method: Literal["pca", "up", "none"] = "up" + orientation_method: Literal["pca", "up", "vertical", "none"] = "vertical" """The method to use for orientation.""" - center_poses: bool = True - """Whether to center the poses.""" + center_method: Literal["poses", "focus", "none"] = "poses" + """The method to use to center the poses.""" auto_scale_poses: bool = True """Whether to automatically scale the poses to fit in +/- 1 bounding box.""" - train_split_percentage: float = 0.9 - """The percent of images to use for training. The remaining images are for eval.""" + train_split_fraction: float = 0.9 + """The fraction of images to use for training. The remaining images are for eval.""" @dataclass @@ -159,7 +159,7 @@ def _generate_dataparser_outputs(self, split="train"): # filter image_filenames and poses based on train/eval split percentage num_images = len(image_filenames) - num_train_images = math.ceil(num_images * self.config.train_split_percentage) + num_train_images = math.ceil(num_images * self.config.train_split_fraction) num_eval_images = num_images - num_train_images i_all = np.arange(num_images) i_train = np.linspace( @@ -184,7 +184,7 @@ def _generate_dataparser_outputs(self, split="train"): poses, _ = camera_utils.auto_orient_and_center_poses( poses, method=orientation_method, - center_poses=self.config.center_poses, + center_method=self.config.center_method, ) # Scale poses diff --git a/nerfstudio/data/dataparsers/nuscenes_dataparser.py b/nerfstudio/data/dataparsers/nuscenes_dataparser.py index 7a274530..bffece38 100644 --- a/nerfstudio/data/dataparsers/nuscenes_dataparser.py +++ b/nerfstudio/data/dataparsers/nuscenes_dataparser.py @@ -65,10 +65,8 @@ class NuScenesDataParserConfig(DataParserConfig): """Which cameras to use.""" mask_dir: Optional[Path] = None """Path to masks of dynamic objects.""" - - train_split_percentage: float = 0.9 - """The percent of images to use for training. The remaining images are for eval.""" - + train_split_fraction: float = 0.9 + """The fraction of images to use for training. The remaining images are for eval.""" verbose: bool = False """Load dataset with verbose messaging""" @@ -160,7 +158,7 @@ def _generate_dataparser_outputs(self, split="train"): # filter image_filenames and poses based on train/eval split percentage num_snapshots = len(samples) - num_train_snapshots = math.ceil(num_snapshots * self.config.train_split_percentage) + num_train_snapshots = math.ceil(num_snapshots * self.config.train_split_fraction) num_eval_snapshots = num_snapshots - num_train_snapshots i_all = np.arange(num_snapshots) i_train = np.linspace( diff --git a/nerfstudio/data/dataparsers/phototourism_dataparser.py b/nerfstudio/data/dataparsers/phototourism_dataparser.py index 7c23b1b0..c8b27eef 100644 --- a/nerfstudio/data/dataparsers/phototourism_dataparser.py +++ b/nerfstudio/data/dataparsers/phototourism_dataparser.py @@ -50,16 +50,16 @@ class PhototourismDataParserConfig(DataParserConfig): """How much to scale the camera origins by.""" alpha_color: str = "white" """alpha color of background""" - train_split_percentage: float = 0.9 - """The percent of images to use for training. The remaining images are for eval.""" + train_split_fraction: float = 0.9 + """The fraction of images to use for training. The remaining images are for eval.""" scene_scale: float = 1.0 """How much to scale the region of interest by.""" - orientation_method: Literal["pca", "up", "none"] = "up" + orientation_method: Literal["pca", "up", "vertical", "none"] = "vertical" """The method to use for orientation.""" + center_method: Literal["poses", "focus", "none"] = "poses" + """The method to use to center the poses.""" auto_scale_poses: bool = True """Whether to automatically scale the poses to fit in +/- 1 bounding box.""" - center_poses: bool = True - """Whether to center the poses.""" @dataclass @@ -119,7 +119,7 @@ def _generate_dataparser_outputs(self, split="train"): # filter image_filenames and poses based on train/eval split percentage num_images = len(image_filenames) - num_train_images = math.ceil(num_images * self.config.train_split_percentage) + num_train_images = math.ceil(num_images * self.config.train_split_fraction) num_eval_images = num_images - num_train_images i_all = np.arange(num_images) i_train = np.linspace( @@ -138,7 +138,7 @@ def _generate_dataparser_outputs(self, split="train"): raise ValueError(f"Unknown dataparser split {split}") poses, _ = camera_utils.auto_orient_and_center_poses( - poses, method=self.config.orientation_method, center_poses=self.config.center_poses + poses, method=self.config.orientation_method, center_method=self.config.center_method ) # Scale poses diff --git a/nerfstudio/data/dataparsers/record3d_dataparser.py b/nerfstudio/data/dataparsers/record3d_dataparser.py index 099470d0..22a7aa8f 100644 --- a/nerfstudio/data/dataparsers/record3d_dataparser.py +++ b/nerfstudio/data/dataparsers/record3d_dataparser.py @@ -51,7 +51,7 @@ class Record3DDataParserConfig(DataParserConfig): """1/val_skip images to use for validation.""" aabb_scale: float = 4.0 """Scene scale.""" - orientation_method: Literal["pca", "up"] = "up" + orientation_method: Literal["pca", "vertical", "up"] = "vertical" """The method to use for orientation""" max_dataset_size: int = 300 """Max number of images to train on. If the dataset has more, images will be sampled approximately evenly. If -1, @@ -116,7 +116,7 @@ def _generate_dataparser_outputs(self, split: str = "train") -> DataparserOutput poses = torch.from_numpy(poses[:, :3, :4]) poses = camera_utils.auto_orient_and_center_poses( - pose_utils.to4x4(poses), method=self.config.orientation_method + pose_utils.to4x4(poses), method=self.config.orientation_method, center_method="poses" )[:, :3, :4] # Centering poses diff --git a/nerfstudio/data/dataparsers/sdfstudio_dataparser.py b/nerfstudio/data/dataparsers/sdfstudio_dataparser.py index 65e1439a..7a3f440e 100644 --- a/nerfstudio/data/dataparsers/sdfstudio_dataparser.py +++ b/nerfstudio/data/dataparsers/sdfstudio_dataparser.py @@ -269,8 +269,8 @@ def _generate_dataparser_outputs(self, split="train"): # pylint: disable=unused if self.config.auto_orient: camera_to_worlds, transform = camera_utils.auto_orient_and_center_poses( camera_to_worlds, - method="up", - center_poses=False, + method="vertical", + center_method="none", ) # we should also transform normal accordingly