Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/models/gpt_oss/slurm_peft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ export WKDIR="${WKDIR:-}"
# After pretrain, use e.g. ${WORKSPACE}/results/${MODEL_NAME}_pretrain_tp2_pp4_ep4_spTrue_cp1
PRETRAINED_CHECKPOINT=${PRETRAINED_CHECKPOINT:-${WORKSPACE}/models/gpt-oss-20b}
MODEL_NAME=gpt_oss_20b
RECIPE_NAME="${RECIPE_NAME:-${MODEL_NAME}_peft_config}"
# RECIPE_NAME="${MODEL_NAME}_peft_fp8_current_scaling_config"
DATASET_NAME=squad
SEQ_LENGTH=2048
TRAIN_ITERS=1000
Expand Down Expand Up @@ -158,10 +160,9 @@ for CONFIG in "${PARALLELISM_CONFIGS[@]}"; do
dataset.seq_length=$SEQ_LENGTH \
model.seq_length=$SEQ_LENGTH
"

CMD="uv run --no-sync python /opt/Megatron-Bridge/scripts/training/run_recipe.py"
CMD="$CMD --mode finetune"
CMD="$CMD --recipe ${MODEL_NAME}_peft_config"
CMD="$CMD --recipe ${RECIPE_NAME}"

# Collapse newlines so bash -c receives a single command
CMD="$CMD $(echo "$CLI_OVERRIDES" | tr '\n' ' ' | sed 's/ \+/ /g')"
Expand Down
5 changes: 3 additions & 2 deletions examples/models/gpt_oss/slurm_pretrain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ export WKDIR="${WKDIR:-}"

# Model and training configurations
MODEL_NAME=gpt_oss_20b
# RECIPE_NAME="${RECIPE_NAME:-${MODEL_NAME}_pretrain_config}"
RECIPE_NAME="${MODEL_NAME}_pretrain_fp8_current_scaling_config"
DATASET_NAME=dclm # set to "mock" for mock data; "dclm" uses DCLM when DCLM_DATA_DIR/DCLM_CACHE are set below
SEQ_LENGTH=4096

Expand Down Expand Up @@ -188,9 +190,8 @@ for CONFIG in "${PARALLELISM_CONFIGS[@]}"; do
if [ -n "$DCLM_DATASET_OVERRIDES" ]; then
CLI_OVERRIDES="$CLI_OVERRIDES $DCLM_DATASET_OVERRIDES"
fi

CMD="uv run --no-sync python /opt/Megatron-Bridge/scripts/training/run_recipe.py"
CMD="$CMD --recipe ${MODEL_NAME}_pretrain_config"
CMD="$CMD --recipe ${RECIPE_NAME}"
CMD="$CMD $CLI_OVERRIDES"

echo "Executing command..."
Expand Down
5 changes: 3 additions & 2 deletions examples/models/gpt_oss/slurm_sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ export WKDIR="${WKDIR:-}"
# Use base dir (e.g. .../gpt-oss-20b) with latest_checkpointed_iteration.txt, or Bridge dir with latest_train_state.pt
PRETRAINED_CHECKPOINT=${PRETRAINED_CHECKPOINT:-${WORKSPACE}/models/gpt-oss-20b}
MODEL_NAME=gpt_oss_20b
RECIPE_NAME="${RECIPE_NAME:-${MODEL_NAME}_sft_config}"
# RECIPE_NAME="${MODEL_NAME}_sft_fp8_current_scaling_config"
DATASET_NAME=squad
SEQ_LENGTH=2048
TRAIN_ITERS=1000
Expand Down Expand Up @@ -161,10 +163,9 @@ for CONFIG in "${PARALLELISM_CONFIGS[@]}"; do
dataset.seq_length=$SEQ_LENGTH \
model.seq_length=$SEQ_LENGTH
"

CMD="uv run --no-sync python /opt/Megatron-Bridge/scripts/training/run_recipe.py"
CMD="$CMD --mode finetune"
CMD="$CMD --recipe ${MODEL_NAME}_sft_config"
CMD="$CMD --recipe ${RECIPE_NAME}"
CMD="$CMD --peft_scheme none"
# Collapse newlines so bash -c receives a single command
CMD="$CMD $(echo "$CLI_OVERRIDES" | tr '\n' ' ' | sed 's/ \+/ /g')"
Expand Down
6 changes: 6 additions & 0 deletions src/megatron/bridge/recipes/gpt_oss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@

from .gpt_oss import (
gpt_oss_20b_peft_config,
gpt_oss_20b_peft_fp8_current_scaling_config,
gpt_oss_20b_pretrain_config,
gpt_oss_20b_pretrain_fp8_current_scaling_config,
gpt_oss_20b_sft_config,
gpt_oss_20b_sft_fp8_current_scaling_config,
gpt_oss_120b_peft_config,
gpt_oss_120b_pretrain_config,
gpt_oss_120b_sft_config,
Expand All @@ -24,9 +27,12 @@

__all__ = [
"gpt_oss_20b_pretrain_config",
"gpt_oss_20b_pretrain_fp8_current_scaling_config",
"gpt_oss_120b_pretrain_config",
"gpt_oss_20b_sft_config",
"gpt_oss_20b_sft_fp8_current_scaling_config",
"gpt_oss_120b_sft_config",
"gpt_oss_20b_peft_config",
"gpt_oss_20b_peft_fp8_current_scaling_config",
"gpt_oss_120b_peft_config",
]
27 changes: 27 additions & 0 deletions src/megatron/bridge/recipes/gpt_oss/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
from megatron.bridge.training.config import ConfigContainer


def _enable_gpt_oss_hopper_fp8_current_scaling(cfg: ConfigContainer) -> ConfigContainer:
"""Enable Hopper FP8 current scaling for GPT-OSS recipes."""
cfg.mixed_precision = "bf16_with_fp8_current_scaling_mixed"
cfg.model.moe_router_padding_for_fp8 = True
return cfg


def gpt_oss_20b_pretrain_config() -> ConfigContainer:
"""Return a pre-training config for GPT-OSS 20B variant.
Expand Down Expand Up @@ -254,6 +261,12 @@ def gpt_oss_120b_pretrain_config() -> ConfigContainer:
return cfg


def gpt_oss_20b_pretrain_fp8_current_scaling_config() -> ConfigContainer:
"""Return a pre-training config for GPT-OSS 20B with Hopper FP8 current scaling."""
cfg = gpt_oss_20b_pretrain_config()
return _enable_gpt_oss_hopper_fp8_current_scaling(cfg)


# =============================================================================
# SFT Configs
# =============================================================================
Expand Down Expand Up @@ -511,6 +524,12 @@ def gpt_oss_120b_sft_config() -> ConfigContainer:
return cfg


def gpt_oss_20b_sft_fp8_current_scaling_config() -> ConfigContainer:
"""Return a full SFT config for GPT-OSS 20B with Hopper FP8 current scaling."""
cfg = gpt_oss_20b_sft_config()
return _enable_gpt_oss_hopper_fp8_current_scaling(cfg)


# =============================================================================
# PEFT Configs
# =============================================================================
Expand Down Expand Up @@ -784,3 +803,11 @@ def gpt_oss_120b_peft_config(
cfg.rng.seed = 5678

return cfg


def gpt_oss_20b_peft_fp8_current_scaling_config(
peft_scheme: str | PEFT = "lora",
) -> ConfigContainer:
"""Return a PEFT config for GPT-OSS 20B with Hopper FP8 current scaling."""
cfg = gpt_oss_20b_peft_config(peft_scheme=peft_scheme)
return _enable_gpt_oss_hopper_fp8_current_scaling(cfg)