Skip to content

Commit e05b160

Browse files
committed
[Feature] Make PPO compatible with composite actions and log-probs
ghstack-source-id: 7041ac4 Pull Request resolved: #2665
1 parent dc25a55 commit e05b160

File tree

25 files changed

+840
-253
lines changed

25 files changed

+840
-253
lines changed

.github/unittest/linux_sota/scripts/test_sota.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -188,19 +188,6 @@
188188
ppo.collector.frames_per_batch=16 \
189189
logger.mode=offline \
190190
logger.backend=
191-
""",
192-
"dreamer": """python sota-implementations/dreamer/dreamer.py \
193-
collector.total_frames=600 \
194-
collector.init_random_frames=10 \
195-
collector.frames_per_batch=200 \
196-
env.n_parallel_envs=1 \
197-
optimization.optim_steps_per_batch=1 \
198-
logger.video=False \
199-
logger.backend=csv \
200-
replay_buffer.buffer_size=120 \
201-
replay_buffer.batch_size=24 \
202-
replay_buffer.batch_length=12 \
203-
networks.rssm_hidden_dim=17
204191
""",
205192
"ddpg-single": """python sota-implementations/ddpg/ddpg.py \
206193
collector.total_frames=48 \
@@ -289,6 +276,19 @@
289276
logger.backend=
290277
""",
291278
"bandits": """python sota-implementations/bandits/dqn.py --n_steps=100
279+
""",
280+
"dreamer": """python sota-implementations/dreamer/dreamer.py \
281+
collector.total_frames=600 \
282+
collector.init_random_frames=10 \
283+
collector.frames_per_batch=200 \
284+
env.n_parallel_envs=1 \
285+
optimization.optim_steps_per_batch=1 \
286+
logger.video=False \
287+
logger.backend=csv \
288+
replay_buffer.buffer_size=120 \
289+
replay_buffer.batch_size=24 \
290+
replay_buffer.batch_length=12 \
291+
networks.rssm_hidden_dim=17
292292
""",
293293
}
294294

examples/agents/composite_actor.py

+6
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,9 @@ def forward(self, x):
5050
data = TensorDict({"x": torch.rand(10)}, [])
5151
module(data)
5252
print(actor(data))
53+
54+
55+
# TODO:
56+
# 1. Use ("action", "action0") + ("action", "action1") vs ("agent0", "action") + ("agent1", "action")
57+
# 2. Must multi-head require an action_key to be a list of keys (I guess so)
58+
# 3. Using maps in the Actor

examples/agents/composite_ppo.py

+203
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
Multi-head agent and PPO loss
8+
=============================
9+
10+
This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions
11+
(Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses.
12+
13+
The code first defines a module `make_params` that extracts the parameters of the distributions from an input tensordict.
14+
It then creates a `dist_constructor` function that takes these parameters as input and outputs a CompositeDistribution
15+
object containing the three distributions.
16+
17+
The policy is defined as a ProbabilisticTensorDictSequential module that reads an observation, casts it to parameters,
18+
creates a distribution from these parameters, and samples from the distribution to output multiple actions.
19+
20+
The example tests the policy with fake data across three different PPO losses: PPOLoss, ClipPPOLoss, and KLPENPPOLoss.
21+
22+
Note that the `log_prob` method of the CompositeDistribution object can return either an aggregated tensor or a
23+
fine-grained tensordict with individual log-probabilities, depending on the value of the `aggregate_probabilities`
24+
argument. The PPO loss modules are designed to handle both cases, and will default to `aggregate_probabilities=False`
25+
if not specified.
26+
27+
In particular, if `aggregate_probabilities=False` and `include_sum=True`, the summed log-probs will also be included in
28+
the output tensordict. However, since we have access to the individual log-probs, this feature is not typically used.
29+
30+
"""
31+
32+
import functools
33+
34+
import torch
35+
from tensordict import TensorDict
36+
from tensordict.nn import (
37+
CompositeDistribution,
38+
InteractionType,
39+
ProbabilisticTensorDictModule as Prob,
40+
ProbabilisticTensorDictSequential as ProbSeq,
41+
TensorDictModule as Mod,
42+
TensorDictSequential as Seq,
43+
WrapModule as Wrap,
44+
)
45+
from torch import distributions as d
46+
from torchrl.objectives import ClipPPOLoss, KLPENPPOLoss, PPOLoss
47+
48+
make_params = Mod(
49+
lambda: (
50+
torch.ones(4),
51+
torch.ones(4),
52+
torch.ones(4, 2),
53+
torch.ones(4, 2),
54+
torch.ones(4, 10) / 10,
55+
torch.zeros(4, 10),
56+
torch.ones(4, 10),
57+
),
58+
in_keys=[],
59+
out_keys=[
60+
("params", "gamma", "concentration"),
61+
("params", "gamma", "rate"),
62+
("params", "Kumaraswamy", "concentration0"),
63+
("params", "Kumaraswamy", "concentration1"),
64+
("params", "mixture", "logits"),
65+
("params", "mixture", "loc"),
66+
("params", "mixture", "scale"),
67+
],
68+
)
69+
70+
71+
def mixture_constructor(logits, loc, scale):
72+
return d.MixtureSameFamily(
73+
d.Categorical(logits=logits), d.Normal(loc=loc, scale=scale)
74+
)
75+
76+
77+
# =============================================================================
78+
# Example 0: aggregate_probabilities=None (default) ===========================
79+
80+
dist_constructor = functools.partial(
81+
CompositeDistribution,
82+
distribution_map={
83+
"gamma": d.Gamma,
84+
"Kumaraswamy": d.Kumaraswamy,
85+
"mixture": mixture_constructor,
86+
},
87+
name_map={
88+
"gamma": ("agent0", "action"),
89+
"Kumaraswamy": ("agent1", "action"),
90+
"mixture": ("agent2", "action"),
91+
},
92+
aggregate_probabilities=None,
93+
)
94+
95+
96+
policy = ProbSeq(
97+
make_params,
98+
Prob(
99+
in_keys=["params"],
100+
out_keys=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")],
101+
distribution_class=dist_constructor,
102+
return_log_prob=True,
103+
default_interaction_type=InteractionType.RANDOM,
104+
),
105+
)
106+
107+
td = policy(TensorDict(batch_size=[4]))
108+
print("0. result of policy call", td)
109+
110+
dist = policy.get_dist(td)
111+
log_prob = dist.log_prob(
112+
td, aggregate_probabilities=False, inplace=False, include_sum=False
113+
)
114+
print("0. non-aggregated log-prob")
115+
116+
# We can also get the log-prob from the policy directly
117+
log_prob = policy.log_prob(
118+
td, aggregate_probabilities=False, inplace=False, include_sum=False
119+
)
120+
print("0. non-aggregated log-prob (from policy)")
121+
122+
# Build a dummy value operator
123+
value_operator = Seq(
124+
Wrap(
125+
lambda td: td.set("state_value", torch.ones((*td.shape, 1))),
126+
out_keys=["state_value"],
127+
)
128+
)
129+
130+
# Create fake data
131+
data = policy(TensorDict(batch_size=[4]))
132+
data.set(
133+
"next",
134+
TensorDict(reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)),
135+
)
136+
137+
# Instantiate the loss
138+
for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
139+
ppo = loss_cls(policy, value_operator)
140+
141+
# Keys are not the default ones - there is more than one action
142+
ppo.set_keys(
143+
action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")],
144+
sample_log_prob=[
145+
("agent0", "action_log_prob"),
146+
("agent1", "action_log_prob"),
147+
("agent2", "action_log_prob"),
148+
],
149+
)
150+
151+
# Get the loss values
152+
loss_vals = ppo(data)
153+
print("0. ", loss_cls, loss_vals)
154+
155+
156+
# ===================================================================
157+
# Example 1: aggregate_probabilities=True ===========================
158+
159+
dist_constructor.keywords["aggregate_probabilities"] = True
160+
161+
td = policy(TensorDict(batch_size=[4]))
162+
print("1. result of policy call", td)
163+
164+
# Instantiate the loss
165+
for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
166+
ppo = loss_cls(policy, value_operator)
167+
168+
# Keys are not the default ones - there is more than one action. No need to indicate the sample-log-prob key, since
169+
# there is only one.
170+
ppo.set_keys(
171+
action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")]
172+
)
173+
174+
# Get the loss values
175+
loss_vals = ppo(data)
176+
print("1. ", loss_cls, loss_vals)
177+
178+
179+
# ===================================================================
180+
# Example 2: aggregate_probabilities=False ===========================
181+
182+
dist_constructor.keywords["aggregate_probabilities"] = False
183+
184+
td = policy(TensorDict(batch_size=[4]))
185+
print("2. result of policy call", td)
186+
187+
# Instantiate the loss
188+
for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
189+
ppo = loss_cls(policy, value_operator)
190+
191+
# Keys are not the default ones - there is more than one action
192+
ppo.set_keys(
193+
action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")],
194+
sample_log_prob=[
195+
("agent0", "action_log_prob"),
196+
("agent1", "action_log_prob"),
197+
("agent2", "action_log_prob"),
198+
],
199+
)
200+
201+
# Get the loss values
202+
loss_vals = ppo(data)
203+
print("2. ", loss_cls, loss_vals)

0 commit comments

Comments
 (0)