-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
making converter additionally support deepseek-coder dense hf model #167
base: main
Are you sure you want to change the base?
Changes from 15 commits
c7c7d7a
f7f34f0
1371b61
0979f01
f8dd383
5ef600e
3998fb2
6aeed33
64868d6
76c250f
2876384
dec0c46
29b18c0
a6f7c9f
c85bccc
a307d88
aeddddf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,30 +16,36 @@ | |
from olmo_core.data.tokenizer import TokenizerConfig | ||
from olmo_core.distributed.checkpoint import load_model_and_optim_state, save_state_dict | ||
from olmo_core.io import clear_directory, dir_is_empty | ||
from olmo_core.nn.rope import RoPEScalingConfig | ||
from olmo_core.nn.rope import RoPELlamaScalingConfig, RoPELinearScalingConfig | ||
from olmo_core.nn.transformer import TransformerConfig | ||
from olmo_core.utils import get_default_device, prepare_cli_environment | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
HF_MODEL = "allenai/OLMo-2-1124-7B" | ||
# HF_MODEL = "allenai/OLMo-2-1124-7B-Instruct" | ||
# HF_MODEL = "allenai/OLMo-2-1124-13B-Instruct" | ||
# HF_MODEL = "meta-llama/Llama-3.2-1B" | ||
# HF_MODEL = "meta-llama/Llama-3.2-8B" | ||
HF_MODEL = f"{os.environ['SHARE_RES_DIR']}/models/deepseek/deepseek-coder-1.3b-base" | ||
|
||
SAVE_PATH = f"/tmp/checkpoints/{HF_MODEL}" | ||
SAVE_OVERWRITE = False | ||
SAVE_PATH = f"{os.environ['SHARE_RES_DIR']}/models/deepseek/olmo/deepseek-coder-1.3b-base" | ||
SAVE_OVERWRITE = True | ||
|
||
TOKENIZER_CONFIG = TokenizerConfig.from_hf(HF_MODEL) | ||
|
||
MODEL_CONFIG: TransformerConfig | ||
if HF_MODEL == "meta-llama/Llama-3.2-1B": | ||
if "Llama-3.2-1B" in HF_MODEL: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we also need this for bigger Llama models? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I change it so that it works with local cache dir. We are encountering some additional weird bug related to conversion; and requires some help from you. We will submit a separate issue regarding that. Do you prefer we incorporate solution to that issue into this PR or should I separate that from this current PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Many small PRs is always better, if possible. |
||
MODEL_CONFIG = TransformerConfig.llama3_1B( | ||
TOKENIZER_CONFIG.vocab_size, | ||
fused_ops=False, | ||
use_flash=False, | ||
rope_scaling=RoPEScalingConfig(), | ||
rope_scaling=RoPELlamaScalingConfig(), | ||
) | ||
|
||
elif "deepseek-coder-1.3b-base" in HF_MODEL: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this have to happen for all DeepSeek models? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, we defined class for this version deepseek following the olmo core setup. For other version, users would need to define their own class as well like this. Also, I can modify this to |
||
MODEL_CONFIG = TransformerConfig.deepseek_1B( | ||
TOKENIZER_CONFIG.vocab_size, | ||
fused_ops=False, | ||
use_flash=False, | ||
rope_scaling=RoPELinearScalingConfig(factor=4.0), | ||
) | ||
|
||
elif HF_MODEL.startswith("allenai/OLMo-2-1124-7B"): | ||
MODEL_CONFIG = TransformerConfig.olmo2_7B( | ||
TOKENIZER_CONFIG.vocab_size, | ||
|
@@ -105,7 +111,7 @@ def convert_checkpoint() -> AutoModelForCausalLM: | |
) | ||
|
||
# Layer norms. | ||
if "Llama" in HF_MODEL: | ||
if "Llama" or "deepseek" in HF_MODEL: | ||
new_state_dict[f"blocks.{block}.feed_forward_norm.weight"] = state_dict.pop( | ||
f"model.layers.{block}.post_attention_layernorm.weight" | ||
) | ||
|
@@ -152,7 +158,7 @@ def validate_conversion(hf_model): | |
model = MODEL_CONFIG.build(device=device, max_seq_len=131072).eval() | ||
|
||
log.info("Loading converted checkpoint for validation...") | ||
load_model_and_optim_state(SAVE_PATH, model) | ||
load_model_and_optim_state(os.path.join(SAVE_PATH, "model_and_optim"), model) | ||
|
||
with torch.no_grad(): | ||
logits = model(input_ids=input_ids) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,13 +11,15 @@ | |
from pathlib import Path | ||
from tempfile import TemporaryDirectory | ||
from typing import Generator | ||
|
||
import os | ||
import torch | ||
from safetensors.torch import load_file | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer, Olmo2Config | ||
|
||
from olmo_core.distributed.checkpoint import unshard_checkpoint | ||
from olmo_core.utils import prepare_cli_environment | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer, Olmo2Config, LlamaConfig, OlmoForCausalLM | ||
from olmo_core.nn.transformer import Transformer | ||
from olmo_core.nn.transformer import TransformerConfig, TransformerBlockConfig | ||
from olmo_core.nn.rope import RoPELlamaScalingConfig, RoPELinearScalingConfig | ||
from olmo_core.distributed.checkpoint import unshard_checkpoint, load_model_and_optim_state | ||
from olmo_core.utils import prepare_cli_environment, get_default_device | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
@@ -33,7 +35,9 @@ def init_empty_weights(include_buffers: bool = False) -> Generator[None, None, N | |
yield None | ||
|
||
|
||
def load_state_dict(checkpoint_path: str | Path) -> dict[str, torch.Tensor]: | ||
from typing import Union | ||
|
||
def load_state_dict(checkpoint_path: Union[str, Path]) -> dict[str, torch.Tensor]: | ||
""" | ||
Load a state dict from either a PyTorch checkpoint or safetensors file. | ||
|
||
|
@@ -64,8 +68,8 @@ def load_state_dict(checkpoint_path: str | Path) -> dict[str, torch.Tensor]: | |
|
||
|
||
def convert_to_hf_checkpoint( | ||
olmo_checkpoint_path: str | Path, | ||
output_path: str | Path, | ||
olmo_checkpoint_path: Union[str, Path], | ||
output_path: Union[str, Path], | ||
olmo_core_config: dict, | ||
tokenizer: GPT2Tokenizer, | ||
max_sequence_length: int = -1, | ||
|
@@ -117,6 +121,7 @@ def convert_to_hf_checkpoint( | |
|
||
for block in range(n_layers): | ||
# Attention | ||
|
||
hf_state_dict[f"model.layers.{block}.self_attn.q_proj.weight"] = olmo_state_dict.pop( | ||
f"blocks.{block}.attention.w_q.weight" | ||
) | ||
|
@@ -159,7 +164,7 @@ def convert_to_hf_checkpoint( | |
f"blocks.{block}.attention.k_norm.weight" | ||
) | ||
else: | ||
# Llama model | ||
# Llama or deepseek model | ||
hf_state_dict[ | ||
f"model.layers.{block}.post_attention_layernorm.weight" | ||
] = olmo_state_dict.pop(f"blocks.{block}.feed_forward_norm.weight") | ||
|
@@ -196,25 +201,66 @@ def convert_to_hf_checkpoint( | |
raise ValueError(f"Missing or invalid sequence length: {max_sequence_length}") | ||
|
||
# Create HF model instance and load state dict | ||
huggingface_config = Olmo2Config( | ||
vocab_size=olmo_core_config["model"]["vocab_size"], | ||
hidden_size=olmo_core_config["model"]["d_model"], | ||
intermediate_size=olmo_core_config["model"]["block"]["feed_forward"]["hidden_size"], | ||
num_hidden_layers=olmo_core_config["model"]["n_layers"], | ||
num_attention_heads=(n_heads := olmo_core_config["model"]["block"]["attention"]["n_heads"]), | ||
num_key_value_heads=( | ||
olmo_core_config["model"]["block"]["attention"].get("n_kv_heads") or n_heads | ||
), | ||
hidden_act="silu", | ||
max_position_embeddings=max_sequence_length, | ||
rope_theta=olmo_core_config["model"]["block"]["attention"]["rope"]["theta"], | ||
attention_bias=olmo_core_config["model"]["block"]["attention"].get("bias") or False, | ||
pad_token_id=tokenizer.vocab.get(tokenizer.pad_token, None), | ||
bos_token_id=tokenizer.vocab.get(tokenizer.bos_token, None), | ||
eos_token_id=tokenizer.vocab.get(tokenizer.eos_token, None), | ||
rms_norm_eps=olmo_core_config["model"]["block"]["layer_norm"]["eps"], | ||
tie_word_embeddings=False, | ||
) | ||
if has_qk_norm: | ||
huggingface_config = Olmo2Config( | ||
vocab_size=olmo_core_config["model"]["vocab_size"], | ||
hidden_size=olmo_core_config["model"]["d_model"], | ||
intermediate_size=olmo_core_config["model"]["block"]["feed_forward"]["hidden_size"], | ||
num_hidden_layers=olmo_core_config["model"]["n_layers"], | ||
num_attention_heads=(n_heads := olmo_core_config["model"]["block"]["attention"]["n_heads"]), | ||
num_key_value_heads=( | ||
olmo_core_config["model"]["block"]["attention"].get("n_kv_heads") or n_heads | ||
), | ||
hidden_act="silu", | ||
max_position_embeddings=max_sequence_length, | ||
rope_theta=olmo_core_config["model"]["block"]["attention"]["rope"]["theta"], | ||
attention_bias=olmo_core_config["model"]["block"]["attention"].get("bias") or False, | ||
pad_token_id=tokenizer.vocab.get(tokenizer.pad_token, None), | ||
bos_token_id=tokenizer.vocab.get(tokenizer.bos_token, None), | ||
eos_token_id=tokenizer.vocab.get(tokenizer.eos_token, None), | ||
rms_norm_eps=olmo_core_config["model"]["block"]["layer_norm"]["eps"], | ||
tie_word_embeddings=False, | ||
) | ||
else: | ||
# Extract RoPE scaling config | ||
# from pdb import set_trace; set_trace() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. leftover debug code? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will remove in next commit. |
||
rope_config = olmo_core_config["model"]["block"]["attention"]["rope"].get("scaling", None) | ||
|
||
if rope_config is None: | ||
rope_scaling = None | ||
elif "high_freq_factor" in rope_config: | ||
# llama3 scaling | ||
rope_scaling = { | ||
"rope_type": "llama3", | ||
"factor": rope_config["factor"], | ||
"high_freq_factor": rope_config["high_freq_factor"], | ||
"low_freq_factor": rope_config["low_freq_factor"], | ||
"original_max_position_embeddings": rope_config["old_context_len"] | ||
} | ||
else: | ||
# linear scaling (e.g. deepseek-coder) | ||
rope_scaling = {"type": "linear", "factor": rope_config["factor"]} | ||
|
||
huggingface_config = LlamaConfig( | ||
vocab_size=olmo_core_config["model"]["vocab_size"], | ||
hidden_size=olmo_core_config["model"]["d_model"], | ||
intermediate_size=olmo_core_config["model"]["block"]["feed_forward"]["hidden_size"], | ||
num_hidden_layers=olmo_core_config["model"]["n_layers"], | ||
num_attention_heads=(n_heads := olmo_core_config["model"]["block"]["attention"]["n_heads"]), | ||
num_key_value_heads=( | ||
olmo_core_config["model"]["block"]["attention"].get("n_kv_heads") or n_heads | ||
), | ||
hidden_act="silu", | ||
max_position_embeddings=max_sequence_length, | ||
rope_theta=olmo_core_config["model"]["block"]["attention"]["rope"]["theta"], | ||
attention_bias=olmo_core_config["model"]["block"]["attention"].get("bias") or False, | ||
pad_token_id=tokenizer.vocab.get(tokenizer.pad_token, None), | ||
bos_token_id=tokenizer.vocab.get(tokenizer.bos_token, None), | ||
eos_token_id=tokenizer.vocab.get(tokenizer.eos_token, None), | ||
rms_norm_eps=olmo_core_config["model"]["block"]["layer_norm"]["eps"], | ||
tie_word_embeddings=False, | ||
rope_scaling=rope_scaling, | ||
) | ||
|
||
with init_empty_weights(): | ||
log.info("Initializing HF model with empty weights...") | ||
|
@@ -233,6 +279,7 @@ def convert_to_hf_checkpoint( | |
log.info(f"Successfully saved HF tokenizer to '{output_path}'") | ||
|
||
|
||
|
||
def load_config(checkpoint_input_dir: Path) -> dict: | ||
assert ( | ||
checkpoint_input_dir / "config.json" | ||
|
@@ -287,15 +334,74 @@ def main(): | |
logging.info("No sharded checkpoint found, using input directory as unsharded") | ||
unsharded_dir = args.checkpoint_input_dir | ||
|
||
convert_to_hf_checkpoint( | ||
olmo_checkpoint_path=unsharded_dir / "model.pt", | ||
output_path=args.huggingface_output_dir, | ||
olmo_core_config=experiment_config, | ||
max_sequence_length=args.max_sequence_length, | ||
tokenizer=tokenizer_config, # type: ignore | ||
) | ||
convert_to_hf_checkpoint( | ||
olmo_checkpoint_path= unsharded_dir / "model.pt", | ||
output_path=args.huggingface_output_dir, | ||
olmo_core_config=experiment_config, | ||
max_sequence_length=args.max_sequence_length, | ||
tokenizer=tokenizer_config, | ||
) | ||
|
||
validate_conversion(args.huggingface_output_dir, unsharded_dir , experiment_config) | ||
|
||
|
||
def validate_conversion(hf_model_path, olmo_checkpoint_path, olmo_config): | ||
""" | ||
Validate that the converted Hugging Face model produces the same output as the original OLMo model. | ||
|
||
Args: | ||
hf_model_path: Path to the converted Hugging Face model directory. | ||
olmo_checkpoint_path: Path to the original OLMo checkpoint directory. | ||
olmo_config: Transformer configuration for the original OLMo model. | ||
""" | ||
log.info("Starting conversion validation...") | ||
|
||
if not os.path.isdir(olmo_checkpoint_path): | ||
raise ValueError(f"Expected a directory for `olmo_checkpoint_path`, but got {olmo_checkpoint_path}") | ||
|
||
device = get_default_device() | ||
|
||
config_path = os.path.join(olmo_checkpoint_path, "config.json") | ||
log.info(f"Loading OLMo model config from {config_path}") | ||
|
||
with open(config_path, "r") as f: | ||
olmo_config_dict = json.load(f)["model"] | ||
# import pdb; pdb.set_trace() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. leftover? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will remove in next commit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. resolved all above in new commit @dirkgr |
||
if "rope" in olmo_config_dict["block"]["attention"] and "scaling" in olmo_config_dict["block"]["attention"]["rope"]: | ||
scaling_config = olmo_config_dict["block"]["attention"]["rope"]["scaling"] | ||
assert isinstance(scaling_config, dict) | ||
|
||
if "high_freq_factor" in scaling_config: | ||
# llama3 scaling | ||
olmo_config_dict["block"]["attention"]["rope"]["scaling"] = RoPELlamaScalingConfig(**scaling_config) | ||
else: | ||
# linear scaling (e.g. deepseek-coder) | ||
olmo_config_dict["block"]["attention"]["rope"]["scaling"] = RoPELinearScalingConfig(**scaling_config) | ||
|
||
olmo_config = TransformerConfig.from_dict(olmo_config_dict) | ||
|
||
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path).to(device).eval() | ||
tokenizer = AutoTokenizer.from_pretrained(hf_model_path) | ||
|
||
B, T = 1, 120 | ||
input_ids = torch.randint(0, tokenizer.vocab_size, (B, T)).to(device) | ||
|
||
with torch.no_grad(): | ||
hf_logits = hf_model(input_ids=input_ids).logits | ||
|
||
model = olmo_config.build(device=device, max_seq_len=131072).eval() | ||
|
||
log.info("Loading converted checkpoint for validation...") | ||
load_model_and_optim_state(os.path.join(os.path.dirname(olmo_checkpoint_path), "model_and_optim"), model) | ||
|
||
with torch.no_grad(): | ||
original_logits = model(input_ids=input_ids) | ||
|
||
torch.testing.assert_close(hf_logits, original_logits, atol=1e-5, rtol=1e-3) | ||
log.info("Conversion validation successful! Outputs match.") | ||
|
||
|
||
if __name__ == "__main__": | ||
prepare_cli_environment() | ||
main() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is hardcoded to deepseek, and the
SAVE_OVERWRITE
was probably a leftover statement? It should default to the safe setting, so people don't overwrite their stuff by accident.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure will fix.