Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa committed Sep 20, 2024
1 parent 1b549c5 commit 5464479
Show file tree
Hide file tree
Showing 13 changed files with 17 additions and 11 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io, teardown
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf
from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes

if TYPE_CHECKING:
from transformers import AutoConfig, AutoModelForCausalLM
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io, teardown
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf
from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes

if TYPE_CHECKING:
from transformers import AutoConfig, AutoModelForCausalLM
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io, teardown
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf
from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes

if TYPE_CHECKING:
from transformers import GemmaForCausalLM
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io, teardown
from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes
from nemo.utils import logging
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf

if TYPE_CHECKING:
from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from nemo.collections.llm.utils import Config
from nemo.lightning import io, teardown
from nemo.lightning.pytorch.optim import OptimizerModule
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf
from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes

if TYPE_CHECKING:
from transformers import MistralConfig, MistralForCausalLM
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.lightning import io, teardown
from nemo.lightning.pytorch.optim import OptimizerModule
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf
from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'dtype_from_hf' is not used.

if TYPE_CHECKING:
from transformers import MixtralForCausalLM
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io, teardown
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf
from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes

if TYPE_CHECKING:
from transformers import NemotronConfig as HFNemotronConfig
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io, teardown
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf
from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes

if TYPE_CHECKING:
from transformers import AutoModelForCausalLM
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/llm/gpt/model/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@
HAVE_MEGATRON_CORE_OR_TE = False

from megatron.core.transformer.transformer_config import TransformerConfig

from nemo.collections.llm.gpt.model.base import GPTModel, gpt_data_step
from nemo.lightning import get_vocab_size, io, teardown
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf
from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'dtype_from_hf' is not used.


def ssm_forward_step(model, batch) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/starcoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io, teardown
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf
from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes

if TYPE_CHECKING:
from transformers import GPTBigCodeConfig as HFStarcoderConfig
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io, teardown
from nemo.lightning.pytorch.utils import extract_dtypes, dtype_from_hf
from nemo.lightning.pytorch.utils import dtype_from_hf, extract_dtypes

if TYPE_CHECKING:
from transformers import Starcoder2Config as HFStarcoder2Config
Expand Down
2 changes: 2 additions & 0 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,13 @@ def megatron_lazy_init_context(config) -> Generator[None, None, None]:
from megatron.core.extensions import transformer_engine as _te

original = _te._get_extra_te_kwargs # noqa: SLF001

def _get_extra_te_kwargs_meta(c):
"""Forces device to meta"""
kwargs = original(c)
kwargs['device'] = 'meta'
return kwargs

_te._get_extra_te_kwargs = _get_extra_te_kwargs_meta # noqa: SLF001

_orig_perform_initialization = config.perform_initialization
Expand Down
3 changes: 3 additions & 0 deletions nemo/lightning/pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch


def extract_dtypes(ckpt):
"""
Extracts dtype from the input iterator
Expand All @@ -13,6 +14,7 @@ def extract_dtypes(ckpt):
continue
return dtypes


def dtype_from_str(dtype):
"""
Convert a str precision to equivalent torch dtype.
Expand All @@ -25,6 +27,7 @@ def dtype_from_str(dtype):
else:
return torch.float32


def dtype_from_hf(config):
"""
Extracts torch dtype from a HF config
Expand Down

0 comments on commit 5464479

Please sign in to comment.