Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Discrete IQL #404

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 175 additions & 3 deletions d3rlpy/algos/qlearning/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,29 @@
from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...models.builders import (
create_categorical_policy,
create_continuous_q_function,
create_discrete_q_function,
create_normal_policy,
create_value_function,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import MeanQFunctionFactory
from ...models.q_functions import (
MeanQFunctionFactory,
QFunctionFactory,
make_q_func_field,
)
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.iql_impl import IQLImpl, IQLModules
from .torch.iql_impl import (
DiscreteIQLImpl,
DiscreteIQLModules,
IQLImpl,
IQLModules,
)

__all__ = ["IQLConfig", "IQL"]
__all__ = ["IQLConfig", "IQL", "DiscreteIQLConfig", "DiscreteIQL"]


@dataclasses.dataclass()
Expand Down Expand Up @@ -176,4 +187,165 @@ def get_action_type(self) -> ActionSpace:
return ActionSpace.CONTINUOUS


@dataclasses.dataclass()
class DiscreteIQLConfig(LearnableConfig):
r"""Implicit Q-Learning algorithm.

IQL is the offline RL algorithm that avoids ever querying values of unseen
actions while still being able to perform multi-step dynamic programming
updates.

There are three functions to train in IQL. First the state-value function
is trained via expectile regression.

.. math::

L_V(\psi) = \mathbb{E}_{(s, a) \sim D}
[L_2^\tau (Q_\theta (s, a) - V_\psi (s))]

where :math:`L_2^\tau (u) = |\tau - \mathbb{1}(u < 0)|u^2`.

The Q-function is trained with the state-value function to avoid query the
actions.

.. math::

L_Q(\theta) = \mathbb{E}_{(s, a, r, s') \sim D}
[(r + \gamma V_\psi(s') - Q_\theta(s, a))^2]

Finally, the policy function is trained by using advantage weighted
regression compared with `IQL`, here we use a categorical policy.

.. math::

L_\pi (\phi) = \mathbb{E}_{(s, a) \sim D}
[\exp(\beta (Q_\theta - V_\psi(s))) \log \pi_\phi(a|s)]

References:
* `Kostrikov et al., Offline Reinforcement Learning with Implicit
Q-Learning. <https://arxiv.org/abs/2110.06169>`_

Args:
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
Observation preprocessor.
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor.
reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor.
actor_learning_rate (float): Learning rate for policy function.
critic_learning_rate (float): Learning rate for Q functions.
actor_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
Optimizer factory for the actor.
critic_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
Optimizer factory for the critic.
actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory for the actor.
critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory for the critic.
value_encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory for the value function.
batch_size (int): Mini-batch size.
gamma (float): Discount factor.
tau (float): Target network synchronization coefficiency.
n_critics (int): Number of Q functions for ensemble.
expectile (float): Expectile value for value function training.
weight_temp (float): Inverse temperature value represented as
:math:`\beta`.
max_weight (float): Maximum advantage weight value to clip.
"""

actor_learning_rate: float = 3e-4
critic_learning_rate: float = 3e-4

q_func_factory: QFunctionFactory = make_q_func_field()
encoder_factory: EncoderFactory = make_encoder_field()
value_encoder_factory: EncoderFactory = make_encoder_field()
critic_optim_factory: OptimizerFactory = make_optimizer_field()

actor_encoder_factory: EncoderFactory = make_encoder_field()
actor_optim_factory: OptimizerFactory = make_optimizer_field()

batch_size: int = 256
gamma: float = 0.99
tau: float = 0.005
n_critics: int = 2
expectile: float = 0.7
weight_temp: float = 3.0
max_weight: float = 100.0

def create(self, device: DeviceArg = False) -> "DiscreteIQL":
return DiscreteIQL(self, device)

@staticmethod
def get_type() -> str:
return "discrete_iql"


class DiscreteIQL(QLearningAlgoBase[DiscreteIQLImpl, DiscreteIQLConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
policy = create_categorical_policy(
observation_shape,
action_size,
self._config.actor_encoder_factory,
self._device,
)
q_funcs, q_func_forwarder = create_discrete_q_function(
observation_shape,
action_size,
self._config.encoder_factory,
self._config.q_func_factory,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove q_func_factory from config? Instead, please use MeanQFunctionFactory just like the continuous IQL?

n_ensembles=self._config.n_critics,
device=self._device,
)
targ_q_funcs, targ_q_func_forwarder = create_discrete_q_function(
observation_shape,
action_size,
self._config.encoder_factory,
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
)
value_func = create_value_function(
observation_shape,
self._config.value_encoder_factory,
device=self._device,
)

q_func_params = list(q_funcs.named_modules())
v_func_params = list(value_func.named_modules())
critic_optim = self._config.critic_optim_factory.create(
q_func_params + v_func_params, lr=self._config.critic_learning_rate
)
actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
)

modules = DiscreteIQLModules(
policy=policy,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
value_func=value_func,
actor_optim=actor_optim,
critic_optim=critic_optim,
)

self._impl = DiscreteIQLImpl(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
tau=self._config.tau,
expectile=self._config.expectile,
weight_temp=self._config.weight_temp,
max_weight=self._config.max_weight,
device=self._device,
)

def get_action_type(self) -> ActionSpace:
return ActionSpace.DISCRETE


register_learnable(IQLConfig)
register_learnable(DiscreteIQLConfig)
114 changes: 113 additions & 1 deletion d3rlpy/algos/qlearning/torch/ddpg_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
from ....models.torch import (
ActionOutput,
ContinuousEnsembleQFunctionForwarder,
DiscreteEnsembleQFunctionForwarder,
Policy,
)
from ....torch_utility import Modules, TorchMiniBatch, hard_sync, soft_sync
from ....types import Shape, TorchObservation
from ..base import QLearningAlgoImplBase
from .utility import ContinuousQFunctionMixin
from .utility import ContinuousQFunctionMixin, DiscreteQFunctionMixin

__all__ = [
"DDPGImpl",
Expand Down Expand Up @@ -157,6 +158,117 @@ def q_function_optim(self) -> Optimizer:
return self._modules.critic_optim


class DiscreteDDPGBaseImpl(
DiscreteQFunctionMixin, QLearningAlgoImplBase, metaclass=ABCMeta
):
_modules: DDPGBaseModules
_gamma: float
_tau: float
_q_func_forwarder: DiscreteEnsembleQFunctionForwarder
_targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder

def __init__(
self,
observation_shape: Shape,
action_size: int,
modules: DDPGBaseModules,
q_func_forwarder: DiscreteEnsembleQFunctionForwarder,
targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder,
gamma: float,
tau: float,
device: str,
):
super().__init__(
observation_shape=observation_shape,
action_size=action_size,
modules=modules,
device=device,
)
self._gamma = gamma
self._tau = tau
self._q_func_forwarder = q_func_forwarder
self._targ_q_func_forwarder = targ_q_func_forwarder
hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs)

def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]:
self._modules.critic_optim.zero_grad()
q_tpn = self.compute_target(batch)
loss = self.compute_critic_loss(batch, q_tpn)
loss.critic_loss.backward()
self._modules.critic_optim.step()
return asdict_as_float(loss)

def compute_critic_loss(
self, batch: TorchMiniBatch, q_tpn: torch.Tensor
) -> DDPGBaseCriticLoss:
loss = self._q_func_forwarder.compute_error(
observations=batch.observations,
actions=batch.actions,
rewards=batch.rewards,
target=q_tpn,
terminals=batch.terminals,
gamma=self._gamma**batch.intervals,
)
return DDPGBaseCriticLoss(loss)

def update_actor(
self, batch: TorchMiniBatch, action: ActionOutput
) -> Dict[str, float]:
# Q function should be inference mode for stability
self._modules.q_funcs.eval()
self._modules.actor_optim.zero_grad()
loss = self.compute_actor_loss(batch, None)
loss.actor_loss.backward()
self._modules.actor_optim.step()
return asdict_as_float(loss)

def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
metrics = {}
action = self._modules.policy(batch.observations)
metrics.update(self.update_critic(batch))
metrics.update(self.update_actor(batch, action))
self.update_critic_target()
return metrics

@abstractmethod
def compute_actor_loss(
self, batch: TorchMiniBatch, action: None
) -> DDPGBaseActorLoss:
pass

@abstractmethod
def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
pass

def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor:
return torch.argmax(self._modules.policy(x).probs).unsqueeze(0)

@abstractmethod
def inner_sample_action(self, x: TorchObservation) -> torch.Tensor:
pass

def update_critic_target(self) -> None:
soft_sync(self._modules.targ_q_funcs, self._modules.q_funcs, self._tau)

@property
def policy(self) -> Policy:
return self._modules.policy

@property
def policy_optim(self) -> Optimizer:
return self._modules.actor_optim

@property
def q_function(self) -> nn.ModuleList:
return self._modules.q_funcs

@property
def q_function_optim(self) -> Optimizer:
return self._modules.critic_optim


@dataclasses.dataclass(frozen=True)
class DDPGModules(DDPGBaseModules):
targ_policy: Policy
Expand Down
Loading
Loading