Skip to content

Conversation

@GeorgiosSmyrnis
Copy link
Collaborator

This adds a flag that stops attention from going across documents, identified by the EOT token.

The loss for the token right after the EOT token is ignored.

TODO: add some tests for the shape of the mask.

@GeorgiosSmyrnis GeorgiosSmyrnis changed the title Attention across documents. [WIP] Attention across documents. Jan 31, 2024
help="If set, allow model to do multiple data passes over our dataset, in order to reach the desired number of tokens.",
)
parser.add_argument(
"--mask-across-documents",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this should be an int not a bool so that a user can specify their EOT token

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense - will update the parameter.

if args.mask_across_documents:
# Some input samples contain EOT as the final token. The prediction after that is meaningless, so it
# should not contribute to the loss.
ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT.value, as_tuple=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i prefer not to hard code our EOT to keep open_lm tokenizer agnostic

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed - I'll change it so that it uses the user defined EOT token.

@GeorgiosSmyrnis GeorgiosSmyrnis force-pushed the gsmyrnis/document_attention branch from 4c322d1 to 7234b31 Compare February 2, 2024 23:29
# Some input samples contain EOT as the final token. The prediction after that is meaningless, so it
# should not contribute to the loss.
ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT.value, as_tuple=True)
targets = targets.detach().clone() # Clone this because it shares mem with input!
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, is the detach necessary here? When args.mask_across_documents is False, should we also a detach()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Detach is not necessary, but clone is - because the targets and the input share the underlying tensor, if the target is explicitly set then the input is also affected.

When args.mask_across_documents is False, this is not an issue - neither the target nor the input are explicitly changed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants