Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions demos/multidataset_per_batch_mixing.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/bin/bash
# demos/multidataset_per_batch_mixing.sh
# Demonstrates scheduled per-batch mixing across multiple datasets.

set -e

# Ensure required corpora are prepared.
bash data/shakespeare_char/get_dataset.sh

pushd data/minipile > /dev/null
if [ ! -f "train.bin" ] || [ ! -f "val.bin" ] || [ ! -f "meta.pkl" ]; then
bash get_dataset.sh
python3 prepare.py -t input.txt --method tiktoken
else
echo "train.bin, val.bin, and meta.pkl already exist for minipile."
fi
popd > /dev/null

# Run a small demonstration training job with per-batch mixing enabled.
python3 train.py \
--training_mode multidataset \
--dataset_list shakespeare_char minipile \
--dataset_mixing_per_batch \
--dataset_sampling_probs 3 1 \
--dataset_sampling_probs_final 1 3 \
--dataset_sampling_probs_transition_method linear \
--batch_size 8 \
--block_size 128 \
--n_layer 4 \
--n_head 4 \
--n_embd 256 \
--max_iters 2000 \
--eval_interval 200 \
--eval_iters 50 \
--learning_rate 3e-4 \
--weight_decay 0.1 \
--optimizer adamw \
--use_rotary_embeddings \
--no-use_abs_pos_embeddings \
--compile \
--no-tensorboard_log \
--seed 1337 \
--out_dir out_multidataset_per_batch_demo
30 changes: 30 additions & 0 deletions explorations/multidataset_per_batch_mixing.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# multidataset_per_batch_mixing.yaml
---
# Demonstrates scheduled per-batch dataset mixing with two corpora.
training_mode: ["multidataset"]
dataset_list:
- "shakespeare_char minipile"
# Start with a 75/25 split and linearly transition to 25/75.
dataset_sampling_probs:
- "3 1"
dataset_sampling_probs_final:
- "1 3"
dataset_sampling_probs_transition_method: ["linear"]
dataset_mixing_per_batch: [true]
# Lightweight model + optimizer settings suitable for quick exploration runs.
block_size: [128]
batch_size: [8]
max_iters: [2000]
eval_interval: [200]
eval_iters: [50]
learning_rate: [3e-4]
weight_decay: [0.1]
optimizer: ["adamw"]
n_layer: [4]
n_head: [4]
n_embd: [256]
# Enable rotary embeddings and torch.compile for parity with other demos.
use_rotary_embeddings: [true]
use_abs_pos_embeddings: [false]
compile: [true]
seed: [1337]
145 changes: 139 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# train.py
from contextlib import nullcontext
from typing import Dict, Optional
import csv
import json
import math
Expand Down Expand Up @@ -134,6 +135,9 @@ def __init__(self, args, model_group, training_group, logging_group):
'abs_max': 0.0,
}

# Track how many samples from each dataset were used in the latest batch
self.current_batch_dataset_counts: Optional[Dict[str, int]] = None

# whether to show all model stats
self.compute_model_stats = self.args.compute_model_stats

Expand Down Expand Up @@ -714,6 +718,7 @@ def load_data(self):
def get_batch(self, split, target_dataset=None):
dataset = None
data = None
self.current_batch_dataset_counts = None
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting self.current_batch_dataset_counts to None at the start of get_batch() means it will be None for all non-per-batch-mixing code paths. This causes the condition if not batch_counts: at line 1869 to fail when batch_counts is None (it should check if batch_counts is None:), potentially leading to an AttributeError when trying to iterate over None in the else block at line 1860.

Copilot uses AI. Check for mistakes.
def interpolate_probs(initial_probs, final_probs, method, step_ratio):
if method == 'linear':
return initial_probs + step_ratio * (final_probs - initial_probs)
Expand Down Expand Up @@ -764,10 +769,125 @@ def get_transitioned_probs():
return x_dict, y_dict, list(self.args.multicontext_datasets)

