diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 0a57416dad6..bddf5aeecce 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2150,17 +2150,18 @@ def _step_proc_data(self, next_tensordict_out): else next_tensordict_out.shape ) for reward_key in self.reward_keys: - reward = next_tensordict_out.get(reward_key) expected_reward_shape = torch.Size( [ *leading_batch_size, *self.output_spec["full_reward_spec"][reward_key].shape, ] ) - actual_reward_shape = reward.shape - if actual_reward_shape != expected_reward_shape: - reward = reward.view(expected_reward_shape) - next_tensordict_out.set(reward_key, reward) + if all(s > 0 for s in expected_reward_shape): + reward = next_tensordict_out.get(reward_key, as_nested=True) + actual_reward_shape = reward.shape + if actual_reward_shape != expected_reward_shape: + reward = reward.view(expected_reward_shape) + next_tensordict_out.set(reward_key, reward) self._complete_done(self.full_done_spec, next_tensordict_out) diff --git a/torchrl/envs/transforms/llm.py b/torchrl/envs/transforms/llm.py index da0673ccb41..6b406b1a11f 100644 --- a/torchrl/envs/transforms/llm.py +++ b/torchrl/envs/transforms/llm.py @@ -12,11 +12,7 @@ import torch from tensordict import lazy_stack, NestedKey, TensorDict, TensorDictBase, unravel_key -from tensordict.nn import ( - ProbabilisticTensorDictModule, - ProbabilisticTensorDictSequential, - TensorDictParams, -) +from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams from tensordict.utils import _zip_strict, is_seq_of_nested_key from torch import nn @@ -657,9 +653,11 @@ def __init__( in_keys=None, out_keys=None, requires_grad=False, + # TODO: adapt this to new API log_prob_key: NestedKey = "sample_log_prob", - action_key: NestedKey = "action", - functional: bool = True, + action_key: NestedKey | None = None, + functional: bool | None = None, + device: torch.device | None = None, ): if in_keys is None: in_keys = self.DEFAULT_IN_KEYS @@ -676,15 +674,16 @@ def __init__( raise ValueError( f"Only one in_key/out_key is allowed, got in_keys={self.in_keys}, out_keys={self.out_keys}." ) - # for convenience, convert out_keys to tuples - self._out_keys = [ - out_key if isinstance(out_key, tuple) else (out_key,) - for out_key in self._out_keys - ] + self._out_keys = [unravel_key(out_key) for out_key in self._out_keys] # update the in_keys for dispatch etc self.in_keys = self.in_keys + actor.in_keys + self.in_keys = [unravel_key(in_key) for in_key in self.in_keys] + + if functional is None: + from torchrl.modules.llm import CategoricalSequential + functional = not isinstance(actor, CategoricalSequential) self.functional = functional # check that the model has parameters if functional: @@ -721,6 +720,7 @@ def _make_detached_param(x): # self._buffers["actor_params"] = params.clone().detach() + self.device = device self.action_key = action_key # find the sample log-prob key @@ -736,55 +736,102 @@ def find_sample_log_prob(module): coef = torch.as_tensor(coef) self.register_buffer("coef", coef) + def set_container(self, container: Transform | EnvBase) -> None: + result = super().set_container(container) + if self.action_key is None: + parent = getattr(self, "parent", None) + if parent is not None: + action_keys = parent.action_keys + if len(action_keys) != 1: + raise ValueError( + f"More than one action_key found. Please pass the `action_key` argument directly to {type(self).__name__}." + ) + action_key = action_keys[0] + self.action_key = action_key + return result + def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: with _set_missing_tolerance(self, True): - tensordict_reset = self._call(tensordict_reset) + tensordict_reset = self._step(tensordict_reset, tensordict_reset) return tensordict_reset - def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: # run the actor on the tensordict - action = next_tensordict.get(self.action_key, None) + action_key = self.action_key + if action_key is None: + raise ValueError( + f"action_key is required. Please set a parent for the {type(self).__name__} to recover the action keys automatically, " + f"or pass the action_key argument directly to {type(self).__name__} constructor." + ) + action = tensordict.get(action_key, None) if action is None: + if not self.missing_tolerance: + raise RuntimeError( + f"Action with key {action_key} not found data {tensordict}" + ) # being called after reset or without action, skipping - if self.out_keys[0] != ("reward",) and self.parent is not None: + if self.out_keys[0] != "reward" and self.parent is not None: next_tensordict.set(self.out_keys[0], self.parent.reward_spec.zero()) return next_tensordict + if self.device is not None: + action = action.to(self.device) + if self.functional: with self.frozen_params.to_module(self.functional_actor): - dist = self.functional_actor.get_dist(next_tensordict.clone(False)) + dist = self.functional_actor.get_dist(tensordict.clone(False)) # get the log_prob given the original model log_prob = dist.log_prob(action) - elif isinstance( - self.functional_actor, - (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential), - ): - # with self.frozen_params.to_module(self.functional_actor): - dist = self.functional_actor.get_dist(next_tensordict.copy()) - # get the log_prob given the original model - log_prob = dist.log_prob(action) - else: - log_prob = self.functional_actor(next_tensordict.copy()).get( - self.sample_log_prob_key + elif hasattr(self.functional_actor, "log_prob"): + if self.device is not None: + td_device = tensordict.to(self.device) + else: + td_device = tensordict + log_prob = self.functional_actor.log_prob( + td_device, as_nested_tensor=True, layout=torch.strided ) + else: + log_prob = self.functional_actor(tensordict).get(self.sample_log_prob_key) reward_key = self.in_keys[0] - reward = next_tensordict.get("next").get(reward_key) - curr_log_prob = next_tensordict.get(self.sample_log_prob_key) + reward = next_tensordict.get(reward_key) + curr_log_prob = tensordict.get( + self.sample_log_prob_key, as_nested_tensor=True, layout=torch.strided + ) + log_prob = log_prob.to(curr_log_prob.device) + curr_log_prob = curr_log_prob.unsqueeze(-1) + # log_prob = log_prob.unsqueeze(-1) + # we use the unbiased consistent estimator of the KL: log_p(x) - log_q(x) when x ~ p(x) - kl = (curr_log_prob - log_prob).view_as(reward) - next_tensordict.set(("next", *self.out_keys[0]), reward + self.coef * kl) + if not reward.is_nested and log_prob.is_nested: + reward = torch.nested.nested_tensor( + [rew.expand(lp.shape) for rew, lp in zip(reward, log_prob)], + layout=torch.strided, + ) + if log_prob[0].shape != curr_log_prob[0].shape: + # Don't check shapes if nested + raise ValueError( + f"the log-probability tensor shapes must match, got cur_log_prob.shape={curr_log_prob[0].shape} and log_prob.shape={log_prob[0].shape}." + ) + if reward is not None and reward.ndim != curr_log_prob.ndim: + raise ValueError( + "The number of dimensions of reward must be the same as the number of dimensions of the KL " + f"term. Got ndim={reward.ndim} and {curr_log_prob.ndim} respectively." + ) + kl = curr_log_prob - log_prob + if reward is None: + reward = 0 + next_tensordict.set(self.out_keys[0], reward + self.coef * kl) return next_tensordict - def _step( - self, tensordict: TensorDictBase, next_tensordict: TensorDictBase - ) -> TensorDictBase: - with tensordict.unlock_(): - return self._call(tensordict.set("next", next_tensordict)).pop("next") - - forward = _call + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + next_td = tensordict.pop("next") + next_td = self._step(tensordict, next_td) + return tensordict.set("next", next_td) def transform_output_spec(self, output_spec: Composite) -> Composite: in_key = unravel_key(self.in_keys[0]) @@ -800,24 +847,38 @@ def transform_output_spec(self, output_spec: Composite) -> Composite: reward_key = "reward" else: raise KeyError("Couln't find the reward key.") - + shape = output_spec["full_reward_spec"][reward_key].shape + shape = (*shape[:-2], -1, 1) reward_spec = Unbounded( device=output_spec.device, - shape=output_spec["full_reward_spec"][reward_key].shape, + shape=shape, ) output_spec["full_reward_spec"] = Composite( {reward_key: reward_spec}, shape=output_spec["full_reward_spec"].shape, ) elif in_key == "reward": + # TODO: we should at least allow to make this a component of the reward specs, to avoid a call during reset parent = self.parent - reward_spec = output_spec["full_reward_spec"][parent.reward_key].clone() + reward_spec = output_spec["full_reward_spec"][parent.reward_key] + + shape = reward_spec.shape + shape = (*shape[:-2], -1, 1) + reward_spec = reward_spec.clone() + reward_spec.shape = torch.Size(shape) + # then we need to populate the output keys observation_spec = output_spec["full_observation_spec"] observation_spec[out_key] = reward_spec else: observation_spec = output_spec["full_observation_spec"] - reward_spec = observation_spec[in_key].clone() + reward_spec = observation_spec[in_key] + + shape = reward_spec.shape + shape = (*shape[:-2], -1, 1) + reward_spec = reward_spec.clone() + reward_spec.shape = torch.Size(shape) + # then we need to populate the output keys observation_spec[out_key] = reward_spec return output_spec