Skip to content

Added adaptive learning rate feature. #6180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion ml-agents/mlagents/trainers/agent_processor.py
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@
GlobalAgentId,
GlobalGroupId,
)
from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple
from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple, MusTuple, SigmasTuple
from mlagents.trainers.torch_entities.utils import ModelUtils

T = TypeVar("T")
@@ -251,6 +251,28 @@ def _process_step(
except KeyError:
log_probs_tuple = LogProbsTuple.empty_log_probs()

try:
stored_action_mus = stored_take_action_outputs["mus"]
if not isinstance(stored_action_mus, MusTuple):
stored_action_mus = stored_action_mus.to_mus_tuple()
mus_tuple = MusTuple(
continuous=stored_action_mus.continuous[idx],
discrete=stored_action_mus.discrete[idx],
)
except KeyError:
mus_tuple = MusTuple.empty_mus()

try:
stored_action_sigmas = stored_take_action_outputs["sigmas"]
if not isinstance(stored_action_sigmas, SigmasTuple):
stored_action_sigmas = stored_action_sigmas.to_sigmas_tuple()
sigmas_tuple = SigmasTuple(
continuous=stored_action_sigmas.continuous[idx],
discrete=stored_action_sigmas.discrete[idx],
)
except KeyError:
sigmas_tuple = MusTuple.empty_mus()

action_mask = stored_decision_step.action_mask
prev_action = self.policy.retrieve_previous_action([global_agent_id])[0, :]

@@ -266,6 +288,8 @@ def _process_step(
done=done,
action=action_tuple,
action_probs=log_probs_tuple,
action_mus=mus_tuple,
action_sigmas=sigmas_tuple,
action_mask=action_mask,
prev_action=prev_action,
interrupted=interrupted,
4 changes: 4 additions & 0 deletions ml-agents/mlagents/trainers/buffer.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,10 @@ class BufferKey(enum.Enum):
CONTINUOUS_ACTION = "continuous_action"
NEXT_CONT_ACTION = "next_continuous_action"
CONTINUOUS_LOG_PROBS = "continuous_log_probs"
CONTINUOUS_MUS = "continuous_mus"
DISCRETE_MUS = "discrete_mus"
CONTINUOUS_SIGMAS = "continuous_sigmas"
DISCRETE_SIGMAS = "discrete_sigmas"
DISCRETE_ACTION = "discrete_action"
NEXT_DISC_ACTION = "next_discrete_action"
DISCRETE_LOG_PROBS = "discrete_log_probs"
26 changes: 22 additions & 4 deletions ml-agents/mlagents/trainers/ppo/optimizer_torch.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
)
from mlagents.trainers.torch_entities.networks import ValueNetwork
from mlagents.trainers.torch_entities.agent_action import AgentAction
from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs
from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs, ActionMus, ActionSigmas
from mlagents.trainers.torch_entities.utils import ModelUtils
from mlagents.trainers.trajectory import ObsUtil

@@ -66,8 +66,10 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
self.decay_learning_rate = ModelUtils.DecayedValue(
self.hyperparameters.learning_rate_schedule,
self.hyperparameters.learning_rate,
1e-10,
self.hyperparameters.lr_min,
self.trainer_settings.max_steps,
self.hyperparameters.desired_lr_kl,
self.hyperparameters.lr_max
)
self.decay_epsilon = ModelUtils.DecayedValue(
self.hyperparameters.epsilon_schedule,
@@ -92,6 +94,10 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):

self.stream_names = list(self.reward_signals.keys())

self.loss = torch.zeros(1, device=default_device())

self.last_actions = None

@property
def critic(self):
return self._critic
@@ -153,13 +159,17 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:

log_probs = run_out["log_probs"]
entropy = run_out["entropy"]
mus = run_out["mus"]
sigmas = run_out["sigmas"]

values, _ = self.critic.critic_pass(
current_obs,
memories=value_memories,
sequence_length=self.policy.sequence_length,
)
old_log_probs = ActionLogProbs.from_buffer(batch).flatten()
old_mus = ActionMus.from_buffer(batch).flatten()
old_sigmas = ActionSigmas.from_buffer(batch).flatten()
log_probs = log_probs.flatten()
loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool)
value_loss = ModelUtils.trust_region_value_loss(
@@ -172,16 +182,22 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
loss_masks,
decay_eps,
)
loss = (
self.loss = (
policy_loss
+ 0.5 * value_loss
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks)
)

