Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions chatlearn/models/megatron_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def __init__(self, *args, **kwargs):
self._logger.info(f"{self.name} Overwrite global_batch_size with train_global_batch_size {self.module_args.train_global_batch_size}")
if not self.model_args.get("tensorboard_dir") and self.runtime_args.output_dir is not None:
self.model_args['tensorboard_dir'] = f"{self.runtime_args.output_dir}/tensorboard"
if not self.model_args.get("wandb_save_dir") and self.runtime_args.output_dir is not None:
self.model_args['wandb_save_dir'] = f"{self.runtime_args.output_dir}/wandb"


def add_extra_args(self, parser):
Expand Down
2 changes: 2 additions & 0 deletions chatlearn/utils/megatron_import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,11 @@
# megatron.global_vars.*
try:
from megatron.global_vars import get_tensorboard_writer
from megatron.global_vars import get_wandb_writer
from megatron.global_vars import set_global_variables
except ImportError:
from megatron.training.global_vars import get_tensorboard_writer
from megatron.training.global_vars import get_wandb_writer
from megatron.training.global_vars import set_global_variables

# megatron.initialize.*
Expand Down
12 changes: 10 additions & 2 deletions examples/megatron/models/old_policy_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from megatron.training import arguments
from megatron.training import get_args, get_tokenizer
from megatron.training import print_rank_0
from megatron.training.global_vars import get_tensorboard_writer
from megatron.training.global_vars import get_tensorboard_writer, get_wandb_writer
from megatron.inference.text_generation.communication import broadcast_float_list, \
broadcast_int_list, broadcast_tensor
from megatron.inference.text_generation.generation import generate_tokens_probs_and_return_on_first_stage
Expand All @@ -36,7 +36,7 @@
from examples.megatron.data.prompt_dataset import PromptPipeline
from .policy_model import PolicyModel as LegacyPolicyModel
from .mcore_policy_model import MCorePolicyModel
from .utils import tensorboard_scalar_dict, get_loss_mask, get_eos_id
from .utils import tensorboard_scalar_dict, wandb_scalar_dict , get_loss_mask, get_eos_id


class PolicyInference(MegatronModule):
Expand Down Expand Up @@ -411,6 +411,8 @@ def log_entropy(self, iteration_for_log):
# log

writer = get_tensorboard_writer()
wandb_writer = get_wandb_writer()

# RL related stats: global
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == (
Expand All @@ -427,6 +429,9 @@ def log_entropy(self, iteration_for_log):
}
tensorboard_scalar_dict(writer, prefix=f"policy_inference/replica_id{self.replica_id}",
global_step=iteration_for_log, scalar_dict=stats_args)
if wandb_writer:
wandb_scalar_dict(wandb_writer, prefix=f"policy_inference/replica_id{self.replica_id}",
global_step=iteration_for_log, scalar_dict=stats_args)

else:
# actual log
Expand All @@ -440,6 +445,9 @@ def log_entropy(self, iteration_for_log):
}
tensorboard_scalar_dict(writer, prefix=f"policy_inference/replica_id{self.replica_id}",
global_step=iteration_for_log, scalar_dict=stats_args)
if wandb_writer:
wandb_scalar_dict(wandb_writer, prefix=f"policy_inference/replica_id{self.replica_id}",
global_step=iteration_for_log, scalar_dict=stats_args)

get_args().entropy_sum = 0
get_args().entropy_num = 0
9 changes: 7 additions & 2 deletions examples/megatron/models/reward_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from megatron.training import get_model
from megatron.training import get_tokenizer
from megatron.training import print_rank_0
from megatron.training.global_vars import get_tensorboard_writer
from megatron.training.global_vars import get_tensorboard_writer, get_wandb_writer
from megatron.training.utils import get_ltor_masks_and_position_ids

import chatlearn
Expand All @@ -37,7 +37,7 @@
from examples.megatron.data.reward_dataset import preprocess
from .reward_model import RewardModel as LegacyRewardModel
from .mcore_reward_model import MCoreRewardModel
from .utils import tensorboard_scalar_dict, get_eos_id
from .utils import tensorboard_scalar_dict, wandb_scalar_dict, get_eos_id
from .constants import RunningMoments, get_running_stats, reset_running_stats
from .forward_step import forward_step_helper

Expand Down Expand Up @@ -468,6 +468,7 @@ def forward_step(self, data, iteration=None):

def log_each_step(self, iteration):
writer = get_tensorboard_writer()
wandb_writer = get_wandb_writer()
stats_episode = get_running_stats(self.per_episode_metrics)
stats_episode.update(self.stats)

