Skip to content

Commit 216ff86

Browse files
committed
[WIP] Compute lp during loss execution
ghstack-source-id: afc425b Pull Request resolved: #2688
1 parent c291c61 commit 216ff86

File tree

2 files changed

+52
-22
lines changed

2 files changed

+52
-22
lines changed

test/test_cost.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -8177,18 +8177,19 @@ def _create_seq_mock_data_ppo(
81778177
obs = total_obs[:, :T]
81788178
next_obs = total_obs[:, 1:]
81798179
if atoms:
8180-
action = torch.randn(batch, T, atoms, action_dim, device=device).clamp(
8181-
-1, 1
8182-
)
8180+
action_shape = (batch, T, atoms, action_dim)
81838181
else:
8184-
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
8182+
action_shape = (batch, T, action_dim)
8183+
params_mean = torch.randn(action_shape, device=device) / 10
8184+
params_scale = torch.rand(action_shape, device=device) / 10
8185+
action = (params_mean + params_scale * torch.randn(action_shape, device=device)).clamp(
8186+
-1, 1
8187+
)
81858188
reward = torch.randn(batch, T, 1, device=device)
81868189
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
81878190
terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
81888191
mask = torch.ones(batch, T, dtype=torch.bool, device=device)
81898192
action = action.masked_fill_(~mask.unsqueeze(-1), 0.0)
8190-
params_mean = torch.randn_like(action) / 10
8191-
params_scale = torch.rand_like(action) / 10
81928193
loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0)
81938194
scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0)
81948195
if sample_log_prob_key is None:
@@ -8215,9 +8216,6 @@ def _create_seq_mock_data_ppo(
82158216
},
82168217
"collector": {"mask": mask},
82178218
action_key: action,
8218-
sample_log_prob_key: (
8219-
torch.randn_like(action[..., 1]) / 10
8220-
).masked_fill_(~mask, 0.0),
82218219
},
82228220
device=device,
82238221
names=[None, "time"],

torchrl/objectives/ppo.py

+45-13
Original file line numberDiff line numberDiff line change
@@ -525,10 +525,7 @@ def _log_weight(
525525
self.actor_network
526526
) if self.functional else contextlib.nullcontext():
527527
dist = self.actor_network.get_dist(tensordict)
528-
if isinstance(dist, CompositeDistribution):
529-
is_composite = True
530-
else:
531-
is_composite = False
528+
is_composite = isinstance(dist, CompositeDistribution)
532529

533530
# current log_prob of actions
534531
if is_composite:
@@ -545,6 +542,32 @@ def _log_weight(
545542
prev_log_prob = _maybe_get_or_select(
546543
tensordict, self.tensor_keys.sample_log_prob
547544
)
545+
# TODO:
546+
# # current log_prob of actions
547+
# action = _maybe_get_or_select(tensordict, self.tensor_keys.action)
548+
#
549+
# is_composite = None
550+
# if all(key in tensordict for key in self.actor_network.dist_params_keys):
551+
# prev_dist = self.actor_network.build_dist_from_params(tensordict.detach())
552+
# kwargs, is_composite = _get_composite_kwargs(prev_dist)
553+
# if is_composite:
554+
# prev_log_prob = prev_dist.log_prob(tensordict, **kwargs)
555+
# else:
556+
# prev_log_prob = prev_dist.log_prob(action, **kwargs)
557+
# print('prev_log_prob', prev_log_prob)
558+
# else:
559+
# try:
560+
# prev_log_prob = _maybe_get_or_select(
561+
# tensordict, self.tensor_keys.sample_log_prob
562+
# )
563+
# except KeyError as err:
564+
# raise _make_lp_get_error(self.tensor_keys, tensordict, err)
565+
566+
with self.actor_network_params.to_module(
567+
self.actor_network
568+
) if self.functional else contextlib.nullcontext():
569+
current_dist = self.actor_network.get_dist(tensordict)
570+
548571

549572
if prev_log_prob.requires_grad:
550573
raise RuntimeError(
@@ -566,20 +589,27 @@ def _log_weight(
566589
"the beginning of your script to get a proper composite log-prob.",
567590
category=UserWarning,
568591
)
569-
if (
570-
is_composite
571-
and not is_tensor_collection(prev_log_prob)
572-
and is_tensor_collection(log_prob)
573-
):
574-
log_prob = _sum_td_features(log_prob)
575-
log_prob.view_as(prev_log_prob)
592+
# TODO:
593+
# if isinstance(action, torch.Tensor):
594+
# log_prob = current_dist.log_prob(action)
595+
# else:
596+
# if is_composite is None:
597+
# kwargs, is_composite = _get_composite_kwargs(current_dist)
598+
# log_prob: TensorDictBase = current_dist.log_prob(tensordict, **kwargs)
599+
if (
600+
is_composite
601+
and not is_tensor_collection(prev_log_prob)
602+
and is_tensor_collection(log_prob)
603+
):
604+
log_prob = _sum_td_features(log_prob)
605+
log_prob.view_as(prev_log_prob)
576606

577607
log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
578608
kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)
579609
if is_tensor_collection(kl_approx):
580610
kl_approx = _sum_td_features(kl_approx)
581611

582-
return log_weight, dist, kl_approx
612+
return log_weight, current_dist, kl_approx
583613

584614
def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
585615
"""Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``."""
@@ -655,6 +685,9 @@ def _cached_critic_network_params_detached(self):
655685
@dispatch
656686
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
657687
tensordict = tensordict.clone(False)
688+
689+
log_weight, dist, kl_approx = self._log_weight(tensordict)
690+
658691
advantage = tensordict.get(self.tensor_keys.advantage, None)
659692
if advantage is None:
660693
self.value_estimator(
@@ -675,7 +708,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
675708
)
676709
advantage = _standardize(advantage, self.normalize_advantage_exclude_dims)
677710

678-
log_weight, dist, kl_approx = self._log_weight(tensordict)
679711
if is_tensor_collection(log_weight):
680712
log_weight = _sum_td_features(log_weight)
681713
log_weight = log_weight.view(advantage.shape)

0 commit comments

Comments
 (0)