Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
409 changes: 295 additions & 114 deletions config.py

Large diffs are not rendered by default.

532 changes: 345 additions & 187 deletions inference.py

Large diffs are not rendered by default.

477 changes: 325 additions & 152 deletions main.py

Large diffs are not rendered by default.

71 changes: 43 additions & 28 deletions nemo_utils/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,23 @@
from typing import List

import torch
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import \
MegatronGPTModel
from nemo.collections.nlp.modules.common.megatron.megatron_init import \
fake_initialize_model_parallel
from nemo.collections.nlp.modules.common.text_generation_server import \
MegatronServer
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import (
MegatronGPTModel,
)
from nemo.collections.nlp.modules.common.megatron.megatron_init import (
fake_initialize_model_parallel,
)
from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer
from nemo.collections.nlp.modules.common.text_generation_utils import generate
from nemo.collections.nlp.modules.common.transformer.text_generation import (
LengthParam, SamplingParam)
from nemo.collections.nlp.parts.nlp_overrides import (CustomProgressBar,
NLPDDPStrategy,
NLPSaveRestoreConnector)
LengthParam,
SamplingParam,
)
from nemo.collections.nlp.parts.nlp_overrides import (
CustomProgressBar,
NLPDDPStrategy,
NLPSaveRestoreConnector,
)
from nemo.core.config import hydra_runner
from nemo.utils.app_state import AppState
from nemo.utils.model_utils import inject_model_parallel_rank
Expand All @@ -45,10 +50,9 @@
HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError):

HAVE_MEGATRON_CORE = False

__all__ = ['init_model', 'pred_by_generation']
__all__ = ["init_model", "pred_by_generation"]

