Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion docs/train/data-prep.md
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,39 @@ format=JsonlOutputConfig(num_shards=64)

Supported size formats: `"256MB"`, `"1G"`, `"500MiB"`, etc.

## Per-Dataset Shard Allocation

When a blend includes multiple datasets, shard counts are now allocated per dataset
instead of using a single global count. The total shard budget comes from
`num_shards` or `shard_size`, and is distributed proportionally by dataset weight
and estimated size, with at least one shard per dataset and a cap based on the
number of input files (to avoid empty shards). Dataset weights still control
training mixture ratios; shard allocation is only a sizing heuristic.

> **Note: Weight vs Shard Counts**
>
> - **Dataset.weight** (e.g., 0.7, 0.3): Controls *training-time sampling* in
> Megatron-Bridge. A blend with weights [0.7, 0.3] means 70% of training
> samples come from dataset 1 during training.
>
> - **Shard counts**: Controlled by `shard_size` (recommended) or explicit
> `num_shards`. These determine how many physical output files are created
> during data preparation, independent of weights.
>
> For blends with datasets of different sizes, use `shard_size="256MB"` instead
> of explicit `num_shards` to let each dataset get an appropriate shard count
> based on its size.

`blend.json` now includes a `num_shards` map with the effective per-dataset counts:

```json
{
"data_paths": ["1.0", "/path/to/ds1/shard", "0.3", "/path/to/ds2/shard"],
"num_shards": {"ds1": 120, "ds2": 8},
"split": "99990,8,2"
}
```

## Per-Split Output

Generate separate train/valid/test outputs using `PerSplitConfig`:
Expand Down Expand Up @@ -485,7 +518,12 @@ output/
{
"train": [["1.0", "/path/to/train/shard_000000"], ["1.0", "/path/to/train/shard_000001"]],
"valid": [["1.0", "/path/to/valid/shard_000000"]],
"test": [["1.0", "/path/to/test/shard_000000"]]
"test": [["1.0", "/path/to/test/shard_000000"]],
"num_shards": {
"train": {"train_ds": 128},
"valid": {"valid_ds": 2},
"test": {"test_ds": 2}
}
}
```

Expand Down
10 changes: 10 additions & 0 deletions src/nemotron/cli/nano3/data/import_/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import json
from pathlib import Path

import typer
Expand Down Expand Up @@ -61,12 +62,21 @@ def pretrain(
# Initialize W&B
init_wandb_if_configured(wandb_config, job_type="data-import", tags=["pretrain", "import"])

dataset_shards = None
try:
with open(data_path) as f:
blend_data = json.load(f)
dataset_shards = blend_data.get("num_shards")
except Exception:
dataset_shards = None

# Create artifact with minimal required fields
artifact_name = name or "nano3/pretrain/data"
artifact = DataBlendsArtifact(
path=data_path,
total_tokens=0,
total_sequences=0,
dataset_shards=dataset_shards,
name=artifact_name,
)

Expand Down
10 changes: 10 additions & 0 deletions src/nemotron/cli/nano3/data/import_/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import json
from pathlib import Path

import typer
Expand Down Expand Up @@ -71,12 +72,21 @@ def sft(
# Initialize W&B
init_wandb_if_configured(wandb_config, job_type="data-import", tags=["sft", "import"])

dataset_shards = None
try:
with open(blend_path) as f:
blend_data = json.load(f)
dataset_shards = blend_data.get("num_shards")
except Exception:
dataset_shards = None

# Create artifact with minimal required fields
artifact_name = name or "nano3/sft/data"
artifact = DataBlendsArtifact(
path=blend_path,
total_tokens=0,
total_sequences=0,
dataset_shards=dataset_shards,
name=artifact_name,
)

Expand Down
30 changes: 19 additions & 11 deletions src/nemotron/data_prep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,17 +364,25 @@ def run_data_prep(

# Build output artifact - path points to output directory, blend_path to blend.json
blend_json_path = result.output_dir / "blend.json"
artifact = artifact_class(
path=result.output_dir,
blend_path=str(blend_json_path),
total_tokens=result.total_tokens,
total_sequences=result.total_sequences,
elapsed_sec=result.elapsed_sec,
num_shards=num_shards,
source_datasets=source_datasets,
tokenizer_uri=tok_uri,
name=config.artifact_name, # Semantic name for W&B artifact naming
)
artifact_kwargs = {
"path": result.output_dir,
"blend_path": str(blend_json_path),
"total_tokens": result.total_tokens,
"total_sequences": result.total_sequences,
"elapsed_sec": result.elapsed_sec,
"num_shards": num_shards,
"source_datasets": source_datasets,
"tokenizer_uri": tok_uri,
"name": config.artifact_name, # Semantic name for W&B artifact naming
}
# Optionally include per-dataset shard counts if supported by the artifact schema
if hasattr(artifact_class, "model_fields") and "dataset_shards" in artifact_class.model_fields:
all_split = result.splits.get("all") if result.splits else None
artifact_kwargs["dataset_shards"] = (
all_split.dataset_shards if all_split is not None else None
)

artifact = artifact_class(**artifact_kwargs)
artifact.save()

# Mark wandb run as successful (before Ray shutdown to avoid socket noise)
Expand Down
6 changes: 5 additions & 1 deletion src/nemotron/data_prep/blend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ class Dataset(BaseModel):
Attributes:
name: Unique identifier for this dataset
path: Data location (hf://repo/name, s3://bucket/prefix, /local/path)
weight: Relative weight in the blend (default: 1.0)
weight: Training-time sampling weight (default: 1.0). Controls how
Megatron-Bridge samples from datasets during training, NOT how
many shards are created during data prep. For example, weights
[0.7, 0.3] mean 70% of training samples come from dataset 1.
Shard counts are determined by dataset size and shard_size config.
split: HuggingFace split name (required for hf:// paths)
subset: HuggingFace config/subset name
text_field: Field containing text to tokenize (default: "text")
Expand Down
9 changes: 7 additions & 2 deletions src/nemotron/data_prep/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,13 @@ class BinIdxOutputConfig:

Attributes:
format: Format identifier (always "binidx")
shard_size: Target size per shard (e.g., "256MB"). Mutually exclusive with num_shards.
num_shards: Exact number of output shards. Mutually exclusive with shard_size.
shard_size: Target size per shard (e.g., "256MB"). When set, shard count
is computed per-dataset based on individual dataset sizes. This is
recommended for blends with datasets of varying sizes, as it prevents
empty shard files for small datasets. Mutually exclusive with num_shards.
num_shards: Exact number of output shards applied to ALL datasets.
Use shard_size instead when datasets have very different sizes to
avoid empty shards. Mutually exclusive with shard_size.
dtype: Token dtype (int32, int64, uint16)
"""

Expand Down
Loading