Expand All @@ -479,6 +480,10 @@ def log_each_step(self, iteration):
tensorboard_scalar_dict(writer, prefix=f"rewards_each/replica_id{self.replica_id}",
global_step=iteration,
scalar_dict=stats_episode)
if wandb_writer:
wandb_scalar_dict(wandb_writer, prefix=f"rewards_each/replica_id{self.replica_id}",
global_step=iteration,
scalar_dict=stats_episode)
# reset runnings
reset_running_stats(self.per_episode_metrics)

Expand Down
16 changes: 11 additions & 5 deletions examples/megatron/models/reward_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
from collections import defaultdict
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from megatron.training.global_vars import get_tensorboard_writer, get_wandb_writer

from chatlearn import BaseModule
from .utils import tensorboard_scalar_dict
from .utils import tensorboard_scalar_dict, wandb_scalar_dict
from .constants import RunningMoments, get_running_stats, reset_running_stats
from .rm_sys.math_rule_rm import MathRuleRM


class MathReward(BaseModule):
"""Math reward"""

Expand All @@ -30,8 +32,8 @@ def setup(self):
self.stats = {}
self.running = RunningMoments()
self.per_episode_metrics = defaultdict(RunningMoments)
tensorboard_dir = f"{self.runtime_args.output_dir}/tensorboard"
self.tensorboard_writer = SummaryWriter(log_dir=tensorboard_dir)
self.writer = get_tensorboard_writer()
self.wandb_writer = get_wandb_writer()

def forward_step(self, data, iteration=0):
answers = data['answer']
Expand Down Expand Up @@ -103,7 +105,11 @@ def log_each_step(self):
stats_episode["exp_scores/running_math_std"] = self.running.std