# adaptive learning rate
if self.hyperparameters.learning_rate_schedule == ScheduleType.ADAPTIVE:
decay_lr = self.decay_learning_rate.get_value(
self.policy.get_current_step(), mus, old_mus, sigmas, old_sigmas
)

# Set optimizer learning rate
ModelUtils.update_learning_rate(self.optimizer, decay_lr)
self.optimizer.zero_grad()
loss.backward()
self.loss.backward()

self.optimizer.step()
update_stats = {
@@ -194,6 +210,8 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
"Policy/Beta": decay_bet,
}

self.loss = torch.zeros(1, device=default_device())

return update_stats

# TODO move module update into TorchOptimizer for reward_provider
4 changes: 4 additions & 0 deletions ml-agents/mlagents/trainers/settings.py
Original file line number Diff line number Diff line change
@@ -105,6 +105,7 @@ class EncoderType(Enum):
class ScheduleType(Enum):
CONSTANT = "constant"
LINEAR = "linear"
ADAPTIVE = "adaptive"
# TODO add support for lesson based scheduling
# LESSON = "lesson"

@@ -158,6 +159,9 @@ class HyperparamSettings:
batch_size: int = 1024
buffer_size: int = 10240
learning_rate: float = 3.0e-4
desired_lr_kl: float = 0.008
lr_min: float = 1.0e-10
lr_max: float = 1.0e-2
learning_rate_schedule: ScheduleType = ScheduleType.CONSTANT


170 changes: 170 additions & 0 deletions ml-agents/mlagents/trainers/torch_entities/action_log_probs.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,38 @@
from mlagents_envs.base_env import _ActionTupleBase


class MusTuple(_ActionTupleBase):
"""
An object whose fields correspond to the mean of action data of continuous and discrete
spaces. Dimensions are of (n_agents, continuous_size) and (n_agents, discrete_size),
respectively. Note, this also holds when continuous or discrete size is
zero.
"""
@property
def discrete_dtype(self) -> np.dtype:
return np.float32

@staticmethod
def empty_mus() -> "MusTuple":
return MusTuple()


class SigmasTuple(_ActionTupleBase):
"""
An object whose fields correspond to the std of action data of continuous and discrete
spaces. Dimensions are of (n_agents, continuous_size) and (n_agents, discrete_size),
respectively. Note, this also holds when continuous or discrete size is
zero.
"""
@property
def discrete_dtype(self) -> np.dtype:
return np.float32

@staticmethod
def empty_sigmas() -> "SigmasTuple":
return SigmasTuple()


