PyTorch implementation of MergeDNA based on the paper:
- Hierarchical architecture: Local Encoder -> Latent Encoder/Decoder -> Local Decoder
- Dynamic token merging with source-matrix tracking (
S,S') - Three pretraining losses (paper Eq. 8):
L_MTR(theta)lambda * L_MTR(theta \ {phi})L_AMTM(theta)
- Optimization defaults:
- AdamW (
lr=1e-4, betas(0.9, 0.95),weight_decay=1e-8) - linear warmup + cosine annealing scheduler
- challenge default (
config.pyand CLI):warmup_steps=100(withsteps=1000) - paper preset (
--preset paper) auto-useswarmup_steps=10000unless you explicitly pass--warmup-steps
- AdamW (
- W&B logging and periodic validation
- Best-checkpoint saving by
val_mtr
src/mergedna/
config.py
blocks.py
scoring.py
merge_ops.py
data.py
model.py # MergeDNA model assembly + forward paths
losses.py
train.py
eval/
__init__.py
data.py # task loading, synthetic data, dataloaders
models.py # LoRA adapters + sequence classifier
train_eval.py # train/eval loops + HP search helpers
__init__.py
scripts/
main.py # training entrypoint
eval_genomics.py # downstream SFT+LoRA evaluator (thin CLI)
eval_protein_fitness.py # frozen latent linear-probe evaluator
run_train_sample.sh # sample launcher
conda create -n mergedna python=3.11 -y
conda activate mergedna
pip install -r requirements.txt
pip install -e .python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
pip install -e .If editable install is not needed:
PYTHONPATH=src python scripts/main.py --helpUsing the sample launcher with conda (default env name is mergedna):
bash scripts/run_train_sample.shOverride the env name:
CONDA_ENV=my_env bash scripts/run_train_sample.shBypass conda activation:
SKIP_CONDA=1 bash scripts/run_train_sample.shPYTHONPATH=src python scripts/main.py \
--steps 20 \
--batch-size 4 \
--seq-len 256 \
--device cpuDisable scheduler (optional):
PYTHONPATH=src python scripts/main.py --lr-scheduler nonePaper-like scheduler example:
PYTHONPATH=src python scripts/main.py \
--steps 100000 \
--warmup-steps 10000 \
--lr-scheduler cosineadjacent(default): merge adjacent pairs only (best biological contiguity)bipartite: ToMe-style bipartite matching within each windowfull_pairwise: score all pairs in each local window
Example:
PYTHONPATH=src python scripts/main.py \
--local-merge-mode adjacentllama(default): RMSNorm + SwiGLU with pre-norm residual blocksstandard: LayerNorm + GELU feed-forward with post-norm residual blocks
Example:
PYTHONPATH=src python scripts/main.py \
--block-style llamaPYTHONPATH=src python scripts/main.py \
--preset paper \
--seq-len 4096 \
--batch-size 1 \
--device cudaNote: this preset is heavy and may require substantial GPU memory. It also applies
warmup_steps=10000 by default; pass --warmup-steps to override.
PYTHONPATH=src python scripts/main.py \
--train-fasta /path/to/train.fasta \
--val-fasta /path/to/val.fasta \
--steps 2000 \
--seq-len 4096 \
--device cudaSupports plain FASTA and .gz FASTA files.
PYTHONPATH=src python scripts/main.py \
--steps 200 \
--device cuda \
--wandb \
--wandb-project mergedna \
--wandb-run-name mergedna-challengeModes:
--wandb-mode online--wandb-mode offline--wandb-mode disabled
Run Genomics Benchmark-style evaluation (frozen encoder + LoRA + MLP head):
PYTHONPATH=src python scripts/eval_genomics.py \
--checkpoint checkpoints/mergedna_best_val_mtr.pt \
--task-group enhancer \
--data-root /path/to/genomic-benchmarkSynthetic smoke-test mode (no CSV files needed):
PYTHONPATH=src python scripts/eval_genomics.py \
--checkpoint /path/to/current-compatible-checkpoint.pt \
--task-group species \
--synthetic \
--synthetic-train-size 8 \
--synthetic-val-size 4 \
--synthetic-test-size 4 \
--epochs 1 \
--lr-grid 1e-4 \
--wd-grid 0.0 \
--batch-size 4 \
--device cpuCore split modules and import paths:
from mergedna.eval.data import GENOMICS_TASK_GROUPS, load_task_raw, load_task_synthetic, make_loaders, infer_num_classes_from_labelsfrom mergedna.eval.models import LoRALinear, SequenceClassifier, attach_lora_adapters, build_frozen_lora_backbonefrom mergedna.eval.train_eval import set_seed, parse_float_grid, evaluate_loader, train_one_setting, select_best_setting
Following the paper-style protocol, you can freeze the pretrained backbone, extract latent embeddings, train a linear regressor, and average metrics across runs:
PYTHONPATH=src python scripts/eval_protein_fitness.py \
--checkpoint /path/to/current-compatible-checkpoint.pt \
--task-name protein_fitness \
--data-root /path/to/protein_fitness_data \
--alpha-grid 0.0,1e-6,1e-4,1e-2,1.0 \
--n-runs 3 \
--batch-size 32 \
--device cpuSynthetic smoke-test mode:
PYTHONPATH=src python scripts/eval_protein_fitness.py \
--checkpoint /path/to/current-compatible-checkpoint.pt \
--synthetic \
--n-runs 3 \
--batch-size 8 \
--device cpuExpected file layout for real data:
<data_root>/<task_name>/train.csv<data_root>/<task_name>/val.csv<data_root>/<task_name>/test.csv
Expected columns:
- sequence column: one of
sequence, seq, dna, protein, text - fitness/target column: one of
fitness, target, y, label
- Periodic checkpoints:
checkpoints/mergedna_step_*.pt - Best checkpoint by validation MTR:
checkpoints/mergedna_best_val_mtr.pt - Final checkpoint:
checkpoints/mergedna_final.pt
If more time or compute is available, the following extensions are the highest-impact next steps:
- Full paper-scale pretraining runs (e.g., 100k steps with larger hardware budgets and long-context settings)
- Automated data retrieval + canonical preprocessing scripts for pretraining/evaluation datasets (download, normalize, split, and manifest generation)
- End-to-end reproduction of paper benchmark tables on full real datasets instead of smoke/synthetic-only checks
- Protein-fitness support with checkpoint/tokenizer setups that natively cover full amino-acid vocabularies
- Broader ablation and hyperparameter sweeps across merge modes, latent selective modes, block styles, and optimization settings
- Multi-seed statistical reporting with confidence intervals for more robust comparisons