diff --git a/CorridorKeyModule/backend.py b/CorridorKeyModule/backend.py index 2e50d119..f4a097a6 100644 --- a/CorridorKeyModule/backend.py +++ b/CorridorKeyModule/backend.py @@ -12,6 +12,7 @@ from pathlib import Path import numpy as np +import torch logger = logging.getLogger(__name__) @@ -179,12 +180,12 @@ def _wrap_mlx_output(raw: dict, despill_strength: float, auto_despeckle: bool, d # Apply despeckle (MLX stubs this) if auto_despeckle: - processed_alpha = cu.clean_matte(alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) + processed_alpha = cu.clean_matte_opencv(alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) else: processed_alpha = alpha # Apply despill (MLX stubs this) - fg_despilled = cu.despill(fg, green_limit_mode="average", strength=despill_strength) + fg_despilled = cu.despill_opencv(fg, green_limit_mode="average", strength=despill_strength) # Composite over checkerboard for comp output h, w = fg.shape[:2] @@ -223,6 +224,7 @@ def process_frame( despill_strength=1.0, auto_despeckle=True, despeckle_size=400, + **_kwargs, ): """Delegate to MLX engine, then normalize output to Torch contract.""" # MLX engine expects uint8 input — convert if float @@ -287,4 +289,6 @@ def create_engine( from CorridorKeyModule.inference_engine import CorridorKeyEngine logger.info("Torch engine loaded: %s (device=%s)", ckpt.name, device) - return CorridorKeyEngine(checkpoint_path=str(ckpt), device=device or "cpu", img_size=img_size) + return CorridorKeyEngine( + checkpoint_path=str(ckpt), device=device or "cpu", img_size=img_size, model_precision=torch.float16 + ) diff --git a/CorridorKeyModule/core/color_utils.py b/CorridorKeyModule/core/color_utils.py index 4a3b17e0..6ac059c9 100644 --- a/CorridorKeyModule/core/color_utils.py +++ b/CorridorKeyModule/core/color_utils.py @@ -6,6 +6,8 @@ import cv2 import numpy as np import torch +import torch.nn.functional as F +import torchvision.transforms.v2.functional as TF def _is_tensor(x: np.ndarray | torch.Tensor) -> bool: @@ -202,7 +204,7 @@ def apply_garbage_matte( return predicted_matte * garbage_mask -def despill( +def despill_opencv( image: np.ndarray | torch.Tensor, green_limit_mode: str = "average", strength: float = 1.0 ) -> np.ndarray | torch.Tensor: """ @@ -247,7 +249,63 @@ def despill( return despilled -def clean_matte(alpha_np: np.ndarray, area_threshold: int = 300, dilation: int = 15, blur_size: int = 5) -> np.ndarray: +def despill_torch(image: torch.Tensor, strength: float) -> torch.Tensor: + """GPU despill — keeps data on device.""" + if strength <= 0.0: + return image + r, g, b = image[:, 0], image[:, 1], image[:, 2] + limit = (r + b) / 2.0 + spill = torch.clamp(g - limit, min=0.0) + g_new = g - spill + r_new = r + spill * 0.5 + b_new = b + spill * 0.5 + despilled = torch.stack([r_new, g_new, b_new], dim=1) + if strength < 1.0: + return image * (1.0 - strength) + despilled * strength + return despilled + + +def connected_components(mask: torch.Tensor, min_component_width=1, max_iterations=100) -> torch.Tensor: + """ + Adapted from: https://gist.github.com/efirdc/5d8bd66859e574c683a504a4690ae8bc + Args: + mask: torch Tensor [B, 1, H, W] binary 1 or 0 + min_component_width: int. Minimum width of connected components that are separated instead of merged. + max_iterations: int. Maximum number of flood fill iterations. Adjust based on expected component sizes. + Returns: + comp: torch Tensor [B, 1, H, W] with connected component labels (0 = background, 1..N = components) + """ + bs, _, H, W = mask.shape + + # Reference implementation uses torch.arange instead of torch.randperm + # torch.randperm converges considerably faster and more uniformly + comp = (torch.randperm(W * H) + 1).repeat(bs, 1).view(mask.shape).float().to(mask.device) + comp[mask != 1] = 0 + + prev_comp = torch.zeros_like(comp) + + iteration = 0 + + while not torch.equal(comp, prev_comp) and iteration < max_iterations: + prev_comp = comp.clone() + comp[mask == 1] = F.max_pool2d( + comp, kernel_size=(2 * min_component_width) + 1, stride=1, padding=min_component_width + )[mask == 1] + iteration += 1 + + comp = comp.long() + # Relabel components to have contiguous labels starting from 1 + unique_labels = torch.unique(comp) + label_map = torch.zeros(unique_labels.max().item() + 1, dtype=torch.long, device=mask.device) + label_map[unique_labels] = torch.arange(len(unique_labels), device=mask.device) + comp = label_map[comp] + + return comp + + +def clean_matte_opencv( + alpha_np: np.ndarray, area_threshold: int = 300, dilation: int = 15, blur_size: int = 5 +) -> np.ndarray: """ Cleans up small disconnected components (like tracking markers) from a predicted alpha matte. alpha_np: Numpy array [H, W] or [H, W, 1] float (0.0 - 1.0) @@ -295,6 +353,39 @@ def clean_matte(alpha_np: np.ndarray, area_threshold: int = 300, dilation: int = return result_alpha +def clean_matte_torch(alpha: torch.Tensor, area_threshold: int, dilation: int, blur_size: int) -> torch.Tensor: + """ + Cleans up small disconnected components (like tracking markers) from a predicted alpha matte. + Supports fully running on the GPU + alpha_np: torch Tensor [B, 1, H, W] (0.0 - 1.0) + """ + _device = alpha.device + mask = alpha > 0.5 # [B, 1, H, W] + + # Find the largest connected components in the mask + # only a limited amount of iterations is needed to find components above the area threshold + components = connected_components(mask, max_iterations=area_threshold // 8, min_component_width=2) + sizes = torch.bincount(components.flatten()) + big_sizes = torch.nonzero(sizes >= area_threshold) + + mask = torch.zeros_like(mask, dtype=torch.float32) + mask[torch.isin(components, big_sizes)] = 1.0 + + # Dilate back to restore edges of large regions + if dilation > 0: + # How many applications with kernel size 5 are needed to achieve the desired dilation radius + repeats = dilation // 2 + for _ in range(repeats): + mask = F.max_pool2d(mask, 5, stride=1, padding=2) + + # Blur for soft edges + if blur_size > 0: + k = int(blur_size * 2 + 1) + mask = TF.gaussian_blur(mask, [k, k]) + + return alpha * mask + + def create_checkerboard( width: int, height: int, checker_size: int = 64, color1: float = 0.2, color2: float = 0.4 ) -> np.ndarray: @@ -321,3 +412,17 @@ def create_checkerboard( # Make it 3-channel return np.stack([bg_img, bg_img, bg_img], axis=-1) + + +@functools.lru_cache(maxsize=4) +def get_checkerboard_linear_torch(w: int, h: int, device: torch.device) -> torch.Tensor: + """Return a cached checkerboard tensor [3, H, W] on device in linear space.""" + checker_size = 128 + y_coords = torch.arange(h, device=device) // checker_size + x_coords = torch.arange(w, device=device) // checker_size + y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing="ij") + checker = ((x_grid + y_grid) % 2).float() + # Map 0 -> 0.15, 1 -> 0.55 (sRGB), then convert to linear before caching + bg_srgb = checker * 0.4 + 0.15 # [H, W] + bg_srgb_3 = bg_srgb.unsqueeze(0).expand(3, -1, -1) + return srgb_to_linear(bg_srgb_3) diff --git a/CorridorKeyModule/inference_engine.py b/CorridorKeyModule/inference_engine.py old mode 100644 new mode 100755 index 44d28150..3f43588d --- a/CorridorKeyModule/inference_engine.py +++ b/CorridorKeyModule/inference_engine.py @@ -9,6 +9,9 @@ import numpy as np import torch import torch.nn.functional as F +import torchvision +import torchvision.transforms.v2 as T +import torchvision.transforms.v2.functional as TF from .core import color_utils as cu from .core.model_transformer import GreenFormer @@ -31,8 +34,8 @@ def __init__( self.checkpoint_path = checkpoint_path self.use_refiner = use_refiner - self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3) - self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3) + self.mean = torch.tensor([0.485, 0.456, 0.406], dtype=model_precision, device=self.device) + self.std = torch.tensor([0.229, 0.224, 0.225], dtype=model_precision, device=self.device) if mixed_precision or model_precision != torch.float32: # Use faster matrix multiplication implementation @@ -47,22 +50,11 @@ def __init__( self.model_precision = model_precision - model = self._load_model().to(model_precision) + self.model = self._load_model() - # We only tested compilation on windows and linux. For other platforms compilation is disabled as a precaution. + # We only tested compilation on Windows and Linux. For other platforms compilation is disabled as a precaution. if sys.platform == "linux" or sys.platform == "win32": - # Try compiling the model. Fallback to eager mode if it fails. - try: - self.model = torch.compile(model) - # Trigger compilation with a dummy input - dummy_input = torch.zeros(1, 4, img_size, img_size, dtype=model_precision, device=self.device) - with torch.inference_mode(): - self.model(dummy_input) - except Exception as e: - logger.info(f"Model compilation failed with error: {e}") - logger.warning("Model compilation failed. Falling back to eager mode.") - torch.cuda.empty_cache() - self.model = model + self._compile() def _load_model(self) -> GreenFormer: logger.info("Loading CorridorKey from %s", self.checkpoint_path) @@ -119,99 +111,71 @@ def _load_model(self) -> GreenFormer: if len(unexpected) > 0: print(f"[Warning] Unexpected keys: {unexpected}") - return model - - @torch.inference_mode() - def process_frame( - self, - image: np.ndarray, - mask_linear: np.ndarray, - refiner_scale: float = 1.0, - input_is_linear: bool = False, - fg_is_straight: bool = True, - despill_strength: float = 1.0, - auto_despeckle: bool = True, - despeckle_size: int = 400, - ) -> dict[str, np.ndarray]: - """ - Process a single frame. - Args: - image: Numpy array [H, W, 3] (0.0-1.0 or 0-255). - - If input_is_linear=False (Default): Assumed sRGB. - - If input_is_linear=True: Assumed Linear. - mask_linear: Numpy array [H, W] or [H, W, 1] (0.0-1.0). Assumed Linear. - refiner_scale: Multiplier for Refiner Deltas (default 1.0). - input_is_linear: bool. If True, resizes in Linear then transforms to sRGB. - If False, resizes in sRGB (standard). - fg_is_straight: bool. If True, assumes FG output is Straight (unpremultiplied). - If False, assumes FG output is Premultiplied. - despill_strength: float. 0.0 to 1.0 multiplier for the despill effect. - auto_despeckle: bool. If True, cleans up small disconnected components from the predicted alpha matte. - despeckle_size: int. Minimum number of consecutive pixels required to keep an island. - Returns: - dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)} - """ - # 1. Inputs Check & Normalization - if image.dtype == np.uint8: - image = image.astype(np.float32) / 255.0 + model = model.to(self.model_precision) - if mask_linear.dtype == np.uint8: - mask_linear = mask_linear.astype(np.float32) / 255.0 - - h, w = image.shape[:2] - - # Ensure Mask Shape - if mask_linear.ndim == 2: - mask_linear = mask_linear[:, :, np.newaxis] + return model + def _compile(self): + try: + compiled_model = torch.compile(self.model, mode="max-autotune") + # Trigger compilation with a dummy input + dummy_input = torch.zeros( + 1, 4, self.img_size, self.img_size, dtype=self.model_precision, device=self.device + ) + with torch.inference_mode(): + compiled_model(dummy_input) + self.model = compiled_model + + except Exception as e: + logger.info(f"Compilation error: {e}") + logger.warning("Model compilation failed. Falling back to eager mode.") + torch.cuda.empty_cache() + + def _preprocess_input( + self, image_batch: torch.Tensor, mask_batch_linear: torch.Tensor, input_is_linear: bool + ) -> torch.Tensor: # 2. Resize to Model Size # If input is linear, we resize in linear to preserve energy/highlights, # THEN convert to sRGB for the model. + image_batch = TF.resize( + image_batch, + [self.img_size, self.img_size], + interpolation=T.InterpolationMode.BILINEAR, + ) if input_is_linear: - # Resize in Linear - img_resized_lin = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) - # Convert to sRGB for Model - img_resized = cu.linear_to_srgb(img_resized_lin) - else: - # Standard sRGB Resize - img_resized = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) - - mask_resized = cv2.resize(mask_linear, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) + image_batch = cu.linear_to_srgb(image_batch) - if mask_resized.ndim == 2: - mask_resized = mask_resized[:, :, np.newaxis] + mask_batch_linear = TF.resize( + mask_batch_linear, + [self.img_size, self.img_size], + interpolation=T.InterpolationMode.BILINEAR, + ) # 3. Normalize (ImageNet) # Model expects sRGB input normalized - img_norm = (img_resized - self.mean) / self.std + image_batch = TF.normalize(image_batch, self.mean, self.std) # 4. Prepare Tensor - inp_np = np.concatenate([img_norm, mask_resized], axis=-1) # [H, W, 4] - inp_t = torch.from_numpy(inp_np.transpose((2, 0, 1))).unsqueeze(0).to(self.model_precision).to(self.device) - - # 5. Inference - # Hook for Refiner Scaling - handle = None - if refiner_scale != 1.0 and self.model.refiner is not None: + inp_concat = torch.concat((image_batch, mask_batch_linear), -3) # [4, H, W] - def scale_hook(module, input, output): - return output * refiner_scale - - handle = self.model.refiner.register_forward_hook(scale_hook) - - with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.mixed_precision): - out = self.model(inp_t) - - if handle: - handle.remove() - - pred_alpha = out["alpha"] - pred_fg = out["fg"] # Output is sRGB (Sigmoid) + return inp_concat + def _postprocess_opencv( + self, + pred_alpha: torch.Tensor, + pred_fg: torch.Tensor, + w: int, + h: int, + fg_is_straight: bool, + despill_strength: float, + auto_despeckle: bool, + despeckle_size: int, + generate_comp: bool, + ) -> dict[str, np.ndarray]: # 6. Post-Process (Resize Back to Original Resolution) # We use Lanczos4 for high-quality resampling to minimize blur when going back to 4K/Original. - res_alpha = pred_alpha[0].permute(1, 2, 0).float().cpu().numpy() - res_fg = pred_fg[0].permute(1, 2, 0).float().cpu().numpy() + res_alpha = pred_alpha.permute(1, 2, 0).cpu().numpy() + res_fg = pred_fg.permute(1, 2, 0).cpu().numpy() res_alpha = cv2.resize(res_alpha, (w, h), interpolation=cv2.INTER_LANCZOS4) res_fg = cv2.resize(res_fg, (w, h), interpolation=cv2.INTER_LANCZOS4) @@ -222,13 +186,13 @@ def scale_hook(module, input, output): # A. Clean Matte (Auto-Despeckle) if auto_despeckle: - processed_alpha = cu.clean_matte(res_alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) + processed_alpha = cu.clean_matte_opencv(res_alpha, area_threshold=despeckle_size, dilation=25, blur_size=5) else: processed_alpha = res_alpha # B. Despill FG # res_fg is sRGB. - fg_despilled = cu.despill(res_fg, green_limit_mode="average", strength=despill_strength) + fg_despilled = cu.despill_opencv(res_fg, green_limit_mode="average", strength=despill_strength) # C. Premultiply (for EXR Output) # CONVERT TO LINEAR FIRST! EXRs must house linear color premultiplied by linear alpha. @@ -243,16 +207,19 @@ def scale_hook(module, input, output): # 7. Composite (on Checkerboard) for checking # Generate Dark/Light Gray Checkerboard (in sRGB, convert to Linear) - bg_srgb = cu.create_checkerboard(w, h, checker_size=128, color1=0.15, color2=0.55) - bg_lin = cu.srgb_to_linear(bg_srgb) + if generate_comp: + bg_srgb = cu.create_checkerboard(w, h, checker_size=128, color1=0.15, color2=0.55) + bg_lin = cu.srgb_to_linear(bg_srgb) - if fg_is_straight: - comp_lin = cu.composite_straight(fg_despilled_lin, bg_lin, processed_alpha) - else: - # If premultiplied model, we shouldn't multiply again (though our pipeline forces straight) - comp_lin = cu.composite_premul(fg_despilled_lin, bg_lin, processed_alpha) + if fg_is_straight: + comp_lin = cu.composite_straight(fg_despilled_lin, bg_lin, processed_alpha) + else: + # If premultiplied model, we shouldn't multiply again (though our pipeline forces straight) + comp_lin = cu.composite_premul(fg_despilled_lin, bg_lin, processed_alpha) - comp_srgb = cu.linear_to_srgb(comp_lin) + comp_srgb = cu.linear_to_srgb(comp_lin) + else: + comp_srgb = None return { # type: ignore[return-value] # cu.* returns ndarray|Tensor but inputs are always ndarray here "alpha": res_alpha, # Linear, Raw Prediction @@ -260,3 +227,201 @@ def scale_hook(module, input, output): "comp": comp_srgb, # sRGB, Composite "processed": processed_rgba, # Linear/Premul, RGBA, Garbage Matted & Despilled } + + def _postprocess_torch( + self, + pred_alpha: torch.Tensor, + pred_fg: torch.Tensor, + w: int, + h: int, + fg_is_straight: bool, + despill_strength: float, + auto_despeckle: bool, + despeckle_size: int, + generate_comp: bool, + ) -> list[dict[str, np.ndarray]]: + """Post-process on GPU, transfer final results to CPU. + + When ``sync=True`` (default), blocks until transfer completes and + returns numpy arrays. When ``sync=False``, starts the DMA + non-blocking and returns a :class:`PendingTransfer` — call + ``.resolve()`` to get the numpy dict later. + """ + # Resize on GPU using torchvision (much faster than cv2 at 4K) + alpha = TF.resize( + pred_alpha.float(), + [h, w], + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ) + fg = TF.resize( + pred_fg.float(), + [h, w], + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ) + + del pred_fg, pred_alpha + torch.cuda.empty_cache() + + # A. Clean matte + if auto_despeckle: + processed_alpha = cu.clean_matte_torch(alpha, despeckle_size, dilation=25, blur_size=5) + else: + processed_alpha = alpha + + # B. Despill on GPU + processed_fg = cu.despill_torch(fg, despill_strength) + + # C. sRGB → linear on GPU + processed_fg_lin = cu.srgb_to_linear(processed_fg) + + # D. Premultiply on GPU + processed_fg = cu.premultiply(processed_fg_lin, processed_alpha) + + # E. Pack RGBA on GPU + packed_processed = torch.cat([processed_fg, processed_alpha], dim=1) + + # F. Composite + if generate_comp: + bg_lin = cu.get_checkerboard_linear_torch(w, h, processed_fg.device) + if fg_is_straight: + comp = cu.composite_straight(processed_fg_lin, bg_lin, processed_alpha) + else: + comp = cu.composite_premul(processed_fg_lin, bg_lin, processed_alpha) + comp = cu.linear_to_srgb(comp) # [H, W, 3] opaque + else: + del processed_fg, processed_alpha + comp = [None] * alpha.shape[0] # placeholder + + alpha, fg, comp, packed_processed = ( + alpha.cpu().permute(0, 2, 3, 1).numpy(), + fg.cpu().permute(0, 2, 3, 1).numpy(), + comp.cpu().permute(0, 2, 3, 1).numpy() if generate_comp else comp, + packed_processed.cpu().permute(0, 2, 3, 1).numpy(), + ) + + out = [] + for i in range(alpha.shape[0]): + result = { + "alpha": alpha[i], + "fg": fg[i], + "comp": comp[i], + "processed": packed_processed[i], + } + out.append(result) + return out + + @torch.inference_mode() + def process_frame( + self, + image: np.ndarray, + mask_linear: np.ndarray, + refiner_scale: float = 1.0, + input_is_linear: bool = False, + fg_is_straight: bool = True, + despill_strength: float = 1.0, + auto_despeckle: bool = True, + despeckle_size: int = 400, + generate_comp: bool = True, + post_process_on_gpu: bool = True, + ) -> dict[str, np.ndarray] | list[dict[str, np.ndarray]]: + """ + Process a single frame. + Args: + image: Numpy array [H, W, 3] or [B, H, W, 3] (0.0-1.0 or 0-255). + - If input_is_linear=False (Default): Assumed sRGB. + - If input_is_linear=True: Assumed Linear. + mask_linear: Numpy array [H, W] or [B, H, W] or [H, W, 1] or [B, H, W, 1] (0.0-1.0). Assumed Linear. + refiner_scale: Multiplier for Refiner Deltas (default 1.0). + input_is_linear: bool. If True, resizes in Linear then transforms to sRGB. + If False, resizes in sRGB (standard). + fg_is_straight: bool. If True, assumes FG output is Straight (unpremultiplied). + If False, assumes FG output is Premultiplied. + despill_strength: float. 0.0 to 1.0 multiplier for the despill effect. + auto_despeckle: bool. If True, cleans up small disconnected components from the predicted alpha matte. + despeckle_size: int. Minimum number of consecutive pixels required to keep an island. + generate_comp: bool. If True, also generates a composite on checkerboard for quick checking. + post_process_on_gpu: bool. If True, performs post-processing on GPU using PyTorch instead of OpenCV. + Returns: + dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)} + """ + torch.compiler.cudagraph_mark_step_begin() + + # If input is a single image, add batch dimension + if image.ndim == 3: + image = image[np.newaxis, :] + mask_linear = mask_linear[np.newaxis, :] + + bs, h, w = image.shape[:3] + + # 1. Inputs Check & Normalization + image = TF.to_dtype( + torch.from_numpy(image).permute((0, 3, 1, 2)), + self.model_precision, + scale=True, + ).to(self.device, non_blocking=True) + mask_linear = TF.to_dtype( + torch.from_numpy(mask_linear.reshape((bs, h, w, 1))).permute((0, 3, 1, 2)), + self.model_precision, + scale=True, + ).to(self.device, non_blocking=True) + + inp_t = self._preprocess_input(image, mask_linear, input_is_linear) + + # Free up unused VRAM in order to keep peak usage down and avoid OOM errors + del image, mask_linear + + # 5. Inference + # Hook for Refiner Scaling + handle = None + if refiner_scale != 1.0 and self.model.refiner is not None: + + def scale_hook(module, input, output): + return output * refiner_scale + + handle = self.model.refiner.register_forward_hook(scale_hook) + + with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.mixed_precision): + prediction = self.model(inp_t) + + # Free up unused VRAM in order to keep peak usage down and avoid OOM errors + del inp_t + + if handle: + handle.remove() + + if post_process_on_gpu: + out = self._postprocess_torch( + prediction["alpha"], + prediction["fg"], + w, + h, + fg_is_straight, + despill_strength, + auto_despeckle, + despeckle_size, + generate_comp, + ) + else: + # Move prediction to CPU before post-processing + pred_alpha = prediction["alpha"].cpu().float() + pred_fg = prediction["fg"].cpu().float() + + out = [] + for i in range(bs): + result = self._postprocess_opencv( + pred_alpha[i], + pred_fg[i], + w, + h, + fg_is_straight, + despill_strength, + auto_despeckle, + despeckle_size, + generate_comp, + ) + out.append(result) + + if bs == 1: + return out[0] + + return out diff --git a/clip_manager.py b/clip_manager.py index 4bc1bd44..18d58c6a 100644 --- a/clip_manager.py +++ b/clip_manager.py @@ -38,6 +38,8 @@ class InferenceSettings: auto_despeckle: bool = True despeckle_size: int = 400 refiner_scale: float = 1.0 + generate_comp: bool = True + gpu_post_processing: bool = False # Core Paths @@ -763,6 +765,8 @@ def run_inference( auto_despeckle=settings.auto_despeckle, despeckle_size=settings.despeckle_size, refiner_scale=settings.refiner_scale, + generate_comp=settings.generate_comp, + post_process_on_gpu=settings.gpu_post_processing, ) pred_fg = res["fg"] # sRGB @@ -782,10 +786,11 @@ def run_inference( cv2.imwrite(os.path.join(matte_dir, f"{input_stem}.exr"), pred_alpha, EXR_WRITE_FLAGS) # 5. Generate Reference Comp - comp_srgb = res["comp"] - # Save Comp (PNG 8-bit) - comp_bgr = cv2.cvtColor((np.clip(comp_srgb, 0.0, 1.0) * 255.0).astype(np.uint8), cv2.COLOR_RGB2BGR) - cv2.imwrite(os.path.join(comp_dir, f"{input_stem}.png"), comp_bgr) + if res["comp"] is not None: + comp_srgb = res["comp"] + # Save Comp (PNG 8-bit) + comp_bgr = cv2.cvtColor((np.clip(comp_srgb, 0.0, 1.0) * 255.0).astype(np.uint8), cv2.COLOR_RGB2BGR) + cv2.imwrite(os.path.join(comp_dir, f"{input_stem}.png"), comp_bgr) # 6. Save Processed (RGBA EXR) if "processed" in res: diff --git a/corridorkey_cli.py b/corridorkey_cli.py index 575ae128..b040b3fe 100644 --- a/corridorkey_cli.py +++ b/corridorkey_cli.py @@ -43,6 +43,7 @@ run_videomama, scan_clips, ) +from CorridorKeyModule.backend import resolve_backend from device_utils import resolve_device logger = logging.getLogger(__name__) @@ -137,6 +138,8 @@ def _prompt_inference_settings( default_despeckle: bool | None = None, default_despeckle_size: int | None = None, default_refiner: float | None = None, + default_comp: bool | None = None, + default_gpu_post: bool | None = None, ) -> InferenceSettings: """Interactively prompt for inference settings, skipping any pre-filled values.""" console.print(Panel("Inference Settings", style="bold cyan")) @@ -189,12 +192,31 @@ def _prompt_inference_settings( except ValueError: refiner_scale = 1.0 + if resolve_backend() == "torch": + if default_comp is not None: + generate_comp = default_comp + else: + generate_comp = Confirm.ask( + "Generate composition previews", + default=True, + ) + + if default_gpu_post is not None: + gpu_post_processing = default_gpu_post + else: + gpu_post_processing = Confirm.ask( + "Use GPU accelerated post-processing [dim](experimental)[/dim]", + default=False, + ) + return InferenceSettings( input_is_linear=input_is_linear, despill_strength=despill_strength, auto_despeckle=auto_despeckle, despeckle_size=despeckle_size, refiner_scale=refiner_scale, + generate_comp=generate_comp, + gpu_post_processing=gpu_post_processing, ) @@ -273,6 +295,14 @@ def run_inference_cmd( Optional[float], typer.Option("--refiner", help="Refiner strength multiplier (default: prompt)"), ] = None, + generate_comp: Annotated[ + Optional[bool], + typer.Option("--comp/--no-comp", help="Generate comp previews (default: prompt)"), + ] = None, + gpu_post: Annotated[ + Optional[bool], + typer.Option("--gpu-post/--cpu-post", help="Use GPU post-processing (default: prompt)"), + ] = None, ) -> None: """Run CorridorKey inference on clips with Input + AlphaHint. @@ -300,6 +330,8 @@ def run_inference_cmd( default_despeckle=despeckle, default_despeckle_size=despeckle_size, default_refiner=refiner, + default_comp=generate_comp, + default_gpu_post=gpu_post, ) with ProgressContext() as ctx_progress: diff --git a/test_outputs.py b/test_outputs.py new file mode 100644 index 00000000..7c1e9645 --- /dev/null +++ b/test_outputs.py @@ -0,0 +1,143 @@ +import os + +import torch +from torchvision.io import read_image +from torchvision.utils import save_image + +from CorridorKeyModule.inference_engine import CorridorKeyEngine + +# there is some compile weirdness when generating the images +torch._dynamo.config.cache_size_limit = 1024 + + +def load_engine(img_size, precision, mixed_precision): + return CorridorKeyEngine( + checkpoint_path="CorridorKeyModule/checkpoints/CorridorKey_v1.0.pth", + img_size=img_size, + device="cuda", + model_precision=precision, + mixed_precision=mixed_precision, + ) + + +def generate_test_images(img_path, mask_path): + img = read_image(img_path).permute(1, 2, 0).numpy() + mask = read_image(mask_path).permute(1, 2, 0).numpy() + img_sizes = [512, 1024, 2048] + precisions = [torch.float16, torch.float32, torch.float64] + for precision in precisions: + for img_size in img_sizes: + # Reset stats + torch.cuda.reset_peak_memory_stats() + + if precision == torch.float64 and img_size > 1024: + continue + + engine = load_engine(img_size, precision) + out = engine.process_frame(img, mask) + + save_image( + torch.from_numpy(out["fg"]).permute(2, 0, 1), + f"./Output/foreground_{img_size}_{str(precision)[-7:]}.png", + ) + save_image( + torch.from_numpy(out["alpha"]).permute(2, 0, 1), f"./Output/alpha_{img_size}_{str(precision)[-7:]}.png" + ) + + peak_vram = torch.cuda.max_memory_allocated() / (1024**3) + print(f"Precision: {precision}, Image Size: {img_size}, Peak VRAM: {peak_vram:.2f} GB") + + +def compare_implementations(src, comparison, output_dir="./Output"): + for _, _, files in os.walk(src): + for file in files: + src_img = read_image(str(os.path.join(src, file))).float() + comp_img = read_image(str(os.path.join(comparison, file))).float() + + is_mask = src_img.shape[0] == 1 or (src_img[0] == src_img[1]).all() and (src_img[1] == src_img[2]).all() + + difference = (src_img - comp_img).float() / 255 + + if is_mask: + difference = difference[0].unsqueeze(0) + difference = torch.cat( + (difference.clamp(-1, 0).abs(), difference.clamp(0, 1), torch.zeros_like(difference)), dim=0 + ) + print(difference.shape) + print(difference.min(), difference.max()) + else: + difference = difference.abs() + + os.makedirs(output_dir, exist_ok=True) + + save_image(difference, f"{output_dir}/diff_{file}") + + +def compare_floating_point_precision(folder, ref="float64"): + for _, _, files in os.walk(folder): + for file in files: + name, fmt = file.split(".") + typ, img_size, precision = name.split("_") + if precision != ref: + continue + float_ref = read_image(str(os.path.join(folder, file))).float() + float_32 = read_image(str(os.path.join(folder, f"{typ}_{img_size}_float32.{fmt}"))).float() + + is_mask = typ == "alpha" + + difference = (float_ref - float_32).float() / 255 + + if is_mask: + difference = difference[0].unsqueeze(0) + difference = torch.cat( + (difference.clamp(-1, 0).abs(), difference.clamp(0, 1), torch.zeros_like(difference)), dim=0 + ) + else: + difference = difference.abs() + print( + is_mask, + difference.min().item(), + difference.max().item(), + difference.mean().item(), + difference.median().item(), + ) + + save_image(difference, f"./Output/prec_{ref}_{typ}_{img_size}.{fmt}") + + +def compare_img_sizes(folder, ref=1024): + for _, _, files in os.walk(folder): + for file in files: + name, fmt = file.split(".") + typ, img_size, precision = name.split("_") + if img_size != str(ref): + continue + if precision == "float64": + continue + img_ref = read_image(str(os.path.join(folder, file))).float() + img_2048 = read_image(str(os.path.join(folder, f"{typ}_2048_{precision}.{fmt}"))).float() + + is_mask = typ == "alpha" + + difference = (img_ref - img_2048).float() / 255 + + if is_mask: + difference = difference[0].unsqueeze(0) + difference = torch.cat( + (difference.clamp(-1, 0).abs(), difference.clamp(0, 1), torch.zeros_like(difference)), dim=0 + ) + else: + difference = difference.abs() + print( + is_mask, + difference.min().item(), + difference.max().item(), + difference.mean().item(), + difference.median().item(), + ) + + save_image(difference, f"./Output/img_{ref}_{typ}_{precision}.{fmt}") + + +if __name__ == "__main__": + compare_implementations("./Output/gpu_full_res/Comp", "./Output/gpu_fp16/Comp", "./Output/diff/fp16_vs_fp32") diff --git a/test_vram.py b/test_vram.py index 2f734d8d..2306caf9 100644 --- a/test_vram.py +++ b/test_vram.py @@ -6,32 +6,53 @@ from CorridorKeyModule.inference_engine import CorridorKeyEngine -def process_frame(engine): +def process_frame(engine: CorridorKeyEngine): img = np.random.randint(0, 255, (2160, 3840, 3), dtype=np.uint8) mask = np.random.randint(0, 255, (2160, 3840), dtype=np.uint8) engine.process_frame(img, mask) +def batch_process_frame(engine: CorridorKeyEngine, batch_size: int): + imgs = np.random.randint(0, 255, (batch_size, 2160, 3840, 3), dtype=np.uint8) + masks = np.random.randint(0, 255, (batch_size, 2160, 3840), dtype=np.uint8) + + engine.batch_process_frames(imgs, masks) + + def test_vram(): + torch.backends.cudnn.benchmark = True + print("Loading engine...") engine = CorridorKeyEngine( checkpoint_path="CorridorKeyModule/checkpoints/CorridorKey_v1.0.pth", img_size=2048, device="cuda", model_precision=torch.float16, + mixed_precision=True, ) # Reset stats torch.cuda.reset_peak_memory_stats() - iterations = 24 + total_seconds = 6 + batch_size = 2 # works with a 16GB GPU + iterations = total_seconds * 24 // batch_size print(f"Running {iterations} inference passes...") - time = timeit.timeit(lambda: process_frame(engine), number=iterations) - print(f"Seconds per frame: {time / iterations}") + time = timeit.timeit( + lambda: batch_process_frame(engine, batch_size), + number=iterations, + setup=lambda: ( + batch_process_frame(engine, batch_size), + torch.cuda.synchronize(), + torch.cuda.empty_cache(), + print("Compilation and warmup complete, starting timed runs..."), + ), + ) + print(f"Seconds per frame: {time / (iterations * batch_size):.4f}") peak_vram = torch.cuda.max_memory_allocated() / (1024**3) - print(f"Peak VRAM used: {peak_vram:.2f} GB") + print(f"Peak VRAM used: {peak_vram:.2f} GiB") if __name__ == "__main__": diff --git a/tests/test_color_utils.py b/tests/test_color_utils.py index 10ea0aef..4a97569d 100644 --- a/tests/test_color_utils.py +++ b/tests/test_color_utils.py @@ -246,32 +246,32 @@ class TestDespill: def test_pure_green_reduced_average_mode_numpy(self): """A pure green pixel should have green clamped to (R+B)/2 = 0.""" img = _to_np([[0.0, 1.0, 0.0]]) - result = cu.despill(img, green_limit_mode="average", strength=1.0) + result = cu.despill_opencv(img, green_limit_mode="average", strength=1.0) # Green should be 0 (clamped to avg of R=0, B=0) assert result[0, 1] == pytest.approx(0.0, abs=1e-6) def test_pure_green_reduced_max_mode_numpy(self): """With 'max' mode, green clamped to max(R, B) = 0 for pure green.""" img = _to_np([[0.0, 1.0, 0.0]]) - result = cu.despill(img, green_limit_mode="max", strength=1.0) + result = cu.despill_opencv(img, green_limit_mode="max", strength=1.0) assert result[0, 1] == pytest.approx(0.0, abs=1e-6) def test_pure_red_unchanged_numpy(self): """A pixel with no green excess should not be modified.""" img = _to_np([[1.0, 0.0, 0.0]]) - result = cu.despill(img, green_limit_mode="average", strength=1.0) + result = cu.despill_opencv(img, green_limit_mode="average", strength=1.0) np.testing.assert_allclose(result, img, atol=1e-6) def test_strength_zero_is_noop_numpy(self): """strength=0 should return the input unchanged.""" img = _to_np([[0.2, 0.9, 0.1]]) - result = cu.despill(img, strength=0.0) + result = cu.despill_opencv(img, strength=0.0) np.testing.assert_allclose(result, img, atol=1e-7) def test_partial_green_average_mode_numpy(self): """Green slightly above (R+B)/2 should be reduced, not zeroed.""" img = _to_np([[0.4, 0.8, 0.2]]) - result = cu.despill(img, green_limit_mode="average", strength=1.0) + result = cu.despill_opencv(img, green_limit_mode="average", strength=1.0) limit = (0.4 + 0.2) / 2.0 # 0.3 expected_green = limit # green clamped to limit assert result[0, 1] == pytest.approx(expected_green, abs=1e-5) @@ -279,16 +279,16 @@ def test_partial_green_average_mode_numpy(self): def test_max_mode_higher_limit_than_average(self): """'max' mode uses max(R,B) which is >= (R+B)/2, so less despill.""" img = _to_np([[0.6, 0.8, 0.1]]) - result_avg = cu.despill(img, green_limit_mode="average", strength=1.0) - result_max = cu.despill(img, green_limit_mode="max", strength=1.0) + result_avg = cu.despill_opencv(img, green_limit_mode="average", strength=1.0) + result_max = cu.despill_opencv(img, green_limit_mode="max", strength=1.0) # max(R,B)=0.6 vs avg(R,B)=0.35, so max mode removes less green assert result_max[0, 1] >= result_avg[0, 1] def test_fractional_strength_interpolates(self): """strength=0.5 should produce a result between original and fully despilled.""" img = _to_np([[0.2, 0.9, 0.1]]) - full = cu.despill(img, green_limit_mode="average", strength=1.0) - half = cu.despill(img, green_limit_mode="average", strength=0.5) + full = cu.despill_opencv(img, green_limit_mode="average", strength=1.0) + half = cu.despill_opencv(img, green_limit_mode="average", strength=0.5) # Half-strength green should be between original green and fully despilled green assert half[0, 1] < img[0, 1] # less green than original assert half[0, 1] > full[0, 1] # more green than full despill @@ -300,8 +300,8 @@ def test_despill_torch(self): """Verify torch path matches numpy path.""" img_np = _to_np([[0.3, 0.9, 0.2]]) img_t = _to_torch([[0.3, 0.9, 0.2]]) - result_np = cu.despill(img_np, green_limit_mode="average", strength=1.0) - result_t = cu.despill(img_t, green_limit_mode="average", strength=1.0) + result_np = cu.despill_opencv(img_np, green_limit_mode="average", strength=1.0) + result_t = cu.despill_opencv(img_t, green_limit_mode="average", strength=1.0) np.testing.assert_allclose(result_np, result_t.numpy(), atol=1e-5) def test_green_below_limit_unchanged_numpy(self): @@ -315,7 +315,7 @@ def test_green_below_limit_unchanged_numpy(self): # G=0.3 is well below the average limit (0.8+0.6)/2 = 0.7 # spill_amount = max(0.3 - 0.7, 0) = 0 → output equals input img = _to_np([[0.8, 0.3, 0.6]]) - result = cu.despill(img, green_limit_mode="average", strength=1.0) + result = cu.despill_opencv(img, green_limit_mode="average", strength=1.0) np.testing.assert_allclose(result, img, atol=1e-6) @@ -335,7 +335,7 @@ def test_large_blob_preserved(self): """A single large opaque region should survive cleanup.""" matte = np.zeros((100, 100), dtype=np.float32) matte[20:80, 20:80] = 1.0 # 60x60 = 3600 pixels - result = cu.clean_matte(matte, area_threshold=300) + result = cu.clean_matte_opencv(matte, area_threshold=300) # Center of the blob should still be opaque assert result[50, 50] > 0.9 @@ -343,7 +343,7 @@ def test_small_blob_removed(self): """A tiny blob below the threshold should be removed.""" matte = np.zeros((100, 100), dtype=np.float32) matte[5:8, 5:8] = 1.0 # 3x3 = 9 pixels - result = cu.clean_matte(matte, area_threshold=300) + result = cu.clean_matte_opencv(matte, area_threshold=300) assert result[6, 6] == pytest.approx(0.0, abs=1e-5) def test_mixed_blobs(self): @@ -354,7 +354,7 @@ def test_mixed_blobs(self): # Small blob: 5x5 = 25 px matte[150:155, 150:155] = 1.0 - result = cu.clean_matte(matte, area_threshold=100) + result = cu.clean_matte_opencv(matte, area_threshold=100) assert result[35, 35] > 0.9 # large blob center preserved assert result[152, 152] < 0.01 # small blob removed @@ -362,7 +362,7 @@ def test_3d_input_preserved(self): """[H, W, 1] input should return [H, W, 1] output.""" matte = np.zeros((50, 50, 1), dtype=np.float32) matte[10:40, 10:40, 0] = 1.0 - result = cu.clean_matte(matte, area_threshold=100) + result = cu.clean_matte_opencv(matte, area_threshold=100) assert result.ndim == 3 assert result.shape[2] == 1 diff --git a/tests/test_inference_engine.py b/tests/test_inference_engine.py index 1223e39e..243b72bf 100644 --- a/tests/test_inference_engine.py +++ b/tests/test_inference_engine.py @@ -24,7 +24,7 @@ # --------------------------------------------------------------------------- -def _make_engine_with_mock(mock_greenformer, img_size=64): +def _make_engine_with_mock(mock_greenformer, img_size=64, device="cpu"): """Create a CorridorKeyEngine with a mocked model, bypassing __init__. Manually sets the attributes that __init__ would create, avoiding the @@ -33,12 +33,12 @@ def _make_engine_with_mock(mock_greenformer, img_size=64): from CorridorKeyModule.inference_engine import CorridorKeyEngine engine = object.__new__(CorridorKeyEngine) - engine.device = torch.device("cpu") + engine.device = torch.device(device) engine.img_size = img_size engine.checkpoint_path = "/fake/checkpoint.pth" engine.use_refiner = False - engine.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3) - engine.std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3) + engine.mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32, device=torch.device(device)).reshape(3, 1, 1) + engine.std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32, device=torch.device(device)).reshape(3, 1, 1) engine.model = mock_greenformer engine.model_precision = torch.float32 engine.mixed_precision = True @@ -53,47 +53,82 @@ def _make_engine_with_mock(mock_greenformer, img_size=64): class TestProcessFrameOutputs: """Verify shape, dtype, and key presence of process_frame outputs.""" - def test_output_keys(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + @pytest.mark.parametrize("batched", [True, False]) + def test_output_keys(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """process_frame must return alpha, fg, comp, and processed.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask) + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[0] + else: + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") assert "alpha" in result assert "fg" in result assert "comp" in result assert "processed" in result - def test_output_shapes_match_input(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + @pytest.mark.parametrize("batched", [True, False]) + def test_output_shapes_match_input(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """All outputs should match the spatial dimensions of the input.""" h, w = sample_frame_rgb.shape[:2] engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask) + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[0] + else: + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") assert result["alpha"].shape[:2] == (h, w) assert result["fg"].shape[:2] == (h, w) assert result["comp"].shape == (h, w, 3) assert result["processed"].shape == (h, w, 4) - def test_output_dtype_float32(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + @pytest.mark.parametrize("batched", [True, False]) + def test_output_dtype_float32(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """All outputs should be float32 numpy arrays.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask) + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[0] + else: + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") for key in ("alpha", "fg", "comp", "processed"): assert result[key].dtype == np.float32, f"{key} should be float32" - def test_alpha_output_range_is_zero_to_one(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + @pytest.mark.parametrize("batched", [True, False]) + def test_alpha_output_range_is_zero_to_one(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """Alpha output must be in [0, 1] — values outside this range corrupt compositing.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask) + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[0] + else: + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") alpha = result["alpha"] assert alpha.min() >= -0.01, f"alpha min {alpha.min():.4f} is below 0" assert alpha.max() <= 1.01, f"alpha max {alpha.max():.4f} is above 1" - def test_fg_output_range_is_zero_to_one(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + @pytest.mark.parametrize("batched", [True, False]) + def test_fg_output_range_is_zero_to_one(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """FG output must be in [0, 1] — required for downstream sRGB conversion and EXR export.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask) + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[0] + else: + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") fg = result["fg"] assert fg.min() >= -0.01, f"fg min {fg.min():.4f} is below 0" assert fg.max() <= 1.01, f"fg max {fg.max():.4f} is above 1" @@ -112,34 +147,64 @@ class TestProcessFrameColorSpace: When False (default), it resizes in sRGB directly. """ - def test_srgb_input_default(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + @pytest.mark.parametrize("batched", [True, False]) + def test_srgb_input_default(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """Default sRGB path should not crash and should return valid outputs.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask, input_is_linear=False) + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.process_frame( + sample_frame_rgb, sample_mask, input_is_linear=False, post_process_on_gpu=backend == "torch" + )[0] + else: + result = engine.process_frame( + sample_frame_rgb, sample_mask, input_is_linear=False, post_process_on_gpu=backend == "torch" + ) np.testing.assert_allclose(result["comp"], 0.545655, atol=1e-4) - def test_linear_input_path(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + @pytest.mark.parametrize("batched", [True, False]) + def test_linear_input_path(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """Linear input path should convert to sRGB before model input.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask, input_is_linear=True) - assert result["comp"].shape == sample_frame_rgb.shape - - def test_uint8_input_normalized(self, sample_mask, mock_greenformer): + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.process_frame( + sample_frame_rgb, sample_mask, input_is_linear=True, post_process_on_gpu=backend == "torch" + )[0] + else: + result = engine.process_frame( + sample_frame_rgb, sample_mask, input_is_linear=True, post_process_on_gpu=backend == "torch" + ) + assert result["comp"].shape == sample_frame_rgb.shape[1:] if batched else sample_frame_rgb.shape + + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + @pytest.mark.parametrize("batched", [True, False]) + def test_uint8_input_normalized(self, sample_mask, mock_greenformer, backend, batched): """uint8 input should be auto-converted to float32 [0, 1].""" img_uint8 = np.random.default_rng(42).integers(0, 256, (64, 64, 3), dtype=np.uint8) engine = _make_engine_with_mock(mock_greenformer) - # Should not crash — uint8 is auto-normalized to float32 - result = engine.process_frame(img_uint8, sample_mask) + if batched: + img_uint8 = np.stack([img_uint8] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.process_frame(img_uint8, sample_mask, post_process_on_gpu=backend == "torch")[0] + else: + result = engine.process_frame(img_uint8, sample_mask, post_process_on_gpu=backend == "torch") assert result["alpha"].dtype == np.float32 - def test_model_called_exactly_once(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + @pytest.mark.parametrize("batched", [True, False]) + def test_model_called_exactly_once(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """The neural network model must be called exactly once per process_frame() call. Double-inference would double latency and produce incorrect outputs. """ engine = _make_engine_with_mock(mock_greenformer) - engine.process_frame(sample_frame_rgb, sample_mask) + engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") assert mock_greenformer.call_count == 1 @@ -151,7 +216,9 @@ def test_model_called_exactly_once(self, sample_frame_rgb, sample_mask, mock_gre class TestProcessFramePostProcessing: """Verify post-processing: despill, despeckle, premultiply, composite.""" - def test_despill_strength_reduces_green_in_spill_pixels(self, sample_frame_rgb, sample_mask): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + @pytest.mark.parametrize("batched", [True, False]) + def test_despill_strength_reduces_green_in_spill_pixels(self, sample_frame_rgb, sample_mask, backend, batched): """despill_strength=1.0 must reduce green in spill pixels; strength=0.0 must leave it unchanged. The default mock_greenformer returns uniform gray (R=G=B=0.6) which has no @@ -178,8 +245,23 @@ def green_heavy_forward(x): green_mock.use_refiner = False engine = _make_engine_with_mock(green_mock) - result_no_despill = engine.process_frame(sample_frame_rgb, sample_mask, despill_strength=0.0) - result_full_despill = engine.process_frame(sample_frame_rgb, sample_mask, despill_strength=1.0) + + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result_no_despill = engine.process_frame( + sample_frame_rgb, sample_mask, despill_strength=0.0, post_process_on_gpu=backend == "torch" + )[0] + result_full_despill = engine.process_frame( + sample_frame_rgb, sample_mask, despill_strength=1.0, post_process_on_gpu=backend == "torch" + )[0] + else: + result_no_despill = engine.process_frame( + sample_frame_rgb, sample_mask, despill_strength=0.0, post_process_on_gpu=backend == "torch" + ) + result_full_despill = engine.process_frame( + sample_frame_rgb, sample_mask, despill_strength=1.0, post_process_on_gpu=backend == "torch" + ) rgb_none = result_no_despill["processed"][:, :, :3] rgb_full = result_full_despill["processed"][:, :, :3] @@ -194,13 +276,27 @@ def green_heavy_forward(x): "despill_strength=1.0 should reduce the green channel relative to strength=0.0 when G > (R+B)/2" ) - def test_auto_despeckle_toggle(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + @pytest.mark.parametrize("batched", [True, False]) + def test_auto_despeckle_toggle(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """auto_despeckle=False should skip clean_matte without crashing.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask, auto_despeckle=False) + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.process_frame( + sample_frame_rgb, sample_mask, auto_despeckle=False, post_process_on_gpu=backend == "torch" + )[0] + sample_frame_rgb = sample_frame_rgb[0] # for the shape assertion below + else: + result = engine.process_frame( + sample_frame_rgb, sample_mask, auto_despeckle=False, post_process_on_gpu=backend == "torch" + ) assert result["alpha"].shape[:2] == sample_frame_rgb.shape[:2] - def test_processed_is_linear_premul_rgba(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + @pytest.mark.parametrize("batched", [True, False]) + def test_processed_is_linear_premul_rgba(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """The 'processed' output should be 4-channel RGBA (linear, premultiplied). This is the EXR-ready output that compositors load into Nuke for @@ -208,7 +304,12 @@ def test_processed_is_linear_premul_rgba(self, sample_frame_rgb, sample_mask, mo means color is already multiplied by alpha). """ engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask) + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch")[0] + else: + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") processed = result["processed"] assert processed.shape[2] == 4 @@ -223,22 +324,43 @@ def test_processed_is_linear_premul_rgba(self, sample_frame_rgb, sample_mask, mo np.testing.assert_allclose(alpha, 0.8, atol=1e-5) np.testing.assert_allclose(rgb, expected_premul, atol=1e-4) - def test_mask_2d_vs_3d_input(self, sample_frame_rgb, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + @pytest.mark.parametrize("batched", [True, False]) + def test_mask_2d_vs_3d_input(self, sample_frame_rgb, mock_greenformer, backend, batched): """process_frame should accept both [H, W] and [H, W, 1] masks.""" engine = _make_engine_with_mock(mock_greenformer) mask_2d = np.ones((64, 64), dtype=np.float32) * 0.5 mask_3d = mask_2d[:, :, np.newaxis] - result_2d = engine.process_frame(sample_frame_rgb, mask_2d) - result_3d = engine.process_frame(sample_frame_rgb, mask_3d) + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + mask_2d = np.stack([mask_2d] * 2, axis=0) + mask_3d = np.stack([mask_3d] * 2, axis=0) + result_2d = engine.process_frame(sample_frame_rgb, mask_2d, post_process_on_gpu=backend == "torch")[0] + result_3d = engine.process_frame(sample_frame_rgb, mask_3d, post_process_on_gpu=backend == "torch")[0] + else: + result_2d = engine.process_frame(sample_frame_rgb, mask_2d, post_process_on_gpu=backend == "torch") + result_3d = engine.process_frame(sample_frame_rgb, mask_3d, post_process_on_gpu=backend == "torch") # Both should produce the same output np.testing.assert_allclose(result_2d["alpha"], result_3d["alpha"], atol=1e-5) - def test_refiner_scale_parameter_accepted(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + @pytest.mark.parametrize("batched", [True, False]) + def test_refiner_scale_parameter_accepted(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """Non-default refiner_scale must not raise — the parameter must be threaded through.""" engine = _make_engine_with_mock(mock_greenformer) - result = engine.process_frame(sample_frame_rgb, sample_mask, refiner_scale=0.5) + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.process_frame( + sample_frame_rgb, sample_mask, refiner_scale=0.5, post_process_on_gpu=backend == "torch" + )[0] + sample_frame_rgb = sample_frame_rgb[0] # for the shape assertion below + else: + result = engine.process_frame( + sample_frame_rgb, sample_mask, refiner_scale=0.5, post_process_on_gpu=backend == "torch" + ) assert result["alpha"].shape[:2] == sample_frame_rgb.shape[:2] @@ -249,7 +371,9 @@ def test_refiner_scale_parameter_accepted(self, sample_frame_rgb, sample_mask, m class TestNvidiaGPUProcess: @pytest.mark.gpu - def test_process_frame_on_gpu(self, sample_frame_rgb, sample_mask, mock_greenformer): + @pytest.mark.parametrize("backend", ["openCV", "torch"]) + @pytest.mark.parametrize("batched", [True, False]) + def test_process_frame_on_gpu(self, sample_frame_rgb, sample_mask, mock_greenformer, backend, batched): """ Scenario: Process a frame using a CUDA-configured engine. Expected: Input tensors are moved to CUDA before the model is called, @@ -267,10 +391,15 @@ def spy_forward(x): mock_greenformer.side_effect = spy_forward - engine = _make_engine_with_mock(mock_greenformer) - engine.device = torch.device("cuda") + engine = _make_engine_with_mock(mock_greenformer, device="cuda") - result = engine.process_frame(sample_frame_rgb, sample_mask) + if batched: + sample_frame_rgb = np.stack([sample_frame_rgb] * 2, axis=0) + sample_mask = np.stack([sample_mask] * 2, axis=0) + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") + result = result[0] + else: + result = engine.process_frame(sample_frame_rgb, sample_mask, post_process_on_gpu=backend == "torch") assert result["alpha"].dtype == np.float32 assert len(captured_device) == 1, "Model should be called exactly once" assert captured_device[0].type == "cuda", f"Expected model input on cuda, got {captured_device[0]}"