-
Notifications
You must be signed in to change notification settings - Fork 28
Allow multidataset and multitokenization training to co-occur within a single iteration mixture #658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Allow multidataset and multitokenization training to co-occur within a single iteration mixture #658
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| #!/bin/bash | ||
| # demos/multidataset_per_batch_mixing.sh | ||
| # Demonstrates scheduled per-batch mixing across multiple datasets. | ||
|
|
||
| set -e | ||
|
|
||
| # Ensure required corpora are prepared. | ||
| bash data/shakespeare_char/get_dataset.sh | ||
|
|
||
| pushd data/minipile > /dev/null | ||
| if [ ! -f "train.bin" ] || [ ! -f "val.bin" ] || [ ! -f "meta.pkl" ]; then | ||
| bash get_dataset.sh | ||
| python3 prepare.py -t input.txt --method tiktoken | ||
| else | ||
| echo "train.bin, val.bin, and meta.pkl already exist for minipile." | ||
| fi | ||
| popd > /dev/null | ||
|
|
||
| # Run a small demonstration training job with per-batch mixing enabled. | ||
| python3 train.py \ | ||
| --training_mode multidataset \ | ||
| --dataset_list shakespeare_char minipile \ | ||
| --dataset_mixing_per_batch \ | ||
| --dataset_sampling_probs 3 1 \ | ||
| --dataset_sampling_probs_final 1 3 \ | ||
| --dataset_sampling_probs_transition_method linear \ | ||
| --batch_size 8 \ | ||
| --block_size 128 \ | ||
| --n_layer 4 \ | ||
| --n_head 4 \ | ||
| --n_embd 256 \ | ||
| --max_iters 2000 \ | ||
| --eval_interval 200 \ | ||
| --eval_iters 50 \ | ||
| --learning_rate 3e-4 \ | ||
| --weight_decay 0.1 \ | ||
| --optimizer adamw \ | ||
| --use_rotary_embeddings \ | ||
| --no-use_abs_pos_embeddings \ | ||
| --compile \ | ||
| --no-tensorboard_log \ | ||
| --seed 1337 \ | ||
| --out_dir out_multidataset_per_batch_demo |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| # multidataset_per_batch_mixing.yaml | ||
| --- | ||
| # Demonstrates scheduled per-batch dataset mixing with two corpora. | ||
| training_mode: ["multidataset"] | ||
| dataset_list: | ||
| - "shakespeare_char minipile" | ||
| # Start with a 75/25 split and linearly transition to 25/75. | ||
| dataset_sampling_probs: | ||
| - "3 1" | ||
| dataset_sampling_probs_final: | ||
| - "1 3" | ||
| dataset_sampling_probs_transition_method: ["linear"] | ||
| dataset_mixing_per_batch: [true] | ||
| # Lightweight model + optimizer settings suitable for quick exploration runs. | ||
| block_size: [128] | ||
| batch_size: [8] | ||
| max_iters: [2000] | ||
| eval_interval: [200] | ||
| eval_iters: [50] | ||
| learning_rate: [3e-4] | ||
| weight_decay: [0.1] | ||
| optimizer: ["adamw"] | ||
| n_layer: [4] | ||
| n_head: [4] | ||
| n_embd: [256] | ||
| # Enable rotary embeddings and torch.compile for parity with other demos. | ||
| use_rotary_embeddings: [true] | ||
| use_abs_pos_embeddings: [false] | ||
| compile: [true] | ||
| seed: [1337] |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,6 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # train.py | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from contextlib import nullcontext | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Dict, Optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import csv | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import json | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import math | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -134,6 +135,9 @@ def __init__(self, args, model_group, training_group, logging_group): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 'abs_max': 0.0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Track how many samples from each dataset were used in the latest batch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.current_batch_dataset_counts: Optional[Dict[str, int]] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # whether to show all model stats | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.compute_model_stats = self.args.compute_model_stats | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -714,6 +718,7 @@ def load_data(self): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_batch(self, split, target_dataset=None): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dataset = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| data = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.current_batch_dataset_counts = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def interpolate_probs(initial_probs, final_probs, method, step_ratio): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if method == 'linear': | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return initial_probs + step_ratio * (final_probs - initial_probs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -764,10 +769,125 @@ def get_transitioned_probs(): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return x_dict, y_dict, list(self.args.multicontext_datasets) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif self.args.training_mode == "multidataset": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def sample_indices_from_dataset(dataset_name, count, data_array): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| available = len(data_array) - self.args.block_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.args.sampling_method == "random": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return torch.randint(available, (count,)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif self.args.sampling_method == "sequential": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.dataset_ptr[dataset_name] + count > available: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.dataset_ptr[dataset_name] = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| start = self.dataset_ptr[dataset_name] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| end = start + count | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| indices = self.dataset_perm[dataset_name][start:end] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.dataset_ptr[dataset_name] = end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return torch.tensor(indices) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif self.args.sampling_method == "without_replacement": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.dataset_ptr[dataset_name] + count > available: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.dataset_perm[dataset_name] = np.random.permutation(available) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.dataset_ptr[dataset_name] = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| start = self.dataset_ptr[dataset_name] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| end = start + count | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| indices = self.dataset_perm[dataset_name][start:end] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.dataset_ptr[dataset_name] = end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return torch.tensor(indices) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+776
to
+792
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif self.args.sampling_method == "sequential": | |
| if self.dataset_ptr[dataset_name] + count > available: | |
| self.dataset_ptr[dataset_name] = 0 | |
| start = self.dataset_ptr[dataset_name] | |
| end = start + count | |
| indices = self.dataset_perm[dataset_name][start:end] | |
| self.dataset_ptr[dataset_name] = end | |
| return torch.tensor(indices) | |
| elif self.args.sampling_method == "without_replacement": | |
| if self.dataset_ptr[dataset_name] + count > available: | |
| self.dataset_perm[dataset_name] = np.random.permutation(available) | |
| self.dataset_ptr[dataset_name] = 0 | |
| start = self.dataset_ptr[dataset_name] | |
| end = start + count | |
| indices = self.dataset_perm[dataset_name][start:end] | |
| self.dataset_ptr[dataset_name] = end | |
| return torch.tensor(indices) | |
| elif self.args.sampling_method in ("sequential", "without_replacement"): | |
| if self.dataset_ptr[dataset_name] + count > available: | |
| if self.args.sampling_method == "without_replacement": | |
| self.dataset_perm[dataset_name] = np.random.permutation(available) | |
| self.dataset_ptr[dataset_name] = 0 | |
| start = self.dataset_ptr[dataset_name] | |
| end = start + count | |
| indices = self.dataset_perm[dataset_name][start:end] | |
| self.dataset_ptr[dataset_name] = end | |
| return torch.tensor(indices) |
Copilot
AI
Oct 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function duplicates tensor construction logic that likely exists elsewhere for batch creation. Consider extracting this into a reusable helper method to avoid code duplication and ensure consistency across different batching paths.
Copilot
AI
Oct 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This condition evaluates to True when batch_counts is None or an empty dict, but the code at line 1850 checks if batch_counts: which only executes when batch_counts is truthy (non-None and non-empty). These conditions should use explicit is None checks for clarity and correctness. Change to if batch_counts is None: to properly handle the case when per-batch mixing is not used.
| if not batch_counts: | |
| if batch_counts is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting
self.current_batch_dataset_countstoNoneat the start ofget_batch()means it will beNonefor all non-per-batch-mixing code paths. This causes the conditionif not batch_counts:at line 1869 to fail whenbatch_countsisNone(it should checkif batch_counts is None:), potentially leading to an AttributeError when trying to iterate overNonein theelseblock at line 1860.