Non-record: JEPA v3 — span-masked I-JEPA + VICReg, val_bpb 1.2321#1581
Open
aiejvn wants to merge 1 commit intoopenai:mainfrom
Open
Non-record: JEPA v3 — span-masked I-JEPA + VICReg, val_bpb 1.2321#1581aiejvn wants to merge 1 commit intoopenai:mainfrom
aiejvn wants to merge 1 commit intoopenai:mainfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Non-record: JEPA v3 — span-masked I-JEPA + VICReg, val_bpb 1.2321
Builds on PR #1330 (JEPA v2 — why same-sequence next-k JEPA collapses in causal LMs). Two additions:
Span-masked JEPA: The context encoder sees target spans replaced with a learned mask embedding (
jepa_mask_emb) rather than the actual tokens — the target encoder sees the full unmasked sequence. This makes prediction genuinely hard: the context encoder cannot recover the target token from its own input and must rely on surrounding context. Bigram hash contributions are explicitly zeroed at masked positions to prevent the Cantor hash from leaking token identity. Span lengths are sampled from Geometric(mean=16) with 4 spans per sequence (~6% masked per step).VICReg anti-collapse regularization: Variance hinge and off-diagonal covariance penalty (V-JEPA style) are applied to the predictor-side representations at masked positions. This prevents the predictor from collapsing to a single point or low-rank subspace independently of the span masking. Target-side VICReg terms are monitored as diagnostics only — no gradient.
Optimizer bug fix (v2 regression): In v2,
JEPAPredictorandjepa_mask_embwere absent from all three optimizer groups — onlybase_model.blockswas iterated (verifiable inb4a428b). The predictor was frozen at zero-init for the entire v2 run. Fixed by explicitly routing predictor matrix params to Muon and scalar params to Adam.Non-record reason: Trained ~20hr on 1× AWS A10G.