Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 102 additions & 37 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,19 @@

from ..attention_dispatch import npu_fusion_attention
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None, cal_q=True):
if cal_q:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None and cal_q:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
if cal_q:
return query, key, value, encoder_query, encoder_key, encoder_value
else:
return value, encoder_query, encoder_key, encoder_value
return query, key, value

def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
Expand Down Expand Up @@ -117,6 +116,7 @@ class FluxAttnProcessor:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
self.double_stream = bool(int(os.environ.get("DOUBLE_STREAM", 1)))

def __call__(
self,
Expand Down Expand Up @@ -261,14 +261,15 @@ def _context_parallel_forward(
torch_npu._npu_flash_attention_unpad(query_all, key_all, value_all, seq_len, 1/math.sqrt(D), N, N, out)

out = out.view(B, S, N, D).contiguous()
out = out.to(query.dtype)
out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
out = _all_to_all_single(out, group)
hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()

hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)

if encoder_hidden_states is not None:
out = _all_to_all_single(out, group)
hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()

hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
Expand All @@ -278,6 +279,9 @@ def _context_parallel_forward(

return hidden_states, encoder_hidden_states
else:
out = out.flatten()
out = funcol.all_to_all_single(out, None, None, group)
hidden_states = out.reshape(world_size, H_LOCAL, B, S_Q_LOCAL, D).flatten(0, 1).permute(1, 2, 0, 3)
return hidden_states


Expand Down Expand Up @@ -529,13 +533,19 @@ def forward(

residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
# mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
mlp_hidden_states = self.proj_mlp(norm_hidden_states)
attn_output = _wait_tensor(attn_output)
attn_output = attn_output.contiguous()
if attn_output.ndim == 4:
attn_output = attn_output.flatten(2, 3)
mlp_hidden_states = self.act_mlp(mlp_hidden_states)

hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
Expand Down Expand Up @@ -576,6 +586,8 @@ def __init__(
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")

self.double_stream = bool(int(os.environ.get("DOUBLE_STREAM", 1)))

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -584,21 +596,44 @@ def forward(
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
joint_attention_kwargs = joint_attention_kwargs or {}

norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
joint_attention_kwargs = joint_attention_kwargs or {}

# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)


if self.double_stream:
emb = self.norm1.linear(self.norm1.silu(temb))
current_event.record(current_stream)

with torch.npu.stream(stream2):
stream2.wait_event(current_event)
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=emb, skip_matmul=True)
event2.record(stream2)

pre_encoder_query = self.attn.add_q_proj(norm_encoder_hidden_states)
pre_encoder_key = self.attn.add_k_proj(norm_encoder_hidden_states)
pre_encoder_value = self.attn.add_v_proj(norm_encoder_hidden_states)
current_stream.wait_event(event2)

attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
pre_encoder_query=pre_encoder_query,
pre_encoder_key=pre_encoder_key,
pre_encoder_value=pre_encoder_value,
cal_q=False,
**joint_attention_kwargs,
)
else:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
if len(attention_outputs) == 2:
attn_output, context_attn_output = attention_outputs
elif len(attention_outputs) == 3:
Expand All @@ -611,26 +646,56 @@ def forward(
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
if self.double_stream:
current_event.record(current_stream)
with torch.npu.stream(stream2):
stream2.wait_event(current_event)
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output

hidden_states = hidden_states + ff_output
if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]

# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
event2.record(stream2)

norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]

ff_output = self.ff(norm_hidden_states)
current_stream.wait_event(event2)

context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output

return encoder_hidden_states, hidden_states
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)

return encoder_hidden_states, hidden_states

else:
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output

hidden_states = hidden_states + ff_output
if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output

# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output

norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]

context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)

return encoder_hidden_states, hidden_states


class FluxPosEmbed(nn.Module):
Expand Down