Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
baa59a1
torch.compile + precision optimizations
MarcelLieb Mar 9, 2026
2be6135
Merge remote-tracking branch 'origin/main'
MarcelLieb Mar 9, 2026
769dc7b
fix tests and formatting
MarcelLieb Mar 9, 2026
cd70255
add warmup run to remove compilation overhead from benchmark
MarcelLieb Mar 9, 2026
6631537
implement preprocessing on GPU using torchvision
MarcelLieb Mar 9, 2026
ed578e8
add a batched frame processing function
MarcelLieb Mar 9, 2026
efc3a4d
use multiprocessing to speed up post-processing
MarcelLieb Mar 9, 2026
18db1cb
implement batched frame processing
MarcelLieb Mar 10, 2026
3852f0d
Revert "use multiprocessing to speed up post-processing"
MarcelLieb Mar 10, 2026
855d435
Revert "add a batched frame processing function"
MarcelLieb Mar 10, 2026
d1bc295
Revert "implement preprocessing on GPU using torchvision"
MarcelLieb Mar 10, 2026
01a84c4
Revert "add warmup run to remove compilation overhead from benchmark"
MarcelLieb Mar 10, 2026
c9e182e
Merge remote-tracking branch 'origin/main'
MarcelLieb Mar 10, 2026
4a148e2
update uv.lock
MarcelLieb Mar 10, 2026
fa1b4cb
fix lint
MarcelLieb Mar 10, 2026
173be39
Merge branch 'root-main'
MarcelLieb Mar 10, 2026
e3b2b03
implement batched frame processing
MarcelLieb Mar 10, 2026
7824bf7
Merge remote-tracking branch 'fork/batch-processing' into batch-proce…
MarcelLieb Mar 10, 2026
809f376
move compilation to function call to improve flexibility
MarcelLieb Mar 10, 2026
2ead927
bound threads with batch size and fix lint
MarcelLieb Mar 10, 2026
8775cb9
add qualitative comparison helper script
MarcelLieb Mar 11, 2026
8bb35e8
fix tests
MarcelLieb Mar 11, 2026
12308e4
initial GPU pipeline draft
MarcelLieb Mar 12, 2026
75cafcd
fix tests
MarcelLieb Mar 14, 2026
8e205bf
Merge branch 'root-main'
MarcelLieb Mar 14, 2026
a0fbc91
Merge branch 'root-main'
MarcelLieb Mar 14, 2026
8eb92d1
Merge branch 'root-main'
MarcelLieb Mar 16, 2026
1309163
optimize VRAM usage
MarcelLieb Mar 16, 2026
2649c10
Move to channels first format
MarcelLieb Mar 16, 2026
64d9938
improve logic
MarcelLieb Mar 16, 2026
ed7340d
optimize clean_matte
MarcelLieb Mar 17, 2026
dc9e109
Add config options
MarcelLieb Mar 17, 2026
9bacf50
Add changes to single frame method
MarcelLieb Mar 17, 2026
b210153
clean up
MarcelLieb Mar 17, 2026
42cf2f8
use new methods
MarcelLieb Mar 17, 2026
e7967f0
Merge branch 'main' of https://github.com/MarcelLieb/CorridorKey
MarcelLieb Mar 17, 2026
1c56ae2
improved fast despeckle
MarcelLieb Mar 18, 2026
e7dc97d
Merge remote-tracking branch 'fork/main'
MarcelLieb Mar 18, 2026
15d3f66
small fixes
MarcelLieb Mar 18, 2026
bd26920
fix compositing
MarcelLieb Mar 18, 2026
39eb7fa
small fixes
MarcelLieb Mar 18, 2026
6484006
fix tests
MarcelLieb Mar 19, 2026
8221f78
match batch processing
MarcelLieb Mar 19, 2026
1f1970e
parameterize tests over backend
MarcelLieb Mar 19, 2026
2e5b6ab
feat: add tests for batched frame processing
MarcelLieb Mar 19, 2026
4ba6d62
Merge branch 'root-main'
MarcelLieb Mar 19, 2026
cfa2012
feat: cleanup + reorganization
MarcelLieb Mar 20, 2026
bfd0b31
feat: use full float16 precision
MarcelLieb Mar 20, 2026
db17b63
fix: remove channels_last format
MarcelLieb Mar 21, 2026
93f8df8
feat: add CLI options
MarcelLieb Mar 21, 2026
d5208ab
feat: remove redundant batch processing method
MarcelLieb Mar 21, 2026
a90900b
fix: parameter name
MarcelLieb Mar 21, 2026
b437337
feat: add safeguards for mlx
MarcelLieb Mar 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions CorridorKeyModule/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pathlib import Path

