Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions comfy/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,116 @@ def stochastic_rounding(value, dtype, seed=0):
return output

return value.to(dtype=dtype)


# TODO: improve this?
def stochastic_float_to_fp4_e2m1(x, generator):
sign = torch.signbit(x).to(torch.uint8)
x_abs = x.abs()

exp = torch.floor(torch.log2(x_abs) + 1.0).clamp(0, 3)
x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25

x_abs = x.abs()
exp = torch.floor(torch.log2(x_abs) + 1.1925).clamp(0, 3)

mantissa = torch.where(
exp > 0,
(x_abs / (2.0 ** (exp - 1)) - 1.0) * 2.0,
(x_abs * 2.0)
).round().to(torch.uint8)

fp4 = (sign << 3) | (exp.to(torch.uint8) << 1) | mantissa

fp4_flat = fp4.view(-1)
packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2]
return packed.reshape(list(x.shape)[:-1] + [-1])


def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
See:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout

Args:
input_matrix: Input tensor of shape (H, W)
Returns:
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
"""

def ceil_div(a, b):
return (a + b - 1) // b

rows, cols = input_matrix.shape
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)

# Calculate the padded shape
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4

padded = input_matrix
if (rows, cols) != (padded_rows, padded_cols):
padded = torch.zeros(
(padded_rows, padded_cols),
device=input_matrix.device,
dtype=input_matrix.dtype,
)
padded[:rows, :cols] = input_matrix

# Rearrange the blocks
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
if flatten:
return rearranged.flatten()

return rearranged.reshape(padded_rows, padded_cols)


def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
F4_E2M1_MAX = 6.0
F8_E4M3_MAX = 448.0

def roundup(x: int, multiple: int) -> int:
"""Round up x to the nearest multiple."""
return ((x + multiple - 1) // multiple) * multiple

orig_shape = x.shape

# Handle padding
if pad_16x:
rows, cols = x.shape
padded_rows = roundup(rows, 16)
padded_cols = roundup(cols, 16)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
# Note: We update orig_shape because the output tensor logic below assumes x.shape matches
# what we want to produce. If we pad here, we want the padded output.
orig_shape = x.shape

block_size = 16

x = x.reshape(orig_shape[0], -1, block_size)
max_abs = torch.amax(torch.abs(x), dim=-1)
block_scale = max_abs / F4_E2M1_MAX
scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype)
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)

# Handle zero blocks (from padding): avoid 0/0 NaN
zero_scale_mask = (total_scale == 0)
total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale)

x = x / total_scale_safe.unsqueeze(-1)

generator = torch.Generator(device=x.device)
generator.manual_seed(seed)

x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x)

x = x.view(orig_shape)
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)

blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
return data_lp, blocked_scales
136 changes: 106 additions & 30 deletions comfy/ldm/lightricks/av_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,69 @@
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
import comfy.ldm.common_dit

class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing."""
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')

def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
"""
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
"""
self.batch_size, num_tokens, self.feature_dim = tensor.shape

# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
self.patches_per_frame = patches_per_frame
self.num_frames = num_tokens // patches_per_frame

# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
# All patches in a frame are identical, so we only keep the first one
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
else:
# Not divisible or too small - store directly without compression
self.patches_per_frame = 1
self.num_frames = num_tokens
self.data = tensor

def expand(self):
"""Expand back to original tensor."""
if self.patches_per_frame == 1:
return self.data

# [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim]
expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim)
return expanded.reshape(self.batch_size, -1, self.feature_dim)

def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)):
"""Compute ada values on compressed per-frame data, then expand spatially."""
num_ada_params = scale_shift_table.shape[0]

# No compression - compute directly
if self.patches_per_frame == 1:
num_tokens = self.data.shape[1]
dim_per_param = self.feature_dim // num_ada_params
reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :]
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype)
ada_values = (table_values + reshaped).unbind(dim=2)
return ada_values

# Compressed: compute on per-frame data then expand spatially
# Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param]
frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :]
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(
device=self.data.device, dtype=self.data.dtype
)
frame_ada = (table_values + frame_reshaped).unbind(dim=2)

# Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim]
return tuple(
frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1)
.reshape(batch_size, -1, frame_val.shape[-1])
for frame_val in frame_ada
)