"""
This is the script to run GPT text generation.
Expand All @@ -63,7 +67,9 @@ def __init__(self, sentences):
super().__init__()
self.sentences = sentences

def __len__(self,):
def __len__(
self,
):
return len(self.sentences)

def __getitem__(self, idx):
Expand All @@ -84,13 +90,13 @@ def nemo_init_model(cfg: OmegaConf):
trainer = Trainer(
strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)),
**cfg.trainer,
callbacks=[CustomProgressBar()],
# callbacks=[CustomProgressBar()],
)

if (
cfg.tensor_model_parallel_size < 0
or cfg.pipeline_model_parallel_size < 0
or cfg.get('pipeline_model_parallel_split_rank', -1) < 0
or cfg.get("pipeline_model_parallel_split_rank", -1) < 0
):
save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(cfg.gpt_model_file):
Expand All @@ -103,9 +109,15 @@ def nemo_init_model(cfg: OmegaConf):
)

with open_dict(cfg):
cfg.tensor_model_parallel_size = model_config.get('tensor_model_parallel_size', 1)
cfg.pipeline_model_parallel_size = model_config.get('pipeline_model_parallel_size', 1)
cfg.pipeline_model_parallel_split_rank = model_config.get('pipeline_model_parallel_split_rank', 0)
cfg.tensor_model_parallel_size = model_config.get(
"tensor_model_parallel_size", 1
)
cfg.pipeline_model_parallel_size = model_config.get(
"pipeline_model_parallel_size", 1
)
cfg.pipeline_model_parallel_split_rank = model_config.get(
"pipeline_model_parallel_split_rank", 0
)

assert (
cfg.trainer.devices * cfg.trainer.num_nodes
Expand All @@ -128,20 +140,24 @@ def nemo_init_model(cfg: OmegaConf):
pretrained_cfg.activations_checkpoint_granularity = None
pretrained_cfg.activations_checkpoint_method = None
pretrained_cfg.precision = trainer.precision
if pretrained_cfg.get('mcore_gpt', False):
if pretrained_cfg.get("mcore_gpt", False):
# with dist checkpointing we can use the model parallel config specified by the user
pretrained_cfg.tensor_model_parallel_size = cfg.tensor_model_parallel_size
pretrained_cfg.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size
pretrained_cfg.pipeline_model_parallel_size = (
cfg.pipeline_model_parallel_size
)
if trainer.precision == "16":
pretrained_cfg.megatron_amp_O2 = False
elif trainer.precision in ['bf16', 'bf16-mixed'] and cfg.get('megatron_amp_O2', False):
elif trainer.precision in ["bf16", "bf16-mixed"] and cfg.get(
"megatron_amp_O2", False
):
pretrained_cfg.megatron_amp_O2 = True
model = MegatronGPTModel.restore_from(
restore_path=cfg.gpt_model_file,
trainer=trainer,
override_config_path=pretrained_cfg,
save_restore_connector=save_restore_connector,
map_location=f'cuda:{trainer.local_rank}', # map_location is needed for converted models
map_location=f"cuda:{trainer.local_rank}", # map_location is needed for converted models
)

model.freeze()
Expand All @@ -153,16 +169,15 @@ def nemo_init_model(cfg: OmegaConf):
pass
return model, trainer

def nemo_generate(model, prompts: List[str], batch_size: int, trainer, cfg: OmegaConf) -> List[str]:

def nemo_generate(
model, prompts: List[str], batch_size: int, trainer, cfg: OmegaConf
) -> List[str]:
cfg_infer = OmegaConf.to_container(cfg.inference)

cfg_infer["batch_size"] = batch_size
ds = RequestDataSet(prompts)
request_dl = DataLoader(dataset=ds, batch_size=batch_size)
model.set_inference_config(cfg_infer)
response = trainer.predict(model, request_dl)
return response




1 change: 1 addition & 0 deletions nemo_utils/nemo_cfgs/megatron_gpt_inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ trainer:
logger: False # logger provided by exp_manager
precision: bf16 # 16, 32, or bf16
use_distributed_sampler: False
enable_progress_bar: False

tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
Expand Down
33 changes: 33 additions & 0 deletions nemo_utils/run_time_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import time

from nemo_utils.state_manager import SingletonMeta


class RunTimeTracker(metaclass=SingletonMeta):
def __init__(self, time_limit_sec=None):
if not hasattr(self, "initialized"):
if time_limit_sec is None:
raise ValueError(
"time_limit_sec must be provided for the first initialization."
)
self.start_time = time.time()
self.time_limit = time_limit_sec
self.initialized = True

def elapsed_time(self):
return time.time() - self.start_time

def __getstate__(self):
state = self.__dict__.copy()
return state

def __setstate__(self, state):
self.__dict__.update(state)

def has_sufficient_time(self, buffer_time=30):
"""
Check if there is sufficient time left before the time limit.
:param buffer_time: Time in seconds to be reserved for saving state (default 10 minutes)
:return: True if there is enough time left, False otherwise
"""
return self.elapsed_time() < self.time_limit - buffer_time
75 changes: 75 additions & 0 deletions nemo_utils/state_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import logging
import pickle


def create_logger(log_path):
logging.getLogger().handlers = []

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")

file_handler = logging.FileHandler(log_path)
file_handler.setFormatter(formatter)
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)

return logger


class SingletonMeta(type):
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(SingletonMeta, cls).__call__(*args, **kwargs)
return cls._instances[cls]


class StateManager(metaclass=SingletonMeta):
def __init__(self, file_path=None, log_dir=None):
if not hasattr(self, "initialized"):
if file_path is None or log_dir is None:
raise ValueError(
"file_path and log_dir must be provided for the first initialization."
)
self.file_path = file_path
self.state = {}
self.logger = create_logger(log_dir)
self.initialized = True
if self.initialized:
self.logger.info(
f"StateManager retrieved from file_path={self.file_path} and log_dir={log_dir}"
)
else:
self.logger.info(
f"StateManager initialized with file_path={self.file_path} and log_dir={log_dir}"
)

def save_state(self, state=None):
if state is not None:
self.state.update(state)
with open(self.file_path, "wb") as f:
pickle.dump(self.state, f)
self.logger.info(f"State saved to {self.file_path}")

def restore_state(self):
try:
with open(self.file_path, "rb") as f:
self.state = pickle.load(f)
return self.state
except FileNotFoundError:
return None

def add_to_state(self, key, value):
"""
Add a key-value pair to the state.
"""
self.state[key] = value

def update_state(self, state):
"""
Add a dictionary of key-value pairs to the state.
"""
self.state.update(state)
44 changes: 44 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import hashlib

from nemo_utils.state_manager import StateManager


def serialize_args(args):
"""Serialize the argparse Namespace to a hash string."""
args_str = "_".join([f"{key}_{value}" for key, value in vars(args).items()])
# Create a hash of the arguments string for a shorter, fixed-length file name
return hashlib.md5(args_str.encode()).hexdigest()


def handle_timeout_error(state_generator_func):
"""
A decorator that handles TimeoutError exceptions by saving the state and logging a message.
state_generator_func (function): A function that generates the state.
"""

def decorator(func):
def wrapper(*args, **kwargs):
state_manager = StateManager()
try:
return func(*args, **kwargs)
except TimeoutError:
if state_manager:
print("Saving state...")
state = state_generator_func(*args, **kwargs)
state_manager.save_state(state)
state_manager.logger.info("Exiting due to time limit.")
raise

return wrapper

return decorator


def generate_predict_step(obj, batch_id, gts, preds, model, all_data=None):
return {"batch_id": batch_id, "gts": gts, "preds": preds, "all_data": all_data}


def generate_semantic_attack_state(
obj, language, inference_model, results_dir, dataset
):
return {"language": language}