elif self.args.training_mode == "multidataset":
def sample_indices_from_dataset(dataset_name, count, data_array):
available = len(data_array) - self.args.block_size
if self.args.sampling_method == "random":
return torch.randint(available, (count,))
elif self.args.sampling_method == "sequential":
if self.dataset_ptr[dataset_name] + count > available:
self.dataset_ptr[dataset_name] = 0
start = self.dataset_ptr[dataset_name]
end = start + count
indices = self.dataset_perm[dataset_name][start:end]
self.dataset_ptr[dataset_name] = end
return torch.tensor(indices)
elif self.args.sampling_method == "without_replacement":
if self.dataset_ptr[dataset_name] + count > available:
self.dataset_perm[dataset_name] = np.random.permutation(available)
self.dataset_ptr[dataset_name] = 0
start = self.dataset_ptr[dataset_name]
end = start + count
indices = self.dataset_perm[dataset_name][start:end]
self.dataset_ptr[dataset_name] = end
return torch.tensor(indices)
Comment on lines +776 to +792
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function duplicates logic that likely exists elsewhere in the codebase for sampling indices. The sequential and without_replacement branches have identical implementations except for line 786 (permutation regeneration). Consider extracting and reusing existing sampling logic or consolidating the duplicated code within this function.

Suggested change
elif self.args.sampling_method == "sequential":
if self.dataset_ptr[dataset_name] + count > available:
self.dataset_ptr[dataset_name] = 0
start = self.dataset_ptr[dataset_name]
end = start + count
indices = self.dataset_perm[dataset_name][start:end]
self.dataset_ptr[dataset_name] = end
return torch.tensor(indices)
elif self.args.sampling_method == "without_replacement":
if self.dataset_ptr[dataset_name] + count > available:
self.dataset_perm[dataset_name] = np.random.permutation(available)
self.dataset_ptr[dataset_name] = 0
start = self.dataset_ptr[dataset_name]
end = start + count
indices = self.dataset_perm[dataset_name][start:end]
self.dataset_ptr[dataset_name] = end
return torch.tensor(indices)
elif self.args.sampling_method in ("sequential", "without_replacement"):
if self.dataset_ptr[dataset_name] + count > available:
if self.args.sampling_method == "without_replacement":
self.dataset_perm[dataset_name] = np.random.permutation(available)
self.dataset_ptr[dataset_name] = 0
start = self.dataset_ptr[dataset_name]
end = start + count
indices = self.dataset_perm[dataset_name][start:end]
self.dataset_ptr[dataset_name] = end
return torch.tensor(indices)

Copilot uses AI. Check for mistakes.
else:
return torch.randint(available, (count,))

def build_tensors(data_array, indices):
x_local = torch.stack([
torch.from_numpy(data_array[i:i+self.args.block_size].astype(np.int64))
for i in indices
])
y_local = torch.stack([
torch.from_numpy(data_array[i+1:i+1+self.args.block_size].astype(np.int64))
for i in indices
])
return x_local, y_local
Comment on lines +796 to +805
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function duplicates tensor construction logic that likely exists elsewhere for batch creation. Consider extracting this into a reusable helper method to avoid code duplication and ensure consistency across different batching paths.

Copilot uses AI. Check for mistakes.

# If multi-dataset sampling is enabled, pick a dataset using sampling probabilities
if target_dataset:
dataset = target_dataset
data = self.train_data_dict[dataset] if split == 'train' else self.val_data_dict[dataset]
elif (
self.args.dataset_mixing_per_batch
and not self.args.dataset_interleaving
and target_dataset is None
):
if self.args.multidataset_wte:
raise ValueError("Per-batch dataset mixing is currently incompatible with --multidataset_wte")

if self.args.dataset_sampling_probs:
transitioned_probs = get_transitioned_probs()
else:
transitioned_probs = np.ones(len(self.args.dataset_list), dtype=float)

transitioned_probs = np.asarray(transitioned_probs, dtype=float)
total_prob = transitioned_probs.sum()
if total_prob <= 0:
transitioned_probs = np.ones_like(transitioned_probs)
total_prob = transitioned_probs.sum()
normalized_probs = transitioned_probs / total_prob

ideal_counts = normalized_probs * self.args.batch_size
base_counts = np.floor(ideal_counts).astype(int)
remainders = ideal_counts - base_counts
remaining = self.args.batch_size - base_counts.sum()

if remaining > 0:
order = np.argsort(-remainders)
for idx in order[:remaining]:
base_counts[idx] += 1
elif remaining < 0:
order = np.argsort(remainders)
for idx in order[: -remaining]:
if base_counts[idx] > 0:
base_counts[idx] -= 1

counts = base_counts
if counts.sum() == 0:
counts[np.argmax(normalized_probs)] = self.args.batch_size

x_parts = []
y_parts = []
dataset_counts = {}
for ds_name, count in zip(self.args.dataset_list, counts):
if count <= 0:
continue
data_array = (
self.train_data_dict[ds_name]
if split == 'train'
else self.val_data_dict[ds_name]
)
indices = sample_indices_from_dataset(ds_name, count, data_array)
x_local, y_local = build_tensors(data_array, indices)
x_parts.append(x_local)
y_parts.append(y_local)
dataset_counts[ds_name] = count