import numpy as np
import torch

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -179,12 +180,12 @@ def _wrap_mlx_output(raw: dict, despill_strength: float, auto_despeckle: bool, d

# Apply despeckle (MLX stubs this)
if auto_despeckle:
processed_alpha = cu.clean_matte(alpha, area_threshold=despeckle_size, dilation=25, blur_size=5)
processed_alpha = cu.clean_matte_opencv(alpha, area_threshold=despeckle_size, dilation=25, blur_size=5)
else:
processed_alpha = alpha

# Apply despill (MLX stubs this)
fg_despilled = cu.despill(fg, green_limit_mode="average", strength=despill_strength)
fg_despilled = cu.despill_opencv(fg, green_limit_mode="average", strength=despill_strength)

# Composite over checkerboard for comp output
h, w = fg.shape[:2]
Expand Down Expand Up @@ -223,6 +224,7 @@ def process_frame(
despill_strength=1.0,
auto_despeckle=True,
despeckle_size=400,
**_kwargs,
):
"""Delegate to MLX engine, then normalize output to Torch contract."""
# MLX engine expects uint8 input — convert if float
Expand Down Expand Up @@ -287,4 +289,6 @@ def create_engine(
from CorridorKeyModule.inference_engine import CorridorKeyEngine

logger.info("Torch engine loaded: %s (device=%s)", ckpt.name, device)
return CorridorKeyEngine(checkpoint_path=str(ckpt), device=device or "cpu", img_size=img_size)
return CorridorKeyEngine(
checkpoint_path=str(ckpt), device=device or "cpu", img_size=img_size, model_precision=torch.float16
)
109 changes: 107 additions & 2 deletions CorridorKeyModule/core/color_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.v2.functional as TF


def _is_tensor(x: np.ndarray | torch.Tensor) -> bool:
Expand Down Expand Up @@ -202,7 +204,7 @@ def apply_garbage_matte(
return predicted_matte * garbage_mask


def despill(
def despill_opencv(
image: np.ndarray | torch.Tensor, green_limit_mode: str = "average", strength: float = 1.0
) -> np.ndarray | torch.Tensor:
"""
Expand Down Expand Up @@ -247,7 +249,63 @@ def despill(
return despilled


def clean_matte(alpha_np: np.ndarray, area_threshold: int = 300, dilation: int = 15, blur_size: int = 5) -> np.ndarray:
def despill_torch(image: torch.Tensor, strength: float) -> torch.Tensor:
"""GPU despill — keeps data on device."""
if strength <= 0.0:
return image
r, g, b = image[:, 0], image[:, 1], image[:, 2]
limit = (r + b) / 2.0
spill = torch.clamp(g - limit, min=0.0)
g_new = g - spill
r_new = r + spill * 0.5
b_new = b + spill * 0.5
despilled = torch.stack([r_new, g_new, b_new], dim=1)
if strength < 1.0:
return image * (1.0 - strength) + despilled * strength
return despilled


def connected_components(mask: torch.Tensor, min_component_width=1, max_iterations=100) -> torch.Tensor:
"""
Adapted from: https://gist.github.com/efirdc/5d8bd66859e574c683a504a4690ae8bc
Args:
mask: torch Tensor [B, 1, H, W] binary 1 or 0
min_component_width: int. Minimum width of connected components that are separated instead of merged.
max_iterations: int. Maximum number of flood fill iterations. Adjust based on expected component sizes.
Returns:
comp: torch Tensor [B, 1, H, W] with connected component labels (0 = background, 1..N = components)
"""
bs, _, H, W = mask.shape

# Reference implementation uses torch.arange instead of torch.randperm
# torch.randperm converges considerably faster and more uniformly
comp = (torch.randperm(W * H) + 1).repeat(bs, 1).view(mask.shape).float().to(mask.device)
comp[mask != 1] = 0

prev_comp = torch.zeros_like(comp)

iteration = 0

while not torch.equal(comp, prev_comp) and iteration < max_iterations:
prev_comp = comp.clone()
comp[mask == 1] = F.max_pool2d(
comp, kernel_size=(2 * min_component_width) + 1, stride=1, padding=min_component_width
)[mask == 1]
iteration += 1

comp = comp.long()
# Relabel components to have contiguous labels starting from 1
unique_labels = torch.unique(comp)
label_map = torch.zeros(unique_labels.max().item() + 1, dtype=torch.long, device=mask.device)
label_map[unique_labels] = torch.arange(len(unique_labels), device=mask.device)
comp = label_map[comp]

return comp


def clean_matte_opencv(
alpha_np: np.ndarray, area_threshold: int = 300, dilation: int = 15, blur_size: int = 5
) -> np.ndarray:
"""
Cleans up small disconnected components (like tracking markers) from a predicted alpha matte.
alpha_np: Numpy array [H, W] or [H, W, 1] float (0.0 - 1.0)
Expand Down Expand Up @@ -295,6 +353,39 @@ def clean_matte(alpha_np: np.ndarray, area_threshold: int = 300, dilation: int =
return result_alpha


def clean_matte_torch(alpha: torch.Tensor, area_threshold: int, dilation: int, blur_size: int) -> torch.Tensor:
"""
Cleans up small disconnected components (like tracking markers) from a predicted alpha matte.
Supports fully running on the GPU
alpha_np: torch Tensor [B, 1, H, W] (0.0 - 1.0)
"""
_device = alpha.device
mask = alpha > 0.5 # [B, 1, H, W]

# Find the largest connected components in the mask
# only a limited amount of iterations is needed to find components above the area threshold
components = connected_components(mask, max_iterations=area_threshold // 8, min_component_width=2)
sizes = torch.bincount(components.flatten())
big_sizes = torch.nonzero(sizes >= area_threshold)

mask = torch.zeros_like(mask, dtype=torch.float32)
mask[torch.isin(components, big_sizes)] = 1.0

# Dilate back to restore edges of large regions
if dilation > 0:
# How many applications with kernel size 5 are needed to achieve the desired dilation radius
repeats = dilation // 2
for _ in range(repeats):
mask = F.max_pool2d(mask, 5, stride=1, padding=2)

# Blur for soft edges
if blur_size > 0:
k = int(blur_size * 2 + 1)
mask = TF.gaussian_blur(mask, [k, k])

return alpha * mask


def create_checkerboard(
width: int, height: int, checker_size: int = 64, color1: float = 0.2, color2: float = 0.4
) -> np.ndarray:
Expand All @@ -321,3 +412,17 @@ def create_checkerboard(

# Make it 3-channel
return np.stack([bg_img, bg_img, bg_img], axis=-1)


@functools.lru_cache(maxsize=4)
def get_checkerboard_linear_torch(w: int, h: int, device: torch.device) -> torch.Tensor:
"""Return a cached checkerboard tensor [3, H, W] on device in linear space."""
checker_size = 128
y_coords = torch.arange(h, device=device) // checker_size
x_coords = torch.arange(w, device=device) // checker_size
y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing="ij")
checker = ((x_grid + y_grid) % 2).float()
# Map 0 -> 0.15, 1 -> 0.55 (sRGB), then convert to linear before caching
bg_srgb = checker * 0.4 + 0.15 # [H, W]
bg_srgb_3 = bg_srgb.unsqueeze(0).expand(3, -1, -1)
return srgb_to_linear(bg_srgb_3)
Loading