class LogProbsTuple(_ActionTupleBase):
"""
An object whose fields correspond to the log probs of actions of different types.
@@ -116,3 +148,141 @@ def from_buffer(buff: AgentBuffer) -> "ActionLogProbs":
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])
]
return ActionLogProbs(continuous, discrete, None)


class ActionMus(NamedTuple):
continuous_tensor: torch.Tensor
discrete_list: Optional[List[torch.Tensor]]
all_discrete_list: Optional[List[torch.Tensor]]

@property
def discrete_tensor(self):
"""
Returns the discrete log probs list as a stacked tensor
"""
return torch.stack(self.discrete_list, dim=-1)

@property
def all_discrete_tensor(self):
"""
Returns the discrete log probs of each branch as a tensor
"""
return torch.cat(self.all_discrete_list, dim=1)

def to_mus_tuple(self) -> MusTuple:
mus_tuple = MusTuple()
if self.continuous_tensor is not None:
continuous = ModelUtils.to_numpy(self.continuous_tensor)
mus_tuple.add_continuous(continuous)
if self.discrete_list is not None:
discrete = ModelUtils.to_numpy(self.discrete_tensor)
mus_tuple.add_discrete(discrete)
return mus_tuple

def _to_tensor_list(self) -> List[torch.Tensor]:
"""
Returns the tensors in the ActionLogProbs as a flat List of torch Tensors. This
is private and serves as a utility for self.flatten()
"""
tensor_list: List[torch.Tensor] = []
if self.continuous_tensor is not None:
tensor_list.append(self.continuous_tensor)
if self.discrete_list is not None:
tensor_list.append(self.discrete_tensor)
return tensor_list

def flatten(self) -> torch.Tensor:
"""
A utility method that returns all log probs in ActionLogProbs as a flattened tensor.
This is useful for algorithms like PPO which can treat all log probs in the same way.
"""
return torch.cat(self._to_tensor_list(), dim=1)

@staticmethod
def from_buffer(buff: AgentBuffer) -> "ActionMus":
"""
A static method that accesses continuous and discrete log probs fields in an AgentBuffer
and constructs the corresponding ActionLogProbs from the retrieved np arrays.
"""
continuous: torch.Tensor = None
discrete: List[torch.Tensor] = None # type: ignore

if BufferKey.CONTINUOUS_MUS in buff:
continuous = ModelUtils.list_to_tensor(buff[BufferKey.CONTINUOUS_MUS])
if BufferKey.DISCRETE_MUS in buff:
discrete_tensor = ModelUtils.list_to_tensor(buff[BufferKey.DISCRETE_MUS])
# This will keep discrete_list = None which enables flatten()
if discrete_tensor.shape[1] > 0:
discrete = [
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])
]
return ActionMus(continuous, discrete, None)


class ActionSigmas(NamedTuple):
continuous_tensor: torch.Tensor
discrete_list: Optional[List[torch.Tensor]]
all_discrete_list: Optional[List[torch.Tensor]]

@property
def discrete_tensor(self):
"""
Returns the discrete log probs list as a stacked tensor
"""
return torch.stack(self.discrete_list, dim=-1)

@property
def all_discrete_tensor(self):
"""
Returns the discrete log probs of each branch as a tensor
"""
return torch.cat(self.all_discrete_list, dim=1)

def to_sigmas_tuple(self) -> SigmasTuple:
sigmas_tuple = SigmasTuple()
if self.continuous_tensor is not None:
continuous = ModelUtils.to_numpy(self.continuous_tensor)
sigmas_tuple.add_continuous(continuous)
if self.discrete_list is not None:
discrete = ModelUtils.to_numpy(self.discrete_tensor)
sigmas_tuple.add_discrete(discrete)
return sigmas_tuple

def _to_tensor_list(self) -> List[torch.Tensor]:
"""
Returns the tensors in the ActionLogProbs as a flat List of torch Tensors. This
is private and serves as a utility for self.flatten()
"""
tensor_list: List[torch.Tensor] = []
if self.continuous_tensor is not None:
tensor_list.append(self.continuous_tensor)
if self.discrete_list is not None:
tensor_list.append(self.discrete_tensor)
return tensor_list

def flatten(self) -> torch.Tensor:
"""
A utility method that returns all log probs in ActionLogProbs as a flattened tensor.
This is useful for algorithms like PPO which can treat all log probs in the same way.
"""
return torch.cat(self._to_tensor_list(), dim=1)

@staticmethod
def from_buffer(buff: AgentBuffer) -> "ActionSigmas":
"""
A static method that accesses continuous and discrete log probs fields in an AgentBuffer
and constructs the corresponding ActionLogProbs from the retrieved np arrays.
"""
continuous: torch.Tensor = None
discrete: List[torch.Tensor] = None # type: ignore

if BufferKey.CONTINUOUS_SIGMAS in buff:
continuous = ModelUtils.list_to_tensor(buff[BufferKey.CONTINUOUS_SIGMAS])
if BufferKey.DISCRETE_SIGMAS in buff:
discrete_tensor = ModelUtils.list_to_tensor(buff[BufferKey.DISCRETE_SIGMAS])
# This will keep discrete_list = None which enables flatten()
if discrete_tensor.shape[1] > 0:
discrete = [
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])
]
return ActionSigmas(continuous, discrete, None)
28 changes: 24 additions & 4 deletions ml-agents/mlagents/trainers/torch_entities/action_model.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
MultiCategoricalDistribution,
)
from mlagents.trainers.torch_entities.agent_action import AgentAction
from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs
from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs, ActionMus, ActionSigmas
from mlagents_envs.base_env import ActionSpec


@@ -146,9 +146,25 @@ def _get_probs_and_entropy(
entropies = torch.cat(entropies_list, dim=1)
return action_log_probs, entropies

def _get_mus_and_sigmas(self, actions, dists):
continuous_mus: Optional[torch.Tensor] = None
continuous_sigmas: Optional[torch.Tensor] = None
discrete_mus: Optional[torch.Tensor] = None
discrete_sigmas: Optional[torch.Tensor] = None
all_discrete_mus: Optional[List[torch.Tensor]] = None
all_discrete_sigmas: Optional[List[torch.Tensor]] = None
if dists.continuous is not None:
continuous_mus = dists.continuous.mu()
continuous_sigmas = dists.continuous.sigma()
action_mus = ActionMus(continuous_mus, discrete_mus, all_discrete_mus)
action_sigmas = ActionSigmas(
continuous_sigmas, discrete_sigmas, all_discrete_sigmas
)
return action_mus, action_sigmas

def evaluate(
self, inputs: torch.Tensor, masks: torch.Tensor, actions: AgentAction
) -> Tuple[ActionLogProbs, torch.Tensor]:
) -> Tuple[ActionLogProbs, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Given actions and encoding from the network body, gets the distributions and
computes the log probabilites and entropies.
@@ -159,9 +175,12 @@ def evaluate(
"""
dists = self._get_dists(inputs, masks)
log_probs, entropies = self._get_probs_and_entropy(actions, dists)
# mus = dists.continuous.deterministic_sample()
mus = dists.continuous.mu()
sigmas = dists.continuous.sigma()
# Use the sum of entropy across actions, not the mean
entropy_sum = torch.sum(entropies, dim=1)
return log_probs, entropy_sum
return log_probs, entropy_sum, mus, sigmas

