Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

making converter additionally support deepseek-coder dense hf model #167

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/examples/huggingface/convert_checkpoint_from_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from olmo_core.data.tokenizer import TokenizerConfig
from olmo_core.distributed.checkpoint import load_model_and_optim_state, save_state_dict
from olmo_core.io import clear_directory, dir_is_empty
from olmo_core.nn.rope import RoPEScalingConfig
from olmo_core.nn.rope import RoPELlamaScalingConfig, RoPELinearScalingConfig
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.utils import get_default_device, prepare_cli_environment

Expand All @@ -27,19 +27,30 @@
# HF_MODEL = "allenai/OLMo-2-1124-13B-Instruct"
# HF_MODEL = "meta-llama/Llama-3.2-1B"
# HF_MODEL = "meta-llama/Llama-3.2-8B"
# HF_MODEL = "deepseek-ai/deepseek-coder-1.3b-base"

SAVE_PATH = f"/tmp/checkpoints/{HF_MODEL}"
SAVE_OVERWRITE = False

TOKENIZER_CONFIG = TokenizerConfig.from_hf(HF_MODEL)

MODEL_CONFIG: TransformerConfig
if HF_MODEL == "meta-llama/Llama-3.2-1B":
if "Llama-3.2-1B" in HF_MODEL:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need this for bigger Llama models?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I change it so that it works with local cache dir. We are encountering some additional weird bug related to conversion; and requires some help from you. We will submit a separate issue regarding that. Do you prefer we incorporate solution to that issue into this PR or should I separate that from this current PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many small PRs is always better, if possible.

MODEL_CONFIG = TransformerConfig.llama3_1B(
TOKENIZER_CONFIG.vocab_size,
fused_ops=False,
use_flash=False,
rope_scaling=RoPEScalingConfig(),
rope_scaling=RoPELlamaScalingConfig(),
)

elif HF_MODEL.startswith("deepseek-coder-1.3b-base"):
MODEL_CONFIG = TransformerConfig.deepseek_1B(
TOKENIZER_CONFIG.vocab_size,
fused_ops=False,
use_flash=False,
rope_scaling=RoPELinearScalingConfig(factor=4.0),
)

elif HF_MODEL.startswith("allenai/OLMo-2-1124-7B"):
MODEL_CONFIG = TransformerConfig.olmo2_7B(
TOKENIZER_CONFIG.vocab_size,
Expand Down Expand Up @@ -105,7 +116,7 @@ def convert_checkpoint() -> AutoModelForCausalLM:
)

# Layer norms.
if "Llama" in HF_MODEL:
if "Llama" or "deepseek" in HF_MODEL:
new_state_dict[f"blocks.{block}.feed_forward_norm.weight"] = state_dict.pop(
f"model.layers.{block}.post_attention_layernorm.weight"
)
Expand Down Expand Up @@ -152,7 +163,7 @@ def validate_conversion(hf_model):
model = MODEL_CONFIG.build(device=device, max_seq_len=131072).eval()

log.info("Loading converted checkpoint for validation...")
load_model_and_optim_state(SAVE_PATH, model)
load_model_and_optim_state(os.path.join(SAVE_PATH, "model_and_optim"), model)

with torch.no_grad():
logits = model(input_ids=input_ids)
Expand Down
175 changes: 140 additions & 35 deletions src/examples/huggingface/convert_checkpoint_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Generator

import os
import torch
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer, Olmo2Config

from olmo_core.distributed.checkpoint import unshard_checkpoint
from olmo_core.utils import prepare_cli_environment
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer, Olmo2Config, LlamaConfig, OlmoForCausalLM
from olmo_core.nn.transformer import Transformer
from olmo_core.nn.transformer import TransformerConfig, TransformerBlockConfig
from olmo_core.nn.rope import RoPELlamaScalingConfig, RoPELinearScalingConfig
from olmo_core.distributed.checkpoint import unshard_checkpoint, load_model_and_optim_state
from olmo_core.utils import prepare_cli_environment, get_default_device

log = logging.getLogger(__name__)

