@@ -564,7 +564,6 @@ def benchmark(func, *args, **kwargs):
564
564
#
565
565
# * Cross Attention
566
566
# * Fully masked rows no longer cause NaNs
567
- # * Modifying attention score: ALiBi with FlexAttention and NJT
568
567
# * Packed Projection
569
568
570
569
###############################################################################
@@ -668,66 +667,6 @@ def benchmark(func, *args, **kwargs):
668
667
# appropriately makes it possible to properly express empty sequences.
669
668
670
669
671
- ################################################################################
672
- # FlexAttention + NJT
673
- # ---------------------------------------------------------------------
674
- # NJT also composes with the ``FlexAttention`` module. This is a generalization
675
- # of the ``MultiheadAttention`` layer that allows for arbitrary modifications
676
- # to the attention score. The example below takes the ``alibi_mod``
677
- # that implements `ALiBi <https://arxiv.org/abs/2108.12409>`_ from
678
- # `attention gym <https://github.com/meta-pytorch/attention-gym>`_ and uses it
679
- # with nested input tensors.
680
-
681
- from torch .nn .attention .flex_attention import flex_attention
682
-
683
-
684
- def generate_alibi_bias (H : int ):
685
- """Returns an alibi bias score_mod given the number of heads H
686
- Args:
687
- H: number of heads
688
- Returns:
689
- alibi_bias: alibi bias score_mod
690
- """
691
-
692
- def alibi_mod (score , b , h , q_idx , kv_idx ):
693
- scale = torch .exp2 (- ((h + 1 ) * 8.0 / H ))
694
- bias = (q_idx - kv_idx ) * scale
695
- return score + bias
696
-
697
- return alibi_mod
698
-
699
-
700
- query , key , value , _ = gen_batch (N , E_q , E_k , E_v , device )
701
- n_heads , D = 8 , E_q // 8
702
- alibi_score_mod = generate_alibi_bias (n_heads )
703
- query = query .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
704
- key = key .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
705
- value = value .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
706
- out_flex2 = flex_attention (query , key , value , score_mod = alibi_score_mod )
707
-
708
- ###############################################################################
709
- # In addition, one can also use the ``block_mask`` utility of ``FlexAttention``
710
- # with NJTs via the ``create_nested_block_mask`` function. This is useful for
711
- # taking advantage of the sparsity of the mask to speed up the attention computation.
712
- # In particular, the function creates a sparse block mask for a "stacked sequence" of all
713
- # the variable length sequences in the NJT combined into one, while properly masking out
714
- # inter-sequence attention. In the following example, we show how to create a
715
- # causal block mask using this utility.
716
-
717
- from torch .nn .attention .flex_attention import create_nested_block_mask
718
-
719
-
720
- def causal_mask (b , h , q_idx , kv_idx ):
721
- return q_idx >= kv_idx
722
-
723
-
724
- query , key , value , _ = gen_batch (N , E_q , E_k , E_v , device )
725
- block_mask = create_nested_block_mask (causal_mask , 1 , 1 , query , _compile = True )
726
- query = query .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
727
- key = key .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
728
- value = value .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
729
- out_flex = flex_attention (query , key , value , block_mask = block_mask )
730
-
731
670
###############################################################################
732
671
# Packed Projection
733
672
# -----------------
0 commit comments