Skip to content

Commit

Permalink
propage ckpt precision changes
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
akoumpa committed Sep 20, 2024
1 parent f357b83 commit 9e0e89d
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 8 deletions.
8 changes: 7 additions & 1 deletion nemo/collections/llm/gpt/model/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io, teardown
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf

if TYPE_CHECKING:
from transformers import AutoConfig, AutoModelForCausalLM
Expand Down Expand Up @@ -76,10 +77,12 @@ def init(self) -> Baichuan2Model:
def apply(self, output_path: Path) -> Path:
from transformers import AutoModelForCausalLM

source = AutoModelForCausalLM.from_pretrained(str(self), trust_remote_code=True)
source = AutoModelForCausalLM.from_pretrained(str(self), trust_remote_code=True, torch_dtype='auto')
target = self.init()
trainer = self.nemo_setup(target)
target_dtypes = extract_dtypes(target.module.named_parameters())
self.convert_state(source, target)
assert target_dtypes == extract_dtypes(target.module.named_parameters())
self.nemo_save(output_path, trainer)

print(f"Converted Baichuan model to Nemo, model saved to {output_path}")
Expand Down Expand Up @@ -131,6 +134,9 @@ def make_vocab_size_divisible_by(vocab_size):
make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size),
share_embeddings_and_output_weights=False,
position_embedding_type="rope" if source.num_hidden_layers == 32 else "alibi",
fp16=(dtype_from_hf(source) == torch.float16),
bf16=(dtype_from_hf(source) == torch.bfloat16),
params_dtype=dtype_from_hf(source),
)

return output
Expand Down
8 changes: 7 additions & 1 deletion nemo/collections/llm/gpt/model/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io, teardown
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf

if TYPE_CHECKING:
from transformers import AutoConfig, AutoModelForCausalLM
Expand Down Expand Up @@ -82,10 +83,12 @@ def init(self) -> ChatGLMModel:
def apply(self, output_path: Path) -> Path:
from transformers import AutoModelForCausalLM

source = AutoModelForCausalLM.from_pretrained(str(self), trust_remote_code=True)
source = AutoModelForCausalLM.from_pretrained(str(self), trust_remote_code=True, torch_dtype='auto')
target = self.init()
trainer = self.nemo_setup(target)
target_dtypes = extract_dtypes(target.module.named_parameters())
self.convert_state(source, target)
assert target_dtypes == extract_dtypes(target.module.named_parameters())
self.nemo_save(output_path, trainer)

print(f"Converted ChatGLM model to Nemo, model saved to {output_path}")
Expand Down Expand Up @@ -128,6 +131,9 @@ def config(self) -> ChatGLMConfig:
seq_length=source.seq_length,
num_query_groups=source.multi_query_group_num,
make_vocab_size_divisible_by=source.padded_vocab_size,
fp16=(dtype_from_hf(source) == torch.float16),
bf16=(dtype_from_hf(source) == torch.bfloat16),
params_dtype=dtype_from_hf(source),
)

return output
Expand Down
8 changes: 7 additions & 1 deletion nemo/collections/llm/gpt/model/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io, teardown
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf

if TYPE_CHECKING:
from transformers import GemmaForCausalLM
Expand Down Expand Up @@ -105,10 +106,12 @@ def init(self) -> GemmaModel:
def apply(self, output_path: Path) -> Path:
from transformers import GemmaForCausalLM

source = GemmaForCausalLM.from_pretrained(str(self))
source = GemmaForCausalLM.from_pretrained(str(self), torch_dtype='auto')
target = self.init()
trainer = self.nemo_setup(target)
target_dtypes = extract_dtypes(target.module.named_parameters())
self.convert_state(source, target)
assert target_dtypes == extract_dtypes(target.module.named_parameters())
self.nemo_save(output_path, trainer)

print(f"Converted Gemma model to Nemo, model saved to {output_path}")
Expand Down Expand Up @@ -160,6 +163,9 @@ def make_vocab_size_divisible_by(vocab_size):
gated_linear_unit=True,
make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size),
share_embeddings_and_output_weights=False,
fp16=(dtype_from_hf(source) == torch.float16),
bf16=(dtype_from_hf(source) == torch.bfloat16),
params_dtype=dtype_from_hf(source),
)