print(f"score only/running_math_mean {self.running.mean}", flush=True)
tensorboard_scalar_dict(self.tensorboard_writer, prefix=f"rewards_each/replica_id{self.replica_id}",
tensorboard_scalar_dict(self.writer, prefix=f"rewards_each/replica_id{self.replica_id}",
global_step=self._iteration,
scalar_dict=stats_episode)
if self.wandb_writer:
wandb_scalar_dict(self.wandb_writer, prefix=f"rewards_each/replica_id{self.replica_id}",
global_step=self._iteration,
scalar_dict=stats_episode)
reset_running_stats(self.per_episode_metrics)
20 changes: 13 additions & 7 deletions examples/megatron/models/train_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import numpy
import torch

from torch.utils.tensorboard import SummaryWriter
from megatron.training.global_vars import get_tensorboard_writer, get_wandb_writer

import chatlearn
from .utils import write_jsonl, read_jsonl, tensorboard_scalar_dict, listdict_to_dictlist
from .utils import write_jsonl, read_jsonl, tensorboard_scalar_dict, wandb_scalar_dict, listdict_to_dictlist

def eval_post_process(results, eval_info):
"""
Expand All @@ -39,10 +39,8 @@ def eval_post_process(results, eval_info):
results = listdict_to_dictlist(results)
if args.get('eval_data_num_limit') > 0:
assert len(results['rewards']) == args.get('eval_data_num_limit'), f"expect {len(results['rewards'])} == {args.get('eval_data_num_limit')}"
tensorboard_dir = f"{args.output_dir}/tensorboard"
writer = SummaryWriter(
log_dir=tensorboard_dir,
max_queue=99999)
writer = get_tensorboard_writer()
wandb_writer = get_wandb_writer()

eval_reward_stats = {"eval_reward_mean": numpy.mean(results['rewards'])}
train_iteration = eval_info["train_iteration"]
Expand All @@ -53,11 +51,19 @@ def eval_post_process(results, eval_info):
tensorboard_scalar_dict(writer, prefix="eval_reward_each/",
global_step=train_iteration,
scalar_dict=eval_reward_stats)

if wandb_writer:
wandb_scalar_dict(wandb_writer, prefix="eval_reward_each/",
global_step=train_iteration,
scalar_dict=eval_reward_stats)
else:
tensorboard_scalar_dict(writer, prefix="eval_reward_each/",
global_step=train_iteration,
scalar_dict=eval_reward_stats)
if wandb_writer:
wandb_scalar_dict(wandb_writer, prefix="eval_reward_each/",
global_step=train_iteration,
scalar_dict=eval_reward_stats)

print(f"eval reward stats: {eval_reward_stats} iter: {train_iteration}")
save_fp = f"{args.output_dir}/eval/{train_iteration}/eval_json_res.json" # pylint: disable=line-too-long
write_jsonl(results["eval_jsonl"], save_fp)
Expand Down
23 changes: 22 additions & 1 deletion examples/megatron/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from megatron.training import get_num_microbatches
except ImportError:
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.training.global_vars import get_tensorboard_writer
from megatron.training.global_vars import get_tensorboard_writer, get_wandb_writer
from megatron.training.training import print_datetime
from torchtyping import TensorType

Expand Down Expand Up @@ -105,6 +105,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
args = get_args()
timers = get_timers()
writer = get_tensorboard_writer()
wandb_writer = get_wandb_writer()

# Advanced, skipped, and Nan iterations.
advanced_iters_key = 'advanced iterations'
Expand Down Expand Up @@ -275,6 +276,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
scalar_dict=iter_dict)
tensorboard_scalar_dict(writer, prefix="", global_step=args.consumed_train_samples,
scalar_dict=consumed_train_samples_dict)
if wandb_writer:
wandb_scalar_dict(wandb_writer, prefix="", global_step=args.consumed_train_samples, scalar_dict=stats)
wandb_scalar_dict(wandb_writer, prefix="", global_step=args.consumed_train_samples,
scalar_dict=iter_dict)
wandb_scalar_dict(wandb_writer, prefix="", global_step=args.consumed_train_samples,
scalar_dict=consumed_train_samples_dict)


else:
Expand All @@ -283,6 +290,11 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
tensorboard_scalar_dict(writer, prefix="", global_step=args.consumed_train_samples,
scalar_dict=consumed_train_samples_dict)
tensorboard_scalar_dict(writer, prefix="", global_step=args.consumed_train_samples, scalar_dict=stats)
if wandb_writer:
wandb_scalar_dict(wandb_writer, prefix="", global_step=args.consumed_train_samples, scalar_dict=iter_dict)
wandb_scalar_dict(wandb_writer, prefix="", global_step=args.consumed_train_samples,
scalar_dict=consumed_train_samples_dict)
wandb_scalar_dict(wandb_writer, prefix="", global_step=args.consumed_train_samples, scalar_dict=stats)


def get_tensor_stats(xs: torch.Tensor, mask: torch.Tensor, n: int):
Expand Down Expand Up @@ -357,6 +369,15 @@ def tensorboard_scalar_dict(tensorboard_writer, prefix, global_step, scalar_dict
name = '{}/{}'.format(prefix, key)
tensorboard_writer.add_scalar(name, value, global_step)

def wandb_scalar_dict(wandb_writer, prefix, global_step, scalar_dict):
if isinstance(scalar_dict, (float, int)):
name = prefix
value = scalar_dict
wandb_writer.log({f"{name}": value}, global_step)
else:
for key, value in scalar_dict.items():
wandb_writer.log({f"{prefix}/{key}": value}, global_step)


def get_loss_mask(all_tokens_right_padded, pad_token_id, prompt_sizes):
'''
Expand Down
9 changes: 7 additions & 2 deletions examples/megatron/models/value_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
from megatron.training import get_timers
from megatron.training import get_tokenizer
from megatron.training import print_rank_0
from megatron.training.global_vars import get_tensorboard_writer
from megatron.training.global_vars import get_tensorboard_writer, get_wandb_writer
from megatron.training.utils import average_losses_across_data_parallel_group
from megatron.training.utils import calc_params_l2_norm

from chatlearn.utils import to_device
from .value_model import ValueModel as LegacyValueModel
from .mcore_value_model import MCoreValueModel
from .utils import tensorboard_scalar_dict, training_log, get_eos_id
from .utils import tensorboard_scalar_dict, wandb_scalar_dict, training_log, get_eos_id
from .base_trainer import BaseTrainer
from .constants import get_ltor_masks_and_position_ids_rlhf, select_actions_from_right_padded, pad_to_max_len

Expand Down Expand Up @@ -184,12 +184,17 @@ def after_episode(self):

# actual log
writer = get_tensorboard_writer()
wandb_writer = get_wandb_writer()

after_episode_dict = {
"value/explained_variance_dp": self.stats["value/explained_variance_dp"]
}
tensorboard_scalar_dict(writer, prefix="", global_step=self.args.consumed_train_samples,
scalar_dict=after_episode_dict)
if wandb_writer:
wandb_scalar_dict(wandb_writer, prefix="", global_step=self.args.consumed_train_samples,
scalar_dict=after_episode_dict)


def before_episode(self):
'''
Expand Down