def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
"""
@@ -228,4 +247,5 @@ def forward(
log_probs, entropies = self._get_probs_and_entropy(actions, dists)
# Use the sum of entropy across actions, not the mean
entropy_sum = torch.sum(entropies, dim=1)
return (actions, log_probs, entropy_sum)
mus, sigmas = self._get_mus_and_sigmas(actions, dists)
return (actions, log_probs, entropy_sum, mus, sigmas)
14 changes: 14 additions & 0 deletions ml-agents/mlagents/trainers/torch_entities/distributions.py
Original file line number Diff line number Diff line change
@@ -23,6 +23,14 @@ def deterministic_sample(self) -> torch.Tensor:
"""
pass

@abc.abstractmethod
def mu(self):
pass

@abc.abstractmethod
def sigma(self):
pass

@abc.abstractmethod
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
@@ -69,6 +77,12 @@ def sample(self):
def deterministic_sample(self):
return self.mean

def mu(self):
return self.mean

def sigma(self):
return self.std

def log_prob(self, value):
var = self.std**2
log_scale = torch.log(self.std + EPSILON)
39 changes: 37 additions & 2 deletions ml-agents/mlagents/trainers/torch_entities/networks.py
Original file line number Diff line number Diff line change
@@ -531,6 +531,15 @@ def get_action_and_stats(
"""
pass

def get_mus(
self,
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Dict[str, Any]:
pass

def get_stats(
self,
inputs: List[torch.Tensor],
@@ -637,7 +646,7 @@ def get_action_and_stats(
encoding, memories = self.network_body(
inputs, memories=memories, sequence_length=sequence_length
)
action, log_probs, entropies = self.action_model(encoding, masks)
action, log_probs, entropies, mus, sigmas = self.action_model(encoding, masks)
run_out = {}
# This is the clipped action which is not saved to the buffer
# but is exclusively sent to the environment.
@@ -646,9 +655,32 @@ def get_action_and_stats(
)
run_out["log_probs"] = log_probs
run_out["entropy"] = entropies
run_out["mus"] = mus
run_out["sigmas"] = sigmas

return action, run_out, memories

def get_mus(
self,
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Dict[str, Any]:
encoding, actor_mem_outs = self.network_body(
inputs, memories=memories, sequence_length=sequence_length
)

(
continuous_out,
discrete_out,
action_out_deprecated,
deterministic_continuous_out,
deterministic_discrete_out,
) = self.action_model.get_action_out(encoding, masks)
run_out = {"mus": deterministic_continuous_out}
return run_out

def get_stats(
self,
inputs: List[torch.Tensor],
@@ -661,10 +693,13 @@ def get_stats(
inputs, memories=memories, sequence_length=sequence_length
)

log_probs, entropies = self.action_model.evaluate(encoding, masks, actions)
log_probs, entropies, mus, sigmas = self.action_model.evaluate(encoding, masks, actions)
run_out = {}
run_out["log_probs"] = log_probs
run_out["entropy"] = entropies
run_out["mus"] = mus
run_out["sigmas"] = sigmas

return run_out

def forward(
68 changes: 64 additions & 4 deletions ml-agents/mlagents/trainers/torch_entities/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Dict
from typing import List, Optional, Tuple, Dict, Any
from mlagents.torch_utils import torch, nn
from mlagents.trainers.torch_entities.layers import LinearEncoder, Initialization
import numpy as np
@@ -67,26 +67,43 @@ def __init__(
initial_value: float,
min_value: float,
max_step: int,
desired_kl: float = None,
max_value: float = None,
):
"""
Object that represnets value of a parameter that should be decayed, assuming it is a function of
Object that represents value of a parameter that should be decayed, assuming it is a function of
global_step.
:param schedule: Type of learning rate schedule.
:param initial_value: Initial value before decay.
:param min_value: Decay value to this value by max_step.
:param max_step: The final step count where the return value should equal min_value.
:param global_step: The current step count.
:param desired_kl: Target KL.
:param max_value: Max value.
:return: The value.
"""
self.schedule = schedule
self.initial_value = initial_value
self.current_value = initial_value
self.min_value = min_value
self.max_step = max_step
self.desired_kl = desired_kl
self.max_value = max_value

def get_value(self, global_step: int) -> float:
def get_value(
self,
global_step: int,
mus: Any = None,
old_mus: Any = None,
sigmas: Any = None,
old_sigmas: Any = None,
) -> float:
"""
Get the value at a given global step.
:param global_step: Step count.
:param mus: Mean value.
:param old_mus: Old mean value.
:param sigmas: Sigma values.
:param old_sigmas: Old sigma values.
:returns: Decayed value at this global step.
"""
if self.schedule == ScheduleType.CONSTANT:
@@ -95,6 +112,18 @@ def get_value(self, global_step: int) -> float:
return ModelUtils.polynomial_decay(
self.initial_value, self.min_value, self.max_step, global_step
)
elif self.schedule == ScheduleType.ADAPTIVE:
self.current_value = ModelUtils.adaptive_decay(
self.current_value,
self.desired_kl,
self.max_value,
self.min_value,
mus,
old_mus,
sigmas,
old_sigmas,
)
return self.current_value
else:
raise UnityTrainerException(f"The schedule {self.schedule} is invalid.")

@@ -121,6 +150,37 @@ def polynomial_decay(
) ** (power) + min_value
return decayed_value

@staticmethod
def adaptive_decay(
current_value: float,
desired_kl: float,
max_value: float,
min_value: float,
mus: Any = None,
old_mus: Any = None,
sigmas: Any = None,
old_sigmas: Any = None,
) -> float:
if mus is None or old_mus is None or sigmas is None or old_sigmas is None:
return current_value
decayed_value = current_value
kl_star = desired_kl
with torch.no_grad():
kl = torch.sum(
torch.log(sigmas / old_sigmas + 1.0e-5)
+ (torch.square(old_sigmas) + torch.square(old_mus - mus))
/ (2.0 * torch.square(sigmas))
- 0.5,
dim=-1,
)
kl_mean = kl.mean()
# print(f"KL: {kl_mean}")
if kl_mean > kl_star * 2.0:
decayed_value = max(min_value, decayed_value / 1.5)
elif kl_star / 2.0 > kl_mean > 0.0:
decayed_value = min(max_value, 1.5 * decayed_value)
return decayed_value

@staticmethod
def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module:
ENCODER_FUNCTION_BY_TYPE = {
11 changes: 10 additions & 1 deletion ml-agents/mlagents/trainers/trajectory.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@
BufferKey,
)
from mlagents_envs.base_env import ActionTuple
from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple
from mlagents.trainers.torch_entities.action_log_probs import LogProbsTuple, MusTuple, SigmasTuple


class AgentStatus(NamedTuple):
@@ -35,6 +35,8 @@ class AgentExperience(NamedTuple):
done: bool
action: ActionTuple
action_probs: LogProbsTuple
action_mus: MusTuple
action_sigmas: SigmasTuple
action_mask: np.ndarray
prev_action: np.ndarray
interrupted: bool
@@ -267,6 +269,13 @@ def to_agentbuffer(self) -> AgentBuffer:
agent_buffer_trajectory[BufferKey.DISCRETE_LOG_PROBS].append(
exp.action_probs.discrete
)
agent_buffer_trajectory[BufferKey.CONTINUOUS_MUS].append(
exp.action_mus.continuous
)

agent_buffer_trajectory[BufferKey.CONTINUOUS_SIGMAS].append(
exp.action_sigmas.continuous
)

# Store action masks if necessary. Note that 1 means active, while
# in AgentExperience False means active.