diff --git a/comfy/float.py b/comfy/float.py index 521316fd2fac..e638b1ff7ebf 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -65,3 +65,116 @@ def stochastic_rounding(value, dtype, seed=0): return output return value.to(dtype=dtype) + + +# TODO: improve this? +def stochastic_float_to_fp4_e2m1(x, generator): + sign = torch.signbit(x).to(torch.uint8) + x_abs = x.abs() + + exp = torch.floor(torch.log2(x_abs) + 1.0).clamp(0, 3) + x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25 + + x_abs = x.abs() + exp = torch.floor(torch.log2(x_abs) + 1.1925).clamp(0, 3) + + mantissa = torch.where( + exp > 0, + (x_abs / (2.0 ** (exp - 1)) - 1.0) * 2.0, + (x_abs * 2.0) + ).round().to(torch.uint8) + + fp4 = (sign << 3) | (exp.to(torch.uint8) << 1) | mantissa + + fp4_flat = fp4.view(-1) + packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2] + return packed.reshape(list(x.shape)[:-1] + [-1]) + + +def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor: + """ + Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. + See: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + input_matrix: Input tensor of shape (H, W) + Returns: + Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) + """ + + def ceil_div(a, b): + return (a + b - 1) // b + + rows, cols = input_matrix.shape + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + # Calculate the padded shape + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + padded = input_matrix + if (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros( + (padded_rows, padded_cols), + device=input_matrix.device, + dtype=input_matrix.dtype, + ) + padded[:rows, :cols] = input_matrix + + # Rearrange the blocks + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + if flatten: + return rearranged.flatten() + + return rearranged.reshape(padded_rows, padded_cols) + + +def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): + F4_E2M1_MAX = 6.0 + F8_E4M3_MAX = 448.0 + + def roundup(x: int, multiple: int) -> int: + """Round up x to the nearest multiple.""" + return ((x + multiple - 1) // multiple) * multiple + + orig_shape = x.shape + + # Handle padding + if pad_16x: + rows, cols = x.shape + padded_rows = roundup(rows, 16) + padded_cols = roundup(cols, 16) + if padded_rows != rows or padded_cols != cols: + x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) + # Note: We update orig_shape because the output tensor logic below assumes x.shape matches + # what we want to produce. If we pad here, we want the padded output. + orig_shape = x.shape + + block_size = 16 + + x = x.reshape(orig_shape[0], -1, block_size) + max_abs = torch.amax(torch.abs(x), dim=-1) + block_scale = max_abs / F4_E2M1_MAX + scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype) + scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn) + total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype) + + # Handle zero blocks (from padding): avoid 0/0 NaN + zero_scale_mask = (total_scale == 0) + total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale) + + x = x / total_scale_safe.unsqueeze(-1) + + generator = torch.Generator(device=x.device) + generator.manual_seed(seed) + + x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x) + + x = x.view(orig_shape) + data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) + + blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False) + return data_lp, blocked_scales diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 759535501f23..c12ace241e09 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -11,6 +11,69 @@ from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier import comfy.ldm.common_dit +class CompressedTimestep: + """Store video timestep embeddings in compressed form using per-frame indexing.""" + __slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim') + + def __init__(self, tensor: torch.Tensor, patches_per_frame: int): + """ + tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame + patches_per_frame: Number of spatial patches per frame (height * width in latent space) + """ + self.batch_size, num_tokens, self.feature_dim = tensor.shape + + # Check if compression is valid (num_tokens must be divisible by patches_per_frame) + if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame: + self.patches_per_frame = patches_per_frame + self.num_frames = num_tokens // patches_per_frame + + # Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame + # All patches in a frame are identical, so we only keep the first one + reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim) + self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim] + else: + # Not divisible or too small - store directly without compression + self.patches_per_frame = 1 + self.num_frames = num_tokens + self.data = tensor + + def expand(self): + """Expand back to original tensor.""" + if self.patches_per_frame == 1: + return self.data + + # [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim] + expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim) + return expanded.reshape(self.batch_size, -1, self.feature_dim) + + def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)): + """Compute ada values on compressed per-frame data, then expand spatially.""" + num_ada_params = scale_shift_table.shape[0] + + # No compression - compute directly + if self.patches_per_frame == 1: + num_tokens = self.data.shape[1] + dim_per_param = self.feature_dim // num_ada_params + reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :] + table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype) + ada_values = (table_values + reshaped).unbind(dim=2) + return ada_values + + # Compressed: compute on per-frame data then expand spatially + # Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param] + frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :] + table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to( + device=self.data.device, dtype=self.data.dtype + ) + frame_ada = (table_values + frame_reshaped).unbind(dim=2) + + # Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim] + return tuple( + frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1) + .reshape(batch_size, -1, frame_val.shape[-1]) + for frame_val in frame_ada + ) + class BasicAVTransformerBlock(nn.Module): def __init__( self, @@ -119,6 +182,9 @@ def __init__( def get_ada_values( self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None) ): + if isinstance(timestep, CompressedTimestep): + return timestep.expand_for_computation(scale_shift_table, batch_size, indices) + num_ada_params = scale_shift_table.shape[0] ada_values = ( @@ -146,10 +212,7 @@ def get_av_ca_ada_values( gate_timestep, ) - scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values] - gate_ada_values = [t.squeeze(2) for t in gate_ada_values] - - return (*scale_shift_chunks, *gate_ada_values) + return (*scale_shift_ada_values, *gate_ada_values) def forward( self, @@ -543,72 +606,80 @@ def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): if grid_mask is not None: timestep = timestep[:, grid_mask] - timestep = timestep * self.timestep_scale_multiplier + timestep_scaled = timestep * self.timestep_scale_multiplier + v_timestep, v_embedded_timestep = self.adaln_single( - timestep.flatten(), + timestep_scaled.flatten(), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) - # Second dimension is 1 or number of tokens (if timestep_per_token) - v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1]) - v_embedded_timestep = v_embedded_timestep.view( - batch_size, -1, v_embedded_timestep.shape[-1] - ) + # Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width] + # Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width + orig_shape = kwargs.get("orig_shape") + v_patches_per_frame = None + if orig_shape is not None and len(orig_shape) == 5: + # orig_shape[3] = height, orig_shape[4] = width (in latent space) + v_patches_per_frame = orig_shape[3] * orig_shape[4] + + # Reshape to [batch_size, num_tokens, dim] and compress for storage + v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame) + v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame) # Prepare audio timestep a_timestep = kwargs.get("a_timestep") if a_timestep is not None: - a_timestep = a_timestep * self.timestep_scale_multiplier + a_timestep_scaled = a_timestep * self.timestep_scale_multiplier + a_timestep_flat = a_timestep_scaled.flatten() + timestep_flat = timestep_scaled.flatten() av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier + # Cross-attention timesteps - compress these too av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( - a_timestep.flatten(), + a_timestep_flat, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( - timestep.flatten(), + timestep_flat, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( - timestep.flatten() * av_ca_factor, + timestep_flat * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( - a_timestep.flatten() * av_ca_factor, + a_timestep_flat * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) + # Compress cross-attention timesteps (only video side, audio is too small to benefit) + cross_av_timestep_ss = [ + av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]), + CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed + CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed + av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]), + ] + a_timestep, a_embedded_timestep = self.audio_adaln_single( - a_timestep.flatten(), + a_timestep_flat, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) + # Audio timesteps a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1]) - a_embedded_timestep = a_embedded_timestep.view( - batch_size, -1, a_embedded_timestep.shape[-1] - ) - cross_av_timestep_ss = [ - av_ca_audio_scale_shift_timestep, - av_ca_video_scale_shift_timestep, - av_ca_a2v_gate_noise_timestep, - av_ca_v2a_gate_noise_timestep, - ] - cross_av_timestep_ss = list( - [t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss] - ) + a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1]) else: - a_timestep = timestep + a_timestep = timestep_scaled a_embedded_timestep = kwargs.get("embedded_timestep") cross_av_timestep_ss = [] @@ -767,6 +838,11 @@ def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): ax = x[1] v_embedded_timestep = embedded_timestep[0] a_embedded_timestep = embedded_timestep[1] + + # Expand compressed video timestep if needed + if isinstance(v_embedded_timestep, CompressedTimestep): + v_embedded_timestep = v_embedded_timestep.expand() + vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs) # Process audio output diff --git a/comfy/ops.py b/comfy/ops.py index 9c0b54ff44f0..415c39e9202b 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -699,7 +699,7 @@ def convert_weight(self, weight, inplace=False, **kwargs): def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): if getattr(self, 'layout_type', None) is not None: # dtype is now implicit in the layout class - weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True) + weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype) else: weight = weight.to(self.weight.dtype) if return_weight: diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 8324be42a0a9..7a61203c346a 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -7,7 +7,7 @@ QuantizedTensor, QuantizedLayout, TensorCoreFP8Layout as _CKFp8Layout, - TensorCoreNVFP4Layout, # Direct import, no wrapper needed + TensorCoreNVFP4Layout as _CKNvfp4Layout, register_layout_op, register_layout_class, get_layout_class, @@ -34,7 +34,7 @@ class QuantizedTensor: class _CKFp8Layout: pass - class TensorCoreNVFP4Layout: + class _CKNvfp4Layout: pass def register_layout_class(name, cls): @@ -84,6 +84,39 @@ def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): return qdata, params +class TensorCoreNVFP4Layout(_CKNvfp4Layout): + @classmethod + def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): + if tensor.dim() != 2: + raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D") + + orig_dtype = tensor.dtype + orig_shape = tuple(tensor.shape) + + if scale is None or (isinstance(scale, str) and scale == "recalculate"): + scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX) + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + padded_shape = cls.get_padded_shape(orig_shape) + needs_padding = padded_shape != orig_shape + + if stochastic_rounding > 0: + qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding) + else: + qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding) + + params = cls.Params( + scale=scale, + orig_dtype=orig_dtype, + orig_shape=orig_shape, + block_scale=block_scale, + ) + return qdata, params + + class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase): FP8_DTYPE = torch.float8_e4m3fn diff --git a/requirements.txt b/requirements.txt index 6c1cd86d2484..43737056e4fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.36.13 -comfyui-workflow-templates==0.7.69 +comfyui-workflow-templates==0.8.4 comfyui-embedded-docs==0.4.0 torch torchsde @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.5 +comfy-kitchen>=0.2.6 #non essential dependencies: kornia>=0.7.1