return output
Expand Down
8 changes: 7 additions & 1 deletion nemo/collections/llm/gpt/model/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nemo.collections.llm.utils import Config
from nemo.lightning import io, teardown
from nemo.lightning.pytorch.optim import OptimizerModule
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf

if TYPE_CHECKING:
from transformers import MistralConfig, MistralForCausalLM
Expand Down Expand Up @@ -112,10 +113,12 @@ def init(self) -> MistralModel:
def apply(self, output_path: Path) -> Path:
from transformers import MistralForCausalLM

source = MistralForCausalLM.from_pretrained(str(self))
source = MistralForCausalLM.from_pretrained(str(self), torch_dtype='auto')
target = self.init()
trainer = self.nemo_setup(target)
target_dtypes = extract_dtypes(target.module.named_parameters())
self.convert_state(source, target)
assert target_dtypes == extract_dtypes(target.module.named_parameters())
self.nemo_save(output_path, trainer)

print(f"Converted Mistral 7B model to Nemo, model saved to {output_path}")
Expand Down Expand Up @@ -175,6 +178,9 @@ def make_vocab_size_divisible_by(mistral_vocab_size):
make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size),
window_size=window_size,
share_embeddings_and_output_weights=False,
fp16=(dtype_from_hf(source) == torch.float16),
bf16=(dtype_from_hf(source) == torch.bfloat16),
params_dtype=dtype_from_hf(source),
)

return output
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/llm/gpt/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.lightning import io, teardown
from nemo.lightning.pytorch.optim import OptimizerModule
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'dtype_from_hf' is not used.

if TYPE_CHECKING:
from transformers import MixtralForCausalLM
Expand Down Expand Up @@ -139,7 +140,9 @@ def apply(self, output_path: Path) -> Path:
source = MixtralForCausalLM.from_pretrained(str(self), torch_dtype='auto', use_safetensors=True)
target = self.init()
trainer = self.nemo_setup(target)
target_dtypes = extract_dtypes(target.module.named_parameters())
self.convert_state(source, target)
assert target_dtypes == extract_dtypes(target.module.named_parameters())
self.nemo_save(output_path, trainer)

teardown(trainer, target)
Expand Down
8 changes: 7 additions & 1 deletion nemo/collections/llm/gpt/model/nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io, teardown
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf

if TYPE_CHECKING:
from transformers import NemotronConfig as HFNemotronConfig
Expand Down Expand Up @@ -139,10 +140,12 @@ def init(self) -> NemotronModel:
def apply(self, output_path: Path) -> Path:
from transformers import NemotronForCausalLM

source = NemotronForCausalLM.from_pretrained(str(self))
source = NemotronForCausalLM.from_pretrained(str(self), torch_dtype='auto')
target = self.init()
trainer = self.nemo_setup(target)
target_dtypes = extract_dtypes(target.module.named_parameters())
self.convert_state(source, target)
assert target_dtypes == extract_dtypes(target.module.named_parameters())
self.nemo_save(output_path, trainer)

print(f"Converted Nemotron model to Nemo, model saved to {output_path}")
Expand Down Expand Up @@ -200,6 +203,9 @@ def make_vocab_size_divisible_by(vocab_size):
rotary_percent=source.partial_rotary_factor,
make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size),
share_embeddings_and_output_weights=False,
fp16=(dtype_from_hf(source) == torch.float16),
bf16=(dtype_from_hf(source) == torch.bfloat16),
params_dtype=dtype_from_hf(source),
)

return output
Expand Down
8 changes: 7 additions & 1 deletion nemo/collections/llm/gpt/model/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io, teardown
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf

if TYPE_CHECKING:
from transformers import AutoModelForCausalLM
Expand Down Expand Up @@ -109,10 +110,12 @@ def init(self) -> Qwen2Model:
def apply(self, output_path: Path) -> Path:
from transformers import AutoModelForCausalLM

source = AutoModelForCausalLM.from_pretrained(str(self), trust_remote_code=True)
source = AutoModelForCausalLM.from_pretrained(str(self), torch_dtype='auto', trust_remote_code=True)
target = self.init()
trainer = self.nemo_setup(target)
target_dtypes = extract_dtypes(target.module.named_parameters())
self.convert_state(source, target)
assert target_dtypes == extract_dtypes(target.module.named_parameters())
self.nemo_save(output_path, trainer)