Expand All @@ -33,7 +35,9 @@ def init_empty_weights(include_buffers: bool = False) -> Generator[None, None, N
yield None


def load_state_dict(checkpoint_path: str | Path) -> dict[str, torch.Tensor]:
from typing import Union

def load_state_dict(checkpoint_path: Union[str, Path]) -> dict[str, torch.Tensor]:
"""
Load a state dict from either a PyTorch checkpoint or safetensors file.

Expand Down Expand Up @@ -64,8 +68,8 @@ def load_state_dict(checkpoint_path: str | Path) -> dict[str, torch.Tensor]:


def convert_to_hf_checkpoint(
olmo_checkpoint_path: str | Path,
output_path: str | Path,
olmo_checkpoint_path: Union[str, Path],
output_path: Union[str, Path],
olmo_core_config: dict,
tokenizer: GPT2Tokenizer,
max_sequence_length: int = -1,
Expand Down Expand Up @@ -117,6 +121,7 @@ def convert_to_hf_checkpoint(

for block in range(n_layers):
# Attention

hf_state_dict[f"model.layers.{block}.self_attn.q_proj.weight"] = olmo_state_dict.pop(
f"blocks.{block}.attention.w_q.weight"
)
Expand Down Expand Up @@ -159,7 +164,7 @@ def convert_to_hf_checkpoint(
f"blocks.{block}.attention.k_norm.weight"
)
else:
# Llama model
# Llama or deepseek model
hf_state_dict[
f"model.layers.{block}.post_attention_layernorm.weight"
] = olmo_state_dict.pop(f"blocks.{block}.feed_forward_norm.weight")
Expand Down Expand Up @@ -196,25 +201,65 @@ def convert_to_hf_checkpoint(
raise ValueError(f"Missing or invalid sequence length: {max_sequence_length}")

# Create HF model instance and load state dict
huggingface_config = Olmo2Config(
vocab_size=olmo_core_config["model"]["vocab_size"],
hidden_size=olmo_core_config["model"]["d_model"],
intermediate_size=olmo_core_config["model"]["block"]["feed_forward"]["hidden_size"],
num_hidden_layers=olmo_core_config["model"]["n_layers"],
num_attention_heads=(n_heads := olmo_core_config["model"]["block"]["attention"]["n_heads"]),
num_key_value_heads=(
olmo_core_config["model"]["block"]["attention"].get("n_kv_heads") or n_heads
),
hidden_act="silu",
max_position_embeddings=max_sequence_length,
rope_theta=olmo_core_config["model"]["block"]["attention"]["rope"]["theta"],
attention_bias=olmo_core_config["model"]["block"]["attention"].get("bias") or False,
pad_token_id=tokenizer.vocab.get(tokenizer.pad_token, None),
bos_token_id=tokenizer.vocab.get(tokenizer.bos_token, None),
eos_token_id=tokenizer.vocab.get(tokenizer.eos_token, None),
rms_norm_eps=olmo_core_config["model"]["block"]["layer_norm"]["eps"],
tie_word_embeddings=False,
)
if has_qk_norm:
huggingface_config = Olmo2Config(
vocab_size=olmo_core_config["model"]["vocab_size"],
hidden_size=olmo_core_config["model"]["d_model"],
intermediate_size=olmo_core_config["model"]["block"]["feed_forward"]["hidden_size"],
num_hidden_layers=olmo_core_config["model"]["n_layers"],
num_attention_heads=(n_heads := olmo_core_config["model"]["block"]["attention"]["n_heads"]),
num_key_value_heads=(
olmo_core_config["model"]["block"]["attention"].get("n_kv_heads") or n_heads
),
hidden_act="silu",
max_position_embeddings=max_sequence_length,
rope_theta=olmo_core_config["model"]["block"]["attention"]["rope"]["theta"],
attention_bias=olmo_core_config["model"]["block"]["attention"].get("bias") or False,
pad_token_id=tokenizer.vocab.get(tokenizer.pad_token, None),
bos_token_id=tokenizer.vocab.get(tokenizer.bos_token, None),
eos_token_id=tokenizer.vocab.get(tokenizer.eos_token, None),
rms_norm_eps=olmo_core_config["model"]["block"]["layer_norm"]["eps"],
tie_word_embeddings=False,
)
else:

rope_config = olmo_core_config["model"]["block"]["attention"]["rope"].get("scaling", None)

if rope_config is None:
rope_scaling = None
elif "high_freq_factor" in rope_config:
# llama3 scaling
rope_scaling = {
"rope_type": "llama3",
"factor": rope_config["factor"],
"high_freq_factor": rope_config["high_freq_factor"],
"low_freq_factor": rope_config["low_freq_factor"],
"original_max_position_embeddings": rope_config["old_context_len"]
}
else:
# linear scaling (e.g. deepseek-coder)
rope_scaling = {"type": "linear", "factor": rope_config["factor"]}

huggingface_config = LlamaConfig(
vocab_size=olmo_core_config["model"]["vocab_size"],
hidden_size=olmo_core_config["model"]["d_model"],
intermediate_size=olmo_core_config["model"]["block"]["feed_forward"]["hidden_size"],
num_hidden_layers=olmo_core_config["model"]["n_layers"],
num_attention_heads=(n_heads := olmo_core_config["model"]["block"]["attention"]["n_heads"]),
num_key_value_heads=(
olmo_core_config["model"]["block"]["attention"].get("n_kv_heads") or n_heads
),
hidden_act="silu",
max_position_embeddings=max_sequence_length,
rope_theta=olmo_core_config["model"]["block"]["attention"]["rope"]["theta"],
attention_bias=olmo_core_config["model"]["block"]["attention"].get("bias") or False,
pad_token_id=tokenizer.vocab.get(tokenizer.pad_token, None),
bos_token_id=tokenizer.vocab.get(tokenizer.bos_token, None),
eos_token_id=tokenizer.vocab.get(tokenizer.eos_token, None),
rms_norm_eps=olmo_core_config["model"]["block"]["layer_norm"]["eps"],
tie_word_embeddings=False,
rope_scaling=rope_scaling,
)

with init_empty_weights():
log.info("Initializing HF model with empty weights...")
Expand All @@ -233,6 +278,7 @@ def convert_to_hf_checkpoint(
log.info(f"Successfully saved HF tokenizer to '{output_path}'")



def load_config(checkpoint_input_dir: Path) -> dict:
assert (
checkpoint_input_dir / "config.json"
Expand Down Expand Up @@ -287,15 +333,74 @@ def main():
logging.info("No sharded checkpoint found, using input directory as unsharded")
unsharded_dir = args.checkpoint_input_dir

convert_to_hf_checkpoint(
olmo_checkpoint_path=unsharded_dir / "model.pt",
output_path=args.huggingface_output_dir,
olmo_core_config=experiment_config,
max_sequence_length=args.max_sequence_length,
tokenizer=tokenizer_config, # type: ignore
)
convert_to_hf_checkpoint(
olmo_checkpoint_path= unsharded_dir / "model.pt",
output_path=args.huggingface_output_dir,
olmo_core_config=experiment_config,
max_sequence_length=args.max_sequence_length,
tokenizer=tokenizer_config,
)

validate_conversion(args.huggingface_output_dir, unsharded_dir , experiment_config)


def validate_conversion(hf_model_path, olmo_checkpoint_path, olmo_config):
"""
Validate that the converted Hugging Face model produces the same output as the original OLMo model.

Args:
hf_model_path: Path to the converted Hugging Face model directory.
olmo_checkpoint_path: Path to the original OLMo checkpoint directory.
olmo_config: Transformer configuration for the original OLMo model.
"""
log.info("Starting conversion validation...")

if not os.path.isdir(olmo_checkpoint_path):
raise ValueError(f"Expected a directory for `olmo_checkpoint_path`, but got {olmo_checkpoint_path}")

device = get_default_device()

config_path = os.path.join(olmo_checkpoint_path, "config.json")
log.info(f"Loading OLMo model config from {config_path}")

with open(config_path, "r") as f:
olmo_config_dict = json.load(f)["model"]

if "rope" in olmo_config_dict["block"]["attention"] and "scaling" in olmo_config_dict["block"]["attention"]["rope"]:
scaling_config = olmo_config_dict["block"]["attention"]["rope"]["scaling"]
assert isinstance(scaling_config, dict)

if "high_freq_factor" in scaling_config:
# llama3 scaling
olmo_config_dict["block"]["attention"]["rope"]["scaling"] = RoPELlamaScalingConfig(**scaling_config)
else:
# linear scaling (e.g. deepseek-coder)
olmo_config_dict["block"]["attention"]["rope"]["scaling"] = RoPELinearScalingConfig(**scaling_config)

olmo_config = TransformerConfig.from_dict(olmo_config_dict)

hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(hf_model_path)

B, T = 1, 120
input_ids = torch.randint(0, tokenizer.vocab_size, (B, T)).to(device)

with torch.no_grad():
hf_logits = hf_model(input_ids=input_ids).logits

model = olmo_config.build(device=device, max_seq_len=131072).eval()

log.info("Loading converted checkpoint for validation...")
load_model_and_optim_state(os.path.join(os.path.dirname(olmo_checkpoint_path), "model_and_optim"), model)

with torch.no_grad():
original_logits = model(input_ids=input_ids)

torch.testing.assert_close(hf_logits, original_logits, atol=1e-5, rtol=1e-3)
log.info("Conversion validation successful! Outputs match.")


if __name__ == "__main__":
prepare_cli_environment()
main()

25 changes: 25 additions & 0 deletions src/olmo_core/nn/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"RoPEType",
"RoPEConfig",
"RoPEScalingConfig",
"RoPELlamaScalingConfig",
"RoPELinearScalingConfig",
"RotaryEmbeddingBase",
"RotaryEmbedding",
"FusedRotaryEmbedding",
Expand Down Expand Up @@ -42,6 +44,14 @@ class RoPEType(StrEnum):

@dataclass
class RoPEScalingConfig(Config):
def scale_inv_freq(
self,
inv_freq: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError("Not implemented")

@dataclass
class RoPELlamaScalingConfig(RoPEScalingConfig):
"""
Defines how to scale RoPE to longer sequence lengths.
"""
Expand Down Expand Up @@ -71,6 +81,21 @@ def scale_inv_freq(
return torch.where(is_medium_freq, smoothed_inv_freq, inv_freq)


@dataclass
class RoPELinearScalingConfig(RoPEScalingConfig):
"""
Defines how to scale RoPE to longer sequence lengths.
"""

factor: float = 4.0

def scale_inv_freq(
self,
inv_freq: torch.Tensor,
) -> torch.Tensor:
return inv_freq / self.factor


@dataclass
class RoPEConfig(Config):
"""
Expand Down
23 changes: 21 additions & 2 deletions src/olmo_core/nn/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..feed_forward import FeedForwardConfig, FeedForwardType
from ..layer_norm import LayerNormConfig, LayerNormType
from ..lm_head import LMHeadConfig, LMHeadType
from ..rope import RoPEConfig, RoPEScalingConfig, RoPEType
from ..rope import RoPEConfig, RoPEScalingConfig, RoPELlamaScalingConfig, RoPELinearScalingConfig, RoPEType
from .block import TransformerBlockConfig, TransformerBlockType
from .init import InitMethod
from .model import (
Expand Down Expand Up @@ -647,6 +647,25 @@ def llama3_70B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
hidden_size_multiple_of=4096,
**kwargs,
)
@classmethod
def deepseek_1B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
"""
A 1B deepseek model config.
"""
return cls.llama_like(
d_model=2048,
vocab_size=vocab_size,
n_layers=kwargs.pop("n_layers", 24),
n_heads=kwargs.pop("n_heads", 16),
layer_norm_eps=kwargs.pop("rms_norm_eps", 1e-06),
n_kv_heads=kwargs.pop("n_kv_heads", 16),
rope_theta=kwargs.pop("rope_theta", 100000),
# rope_scaling=RoPELinearScalingConfig(factor=4.0),
hidden_size_multiple_of=128,
hidden_size_multiplier=None,
rope_type="default",
**kwargs,
)

@classmethod
def llama3_405B(
Expand Down Expand Up @@ -740,7 +759,7 @@ def llama_like(
n_heads=n_heads,
n_kv_heads=n_kv_heads,
bias=False,
rope=RoPEConfig(name=rope_type, theta=rope_theta, scaling=rope_scaling),
rope=RoPEConfig(name=rope_type, theta=rope_theta, scaling=rope_scaling, full_precision=False),
qk_norm=layer_norm if qk_norm else None,
use_flash=use_flash,
dtype=dtype,
Expand Down