diff --git a/chatlearn/models/megatron_module.py b/chatlearn/models/megatron_module.py index eae3208e..6b8bf4c9 100644 --- a/chatlearn/models/megatron_module.py +++ b/chatlearn/models/megatron_module.py @@ -60,6 +60,8 @@ def __init__(self, *args, **kwargs): self._logger.info(f"{self.name} Overwrite global_batch_size with train_global_batch_size {self.module_args.train_global_batch_size}") if not self.model_args.get("tensorboard_dir") and self.runtime_args.output_dir is not None: self.model_args['tensorboard_dir'] = f"{self.runtime_args.output_dir}/tensorboard" + if not self.model_args.get("wandb_save_dir") and self.runtime_args.output_dir is not None: + self.model_args['wandb_save_dir'] = f"{self.runtime_args.output_dir}/wandb" def add_extra_args(self, parser): diff --git a/chatlearn/utils/megatron_import_helper.py b/chatlearn/utils/megatron_import_helper.py index 50a74bfe..44105db3 100644 --- a/chatlearn/utils/megatron_import_helper.py +++ b/chatlearn/utils/megatron_import_helper.py @@ -75,9 +75,11 @@ # megatron.global_vars.* try: from megatron.global_vars import get_tensorboard_writer + from megatron.global_vars import get_wandb_writer from megatron.global_vars import set_global_variables except ImportError: from megatron.training.global_vars import get_tensorboard_writer + from megatron.training.global_vars import get_wandb_writer from megatron.training.global_vars import set_global_variables # megatron.initialize.* diff --git a/examples/megatron/models/old_policy_inference.py b/examples/megatron/models/old_policy_inference.py index f3a1ae57..a398fafa 100644 --- a/examples/megatron/models/old_policy_inference.py +++ b/examples/megatron/models/old_policy_inference.py @@ -23,7 +23,7 @@ from megatron.training import arguments from megatron.training import get_args, get_tokenizer from megatron.training import print_rank_0 -from megatron.training.global_vars import get_tensorboard_writer +from megatron.training.global_vars import get_tensorboard_writer, get_wandb_writer from megatron.inference.text_generation.communication import broadcast_float_list, \ broadcast_int_list, broadcast_tensor from megatron.inference.text_generation.generation import generate_tokens_probs_and_return_on_first_stage @@ -36,7 +36,7 @@ from examples.megatron.data.prompt_dataset import PromptPipeline from .policy_model import PolicyModel as LegacyPolicyModel from .mcore_policy_model import MCorePolicyModel -from .utils import tensorboard_scalar_dict, get_loss_mask, get_eos_id +from .utils import tensorboard_scalar_dict, wandb_scalar_dict , get_loss_mask, get_eos_id class PolicyInference(MegatronModule): @@ -411,6 +411,8 @@ def log_entropy(self, iteration_for_log): # log writer = get_tensorboard_writer() + wandb_writer = get_wandb_writer() + # RL related stats: global if torch.distributed.is_initialized(): if torch.distributed.get_rank() == ( @@ -427,6 +429,9 @@ def log_entropy(self, iteration_for_log): } tensorboard_scalar_dict(writer, prefix=f"policy_inference/replica_id{self.replica_id}", global_step=iteration_for_log, scalar_dict=stats_args) + if wandb_writer: + wandb_scalar_dict(wandb_writer, prefix=f"policy_inference/replica_id{self.replica_id}", + global_step=iteration_for_log, scalar_dict=stats_args) else: # actual log @@ -440,6 +445,9 @@ def log_entropy(self, iteration_for_log): } tensorboard_scalar_dict(writer, prefix=f"policy_inference/replica_id{self.replica_id}", global_step=iteration_for_log, scalar_dict=stats_args) + if wandb_writer: + wandb_scalar_dict(wandb_writer, prefix=f"policy_inference/replica_id{self.replica_id}", + global_step=iteration_for_log, scalar_dict=stats_args) get_args().entropy_sum = 0 get_args().entropy_num = 0 diff --git a/examples/megatron/models/reward_inference.py b/examples/megatron/models/reward_inference.py index 3d3385b9..d704cf40 100644 --- a/examples/megatron/models/reward_inference.py +++ b/examples/megatron/models/reward_inference.py @@ -27,7 +27,7 @@ from megatron.training import get_model from megatron.training import get_tokenizer from megatron.training import print_rank_0 -from megatron.training.global_vars import get_tensorboard_writer +from megatron.training.global_vars import get_tensorboard_writer, get_wandb_writer from megatron.training.utils import get_ltor_masks_and_position_ids import chatlearn @@ -37,7 +37,7 @@ from examples.megatron.data.reward_dataset import preprocess from .reward_model import RewardModel as LegacyRewardModel from .mcore_reward_model import MCoreRewardModel -from .utils import tensorboard_scalar_dict, get_eos_id +from .utils import tensorboard_scalar_dict, wandb_scalar_dict, get_eos_id from .constants import RunningMoments, get_running_stats, reset_running_stats from .forward_step import forward_step_helper @@ -468,6 +468,7 @@ def forward_step(self, data, iteration=None): def log_each_step(self, iteration): writer = get_tensorboard_writer() + wandb_writer = get_wandb_writer() stats_episode = get_running_stats(self.per_episode_metrics) stats_episode.update(self.stats) @@ -479,6 +480,10 @@ def log_each_step(self, iteration): tensorboard_scalar_dict(writer, prefix=f"rewards_each/replica_id{self.replica_id}", global_step=iteration, scalar_dict=stats_episode) + if wandb_writer: + wandb_scalar_dict(wandb_writer, prefix=f"rewards_each/replica_id{self.replica_id}", + global_step=iteration, + scalar_dict=stats_episode) # reset runnings reset_running_stats(self.per_episode_metrics) diff --git a/examples/megatron/models/reward_math.py b/examples/megatron/models/reward_math.py index ee437800..f4d27e00 100644 --- a/examples/megatron/models/reward_math.py +++ b/examples/megatron/models/reward_math.py @@ -16,12 +16,14 @@ from collections import defaultdict import numpy as np import torch -from torch.utils.tensorboard import SummaryWriter +from megatron.training.global_vars import get_tensorboard_writer, get_wandb_writer + from chatlearn import BaseModule -from .utils import tensorboard_scalar_dict +from .utils import tensorboard_scalar_dict, wandb_scalar_dict from .constants import RunningMoments, get_running_stats, reset_running_stats from .rm_sys.math_rule_rm import MathRuleRM + class MathReward(BaseModule): """Math reward""" @@ -30,8 +32,8 @@ def setup(self): self.stats = {} self.running = RunningMoments() self.per_episode_metrics = defaultdict(RunningMoments) - tensorboard_dir = f"{self.runtime_args.output_dir}/tensorboard" - self.tensorboard_writer = SummaryWriter(log_dir=tensorboard_dir) + self.writer = get_tensorboard_writer() + self.wandb_writer = get_wandb_writer() def forward_step(self, data, iteration=0): answers = data['answer'] @@ -103,7 +105,11 @@ def log_each_step(self): stats_episode["exp_scores/running_math_std"] = self.running.std print(f"score only/running_math_mean {self.running.mean}", flush=True) - tensorboard_scalar_dict(self.tensorboard_writer, prefix=f"rewards_each/replica_id{self.replica_id}", + tensorboard_scalar_dict(self.writer, prefix=f"rewards_each/replica_id{self.replica_id}", global_step=self._iteration, scalar_dict=stats_episode) + if self.wandb_writer: + wandb_scalar_dict(self.wandb_writer, prefix=f"rewards_each/replica_id{self.replica_id}", + global_step=self._iteration, + scalar_dict=stats_episode) reset_running_stats(self.per_episode_metrics) diff --git a/examples/megatron/models/train_helper.py b/examples/megatron/models/train_helper.py index 89127015..fac4890d 100644 --- a/examples/megatron/models/train_helper.py +++ b/examples/megatron/models/train_helper.py @@ -17,10 +17,10 @@ import numpy import torch -from torch.utils.tensorboard import SummaryWriter +from megatron.training.global_vars import get_tensorboard_writer, get_wandb_writer import chatlearn -from .utils import write_jsonl, read_jsonl, tensorboard_scalar_dict, listdict_to_dictlist +from .utils import write_jsonl, read_jsonl, tensorboard_scalar_dict, wandb_scalar_dict, listdict_to_dictlist def eval_post_process(results, eval_info): """ @@ -39,10 +39,8 @@ def eval_post_process(results, eval_info): results = listdict_to_dictlist(results) if args.get('eval_data_num_limit') > 0: assert len(results['rewards']) == args.get('eval_data_num_limit'), f"expect {len(results['rewards'])} == {args.get('eval_data_num_limit')}" - tensorboard_dir = f"{args.output_dir}/tensorboard" - writer = SummaryWriter( - log_dir=tensorboard_dir, - max_queue=99999) + writer = get_tensorboard_writer() + wandb_writer = get_wandb_writer() eval_reward_stats = {"eval_reward_mean": numpy.mean(results['rewards'])} train_iteration = eval_info["train_iteration"] @@ -53,11 +51,19 @@ def eval_post_process(results, eval_info): tensorboard_scalar_dict(writer, prefix="eval_reward_each/", global_step=train_iteration, scalar_dict=eval_reward_stats) - + if wandb_writer: + wandb_scalar_dict(wandb_writer, prefix="eval_reward_each/", + global_step=train_iteration, + scalar_dict=eval_reward_stats) else: tensorboard_scalar_dict(writer, prefix="eval_reward_each/", global_step=train_iteration, scalar_dict=eval_reward_stats) + if wandb_writer: + wandb_scalar_dict(wandb_writer, prefix="eval_reward_each/", + global_step=train_iteration, + scalar_dict=eval_reward_stats) + print(f"eval reward stats: {eval_reward_stats} iter: {train_iteration}") save_fp = f"{args.output_dir}/eval/{train_iteration}/eval_json_res.json" # pylint: disable=line-too-long write_jsonl(results["eval_jsonl"], save_fp) diff --git a/examples/megatron/models/utils.py b/examples/megatron/models/utils.py index fda8f92d..be09212b 100644 --- a/examples/megatron/models/utils.py +++ b/examples/megatron/models/utils.py @@ -35,7 +35,7 @@ from megatron.training import get_num_microbatches except ImportError: from megatron.core.num_microbatches_calculator import get_num_microbatches -from megatron.training.global_vars import get_tensorboard_writer +from megatron.training.global_vars import get_tensorboard_writer, get_wandb_writer from megatron.training.training import print_datetime from torchtyping import TensorType @@ -105,6 +105,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, args = get_args() timers = get_timers() writer = get_tensorboard_writer() + wandb_writer = get_wandb_writer() # Advanced, skipped, and Nan iterations. advanced_iters_key = 'advanced iterations' @@ -275,6 +276,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, scalar_dict=iter_dict) tensorboard_scalar_dict(writer, prefix="", global_step=args.consumed_train_samples, scalar_dict=consumed_train_samples_dict) + if wandb_writer: + wandb_scalar_dict(wandb_writer, prefix="", global_step=args.consumed_train_samples, scalar_dict=stats) + wandb_scalar_dict(wandb_writer, prefix="", global_step=args.consumed_train_samples, + scalar_dict=iter_dict) + wandb_scalar_dict(wandb_writer, prefix="", global_step=args.consumed_train_samples, + scalar_dict=consumed_train_samples_dict) else: @@ -283,6 +290,11 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, tensorboard_scalar_dict(writer, prefix="", global_step=args.consumed_train_samples, scalar_dict=consumed_train_samples_dict) tensorboard_scalar_dict(writer, prefix="", global_step=args.consumed_train_samples, scalar_dict=stats) + if wandb_writer: + wandb_scalar_dict(wandb_writer, prefix="", global_step=args.consumed_train_samples, scalar_dict=iter_dict) + wandb_scalar_dict(wandb_writer, prefix="", global_step=args.consumed_train_samples, + scalar_dict=consumed_train_samples_dict) + wandb_scalar_dict(wandb_writer, prefix="", global_step=args.consumed_train_samples, scalar_dict=stats) def get_tensor_stats(xs: torch.Tensor, mask: torch.Tensor, n: int): @@ -357,6 +369,15 @@ def tensorboard_scalar_dict(tensorboard_writer, prefix, global_step, scalar_dict name = '{}/{}'.format(prefix, key) tensorboard_writer.add_scalar(name, value, global_step) +def wandb_scalar_dict(wandb_writer, prefix, global_step, scalar_dict): + if isinstance(scalar_dict, (float, int)): + name = prefix + value = scalar_dict + wandb_writer.log({f"{name}": value}, global_step) + else: + for key, value in scalar_dict.items(): + wandb_writer.log({f"{prefix}/{key}": value}, global_step) + def get_loss_mask(all_tokens_right_padded, pad_token_id, prompt_sizes): ''' diff --git a/examples/megatron/models/value_trainer.py b/examples/megatron/models/value_trainer.py index ef36c6fe..1aa09624 100644 --- a/examples/megatron/models/value_trainer.py +++ b/examples/megatron/models/value_trainer.py @@ -25,14 +25,14 @@ from megatron.training import get_timers from megatron.training import get_tokenizer from megatron.training import print_rank_0 -from megatron.training.global_vars import get_tensorboard_writer +from megatron.training.global_vars import get_tensorboard_writer, get_wandb_writer from megatron.training.utils import average_losses_across_data_parallel_group from megatron.training.utils import calc_params_l2_norm from chatlearn.utils import to_device from .value_model import ValueModel as LegacyValueModel from .mcore_value_model import MCoreValueModel -from .utils import tensorboard_scalar_dict, training_log, get_eos_id +from .utils import tensorboard_scalar_dict, wandb_scalar_dict, training_log, get_eos_id from .base_trainer import BaseTrainer from .constants import get_ltor_masks_and_position_ids_rlhf, select_actions_from_right_padded, pad_to_max_len @@ -184,12 +184,17 @@ def after_episode(self): # actual log writer = get_tensorboard_writer() + wandb_writer = get_wandb_writer() after_episode_dict = { "value/explained_variance_dp": self.stats["value/explained_variance_dp"] } tensorboard_scalar_dict(writer, prefix="", global_step=self.args.consumed_train_samples, scalar_dict=after_episode_dict) + if wandb_writer: + wandb_scalar_dict(wandb_writer, prefix="", global_step=self.args.consumed_train_samples, + scalar_dict=after_episode_dict) + def before_episode(self): '''