|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +""" |
| 7 | +Multi-head agent and PPO loss |
| 8 | +============================= |
| 9 | +
|
| 10 | +This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions |
| 11 | +(Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses. |
| 12 | +
|
| 13 | +The code first defines a module `make_params` that extracts the parameters of the distributions from an input tensordict. |
| 14 | +It then creates a `dist_constructor` function that takes these parameters as input and outputs a CompositeDistribution |
| 15 | +object containing the three distributions. |
| 16 | +
|
| 17 | +The policy is defined as a ProbabilisticTensorDictSequential module that reads an observation, casts it to parameters, |
| 18 | +creates a distribution from these parameters, and samples from the distribution to output multiple actions. |
| 19 | +
|
| 20 | +The example tests the policy with fake data across three different PPO losses: PPOLoss, ClipPPOLoss, and KLPENPPOLoss. |
| 21 | +
|
| 22 | +Note that the `log_prob` method of the CompositeDistribution object can return either an aggregated tensor or a |
| 23 | +fine-grained tensordict with individual log-probabilities, depending on the value of the `aggregate_probabilities` |
| 24 | +argument. The PPO loss modules are designed to handle both cases, and will default to `aggregate_probabilities=False` |
| 25 | +if not specified. |
| 26 | +
|
| 27 | +In particular, if `aggregate_probabilities=False` and `include_sum=True`, the summed log-probs will also be included in |
| 28 | +the output tensordict. However, since we have access to the individual log-probs, this feature is not typically used. |
| 29 | +
|
| 30 | +""" |
| 31 | + |
| 32 | +import functools |
| 33 | + |
| 34 | +import torch |
| 35 | +from tensordict import TensorDict |
| 36 | +from tensordict.nn import ( |
| 37 | + CompositeDistribution, |
| 38 | + InteractionType, |
| 39 | + ProbabilisticTensorDictModule as Prob, |
| 40 | + ProbabilisticTensorDictSequential as ProbSeq, |
| 41 | + TensorDictModule as Mod, |
| 42 | + TensorDictSequential as Seq, |
| 43 | + WrapModule as Wrap, |
| 44 | +) |
| 45 | +from torch import distributions as d |
| 46 | +from torchrl.objectives import ClipPPOLoss, KLPENPPOLoss, PPOLoss |
| 47 | + |
| 48 | +make_params = Mod( |
| 49 | + lambda: ( |
| 50 | + torch.ones(4), |
| 51 | + torch.ones(4), |
| 52 | + torch.ones(4, 2), |
| 53 | + torch.ones(4, 2), |
| 54 | + torch.ones(4, 10) / 10, |
| 55 | + torch.zeros(4, 10), |
| 56 | + torch.ones(4, 10), |
| 57 | + ), |
| 58 | + in_keys=[], |
| 59 | + out_keys=[ |
| 60 | + ("params", "gamma", "concentration"), |
| 61 | + ("params", "gamma", "rate"), |
| 62 | + ("params", "Kumaraswamy", "concentration0"), |
| 63 | + ("params", "Kumaraswamy", "concentration1"), |
| 64 | + ("params", "mixture", "logits"), |
| 65 | + ("params", "mixture", "loc"), |
| 66 | + ("params", "mixture", "scale"), |
| 67 | + ], |
| 68 | +) |
| 69 | + |
| 70 | + |
| 71 | +def mixture_constructor(logits, loc, scale): |
| 72 | + return d.MixtureSameFamily( |
| 73 | + d.Categorical(logits=logits), d.Normal(loc=loc, scale=scale) |
| 74 | + ) |
| 75 | + |
| 76 | + |
| 77 | +# ============================================================================= |
| 78 | +# Example 0: aggregate_probabilities=None (default) =========================== |
| 79 | + |
| 80 | +dist_constructor = functools.partial( |
| 81 | + CompositeDistribution, |
| 82 | + distribution_map={ |
| 83 | + "gamma": d.Gamma, |
| 84 | + "Kumaraswamy": d.Kumaraswamy, |
| 85 | + "mixture": mixture_constructor, |
| 86 | + }, |
| 87 | + name_map={ |
| 88 | + "gamma": ("agent0", "action"), |
| 89 | + "Kumaraswamy": ("agent1", "action"), |
| 90 | + "mixture": ("agent2", "action"), |
| 91 | + }, |
| 92 | + aggregate_probabilities=None, |
| 93 | +) |
| 94 | + |
| 95 | + |
| 96 | +policy = ProbSeq( |
| 97 | + make_params, |
| 98 | + Prob( |
| 99 | + in_keys=["params"], |
| 100 | + out_keys=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")], |
| 101 | + distribution_class=dist_constructor, |
| 102 | + return_log_prob=True, |
| 103 | + default_interaction_type=InteractionType.RANDOM, |
| 104 | + ), |
| 105 | +) |
| 106 | + |
| 107 | +td = policy(TensorDict(batch_size=[4])) |
| 108 | +print("0. result of policy call", td) |
| 109 | + |
| 110 | +dist = policy.get_dist(td) |
| 111 | +log_prob = dist.log_prob( |
| 112 | + td, aggregate_probabilities=False, inplace=False, include_sum=False |
| 113 | +) |
| 114 | +print("0. non-aggregated log-prob") |
| 115 | + |
| 116 | +# We can also get the log-prob from the policy directly |
| 117 | +log_prob = policy.log_prob( |
| 118 | + td, aggregate_probabilities=False, inplace=False, include_sum=False |
| 119 | +) |
| 120 | +print("0. non-aggregated log-prob (from policy)") |
| 121 | + |
| 122 | +# Build a dummy value operator |
| 123 | +value_operator = Seq( |
| 124 | + Wrap( |
| 125 | + lambda td: td.set("state_value", torch.ones((*td.shape, 1))), |
| 126 | + out_keys=["state_value"], |
| 127 | + ) |
| 128 | +) |
| 129 | + |
| 130 | +# Create fake data |
| 131 | +data = policy(TensorDict(batch_size=[4])) |
| 132 | +data.set( |
| 133 | + "next", |
| 134 | + TensorDict(reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)), |
| 135 | +) |
| 136 | + |
| 137 | +# Instantiate the loss |
| 138 | +for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): |
| 139 | + ppo = loss_cls(policy, value_operator) |
| 140 | + |
| 141 | + # Keys are not the default ones - there is more than one action |
| 142 | + ppo.set_keys( |
| 143 | + action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")], |
| 144 | + sample_log_prob=[ |
| 145 | + ("agent0", "action_log_prob"), |
| 146 | + ("agent1", "action_log_prob"), |
| 147 | + ("agent2", "action_log_prob"), |
| 148 | + ], |
| 149 | + ) |
| 150 | + |
| 151 | + # Get the loss values |
| 152 | + loss_vals = ppo(data) |
| 153 | + print("0. ", loss_cls, loss_vals) |
| 154 | + |
| 155 | + |
| 156 | +# =================================================================== |
| 157 | +# Example 1: aggregate_probabilities=True =========================== |
| 158 | + |
| 159 | +dist_constructor.keywords["aggregate_probabilities"] = True |
| 160 | + |
| 161 | +td = policy(TensorDict(batch_size=[4])) |
| 162 | +print("1. result of policy call", td) |
| 163 | + |
| 164 | +# Instantiate the loss |
| 165 | +for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): |
| 166 | + ppo = loss_cls(policy, value_operator) |
| 167 | + |
| 168 | + # Keys are not the default ones - there is more than one action. No need to indicate the sample-log-prob key, since |
| 169 | + # there is only one. |
| 170 | + ppo.set_keys( |
| 171 | + action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")] |
| 172 | + ) |
| 173 | + |
| 174 | + # Get the loss values |
| 175 | + loss_vals = ppo(data) |
| 176 | + print("1. ", loss_cls, loss_vals) |
| 177 | + |
| 178 | + |
| 179 | +# =================================================================== |
| 180 | +# Example 2: aggregate_probabilities=False =========================== |
| 181 | + |
| 182 | +dist_constructor.keywords["aggregate_probabilities"] = False |
| 183 | + |
| 184 | +td = policy(TensorDict(batch_size=[4])) |
| 185 | +print("2. result of policy call", td) |
| 186 | + |
| 187 | +# Instantiate the loss |
| 188 | +for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): |
| 189 | + ppo = loss_cls(policy, value_operator) |
| 190 | + |
| 191 | + # Keys are not the default ones - there is more than one action |
| 192 | + ppo.set_keys( |
| 193 | + action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")], |
| 194 | + sample_log_prob=[ |
| 195 | + ("agent0", "action_log_prob"), |
| 196 | + ("agent1", "action_log_prob"), |
| 197 | + ("agent2", "action_log_prob"), |
| 198 | + ], |
| 199 | + ) |
| 200 | + |
| 201 | + # Get the loss values |
| 202 | + loss_vals = ppo(data) |
| 203 | + print("2. ", loss_cls, loss_vals) |
0 commit comments