class BasicAVTransformerBlock(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -119,6 +182,9 @@ def __init__(
def get_ada_values(
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
):
if isinstance(timestep, CompressedTimestep):
return timestep.expand_for_computation(scale_shift_table, batch_size, indices)

num_ada_params = scale_shift_table.shape[0]

ada_values = (
Expand Down Expand Up @@ -146,10 +212,7 @@ def get_av_ca_ada_values(
gate_timestep,
)

scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]

return (*scale_shift_chunks, *gate_ada_values)
return (*scale_shift_ada_values, *gate_ada_values)

def forward(
self,
Expand Down Expand Up @@ -543,72 +606,80 @@ def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
if grid_mask is not None:
timestep = timestep[:, grid_mask]

timestep = timestep * self.timestep_scale_multiplier
timestep_scaled = timestep * self.timestep_scale_multiplier

v_timestep, v_embedded_timestep = self.adaln_single(
timestep.flatten(),
timestep_scaled.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)

# Second dimension is 1 or number of tokens (if timestep_per_token)
v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1])
v_embedded_timestep = v_embedded_timestep.view(
batch_size, -1, v_embedded_timestep.shape[-1]
)
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
orig_shape = kwargs.get("orig_shape")
v_patches_per_frame = None
if orig_shape is not None and len(orig_shape) == 5:
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
v_patches_per_frame = orig_shape[3] * orig_shape[4]

# Reshape to [batch_size, num_tokens, dim] and compress for storage
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)

# Prepare audio timestep
a_timestep = kwargs.get("a_timestep")
if a_timestep is not None:
a_timestep = a_timestep * self.timestep_scale_multiplier
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
a_timestep_flat = a_timestep_scaled.flatten()
timestep_flat = timestep_scaled.flatten()
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier

# Cross-attention timesteps - compress these too
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
a_timestep.flatten(),
a_timestep_flat,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
timestep.flatten(),
timestep_flat,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
timestep.flatten() * av_ca_factor,
timestep_flat * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
a_timestep.flatten() * av_ca_factor,
a_timestep_flat * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)

# Compress cross-attention timesteps (only video side, audio is too small to benefit)
cross_av_timestep_ss = [
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
]

a_timestep, a_embedded_timestep = self.audio_adaln_single(
a_timestep.flatten(),
a_timestep_flat,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
# Audio timesteps
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
a_embedded_timestep = a_embedded_timestep.view(
batch_size, -1, a_embedded_timestep.shape[-1]
)
cross_av_timestep_ss = [
av_ca_audio_scale_shift_timestep,
av_ca_video_scale_shift_timestep,
av_ca_a2v_gate_noise_timestep,
av_ca_v2a_gate_noise_timestep,
]
cross_av_timestep_ss = list(
[t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss]
)
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
else:
a_timestep = timestep
a_timestep = timestep_scaled
a_embedded_timestep = kwargs.get("embedded_timestep")
cross_av_timestep_ss = []

Expand Down Expand Up @@ -767,6 +838,11 @@ def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
ax = x[1]
v_embedded_timestep = embedded_timestep[0]
a_embedded_timestep = embedded_timestep[1]

# Expand compressed video timestep if needed
if isinstance(v_embedded_timestep, CompressedTimestep):
v_embedded_timestep = v_embedded_timestep.expand()

vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)

# Process audio output
Expand Down
2 changes: 1 addition & 1 deletion comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def convert_weight(self, weight, inplace=False, **kwargs):
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
# dtype is now implicit in the layout class
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True)
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
else:
weight = weight.to(self.weight.dtype)
if return_weight:
Expand Down
37 changes: 35 additions & 2 deletions comfy/quant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
QuantizedTensor,
QuantizedLayout,
TensorCoreFP8Layout as _CKFp8Layout,
TensorCoreNVFP4Layout, # Direct import, no wrapper needed
TensorCoreNVFP4Layout as _CKNvfp4Layout,
register_layout_op,
register_layout_class,
get_layout_class,
Expand All @@ -34,7 +34,7 @@ class QuantizedTensor:
class _CKFp8Layout:
pass

class TensorCoreNVFP4Layout:
class _CKNvfp4Layout:
pass

def register_layout_class(name, cls):
Expand Down Expand Up @@ -84,6 +84,39 @@ def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
return qdata, params


class TensorCoreNVFP4Layout(_CKNvfp4Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if tensor.dim() != 2:
raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D")

orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)

if scale is None or (isinstance(scale, str) and scale == "recalculate"):
scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX)

if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)

padded_shape = cls.get_padded_shape(orig_shape)
needs_padding = padded_shape != orig_shape

if stochastic_rounding > 0:
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
else:
qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding)

params = cls.Params(
scale=scale,
orig_dtype=orig_dtype,
orig_shape=orig_shape,
block_scale=block_scale,
)
return qdata, params


class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
FP8_DTYPE = torch.float8_e4m3fn

Expand Down
Loading
Loading