From 54644799d6c7ad0bafd576ec870bfe054bfbd673 Mon Sep 17 00:00:00 2001 From: akoumpa Date: Fri, 20 Sep 2024 18:54:14 +0000 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: akoumpa --- nemo/collections/llm/gpt/model/baichuan.py | 2 +- nemo/collections/llm/gpt/model/chatglm.py | 2 +- nemo/collections/llm/gpt/model/gemma.py | 2 +- nemo/collections/llm/gpt/model/llama.py | 2 +- nemo/collections/llm/gpt/model/mistral.py | 2 +- nemo/collections/llm/gpt/model/mixtral.py | 2 +- nemo/collections/llm/gpt/model/nemotron.py | 2 +- nemo/collections/llm/gpt/model/qwen2.py | 2 +- nemo/collections/llm/gpt/model/ssm.py | 3 ++- nemo/collections/llm/gpt/model/starcoder.py | 2 +- nemo/collections/llm/gpt/model/starcoder2.py | 2 +- nemo/lightning/_strategy_lib.py | 2 ++ nemo/lightning/pytorch/utils.py | 3 +++ 13 files changed, 17 insertions(+), 11 deletions(-) diff --git a/nemo/collections/llm/gpt/model/baichuan.py b/nemo/collections/llm/gpt/model/baichuan.py index 5bddb1b324a6..202d5169c8f0 100644 --- a/nemo/collections/llm/gpt/model/baichuan.py +++ b/nemo/collections/llm/gpt/model/baichuan.py @@ -23,7 +23,7 @@ from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config from nemo.lightning import OptimizerModule, io, teardown -from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf +from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes if TYPE_CHECKING: from transformers import AutoConfig, AutoModelForCausalLM diff --git a/nemo/collections/llm/gpt/model/chatglm.py b/nemo/collections/llm/gpt/model/chatglm.py index 12f5ea23c7e9..e0b493fef27c 100644 --- a/nemo/collections/llm/gpt/model/chatglm.py +++ b/nemo/collections/llm/gpt/model/chatglm.py @@ -23,7 +23,7 @@ from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config from nemo.lightning import OptimizerModule, io, teardown -from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf +from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes if TYPE_CHECKING: from transformers import AutoConfig, AutoModelForCausalLM diff --git a/nemo/collections/llm/gpt/model/gemma.py b/nemo/collections/llm/gpt/model/gemma.py index ca226a318d74..0b669bad503b 100644 --- a/nemo/collections/llm/gpt/model/gemma.py +++ b/nemo/collections/llm/gpt/model/gemma.py @@ -23,7 +23,7 @@ from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config from nemo.lightning import OptimizerModule, io, teardown -from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf +from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes if TYPE_CHECKING: from transformers import GemmaForCausalLM diff --git a/nemo/collections/llm/gpt/model/llama.py b/nemo/collections/llm/gpt/model/llama.py index 154b58c69e94..5d52a3aaa3f5 100644 --- a/nemo/collections/llm/gpt/model/llama.py +++ b/nemo/collections/llm/gpt/model/llama.py @@ -24,8 +24,8 @@ from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config from nemo.lightning import OptimizerModule, io, teardown +from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes from nemo.utils import logging -from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf if TYPE_CHECKING: from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel diff --git a/nemo/collections/llm/gpt/model/mistral.py b/nemo/collections/llm/gpt/model/mistral.py index 91787de816e1..380f5eb59e97 100644 --- a/nemo/collections/llm/gpt/model/mistral.py +++ b/nemo/collections/llm/gpt/model/mistral.py @@ -26,7 +26,7 @@ from nemo.collections.llm.utils import Config from nemo.lightning import io, teardown from nemo.lightning.pytorch.optim import OptimizerModule -from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf +from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes if TYPE_CHECKING: from transformers import MistralConfig, MistralForCausalLM diff --git a/nemo/collections/llm/gpt/model/mixtral.py b/nemo/collections/llm/gpt/model/mixtral.py index 41d409f708cc..e8ae4734868a 100644 --- a/nemo/collections/llm/gpt/model/mixtral.py +++ b/nemo/collections/llm/gpt/model/mixtral.py @@ -23,7 +23,7 @@ from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.lightning import io, teardown from nemo.lightning.pytorch.optim import OptimizerModule -from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf +from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes if TYPE_CHECKING: from transformers import MixtralForCausalLM diff --git a/nemo/collections/llm/gpt/model/nemotron.py b/nemo/collections/llm/gpt/model/nemotron.py index 8851d7ae9e15..58b69863bc09 100644 --- a/nemo/collections/llm/gpt/model/nemotron.py +++ b/nemo/collections/llm/gpt/model/nemotron.py @@ -23,7 +23,7 @@ from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config from nemo.lightning import OptimizerModule, io, teardown -from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf +from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes if TYPE_CHECKING: from transformers import NemotronConfig as HFNemotronConfig diff --git a/nemo/collections/llm/gpt/model/qwen2.py b/nemo/collections/llm/gpt/model/qwen2.py index 44e130ca2338..9b36481e5aab 100644 --- a/nemo/collections/llm/gpt/model/qwen2.py +++ b/nemo/collections/llm/gpt/model/qwen2.py @@ -23,7 +23,7 @@ from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config from nemo.lightning import OptimizerModule, io, teardown -from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf +from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes if TYPE_CHECKING: from transformers import AutoModelForCausalLM diff --git a/nemo/collections/llm/gpt/model/ssm.py b/nemo/collections/llm/gpt/model/ssm.py index 8236eb1a13b1..c2e67f24b71c 100644 --- a/nemo/collections/llm/gpt/model/ssm.py +++ b/nemo/collections/llm/gpt/model/ssm.py @@ -32,9 +32,10 @@ HAVE_MEGATRON_CORE_OR_TE = False from megatron.core.transformer.transformer_config import TransformerConfig + from nemo.collections.llm.gpt.model.base import GPTModel, gpt_data_step from nemo.lightning import get_vocab_size, io, teardown -from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf +from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes def ssm_forward_step(model, batch) -> torch.Tensor: diff --git a/nemo/collections/llm/gpt/model/starcoder.py b/nemo/collections/llm/gpt/model/starcoder.py index d22f3f8db7b4..f1af8c43654a 100644 --- a/nemo/collections/llm/gpt/model/starcoder.py +++ b/nemo/collections/llm/gpt/model/starcoder.py @@ -22,7 +22,7 @@ from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config from nemo.lightning import OptimizerModule, io, teardown -from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf +from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes if TYPE_CHECKING: from transformers import GPTBigCodeConfig as HFStarcoderConfig diff --git a/nemo/collections/llm/gpt/model/starcoder2.py b/nemo/collections/llm/gpt/model/starcoder2.py index 641b238b50e3..ef6ebd259cf1 100644 --- a/nemo/collections/llm/gpt/model/starcoder2.py +++ b/nemo/collections/llm/gpt/model/starcoder2.py @@ -23,7 +23,7 @@ from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config from nemo.lightning import OptimizerModule, io, teardown -from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf +from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes if TYPE_CHECKING: from transformers import Starcoder2Config as HFStarcoder2Config diff --git a/nemo/lightning/_strategy_lib.py b/nemo/lightning/_strategy_lib.py index 13cc93e9a271..ec061e4bf143 100644 --- a/nemo/lightning/_strategy_lib.py +++ b/nemo/lightning/_strategy_lib.py @@ -164,11 +164,13 @@ def megatron_lazy_init_context(config) -> Generator[None, None, None]: from megatron.core.extensions import transformer_engine as _te original = _te._get_extra_te_kwargs # noqa: SLF001 + def _get_extra_te_kwargs_meta(c): """Forces device to meta""" kwargs = original(c) kwargs['device'] = 'meta' return kwargs + _te._get_extra_te_kwargs = _get_extra_te_kwargs_meta # noqa: SLF001 _orig_perform_initialization = config.perform_initialization diff --git a/nemo/lightning/pytorch/utils.py b/nemo/lightning/pytorch/utils.py index 0dc8395249f4..37efa1b68c5f 100644 --- a/nemo/lightning/pytorch/utils.py +++ b/nemo/lightning/pytorch/utils.py @@ -1,5 +1,6 @@ import torch + def extract_dtypes(ckpt): """ Extracts dtype from the input iterator @@ -13,6 +14,7 @@ def extract_dtypes(ckpt): continue return dtypes + def dtype_from_str(dtype): """ Convert a str precision to equivalent torch dtype. @@ -25,6 +27,7 @@ def dtype_from_str(dtype): else: return torch.float32 + def dtype_from_hf(config): """ Extracts torch dtype from a HF config