Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #553

Closed
wants to merge 13 commits into from
Closed

Dev #553

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,47 @@
# Installation
```bash
micromamba create -f environment.yml -y
micromamba activate gsplat
pip install -r examples/requirements.txt
pip install -e .
```

# How to run
Currently, there is an memory issue that pops up every now and then with the prefered densification strategy mcmc or mcmc-style. See for instance: https://github.com/nerfstudio-project/gsplat/issues/487. It is not safe to run it

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python examples/simple_trainer.py default --use-bilateral-grid --data-dir path --test_every 0 -restult-dir path --disable_viewer --steps-scaler 0.5
```
In the result-dir there will be a ckpt folder and a ply folder which has multiple files (one per GPU). They can be merged with either
```bash
python scripts/merge_plys.py ./results/plys ./results/merged.plys

```
or

```bash
python merge_pts.py --ckpts-folder ./results/ckpts --output-dir ./results/ckpts
```

The images and masks (if available) must have the following folder structure
```bash
├── images
│   └── L2PRO (or othere device)
│   ├── camera_0
│   └── camera_1
├── masks
│   └── L2PRO (or other device, must be consistent with images folder)
│   ├── camera_0
│   └── camera_1
├── sparse
│   └── 0
│   ├── cameras.bin
│   ├── images.bin
│   └── points3D.bin
```
Attention, the colmap loading dependency has a bug and does not properly work with the txt format. Only *.bin are working properly.


# gsplat

[![Core Tests.](https://github.com/nerfstudio-project/gsplat/actions/workflows/core_tests.yml/badge.svg?branch=main)](https://github.com/nerfstudio-project/gsplat/actions/workflows/core_tests.yml)
Expand Down
28 changes: 28 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: gsplat
channels:
- pytorch3d
- pytorch
- nvidia
- conda-forge
- defaults
dependencies:
- setuptools==69.5.1
- python=3.10
- pip
- plyfile
- tqdm
- tyro
- pytorch=2.1.0
- torchvision=0.16.0
- pytorch-cuda=12.1
- typing_extensions
- numpy=1.26.4
- scikit-learn
- pip:
- viser
- tensorboard
- pyyaml
- open3d==0.18.0
- tqdm
- pillow
- plyfile
33 changes: 32 additions & 1 deletion examples/datasets/colmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Any, Dict, List, Optional
from typing_extensions import assert_never

import cv2
from PIL import Image
import cv2
import imageio.v2 as imageio
import numpy as np
import torch
Expand Down Expand Up @@ -182,6 +182,8 @@ def __init__(
image_dir_suffix = ""
colmap_image_dir = os.path.join(data_dir, "images")
image_dir = os.path.join(data_dir, "images" + image_dir_suffix)
mask_dir = os.path.join(data_dir, "masks" + image_dir_suffix)

for d in [image_dir, colmap_image_dir]:
if not os.path.exists(d):
raise ValueError(f"Image folder {d} does not exist.")
Expand All @@ -197,6 +199,14 @@ def __init__(
image_files = sorted(_get_rel_paths(image_dir))
colmap_to_image = dict(zip(colmap_files, image_files))
image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names]

# Only create mask paths if mask directory exists
image_mask_paths = []
if os.path.exists(mask_dir):
image_mask_paths = [os.path.join(mask_dir, colmap_to_image[f].replace('.jpg', '.png')) for f in image_names]
else:
print(f"Warning: Mask folder {mask_dir} does not exist. Proceeding without masks.")
image_mask_paths = [None] * len(image_names)

# 3D points and {image_name -> [point_idx]}
points = manager.points3D.astype(np.float32)
Expand Down Expand Up @@ -230,6 +240,7 @@ def __init__(

self.image_names = image_names # List[str], (num_images,)
self.image_paths = image_paths # List[str], (num_images,)
self.image_mask_paths = image_mask_paths # List[str], (num_images,)
self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4)
self.camera_ids = camera_ids # List[int], (num_images,)
self.Ks_dict = Ks_dict # Dict of camera_id -> K
Expand Down Expand Up @@ -357,6 +368,23 @@ def __len__(self):
def __getitem__(self, item: int) -> Dict[str, Any]:
index = self.indices[item]
image = imageio.imread(self.parser.image_paths[index])[..., :3]

# Handle mask loading with proper checks
image_mask = None
if self.parser.image_mask_paths[index] is not None:
image_mask = imageio.imread(self.parser.image_mask_paths[index])
# Handle different possible PNG formats:
if len(image_mask.shape) == 3:
if image_mask.shape[2] == 4: # RGBA
image_mask = image_mask[..., -1] # Take alpha channel
else: # RGB
image_mask = image_mask.mean(axis=2)
# Now convert to binary
image_mask = (image_mask > 127).astype(np.uint8)
else:
# If no mask exists, create a dummy mask of ones (no masking)
image_mask = np.ones(image.shape[:2], dtype=np.uint8)

camera_id = self.parser.camera_ids[index]
K = self.parser.Ks_dict[camera_id].copy() # undistorted K
params = self.parser.params_dict[camera_id]
Expand All @@ -372,20 +400,23 @@ def __getitem__(self, item: int) -> Dict[str, Any]:
image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR)
x, y, w, h = self.parser.roi_undist_dict[camera_id]
image = image[y : y + h, x : x + w]
image_mask = image_mask[y : y + h, x : x + w]

if self.patch_size is not None:
# Random crop.
h, w = image.shape[:2]
x = np.random.randint(0, max(w - self.patch_size, 1))
y = np.random.randint(0, max(h - self.patch_size, 1))
image = image[y : y + self.patch_size, x : x + self.patch_size]
image_mask = image_mask[y : y + self.patch_size, x : x + self.patch_size]
K[0, 2] -= x
K[1, 2] -= y

data = {
"K": torch.from_numpy(K).float(),
"camtoworld": torch.from_numpy(camtoworlds).float(),
"image": torch.from_numpy(image).float(),
"image_mask": ~torch.from_numpy(image_mask).bool(),
"image_id": item, # the index of the image in the dataset
}
if mask is not None:
Expand Down
40 changes: 33 additions & 7 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ class Config:
render_traj_path: str = "interp"

# Path to the Mip-NeRF 360 dataset
data_dir: str = "data/360_v2/garden"
data_dir: str = "/home/paja/new_data/xplor/beach_structure/undistort/"
# Downsample factor for the dataset
data_factor: int = 4
data_factor: int = 1
# Directory to save results
result_dir: str = "results/garden"
result_dir: str = "results/office_lobby_new"
# Every N images there is a test image
test_every: int = 8
# Random crop size for training (experimental)
Expand All @@ -83,7 +83,7 @@ class Config:
# Number of training steps
max_steps: int = 30_000
# Steps to evaluate the model
eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
eval_steps: List[int] = field(default_factory=lambda: [])
# Steps to save the model
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Whether to save ply file (storage size can be large)
Expand Down Expand Up @@ -153,7 +153,7 @@ class Config:
app_opt_reg: float = 1e-6

# Enable bilateral grid. (experimental)
use_bilateral_grid: bool = False
use_bilateral_grid: bool = True
# Shape of the bilateral grid (X, Y, W)
bilateral_grid_shape: Tuple[int, int, int] = (16, 16, 8)

Expand Down Expand Up @@ -189,6 +189,23 @@ def adjust_steps(self, factor: float):
else:
assert_never(strategy)

def save_step_images(pixels, image_mask, save_dir, step):
"""
Save both pixels and mask for each step
pixels: [1,n,m,3]
image_mask: [1,n,m,1]
"""
# Create directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)

# Convert pixels to uint8 image format (multiply by 255 since they were divided earlier)
pixels_img = (pixels[0].detach().cpu().numpy() * 255).astype(np.uint8)
# Convert mask to binary uint8 format
mask_img = (image_mask[0].detach().cpu().numpy() * 255).astype(np.uint8)

# Save both images
imageio.imwrite(os.path.join(save_dir, f'step_{step:04d}_image.png'), pixels_img)
imageio.imwrite(os.path.join(save_dir, f'step_{step:04d}_mask.png'), mask_img[..., 0]) # Remove single channel dimension

def create_splats_with_optimizers(
parser: Parser,
Expand Down Expand Up @@ -575,8 +592,14 @@ def train(self):
data = next(trainloader_iter)

camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4]
image_mask = None
if image_mask in data:
image_mask = data["image_mask"].to(device)
image_mask = image_mask.permute(1,2,0).unsqueeze(0)
Ks = data["K"].to(device) # [1, 3, 3]
pixels = data["image"].to(device) / 255.0 # [1, H, W, 3]
#assert pixels.shape == image_mask.shape, f"pixels.shape {pixels.shape}, image_mask.shape {image_mask.shape}"
#save_step_images(pixels, image_mask, ".", step)
num_train_rays_per_step = (
pixels.shape[0] * pixels.shape[1] * pixels.shape[2]
)
Expand Down Expand Up @@ -615,6 +638,9 @@ def train(self):
else:
colors, depths = renders, None

if image_mask is not None:
colors = colors * image_mask
pixels = pixels * image_mask
if cfg.use_bilateral_grid:
grid_y, grid_x = torch.meshgrid(
(torch.arange(height, device=self.device) + 0.5) / height,
Expand Down Expand Up @@ -1096,8 +1122,8 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config):
Config(
init_opa=0.5,
init_scale=0.1,
opacity_reg=0.01,
scale_reg=0.01,
opacity_reg=0.001, # opacity reg down removes black floaters
scale_reg=0.05, # scale_reg up helps to regulate huge splats
strategy=MCMCStrategy(verbose=True),
),
),
Expand Down
105 changes: 34 additions & 71 deletions gsplat/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,89 +173,52 @@ def all_to_all_tensor_list(
splits: List[Union[int, Tensor]],
output_splits: Optional[List[Union[int, Tensor]]] = None,
) -> List[Tensor]:
"""Split and exchange tensors between all ranks in a many-to-many fashion.

Args:
world_size: The total number of ranks.
tensor_list: A list of tensors to split and exchange. The size of the first
dimension of all the tensors in this list should be the same. The rest
dimensions can be arbitrary. Shape: [(N, *), (N, *), ...]
splits: A list of integers representing the number of elements to send to each
rank. It will be used to split the tensor in the `tensor_list`.
The sum of the elements in this list should be equal to N. The size of this
list should be equal to the `world_size`.
output_splits: Splits of the output tensors. Could be pre-calculated via
`all_to_all_int32(world_size, splits)`. If not provided, it will
be calculated internally.

Returns:
A list of tensors exchanged between all ranks, where the i-th element is
corresponding to the i-th tensor in `tensor_list`. Note the shape of the
returned tensors might be different from the input tensors, depending on the
splits.

Examples:

.. code-block:: python

>>> # on rank 0
>>> # tensor_list = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])]
>>> # splits = [2, 1]

>>> # on rank 1
>>> # tensor_list = [torch.tensor([7, 8]), torch.tensor([9, 10])]
>>> # splits = [1, 1]

>>> collected = all_to_all_tensor_list(world_rank, world_size, tensor_list, splits)

>>> # on rank 0
>>> # [torch.tensor([1, 2, 7]), torch.tensor([4, 5, 9])]
>>> # on rank 1
>>> # [torch.tensor([3, 8]), torch.tensor([6, 10])]

"""
if world_size == 1:
return tensor_list

# Validate inputs
N = len(tensor_list[0])
for tensor in tensor_list:
assert len(tensor) == N, "All tensors should have the same first dimension size"
assert len(splits) == world_size, "The length of splits should be equal to world_size"

assert (
len(splits) == world_size
), "The length of splits should be equal to world_size"

# concatenate tensors and record their sizes
data = torch.cat([t.reshape(N, -1) for t in tensor_list], dim=-1)
sizes = [t.numel() // N for t in tensor_list]
# Pre-calculate output splits if not provided to avoid redundant computation
if output_splits is None:
output_splits = all_to_all_int32(world_size, splits, device=tensor_list[0].device)

# all_to_all
if output_splits is not None:
collected_splits = output_splits
else:
collected_splits = all_to_all_int32(world_size, splits, device=data.device)
collected = [
torch.empty((l, *data.shape[1:]), dtype=data.dtype, device=data.device)
for l in collected_splits
]
# torch.split requires tuple of integers
splits = [s.item() if isinstance(s, Tensor) else s for s in splits]
if data.requires_grad:
# differentiable all_to_all
distF.all_to_all(collected, data.split(splits, dim=0))
else:
# non-differentiable all_to_all
torch.distributed.all_to_all(collected, list(data.split(splits, dim=0)))
collected = torch.cat(collected, dim=0)

# split the collected tensor and reshape to the original shape
out_tensor_tuple = torch.split(collected, sizes, dim=-1)
# Process tensors one at a time to reduce peak memory
out_tensor_list = []
for out_tensor, tensor in zip(out_tensor_tuple, tensor_list):
for tensor in tensor_list:
# Reshape tensor without creating a new copy if possible
flat_tensor = tensor.reshape(N, -1)

# Pre-allocate output tensors
collected = [
torch.empty((l, flat_tensor.shape[1]),
dtype=flat_tensor.dtype,
device=flat_tensor.device)
for l in output_splits
]

# Normalize splits for torch.split
norm_splits = [s.item() if isinstance(s, Tensor) else s for s in splits]

# Perform all_to_all operation
if flat_tensor.requires_grad:
distF.all_to_all(collected, flat_tensor.split(norm_splits, dim=0))
else:
torch.distributed.all_to_all(collected, list(flat_tensor.split(norm_splits, dim=0)))

# Concatenate and reshape immediately
out_tensor = torch.cat(collected, dim=0)
out_tensor = out_tensor.view(-1, *tensor.shape[1:])
out_tensor_list.append(out_tensor)
return out_tensor_list

# Clear intermediate tensors
del collected, flat_tensor
torch.cuda.empty_cache()

return out_tensor_list

def _find_free_port():
import socket
Expand Down
Loading
Loading