@@ -525,10 +525,7 @@ def _log_weight(
525
525
self .actor_network
526
526
) if self .functional else contextlib .nullcontext ():
527
527
dist = self .actor_network .get_dist (tensordict )
528
- if isinstance (dist , CompositeDistribution ):
529
- is_composite = True
530
- else :
531
- is_composite = False
528
+ is_composite = isinstance (dist , CompositeDistribution )
532
529
533
530
# current log_prob of actions
534
531
if is_composite :
@@ -545,6 +542,32 @@ def _log_weight(
545
542
prev_log_prob = _maybe_get_or_select (
546
543
tensordict , self .tensor_keys .sample_log_prob
547
544
)
545
+ # TODO:
546
+ # # current log_prob of actions
547
+ # action = _maybe_get_or_select(tensordict, self.tensor_keys.action)
548
+ #
549
+ # is_composite = None
550
+ # if all(key in tensordict for key in self.actor_network.dist_params_keys):
551
+ # prev_dist = self.actor_network.build_dist_from_params(tensordict.detach())
552
+ # kwargs, is_composite = _get_composite_kwargs(prev_dist)
553
+ # if is_composite:
554
+ # prev_log_prob = prev_dist.log_prob(tensordict, **kwargs)
555
+ # else:
556
+ # prev_log_prob = prev_dist.log_prob(action, **kwargs)
557
+ # print('prev_log_prob', prev_log_prob)
558
+ # else:
559
+ # try:
560
+ # prev_log_prob = _maybe_get_or_select(
561
+ # tensordict, self.tensor_keys.sample_log_prob
562
+ # )
563
+ # except KeyError as err:
564
+ # raise _make_lp_get_error(self.tensor_keys, tensordict, err)
565
+
566
+ with self .actor_network_params .to_module (
567
+ self .actor_network
568
+ ) if self .functional else contextlib .nullcontext ():
569
+ current_dist = self .actor_network .get_dist (tensordict )
570
+
548
571
549
572
if prev_log_prob .requires_grad :
550
573
raise RuntimeError (
@@ -566,20 +589,27 @@ def _log_weight(
566
589
"the beginning of your script to get a proper composite log-prob." ,
567
590
category = UserWarning ,
568
591
)
569
- if (
570
- is_composite
571
- and not is_tensor_collection (prev_log_prob )
572
- and is_tensor_collection (log_prob )
573
- ):
574
- log_prob = _sum_td_features (log_prob )
575
- log_prob .view_as (prev_log_prob )
592
+ # TODO:
593
+ # if isinstance(action, torch.Tensor):
594
+ # log_prob = current_dist.log_prob(action)
595
+ # else:
596
+ # if is_composite is None:
597
+ # kwargs, is_composite = _get_composite_kwargs(current_dist)
598
+ # log_prob: TensorDictBase = current_dist.log_prob(tensordict, **kwargs)
599
+ if (
600
+ is_composite
601
+ and not is_tensor_collection (prev_log_prob )
602
+ and is_tensor_collection (log_prob )
603
+ ):
604
+ log_prob = _sum_td_features (log_prob )
605
+ log_prob .view_as (prev_log_prob )
576
606
577
607
log_weight = (log_prob - prev_log_prob ).unsqueeze (- 1 )
578
608
kl_approx = (prev_log_prob - log_prob ).unsqueeze (- 1 )
579
609
if is_tensor_collection (kl_approx ):
580
610
kl_approx = _sum_td_features (kl_approx )
581
611
582
- return log_weight , dist , kl_approx
612
+ return log_weight , current_dist , kl_approx
583
613
584
614
def loss_critic (self , tensordict : TensorDictBase ) -> torch .Tensor :
585
615
"""Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``."""
@@ -655,6 +685,9 @@ def _cached_critic_network_params_detached(self):
655
685
@dispatch
656
686
def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
657
687
tensordict = tensordict .clone (False )
688
+
689
+ log_weight , dist , kl_approx = self ._log_weight (tensordict )
690
+
658
691
advantage = tensordict .get (self .tensor_keys .advantage , None )
659
692
if advantage is None :
660
693
self .value_estimator (
@@ -675,7 +708,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
675
708
)
676
709
advantage = _standardize (advantage , self .normalize_advantage_exclude_dims )
677
710
678
- log_weight , dist , kl_approx = self ._log_weight (tensordict )
679
711
if is_tensor_collection (log_weight ):
680
712
log_weight = _sum_td_features (log_weight )
681
713
log_weight = log_weight .view (advantage .shape )
0 commit comments