Skip to content

Commit 02ff3f5

Browse files
authored
Merge branch 'main' into RC-TEST-2.9
2 parents 2aef2a9 + 702e642 commit 02ff3f5

File tree

4 files changed

+1
-197
lines changed

4 files changed

+1
-197
lines changed

intermediate_source/transformer_building_blocks.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,6 @@ def benchmark(func, *args, **kwargs):
564564
#
565565
# * Cross Attention
566566
# * Fully masked rows no longer cause NaNs
567-
# * Modifying attention score: ALiBi with FlexAttention and NJT
568567
# * Packed Projection
569568

570569
###############################################################################
@@ -668,66 +667,6 @@ def benchmark(func, *args, **kwargs):
668667
# appropriately makes it possible to properly express empty sequences.
669668

670669

671-
################################################################################
672-
# FlexAttention + NJT
673-
# ---------------------------------------------------------------------
674-
# NJT also composes with the ``FlexAttention`` module. This is a generalization
675-
# of the ``MultiheadAttention`` layer that allows for arbitrary modifications
676-
# to the attention score. The example below takes the ``alibi_mod``
677-
# that implements `ALiBi <https://arxiv.org/abs/2108.12409>`_ from
678-
# `attention gym <https://github.com/meta-pytorch/attention-gym>`_ and uses it
679-
# with nested input tensors.
680-
681-
from torch.nn.attention.flex_attention import flex_attention
682-
683-
684-
def generate_alibi_bias(H: int):
685-
"""Returns an alibi bias score_mod given the number of heads H
686-
Args:
687-
H: number of heads
688-
Returns:
689-
alibi_bias: alibi bias score_mod
690-
"""
691-
692-
def alibi_mod(score, b, h, q_idx, kv_idx):
693-
scale = torch.exp2(-((h + 1) * 8.0 / H))
694-
bias = (q_idx - kv_idx) * scale
695-
return score + bias
696-
697-
return alibi_mod
698-
699-
700-
query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
701-
n_heads, D = 8, E_q // 8
702-
alibi_score_mod = generate_alibi_bias(n_heads)
703-
query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
704-
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
705-
value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
706-
out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod)
707-
708-
###############################################################################
709-
# In addition, one can also use the ``block_mask`` utility of ``FlexAttention``
710-
# with NJTs via the ``create_nested_block_mask`` function. This is useful for
711-
# taking advantage of the sparsity of the mask to speed up the attention computation.
712-
# In particular, the function creates a sparse block mask for a "stacked sequence" of all
713-
# the variable length sequences in the NJT combined into one, while properly masking out
714-
# inter-sequence attention. In the following example, we show how to create a
715-
# causal block mask using this utility.
716-
717-
from torch.nn.attention.flex_attention import create_nested_block_mask
718-
719-
720-
def causal_mask(b, h, q_idx, kv_idx):
721-
return q_idx >= kv_idx
722-
723-
724-
query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
725-
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)
726-
query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
727-
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
728-
value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
729-
out_flex = flex_attention(query, key, value, block_mask=block_mask)
730-
731670
###############################################################################
732671
# Packed Projection
733672
# -----------------

redirects.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,5 @@
3838
"recipes/recipes_index.html": "../recipes_index.html",
3939
"recipes/torchserve_vertexai_tutorial.html": "../index.html",
4040
"unstable_source/vulkan_workflow.rst": "../index.html",
41+
"unstable/skip_param_init.html": "https://docs.pytorch.org/tutorials/recipes/recipes/module_load_state_dict_tips.html",
4142
}

unstable_index.rst

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,6 @@ decide if we want to upgrade the level of commitment or to fail fast.
4545
:link: unstable/semi_structured_sparse.html
4646
:tags: Model-Optimiziation
4747

48-
.. Modules
49-
50-
.. customcarditem::
51-
:header: Skipping Module Parameter Initialization in PyTorch 1.10
52-
:card_description: Describes skipping parameter initialization during module construction in PyTorch 1.10, avoiding wasted computation.
53-
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.png
54-
:link: unstable/skip_param_init.html
55-
:tags: Modules
56-
5748
.. vmap
5849
5950
.. customcarditem::

unstable_source/skip_param_init.rst

Lines changed: 0 additions & 127 deletions
This file was deleted.

0 commit comments

Comments
 (0)