diff --git a/configuration_bibo.py b/configuration_bibo.py index 7a86ee4..bb64948 100644 --- a/configuration_bibo.py +++ b/configuration_bibo.py @@ -104,6 +104,8 @@ def __init__( self.kernel_size = kernel_size self.norm_topk_prob = norm_topk_prob self.output_router_logits = output_router_logits + self.conv_router = conv_router + self.use_ssmax = use_ssmax if mlp_only_layers is None: self.mlp_only_layers = [0, num_hidden_layers - 1] else: diff --git a/modeling_bibo.py b/modeling_bibo.py index 69ad9fe..f016295 100644 --- a/modeling_bibo.py +++ b/modeling_bibo.py @@ -16,7 +16,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss -from einops import rearrange +from einops import rearrange, einsum from transformers.activations import ACT2FN from transformers.modeling_outputs import ( @@ -701,9 +701,146 @@ def forward( attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value - - def eager_sliding_window_attention(self): - pass + + def eager_sliding_window_attention( + self, + query_states: torch.Tensor, # (b, h, q_len, d) + key_states: torch.Tensor, # (b, h, kv_len, d) + value_states: torch.Tensor, # (b, h, kv_len, d) + attention_mask: Optional[torch.Tensor], # (b, 1, q_len, kv_len) additive + window_size: int, + stride: Optional[int] = None + ) -> torch.Tensor: + """ + Efficient sliding window attention using tensor unfolding and einops, + matching the logic of the provided loop-based implementation. + + Args: + query_states: Query tensor [batch_size, num_heads, q_len, head_dim] + key_states: Key tensor [batch_size, num_heads, kv_len, head_dim] + value_states: Value tensor [batch_size, num_heads, kv_len, head_dim] + attention_mask: Optional additive mask tensor (-inf for masked) + Shape [batch_size, 1, q_len, kv_len]. + window_size: Size of the attention window. + + Returns: + Attention output tensor [batch_size, num_heads, q_len, head_dim] + """ + batch_size, num_heads, q_len, head_dim = query_states.shape + kv_len = key_states.shape[-2] + # Ensure kv_len matches query_states if expected by causal window logic + # assert q_len == kv_len, "This specific unfold logic assumes q_len == kv_len for simplicity" + + # --- 1. Pad Key and Value tensors (Left Padding) --- + # Pad by window_size - 1 on the left of the sequence dimension (dim 2) + kv_padding_size = max(0, window_size - 1) + # Pad format: (pad_left, pad_right) for last dim, then second-to-last, etc. + # We pad dim 2 (sequence length): (pad_seq_left, pad_seq_right) + kv_padding = (0, 0, kv_padding_size, 0) # (pad_dim3_l, pad_d3_r, pad_dim2_l, pad_d2_r) + + padded_key_states = F.pad(key_states, kv_padding) + padded_value_states = F.pad(value_states, kv_padding) + # Padded shape: [b, h, kv_len + window_size - 1, d] + + # --- 2. Unfold Padded Key/Value tensors --- + # Create sliding windows of size `window_size` along dim 2 with step 1 + unfolded_key = padded_key_states.unfold(dimension=2, size=window_size, step=1) + unfolded_value = padded_value_states.unfold(dimension=2, size=window_size, step=1) + # Shape after unfold: [b, h, num_windows, d, w] + # num_windows = (kv_len + window_size - 1) - window_size + 1 = kv_len + # If q_len == kv_len, then num_windows == q_len + + # Handle potential mismatch if q_len != kv_len (unlikely for standard SWA) + num_windows = unfolded_key.shape[2] + if num_windows != q_len: + print(f"Warning: q_len ({q_len}) != kv_len ({num_windows}). Check logic if this is intended. Taking first {q_len} windows.") + unfolded_key = unfolded_key[:, :, :q_len, :, :] + unfolded_value = unfolded_value[:, :, :q_len, :, :] + + # --- 3. Rearrange with einops --- + # Rearrange to [b, h, q_len, window_size, head_dim] for matmul convenience + unfolded_key = rearrange(unfolded_key, 'b h q d w -> b h q w d') + unfolded_value = rearrange(unfolded_value, 'b h q d w -> b h q w d') + + # --- 4. Compute Attention Scores within Windows --- + # Scale query beforehand as in the loop version + # query_scaled = query_states * (self.head_dim ** -0.5) # Option 1: Scale Q + # attn_scores_windowed = einsum(query_scaled, unfolded_key, 'b h q d, b h q w d -> b h q w') + + # Option 2: Scale scores after matmul (more common) + scale_factor = self.head_dim ** -0.5 + attn_scores_windowed = einsum(query_states, unfolded_key, 'b h q d, b h q w d -> b h q w') * scale_factor + # Shape: [b, h, q, w] + + # --- 5. Apply Masking --- + # a) Mask attention to padded key positions introduced by F.pad + # Calculate original key indices corresponding to each window position + relative_indices = torch.arange(window_size, device=query_states.device) # (w,) + query_indices = torch.arange(q_len, device=query_states.device).unsqueeze(1) # (q, 1) + # For query i, window position k, the key index in the padded tensor is i+k. + # The original key index is (i+k) - kv_padding_size + # More directly: original key index = query_idx - (window_size - 1) + window_relative_idx + original_key_indices = query_indices - kv_padding_size + relative_indices # Shape (q, w) + # Mask if the original key index is negative (came from padding) + padding_mask_window = (original_key_indices < 0) # Shape (q, w) boolean + + # Apply this padding mask (additive -inf) + attn_scores_windowed = attn_scores_windowed.masked_fill( + padding_mask_window.unsqueeze(0).unsqueeze(0), # Expand to (1, 1, q, w) + float('-inf') + ) + + # b) Apply the external attention_mask if provided + # This requires selecting the mask values corresponding to the keys in each window. + # Replicating the loop version's logic by constructing the windowed mask. + if attention_mask is not None: + # Input mask shape: (b, 1, q, k), additive (-inf where masked) + # Output needed: (b, h, q, w) mask values corresponding to windowed keys + windowed_attention_mask = torch.zeros_like(attn_scores_windowed) # Start with 0 (no mask) + + # Ensure mask has correct dimensions for slicing + mask_for_loop = attention_mask + if mask_for_loop.shape[2] != q_len or mask_for_loop.shape[3] != kv_len: + mask_for_loop = mask_for_loop[:,:,:q_len, :kv_len] # Adjust slice if needed + + for i in range(q_len): + # Determine the slice of keys in the *original* sequence for query i's window + k_start = max(0, i - window_size + 1) + k_end = i + 1 # exclusive end index + + # Extract the relevant slice from the original attention_mask + # Mask for query i, attending to keys k_start to k_end-1 + # Shape: (b, 1, 1, actual_window_len) + mask_slice = mask_for_loop[:, :, i:i+1, k_start:k_end] + + # Pad the mask slice on the left if the window was truncated at the start + actual_window_len = k_end - k_start + left_padding_needed = window_size - actual_window_len + # Pad format (left, right) for the last dimension. Pad with 0 for additive mask. + padded_mask_slice = F.pad(mask_slice, (left_padding_needed, 0), value=0.0) + + # Assign to the correct position in the windowed mask tensor + # Squeeze the query dim (dim 2) from the slice before assigning + windowed_attention_mask[:, :, i, :] = padded_mask_slice.squeeze(2) + + # Add the constructed windowed mask to the scores + attn_scores_windowed = attn_scores_windowed + windowed_attention_mask + + # --- 6. Compute Attention Probabilities --- + # Softmax is applied over the window dimension (-1) + attn_weights = F.softmax(attn_scores_windowed, dim=-1, dtype=torch.float32).to(query_states.dtype) + # Shape: [b, h, q, w] + + # Apply dropout (as in the loop version) + attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + # --- 7. Compute Output using einsum --- + # Weighted sum of the unfolded values based on the attention weights + # weights[b,h,q,w] * values[b,h,q,w,d] -> output[b,h,q,d] + attn_output = einsum(attn_weights, unfolded_value, 'b h q w, b h q w d -> b h q d') + # Final shape: [batch_size, num_heads, q_len, head_dim] + + return attn_output def eager_standard_attention( @@ -729,6 +866,7 @@ def eager_standard_attention( kv_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if self.use_ssmax: log_n = torch.log(torch.clamp(torch.tensor(kv_len, device=query_states.device, dtype=self.ssmax_scale.dtype), min=2.0)) # min=2.0 since log(1) = 0 and negative for <1 @@ -738,11 +876,9 @@ def eager_standard_attention( # SSMax Ratio: exp(C * z_i) / exp(C * z_k) = exp(C * z_i - C * z_k) = exp(C * (z_i - z_k)) = (exp(z_i - z_k))^C # C is scaling factor i.e s*log(seq_len) ; # in a gist: a learnable, seq-len adaptive temperature applied per head to control attention sharpness, preventing fading in long contexts. + s_scaled = self.s.view(1, self.num_heads, 1, 1) * log_n - - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - attn_weights = attn_weights * s_scaled + attn_weights = attn_weights * s_scaled if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] @@ -752,8 +888,3 @@ def eager_standard_attention( attn_output = torch.matmul(attn_weights, value_states) return attn_output - - - - - diff --git a/test_sliding_window_attention.py b/test_sliding_window_attention.py new file mode 100644 index 0000000..b64cdc0 --- /dev/null +++ b/test_sliding_window_attention.py @@ -0,0 +1,247 @@ +import math +import torch +import pytest +from configuration_bibo import BiBoConfig +from modeling_bibo import BiBoAttention + + +def test_window_size_1(): + """Test with window_size=1 (each token only attends to itself).""" + batch_size, num_heads, seq_len, head_dim = 1, 2, 4, 8 + hidden_size = num_heads * head_dim + + # Create a proper config + config = BiBoConfig() + config.hidden_size = hidden_size + config.num_attention_heads = num_heads + config.num_key_value_heads = num_heads # Same as num_heads to avoid divisibility issue + + # Create test inputs + query = torch.randn(batch_size, num_heads, seq_len, head_dim) + key = torch.randn(batch_size, num_heads, seq_len, head_dim) + value = torch.ones(batch_size, num_heads, seq_len, head_dim) # All ones for easy verification + + # Initialize the attention module + attn = BiBoAttention(config, layer_idx=0, use_sliding_window=True) + + # Override the sliding window parameter for this test + attn.sliding_window = 1 + + # With window_size=1, each token should only attend to itself + output = attn.eager_sliding_window_attention( + query, key, value, attention_mask=None, window_size=1 + ) + + # Check output shape + assert output.shape == (batch_size, num_heads, seq_len, head_dim) + + # Compute attention weights manually + attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_dim) + + # Create a diagonal mask + mask = torch.zeros((seq_len, seq_len), dtype=torch.bool) + for i in range(seq_len): + mask[i, i] = True + mask = mask.unsqueeze(0).unsqueeze(0) # shape: (1, 1, seq_len, seq_len) + + # Apply mask + masked_weights = attn_weights.clone() + masked_weights = masked_weights.masked_fill(~mask, float('-inf')) + expected_weights = torch.nn.functional.softmax(masked_weights, dim=-1) + + # Compute expected output + expected_output = torch.matmul(expected_weights, value) + + # Compare outputs + torch.testing.assert_close(output, expected_output, rtol=1e-4, atol=1e-4) + print("✅ test_window_size_1 passed") + + +def test_window_size_larger_than_sequence(): + """Test with window_size > seq_len (equivalent to full causal attention).""" + batch_size, num_heads, seq_len, head_dim = 1, 2, 4, 8 + hidden_size = num_heads * head_dim + + # Create a proper config + config = BiBoConfig() + config.hidden_size = hidden_size + config.num_attention_heads = num_heads + config.num_key_value_heads = num_heads # Same as num_heads to avoid divisibility issue + + # Create test inputs + query = torch.randn(batch_size, num_heads, seq_len, head_dim) + key = torch.randn(batch_size, num_heads, seq_len, head_dim) + value = torch.randn(batch_size, num_heads, seq_len, head_dim) + + # Initialize the attention module with sliding window + attn = BiBoAttention(config, layer_idx=0, use_sliding_window=True) + + # Create a standard causal mask (-inf for positions not allowed to attend to) + attention_mask = torch.zeros((batch_size, 1, seq_len, seq_len)) + attention_mask = attention_mask.masked_fill( + torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().unsqueeze(0).unsqueeze(0), + float('-inf') + ) + + # With window_size > seq_len, should be equivalent to full causal attention + window_output = attn.eager_sliding_window_attention( + query, key, value, attention_mask=None, window_size=seq_len * 2 + ) + + # For comparison, we need to implement the standard causal attention manually + # Start with raw attention scores + attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_dim) + + # Apply causal mask (lower triangular) + causal_mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1) + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len] + attn_weights = attn_weights + causal_mask + + # Apply softmax + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + # Get expected output + expected_output = torch.matmul(attn_weights, value) + + # Outputs should match + torch.testing.assert_close(window_output, expected_output, rtol=1e-4, atol=1e-4) + print("✅ test_window_size_larger_than_sequence passed") + + +def test_window_size_3(): + """Test with a specific window size (window_size=3).""" + batch_size, num_heads, seq_len, head_dim = 1, 2, 6, 4 # Using 2 heads to match key_value_heads + hidden_size = num_heads * head_dim + window_size = 3 + + # Create a proper config + config = BiBoConfig() + config.hidden_size = hidden_size + config.num_attention_heads = num_heads + config.num_key_value_heads = num_heads # Same as num_heads to avoid divisibility issue + config.sliding_window = window_size + + # Initialize the attention module + attn = BiBoAttention(config, layer_idx=0, use_sliding_window=True) + + # Create inputs with a clear pattern + # Use one-hot encodings to make the test case clearer + query = torch.zeros(batch_size, num_heads, seq_len, head_dim) + key = torch.zeros(batch_size, num_heads, seq_len, head_dim) + + # Set each position to attend equally to all positions + # (makes the attention weights uniform within the allowed window) + for i in range(seq_len): + for j in range(num_heads): + for d in range(head_dim): + query[0, j, i, d] = 1.0 + key[0, j, i, d] = 1.0 + + # Create values where each position has a unique identifier + value = torch.zeros(batch_size, num_heads, seq_len, head_dim) + for i in range(seq_len): + for j in range(num_heads): + value[0, j, i] = i + 1 # Position i has value i+1 + + # Apply sliding window attention + output = attn.eager_sliding_window_attention( + query, key, value, attention_mask=None, window_size=window_size + ) + + # Verify each position's attention pattern + # Each position should attend to the previous window_size-1 positions and itself + for pos in range(seq_len): + # Start of the window for this position (respecting sequence boundaries) + window_start = max(0, pos - window_size + 1) + # End of the window for this position (inclusive) + window_end = pos + + # Expected output is the average of values in the window + # (since query and key are the same, attention weights are uniform within the window) + window_size_actual = window_end - window_start + 1 + expected_value = sum(i + 1 for i in range(window_start, window_end + 1)) / window_size_actual + + # Check if the output matches the expected value (for the first head) + actual_value = output[0, 0, pos].mean().item() + assert abs(actual_value - expected_value) < 1e-5, \ + f"Position {pos}: Expected {expected_value}, got {actual_value}" + + print("✅ test_window_size_3 passed") + + +def test_with_attention_mask(): + """Test interaction with an additional attention mask.""" + batch_size, num_heads, seq_len, head_dim = 1, 2, 5, 4 + hidden_size = num_heads * head_dim + window_size = 3 + + # Create a proper config + config = BiBoConfig() + config.hidden_size = hidden_size + config.num_attention_heads = num_heads + config.num_key_value_heads = num_heads # Same as num_heads to avoid divisibility issue + + # Initialize the attention module + attn = BiBoAttention(config, layer_idx=0, use_sliding_window=True) + + # Create inputs + query = torch.randn(batch_size, num_heads, seq_len, head_dim) + key = torch.randn(batch_size, num_heads, seq_len, head_dim) + value = torch.randn(batch_size, num_heads, seq_len, head_dim) + + # Create a custom attention mask that blocks attention to position 2 + # We'll use a very negative number (-1e9) to represent attention that's blocked + attention_mask = torch.zeros((batch_size, 1, seq_len, seq_len)) + attention_mask[:, :, :, 2] = -1e9 # Block all attention to position 2 + + # Apply sliding window attention with the additional mask + output_with_mask = attn.eager_sliding_window_attention( + query, key, value, attention_mask=attention_mask, window_size=window_size + ) + + # Now manually verify the attention pattern for a specific position + # For example, position 3 should attend to positions 1,3 but not 2 (blocked by mask) + # First compute attention scores + scores = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_dim) + + # Apply window mask for position 3 + window_mask = torch.zeros((seq_len, seq_len), dtype=torch.bool) + for i in range(seq_len): + start = max(0, i - window_size + 1) + for j in range(start, i + 1): + window_mask[i, j] = True + + # Position 3 should attend to positions 1,2,3 due to window_size=3 + # But position 2 is blocked by attention_mask + pos = 3 + valid_attention = window_mask[pos].clone() + valid_attention[2] = False # Position 2 is blocked + + # Get the weights after applying masks + pos_scores = scores[0, 0, pos].clone() + pos_scores = pos_scores.masked_fill(~valid_attention, float('-inf')) + pos_scores = pos_scores + attention_mask[0, 0, pos] + pos_weights = torch.nn.functional.softmax(pos_scores, dim=-1) + + # Expected output for position 3 + expected_output_pos3 = torch.matmul( + pos_weights.unsqueeze(0), + value[0, 0].clone() + ).squeeze(0) + + # Verify position 3 output + torch.testing.assert_close( + output_with_mask[0, 0, pos], + expected_output_pos3, + rtol=1e-5, atol=1e-5 + ) + print("✅ test_with_attention_mask passed") + + +if __name__ == "__main__": + # Run all tests + test_window_size_1() + test_window_size_larger_than_sequence() + test_window_size_3() + test_with_attention_mask() + print("All tests passed! 🎉") \ No newline at end of file