Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

making converter additionally support deepseek-coder dense hf model #167

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,50 @@
HuggingFace.
"""

import json
import logging
import os

import torch
from transformers import AutoModelForCausalLM

from olmo_core.data.tokenizer import TokenizerConfig
from olmo_core.distributed.checkpoint import load_model_and_optim_state, save_state_dict
from olmo_core.io import clear_directory, dir_is_empty
from olmo_core.nn.rope import RoPEScalingConfig
from olmo_core.nn.rope import RoPELlamaScalingConfig, RoPELinearScalingConfig
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.utils import get_default_device, prepare_cli_environment

log = logging.getLogger(__name__)

HF_MODEL = "allenai/OLMo-2-1124-7B"
# HF_MODEL = "allenai/OLMo-2-1124-7B-Instruct"
# HF_MODEL = "allenai/OLMo-2-1124-13B-Instruct"
# HF_MODEL = "meta-llama/Llama-3.2-1B"
# HF_MODEL = "meta-llama/Llama-3.2-8B"
HF_MODEL = f"{os.environ['SHARE_RES_DIR']}/models/deepseek/deepseek-coder-1.3b-base"
# HF_MODEL = "/home/zliu/shared_resources/models/llama3/hf/Llama-3.2-1B"
# HF_MODEL = ""

SAVE_PATH = f"/tmp/checkpoints/{HF_MODEL}"
SAVE_OVERWRITE = False
SAVE_PATH = f"{os.environ['SHARE_RES_DIR']}/models/deepseek/olmo/deepseek-coder-1.3b-base"
# SAVE_PATH = "/home/zliu/shared_resources/models/llama3/olmo/Llama-3.2-1B"
SAVE_OVERWRITE = True

TOKENIZER_CONFIG = TokenizerConfig.from_hf(HF_MODEL)
# TOKENIZER_CONFIG = TokenizerConfig.from_hf(HF_MODEL)
TOKENIZER_CONFIG = TokenizerConfig.from_hf("deepseek-ai/deepseek-coder-1.3b-base")
# TOKENIZER_CONFIG = TokenizerConfig.from_hf("meta-llama/Llama-3.2-1B")
MODEL_CONFIG: TransformerConfig
if HF_MODEL == "meta-llama/Llama-3.2-1B":
if "Llama-3.2-1B" in HF_MODEL:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need this for bigger Llama models?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I change it so that it works with local cache dir. We are encountering some additional weird bug related to conversion; and requires some help from you. We will submit a separate issue regarding that. Do you prefer we incorporate solution to that issue into this PR or should I separate that from this current PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many small PRs is always better, if possible.

MODEL_CONFIG = TransformerConfig.llama3_1B(
TOKENIZER_CONFIG.vocab_size,
fused_ops=False,
use_flash=False,
rope_scaling=RoPEScalingConfig(),
rope_scaling=RoPELlamaScalingConfig(),
)

elif "deepseek-coder-1.3b-base" in HF_MODEL:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have to happen for all DeepSeek models?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we defined class for this version deepseek following the olmo core setup. For other version, users would need to define their own class as well like this. Also, I can modify this to elif HF_MODEL.startswith("deepseek-coder-1.3b-base"): to follow the same format as repo did.

MODEL_CONFIG = TransformerConfig.deepseek_1B(
TOKENIZER_CONFIG.vocab_size,
fused_ops=False,
use_flash=False,
rope_scaling=RoPELinearScalingConfig(factor=4.0),
)

elif HF_MODEL.startswith("allenai/OLMo-2-1124-7B"):
MODEL_CONFIG = TransformerConfig.olmo2_7B(
TOKENIZER_CONFIG.vocab_size,
Expand Down Expand Up @@ -103,7 +115,7 @@ def convert_checkpoint() -> AutoModelForCausalLM:
)

# Layer norms.
if "Llama" in HF_MODEL:
if "Llama" or "deepseek" in HF_MODEL:
new_state_dict[f"blocks.{block}.feed_forward_norm.weight"] = state_dict.pop(
f"model.layers.{block}.post_attention_layernorm.weight"
)
Expand All @@ -127,7 +139,10 @@ def convert_checkpoint() -> AutoModelForCausalLM:
assert len(state_dict) == 0

log.info(f"Saving converted model checkpoint '{SAVE_PATH}'...")
save_state_dict(SAVE_PATH, {"model": new_state_dict})
save_state_dict(os.path.join(SAVE_PATH, "model_and_optim"), {"model": new_state_dict})

with open(os.path.join(SAVE_PATH, "config.json"), "w") as f:
json.dump({"model": MODEL_CONFIG.as_dict()}, f)

return hf_model

Expand All @@ -147,7 +162,7 @@ def validate_conversion(hf_model):
model = MODEL_CONFIG.build(device=device, max_seq_len=131072).eval()

log.info("Loading converted checkpoint for validation...")
load_model_and_optim_state(SAVE_PATH, model)
load_model_and_optim_state(os.path.join(SAVE_PATH, "model_and_optim"), model)

with torch.no_grad():
logits = model(input_ids=input_ids)
Expand All @@ -159,5 +174,7 @@ def validate_conversion(hf_model):

if __name__ == "__main__":
prepare_cli_environment()

config = MODEL_CONFIG.as_dict()
hf_model = convert_checkpoint()
validate_conversion(hf_model)
Loading
Loading