print(f"Converted Qwen model to Nemo, model saved to {output_path}")
Expand Down Expand Up @@ -161,6 +164,9 @@ def config(self) -> Qwen2Config:
make_vocab_size_divisible_by=128,
rotary_base=source.rope_theta,
share_embeddings_and_output_weights=False,
fp16=(dtype_from_hf(source) == torch.float16),
bf16=(dtype_from_hf(source) == torch.bfloat16),
params_dtype=dtype_from_hf(source),
)

return output
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/llm/gpt/model/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from megatron.core.transformer.transformer_config import TransformerConfig
from nemo.collections.llm.gpt.model.base import GPTModel, gpt_data_step
from nemo.lightning import get_vocab_size, io, teardown
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf


def ssm_forward_step(model, batch) -> torch.Tensor:
Expand Down Expand Up @@ -130,7 +131,9 @@ def state_dict(self):
source = ModelState(source)
target = self.init()
trainer = self.nemo_setup(target)
target_dtypes = extract_dtypes(target.module.named_parameters())
self.convert_state(source, target)
assert target_dtypes == extract_dtypes(target.module.named_parameters())
self.nemo_save(output_path, trainer)

logging.info(f"Converted SSM model to Nemo, model saved to {output_path}")
Expand Down
8 changes: 7 additions & 1 deletion nemo/collections/llm/gpt/model/starcoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io, teardown
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf

if TYPE_CHECKING:
from transformers import GPTBigCodeConfig as HFStarcoderConfig
Expand Down Expand Up @@ -81,10 +82,12 @@ def init(self) -> StarcoderModel:
def apply(self, output_path: Path) -> Path:
from transformers import GPTBigCodeForCausalLM

source = GPTBigCodeForCausalLM.from_pretrained(str(self))
source = GPTBigCodeForCausalLM.from_pretrained(str(self), torch_dtype='auto')
target = self.init()
trainer = self.nemo_setup(target)
target_dtypes = extract_dtypes(target.module.named_parameters())
self.convert_state(source, target)
assert target_dtypes == extract_dtypes(target.module.named_parameters())
self.nemo_save(output_path, trainer)

print(f"Converted Starcoder model to Nemo, model saved to {output_path}")
Expand Down Expand Up @@ -146,6 +149,9 @@ def make_vocab_size_divisible_by(vocab_size):
num_query_groups=1,
make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size),
share_embeddings_and_output_weights=False,
fp16=(dtype_from_hf(source) == torch.float16),
bf16=(dtype_from_hf(source) == torch.bfloat16),
params_dtype=dtype_from_hf(source),
)

return output
Expand Down
8 changes: 7 additions & 1 deletion nemo/collections/llm/gpt/model/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io, teardown
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf

if TYPE_CHECKING:
from transformers import Starcoder2Config as HFStarcoder2Config
Expand Down Expand Up @@ -108,10 +109,12 @@ def init(self) -> Starcoder2Model:
def apply(self, output_path: Path) -> Path:
from transformers import Starcoder2ForCausalLM

source = Starcoder2ForCausalLM.from_pretrained(str(self))
source = Starcoder2ForCausalLM.from_pretrained(str(self), torch_dtype='auto')
target = self.init()
trainer = self.nemo_setup(target)
target_dtypes = extract_dtypes(target.module.named_parameters())
self.convert_state(source, target)
assert target_dtypes == extract_dtypes(target.module.named_parameters())
self.nemo_save(output_path, trainer)

print(f"Converted Starcoder2 model to Nemo, model saved to {output_path}")
Expand Down Expand Up @@ -171,6 +174,9 @@ def make_vocab_size_divisible_by(vocab_size):
rotary_base=source.rope_theta,
make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size),
share_embeddings_and_output_weights=False,
fp16=(dtype_from_hf(source) == torch.float16),
bf16=(dtype_from_hf(source) == torch.bfloat16),
params_dtype=dtype_from_hf(source),
)

return output
Expand Down

0 comments on commit 9e0e89d

Please sign in to comment.