if not x_parts:
dataset = np.random.choice(self.args.dataset_list)
data = (
self.train_data_dict[dataset]
if split == 'train'
else self.val_data_dict[dataset]
)
else:
x = torch.cat(x_parts, dim=0) if len(x_parts) > 1 else x_parts[0]
y = torch.cat(y_parts, dim=0) if len(y_parts) > 1 else y_parts[0]
if x.size(0) > 1:
permutation = torch.randperm(x.size(0))
x = x[permutation]
y = y[permutation]
self.current_batch_dataset_counts = dataset_counts
dataset = max(dataset_counts, key=dataset_counts.get)
if self.args.use_lsv:
self.model.set_lsv_index(self.args.dataset_list.index(dataset))
if self.device_type == 'cuda':
x = x.pin_memory().to(self.device, non_blocking=True)
y = y.pin_memory().to(self.device, non_blocking=True)
else:
x, y = x.to(self.device), y.to(self.device)
return x, y, dataset
elif self.args.dataset_interleaving:
# print("using interleaving")
if self.args.dataset_sampling_probs is not None:
Expand Down Expand Up @@ -1726,18 +1846,31 @@ def train(self):
prior_dataset = current_dataset
tokens_trained_this_batch = self.args.batch_size * self.args.block_size
if self.args.dataset_list:
# Update per–dataset count
self.tokens_trained_dict[current_dataset] += tokens_trained_this_batch
self.tokens_trained = self.tokens_trained_dict[current_dataset]
batch_counts = self.current_batch_dataset_counts
if batch_counts:
for ds_name, sample_count in batch_counts.items():
tokens = sample_count * self.args.block_size
self.tokens_trained_dict[ds_name] += tokens
self.epochs_trained_dict[ds_name] = (
self.tokens_trained_dict[ds_name] / self.dataset_size_tokens[ds_name]
)
dominant_dataset = max(batch_counts, key=batch_counts.get)
self.tokens_trained = self.tokens_trained_dict[dominant_dataset]
current_epoch = self.epochs_trained_dict[dominant_dataset]
else:
# Update per–dataset count
self.tokens_trained_dict[current_dataset] += tokens_trained_this_batch
self.tokens_trained = self.tokens_trained_dict[current_dataset]
else:
self.tokens_trained += tokens_trained_this_batch

# Compute epoch for logging:
if self.args.dataset_list:
if self.args.dataset_list:
if not batch_counts:
Copy link

Copilot AI Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition evaluates to True when batch_counts is None or an empty dict, but the code at line 1850 checks if batch_counts: which only executes when batch_counts is truthy (non-None and non-empty). These conditions should use explicit is None checks for clarity and correctness. Change to if batch_counts is None: to properly handle the case when per-batch mixing is not used.

Suggested change
if not batch_counts:
if batch_counts is None:

Copilot uses AI. Check for mistakes.
current_epoch = self.tokens_trained_dict[current_dataset] / self.dataset_size_tokens[current_dataset]
self.epochs_trained_dict[current_dataset] = current_epoch
else:
current_epoch = self.tokens_trained / self.dataset_size_tokens
else:
current_epoch = self.tokens_trained / self.dataset_size_tokens

self.scaler.scale(loss).backward()

Expand Down
2 changes: 2 additions & 0 deletions train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def parse_args():
training_group.add_argument('--dataset_sampling_probs', action=FlattenListAction, default=None, nargs='+', help="Sampling proportions for each dataset in dataset_list. Probabilities normally but proportions in dataset_interleaving")
training_group.add_argument('--dataset_sampling_probs_final', action=FlattenListAction,default=None, nargs='+', help="If, set final sampling probabilities for each dataset in dataset_list.")
training_group.add_argument('--dataset_sampling_probs_transition_method', default=None, type=str, choices=["linear", "cosine", "exponential"])
training_group.add_argument('--dataset_mixing_per_batch', default=False, action=argparse.BooleanOptionalAction,
help="When enabled in multidataset mode, draw each batch using the per-dataset sampling probabilities instead of sampling a single dataset per iteration.")

# Add GNS settings
training_group.add_argument('--gns_type', type=str, default=None, choices=['sogns', 'exact'], help='Type of gradient norm scaling to use (default: None)')
Expand Down