Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions comfy/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,44 @@ def ceil_div(a, b):
return rearranged.reshape(padded_rows, padded_cols)


def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
def stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator):
F4_E2M1_MAX = 6.0
F8_E4M3_MAX = 448.0

orig_shape = x.shape

block_size = 16

x = x.reshape(orig_shape[0], -1, block_size)
scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
x = x / (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1)

x = x.view(orig_shape).nan_to_num()
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
return data_lp, scaled_block_scales_fp8


def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
def roundup(x: int, multiple: int) -> int:
"""Round up x to the nearest multiple."""
return ((x + multiple - 1) // multiple) * multiple

generator = torch.Generator(device=x.device)
generator.manual_seed(seed)

# Handle padding
if pad_16x:
rows, cols = x.shape
padded_rows = roundup(rows, 16)
padded_cols = roundup(cols, 16)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))

x, blocked_scaled = stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator)
return x, to_blocked(blocked_scaled, flatten=False)


def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=0, block_size=4096 * 4096):
def roundup(x: int, multiple: int) -> int:
"""Round up x to the nearest multiple."""
return ((x + multiple - 1) // multiple) * multiple
Expand All @@ -158,16 +192,20 @@ def roundup(x: int, multiple: int) -> int:
# what we want to produce. If we pad here, we want the padded output.
orig_shape = x.shape

block_size = 16
orig_shape = list(orig_shape)

x = x.reshape(orig_shape[0], -1, block_size)
scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
x /= (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1)
output_fp4 = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 2], dtype=torch.uint8, device=x.device)
output_block = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 16], dtype=torch.float8_e4m3fn, device=x.device)

generator = torch.Generator(device=x.device)
generator.manual_seed(seed)

x = x.view(orig_shape).nan_to_num()
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
return data_lp, blocked_scales
num_slices = max(1, (x.numel() / block_size))
slice_size = max(1, (round(x.shape[0] / num_slices)))

for i in range(0, x.shape[0], slice_size):
fp4, block = stochastic_round_quantize_nvfp4_block(x[i: i + slice_size], per_tensor_scale, generator=generator)
output_fp4[i:i + slice_size].copy_(fp4)
output_block[i:i + slice_size].copy_(block)

return output_fp4, to_blocked(output_block, flatten=False)
2 changes: 1 addition & 1 deletion comfy/quant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
needs_padding = padded_shape != orig_shape

if stochastic_rounding > 0:
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4_by_block(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
else:
qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding)

Expand Down
2 changes: 1 addition & 1 deletion comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ class ZImage(Lumina2):
"shift": 3.0,
}

memory_usage_factor = 2.0
memory_usage_factor = 2.8

supported_inference_dtypes = [torch.bfloat16, torch.float32]

Expand Down
Loading