From 87a0a704d0173a85663abc2b12ab740b5b7de8c0 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Wed, 29 Jan 2025 09:11:25 +0000 Subject: [PATCH 01/13] op optimization --- examples/datasets/colmap.py | 2 +- gsplat/strategy/ops.py | 39 +++++++++++++++++++++++-------------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/examples/datasets/colmap.py b/examples/datasets/colmap.py index 11ad2a4b2..2aaacffdd 100644 --- a/examples/datasets/colmap.py +++ b/examples/datasets/colmap.py @@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional from typing_extensions import assert_never -import cv2 from PIL import Image +import cv2 import imageio.v2 as imageio import numpy as np import torch diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 83c90a25e..7416e3563 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -339,7 +339,6 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: if isinstance(v, torch.Tensor): state[k] = torch.cat((v, v_new)) - @torch.no_grad() def inject_noise_to_position( params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], @@ -347,23 +346,33 @@ def inject_noise_to_position( state: Dict[str, Tensor], scaler: float, ): + """Inplace inject noise to positions based on covariance and opacity.""" + device = params["means"].device + + # Calculate opacities more efficiently opacities = torch.sigmoid(params["opacities"].flatten()) - scales = torch.exp(params["scales"]) + + # Compute scaling factor for noise more efficiently + with torch.cuda.amp.autocast(enabled=True): + op_weights = 1 / (1 + torch.exp(-100 * (1 - opacities - 0.995))) + + # Generate noise directly in the right shape + noise = torch.randn_like(params["means"], device=device) * scaler + + # Scale noise by opacity weights + noise *= op_weights.unsqueeze(-1) + + # Get covariance matrices covars, _ = quat_scale_to_covar_preci( - params["quats"], - scales, + quats=params["quats"], + scales=torch.exp(params["scales"]), compute_covar=True, compute_preci=False, triu=False, ) - - def op_sigmoid(x, k=100, x0=0.995): - return 1 / (1 + torch.exp(-k * (x - x0))) - - noise = ( - torch.randn_like(params["means"]) - * (op_sigmoid(1 - opacities)).unsqueeze(-1) - * scaler - ) - noise = torch.einsum("bij,bj->bi", covars, noise) - params["means"].add_(noise) + + # Apply covariance scaling efficiently using batched operations + noise = torch.bmm(covars, noise.unsqueeze(-1)).squeeze(-1) + + # Update means inplace + params["means"].add_(noise) \ No newline at end of file From 0a09c255b6ea4e96598ecb5aefe9af65c46a0ee5 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Wed, 29 Jan 2025 09:11:50 +0000 Subject: [PATCH 02/13] much lighter on memory --- gsplat/distributed.py | 105 ++++++++++++++---------------------------- 1 file changed, 34 insertions(+), 71 deletions(-) diff --git a/gsplat/distributed.py b/gsplat/distributed.py index cab559df5..11ed06e54 100644 --- a/gsplat/distributed.py +++ b/gsplat/distributed.py @@ -173,89 +173,52 @@ def all_to_all_tensor_list( splits: List[Union[int, Tensor]], output_splits: Optional[List[Union[int, Tensor]]] = None, ) -> List[Tensor]: - """Split and exchange tensors between all ranks in a many-to-many fashion. - - Args: - world_size: The total number of ranks. - tensor_list: A list of tensors to split and exchange. The size of the first - dimension of all the tensors in this list should be the same. The rest - dimensions can be arbitrary. Shape: [(N, *), (N, *), ...] - splits: A list of integers representing the number of elements to send to each - rank. It will be used to split the tensor in the `tensor_list`. - The sum of the elements in this list should be equal to N. The size of this - list should be equal to the `world_size`. - output_splits: Splits of the output tensors. Could be pre-calculated via - `all_to_all_int32(world_size, splits)`. If not provided, it will - be calculated internally. - - Returns: - A list of tensors exchanged between all ranks, where the i-th element is - corresponding to the i-th tensor in `tensor_list`. Note the shape of the - returned tensors might be different from the input tensors, depending on the - splits. - - Examples: - - .. code-block:: python - - >>> # on rank 0 - >>> # tensor_list = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])] - >>> # splits = [2, 1] - - >>> # on rank 1 - >>> # tensor_list = [torch.tensor([7, 8]), torch.tensor([9, 10])] - >>> # splits = [1, 1] - - >>> collected = all_to_all_tensor_list(world_rank, world_size, tensor_list, splits) - - >>> # on rank 0 - >>> # [torch.tensor([1, 2, 7]), torch.tensor([4, 5, 9])] - >>> # on rank 1 - >>> # [torch.tensor([3, 8]), torch.tensor([6, 10])] - - """ if world_size == 1: return tensor_list + # Validate inputs N = len(tensor_list[0]) for tensor in tensor_list: assert len(tensor) == N, "All tensors should have the same first dimension size" + assert len(splits) == world_size, "The length of splits should be equal to world_size" - assert ( - len(splits) == world_size - ), "The length of splits should be equal to world_size" - - # concatenate tensors and record their sizes - data = torch.cat([t.reshape(N, -1) for t in tensor_list], dim=-1) - sizes = [t.numel() // N for t in tensor_list] + # Pre-calculate output splits if not provided to avoid redundant computation + if output_splits is None: + output_splits = all_to_all_int32(world_size, splits, device=tensor_list[0].device) - # all_to_all - if output_splits is not None: - collected_splits = output_splits - else: - collected_splits = all_to_all_int32(world_size, splits, device=data.device) - collected = [ - torch.empty((l, *data.shape[1:]), dtype=data.dtype, device=data.device) - for l in collected_splits - ] - # torch.split requires tuple of integers - splits = [s.item() if isinstance(s, Tensor) else s for s in splits] - if data.requires_grad: - # differentiable all_to_all - distF.all_to_all(collected, data.split(splits, dim=0)) - else: - # non-differentiable all_to_all - torch.distributed.all_to_all(collected, list(data.split(splits, dim=0))) - collected = torch.cat(collected, dim=0) - - # split the collected tensor and reshape to the original shape - out_tensor_tuple = torch.split(collected, sizes, dim=-1) + # Process tensors one at a time to reduce peak memory out_tensor_list = [] - for out_tensor, tensor in zip(out_tensor_tuple, tensor_list): + for tensor in tensor_list: + # Reshape tensor without creating a new copy if possible + flat_tensor = tensor.reshape(N, -1) + + # Pre-allocate output tensors + collected = [ + torch.empty((l, flat_tensor.shape[1]), + dtype=flat_tensor.dtype, + device=flat_tensor.device) + for l in output_splits + ] + + # Normalize splits for torch.split + norm_splits = [s.item() if isinstance(s, Tensor) else s for s in splits] + + # Perform all_to_all operation + if flat_tensor.requires_grad: + distF.all_to_all(collected, flat_tensor.split(norm_splits, dim=0)) + else: + torch.distributed.all_to_all(collected, list(flat_tensor.split(norm_splits, dim=0))) + + # Concatenate and reshape immediately + out_tensor = torch.cat(collected, dim=0) out_tensor = out_tensor.view(-1, *tensor.shape[1:]) out_tensor_list.append(out_tensor) - return out_tensor_list + # Clear intermediate tensors + del collected, flat_tensor + torch.cuda.empty_cache() + + return out_tensor_list def _find_free_port(): import socket From f70bfb612eabc07bb464f0c6c317efa9243fd92f Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Wed, 29 Jan 2025 09:12:46 +0000 Subject: [PATCH 03/13] add aws download script --- scripts/download_blob_aws.py | 42 ++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 scripts/download_blob_aws.py diff --git a/scripts/download_blob_aws.py b/scripts/download_blob_aws.py new file mode 100644 index 000000000..86a323c46 --- /dev/null +++ b/scripts/download_blob_aws.py @@ -0,0 +1,42 @@ +from azure.storage.blob import BlobServiceClient +import os, shutil + +# Don't clean directory this time since we want to keep existing files +local_path = os.path.expanduser("~/new_data") +if not os.path.exists(local_path): + os.makedirs(local_path) + +connection_string = f"DefaultEndpointsProtocol=https;AccountName=odyingest;AccountKey=;EndpointSuffix=core.windows.net" +container_name = "3droutput" + +# Create the blob service client +blob_service_client = BlobServiceClient.from_connection_string(connection_string) +container_client = blob_service_client.get_container_client(container_name) + +# Create ALL directories up front +directories = set() +for blob in container_client.list_blobs(): + dir_path = os.path.dirname(os.path.join(local_path, blob.name)) + directories.add(dir_path) + +# Create all directory paths first +for dir_path in sorted(directories): + os.makedirs(dir_path, exist_ok=True) + print(f"Created directory: {dir_path}") + +# Now download only files that don't exist +for blob in container_client.list_blobs(): + if blob.size > 0: + file_path = os.path.join(local_path, blob.name) + if not os.path.exists(file_path): # Only download if file doesn't exist + try: + print(f"Downloading {blob.name}...") + with open(file_path, "wb") as file: + data = container_client.download_blob(blob.name).readall() + file.write(data) + except Exception as e: + print(f"Error downloading {blob.name}: {str(e)}") + else: + print(f"Skipping {blob.name} - already exists") + +print("Download complete!") From 8e66dd84833a70614e700660b4d70de09424f82b Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Wed, 29 Jan 2025 09:14:06 +0000 Subject: [PATCH 04/13] mem logging in rendering --- gsplat/rendering.py | 64 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 78da64abf..94bf64922 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -24,6 +24,66 @@ ) from .utils import depth_to_normal, get_projection_matrix +import torch.distributed as dist + +import torch +import torch.distributed as dist +from typing import Optional + +def log_gpu_memory(stage: str, tensor: Optional[torch.Tensor] = None, verbose: bool = True): + """ + Logs the current and peak GPU memory usage for the GPU associated with the given tensor, + along with the process rank. + + Args: + stage (str): A descriptive label for the current stage in the code. + tensor (Optional[torch.Tensor]): A tensor to infer the GPU index from. If None, + defaults to the current device. + verbose (bool): If True, logging is performed. Default is True. + """ + if not verbose: + return + + if not torch.cuda.is_available(): + print(f"[{stage}] CUDA is not available.") + return + + # Determine the device index + if tensor is not None: + device = tensor.device + else: + device = torch.cuda.current_device() + + if dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + + if device.type != 'cuda': + print(f"[{stage}] Tensor is not on a CUDA device. Current device: {device}") + return + + gpu_index = device.index # Get the GPU index + + if gpu_index is None: + # Handle cases where the device is specified as 'cuda' without an index + gpu_index = 0 + + if gpu_index >= torch.cuda.device_count(): + print(f"[{stage}] Invalid GPU index: {gpu_index}. Total GPUs available: {torch.cuda.device_count()}.") + return + + mem_alloc = torch.cuda.memory_allocated(gpu_index) / (1024 ** 2) # Convert to MB + mem_reserved = torch.cuda.memory_reserved(gpu_index) / (1024 ** 2) # Convert to MB + mem_peak = torch.cuda.max_memory_allocated(gpu_index) / (1024 ** 2) # Convert to MB + + print(f"[{stage}] Rank {rank}/{world_size} GPU {gpu_index}: " + f"Allocated = {mem_alloc:.2f} MB, " + f"Reserved = {mem_reserved:.2f} MB, " + f"Peak Allocated = {mem_peak:.2f} MB") + def rasterization( means: Tensor, # [N, 3] @@ -293,6 +353,7 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso # Silently change C from local #Cameras to global #Cameras. C = len(viewmats) + #log_gpu_memory("Before projection", means) # Project Gaussians to 2D. Directly pass in {quats, scales} is faster than precomputing covars. proj_results = fully_fused_projection( means, @@ -313,6 +374,7 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso camera_model=camera_model, ) + #log_gpu_memory("After projection", means) if packed: # The results are packed into shape [nnz, ...]. All elements are valid. ( @@ -494,6 +556,7 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso # Identify intersecting tiles tile_width = math.ceil(width / float(tile_size)) tile_height = math.ceil(height / float(tile_size)) + #log_gpu_memory("before isect_tiles", means) tiles_per_gauss, isect_ids, flatten_ids = isect_tiles( means2d, radii, @@ -506,6 +569,7 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso camera_ids=camera_ids, gaussian_ids=gaussian_ids, ) + #log_gpu_memory("after isect_tiles", means) # print("rank", world_rank, "Before isect_offset_encode") isect_offsets = isect_offset_encode(isect_ids, C, tile_width, tile_height) From 8ad417ccc1c83c5b764ad691de18ef075b84bfbe Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Wed, 29 Jan 2025 09:16:22 +0000 Subject: [PATCH 05/13] alleviate vram consumption --- examples/simple_trainer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index ca9271e81..9b91d54f6 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -56,11 +56,11 @@ class Config: render_traj_path: str = "interp" # Path to the Mip-NeRF 360 dataset - data_dir: str = "data/360_v2/garden" + data_dir: str = "/home/paja/new_data/xplor/office_lobby/undistort/" # Downsample factor for the dataset - data_factor: int = 4 + data_factor: int = 1 # Directory to save results - result_dir: str = "results/garden" + result_dir: str = "results/office_lobby4k_default_full" # Every N images there is a test image test_every: int = 8 # Random crop size for training (experimental) @@ -83,7 +83,7 @@ class Config: # Number of training steps max_steps: int = 30_000 # Steps to evaluate the model - eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + eval_steps: List[int] = field(default_factory=lambda: []) # Steps to save the model save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # Whether to save ply file (storage size can be large) @@ -153,7 +153,7 @@ class Config: app_opt_reg: float = 1e-6 # Enable bilateral grid. (experimental) - use_bilateral_grid: bool = False + use_bilateral_grid: bool = True # Shape of the bilateral grid (X, Y, W) bilateral_grid_shape: Tuple[int, int, int] = (16, 16, 8) @@ -1096,8 +1096,8 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config): Config( init_opa=0.5, init_scale=0.1, - opacity_reg=0.01, - scale_reg=0.01, + opacity_reg=0.001, # opacity reg down removes black floaters + scale_reg=0.05, # scale_reg up helps to regulate huge splats strategy=MCMCStrategy(verbose=True), ), ), From d8abe65cb62f51012389e799dc4934488382492a Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Thu, 30 Jan 2025 08:19:49 +0000 Subject: [PATCH 06/13] implement masks --- examples/datasets/colmap.py | 18 ++++++++++++++++++ examples/simple_trainer.py | 27 +++++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/examples/datasets/colmap.py b/examples/datasets/colmap.py index 2aaacffdd..7f7af4631 100644 --- a/examples/datasets/colmap.py +++ b/examples/datasets/colmap.py @@ -182,9 +182,13 @@ def __init__( image_dir_suffix = "" colmap_image_dir = os.path.join(data_dir, "images") image_dir = os.path.join(data_dir, "images" + image_dir_suffix) + mask_dir = os.path.join(data_dir, "masks" + image_dir_suffix) for d in [image_dir, colmap_image_dir]: if not os.path.exists(d): raise ValueError(f"Image folder {d} does not exist.") + for d in [mask_dir, colmap_image_dir]: + if not os.path.exists(d): + raise ValueError(f"Image folder {d} does not exist.") # Downsampled images may have different names vs images used for COLMAP, # so we need to map between the two sorted lists of files. @@ -197,6 +201,7 @@ def __init__( image_files = sorted(_get_rel_paths(image_dir)) colmap_to_image = dict(zip(colmap_files, image_files)) image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names] + image_mask_paths = [os.path.join(mask_dir, colmap_to_image[f].replace('.jpg', '.png')) for f in image_names] # 3D points and {image_name -> [point_idx]} points = manager.points3D.astype(np.float32) @@ -230,6 +235,7 @@ def __init__( self.image_names = image_names # List[str], (num_images,) self.image_paths = image_paths # List[str], (num_images,) + self.image_mask_paths = image_mask_paths # List[str], (num_images,) self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4) self.camera_ids = camera_ids # List[int], (num_images,) self.Ks_dict = Ks_dict # Dict of camera_id -> K @@ -357,6 +363,16 @@ def __len__(self): def __getitem__(self, item: int) -> Dict[str, Any]: index = self.indices[item] image = imageio.imread(self.parser.image_paths[index])[..., :3] + image_mask = imageio.imread(self.parser.image_mask_paths[index]) + + # Handle different possible PNG formats: + if len(image_mask.shape) == 3: + if image_mask.shape[2] == 4: # RGBA + image_mask = image_mask[..., -1] # Take alpha channel + else: # RGB + image_mask = image_mask.mean(axis=2) + # Now convert to binary + image_mask = (image_mask > 127).astype(np.uint8) camera_id = self.parser.camera_ids[index] K = self.parser.Ks_dict[camera_id].copy() # undistorted K params = self.parser.params_dict[camera_id] @@ -379,6 +395,7 @@ def __getitem__(self, item: int) -> Dict[str, Any]: x = np.random.randint(0, max(w - self.patch_size, 1)) y = np.random.randint(0, max(h - self.patch_size, 1)) image = image[y : y + self.patch_size, x : x + self.patch_size] + image_mask = image_mask[y : y + self.patch_size, x : x + self.patch_size] K[0, 2] -= x K[1, 2] -= y @@ -386,6 +403,7 @@ def __getitem__(self, item: int) -> Dict[str, Any]: "K": torch.from_numpy(K).float(), "camtoworld": torch.from_numpy(camtoworlds).float(), "image": torch.from_numpy(image).float(), + "image_mask": ~torch.from_numpy(image_mask).bool(), "image_id": item, # the index of the image in the dataset } if mask is not None: diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 9b91d54f6..d1f2d99ea 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -56,11 +56,11 @@ class Config: render_traj_path: str = "interp" # Path to the Mip-NeRF 360 dataset - data_dir: str = "/home/paja/new_data/xplor/office_lobby/undistort/" + data_dir: str = "/home/paja/new_data/basketball/" # Downsample factor for the dataset data_factor: int = 1 # Directory to save results - result_dir: str = "results/office_lobby4k_default_full" + result_dir: str = "results/basketball_mcmc_1_350k_scale" # Every N images there is a test image test_every: int = 8 # Random crop size for training (experimental) @@ -189,6 +189,23 @@ def adjust_steps(self, factor: float): else: assert_never(strategy) +def save_step_images(pixels, image_mask, save_dir, step): + """ + Save both pixels and mask for each step + pixels: [1,n,m,3] + image_mask: [1,n,m,1] + """ + # Create directory if it doesn't exist + os.makedirs(save_dir, exist_ok=True) + + # Convert pixels to uint8 image format (multiply by 255 since they were divided earlier) + pixels_img = (pixels[0].detach().cpu().numpy() * 255).astype(np.uint8) + # Convert mask to binary uint8 format + mask_img = (image_mask[0].detach().cpu().numpy() * 255).astype(np.uint8) + + # Save both images + imageio.imwrite(os.path.join(save_dir, f'step_{step:04d}_image.png'), pixels_img) + imageio.imwrite(os.path.join(save_dir, f'step_{step:04d}_mask.png'), mask_img[..., 0]) # Remove single channel dimension def create_splats_with_optimizers( parser: Parser, @@ -575,8 +592,12 @@ def train(self): data = next(trainloader_iter) camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4] + image_mask = data["image_mask"].to(device) + image_mask = image_mask.permute(1,2,0).unsqueeze(0) Ks = data["K"].to(device) # [1, 3, 3] pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] + #assert pixels.shape == image_mask.shape, f"pixels.shape {pixels.shape}, image_mask.shape {image_mask.shape}" + #save_step_images(pixels, image_mask, ".", step) num_train_rays_per_step = ( pixels.shape[0] * pixels.shape[1] * pixels.shape[2] ) @@ -615,6 +636,8 @@ def train(self): else: colors, depths = renders, None + colors = colors * image_mask + pixels = pixels * image_mask if cfg.use_bilateral_grid: grid_y, grid_x = torch.meshgrid( (torch.arange(height, device=self.device) + 0.5) / height, From d333f88b19f92431144aa4940baca22882bdf10f Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Thu, 30 Jan 2025 08:20:27 +0000 Subject: [PATCH 07/13] lower densification threshold --- gsplat/strategy/default.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gsplat/strategy/default.py b/gsplat/strategy/default.py index 30c152d99..c399735dd 100644 --- a/gsplat/strategy/default.py +++ b/gsplat/strategy/default.py @@ -77,7 +77,7 @@ class DefaultStrategy(Strategy): """ prune_opa: float = 0.005 - grow_grad2d: float = 0.0002 + grow_grad2d: float = 0.00012 grow_scale3d: float = 0.01 grow_scale2d: float = 0.05 prune_scale3d: float = 0.1 From 26ad57473b866c66971d3e67415fa9c05e9a9433 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Tue, 4 Feb 2025 16:54:01 +0000 Subject: [PATCH 08/13] check mask path --- examples/datasets/colmap.py | 41 ++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/examples/datasets/colmap.py b/examples/datasets/colmap.py index 7f7af4631..ee8d1f265 100644 --- a/examples/datasets/colmap.py +++ b/examples/datasets/colmap.py @@ -183,12 +183,10 @@ def __init__( colmap_image_dir = os.path.join(data_dir, "images") image_dir = os.path.join(data_dir, "images" + image_dir_suffix) mask_dir = os.path.join(data_dir, "masks" + image_dir_suffix) + for d in [image_dir, colmap_image_dir]: if not os.path.exists(d): raise ValueError(f"Image folder {d} does not exist.") - for d in [mask_dir, colmap_image_dir]: - if not os.path.exists(d): - raise ValueError(f"Image folder {d} does not exist.") # Downsampled images may have different names vs images used for COLMAP, # so we need to map between the two sorted lists of files. @@ -201,7 +199,14 @@ def __init__( image_files = sorted(_get_rel_paths(image_dir)) colmap_to_image = dict(zip(colmap_files, image_files)) image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names] - image_mask_paths = [os.path.join(mask_dir, colmap_to_image[f].replace('.jpg', '.png')) for f in image_names] + + # Only create mask paths if mask directory exists + image_mask_paths = [] + if os.path.exists(mask_dir): + image_mask_paths = [os.path.join(mask_dir, colmap_to_image[f].replace('.jpg', '.png')) for f in image_names] + else: + print(f"Warning: Mask folder {mask_dir} does not exist. Proceeding without masks.") + image_mask_paths = [None] * len(image_names) # 3D points and {image_name -> [point_idx]} points = manager.points3D.astype(np.float32) @@ -363,16 +368,23 @@ def __len__(self): def __getitem__(self, item: int) -> Dict[str, Any]: index = self.indices[item] image = imageio.imread(self.parser.image_paths[index])[..., :3] - image_mask = imageio.imread(self.parser.image_mask_paths[index]) - - # Handle different possible PNG formats: - if len(image_mask.shape) == 3: - if image_mask.shape[2] == 4: # RGBA - image_mask = image_mask[..., -1] # Take alpha channel - else: # RGB - image_mask = image_mask.mean(axis=2) - # Now convert to binary - image_mask = (image_mask > 127).astype(np.uint8) + + # Handle mask loading with proper checks + image_mask = None + if self.parser.image_mask_paths[index] is not None: + image_mask = imageio.imread(self.parser.image_mask_paths[index]) + # Handle different possible PNG formats: + if len(image_mask.shape) == 3: + if image_mask.shape[2] == 4: # RGBA + image_mask = image_mask[..., -1] # Take alpha channel + else: # RGB + image_mask = image_mask.mean(axis=2) + # Now convert to binary + image_mask = (image_mask > 127).astype(np.uint8) + else: + # If no mask exists, create a dummy mask of ones (no masking) + image_mask = np.ones(image.shape[:2], dtype=np.uint8) + camera_id = self.parser.camera_ids[index] K = self.parser.Ks_dict[camera_id].copy() # undistorted K params = self.parser.params_dict[camera_id] @@ -388,6 +400,7 @@ def __getitem__(self, item: int) -> Dict[str, Any]: image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR) x, y, w, h = self.parser.roi_undist_dict[camera_id] image = image[y : y + h, x : x + w] + image_mask = image_mask[y : y + h, x : x + w] if self.patch_size is not None: # Random crop. From ea2f1f3db79bfbf8ad9a5a32c62ebe3db6838bf1 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Tue, 4 Feb 2025 17:38:41 +0000 Subject: [PATCH 09/13] default is more stable --- examples/simple_trainer.py | 15 +++++++++------ gsplat/strategy/default.py | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index d1f2d99ea..b1d2d3079 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -56,11 +56,11 @@ class Config: render_traj_path: str = "interp" # Path to the Mip-NeRF 360 dataset - data_dir: str = "/home/paja/new_data/basketball/" + data_dir: str = "/home/paja/new_data/xplor/beach_structure/undistort/" # Downsample factor for the dataset data_factor: int = 1 # Directory to save results - result_dir: str = "results/basketball_mcmc_1_350k_scale" + result_dir: str = "results/office_lobby_new" # Every N images there is a test image test_every: int = 8 # Random crop size for training (experimental) @@ -592,8 +592,10 @@ def train(self): data = next(trainloader_iter) camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4] - image_mask = data["image_mask"].to(device) - image_mask = image_mask.permute(1,2,0).unsqueeze(0) + image_mask = None + if image_mask in data: + image_mask = data["image_mask"].to(device) + image_mask = image_mask.permute(1,2,0).unsqueeze(0) Ks = data["K"].to(device) # [1, 3, 3] pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] #assert pixels.shape == image_mask.shape, f"pixels.shape {pixels.shape}, image_mask.shape {image_mask.shape}" @@ -636,8 +638,9 @@ def train(self): else: colors, depths = renders, None - colors = colors * image_mask - pixels = pixels * image_mask + if image_mask is not None: + colors = colors * image_mask + pixels = pixels * image_mask if cfg.use_bilateral_grid: grid_y, grid_x = torch.meshgrid( (torch.arange(height, device=self.device) + 0.5) / height, diff --git a/gsplat/strategy/default.py b/gsplat/strategy/default.py index c399735dd..8c22dcd57 100644 --- a/gsplat/strategy/default.py +++ b/gsplat/strategy/default.py @@ -89,7 +89,7 @@ class DefaultStrategy(Strategy): refine_every: int = 100 pause_refine_after_reset: int = 0 absgrad: bool = False - revised_opacity: bool = False + revised_opacity: bool = True verbose: bool = False key_for_gradient: Literal["means2d", "gradient_2dgs"] = "means2d" From 08d1bdd2b3c20bc22e1bf6cfa5459d48938ab27b Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Tue, 4 Feb 2025 17:41:21 +0000 Subject: [PATCH 10/13] Merge plys and pts --- scripts/merge_plys.py | 193 ++++++++++++++++++++++++++++++++++++++++++ scripts/merge_pts.py | 127 +++++++++++++++++++++++++++ 2 files changed, 320 insertions(+) create mode 100644 scripts/merge_plys.py create mode 100644 scripts/merge_pts.py diff --git a/scripts/merge_plys.py b/scripts/merge_plys.py new file mode 100644 index 000000000..6748cece0 --- /dev/null +++ b/scripts/merge_plys.py @@ -0,0 +1,193 @@ +import open3d as o3d +import numpy as np +import torch +from typing import List, Dict +from pathlib import Path + + +def load_ply(filepath: str) -> Dict[str, torch.Tensor]: + """ + Load a PLY file and extract all its attributes into a dictionary of torch tensors. + """ + print(f"Loading PLY from {filepath}") + pcd = o3d.t.io.read_point_cloud(filepath) + + # Initialize dictionary for all attributes + data = {} + + # Extract point positions (means) + means = pcd.point.positions.numpy() + data["means"] = torch.from_numpy(means) + + # Extract scales + scales = np.column_stack([ + pcd.point[f"scale_{i}"].numpy() for i in range(3) + ]) + data["scales"] = torch.from_numpy(scales) + + # Extract quaternions + quats = np.column_stack([ + pcd.point[f"rot_{i}"].numpy() for i in range(4) + ]) + data["quats"] = torch.from_numpy(quats) + + # Extract opacities + data["opacities"] = torch.from_numpy(pcd.point["opacity"].numpy()) + + # Check if we have SH coefficients or colors + if "f_dc_0" in pcd.point: + # Count number of SH coefficients + dc_count = sum(1 for key in pcd.point if key.startswith("f_dc_")) + rest_count = sum(1 for key in pcd.point if key.startswith("f_rest_")) + + if rest_count > 0: # We have SH coefficients + # Extract SH0 coefficients + sh0 = np.column_stack([ + pcd.point[f"f_dc_{i}"].numpy() for i in range(dc_count) + ]) + data["sh0"] = torch.from_numpy(sh0).reshape(-1, dc_count // 3, 3).transpose(1, 2) + + # Extract SHN coefficients + shN = np.column_stack([ + pcd.point[f"f_rest_{i}"].numpy() for i in range(rest_count) + ]) + data["shN"] = torch.from_numpy(shN).reshape(-1, rest_count // 3, 3).transpose(1, 2) + else: # We have colors + colors = np.column_stack([ + pcd.point[f"f_dc_{i}"].numpy() for i in range(dc_count) + ]) + data["colors"] = torch.from_numpy(colors * 0.2820947917738781 + 0.5) + + return data + + +def merge_plys(filepaths: List[str]) -> Dict[str, torch.Tensor]: + """ + Merge multiple PLY files into a single dictionary of torch tensors. + """ + print(f"Merging {len(filepaths)} PLY files") + merged_data = {} + + # Load and merge each PLY file + for filepath in filepaths: + data = load_ply(filepath) + + # For first file, initialize merged_data + if not merged_data: + merged_data = {k: [v] for k, v in data.items()} + else: + # Verify that the current file has the same attributes + assert set(data.keys()) == set(merged_data.keys()), \ + f"PLY file {filepath} has different attributes than previous files" + + # Append tensors to lists + for k, v in data.items(): + merged_data[k].append(v) + + # Concatenate all tensors + return {k: torch.cat(v, dim=0) for k, v in merged_data.items()} + + +def save_merged_ply(merged_data: Dict[str, torch.Tensor], output_path: str): + """ + Save merged data as a PLY file using the same format as the original save_ply function. + """ + # Convert to ParameterDict if it isn't already + if not isinstance(merged_data, torch.nn.ParameterDict): + param_dict = torch.nn.ParameterDict() + for k, v in merged_data.items(): + if k != "colors": # Skip colors if present + param_dict[k] = torch.nn.Parameter(v) + merged_data = param_dict + + # Use the provided save_ply function + colors = merged_data.get("colors") if "colors" in merged_data else None + save_ply(merged_data, output_path, colors) + + +def process_plys(input_dir: str, output_path: str, pattern: str = "*.ply"): + """ + Process all PLY files in a directory and merge them into a single file. + + Args: + input_dir: Directory containing input PLY files + output_path: Path where to save the merged PLY file + pattern: Glob pattern to match PLY files (default: "*.ply") + """ + input_path = Path(input_dir) + ply_files = sorted(str(p) for p in input_path.glob(pattern)) + + if not ply_files: + raise ValueError(f"No PLY files found in {input_dir} matching pattern {pattern}") + + print(f"Found {len(ply_files)} PLY files") + + # Merge all PLY files + merged_data = merge_plys(ply_files) + + # Save merged result + save_merged_ply(merged_data, output_path) + print(f"Saved merged PLY to {output_path}") + +def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = None): + # Convert all tensors to numpy arrays in one go + print(f"Saving ply to {dir}") + numpy_data = {k: v.detach().cpu().numpy() for k, v in splats.items()} + + means = numpy_data["means"] + scales = numpy_data["scales"] + quats = numpy_data["quats"] + opacities = numpy_data["opacities"] + ply_data = { + "positions": o3d.core.Tensor(means, dtype=o3d.core.Dtype.Float32), + "normals": o3d.core.Tensor(np.zeros_like(means), dtype=o3d.core.Dtype.Float32), + "opacity": o3d.core.Tensor( + opacities.reshape(-1, 1), dtype=o3d.core.Dtype.Float32 + ), + } + + if colors is not None: + color = colors.detach().cpu().numpy().copy() # + for j in range(color.shape[1]): + # Needs to be converted to shs as that's what all viewers take. + ply_data[f"f_dc_{j}"] = o3d.core.Tensor( + (color[:, j : j + 1] - 0.5) / 0.2820947917738781, + dtype=o3d.core.Dtype.Float32, + ) + else: + sh0 = numpy_data["sh0"].transpose(0, 2, 1).reshape(means.shape[0], -1).copy() + shN = numpy_data["shN"].transpose(0, 2, 1).reshape(means.shape[0], -1).copy() + + # Add sh0 and shN data + for i, data in enumerate([sh0, shN]): + prefix = "f_dc" if i == 0 else "f_rest" + for j in range(data.shape[1]): + ply_data[f"{prefix}_{j}"] = o3d.core.Tensor( + data[:, j : j + 1], dtype=o3d.core.Dtype.Float32 + ) + + # Add scales and quats data + for name, data in [("scale", scales), ("rot", quats)]: + for i in range(data.shape[1]): + ply_data[f"{name}_{i}"] = o3d.core.Tensor( + data[:, i : i + 1], dtype=o3d.core.Dtype.Float32 + ) + + pcd = o3d.t.geometry.PointCloud(ply_data) + + success = o3d.t.io.write_point_cloud(dir, pcd) + assert success, "Could not save ply file." + +# Command line interface +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description='Merge multiple PLY files into a single PLY file.') + parser.add_argument('input_dir', type=str, help='Directory containing input PLY files') + parser.add_argument('output_path', type=str, help='Path where to save the merged PLY file') + parser.add_argument('--pattern', type=str, default='*.ply', + help='Glob pattern to match PLY files (default: *.ply)') + + args = parser.parse_args() + + process_plys(args.input_dir, args.output_path, args.pattern) \ No newline at end of file diff --git a/scripts/merge_pts.py b/scripts/merge_pts.py new file mode 100644 index 000000000..de43a8264 --- /dev/null +++ b/scripts/merge_pts.py @@ -0,0 +1,127 @@ +import os +import glob +import torch +import argparse +import open3d as o3d +from pathlib import Path +from collections import defaultdict +import numpy as np + +def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = None): + # Convert all tensors to numpy arrays in one go + print(f"Saving ply to {dir}") + numpy_data = {k: v.detach().cpu().numpy() for k, v in splats.items()} + + means = numpy_data["means"] + scales = numpy_data["scales"] + quats = numpy_data["quats"] + opacities = numpy_data["opacities"] + ply_data = { + "positions": o3d.core.Tensor(means, dtype=o3d.core.Dtype.Float32), + "normals": o3d.core.Tensor(np.zeros_like(means), dtype=o3d.core.Dtype.Float32), + "opacity": o3d.core.Tensor( + opacities.reshape(-1, 1), dtype=o3d.core.Dtype.Float32 + ), + } + + if colors is not None: + color = colors.detach().cpu().numpy().copy() # + for j in range(color.shape[1]): + # Needs to be converted to shs as that's what all viewers take. + ply_data[f"f_dc_{j}"] = o3d.core.Tensor( + (color[:, j : j + 1] - 0.5) / 0.2820947917738781, + dtype=o3d.core.Dtype.Float32, + ) + else: + sh0 = numpy_data["sh0"].transpose(0, 2, 1).reshape(means.shape[0], -1).copy() + shN = numpy_data["shN"].transpose(0, 2, 1).reshape(means.shape[0], -1).copy() + + # Add sh0 and shN data + for i, data in enumerate([sh0, shN]): + prefix = "f_dc" if i == 0 else "f_rest" + for j in range(data.shape[1]): + ply_data[f"{prefix}_{j}"] = o3d.core.Tensor( + data[:, j : j + 1], dtype=o3d.core.Dtype.Float32 + ) + + # Add scales and quats data + for name, data in [("scale", scales), ("rot", quats)]: + for i in range(data.shape[1]): + ply_data[f"{name}_{i}"] = o3d.core.Tensor( + data[:, i : i + 1], dtype=o3d.core.Dtype.Float32 + ) + + pcd = o3d.t.geometry.PointCloud(ply_data) + success = o3d.t.io.write_point_cloud(dir, pcd) + assert success, "Could not save ply file." + +def merge_checkpoints(ckpts_folder: str, output_dir: str = None): + """ + Load and merge checkpoint files from a folder into PLY files. + + Args: + ckpts_folder: Folder containing checkpoint files + output_dir: Output directory for PLY files. If None, uses ckpts_folder parent + """ + ckpts_folder = Path(ckpts_folder) + if not ckpts_folder.exists(): + raise ValueError(f"Checkpoint folder does not exist: {ckpts_folder}") + + # If no output directory specified, create a 'merged_ply' folder next to ckpts + if output_dir is None: + output_dir = ckpts_folder.parent / "merged_ply" + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + + # Find all checkpoint files + ckpt_files = list(ckpts_folder.glob("ckpt_*_rank*.pt")) + if not ckpt_files: + raise ValueError(f"No checkpoint files found in: {ckpts_folder}") + + # Group checkpoints by step number + step_groups = defaultdict(list) + for ckpt_file in ckpt_files: + # Extract step number from filename (assumes format ckpt_STEP_rank*.pt) + step = int(ckpt_file.name.split('_')[1]) + step_groups[step].append(ckpt_file) + + print(f"Found checkpoints for {len(step_groups)} steps: {sorted(step_groups.keys())}") + + # Process each step + for step, files in sorted(step_groups.items()): + print(f"\nProcessing step {step}...") + print(f"Found {len(files)} rank files:") + for f in files: + print(f" {f}") + + # Load all checkpoints for this step + ckpts = [torch.load(f, map_location='cpu') for f in files] + + # Create a new ParameterDict to store merged parameters + merged_splats = torch.nn.ParameterDict() + + # Get the keys from first checkpoint's splats + splat_keys = ckpts[0]['splats'].keys() + + # Merge parameters for each key + for key in splat_keys: + merged_data = torch.cat([ckpt['splats'][key] for ckpt in ckpts]) + merged_splats[key] = torch.nn.Parameter(merged_data) + + # Save merged data as PLY + output_ply = output_dir / f"gaussian_step_{step}.ply" + save_ply(merged_splats, str(output_ply)) + print(f"Successfully saved merged PLY to {output_ply}") + +def main(): + parser = argparse.ArgumentParser(description='Merge checkpoint files from a folder into PLY files') + parser.add_argument('--ckpts-folder', type=str, required=True, + help='Folder containing checkpoint files') + parser.add_argument('--output-dir', type=str, default=None, + help='Output directory for PLY files (optional)') + + args = parser.parse_args() + merge_checkpoints(args.ckpts_folder, args.output_dir) + +if __name__ == '__main__': + main() From dbe5b8674e3cf00376b82f7fb64c25e7c3986a68 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Tue, 4 Feb 2025 18:03:07 +0000 Subject: [PATCH 11/13] udpate environement.yml --- environment.yml | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 environment.yml diff --git a/environment.yml b/environment.yml new file mode 100644 index 000000000..5fd7707dc --- /dev/null +++ b/environment.yml @@ -0,0 +1,28 @@ +name: gsplat +channels: + - pytorch3d + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + - setuptools==69.5.1 + - python=3.10 + - pip + - plyfile + - tqdm + - tyro + - pytorch=2.1.0 + - torchvision=0.16.0 + - pytorch-cuda=12.1 + - typing_extensions + - numpy=1.26.4 + - scikit-learn + - pip: + - viser + - tensorboard + - pyyaml + - open3d==0.18.0 + - tqdm + - pillow + - plyfile \ No newline at end of file From 4fabfd33d82fa1ba8ae31e9f6c687764a32598b7 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Tue, 4 Feb 2025 18:03:23 +0000 Subject: [PATCH 12/13] update readme --- README.md | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/README.md b/README.md index c4186c26d..c29ba3343 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,47 @@ +# Installation +```bash +micromamba create -f environment.yml -y +micromamba activate gsplat +pip install -r examples/requirements.txt +pip install -e . +``` + +# How to run +Currently, there is an memory issue that pops up every now and then with the prefered densification strategy mcmc or mcmc-style. See for instance: https://github.com/nerfstudio-project/gsplat/issues/487. It is not safe to run it + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python examples/simple_trainer.py default --use-bilateral-grid --data-dir path --test_every 0 -restult-dir path --disable_viewer --steps-scaler 0.5 +``` +In the result-dir there will be a ckpt folder and a ply folder which has multiple files (one per GPU). They can be merged with either +```bash +python scripts/merge_plys.py ./results/plys ./results/merged.plys + +``` +or + +```bash +python merge_pts.py --ckpts-folder ./results/ckpts --output-dir ./results/ckpts +``` + +The images and masks (if available) must have the following folder structure +```bash +├── images +│   └── L2PRO +│   ├── camera_0 +│   └── camera_1 +├── masks +│   └── L2PRO +│   ├── camera_0 +│   └── camera_1 +├── sparse +│   └── 0 +│   ├── cameras.bin +│   ├── images.bin +│   └── points3D.bin +``` +Attention, the colmap loading dependency has a bug and does not properly work with the txt format. Only *.bin are working properly. + + # gsplat [![Core Tests.](https://github.com/nerfstudio-project/gsplat/actions/workflows/core_tests.yml/badge.svg?branch=main)](https://github.com/nerfstudio-project/gsplat/actions/workflows/core_tests.yml) From 396d054064d17152884cbfc89637b4786f2101e4 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Tue, 4 Feb 2025 18:05:38 +0000 Subject: [PATCH 13/13] update readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c29ba3343..9f0529b7b 100644 --- a/README.md +++ b/README.md @@ -26,11 +26,11 @@ python merge_pts.py --ckpts-folder ./results/ckpts --output-dir ./results/ckpts The images and masks (if available) must have the following folder structure ```bash ├── images -│   └── L2PRO +│   └── L2PRO (or othere device) │   ├── camera_0 │   └── camera_1 ├── masks -│   └── L2PRO +│   └── L2PRO (or other device, must be consistent with images folder) │   ├── camera_0 │   └── camera_1 ├── sparse