From abeb00fe790c00137d20d850132ff536ecaae9fd Mon Sep 17 00:00:00 2001 From: rajat08 Date: Sat, 5 Oct 2024 02:10:16 -0700 Subject: [PATCH 1/2] boiler plate code for multi-turn reward for RLHF --- multi_turn_reward_for_RLHF/main.py | 132 +++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 multi_turn_reward_for_RLHF/main.py diff --git a/multi_turn_reward_for_RLHF/main.py b/multi_turn_reward_for_RLHF/main.py new file mode 100644 index 00000000000..9c9f736eee3 --- /dev/null +++ b/multi_turn_reward_for_RLHF/main.py @@ -0,0 +1,132 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import numpy as np + + +class DialogueEnv: + """Multi-turn dialogue environment that simulates conversations.""" + + def __init__(self): + self.turns = 5 # Each dialogue lasts 5 turns + self.current_turn = 0 + self.conversation = [] + + def reset(self): + """Resets the environment for a new dialogue.""" + self.current_turn = 0 + self.conversation = [] + return "Hi, how can I help you today?" # Starting dialogue + + def step(self, action): + """Takes an action (a response) and advances the conversation.""" + self.conversation.append(action) + self.current_turn += 1 + + if self.current_turn < self.turns: + # Generate the next response from the environment (placeholder) + next_state = f"Response {self.current_turn}: How about this?" + done = False + reward = self._human_feedback(action) + else: + next_state = "Conversation ended." + done = True + reward = self._human_feedback(action) + + return next_state, reward, done + + def _human_feedback(self, action): + """Simulates human feedback by returning a random reward.""" + return np.random.choice([1, -1]) # 1 for positive feedback, -1 for negative + + +class PolicyNetwork(nn.Module): + """Policy network that defines the agent's behavior.""" + + def __init__(self, input_size=100, hidden_size=128, output_size=10): + super(PolicyNetwork, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.fc2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + """Forward pass through the network.""" + x = torch.relu(self.fc1(x)) + return self.fc2(x) + + +def pad_or_truncate(state, size=100): + """Pads or truncates the input state to match the required input size.""" + state_tensor = torch.tensor([ord(c) for c in state], dtype=torch.float32) + if state_tensor.size(0) < size: + padded_tensor = torch.cat([state_tensor, torch.zeros(size - state_tensor.size(0))]) + else: + padded_tensor = state_tensor[:size] + return padded_tensor.unsqueeze(0) # Add batch dimension + + +def train_rlhf(env, model, optimizer, num_episodes=1000): + """Trains the policy network using reinforcement learning with human feedback.""" + gamma = 0.99 # Discount factor for future rewards + + for episode in range(num_episodes): + state = env.reset() + total_reward = 0 + log_probs = [] + rewards = [] + + done = False + while not done: + # Pad or truncate the input state to the required size + state_tensor = pad_or_truncate(state, size=100) + logits = model(state_tensor) + action_probs = torch.softmax(logits, dim=-1) + action_dist = torch.distributions.Categorical(action_probs) + + action = action_dist.sample() + log_prob = action_dist.log_prob(action) + log_probs.append(log_prob) + + # Take the action in the environment + action_text = f"Action {action.item()}" + next_state, reward, done = env.step(action_text) + rewards.append(reward) + total_reward += reward + + state = next_state + + # Calculate the discounted rewards + discounted_rewards = [] + cumulative_reward = 0 + for r in reversed(rewards): + cumulative_reward = r + gamma * cumulative_reward + discounted_rewards.insert(0, cumulative_reward) + + # Normalize the rewards + discounted_rewards = torch.tensor(discounted_rewards) + discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-6) + + # Policy Gradient: Update the policy + policy_loss = [] + for log_prob, reward in zip(log_probs, discounted_rewards): + policy_loss.append(-log_prob * reward) + + optimizer.zero_grad() + policy_loss = torch.cat(policy_loss).sum() + policy_loss.backward() + optimizer.step() + + print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {total_reward}") + + +if __name__ == "__main__": + # Instantiate environment and model + env = DialogueEnv() + input_size = 100 # Placeholder for state size (e.g., fixed-length input of size 100) + hidden_size = 128 + output_size = 10 # Placeholder for the number of possible actions (dialogue responses) + + model = PolicyNetwork(input_size, hidden_size, output_size) + optimizer = optim.Adam(model.parameters(), lr=1e-3) + + # Train the policy using RL with Human Feedback (simulated) + train_rlhf(env, model, optimizer, num_episodes=1000) From 717c01eb2c832c0f43cb039e958c0705d262eb3a Mon Sep 17 00:00:00 2001 From: rghosh8 Date: Thu, 13 Feb 2025 01:38:38 -0800 Subject: [PATCH 2/2] update --- multi_turn_reward_for_RLHF/main.py | 185 ++++++++++++++++------------- 1 file changed, 105 insertions(+), 80 deletions(-) diff --git a/multi_turn_reward_for_RLHF/main.py b/multi_turn_reward_for_RLHF/main.py index 9c9f736eee3..e7fd0ec7846 100644 --- a/multi_turn_reward_for_RLHF/main.py +++ b/multi_turn_reward_for_RLHF/main.py @@ -2,44 +2,62 @@ import torch.nn as nn import torch.optim as optim import numpy as np - - -class DialogueEnv: - """Multi-turn dialogue environment that simulates conversations.""" +from torchrl.envs import EnvBase +from torchrl.envs.libs.gym import GymWrapper +from torchrl.modules import ProbabilisticActor, ValueOperator +from torchrl.collectors import SyncDataCollector +from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage +from torchrl.objectives import ClipPPOLoss +from torchrl.objectives.value import GAE + +# Define the Dialogue Environment in TorchRL Format +class DialogueEnvTorchRL(EnvBase): + """TorchRL-compatible multi-turn dialogue environment that simulates conversations.""" def __init__(self): + super().__init__() self.turns = 5 # Each dialogue lasts 5 turns self.current_turn = 0 self.conversation = [] + self.action_spec = torch.arange(10) # 10 discrete actions (dialogue responses) + self.observation_spec = torch.zeros(100) # Fixed-size state representation - def reset(self): + def _reset(self): """Resets the environment for a new dialogue.""" self.current_turn = 0 self.conversation = [] - return "Hi, how can I help you today?" # Starting dialogue + return {"observation": self._encode_state("Hi, how can I help you today?")} - def step(self, action): + def _step(self, action): """Takes an action (a response) and advances the conversation.""" - self.conversation.append(action) + action_text = f"Action {action.item()}" + self.conversation.append(action_text) self.current_turn += 1 if self.current_turn < self.turns: - # Generate the next response from the environment (placeholder) next_state = f"Response {self.current_turn}: How about this?" done = False - reward = self._human_feedback(action) else: next_state = "Conversation ended." done = True - reward = self._human_feedback(action) - return next_state, reward, done + reward = self._human_feedback(action_text) + return {"observation": self._encode_state(next_state), "reward": reward, "done": done} def _human_feedback(self, action): """Simulates human feedback by returning a random reward.""" return np.random.choice([1, -1]) # 1 for positive feedback, -1 for negative + def _encode_state(self, state, size=100): + """Encodes state into a tensor format (pads or truncates).""" + state_tensor = torch.tensor([ord(c) for c in state], dtype=torch.float32) + if state_tensor.size(0) < size: + padded_tensor = torch.cat([state_tensor, torch.zeros(size - state_tensor.size(0))]) + else: + padded_tensor = state_tensor[:size] + return padded_tensor.unsqueeze(0) # Add batch dimension +# Define Policy Network class PolicyNetwork(nn.Module): """Policy network that defines the agent's behavior.""" @@ -53,80 +71,87 @@ def forward(self, x): x = torch.relu(self.fc1(x)) return self.fc2(x) +# Define Value Network for PPO +class ValueNetwork(nn.Module): + """Value network for estimating the state value.""" -def pad_or_truncate(state, size=100): - """Pads or truncates the input state to match the required input size.""" - state_tensor = torch.tensor([ord(c) for c in state], dtype=torch.float32) - if state_tensor.size(0) < size: - padded_tensor = torch.cat([state_tensor, torch.zeros(size - state_tensor.size(0))]) - else: - padded_tensor = state_tensor[:size] - return padded_tensor.unsqueeze(0) # Add batch dimension + def __init__(self, input_size=100, hidden_size=128): + super(ValueNetwork, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.fc2 = nn.Linear(hidden_size, 1) + def forward(self, x): + """Forward pass through the network.""" + x = torch.relu(self.fc1(x)) + return self.fc2(x) -def train_rlhf(env, model, optimizer, num_episodes=1000): - """Trains the policy network using reinforcement learning with human feedback.""" - gamma = 0.99 # Discount factor for future rewards +# Training Setup using PPO in TorchRL +def train_rlhf_torchrl(num_episodes=1000, batch_size=32): + """Trains the policy network using PPO with reinforcement learning with human feedback.""" + + # Instantiate environment + env = DialogueEnvTorchRL() + + # Create policy and value networks + policy_model = PolicyNetwork(input_size=100, hidden_size=128, output_size=10) + value_model = ValueNetwork(input_size=100, hidden_size=128) + + # Create policy distribution + policy = ProbabilisticActor( + module=policy_model, + in_keys=["observation"], + out_keys=["action"], + distribution_class=torch.distributions.Categorical + ) + + # Value function + value_operator = ValueOperator( + module=value_model, + in_keys=["observation"] + ) + + # Optimizers + policy_optimizer = optim.Adam(policy.parameters(), lr=1e-3) + value_optimizer = optim.Adam(value_operator.parameters(), lr=1e-3) + + # Setup collector + collector = SyncDataCollector( + env, policy, frames_per_batch=batch_size, total_frames=num_episodes * batch_size + ) + + # Replay buffer + buffer = TensorDictReplayBuffer( + storage=LazyTensorStorage(max_size=10000) + ) + + # Loss function (PPO) + advantage_module = GAE(value_operator=value_operator, gamma=0.99, lmbda=0.95) + loss_module = ClipPPOLoss( + actor=policy, + critic=value_operator, + advantage_module=advantage_module, + clip_epsilon=0.2 + ) for episode in range(num_episodes): - state = env.reset() - total_reward = 0 - log_probs = [] - rewards = [] - - done = False - while not done: - # Pad or truncate the input state to the required size - state_tensor = pad_or_truncate(state, size=100) - logits = model(state_tensor) - action_probs = torch.softmax(logits, dim=-1) - action_dist = torch.distributions.Categorical(action_probs) - - action = action_dist.sample() - log_prob = action_dist.log_prob(action) - log_probs.append(log_prob) - - # Take the action in the environment - action_text = f"Action {action.item()}" - next_state, reward, done = env.step(action_text) - rewards.append(reward) - total_reward += reward - - state = next_state - - # Calculate the discounted rewards - discounted_rewards = [] - cumulative_reward = 0 - for r in reversed(rewards): - cumulative_reward = r + gamma * cumulative_reward - discounted_rewards.insert(0, cumulative_reward) - - # Normalize the rewards - discounted_rewards = torch.tensor(discounted_rewards) - discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-6) - - # Policy Gradient: Update the policy - policy_loss = [] - for log_prob, reward in zip(log_probs, discounted_rewards): - policy_loss.append(-log_prob * reward) - - optimizer.zero_grad() - policy_loss = torch.cat(policy_loss).sum() - policy_loss.backward() - optimizer.step() - - print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {total_reward}") + for batch in collector: + buffer.extend(batch) + # Sample from buffer + sampled_batch = buffer.sample(batch_size) -if __name__ == "__main__": - # Instantiate environment and model - env = DialogueEnv() - input_size = 100 # Placeholder for state size (e.g., fixed-length input of size 100) - hidden_size = 128 - output_size = 10 # Placeholder for the number of possible actions (dialogue responses) + # Compute loss and update policy + loss = loss_module(sampled_batch) + policy_optimizer.zero_grad() + loss["loss_objective"].backward() + policy_optimizer.step() - model = PolicyNetwork(input_size, hidden_size, output_size) - optimizer = optim.Adam(model.parameters(), lr=1e-3) + # Update value function + value_optimizer.zero_grad() + loss["loss_critic"].backward() + value_optimizer.step() - # Train the policy using RL with Human Feedback (simulated) - train_rlhf(env, model, optimizer, num_episodes=1000) + print(f"Episode {episode + 1}/{num_episodes}, Loss: {loss['loss_objective'].item()}") + +if __name__ == "__main__": + train_rlhf_torchrl(num_episodes=1000, batch_size=32)