diff --git a/.gitignore b/.gitignore index d0f4500..0e05091 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,4 @@ marltoolbox/examples/Tutorial_*.py marltoolbox/experiments/Tutorial_*.py api_key_wandb wandb/ +requirements.txt diff --git a/marltoolbox/algos/amTFT/__init__.py b/marltoolbox/algos/amTFT/__init__.py index a8d5f52..d3d651e 100644 --- a/marltoolbox/algos/amTFT/__init__.py +++ b/marltoolbox/algos/amTFT/__init__.py @@ -19,3 +19,21 @@ AmTFTRolloutsTorchPolicy, ) from marltoolbox.algos.amTFT.train_helper import train_amtft + + +__all__ = [ + "train_amtft", + "AmTFTRolloutsTorchPolicy", + "amTFTQValuesTorchPolicy", + "Level1amTFTExploiterTorchPolicy", + "observation_fn", + "AmTFTCallbacks", + "WORKING_STATES", + "WORKING_STATES_IN_EVALUATION", + "AmTFTReferenceClass", + "DEFAULT_NESTED_POLICY_COOP", + "DEFAULT_NESTED_POLICY_SELFISH", + "PLOT_ASSEMBLAGE_TAGS", + "PLOT_KEYS", + "DEFAULT_CONFIG", +] diff --git a/marltoolbox/algos/amTFT/base.py b/marltoolbox/algos/amTFT/base.py index 6c383a3..a45697f 100644 --- a/marltoolbox/algos/amTFT/base.py +++ b/marltoolbox/algos/amTFT/base.py @@ -4,20 +4,23 @@ from ray.rllib.utils import merge_dicts from marltoolbox.algos import hierarchical, augmented_dqn -from marltoolbox.utils import \ - postprocessing, miscellaneous +from marltoolbox.utils import postprocessing, miscellaneous logger = logging.getLogger(__name__) APPROXIMATION_METHOD_Q_VALUE = "amTFT_use_Q_net" APPROXIMATION_METHOD_ROLLOUTS = "amTFT_use_rollout" -APPROXIMATION_METHODS = (APPROXIMATION_METHOD_Q_VALUE, - APPROXIMATION_METHOD_ROLLOUTS) -WORKING_STATES = ("train_coop", - "train_selfish", - "eval_amtft", - "eval_naive_selfish", - "eval_naive_coop") +APPROXIMATION_METHODS = ( + APPROXIMATION_METHOD_Q_VALUE, + APPROXIMATION_METHOD_ROLLOUTS, +) +WORKING_STATES = ( + "train_coop", + "train_selfish", + "eval_amtft", + "eval_naive_selfish", + "eval_naive_coop", +) WORKING_STATES_IN_EVALUATION = WORKING_STATES[2:] OWN_COOP_POLICY_IDX = 0 @@ -29,8 +32,9 @@ DEFAULT_NESTED_POLICY_COOP = DEFAULT_NESTED_POLICY_SELFISH.with_updates( postprocess_fn=miscellaneous.merge_policy_postprocessing_fn( postprocessing.welfares_postprocessing_fn( - add_utilitarian_welfare=True, ), - postprocess_nstep_and_prio + add_utilitarian_welfare=True, + ), + postprocess_nstep_and_prio, ) ) @@ -44,31 +48,35 @@ "rollout_length": 40, "n_rollout_replicas": 20, "last_k": 1, + "punish_instead_of_selfish": False, # TODO use log level of RLLib instead of mine "verbose": 1, "auto_load_checkpoint": True, - # To configure "own_policy_id": None, "opp_policy_id": None, "callbacks": None, # One from marltoolbox.utils.postprocessing.WELFARES "welfare_key": postprocessing.WELFARE_UTILITARIAN, - - 'nested_policies': [ + "nested_policies": [ # Here the trainer need to be a DQNTrainer # to provide the config for the 3 DQNTorchPolicy - {"Policy_class": DEFAULT_NESTED_POLICY_COOP, - "config_update": {}}, - {"Policy_class": DEFAULT_NESTED_POLICY_SELFISH, - "config_update": {}}, - {"Policy_class": DEFAULT_NESTED_POLICY_COOP, - "config_update": {}}, - {"Policy_class": DEFAULT_NESTED_POLICY_SELFISH, - "config_update": {}}, + {"Policy_class": DEFAULT_NESTED_POLICY_COOP, "config_update": {}}, + { + "Policy_class": DEFAULT_NESTED_POLICY_SELFISH, + "config_update": {}, + }, + {"Policy_class": DEFAULT_NESTED_POLICY_COOP, "config_update": {}}, + { + "Policy_class": DEFAULT_NESTED_POLICY_SELFISH, + "config_update": {}, + }, ], - "optimizer": {"sgd_momentum": 0.0, }, - } + "optimizer": { + "sgd_momentum": 0.0, + }, + "batch_mode": "complete_episodes", + }, ) PLOT_KEYS = [ @@ -76,7 +84,8 @@ "debit", "debit_threshold", "summed_debit", - "summed_n_steps_to_punish" + "summed_n_steps_to_punish", + "reset_rnn_state", ] PLOT_ASSEMBLAGE_TAGS = [ @@ -85,6 +94,7 @@ ("debit_threshold",), ("summed_debit",), ("summed_n_steps_to_punish",), + ("reset_rnn_state",), ] diff --git a/marltoolbox/algos/amTFT/base_policy.py b/marltoolbox/algos/amTFT/base_policy.py index 33c63d0..9b8978e 100644 --- a/marltoolbox/algos/amTFT/base_policy.py +++ b/marltoolbox/algos/amTFT/base_policy.py @@ -1,14 +1,19 @@ import copy import logging -from typing import List, Union, Optional, Dict, Tuple, TYPE_CHECKING +from typing import Dict, TYPE_CHECKING import numpy as np +import torch from ray.rllib.env import BaseEnv from ray.rllib.evaluation import MultiAgentEpisode from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import override -from ray.rllib.utils.typing import TensorType, PolicyID +from ray.rllib.utils.threading import with_lock +from ray.rllib.utils.torch_ops import ( + convert_to_torch_tensor, +) +from ray.rllib.utils.typing import PolicyID if TYPE_CHECKING: from ray.rllib.evaluation import RolloutWorker @@ -37,21 +42,32 @@ def __init__(self, observation_space, action_space, config, **kwargs): self.total_debit = 0 self.n_steps_to_punish = 0 self.observed_n_step_in_current_epi = 0 - self._first_fake_step_played = False + self.last_own_algo_idx_in_eval = OWN_COOP_POLICY_IDX self.opp_previous_obs = None self.opp_new_obs = None self.own_previous_obs = None self.own_new_obs = None self.both_previous_raw_obs = None self.both_new_raw_obs = None + self.coop_own_rnn_state_before_last_act = None + self.coop_opp_rnn_state_before_last_act = None # notation T in the paper self.debit_threshold = config["debit_threshold"] # notation alpha in the paper self.punishment_multiplier = config["punishment_multiplier"] self.working_state = config["working_state"] + assert ( + self.working_state in WORKING_STATES + ), f"self.working_state {self.working_state}" self.verbose = config["verbose"] self.welfare_key = config["welfare_key"] self.auto_load_checkpoint = config.get("auto_load_checkpoint", True) + self.punish_instead_of_selfish = config.get( + "punish_instead_of_selfish", False + ) + self.punish_instead_of_selfish_key = ( + postprocessing.OPPONENT_NEGATIVE_REWARD + ) if self.working_state in WORKING_STATES_IN_EVALUATION: self._set_models_for_evaluation() @@ -60,45 +76,32 @@ def __init__(self, observation_space, action_space, config, **kwargs): self.auto_load_checkpoint and restore.LOAD_FROM_CONFIG_KEY in config.keys() ): - restore.after_init_load_policy_checkpoint(self) + restore.before_loss_init_load_policy_checkpoint(self) def _set_models_for_evaluation(self): for algo in self.algorithms: algo.model.eval() + @with_lock @override(hierarchical.HierarchicalTorchPolicy) - def compute_actions( - self, - obs_batch: Union[List[TensorType], TensorType], - state_batches: Optional[List[TensorType]] = None, - prev_action_batch: Union[List[TensorType], TensorType] = None, - prev_reward_batch: Union[List[TensorType], TensorType] = None, - info_batch: Optional[Dict[str, list]] = None, - episodes: Optional[List["MultiAgentEpisode"]] = None, - explore: Optional[bool] = None, - timestep: Optional[int] = None, - **kwargs, - ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: - - self._select_witch_algo_to_use() - actions, state_out, extra_fetches = self.algorithms[ - self.active_algo_idx - ].compute_actions( - obs_batch, - state_batches, - prev_action_batch, - prev_reward_batch, - info_batch, - episodes, - explore, - timestep, - **kwargs, + def _compute_action_helper( + self, input_dict, state_batches, seq_lens, explore, timestep + ): + state_batches = self._select_witch_algo_to_use(state_batches) + + self._track_last_coop_rnn_state(state_batches) + actions, state_out, extra_fetches = super()._compute_action_helper( + input_dict, state_batches, seq_lens, explore, timestep ) - if self.verbose > 2: - print(f"self.active_algo_idx {self.active_algo_idx}") + if self.verbose > 1: + print("algo idx", self.active_algo_idx, "action", actions) + print("extra_fetches", extra_fetches) + print("state_batches (in)", state_batches) + print("state_out", state_out) + return actions, state_out, extra_fetches - def _select_witch_algo_to_use(self): + def _select_witch_algo_to_use(self, state_batches): if ( self.working_state == WORKING_STATES[0] or self.working_state == WORKING_STATES[4] @@ -109,14 +112,17 @@ def _select_witch_algo_to_use(self): or self.working_state == WORKING_STATES[3] ): self.active_algo_idx = OWN_SELFISH_POLICY_IDX + elif self.working_state == WORKING_STATES[2]: - self._select_algo_to_use_in_eval() + state_batches = self._select_algo_to_use_in_eval(state_batches) else: raise ValueError( f'config["working_state"] ' f"must be one of {WORKING_STATES}" ) - def _select_algo_to_use_in_eval(self): + return state_batches + + def _select_algo_to_use_in_eval(self, state_batches): if self.n_steps_to_punish == 0: self.active_algo_idx = OWN_COOP_POLICY_IDX elif self.n_steps_to_punish > 0: @@ -125,23 +131,58 @@ def _select_algo_to_use_in_eval(self): else: raise ValueError("self.n_steps_to_punish can't be below zero") + state_batches = self._check_for_rnn_state_reset( + state_batches, "last_own_algo_idx_in_eval" + ) + + return state_batches + + def _check_for_rnn_state_reset(self, state_batches, last_algo_idx: str): + if getattr(self, last_algo_idx) != self.active_algo_idx: + state_batches = self._get_initial_rnn_state(state_batches) + self._to_log["reset_rnn_state"] = self.active_algo_idx + setattr(self, last_algo_idx, self.active_algo_idx) + if self.verbose > 0: + print("reset_rnn_state") + else: + if "reset_rnn_state" in self._to_log.keys(): + self._to_log.pop("reset_rnn_state") + return state_batches + + def _get_initial_rnn_state(self, state_batches): + if "model" in self.config.keys() and self.config["model"]["use_lstm"]: + initial_state = self.algorithms[ + self.active_algo_idx + ].get_initial_state() + initial_state = [ + convert_to_torch_tensor(s, self.device) for s in initial_state + ] + initial_state = [s.unsqueeze(0) for s in initial_state] + msg = ( + f"self.active_algo_idx {self.active_algo_idx} " + f"state_batches {state_batches} reset to initial rnn state" + ) + # print(msg) + logger.info(msg) + return initial_state + else: + return state_batches + + def _track_last_coop_rnn_state(self, state_batches): + if self.active_algo_idx == OWN_COOP_POLICY_IDX: + self.coop_own_rnn_state_before_last_act = state_batches + @override(hierarchical.HierarchicalTorchPolicy) def _learn_on_batch(self, samples: SampleBatch): - # working_state_idx = WORKING_STATES.index(self.working_state) - # assert working_state_idx == OWN_COOP_POLICY_IDX \ - # or working_state_idx == OWN_SELFISH_POLICY_IDX, \ - # f"current working_state is {self.working_state} " \ - # f"but must be one of " \ - # f"[{WORKING_STATES[OWN_COOP_POLICY_IDX]}, " \ - # f"{WORKING_STATES[OWN_SELFISH_POLICY_IDX]}]" - if self.working_state == WORKING_STATES[0]: algo_idx_to_train = OWN_COOP_POLICY_IDX elif self.working_state == WORKING_STATES[1]: algo_idx_to_train = OWN_SELFISH_POLICY_IDX else: - raise ValueError() + raise ValueError( + f"self.working_state must be one of " f"{WORKING_STATES[0:2]}" + ) samples = self._modify_batch_for_policy(algo_idx_to_train, samples) algo_to_train = self.algorithms[algo_idx_to_train] @@ -158,14 +199,25 @@ def _learn_on_batch(self, samples: SampleBatch): def _modify_batch_for_policy(self, algo_idx_to_train, samples): if algo_idx_to_train == OWN_COOP_POLICY_IDX: samples = samples.copy() - samples = self._overwrite_reward_for_policy_in_use(samples) + samples = self._overwrite_reward_for_policy_in_use( + samples, self.welfare_key + ) + elif ( + self.punish_instead_of_selfish + and algo_idx_to_train == OWN_SELFISH_POLICY_IDX + ): + samples = samples.copy() + samples = self._overwrite_reward_for_policy_in_use( + samples, self.punish_instead_of_selfish_key + ) + return samples - def _overwrite_reward_for_policy_in_use(self, samples_copy): + def _overwrite_reward_for_policy_in_use(self, samples_copy, welfare_key): samples_copy[samples_copy.REWARDS] = np.array( - samples_copy.data[self.welfare_key] + samples_copy.data[welfare_key] ) - logger.debug(f"overwrite reward with {self.welfare_key}") + logger.debug(f"overwrite reward with {welfare_key}") return samples_copy def on_observation_fn(self, own_new_obs, opp_new_obs, both_new_raw_obs): @@ -173,7 +225,7 @@ def on_observation_fn(self, own_new_obs, opp_new_obs, both_new_raw_obs): # observation produced by this action. But we need the # observation that cause the agent to play this action # thus the observation n-1 - if self._first_fake_step_played: + if self.own_new_obs is not None: self.own_previous_obs = self.own_new_obs self.opp_previous_obs = self.opp_new_obs self.both_previous_raw_obs = self.both_new_raw_obs @@ -193,20 +245,17 @@ def on_episode_step( *args, **kwargs, ): - if self._first_fake_step_played: - opp_obs, raw_obs, opp_a = self._get_information_from_opponent( - policy_id, policy_ids, episode - ) + opp_obs, raw_obs, opp_a = self._get_information_from_opponent( + policy_id, policy_ids, episode + ) - # Ignored the first step in epi because the - # actions provided are fake (they were not played) - self._on_episode_step( - opp_obs, raw_obs, opp_a, worker, base_env, episode, env_index - ) + # Ignored the first step in epi because the + # actions provided are fake (they were not played) + self._on_episode_step( + opp_obs, raw_obs, opp_a, worker, base_env, episode, env_index + ) - self.observed_n_step_in_current_epi += 1 - else: - self._first_fake_step_played = True + self.observed_n_step_in_current_epi += 1 def _get_information_from_opponent(self, agent_id, agent_ids, episode): opp_agent_id = [one_id for one_id in agent_ids if one_id != agent_id][ @@ -243,11 +292,27 @@ def _on_episode_step( coop_opp_simulated_action = ( self._simulate_action_from_cooperative_opponent(opp_obs) ) + assert ( + len(worker.async_env.env_states) == 1 + ), "amTFT in eval only works with one env not vector of envs" + assert ( + worker.env.step_count_in_current_episode + == worker.async_env.env_states[ + 0 + ].env.step_count_in_current_episode + ) + assert ( + worker.env.step_count_in_current_episode + == self._base_env_at_last_step.get_unwrapped()[ + 0 + ].step_count_in_current_episode + + 1 + ) self._update_total_debit( last_obs, opp_action, worker, - base_env, + self._base_env_at_last_step, episode, env_index, coop_opp_simulated_action, @@ -257,20 +322,53 @@ def _on_episode_step( opp_action, coop_opp_simulated_action, worker, last_obs ) + self._base_env_at_last_step = copy.deepcopy(base_env) + self._to_log["n_steps_to_punish"] = self.n_steps_to_punish self._to_log["debit_threshold"] = self.debit_threshold def _is_punishment_planned(self): return self.n_steps_to_punish > 0 + def on_episode_start(self, *args, **kwargs): + if self.working_state in WORKING_STATES_IN_EVALUATION: + self._base_env_at_last_step = copy.deepcopy(kwargs["base_env"]) + self.last_own_algo_idx_in_eval = OWN_COOP_POLICY_IDX + self.coop_opp_rnn_state_after_last_act = self.algorithms[ + OPP_COOP_POLICY_IDX + ].get_initial_state() + def _simulate_action_from_cooperative_opponent(self, opp_obs): + if self.verbose > 1: + print("opp_obs for opp coop simu nonzero obs", np.nonzero(opp_obs)) + for i, algo in enumerate(self.algorithms): + print("algo", i, algo) + self.coop_opp_rnn_state_before_last_act = ( + self.coop_opp_rnn_state_after_last_act + ) ( coop_opp_simulated_action, - _, - self.coop_opp_extra_fetches, - ) = self.algorithms[OPP_COOP_POLICY_IDX].compute_actions([opp_obs]) - # Returned a list - coop_opp_simulated_action = coop_opp_simulated_action[0] + self.coop_opp_rnn_state_after_last_act, + coop_opp_extra_fetches, + ) = self.algorithms[OPP_COOP_POLICY_IDX].compute_single_action( + obs=opp_obs, + state=self.coop_opp_rnn_state_after_last_act, + ) + if self.verbose > 1: + print( + coop_opp_simulated_action, + "coop_opp_extra_fetches", + coop_opp_extra_fetches, + ) + print( + "state before simu coop opp", + self.coop_opp_rnn_state_before_last_act, + ) + print( + "state after simu coop opp", + self.coop_opp_rnn_state_after_last_act, + ) + return coop_opp_simulated_action def _update_total_debit( @@ -283,8 +381,19 @@ def _update_total_debit( env_index, coop_opp_simulated_action, ): - + if self.verbose > 1: + print( + self.own_policy_id, + self.config[restore.LOAD_FROM_CONFIG_KEY][0].split("/")[-5], + ) if coop_opp_simulated_action != opp_action: + if self.verbose > 0: + print( + self.own_policy_id, + "coop_opp_simulated_action != opp_action:", + coop_opp_simulated_action, + opp_action, + ) if ( worker.env.step_count_in_current_episode >= worker.env.max_steps @@ -319,11 +428,6 @@ def _update_total_debit( ) if coop_opp_simulated_action != opp_action: if self.verbose > 0: - print( - "coop_opp_simulated_action != opp_action:", - coop_opp_simulated_action, - opp_action, - ) print(f"debit {debit}") print( f"self.total_debit {self.total_debit}, previous was {tmp}" @@ -331,8 +435,8 @@ def _update_total_debit( def _is_starting_new_punishment_required(self, manual_threshold=None): if manual_threshold is not None: - return self.total_debit > manual_threshold - return self.total_debit > self.debit_threshold + return self.total_debit >= manual_threshold + return self.total_debit >= self.debit_threshold def _plan_punishment( self, opp_action, coop_opp_simulated_action, worker, last_obs @@ -355,7 +459,11 @@ def _plan_punishment( print(f"reset self.total_debit to 0 since planned punishement") def on_episode_end(self, *args, **kwargs): - if self.working_state not in WORKING_STATES_IN_EVALUATION: + self._defensive_check_observed_n_opp_moves(*args, **kwargs) + self._if_in_eval_reset_debit_and_punish() + + def _defensive_check_observed_n_opp_moves(self, *args, **kwargs): + if self.working_state in WORKING_STATES_IN_EVALUATION: assert ( self.observed_n_step_in_current_epi == kwargs["base_env"].get_unwrapped()[0].max_steps @@ -365,13 +473,16 @@ def on_episode_end(self, *args, **kwargs): f"{kwargs['base_env'].get_unwrapped()[0].max_steps} " "steps per episodes." ) - self.observed_n_step_in_current_epi = 0 + self.observed_n_step_in_current_epi = 0 - if self.working_state == WORKING_STATES[2]: + def _if_in_eval_reset_debit_and_punish(self): + if self.working_state in WORKING_STATES_IN_EVALUATION: self.total_debit = 0 self.n_steps_to_punish = 0 if self.verbose > 0: - print(f"reset self.total_debit to 0 since end of episode") + logger.info( + "reset self.total_debit to 0 since end of " "episode" + ) def _compute_debit( self, @@ -390,6 +501,27 @@ def _compute_punishment_duration( ): raise NotImplementedError() + def _get_last_rnn_states_before_rollouts(self): + if self.config["model"]["use_lstm"]: + return { + self.own_policy_id: self._squeezes_rnn_state( + self.coop_own_rnn_state_before_last_act + ), + self.opp_policy_id: self.coop_opp_rnn_state_before_last_act, + } + + else: + return None + + @staticmethod + def _squeezes_rnn_state(state): + return [ + s.squeeze(0) + if torch and isinstance(s, torch.Tensor) + else np.squeeze(s, 0) + for s in state + ] + class AmTFTCallbacks(callbacks.PolicyCallbacks): def on_train_result(self, trainer, *args, **kwargs): diff --git a/marltoolbox/algos/amTFT/inversed_policy.py b/marltoolbox/algos/amTFT/inversed_policy.py index fef207e..578dcd8 100644 --- a/marltoolbox/algos/amTFT/inversed_policy.py +++ b/marltoolbox/algos/amTFT/inversed_policy.py @@ -4,14 +4,15 @@ class InversedAmTFTRolloutsTorchPolicy( - policy_using_rollouts.AmTFTRolloutsTorchPolicy): + policy_using_rollouts.AmTFTRolloutsTorchPolicy +): """ Instead of simulating the opponent, simulate our own policy and act as if it was the opponent. """ - def _init_for_rollout(self, config): - super()._init_for_rollout(config) + def _init(self, config): + super()._init(config) self.ag_id_rollout_reward_to_read = self.own_policy_id @override(base_policy.AmTFTPolicyBase) @@ -26,86 +27,3 @@ def _switch_own_and_opp(self, agent_id): output = super()._switch_own_and_opp(agent_id) self.use_opponent_policies = not self.use_opponent_policies return output - - # @override(policy_using_rollouts.AmTFTRolloutsTorchPolicy) - # def _select_algo_to_use_in_eval(self): - # assert self.performing_rollouts - # - # if not self.use_opponent_policies: - # if self.n_steps_to_punish == 0: - # self.active_algo_idx = base.OPP_COOP_POLICY_IDX - # elif self.n_steps_to_punish > 0: - # self.active_algo_idx = base.OPP_SELFISH_POLICY_IDX - # self.n_steps_to_punish -= 1 - # else: - # raise ValueError("self.n_steps_to_punish can't be below zero") - # else: - # # assert self.performing_rollouts - # if self.n_steps_to_punish_opponent == 0: - # self.active_algo_idx = base.OWN_COOP_POLICY_IDX - # elif self.n_steps_to_punish_opponent > 0: - # self.active_algo_idx = base.OWN_SELFISH_POLICY_IDX - # self.n_steps_to_punish_opponent -= 1 - # else: - # raise ValueError("self.n_steps_to_punish_opp " - # "can't be below zero") - - # @override(policy_using_rollouts.AmTFTRolloutsTorchPolicy) - # def _init_for_rollout(self, config): - # super()._init_for_rollout(config) - # # the policies stored as opponent_policies are our own policy - # # (not the opponent's policies) - # self.use_opponent_policies = False - - # @override(policy_using_rollouts.AmTFTRolloutsTorchPolicy) - # def _prepare_to_perform_virtual_rollouts_in_env(self, worker): - # outputs = super()._prepare_to_perform_virtual_rollouts_in_env( - # worker) - # # the policies stored as opponent_policies are our own policy - # # (not the opponent's policies) - # self.use_opponent_policies = True - # return outputs - - # @override(policy_using_rollouts.AmTFTRolloutsTorchPolicy) - # def _stop_performing_virtual_rollouts_in_env(self, n_steps_to_punish): - # super()._stop_performing_virtual_rollouts_in_env(n_steps_to_punish) - # # the policies stored as opponent_policies are our own policy - # # (not the opponent's policies) - # self.use_opponent_policies = True - - # @override(policy_using_rollouts.AmTFTRolloutsTorchPolicy) - # def compute_actions( - # self, - # obs_batch: Union[List[TensorType], TensorType], - # state_batches: Optional[List[TensorType]] = None, - # prev_action_batch: Union[List[TensorType], TensorType] = None, - # prev_reward_batch: Union[List[TensorType], TensorType] = None, - # info_batch: Optional[Dict[str, list]] = None, - # episodes: Optional[List["MultiAgentEpisode"]] = None, - # explore: Optional[bool] = None, - # timestep: Optional[int] = None, - # **kwargs) -> \ - # Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: - # - # # Option to overwrite action during internal rollouts - # if self.use_opponent_policies: - # if len(self.overwrite_action) > 0: - # actions, state_out, extra_fetches = \ - # self.overwrite_action.pop(0) - # if self.verbose > 1: - # print("overwrite actions", actions, type(actions)) - # return actions, state_out, extra_fetches - # - # return super().compute_actions( - # obs_batch, state_batches, prev_action_batch, prev_reward_batch, - # info_batch, episodes, explore, timestep, **kwargs) - -# debit = self._compute_debit( -# last_obs, opp_action, worker, base_env, -# episode, env_index, coop_opp_simulated_action) - -# self.n_steps_to_punish = self._compute_punishment_duration( -# opp_action, -# coop_opp_simulated_action, -# worker, -# last_obs) diff --git a/marltoolbox/algos/amTFT/level_1_exploiter.py b/marltoolbox/algos/amTFT/level_1_exploiter.py index c395ded..7f5c52d 100644 --- a/marltoolbox/algos/amTFT/level_1_exploiter.py +++ b/marltoolbox/algos/amTFT/level_1_exploiter.py @@ -3,6 +3,7 @@ from ray.rllib import SampleBatch from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import override +from ray.rllib.utils.threading import with_lock from marltoolbox.algos import hierarchical, amTFT from marltoolbox.algos.amTFT import base @@ -32,7 +33,7 @@ def __init__(self, observation_space, action_space, config, **kwargs): self.exploiter_activated = config.get("exploiter_activated") if restore.LOAD_FROM_CONFIG_KEY in config.keys(): - restore.after_init_load_policy_checkpoint(self) + restore.before_loss_init_load_policy_checkpoint(self) if self.working_state in base.WORKING_STATES_IN_EVALUATION: if self.exploiter_activated: @@ -48,12 +49,16 @@ def working_state(self, value): def on_episode_start(self, worker, *args, **kwargs): self.worker = worker + self.simulated_opponent.on_episode_start(*args, **kwargs) def on_episode_end(self, *args, **kwargs): self.simulated_opponent.on_episode_end(*args, **kwargs) - def on_observation_fn(self, *args, **kwargs): - self.simulated_opponent.on_observation_fn(*args, **kwargs) + def on_observation_fn(self, own_new_obs, opp_new_obs, both_new_raw_obs): + self.simulated_opponent.on_observation_fn( + own_new_obs, opp_new_obs, both_new_raw_obs + ) + self.both_new_raw_obs = both_new_raw_obs def on_episode_step(self, *args, **kwargs): self.simulated_opponent.on_episode_step(*args, **kwargs) @@ -78,12 +83,14 @@ def _train_dual_policies(self, samples: SampleBatch): ) return learner_stats + @with_lock @override(level_1.Level1ExploiterTorchPolicy) - def compute_actions(self, *args, **kwargs): - + def _compute_action_helper(self, *args, **kwargs): # use the simulated amTFT opponent when performing rollouts if self.simulated_opponent.performing_rollouts: - return self.simulated_opponent.compute_actions(*args, **kwargs) + return self.simulated_opponent._compute_action_helper( + *args, **kwargs + ) self._to_log["use_simulated_opponent"] = 0.0 self._to_log["use_exploiter_cooperating"] = 0.0 @@ -93,16 +100,19 @@ def compute_actions(self, *args, **kwargs): if self.working_state not in base.WORKING_STATES_IN_EVALUATION: self._to_log["use_simulated_opponent"] = 1.0 self.active_algo_idx = self.LEVEL_1_POLICY_IDX - return self.simulated_opponent.compute_actions(*args, **kwargs) + return self.simulated_opponent._compute_action_helper( + *args, **kwargs + ) # Use the simulated amTFT opponent when not using the exploiter if not self.exploiter_activated: self._to_log["use_simulated_opponent"] = 1.0 self.active_algo_idx = self.LEVEL_1_POLICY_IDX - return self.simulated_opponent.compute_actions(*args, **kwargs) + return self.simulated_opponent._compute_action_helper( + *args, **kwargs + ) - outputs = super().compute_actions(*args, **kwargs) - return outputs + return super()._compute_action_helper(*args, **kwargs) def _is_cooperation_needed_to_prevent_punishement( self, selfish_action_tuple, *args, **kwargs @@ -117,9 +127,14 @@ def _lookahead_for_opponent_punishing( self, selfish_action, *args, **kwargs ) -> bool: selfish_action = selfish_action[0] - last_obs = kwargs["episodes"][0]._agent_to_last_raw_obs selfish_step_debit = self.simulated_opponent._compute_debit( - last_obs, selfish_action, self.worker, None, None, None, None + self.both_new_raw_obs, + selfish_action, + self.worker, + None, + None, + None, + None, ) if ( "expl_min_anticipated_gain" not in self._to_log.keys() diff --git a/marltoolbox/algos/amTFT/policy_using_rollouts.py b/marltoolbox/algos/amTFT/policy_using_rollouts.py index d6ba59b..0cc5cb6 100644 --- a/marltoolbox/algos/amTFT/policy_using_rollouts.py +++ b/marltoolbox/algos/amTFT/policy_using_rollouts.py @@ -1,15 +1,13 @@ -from typing import List, Union, Optional, Dict, Tuple - import numpy as np -from ray.rllib.evaluation import MultiAgentEpisode -from ray.rllib.utils.typing import TensorType +from ray.rllib.utils import override +from ray.rllib.utils.threading import with_lock from marltoolbox.algos.amTFT.base import ( OWN_COOP_POLICY_IDX, OWN_SELFISH_POLICY_IDX, OPP_SELFISH_POLICY_IDX, OPP_COOP_POLICY_IDX, - WORKING_STATES, + WORKING_STATES_IN_EVALUATION, ) from marltoolbox.algos.amTFT.base_policy import AmTFTPolicyBase from marltoolbox.utils import rollout @@ -18,9 +16,9 @@ class AmTFTRolloutsTorchPolicy(AmTFTPolicyBase): def __init__(self, observation_space, action_space, config, **kwargs): super().__init__(observation_space, action_space, config, **kwargs) - self._init_for_rollout(self.config) + self._init(self.config) - def _init_for_rollout(self, config): + def _init(self, config): self.last_k = config["last_k"] self.use_opponent_policies = False self.rollout_length = config["rollout_length"] @@ -31,49 +29,36 @@ def _init_for_rollout(self, config): self.opp_policy_id = config["opp_policy_id"] self.n_steps_to_punish_opponent = 0 self.ag_id_rollout_reward_to_read = self.opp_policy_id + self.last_opp_algo_idx_in_rollout = OPP_COOP_POLICY_IDX + self.last_own_algo_idx_in_rollout = OWN_COOP_POLICY_IDX + self.use_short_debit_rollout = config.get( + "use_short_debit_rollout", False + ) - # Don't support LSTM (at least because of action - # overwriting needed in the rollouts) - if "model" in config.keys(): - if "use_lstm" in config["model"].keys(): - assert not config["model"]["use_lstm"] - - def compute_actions( - self, - obs_batch: Union[List[TensorType], TensorType], - state_batches: Optional[List[TensorType]] = None, - prev_action_batch: Union[List[TensorType], TensorType] = None, - prev_reward_batch: Union[List[TensorType], TensorType] = None, - info_batch: Optional[Dict[str, list]] = None, - episodes: Optional[List["MultiAgentEpisode"]] = None, - explore: Optional[bool] = None, - timestep: Optional[int] = None, - **kwargs, - ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: - + @with_lock + @override(AmTFTPolicyBase) + def _compute_action_helper( + self, input_dict, state_batches, seq_lens, explore, timestep + ): # Option to overwrite action during internal rollouts if self.use_opponent_policies: if len(self.overwrite_action) > 0: actions, state_out, extra_fetches = self.overwrite_action.pop( 0 ) + self.last_opp_algo_idx_in_rollout = OPP_SELFISH_POLICY_IDX if self.verbose > 1: - print("overwritten actions", actions, type(actions)) + print( + "overwritten actions", actions, "state_out", state_out + ) return actions, state_out, extra_fetches - return super().compute_actions( - obs_batch, - state_batches, - prev_action_batch, - prev_reward_batch, - info_batch, - episodes, - explore, - timestep, - **kwargs, + return super()._compute_action_helper( + input_dict, state_batches, seq_lens, explore, timestep ) - def _select_algo_to_use_in_eval(self): + @override(AmTFTPolicyBase) + def _select_algo_to_use_in_eval(self, state_batches): if not self.use_opponent_policies: if self.n_steps_to_punish == 0: self.active_algo_idx = OWN_COOP_POLICY_IDX @@ -82,6 +67,16 @@ def _select_algo_to_use_in_eval(self): self.n_steps_to_punish -= 1 else: raise ValueError("self.n_steps_to_punish can't be below zero") + + if self.performing_rollouts: + state_batches = self._check_for_rnn_state_reset( + state_batches, "last_own_algo_idx_in_rollout" + ) + else: + state_batches = self._check_for_rnn_state_reset( + state_batches, "last_own_algo_idx_in_eval" + ) + else: assert self.performing_rollouts if self.n_steps_to_punish_opponent == 0: @@ -94,6 +89,17 @@ def _select_algo_to_use_in_eval(self): "self.n_steps_to_punish_opp " "can't be below zero" ) + state_batches = self._check_for_rnn_state_reset( + state_batches, "last_opp_algo_idx_in_rollout" + ) + + return state_batches + + @override(AmTFTPolicyBase) + def _track_last_coop_rnn_state(self, state_batches): + if not self.performing_rollouts and not self.use_opponent_policies: + super()._track_last_coop_rnn_state(state_batches) + def _on_episode_step( self, opp_obs, @@ -104,16 +110,16 @@ def _on_episode_step( episode, env_index, ): - if not self.performing_rollouts: - super()._on_episode_step( - opp_obs, - last_obs, - opp_action, - worker, - base_env, - episode, - env_index, - ) + assert not self.performing_rollouts + super()._on_episode_step( + opp_obs, + last_obs, + opp_action, + worker, + base_env, + episode, + env_index, + ) def _compute_debit( self, @@ -126,11 +132,13 @@ def _compute_debit( coop_opp_simulated_action, ): approximated_debit = self._compute_debit_using_rollouts( - last_obs, opp_action, worker + last_obs, opp_action, worker, base_env ) return approximated_debit - def _compute_debit_using_rollouts(self, last_obs, opp_action, worker): + def _compute_debit_using_rollouts( + self, last_obs, opp_action, worker, base_env + ): ( n_steps_to_punish, policy_map, @@ -138,6 +146,9 @@ def _compute_debit_using_rollouts(self, last_obs, opp_action, worker): ) = self._prepare_to_perform_virtual_rollouts_in_env(worker) # Cooperative rollouts + if self.verbose > 1: + print("Compute debit") + print("Cooperative rollouts") ( mean_total_reward_for_totally_coop_opp, _, @@ -148,8 +159,12 @@ def _compute_debit_using_rollouts(self, last_obs, opp_action, worker): partially_coop=False, opp_action=None, last_obs=last_obs, + rollout_length=1 if self.use_short_debit_rollout else None, + base_env=base_env, ) # Cooperative rollouts with first action as the real one + if self.verbose > 1: + print("Parital cooperative rollouts") ( mean_total_reward_for_partially_coop_opp, _, @@ -160,6 +175,8 @@ def _compute_debit_using_rollouts(self, last_obs, opp_action, worker): partially_coop=True, opp_action=opp_action, last_obs=last_obs, + rollout_length=1 if self.use_short_debit_rollout else None, + base_env=base_env, ) if self.verbose > 0: @@ -489,24 +506,30 @@ def _compute_opp_mean_total_reward( opp_action, last_obs, k_to_explore=0, + rollout_length=None, + base_env=None, ): opp_total_rewards = [] for i in range(self.n_rollout_replicas // 2): + self.last_opp_algo_idx_in_rollout = OPP_COOP_POLICY_IDX + self.last_own_algo_idx_in_rollout = OWN_COOP_POLICY_IDX self.n_steps_to_punish = k_to_explore self.n_steps_to_punish_opponent = k_to_explore if partially_coop: - assert len(self.overwrite_action) == 0 - self.overwrite_action = [ - (np.array([opp_action]), [], {}), - ] + self._set_overwrite_action(opp_action) + last_rnn_states = self._get_last_rnn_states_before_rollouts() coop_rollout = rollout.internal_rollout( worker, - num_steps=self.rollout_length, + num_steps=self.rollout_length + if rollout_length is None + else rollout_length, policy_map=policy_map, last_obs=last_obs, policy_agent_mapping=policy_agent_mapping, reset_env_before=False, num_episodes=1, + last_rnn_states=last_rnn_states, + base_env=base_env, ) assert ( coop_rollout._num_episodes == 1 @@ -536,8 +559,33 @@ def _compute_opp_mean_total_reward( opp_mean_total_reward = sum(opp_total_rewards) / len(opp_total_rewards) return opp_mean_total_reward, n_steps_played + def _set_overwrite_action(self, opp_action): + assert len(self.overwrite_action) == 0 + if self.config["model"]["use_lstm"]: + # When we play the real opponent action, don't break the sequence + # of rnn state for the cooperative opponent + # rnn_state = [[el] for el in self.coop_opp_rnn_state_after_last_act] + rnn_state = [ + self._get_initial_rnn_state(None) + for el in self.coop_opp_rnn_state_after_last_act + ] + + else: + rnn_state = [] + self.overwrite_action = [ + (np.array([opp_action]), rnn_state, {}), + ] + + @override(AmTFTPolicyBase) def on_episode_end(self, *args, **kwargs): - if self.working_state == WORKING_STATES[2]: - self.total_debit = 0 - self.n_steps_to_punish = 0 + assert not self.performing_rollouts + if self.working_state in WORKING_STATES_IN_EVALUATION: self.n_steps_to_punish_opponent = 0 + super().on_episode_end(*args, **kwargs) + + @override(AmTFTPolicyBase) + def on_episode_start(self, *args, **kwargs): + assert not self.performing_rollouts + if self.working_state in WORKING_STATES_IN_EVALUATION: + self.last_own_algo_idx_in_eval = OWN_COOP_POLICY_IDX + super().on_episode_start(*args, **kwargs) diff --git a/marltoolbox/algos/amTFT/train_helper.py b/marltoolbox/algos/amTFT/train_helper.py index 9d6109b..34dafd8 100644 --- a/marltoolbox/algos/amTFT/train_helper.py +++ b/marltoolbox/algos/amTFT/train_helper.py @@ -5,11 +5,12 @@ from ray import tune from ray.rllib.agents.dqn import DQNTrainer +from marltoolbox import utils from marltoolbox.algos import amTFT from marltoolbox.scripts.aggregate_and_plot_tensorboard_data import ( add_summary_plots, ) -from marltoolbox.utils import miscellaneous, restore +from marltoolbox.utils import miscellaneous, restore, tune_analysis def train_amtft( @@ -117,10 +118,10 @@ def _get_plot_keys(plot_keys, plot_assemblage_tags): def _extract_selfish_policies_checkpoints(tune_analysis_selfish_policies): - checkpoints = miscellaneous.extract_checkpoints( + checkpoints = restore.extract_checkpoints_from_tune_analysis( tune_analysis_selfish_policies ) - seeds = miscellaneous.extract_config_values_from_tune_analysis( + seeds = utils.tune_analysis.extract_config_values_from_tune_analysis( tune_analysis_selfish_policies, "seed" ) seed_to_checkpoint = {} diff --git a/marltoolbox/algos/amTFT/weights_exchanger.py b/marltoolbox/algos/amTFT/weights_exchanger.py index c296a05..380e812 100644 --- a/marltoolbox/algos/amTFT/weights_exchanger.py +++ b/marltoolbox/algos/amTFT/weights_exchanger.py @@ -44,6 +44,7 @@ def _share_weights_during_training(trainer): local_policy_map ) if in_training: + print("amTFT _share_weights_during_training") WeightsExchanger._check_only_amTFT_policies(local_policy_map) policies_weights = trainer.get_weights() policies_weights = ( diff --git a/marltoolbox/algos/augmented_dqn.py b/marltoolbox/algos/augmented_dqn.py index 0cbf986..e5bbe7c 100644 --- a/marltoolbox/algos/augmented_dqn.py +++ b/marltoolbox/algos/augmented_dqn.py @@ -1,7 +1,6 @@ -from typing import Dict - import torch import torch.nn.functional as F +from ray.rllib.agents.dqn import dqn_torch_policy from ray.rllib.agents.dqn.dqn_tf_policy import PRIO_WEIGHTS from ray.rllib.agents.dqn.dqn_torch_policy import ComputeTDErrorMixin from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy @@ -11,6 +10,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.torch_ops import FLOAT_MIN from ray.rllib.utils.typing import TensorType +from typing import Dict from marltoolbox.utils import log, optimizers, policy @@ -20,36 +20,34 @@ def build_q_losses_wt_additional_logs( ) -> TensorType: """ Copy of build_q_losses with additional values saved into the policy - Made only 2 changes, see in comments. + Made only 1 change, see in comments. """ config = policy.config # Q-network evaluation. - q_t, q_logits_t, q_probs_t = compute_q_values( + q_t, q_logits_t, q_probs_t, _ = compute_q_values( policy, - policy.q_model, - train_batch[SampleBatch.CUR_OBS], + model, + {"obs": train_batch[SampleBatch.CUR_OBS]}, explore=False, is_training=True, ) - # Addition 1 out of 2 - policy.last_q_t = q_t.clone() - # Target Q-network evaluation. - q_tp1, q_logits_tp1, q_probs_tp1 = compute_q_values( + q_tp1, _, q_probs_tp1, _ = compute_q_values( policy, policy.target_q_model, - train_batch[SampleBatch.NEXT_OBS], + {"obs": train_batch[SampleBatch.NEXT_OBS]}, explore=False, is_training=True, ) - # Addition 2 out of 2 - policy.last_target_q_t = q_tp1.clone() + # Only additions + policy.last_q_t = q_t.clone().detach() + policy.last_target_q_t = q_tp1.clone().detach() # Q scores for actions which we know were selected in the given state. one_hot_selection = F.one_hot( - train_batch[SampleBatch.ACTIONS], policy.action_space.n + train_batch[SampleBatch.ACTIONS].long(), policy.action_space.n ) q_t_selected = torch.sum( torch.where( @@ -68,10 +66,11 @@ def build_q_losses_wt_additional_logs( q_tp1_using_online_net, q_logits_tp1_using_online_net, q_dist_tp1_using_online_net, + _, ) = compute_q_values( policy, - policy.q_model, - train_batch[SampleBatch.NEXT_OBS], + model, + {"obs": train_batch[SampleBatch.NEXT_OBS]}, explore=False, is_training=True, ) @@ -108,20 +107,12 @@ def build_q_losses_wt_additional_logs( q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1 ) - if PRIO_WEIGHTS not in train_batch.keys(): - assert config["prioritized_replay"] is False - prio_weights = torch.tensor( - [1.0] * len(train_batch[SampleBatch.REWARDS]) - ).to(policy.device) - else: - prio_weights = train_batch[PRIO_WEIGHTS] - policy.q_loss = QLoss( q_t_selected, q_logits_t_selected, q_tp1_best, q_probs_tp1_best, - prio_weights, + train_batch[PRIO_WEIGHTS], train_batch[SampleBatch.REWARDS], train_batch[SampleBatch.DONES].float(), config["gamma"], @@ -134,28 +125,25 @@ def build_q_losses_wt_additional_logs( return policy.q_loss.loss -def build_q_stats_wt_addtional_log( - policy: Policy, batch -) -> Dict[str, TensorType]: - entropy_avg, entropy_single = log._compute_entropy_from_raw_q_values( - policy, policy.last_q_t.clone() - ) +def my_build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]: + q_stats = dqn_torch_policy.build_q_stats(policy, batch) - return dict( + entropy_avg, _ = log.compute_entropy_from_raw_q_values( + policy, policy.last_q_t + ) + q_stats.update( { "entropy_avg": entropy_avg, - "cur_lr": policy.cur_lr, - }, - **policy.q_loss.stats, + } ) + return q_stats + MyDQNTorchPolicy = DQNTorchPolicy.with_updates( optimizer_fn=optimizers.sgd_optimizer_dqn, loss_fn=build_q_losses_wt_additional_logs, - stats_fn=log.augment_stats_fn_wt_additionnal_logs( - build_q_stats_wt_addtional_log - ), + stats_fn=log.augment_stats_fn_wt_additionnal_logs(my_build_q_stats), before_init=policy.my_setup_early_mixins, mixins=[ TargetNetworkMixin, diff --git a/marltoolbox/algos/augmented_r2d2.py b/marltoolbox/algos/augmented_r2d2.py new file mode 100644 index 0000000..0a46d65 --- /dev/null +++ b/marltoolbox/algos/augmented_r2d2.py @@ -0,0 +1,224 @@ +"""PyTorch policy class used for R2D2.""" + +import torch +import torch.nn.functional as F +from ray.rllib.agents.dqn import r2d2_torch_policy +from ray.rllib.agents.dqn.dqn_tf_policy import PRIO_WEIGHTS +from ray.rllib.agents.dqn.dqn_torch_policy import ComputeTDErrorMixin +from ray.rllib.agents.dqn.dqn_torch_policy import compute_q_values +from ray.rllib.agents.dqn.r2d2_torch_policy import ( + R2D2TorchPolicy, + h_function, + h_inverse, +) +from ray.rllib.agents.dqn.simple_q_torch_policy import TargetNetworkMixin +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.torch_ops import FLOAT_MIN +from ray.rllib.utils.torch_ops import huber_loss, sequence_mask, l2_loss +from ray.rllib.utils.typing import TensorType +from typing import Dict + +from marltoolbox.utils import log, optimizers, policy + + +def my_r2d2_loss( + policy: Policy, model, _, train_batch: SampleBatch +) -> TensorType: + """Constructs the loss for R2D2TorchPolicy. + + Args: + policy (Policy): The Policy to calculate the loss for. + model (ModelV2): The Model to calculate the loss for. + train_batch (SampleBatch): The training data. + + Returns: + TensorType: A single loss tensor. + """ + config = policy.config + + # Construct internal state inputs. + i = 0 + state_batches = [] + while "state_in_{}".format(i) in train_batch: + state_batches.append(train_batch["state_in_{}".format(i)]) + i += 1 + assert state_batches + + # Q-network evaluation (at t). + q, _, _, _ = compute_q_values( + policy, + model, + train_batch, + state_batches=state_batches, + seq_lens=train_batch.get("seq_lens"), + explore=False, + is_training=True, + ) + + # Target Q-network evaluation (at t+1). + q_target, _, _, _ = compute_q_values( + policy, + policy.target_q_model, + train_batch, + state_batches=state_batches, + seq_lens=train_batch.get("seq_lens"), + explore=False, + is_training=True, + ) + + # Only additions + policy.last_q = q.clone().detach() + policy.last_q_target = q_target.clone().detach() + + actions = train_batch[SampleBatch.ACTIONS].long() + dones = train_batch[SampleBatch.DONES].float() + rewards = train_batch[SampleBatch.REWARDS] + weights = train_batch[PRIO_WEIGHTS] + + B = state_batches[0].shape[0] + T = q.shape[0] // B + + # Q scores for actions which we know were selected in the given state. + one_hot_selection = F.one_hot(actions, policy.action_space.n) + q_selected = torch.sum( + torch.where(q > FLOAT_MIN, q, torch.tensor(0.0, device=policy.device)) + * one_hot_selection, + 1, + ) + + if config["double_q"]: + best_actions = torch.argmax(q, dim=1) + else: + best_actions = torch.argmax(q_target, dim=1) + + best_actions_one_hot = F.one_hot(best_actions, policy.action_space.n) + q_target_best = torch.sum( + torch.where( + q_target > FLOAT_MIN, + q_target, + torch.tensor(0.0, device=policy.device), + ) + * best_actions_one_hot, + dim=1, + ) + + if config["num_atoms"] > 1: + raise ValueError("Distributional R2D2 not supported yet!") + else: + q_target_best_masked_tp1 = (1.0 - dones) * torch.cat( + [q_target_best[1:], torch.tensor([0.0], device=policy.device)] + ) + + if config["use_h_function"]: + h_inv = h_inverse( + q_target_best_masked_tp1, config["h_function_epsilon"] + ) + target = h_function( + rewards + config["gamma"] ** config["n_step"] * h_inv, + config["h_function_epsilon"], + ) + else: + target = ( + rewards + + config["gamma"] ** config["n_step"] + * q_target_best_masked_tp1 + ) + + # Seq-mask all loss-related terms. + seq_mask = sequence_mask(train_batch["seq_lens"], T)[:, :-1] + # Mask away also the burn-in sequence at the beginning. + burn_in = policy.config["burn_in"] + if burn_in > 0 and burn_in < T: + seq_mask[:, :burn_in] = False + + num_valid = torch.sum(seq_mask) + + def reduce_mean_valid(t): + return torch.sum(t[seq_mask]) / num_valid + + # Make sure use the correct time indices: + # Q(t) - [gamma * r + Q^(t+1)] + q_selected = q_selected.reshape([B, T])[:, :-1] + td_error = q_selected - target.reshape([B, T])[:, :-1].detach() + td_error = td_error * seq_mask + weights = weights.reshape([B, T])[:, :-1] + + return weights, td_error, q_selected, reduce_mean_valid + + +def my_r2d2_loss_wt_huber_loss( + policy: Policy, model, _, train_batch: SampleBatch +) -> TensorType: + + weights, td_error, q_selected, reduce_mean_valid = my_r2d2_loss( + policy, model, _, train_batch + ) + + policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error)) + policy._td_error = td_error.reshape([-1]) + policy._loss_stats = { + "mean_q": reduce_mean_valid(q_selected), + "min_q": torch.min(q_selected), + "max_q": torch.max(q_selected), + "mean_td_error": reduce_mean_valid(td_error), + } + + return policy._total_loss + + +def my_r2d2_loss_wtout_huber_loss( + policy: Policy, model, _, train_batch: SampleBatch +) -> TensorType: + + weights, td_error, q_selected, reduce_mean_valid = my_r2d2_loss( + policy, model, _, train_batch + ) + + policy._total_loss = reduce_mean_valid(weights * l2_loss(td_error)) + policy._td_error = td_error.reshape([-1]) + policy._loss_stats = { + "mean_q": reduce_mean_valid(q_selected), + "min_q": torch.min(q_selected), + "max_q": torch.max(q_selected), + "mean_td_error": reduce_mean_valid(td_error), + } + + return policy._total_loss + + +def my_build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]: + q_stats = r2d2_torch_policy.build_q_stats(policy, batch) + + entropy_avg, _ = log.compute_entropy_from_raw_q_values( + policy, policy.last_q.clone() + ) + q_stats.update( + { + "entropy_avg": entropy_avg, + } + ) + + return q_stats + + +MyR2D2TorchPolicy = R2D2TorchPolicy.with_updates( + optimizer_fn=optimizers.sgd_optimizer_dqn, + loss_fn=my_r2d2_loss_wt_huber_loss, + stats_fn=log.augment_stats_fn_wt_additionnal_logs(my_build_q_stats), + before_init=policy.my_setup_early_mixins, + mixins=[ + TargetNetworkMixin, + ComputeTDErrorMixin, + policy.MyLearningRateSchedule, + ], +) + +MyR2D2TorchPolicyWtMSELoss = MyR2D2TorchPolicy.with_updates( + loss_fn=my_r2d2_loss_wtout_huber_loss, +) + +MyAdamDQNTorchPolicy = MyR2D2TorchPolicy.with_updates( + optimizer_fn=optimizers.adam_optimizer_dqn, +) diff --git a/marltoolbox/algos/exploiters/dual_behavior.py b/marltoolbox/algos/exploiters/dual_behavior.py index 31c8cfd..13e1527 100644 --- a/marltoolbox/algos/exploiters/dual_behavior.py +++ b/marltoolbox/algos/exploiters/dual_behavior.py @@ -4,6 +4,7 @@ from ray.rllib import SampleBatch from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import override +from ray.rllib.utils.threading import with_lock from marltoolbox.algos import hierarchical from marltoolbox.utils import postprocessing @@ -34,6 +35,7 @@ def __init__(self, observation_space, action_space, config, **kwargs): super().__init__(observation_space, action_space, config, **kwargs) self.active_algo_idx = self.SELFISH_POLICY_IDX + @with_lock @override(hierarchical.HierarchicalTorchPolicy) def _learn_on_batch(self, samples: SampleBatch): return self._train_dual_policies(samples) diff --git a/marltoolbox/algos/exploiters/level_1.py b/marltoolbox/algos/exploiters/level_1.py index cb2807a..b5bd8d8 100644 --- a/marltoolbox/algos/exploiters/level_1.py +++ b/marltoolbox/algos/exploiters/level_1.py @@ -1,10 +1,11 @@ import logging -from typing import TYPE_CHECKING from ray.rllib import SampleBatch -from ray.rllib.policy import Policy +from ray.rllib.policy import TorchPolicy from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import override +from ray.rllib.utils.threading import with_lock +from typing import TYPE_CHECKING from marltoolbox.algos.exploiters import dual_behavior @@ -17,7 +18,7 @@ dual_behavior.DEFAULT_CONFIG, { "lookahead_n_times": int(1), - } + }, ) @@ -34,6 +35,7 @@ class Level1ExploiterTorchPolicy(dual_behavior.DualBehaviorTorchPolicy): """ + LEVEL_1_POLICY_IDX = 2 def __init__(self, observation_space, action_space, config, **kwargs): @@ -50,13 +52,14 @@ def _learn_on_batch(self, samples: SampleBatch): def _train_level_1_policy(self, learner_stats, samples): raise NotImplementedError() - @override(Policy) - def compute_actions(self, *args, **kwargs): - + @with_lock + @override(TorchPolicy) + def _compute_action_helper(self, *args, **kwargs): selfish_action_tuple = self._compute_selfish_action(*args, **kwargs) coop_needed = self._is_cooperation_needed_to_prevent_punishement( - selfish_action_tuple, *args, **kwargs) + selfish_action_tuple, *args, **kwargs + ) self._to_log["coop_needed"] = coop_needed @@ -64,7 +67,7 @@ def compute_actions(self, *args, **kwargs): self._to_log["use_exploiter_cooperating"] = 1.0 self._to_log["use_exploiter_selfish"] = 0.0 self.active_algo_idx = self.COOP_POLICY_IDX - return super().compute_actions(*args, **kwargs) + return super()._compute_action_helper(*args, **kwargs) else: self._to_log["use_exploiter_cooperating"] = 0.0 self._to_log["use_exploiter_selfish"] = 1.0 @@ -74,12 +77,13 @@ def compute_actions(self, *args, **kwargs): def _compute_selfish_action(self, *args, **kwargs): tmp_active_algo_idx = self.active_algo_idx self.active_algo_idx = self.SELFISH_POLICY_IDX - selfish_action_tuple = super().compute_actions(*args, **kwargs) + selfish_action_tuple = super()._compute_action_helper(*args, **kwargs) self.active_algo_idx = tmp_active_algo_idx return selfish_action_tuple def _is_cooperation_needed_to_prevent_punishement( - self, selfish_action_tuple, *args, **kwargs): + self, selfish_action_tuple, *args, **kwargs + ): selfish_action = selfish_action_tuple[0] @@ -87,23 +91,26 @@ def _is_cooperation_needed_to_prevent_punishement( is_going_to_punish_at_next_step = False for _ in range(self.lookahead_n_times): - is_going_to_punish_at_next_step = \ + is_going_to_punish_at_next_step = ( self._lookahead_for_opponent_punishing( - selfish_action, - *args, - **kwargs) + selfish_action, *args, **kwargs + ) + ) if is_going_to_punish_at_next_step: break self._to_log["simu_opp_is_punishing"] = is_punishing - self._to_log["simu_opp_is_going_to_punish_at_next_step"] = \ - is_going_to_punish_at_next_step - return (is_punishing or is_going_to_punish_at_next_step) + self._to_log[ + "simu_opp_is_going_to_punish_at_next_step" + ] = is_going_to_punish_at_next_step + return is_punishing or is_going_to_punish_at_next_step def _lookahead_for_opponent_punishing( - self, selfish_action, *args, **kwargs) -> bool: + self, selfish_action, *args, **kwargs + ) -> bool: raise NotImplementedError() + # class ObserveOpponentCallbacks(DefaultCallbacks): # # def on_postprocess_trajectory( diff --git a/marltoolbox/algos/hierarchical.py b/marltoolbox/algos/hierarchical.py index aa6ab30..c7c3880 100644 --- a/marltoolbox/algos/hierarchical.py +++ b/marltoolbox/algos/hierarchical.py @@ -1,13 +1,14 @@ import copy import logging -from typing import List, Union, Optional, Dict, Tuple, Iterable + +from ray.rllib.utils.threading import with_lock +from typing import List, Union, Iterable import torch from ray import rllib -from ray.rllib.evaluation import MultiAgentEpisode from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.typing import TensorType +from ray.rllib.utils.annotations import override from marltoolbox.utils import log @@ -41,9 +42,9 @@ def __init__( updated_config.update(nested_config["config_update"]) if nested_config["Policy_class"] is None: raise ValueError( - f"You must specify the classes for the nested Policies " - f'in config["nested_config"]["Policy_class"] ' - f'current value is {nested_config["Policy_class"]}.' + "You must specify the classes for the nested Policies " + 'in config["nested_config"]["Policy_class"]. ' + f"nested_config: {nested_config}." ) Policy = nested_config["Policy_class"] policy = Policy( @@ -70,6 +71,21 @@ def __init__( self._to_log = {} self._already_printed_warnings = [] + self._merge_all_view_requirements() + + def _merge_all_view_requirements(self): + self.view_requirements = {} + for algo in self.algorithms: + for k, v in algo.view_requirements.items(): + self._add_in_view_requirement_or_check_are_equals(k, v) + + def _add_in_view_requirement_or_check_are_equals(self, k, v): + if k not in self.view_requirements.keys(): + self.view_requirements[k] = v + else: + assert vars(self.view_requirements[k]) == vars( + v + ), f"{vars(self.view_requirements[k])} must equal {vars(v)}" def __getattribute__(self, attr): """ @@ -87,10 +103,8 @@ def __getattribute__(self, attr): f"printing this message." ) logger.info(msg) - # from warnings import warn - # warn(msg) - return object.__getattribute__( - self.algorithms[self.active_algo_idx], attr + return self.algorithms[self.active_algo_idx].__getattribute__( + attr ) except AttributeError as secondary: raise type(initial)(f"{initial.args} and {secondary.args}") @@ -113,7 +127,6 @@ def to_log(self, value): for algo in self.algorithms: if hasattr(algo, "to_log"): algo.to_log = {} - # setattr(algo, "to_log", {}) self._to_log = value @@ -125,6 +138,14 @@ def model(self): def model(self, value): self.algorithms[self.active_algo_idx].model = value + @property + def exploration(self): + return self.algorithms[self.active_algo_idx].exploration + + @exploration.setter + def exploration(self, value): + self.algorithms[self.active_algo_idx].exploration = value + @property def dist_class(self): return self.algorithms[self.active_algo_idx].dist_class @@ -142,9 +163,11 @@ def global_timestep(self, value): for algo in self.algorithms: algo.global_timestep = value + @override(rllib.policy.Policy) def on_global_var_update(self, global_vars): for algo in self.algorithms: algo.on_global_var_update(global_vars) + super().on_global_var_update(global_vars) @property def update_target(self): @@ -155,12 +178,14 @@ def nested_update_target(): return nested_update_target + @override(rllib.policy.TorchPolicy) def get_weights(self): return { self.nested_key(i): algo.get_weights() for i, algo in enumerate(self.algorithms) } + @override(rllib.policy.TorchPolicy) def set_weights(self, weights): for i, algo in enumerate(self.algorithms): algo.set_weights(weights[self.nested_key(i)]) @@ -168,27 +193,22 @@ def set_weights(self, weights): def nested_key(self, i): return f"nested_{i}" - def compute_actions( - self, - obs_batch: Union[List[TensorType], TensorType], - state_batches: Optional[List[TensorType]] = None, - prev_action_batch: Union[List[TensorType], TensorType] = None, - prev_reward_batch: Union[List[TensorType], TensorType] = None, - info_batch: Optional[Dict[str, list]] = None, - episodes: Optional[List["MultiAgentEpisode"]] = None, - explore: Optional[bool] = None, - timestep: Optional[int] = None, - **kwargs, - ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: - + @with_lock + @override(rllib.policy.TorchPolicy) + def _compute_action_helper( + self, input_dict, state_batches, seq_lens, explore, timestep + ): actions, state_out, extra_fetches = self.algorithms[ self.active_algo_idx - ].compute_actions(obs_batch) + ]._compute_action_helper( + input_dict, state_batches, seq_lens, explore, timestep + ) return actions, state_out, extra_fetches + @with_lock def learn_on_batch(self, samples: SampleBatch): - self._update_lr_in_all_optimizers() + stats = self._learn_on_batch(samples) self._log_learning_rates() return stats @@ -217,22 +237,12 @@ def optimizer( The local PyTorch optimizer(s) to use for this Policy. """ - # TODO find a clean solution to update the LR when using a LearningRateSchedule - # TODO this will probably no more be needed when moving to RLLib>1.0.0 - for algo in self.algorithms: - if hasattr(algo, "cur_lr"): - for opt in algo._optimizers: - for p in opt.param_groups: - p["lr"] = algo.cur_lr - all_optimizers = [] for algo_n, algo in enumerate(self.algorithms): opt = algo.optimizer() all_optimizers.extend(opt) return all_optimizers - # TODO Move this in helper functions - def postprocess_trajectory( self, sample_batch, other_agent_batches=None, episode=None ): @@ -245,14 +255,13 @@ def _log_learning_rates(self): Use to log LR to check that they are really updated as configured. """ for algo_idx, algo in enumerate(self.algorithms): - self.to_log[f"algo{algo_idx}"] = log._log_learning_rate(algo) + self.to_log[f"algo{algo_idx}"] = log.log_learning_rate(algo) def set_state(self, state: object) -> None: state = state.copy() # shallow copy # Set optimizer vars first. optimizer_vars = state.pop("_optimizer_variables", None) if optimizer_vars: - print("self", self) assert len(optimizer_vars) == len( self._optimizers ), f"{len(optimizer_vars)} {len(self._optimizers)}" @@ -260,3 +269,12 @@ def set_state(self, state: object) -> None: o.load_state_dict(s) # Then the Policy's (NN) weights. super().set_state(state) + + def _get_dummy_batch_from_view_requirements( + self, batch_size: int = 1 + ) -> SampleBatch: + dummy_sample_batch = super()._get_dummy_batch_from_view_requirements( + batch_size + ) + dummy_sample_batch[dummy_sample_batch.DONES][-1] = True + return dummy_sample_batch diff --git a/marltoolbox/algos/lola/__init__.py b/marltoolbox/algos/lola/__init__.py index 7e97b24..553ba18 100644 --- a/marltoolbox/algos/lola/__init__.py +++ b/marltoolbox/algos/lola/__init__.py @@ -1,3 +1,3 @@ from marltoolbox.algos.lola.train_cg_tune_class_API import LOLAPGCG -from marltoolbox.algos.lola.train_exact_tune_class_API import LOLAExact +from marltoolbox.algos.lola.train_exact_tune_class_API import LOLAExactTrainer from marltoolbox.algos.lola.train_pg_tune_class_API import LOLAPGMatrice diff --git a/marltoolbox/algos/lola/networks.py b/marltoolbox/algos/lola/networks.py index dd01eaf..d17c8f8 100644 --- a/marltoolbox/algos/lola/networks.py +++ b/marltoolbox/algos/lola/networks.py @@ -100,7 +100,7 @@ def __init__( self.sample_reward_bis = tf.placeholder( shape=[None, trace_length], dtype=tf.float32, - name="sample_reward", + name="sample_reward_bis", ) self.j = tf.placeholder(shape=[None], dtype=tf.float32, name="j") @@ -239,6 +239,7 @@ def __init__( shape=[None, 1], dtype=tf.float32, name="next_value" ) self.next_v = tf.matmul(self.next_value, self.gamma_array_inverse) + if use_critic: self.target = self.sample_reward_bis + self.next_v else: @@ -325,6 +326,7 @@ def __init__( entropy_coeff * self.entropy + weigth_decay * self.weigths_norm ) * self.loss_multiplier + self.updateModel = self.trainer.minimize( total_loss, var_list=self.value_params ) diff --git a/marltoolbox/algos/lola/train_cg_tune_class_API.py b/marltoolbox/algos/lola/train_cg_tune_class_API.py index 7a5f903..ac93645 100644 --- a/marltoolbox/algos/lola/train_cg_tune_class_API.py +++ b/marltoolbox/algos/lola/train_cg_tune_class_API.py @@ -21,7 +21,7 @@ from marltoolbox.algos.lola.networks import Pnetwork from marltoolbox.algos.lola.utils import get_monte_carlo, make_cube from marltoolbox.envs.vectorized_coin_game import AsymVectorizedCoinGame -from marltoolbox.utils.full_epi_logger import FullEpisodeLogger +from marltoolbox.utils.log.full_epi_logger import FullEpisodeLogger PLOT_KEYS = [ "player_1_loss", @@ -141,7 +141,7 @@ def _init_lola( grid_size, gamma, hidden, - bs_mul, + global_lr_divider, lr, env_config, mem_efficient=True, @@ -168,6 +168,14 @@ def _init_lola( use_destabilizer=False, adding_scaled_weights=False, always_train_PG=False, + use_normalized_rewards=False, + use_centered_reward=False, + use_rolling_avg_actor_grad=False, + process_reward_after_rolling=False, + only_process_reward=False, + use_rolling_avg_reward=False, + reward_processing_bais=False, + center_and_normalize_with_rolling_avg=False, **kwargs, ): @@ -194,7 +202,6 @@ def _init_lola( self.grid_size = grid_size self.gamma = gamma self.hidden = hidden - self.bs_mul = bs_mul self.lr = lr self.mem_efficient = mem_efficient self.asymmetry = env_class == AsymVectorizedCoinGame @@ -232,10 +239,47 @@ def _init_lola( assert self.adding_scaled_weights > 0.0 self.always_train_PG = always_train_PG self.last_term_to_use = 0.0 + self.use_normalized_rewards = use_normalized_rewards + self.use_centered_reward = use_centered_reward + self.use_rolling_avg_actor_grad = use_rolling_avg_actor_grad + self.process_reward_after_rolling = process_reward_after_rolling + self.only_process_reward = only_process_reward + self.use_rolling_avg_reward = use_rolling_avg_reward + self.reward_processing_bais = reward_processing_bais + self.center_and_normalize_with_rolling_avg = ( + center_and_normalize_with_rolling_avg + ) + self._reward_rolling_avg = {} + self._reward_rolling_avg_length = 100 + + if self.use_rolling_avg_reward and self.use_rolling_avg_actor_grad: + self.use_rolling_avg_actor_grad = False + print("self.use_normalized_rewards", self.use_normalized_rewards) + print("self.use_centered_reward", self.use_centered_reward) + print( + "self.use_rolling_avg_actor_grad", self.use_rolling_avg_actor_grad + ) + + global_lr_divider_mul = 1.0 + if self.use_normalized_rewards: + global_lr_divider_mul *= 2.0 + if self.use_centered_reward: + global_lr_divider_mul *= 1.0 + if self.use_rolling_avg_actor_grad: + global_lr_divider_mul *= 1 / 20.0 + if self.use_rolling_avg_reward: + global_lr_divider_mul *= 1 / 10000.0 + + self.global_lr_divider = global_lr_divider * global_lr_divider_mul + self.global_lr_divider_original = ( + global_lr_divider * global_lr_divider_mul + ) self.obs_batch = deque(maxlen=self.batch_size) self.full_episode_logger = FullEpisodeLogger( - logdir=self._logdir, log_interval=100, log_ful_epi_one_hot_obs=True + logdir=self._logdir, + log_interval=100, + convert_one_hot_obs_to_idx=True, ) # Setting the training parameters @@ -249,6 +293,10 @@ def _init_lola( self.max_epLength = ( trace_length + 1 ) # The max allowed length of our episode. + self.actor_rolling_avg = [ + 10.0 / global_lr_divider_mul + for agent in range(self.total_n_agents) + ] graph = tf.Graph() @@ -301,69 +349,23 @@ def _init_lola( use_critic=use_critic, ) ) - # Clones of the opponents - if opp_model: - self.mainPN_clone = [] - for agent in range(self.total_n_agents): - self.mainPN_clone.append( - Pnetwork( - f"clone_{agent}", - self.h_size[agent], - agent, - self.env, - trace_length=trace_length, - batch_size=batch_size, - changed_config=changed_config, - ac_lr=ac_lr, - use_MAE=use_MAE, - clip_loss_norm=clip_loss_norm, - sess=self.sess, - entropy_coeff=entropy_coeff, - use_critic=use_critic, - ) - ) if not mem_efficient: self.cube, self.cube_ops = make_cube(trace_length) else: self.cube, self.cube_ops = None, None - if not opp_model: - corrections_func( - self.mainPN, - batch_size, - trace_length, - corrections, - self.cube, - clip_lola_update_norm=clip_lola_update_norm, - lola_correction_multiplier=self.lola_correction_multiplier, - clip_lola_correction_norm=clip_lola_correction_norm, - clip_lola_actor_norm=clip_lola_actor_norm, - ) - else: - corrections_func( - [self.mainPN[0], self.mainPN_clone[1]], - batch_size, - trace_length, - corrections, - self.cube, - clip_lola_update_norm=clip_lola_update_norm, - lola_correction_multiplier=self.lola_correction_multiplier, - clip_lola_correction_norm=clip_lola_correction_norm, - clip_lola_actor_norm=clip_lola_actor_norm, - ) - corrections_func( - [self.mainPN[1], self.mainPN_clone[0]], - batch_size, - trace_length, - corrections, - self.cube, - clip_lola_update_norm=clip_lola_update_norm, - lola_correction_multiplier=self.lola_correction_multiplier, - clip_lola_correction_norm=clip_lola_correction_norm, - clip_lola_actor_norm=clip_lola_actor_norm, - ) - clone_update(self.mainPN_clone) + corrections_func( + self.mainPN, + batch_size, + trace_length, + corrections, + self.cube, + clip_lola_update_norm=clip_lola_update_norm, + lola_correction_multiplier=self.lola_correction_multiplier, + clip_lola_correction_norm=clip_lola_correction_norm, + clip_lola_actor_norm=clip_lola_actor_norm, + ) if self.use_PG_exploiter: simple_actor_training_func( @@ -553,19 +555,24 @@ def step(self): pow_series = np.arange(self.trace_length) discount = np.array([pow(self.gamma, item) for item in pow_series]) + reward_mean_p0 = np.array(trainBatch0[2]).mean() + reward_mean_p1 = np.array(trainBatch1[2]).mean() + reward_std_p0 = np.array(trainBatch0[2]).std() + reward_std_p1 = np.array(trainBatch1[2]).std() + ( sample_return0, sample_reward0, sample_reward0_bis, ) = self.compute_centered_discounted_r( - rewards=trainBatch0[2], discount=discount + rewards=trainBatch0[2], discount=discount, player_key="player_0" ) ( sample_return1, sample_reward1, sample_reward1_bis, ) = self.compute_centered_discounted_r( - rewards=trainBatch1[2], discount=discount + rewards=trainBatch1[2], discount=discount, player_key="player_1" ) state_input0 = np.concatenate(trainBatch0[0], axis=0) @@ -607,34 +614,6 @@ def step(self): }, ) - if self.opp_model: - ## update local clones - update_clone = [ - self.mainPN_clone[0].update, - self.mainPN_clone[1].update, - ] - feed_dict = { - self.mainPN_clone[0].state_input: state_input1, - self.mainPN_clone[0].actions: actions1, - self.mainPN_clone[0].sample_return: sample_return1, - self.mainPN_clone[0].sample_reward: sample_reward1, - self.mainPN_clone[1].state_input: state_input0, - self.mainPN_clone[1].actions: actions0, - self.mainPN_clone[1].sample_return: sample_return0, - self.mainPN_clone[1].sample_reward: sample_reward0, - self.mainPN_clone[0].gamma_array: np.reshape( - discount, [1, -1] - ), - self.mainPN_clone[1].gamma_array: np.reshape( - discount, [1, -1] - ), - self.mainPN_clone[0].is_training: True, - self.mainPN_clone[1].is_training: True, - } - num_loops = 50 if self.timestep == 0 else 1 - for _ in range(num_loops): - self.sess.run(update_clone, feed_dict=feed_dict) - if self.lr_decay: lr_decay = (self.num_episodes - self.timestep) / self.num_episodes else: @@ -667,25 +646,6 @@ def step(self): self.mainPN[0].is_training: True, self.mainPN[1].is_training: True, } - if self.opp_model: - feed_dict.update( - { - self.mainPN_clone[0].state_input: state_input1, - self.mainPN_clone[0].actions: actions1, - self.mainPN_clone[0].sample_return: sample_return1, - self.mainPN_clone[0].sample_reward: sample_reward1, - self.mainPN_clone[1].state_input: state_input0, - self.mainPN_clone[1].actions: actions0, - self.mainPN_clone[1].sample_return: sample_return0, - self.mainPN_clone[1].sample_reward: sample_reward0, - self.mainPN_clone[0].gamma_array: np.reshape( - discount, [1, -1] - ), - self.mainPN_clone[1].gamma_array: np.reshape( - discount, [1, -1] - ), - } - ) lola_training_list = [ self.mainPN[0].value, @@ -728,103 +688,46 @@ def step(self): ) ( # Player_red - values, + values_p1, updateModel_1, update1, player_1_value, player_1_target, player_1_loss, - entropy_p_0, - v_0_log, - actor_target_error_0, - actor_loss_0, - parameters_norm_0, - second_order0, - v_0_grad_theta_0, - second_order0_sum, - actor_grad_sum_0, + entropy_p1, + v_0_log_p1, + actor_target_error_p1, + actor_loss_p1, + parameters_norm_p1, + second_order_p1, + v_0_grad_theta_p1, + second_order_sum_p1, + actor_grad_sum_p1, v_0_grad_01, - multiply0, + multiply_p1, # Player_blue - values_1, - updateModel_2, - update2, + values_p2, + updateModel_p2, + update_p2, player_2_value, player_2_target, player_2_loss, - entropy_p_1, - v_1_log, - actor_target_error_1, - actor_loss_1, - parameters_norm_1, - second_order1, - v_1_grad_theta_1, - second_order1_sum, - actor_grad_sum_1, + entropy_p2, + v_1_log_p2, + actor_target_error_p2, + actor_loss_p2, + parameters_norm_p2, + second_order_p2, + v_1_grad_theta_p2, + second_order_sum_p2, + actor_grad_sum_p2, ) = self.sess.run(lola_training_list, feed_dict=feed_dict) - if self.warmup: - update1 = update1 * self.warmup_step_n / self.warmup - update2 = update2 * self.warmup_step_n / self.warmup - if self.lr_decay: - update1 = update1 * lr_decay - update2 = update2 * lr_decay - - update1_sum = sum(update1) / self.bs_mul - update2_sum = sum(update2) / self.bs_mul - - update( - self.mainPN, - self.lr, - update1 / self.bs_mul, - update2 / self.bs_mul, - use_actions_from_exploiter, + update1_sum, update2_sum = self._update_players( + update1, update_p2, lr_decay, use_actions_from_exploiter ) if self.use_PG_exploiter: - # Update policy networks - feed_dict = { - self.mainPN[0].state_input: state_input0, - self.mainPN[0].sample_return: sample_return0, - self.mainPN[0].actions: actions0, - self.mainPN[2].state_input: state_input1, - self.mainPN[2].sample_return: sample_return1, - self.mainPN[2].actions: actions1, - self.mainPN[0].sample_reward: sample_reward0, - self.mainPN[2].sample_reward: sample_reward1, - self.mainPN[0].sample_reward_bis: sample_reward0_bis, - self.mainPN[2].sample_reward_bis: sample_reward1_bis, - self.mainPN[0].gamma_array: np.reshape(discount, [1, -1]), - self.mainPN[2].gamma_array: np.reshape(discount, [1, -1]), - self.mainPN[0].next_value: value_0_next, - self.mainPN[2].next_value: expl_value_next, - self.mainPN[0].gamma_array_inverse: np.reshape( - self.discount_array, [1, -1] - ), - self.mainPN[2].gamma_array_inverse: np.reshape( - self.discount_array, [1, -1] - ), - self.mainPN[0].loss_multiplier: [lr_decay], - self.mainPN[2].loss_multiplier: [lr_decay], - self.mainPN[0].is_training: True, - self.mainPN[2].is_training: True, - } - - lola_training_list = [ - self.mainPN[2].value, - self.mainPN[2].updateModel, - self.mainPN[2].delta, - self.mainPN[2].value, - self.mainPN[2].target, - self.mainPN[2].loss, - self.mainPN[2].entropy, - self.mainPN[2].v_0_log, - self.mainPN[2].actor_target_error, - self.mainPN[2].actor_loss, - self.mainPN[2].weigths_norm, - self.mainPN[2].grad, - self.mainPN[2].grad_sum, - ] ( pg_expl_values, pg_expl_updateModel, @@ -839,32 +742,405 @@ def step(self): pg_expl_parameters_norm, pg_expl_v_grad_theta, pg_expl_actor_grad_sum, - ) = self.sess.run(lola_training_list, feed_dict=feed_dict) + pg_expl_update_sum, + ) = self._update_pg_exploiter( + state_input0, + sample_return0, + actions0, + state_input1, + sample_return1, + actions1, + sample_reward0, + sample_reward1, + sample_reward0_bis, + sample_reward1_bis, + discount, + value_0_next, + expl_value_next, + lr_decay, + ) + else: + pg_expl_player_loss = None + pg_expl_v_log = None + pg_expl_entropy = None + pg_expl_actor_loss = None + pg_expl_parameters_norm = None + pg_expl_update_sum = None + pg_expl_actor_grad_sum = None + + self._update_actor_rolling_avg( + actor_grad_sum_p1, + actor_grad_sum_p2, + trainBatch0[2], + trainBatch1[2], + ) - if self.warmup: - pg_expl_update = ( - pg_expl_update * self.warmup_step_n / self.warmup + print("update params") + to_report = self._prepare_logs( + to_report, + last_info, + player_1_loss, + player_2_loss, + v_0_log_p1, + v_1_log_p2, + entropy_p1, + entropy_p2, + actor_loss_p1, + actor_loss_p2, + parameters_norm_p1, + parameters_norm_p2, + second_order_sum_p1, + second_order_sum_p2, + update1_sum, + update2_sum, + actor_grad_sum_p1, + actor_grad_sum_p2, + pg_expl_player_loss, + pg_expl_v_log, + pg_expl_entropy, + pg_expl_actor_loss, + pg_expl_parameters_norm, + pg_expl_update_sum, + pg_expl_actor_grad_sum, + reward_mean_p0, + reward_mean_p1, + reward_std_p0, + reward_std_p1, + sample_return0, + sample_return1, + sample_reward0, + sample_reward1, + values_p1, + values_p2, + player_1_target, + player_2_target, + ) + return to_report + + def compute_centered_discounted_r(self, rewards, discount, player_key=""): + if not self.only_process_reward: + rewards = self._process_rewards( + rewards, preprocess=True, key=player_key + "_rewards" + ) + + sample_return = np.reshape( + get_monte_carlo( + rewards, self.y, self.trace_length, self.batch_size + ), + [self.batch_size, -1], + ) + + if self.only_process_reward: + rewards = self._process_rewards( + rewards, preprocess=True, key=player_key + "_rewards" + ) + + if self.correction_reward_baseline_per_step: + sample_reward = discount * np.reshape( + rewards - np.mean(np.array(rewards), axis=0), + [-1, self.trace_length], + ) + else: + sample_reward = discount * np.reshape( + rewards - np.mean(rewards), [-1, self.trace_length] + ) + sample_reward_bis = discount * np.reshape( + rewards, [-1, self.trace_length] + ) + + if not self.only_process_reward: + sample_return = self._process_rewards( + sample_return, + preprocess=False, + key=player_key + "_sample_return", + ) + sample_reward = self._process_rewards( + sample_reward, preprocess=False, key=player_key + "_sample_reward" + ) + sample_reward_bis = self._process_rewards( + sample_reward_bis, + preprocess=False, + key=player_key + "_sample_reward_bis", + ) + return sample_return, sample_reward, sample_reward_bis + + def _process_rewards(self, rewards, preprocess, key: str): + postprocess = not preprocess + preprocessing = preprocess and not self.process_reward_after_rolling + postprocessing = postprocess and self.process_reward_after_rolling + if preprocessing or postprocessing: + if self.use_normalized_rewards: + return self._normalize_rewards(rewards, key) + elif self.use_centered_reward: + return self._center_reward(rewards, key) + return rewards + + def _normalize_rewards(self, rewards, key: str): + r_array = np.array(rewards) + if self.center_and_normalize_with_rolling_avg: + mean_key = key + "_mean" + std_key = key + "_std" + self._update_reward_rolling_avg(r_array.mean(), mean_key) + self._update_reward_rolling_avg(r_array.std(), std_key) + rewards = ( + r_array - self._get_reward_rolling_avg(mean_key) + ) / self._get_reward_rolling_avg(std_key) + else: + rewards = (r_array - r_array.mean()) / r_array.std() + if self.reward_processing_bais: + rewards += self.reward_processing_bais + return rewards.tolist() + + def _center_reward(self, rewards, key: str): + r_array = np.array(rewards) + if self.center_and_normalize_with_rolling_avg: + self._update_reward_rolling_avg(r_array.mean(), key) + rewards = r_array - self._get_reward_rolling_avg(key) + else: + rewards = r_array - r_array.mean() + if self.reward_processing_bais: + rewards += self.reward_processing_bais + return rewards.tolist() + + def _update_reward_rolling_avg(self, value, key): + if key in self._reward_rolling_avg: + self._reward_rolling_avg[key] = value + self._reward_rolling_avg[ + key + ] * (1 - (1 / self._reward_rolling_avg_length)) + else: + self._reward_rolling_avg[key] = ( + value * self._reward_rolling_avg_length + ) + + def _get_reward_rolling_avg(self, key): + return self._reward_rolling_avg[key] / self._reward_rolling_avg_length + + def _update_bs_mul_wt_rolling_avg(self): + rolling_avg = None + if self.use_rolling_avg_actor_grad: + rolling_avg = self.use_rolling_avg_actor_grad + if self.use_rolling_avg_reward: + rolling_avg = self.use_rolling_avg_reward + + if rolling_avg is not None: + self.global_lr_divider = [ + self.global_lr_divider_original * v / rolling_avg + for v in self.actor_rolling_avg + ] + + def _update_actor_rolling_avg( + self, + actor_grad_sum_0, + actor_grad_sum_1, + reward_player_0, + reward_player_1, + ): + if self.use_rolling_avg_actor_grad or self.use_rolling_avg_reward: + roll_avg_p0 = self.actor_rolling_avg[0] + roll_avg_p1 = self.actor_rolling_avg[1] + if self.use_rolling_avg_actor_grad: + roll_avg_p0 = np.abs(actor_grad_sum_0) + roll_avg_p0 * ( + 1 - (1 / self.use_rolling_avg_actor_grad) ) - if self.lr_decay: - pg_expl_update = pg_expl_update * lr_decay + roll_avg_p1 = np.abs(actor_grad_sum_1) + roll_avg_p1 * ( + 1 - (1 / self.use_rolling_avg_actor_grad) + ) + if self.use_rolling_avg_reward: + roll_avg_p0 = np.sum(np.abs(reward_player_0)) + roll_avg_p0 * ( + 1 - (1 / self.use_rolling_avg_reward) + ) + roll_avg_p1 = np.sum(np.abs(reward_player_1)) + roll_avg_p1 * ( + 1 - (1 / self.use_rolling_avg_reward) + ) + self.actor_rolling_avg[0] = roll_avg_p0 + self.actor_rolling_avg[1] = roll_avg_p1 - # pg_expl_update_to_log = pg_expl_update - pg_expl_update_sum = sum(pg_expl_update) / self.bs_mul + def _update_players( + self, update1, update2, lr_decay, use_actions_from_exploiter + ): + if self.warmup: + update1 = update1 * self.warmup_step_n / self.warmup + update2 = update2 * self.warmup_step_n / self.warmup + if self.lr_decay: + update1 = update1 * lr_decay + update2 = update2 * lr_decay - update_single( - self.mainPN[2], self.lr, pg_expl_update / self.bs_mul + if self.use_rolling_avg_actor_grad or self.use_rolling_avg_reward: + self._update_bs_mul_wt_rolling_avg() + update1_sum = sum(update1) / self.global_lr_divider[0] + update2_sum = sum(update2) / self.global_lr_divider[1] + update( + self.mainPN, + self.lr, + update1 / self.global_lr_divider[0], + update2 / self.global_lr_divider[1], + use_actions_from_exploiter, ) + else: + update1_sum = sum(update1) / self.global_lr_divider + update2_sum = sum(update2) / self.global_lr_divider + update( + self.mainPN, + self.lr, + update1 / self.global_lr_divider, + update2 / self.global_lr_divider, + use_actions_from_exploiter, + ) + return update1_sum, update2_sum - if self.timestep >= self.start_using_exploiter_at_update_n: - if self.timestep % self.every_n_updates_copy_weights == 0: - copy_weigths( - from_policy=self.mainPN[2], - to_policy=self.mainPN[1], - adding_scaled_weights=self.adding_scaled_weights, - ) + def _update_pg_exploiter( + self, + state_input0, + sample_return0, + actions0, + state_input1, + sample_return1, + actions1, + sample_reward0, + sample_reward1, + sample_reward0_bis, + sample_reward1_bis, + discount, + value_0_next, + expl_value_next, + lr_decay, + ): + # Update policy networks + feed_dict = { + self.mainPN[0].state_input: state_input0, + self.mainPN[0].sample_return: sample_return0, + self.mainPN[0].actions: actions0, + self.mainPN[2].state_input: state_input1, + self.mainPN[2].sample_return: sample_return1, + self.mainPN[2].actions: actions1, + self.mainPN[0].sample_reward: sample_reward0, + self.mainPN[2].sample_reward: sample_reward1, + self.mainPN[0].sample_reward_bis: sample_reward0_bis, + self.mainPN[2].sample_reward_bis: sample_reward1_bis, + self.mainPN[0].gamma_array: np.reshape(discount, [1, -1]), + self.mainPN[2].gamma_array: np.reshape(discount, [1, -1]), + self.mainPN[0].next_value: value_0_next, + self.mainPN[2].next_value: expl_value_next, + self.mainPN[0].gamma_array_inverse: np.reshape( + self.discount_array, [1, -1] + ), + self.mainPN[2].gamma_array_inverse: np.reshape( + self.discount_array, [1, -1] + ), + self.mainPN[0].loss_multiplier: [lr_decay], + self.mainPN[2].loss_multiplier: [lr_decay], + self.mainPN[0].is_training: True, + self.mainPN[2].is_training: True, + } - print("update params") + lola_training_list = [ + self.mainPN[2].value, + self.mainPN[2].updateModel, + self.mainPN[2].delta, + self.mainPN[2].value, + self.mainPN[2].target, + self.mainPN[2].loss, + self.mainPN[2].entropy, + self.mainPN[2].v_0_log, + self.mainPN[2].actor_target_error, + self.mainPN[2].actor_loss, + self.mainPN[2].weigths_norm, + self.mainPN[2].grad, + self.mainPN[2].grad_sum, + ] + ( + pg_expl_values, + pg_expl_updateModel, + pg_expl_update, + pg_expl_player_value, + pg_expl_player_target, + pg_expl_player_loss, + pg_expl_entropy, + pg_expl_v_log, + pg_expl_actor_target_error, + pg_expl_actor_loss, + pg_expl_parameters_norm, + pg_expl_v_grad_theta, + pg_expl_actor_grad_sum, + ) = self.sess.run(lola_training_list, feed_dict=feed_dict) + + if self.warmup: + pg_expl_update = pg_expl_update * self.warmup_step_n / self.warmup + if self.lr_decay: + pg_expl_update = pg_expl_update * lr_decay + # pg_expl_update_to_log = pg_expl_update + pg_expl_update_sum = sum(pg_expl_update) / self.global_lr_divider + + update_single( + self.mainPN[2], self.lr, pg_expl_update / self.global_lr_divider + ) + + if self.timestep >= self.start_using_exploiter_at_update_n: + if self.timestep % self.every_n_updates_copy_weights == 0: + copy_weigths( + from_policy=self.mainPN[2], + to_policy=self.mainPN[1], + adding_scaled_weights=self.adding_scaled_weights, + ) + + return ( + pg_expl_values, + pg_expl_updateModel, + pg_expl_update, + pg_expl_player_value, + pg_expl_player_target, + pg_expl_player_loss, + pg_expl_entropy, + pg_expl_v_log, + pg_expl_actor_target_error, + pg_expl_actor_loss, + pg_expl_parameters_norm, + pg_expl_v_grad_theta, + pg_expl_actor_grad_sum, + pg_expl_update_sum, + ) + + def _prepare_logs( + self, + to_report, + last_info, + player_1_loss, + player_2_loss, + v_0_log, + v_1_log, + entropy_p_0, + entropy_p_1, + actor_loss_0, + actor_loss_1, + parameters_norm_0, + parameters_norm_1, + second_order0_sum, + second_order1_sum, + update1_sum, + update2_sum, + actor_grad_sum_0, + actor_grad_sum_1, + pg_expl_player_loss, + pg_expl_v_log, + pg_expl_entropy, + pg_expl_actor_loss, + pg_expl_parameters_norm, + pg_expl_update_sum, + pg_expl_actor_grad_sum, + reward_mean_p0, + reward_mean_p1, + reward_std_p0, + reward_std_p1, + sample_return0, + sample_return1, + sample_reward0, + sample_reward1, + values, + values_1, + player_1_target, + player_2_target, + ): rlog = np.sum(self.rList[-self.summary_len :], 0) to_plot = {} @@ -891,6 +1167,10 @@ def step(self): last_info.pop("available_actions", None) training_info = { + "reward_mean_p0": reward_mean_p0, + "reward_mean_p1": reward_mean_p1, + "reward_std_p0": reward_std_p0, + "reward_std_p1": reward_std_p1, "player_1_loss": player_1_loss, "player_2_loss": player_2_loss, "v_0_log": v_0_log, @@ -911,16 +1191,18 @@ def step(self): / self.num_episodes, } # Logging distribution (can be a speed bottleneck) - # training_info.update({ - # "sample_return0": sample_return0, - # "sample_return1": sample_return1, - # "sample_reward0": sample_reward0, - # "sample_reward1": sample_reward1, - # "player_1_values": values, - # "player_2_values": values_1, - # "player_1_target": player_1_target, - # "player_2_target": player_2_target, - # }) + # training_info.update( + # { + # "sample_return0": sample_return0, + # "sample_return1": sample_return1, + # "sample_reward0": sample_reward0, + # "sample_reward1": sample_reward1, + # "player_1_values": values, + # "player_2_values": values_1, + # "player_1_target": player_1_target, + # "player_2_target": player_2_target, + # } + # ) # self.update1_list.clear() # self.update2_list.clear() @@ -952,30 +1234,14 @@ def step(self): + to_report.get("player_red_pick_own_color", 0.0) ) + if self.use_rolling_avg_actor_grad: + to_report["rolling_avg_actor_grad_p0"] = self.actor_rolling_avg[0] + to_report["rolling_avg_actor_grad_p1"] = self.actor_rolling_avg[1] + if self.use_rolling_avg_reward: + to_report["rolling_avg_reward_p0"] = self.actor_rolling_avg[0] + to_report["rolling_avg_reward_p1"] = self.actor_rolling_avg[1] return to_report - def compute_centered_discounted_r(self, rewards, discount): - sample_return = np.reshape( - get_monte_carlo( - rewards, self.y, self.trace_length, self.batch_size - ), - [self.batch_size, -1], - ) - - if self.correction_reward_baseline_per_step: - sample_reward = discount * np.reshape( - rewards - np.mean(np.array(rewards), axis=0), - [-1, self.trace_length], - ) - else: - sample_reward = discount * np.reshape( - rewards - np.mean(rewards), [-1, self.trace_length] - ) - sample_reward_bis = discount * np.reshape( - rewards, [-1, self.trace_length] - ) - return sample_return, sample_reward, sample_reward_bis - def _log_one_step_in_full_episode(self, s, r, actions, obs, info): self.full_episode_logger.on_episode_step( step_data={ diff --git a/marltoolbox/algos/lola/train_exact_tune_class_API.py b/marltoolbox/algos/lola/train_exact_tune_class_API.py index ef844ab..472c00a 100644 --- a/marltoolbox/algos/lola/train_exact_tune_class_API.py +++ b/marltoolbox/algos/lola/train_exact_tune_class_API.py @@ -11,7 +11,7 @@ import json import os import random - +import torch import numpy as np import tensorflow as tf import tensorflow.contrib.layers as layers @@ -29,22 +29,35 @@ class Qnetwork: Q-network that is either a look-up table or an MLP with 1 hidden layer. """ - def __init__(self, myScope, num_hidden, sess, simple_net=True): + def __init__( + self, myScope, num_hidden, sess, simple_net=True, n_actions=2, std=1.0 + ): with tf.variable_scope(myScope): - self.input_place = tf.placeholder(shape=[5], dtype=tf.int32) + # self.input_place = tf.placeholder(shape=[5], dtype=tf.int32) + # if simple_net: + # self.p_act = tf.Variable(tf.random_normal([5, 1], stddev=3.0)) + # else: + # act = tf.nn.tanh( + # layers.fully_connected( + # tf.one_hot(self.input_place, 5, dtype=tf.float32), + # num_outputs=num_hidden, + # activation_fn=None, + # ) + # ) + # self.p_act = layers.fully_connected( + # act, num_outputs=1, activation_fn=None + # ) + self.input_place = tf.placeholder( + shape=[n_actions ** 2 + 1], dtype=tf.int32 + ) if simple_net: - self.p_act = tf.Variable(tf.random_normal([5, 1], stddev=1.0)) - else: - act = tf.nn.tanh( - layers.fully_connected( - tf.one_hot(self.input_place, 5, dtype=tf.float32), - num_outputs=num_hidden, - activation_fn=None, + self.p_act = tf.Variable( + tf.random_normal( + [n_actions ** 2 + 1, n_actions], stddev=std ) ) - self.p_act = layers.fully_connected( - act, num_outputs=1, activation_fn=None - ) + else: + raise ValueError() self.parameters = [] for i in tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=myScope @@ -63,51 +76,135 @@ def update(mainQN, lr, final_delta_1_v, final_delta_2_v): ) -def corrections_func(mainQN, corrections, gamma, pseudo, reg): +def corrections_func(mainQN, corrections, gamma, pseudo, reg, n_actions=2): + print_opts = [] + mainQN[0].lr_correction = tf.placeholder(shape=[1], dtype=tf.float32) mainQN[1].lr_correction = tf.placeholder(shape=[1], dtype=tf.float32) theta_1_all = mainQN[0].p_act theta_2_all = mainQN[1].p_act - theta_1 = tf.slice(theta_1_all, [0, 0], [4, 1]) - theta_2 = tf.slice(theta_2_all, [0, 0], [4, 1]) - - theta_1_0 = tf.slice(theta_1_all, [4, 0], [1, 1]) - theta_2_0 = tf.slice(theta_2_all, [4, 0], [1, 1]) - p_1 = tf.nn.sigmoid(theta_1) - p_2 = tf.nn.sigmoid(theta_2) - mainQN[0].policy = tf.nn.sigmoid(theta_1_all) - mainQN[1].policy = tf.nn.sigmoid(theta_2_all) - - p_1_0 = tf.nn.sigmoid(theta_1_0) - p_2_0 = tf.nn.sigmoid(theta_2_0) + # Using sigmoid + normalize to keep values similar to the official + # implementation + pi_player_1_for_all_states = tf.nn.sigmoid(theta_1_all) + pi_player_2_for_all_states = tf.nn.sigmoid(theta_2_all) + print_opts.append( + tf.print("pi_player_1_for_all_states", pi_player_1_for_all_states) + ) + print_opts.append( + tf.print("pi_player_2_for_all_states", pi_player_2_for_all_states) + ) + sum_1 = tf.reduce_sum(pi_player_1_for_all_states, axis=1) + sum_1 = tf.stack([sum_1 for _ in range(n_actions)], axis=1) + sum_2 = tf.reduce_sum(pi_player_2_for_all_states, axis=1) + sum_2 = tf.stack([sum_2 for _ in range(n_actions)], axis=1) + pi_player_1_for_all_states = pi_player_1_for_all_states / sum_1 + pi_player_2_for_all_states = pi_player_2_for_all_states / sum_2 + print_opts.append( + tf.print("pi_player_1_for_all_states", pi_player_1_for_all_states) + ) + mainQN[0].policy = pi_player_1_for_all_states + mainQN[1].policy = pi_player_2_for_all_states + n_states = int(pi_player_1_for_all_states.shape[0]) + print_opts.append(tf.print("mainQN[0].policy", mainQN[0].policy)) + print_opts.append(tf.print("mainQN[1].policy", mainQN[1].policy)) + pi_player_1_for_states_in_game = tf.slice( + pi_player_1_for_all_states, + [0, 0], + [n_states - 1, n_actions], + ) + pi_player_2_for_states_in_game = tf.slice( + pi_player_2_for_all_states, + [0, 0], + [n_states - 1, n_actions], + ) + pi_player_1_for_initial_state = tf.slice( + pi_player_1_for_all_states, + [n_states - 1, 0], + [1, n_actions], + ) + pi_player_2_for_initial_state = tf.slice( + pi_player_2_for_all_states, + [n_states - 1, 0], + [1, n_actions], + ) + pi_player_1_for_initial_state = tf.transpose(pi_player_1_for_initial_state) + pi_player_2_for_initial_state = tf.transpose(pi_player_2_for_initial_state) + print_opts.append( + tf.print( + "pi_player_1_for_initial_state", pi_player_1_for_initial_state + ) + ) - p_1_0_v = tf.concat([p_1_0, (1 - p_1_0)], 0) - p_2_0_v = tf.concat([p_2_0, (1 - p_2_0)], 0) + s_0 = tf.reshape( + tf.matmul( + pi_player_1_for_initial_state, + tf.transpose(pi_player_2_for_initial_state), + ), + [-1, 1], + ) - s_0 = tf.reshape(tf.matmul(p_1_0_v, tf.transpose(p_2_0_v)), [-1, 1]) + pi_p1 = tf.reshape(pi_player_1_for_states_in_game, [n_states - 1, -1, 1]) + pi_p2 = tf.reshape(pi_player_2_for_states_in_game, [n_states - 1, -1, 1]) - # CC, CD, DC, DD + all_actions_proba_pairs = [] + for action_p1 in range(n_actions): + for action_p2 in range(n_actions): + all_actions_proba_pairs.append( + tf.multiply(pi_p1[:, action_p1], pi_p2[:, action_p2]) + ) P = tf.concat( - [ - tf.multiply(p_1, p_2), - tf.multiply(p_1, 1 - p_2), - tf.multiply(1 - p_1, p_2), - tf.multiply(1 - p_1, 1 - p_2), - ], + all_actions_proba_pairs, 1, ) - R_1 = tf.placeholder(shape=[4, 1], dtype=tf.float32) - R_2 = tf.placeholder(shape=[4, 1], dtype=tf.float32) - - I_m_P = tf.diag([1.0, 1.0, 1.0, 1.0]) - P * gamma + # if n_actions == 2: + # # CC, CD, DC, DD + # + # P = tf.concat( + # [ + # tf.multiply(pi_p1[:, 0], pi_p2[:, 0]), + # tf.multiply(pi_p1[:, 0], pi_p2[:, 1]), + # tf.multiply(pi_p1[:, 1], pi_p2[:, 0]), + # tf.multiply(pi_p1[:, 1], pi_p2[:, 1]), + # ], + # 1, + # ) + # elif n_actions == 3: + # # CC, CD, CN, DC, DD, DN, NC, ND, NN + # P = tf.concat( + # [ + # tf.multiply(pi_p1[:, 0], pi_p2[:, 0]), + # tf.multiply(pi_p1[:, 0], pi_p2[:, 1]), + # tf.multiply(pi_p1[:, 0], pi_p2[:, 2]), + # tf.multiply(pi_p1[:, 1], pi_p2[:, 0]), + # tf.multiply(pi_p1[:, 1], pi_p2[:, 1]), + # tf.multiply(pi_p1[:, 1], pi_p2[:, 2]), + # tf.multiply(pi_p1[:, 2], pi_p2[:, 0]), + # tf.multiply(pi_p1[:, 2], pi_p2[:, 1]), + # tf.multiply(pi_p1[:, 2], pi_p2[:, 2]), + # ], + # 1, + # ) + # else: + # raise ValueError(f"n_actions {n_actions}") + # R_1 = tf.placeholder(shape=[4, 1], dtype=tf.float32) + # R_2 = tf.placeholder(shape=[4, 1], dtype=tf.float32) + R_1 = tf.placeholder(shape=[n_actions ** 2, 1], dtype=tf.float32) + R_2 = tf.placeholder(shape=[n_actions ** 2, 1], dtype=tf.float32) + + # I_m_P = tf.diag([1.0, 1.0, 1.0, 1.0]) - P * gamma + I_m_P = tf.diag([1.0] * (n_actions ** 2)) - P * gamma v_0 = tf.matmul( tf.matmul(tf.matrix_inverse(I_m_P), R_1), s_0, transpose_a=True ) v_1 = tf.matmul( tf.matmul(tf.matrix_inverse(I_m_P), R_2), s_0, transpose_a=True ) + print_opts.append(tf.print("s_0", s_0)) + print_opts.append(tf.print("I_m_P", I_m_P)) + print_opts.append(tf.print("R_1", R_1)) + print_opts.append(tf.print("v_0", v_0)) if reg > 0: for indx, _ in enumerate(mainQN[0].parameters): v_0 -= reg * tf.reduce_sum( @@ -145,6 +242,7 @@ def corrections_func(mainQN, corrections, gamma, pseudo, reg): tf.reshape(v_0_grad_theta_0_wrong, [param_len, 1]), ) + # with tf.control_dependencies(print_opts): second_order0 = flatgrad(multiply0, mainQN[0].parameters) second_order1 = flatgrad(multiply1, mainQN[1].parameters) @@ -158,11 +256,11 @@ def corrections_func(mainQN, corrections, gamma, pseudo, reg): mainQN[1].delta += tf.multiply(second_order1, mainQN[1].lr_correction) -class LOLAExact(tune.Trainable): +class LOLAExactTrainer(tune.Trainable): def _init_lola( self, - env_name, *, + env_name="IteratedAsymBoS", num_episodes=50, trace_length=200, simple_net=True, @@ -175,10 +273,12 @@ def _init_lola( gamma=0.96, with_linear_LR_decay_to_zero=False, clip_update=None, + re_init_every_n_epi=1, + Q_net_std=1.0, **kwargs, ): - print("args not used:", kwargs) + # print("args not used:", kwargs) self.num_episodes = num_episodes self.trace_length = trace_length @@ -192,6 +292,8 @@ def _init_lola( self.gamma = gamma self.with_linear_LR_decay_to_zero = with_linear_LR_decay_to_zero self.clip_update = clip_update + self.re_init_every_n_epi = re_init_every_n_epi + self.Q_net_std = Q_net_std graph = tf.Graph() @@ -202,11 +304,19 @@ def _init_lola( if env_name == "IPD": self.payout_mat_1 = np.array([[-1.0, 0.0], [-3.0, -2.0]]) self.payout_mat_2 = self.payout_mat_1.T - elif env_name == "AsymmetricIteratedBoS": + elif env_name == "IteratedAsymBoS": self.payout_mat_1 = np.array([[+4.0, 0.0], [0.0, +2.0]]) self.payout_mat_2 = np.array([[+1.0, 0.0], [0.0, +2.0]]) + elif "custom_payoff_matrix" in kwargs.keys(): + custom_matrix = kwargs["custom_payoff_matrix"] + self.payout_mat_1 = custom_matrix[:, :, 0] + self.payout_mat_2 = custom_matrix[:, :, 1] else: raise ValueError(f"exp_name: {env_name}") + self.n_actions = int(self.payout_mat_1.shape[0]) + assert self.n_actions == self.payout_mat_1.shape[1] + self.policy1 = [0.0] * self.n_actions + self.policy2 = [0.0] * self.n_actions # Sanity @@ -219,6 +329,8 @@ def _init_lola( self.num_hidden, self.sess, self.simple_net, + self.n_actions, + self.Q_net_std, ) ) @@ -229,6 +341,7 @@ def _init_lola( self.gamma, self.pseudo, self.reg, + self.n_actions, ) self.results = [] @@ -241,21 +354,22 @@ def _init_lola( # TODO add something to not load and create everything when only evaluating with RLLib def setup(self, config): - print("_init_lola", config) self._init_lola(**config) def step(self): - self.sess.run(self.init) - lr_coor = np.ones(1) * self.lr_correction - log_items = {} log_items["episode"] = self.training_iteration + if self.training_iteration % self.re_init_every_n_epi == 0: + self.sess.run(self.init) + + lr_coor = np.ones(1) * self.lr_correction + res = [] params_time = [] delta_time = [] - input_vals = np.reshape(np.array(range(5)) + 1, [-1]) + input_vals = self._get_input_vals() for i in range(self.trace_length): params0 = self.mainQN[0].getparams() params1 = self.mainQN[1].getparams() @@ -267,14 +381,13 @@ def step(self): self.mainQN[0].policy, self.mainQN[1].policy, ] - # print("input_vals", input_vals) - update1, update2, v1, v2, policy1, policy2 = self.sess.run( + (update1, update2, v1, v2, policy1, policy2) = self.sess.run( outputs, feed_dict={ self.mainQN[0].input_place: input_vals, self.mainQN[1].input_place: input_vals, - self.mainQN[0].R1: np.reshape(self.payout_mat_2, [-1, 1]), - self.mainQN[1].R1: np.reshape(self.payout_mat_1, [-1, 1]), + self.mainQN[0].R1: np.reshape(self.payout_mat_1, [-1, 1]), + self.mainQN[1].R1: np.reshape(self.payout_mat_2, [-1, 1]), self.mainQN[0].lr_correction: lr_coor, self.mainQN[1].lr_correction: lr_coor, }, @@ -289,10 +402,21 @@ def step(self): res.append([v1[0][0] / self.norm, v2[0][0] / self.norm]) self.results.append(res) + if self.training_iteration % self.re_init_every_n_epi == ( + self.re_init_every_n_epi - 1 + ): + self.policy1 = policy1 + self.policy2 = policy2 + log_items["episodes_total"] = self.training_iteration + log_items["policy1"] = self.policy1 + log_items["policy2"] = self.policy2 return log_items + def _get_input_vals(self): + return np.reshape(np.array(range(self.n_actions ** 2 + 1)) + 1, [-1]) + def _clip_update(self, update1, update2): if self.clip_update is not None: assert self.clip_update > 0.0 @@ -382,24 +506,28 @@ def _post_process_action(self, action): return action[None, ...] # add batch dim def compute_actions(self, policy_id: str, obs_batch: list): - # because of the LSTM - assert len(obs_batch) == 1 + assert ( + len(obs_batch) == 1 + ), f"{len(obs_batch)} == 1. obs_batch: {obs_batch}" for single_obs in obs_batch: agent_to_use = self._get_agent_to_use(policy_id) obs = self._preprocess_obs(single_obs, agent_to_use) - input_vals = np.reshape(np.array(range(5)) + 1, [-1]) + input_vals = self._get_input_vals() policy = self.sess.run( - [self.mainQN[agent_to_use].policy], + [ + self.mainQN[agent_to_use].policy, + ], feed_dict={ self.mainQN[agent_to_use].input_place: input_vals, }, ) - coop_proba = policy[0][obs][0] - if coop_proba > random.random(): - action = np.array(0) - else: - action = np.array(1) + probabilities = policy[0][obs] + probabilities = torch.tensor(probabilities) + policy_for_this_state = torch.distributions.Categorical( + probs=probabilities + ) + action = policy_for_this_state.sample() action = self._post_process_action(action) diff --git a/marltoolbox/algos/ltft/ltft.py b/marltoolbox/algos/ltft/ltft.py index 5a91e15..203f7ad 100644 --- a/marltoolbox/algos/ltft/ltft.py +++ b/marltoolbox/algos/ltft/ltft.py @@ -12,13 +12,16 @@ from ray.rllib.execution.replay_buffer import LocalReplayBuffer from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches -from ray.rllib.execution.train_ops import TrainOneStep -from ray.rllib.execution.train_ops import UpdateTargetNetwork from ray.rllib.utils import merge_dicts from ray.rllib.utils.schedules import PiecewiseSchedule from ray.rllib.utils.typing import PolicyID, SampleBatchType from ray.rllib.utils.typing import TrainerConfigDict from ray.util.iter import LocalIterator +from ray.rllib.execution.train_ops import ( + TrainOneStep, + UpdateTargetNetwork, + TrainTFMultiGPU, +) from torch.nn import CrossEntropyLoss from marltoolbox.algos import augmented_dqn, supervised_learning, hierarchical @@ -116,7 +119,7 @@ framework="torch", ), }, - "log_level": "DEBUG", + "batch_mode": "complete_episodes", } ) @@ -133,7 +136,7 @@ PLOT_ASSEMBLAGE_TAGS = [ ("defection",), - ("defection", "chosen_percentile"), + ("defection_", "chosen_percentile"), ("punish",), ("delta_log_likelihood_coop_std", "delta_log_likelihood_coop_mean"), ("likelihood_opponent", "likelihood_approximated_opponent"), @@ -158,7 +161,15 @@ def execution_plan( workers: WorkerSet, config: TrainerConfigDict ) -> LocalIterator[dict]: """ - Modified from the execution plan of the DQNTrainer + Execution plan of the DQN algorithm. Defines the distributed dataflow. + + Args: + workers (WorkerSet): The WorkerSet for training the Polic(y/ies) + of the Trainer. + config (TrainerConfigDict): The trainer's configuration dict. + + Returns: + LocalIterator[dict]: A local iterator over training metrics. """ if config.get("prioritized_replay"): prio_args = { @@ -175,7 +186,9 @@ def execution_plan( buffer_size=config["buffer_size"], replay_batch_size=config["train_batch_size"], replay_mode=config["multiagent"]["replay_mode"], - replay_sequence_length=config["replay_sequence_length"], + replay_sequence_length=config.get("replay_sequence_length", 1), + replay_burn_in=config.get("burn_in", 0), + replay_zero_init_states=config.get("zero_init_states", True), **prio_args, ) @@ -204,10 +217,9 @@ def update_prio(item): td_error = info.get( "td_error", info[LEARNER_STATS_KEY].get("td_error") ) + samples.policy_batches[policy_id].set_get_interceptor(None) prio_dict[policy_id] = ( - samples.policy_batches[policy_id].data.get( - "batch_indexes" - ), + samples.policy_batches[policy_id].get("batch_indexes"), td_error, ) local_replay_buffer.update_priorities(prio_dict) @@ -217,11 +229,25 @@ def update_prio(item): # returned from the LocalReplay() iterator is passed to TrainOneStep to # take a SGD step, and then we decide whether to update the target network. post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b) + + if config["simple_optimizer"]: + train_step_op = TrainOneStep(workers) + else: + train_step_op = TrainTFMultiGPU( + workers=workers, + sgd_minibatch_size=config["train_batch_size"], + num_sgd_iter=1, + num_gpus=config["num_gpus"], + shuffle_sequences=True, + _fake_gpus=config["_fake_gpus"], + framework=config.get("framework"), + ) + replay_op_dqn = ( Replay(local_buffer=local_replay_buffer) .for_each(lambda x: post_fn(x, workers, config)) .for_each(LocalTrainablePolicyModifier(workers, train_dqn_only)) - .for_each(TrainOneStep(workers)) + .for_each(train_step_op) .for_each(update_prio) .for_each( UpdateTargetNetwork(workers, config["target_network_update_freq"]) diff --git a/marltoolbox/algos/ltft/ltft_torch_policy.py b/marltoolbox/algos/ltft/ltft_torch_policy.py index 05ce962..6101e5d 100644 --- a/marltoolbox/algos/ltft/ltft_torch_policy.py +++ b/marltoolbox/algos/ltft/ltft_torch_policy.py @@ -8,7 +8,7 @@ import copy import logging from collections import deque -from typing import List, Union, Optional, Dict, Tuple, TYPE_CHECKING +from typing import Optional, Dict, Tuple, TYPE_CHECKING, Union, List import numpy as np import torch @@ -16,9 +16,11 @@ from ray.rllib.evaluation import MultiAgentEpisode from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import override from ray.rllib.utils.typing import AgentID, PolicyID, TensorType +from ray.rllib.utils.torch_ops import convert_to_torch_tensor if TYPE_CHECKING: from ray.rllib.evaluation import RolloutWorker @@ -45,6 +47,9 @@ }, ) +BEING_PUNISHED = "being_punished" +PUNISHING = "punishing" + class LTFTTorchPolicy(hierarchical.HierarchicalTorchPolicy): """ @@ -63,12 +68,10 @@ class LTFTTorchPolicy(hierarchical.HierarchicalTorchPolicy): INITIALLY_ACTIVE_ALGO = COOP_POLICY_IDX def __init__(self, observation_space, action_space, config, **kwargs): - super().__init__( observation_space, action_space, config, - # after_init_nested=_init_weights_fn, **kwargs, ) @@ -100,10 +103,8 @@ def __init__(self, observation_space, action_space, config, **kwargs): self.n_steps_since_start = 0 self.last_computed_w = None - # self._episode_starting = True self.opp_previous_obs = None self.opp_new_obs = None - self._first_fake_step_played = False self.observed_n_step_in_current_epi = 0 self.data_queue = deque(maxlen=self.length_of_history) @@ -140,47 +141,39 @@ def __init__(self, observation_space, action_space, config, **kwargs): add_opponent_neg_reward=True, ) + self.view_requirements.update( + { + PUNISHING: ViewRequirement(), + } + ) + def _reset_learner_stats(self): self.learner_stats = {"learner_stats": {}} @override(hierarchical.HierarchicalTorchPolicy) - def compute_actions( - self, - obs_batch: Union[List[TensorType], TensorType], - state_batches: Optional[List[TensorType]] = None, - prev_action_batch: Union[List[TensorType], TensorType] = None, - prev_reward_batch: Union[List[TensorType], TensorType] = None, - info_batch: Optional[Dict[str, list]] = None, - episodes: Optional[List["MultiAgentEpisode"]] = None, - explore: Optional[bool] = None, - timestep: Optional[int] = None, - **kwargs, - ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: - assert ( - len(obs_batch) == 1 - ), "LE only works with sampling one step at a time" - - actions, state_out, extra_fetches = super().compute_actions( - obs_batch, - state_batches, - prev_action_batch, - prev_reward_batch, - info_batch, - episodes, - explore, - timestep, - **kwargs, + def _compute_action_helper( + self, input_dict, state_batches, seq_lens, explore, timestep + ): + obs_batch = input_dict["obs"] + assert len(obs_batch) == 1, ( + "The LTFT policy only works when sampling/infering one step " + "at a time." ) - extra_fetches["punishing"] = [self._is_punishing() for _ in obs_batch] - return actions, [], extra_fetches + actions, state_out, extra_fetches = super()._compute_action_helper( + input_dict, state_batches, seq_lens, explore, timestep + ) + extra_fetches[PUNISHING] = [ + self._is_punishing() for _ in obs_batch.tolist() + ] + + return actions, state_out, extra_fetches def _is_punishing(self): return self.active_algo_idx == self.PUNITIVE_POLICY_IDX @override(hierarchical.HierarchicalTorchPolicy) def _learn_on_batch(self, samples: SampleBatch): - ( policies_idx_to_train, policies_to_train, @@ -237,14 +230,14 @@ def _modify_batch_for_policy(self, policy_n, samples_copy): ) if policy_n == self.COOP_POLICY_IDX: samples_copy = filter_sample_batch( - samples_copy, remove=True, filter_key="punishing" + samples_copy, remove=True, filter_key=PUNISHING ) elif policy_n == self.COOP_OPP_POLICY_IDX: samples_copy[samples_copy.ACTIONS] = np.array( samples_copy.data[postprocessing.OPPONENT_ACTIONS] ) samples_copy = filter_sample_batch( - samples_copy, remove=True, filter_key="being_punished" + samples_copy, remove=True, filter_key=BEING_PUNISHED ) else: raise ValueError() @@ -260,7 +253,7 @@ def _modify_batch_for_policy(self, policy_n, samples_copy): # this policy. This is because the opponent is not using its # cooperative policy at that time. samples_copy = filter_sample_batch( - samples_copy, remove=True, filter_key="being_punished" + samples_copy, remove=True, filter_key=BEING_PUNISHED ) else: raise ValueError() @@ -268,30 +261,25 @@ def _modify_batch_for_policy(self, policy_n, samples_copy): return samples_copy def on_episode_step(self, episode, policy_id, policy_ids, *args, **kwargs): - if self._first_fake_step_played: - ( - opp_previous_obs, - opp_a, - being_punished_by_opp, - ) = self._get_information_from_opponent( - episode, policy_id, policy_ids + ( + opp_previous_obs, + opp_a, + being_punished_by_opp, + ) = self._get_information_from_opponent(episode, policy_id, policy_ids) + + self.being_punished_by_opp = being_punished_by_opp + if not self.being_punished_by_opp: + self.n_steps_since_start += 1 + self._put_log_likelihood_in_data_buffer( + opp_previous_obs, opp_a, self.data_queue ) - self.being_punished_by_opp = being_punished_by_opp - if not self.being_punished_by_opp: - self.n_steps_since_start += 1 - self._put_log_likelihood_in_data_buffer( - opp_previous_obs, opp_a, self.data_queue - ) - - if self.remaining_punishing_time > 0: - self.n_punishment_steps_in_current_epi += 1 - else: - self.n_cooperation_steps_in_current_epi += 1 - - self.observed_n_step_in_current_epi += 1 + if self.remaining_punishing_time > 0: + self.n_punishment_steps_in_current_epi += 1 else: - self._first_fake_step_played = True + self.n_cooperation_steps_in_current_epi += 1 + + self.observed_n_step_in_current_epi += 1 def _get_information_from_opponent(self, episode, agent_id, agent_ids): opp_agent_id = [one_id for one_id in agent_ids if one_id != agent_id][ @@ -299,7 +287,7 @@ def _get_information_from_opponent(self, episode, agent_id, agent_ids): ] opp_a = episode.last_action_for(opp_agent_id) being_punished_by_opp = episode.last_pi_info_for(opp_agent_id).get( - "punishing", False + PUNISHING, False ) return self.opp_previous_obs, opp_a, being_punished_by_opp @@ -308,8 +296,7 @@ def on_observation_fn(self, opp_new_obs): # observation produced by this action. But we need the # observation that cause the agent to play this action # thus the observation n-1 - if self._first_fake_step_played: - self.opp_previous_obs = self.opp_new_obs + self.opp_previous_obs = self.opp_new_obs self.opp_new_obs = opp_new_obs def on_episode_end(self, base_env, *args, **kwargs): @@ -386,15 +373,11 @@ def postprocess_trajectory( def _put_log_likelihood_in_data_buffer(self, s, a, data_queue, log=True): s = torch.from_numpy(s).unsqueeze(dim=0) - log_likelihood_opponent_cooperating = ( - compute_log_likelihoods_wt_exploration( - self.algorithms[self.COOP_OPP_POLICY_IDX], a, s - ) + log_likelihood_opponent_cooperating = compute_log_likelihoods( + self.algorithms[self.COOP_OPP_POLICY_IDX], [a], s ) - log_likelihood_approximated_opponent = ( - compute_log_likelihoods_wt_exploration( - self.algorithms[self.SPL_OPP_POLICY_IDX], a, s - ) + log_likelihood_approximated_opponent = compute_log_likelihoods( + self.algorithms[self.SPL_OPP_POLICY_IDX], [a], s ) log_likelihood_opponent_cooperating = float( log_likelihood_opponent_cooperating @@ -500,20 +483,20 @@ def on_postprocess_trajectory( opp_policy_id = all_agent_keys[0] opp_is_broadcast_punishment_state = ( - "punishing" in original_batches[opp_policy_id][1].data.keys() + PUNISHING in original_batches[opp_policy_id][1].keys() ) if opp_is_broadcast_punishment_state: - postprocessed_batch.data["being_punished"] = copy.deepcopy( - original_batches[all_agent_keys[0]][1].data["punishing"] + postprocessed_batch.data[BEING_PUNISHED] = copy.deepcopy( + original_batches[all_agent_keys[0]][1][PUNISHING] ) else: - postprocessed_batch.data["being_punished"] = [False] * len( + postprocessed_batch.data[BEING_PUNISHED] = [False] * len( postprocessed_batch[postprocessed_batch.OBS] ) -# Modified from torch_policy_template -def compute_log_likelihoods_wt_exploration( +# Modified from torch_policy +def compute_log_likelihoods( policy, actions: Union[List[TensorType], TensorType], obs_batch: Union[List[TensorType], TensorType], @@ -521,6 +504,7 @@ def compute_log_likelihoods_wt_exploration( prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None, prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None, ) -> TensorType: + if policy.action_sampler_fn and policy.action_distribution_fn is None: raise ValueError( "Cannot compute log-prob/likelihood w/o an " @@ -537,34 +521,63 @@ def compute_log_likelihoods_wt_exploration( if prev_reward_batch is not None: input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch seq_lens = torch.ones(len(obs_batch), dtype=torch.int32) + state_batches = [ + convert_to_torch_tensor(s, policy.device) + for s in (state_batches or []) + ] # Exploration hook before each forward pass. policy.exploration.before_compute_actions(explore=False) # Action dist class and inputs are generated via custom function. if policy.action_distribution_fn: - dist_inputs, dist_class, _ = policy.action_distribution_fn( - policy=policy, - model=policy.model, - obs_batch=input_dict[SampleBatch.CUR_OBS], - explore=False, - is_training=False, - ) + + # Try new action_distribution_fn signature, supporting + # state_batches and seq_lens. + try: + (dist_inputs, dist_class, _,) = policy.action_distribution_fn( + policy, + policy.model, + input_dict=input_dict, + state_batches=state_batches, + seq_lens=seq_lens, + explore=False, + is_training=False, + ) + # Trying the old way (to stay backward compatible). + # TODO: Remove in future. + except TypeError as e: + if ( + "positional argument" in e.args[0] + or "unexpected keyword argument" in e.args[0] + ): + dist_inputs, dist_class, _ = policy.action_distribution_fn( + policy=policy, + model=policy.model, + obs_batch=input_dict[SampleBatch.CUR_OBS], + explore=False, + is_training=False, + ) + else: + raise e + # Default action-dist inputs calculation. else: dist_class = policy.dist_class dist_inputs, _ = policy.model(input_dict, state_batches, seq_lens) action_dist = dist_class(dist_inputs, policy.model) + # ADDITION 1/1 STARTS if policy.config["explore"]: # Adding that because of a bug in TorchCategorical # which modify dist_inputs through action_dist: - _, _ = policy.exploration.get_exploration_action( + _ = policy.exploration.get_exploration_action( action_distribution=action_dist, timestep=policy.global_timestep, explore=policy.config["explore"], ) action_dist = dist_class(dist_inputs, policy.model) + # ADDITION 1/1 ENDS log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS]) diff --git a/marltoolbox/algos/population.py b/marltoolbox/algos/population.py index 81cb229..ea5e0b2 100644 --- a/marltoolbox/algos/population.py +++ b/marltoolbox/algos/population.py @@ -1,12 +1,9 @@ import copy import random -from typing import List, Union, Optional, Dict, Tuple -from ray.rllib.evaluation import MultiAgentEpisode from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import merge_dicts -from ray.rllib.utils.annotations import override -from ray.rllib.utils.typing import TensorType +from ray.rllib.utils import override from marltoolbox.algos import hierarchical from marltoolbox.utils import miscellaneous, restore @@ -17,7 +14,7 @@ # To configure "policy_checkpoints": [], "policy_id_to_load": None, - "nested_policies": None, + # "nested_policies": None, "freeze_algo": True, "use_random_algo": True, "use_algo_in_order": False, @@ -54,12 +51,10 @@ def set_algo_to_use(self): """ Called by a callback at the start of every episode. """ - if self.switch_of_algo_counter == self.switch_of_algo_every_n_epi: self.switch_of_algo_counter = 0 self._set_algo_to_use() - else: - self.switch_of_algo_counter += 1 + self.switch_of_algo_counter += 1 def _set_algo_to_use(self): self._select_algo_idx_to_use() @@ -75,10 +70,17 @@ def _select_algo_idx_to_use(self): else: raise ValueError() + self._to_log[ + "PopulationOfIdenticalAlgo_active_checkpoint_idx" + ] = self.active_checkpoint_idx + def _use_random_algo(self): self.active_checkpoint_idx = random.randint( 0, len(self.policy_checkpoints) - 1 ) + self._to_log[ + "PopulationOfIdenticalAlgo_active_checkpoint_idx" + ] = self.active_checkpoint_idx def _use_next_algo(self): self.active_checkpoint_idx += 1 @@ -110,22 +112,7 @@ def set_weights(self, weights): self.algorithms[self.active_algo_idx].set_weights(weights) @override(hierarchical.HierarchicalTorchPolicy) - def compute_actions( - self, - obs_batch: Union[List[TensorType], TensorType], - state_batches: Optional[List[TensorType]] = None, - prev_action_batch: Union[List[TensorType], TensorType] = None, - prev_reward_batch: Union[List[TensorType], TensorType] = None, - info_batch: Optional[Dict[str, list]] = None, - episodes: Optional[List["MultiAgentEpisode"]] = None, - explore: Optional[bool] = None, - timestep: Optional[int] = None, - **kwargs, - ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: - return self.algorithms[self.active_algo_idx].compute_actions(obs_batch) - - @override(hierarchical.HierarchicalTorchPolicy) - def learn_on_batch(self, samples: SampleBatch): + def _learn_on_batch(self, samples: SampleBatch): if not self.freeze_algo: # TODO maybe need to call optimizer to update the LR of the nested optimizers learner_stats = {"learner_stats": {}} @@ -165,6 +152,11 @@ def on_episode_start( **kwargs, ): self.set_algo_to_use() + if hasattr(self.algorithms[self.active_algo_idx], "on_episode_start"): + self.algorithms[self.active_algo_idx].on_episode_start( + *args, + **kwargs, + ) def modify_config_to_use_population( diff --git a/marltoolbox/algos/sos.py b/marltoolbox/algos/sos.py new file mode 100644 index 0000000..dbf8079 --- /dev/null +++ b/marltoolbox/algos/sos.py @@ -0,0 +1,521 @@ +###### +# Code modified from: +# https://github.com/julianstastny/openspiel-social-dilemmas +###### +import os +import random + +import torch +import numpy as np +from ray import tune + +from marltoolbox.envs.matrix_sequential_social_dilemma import ( + IteratedPrisonersDilemma, + IteratedAsymBoS, +) + + +class SOSTrainer(tune.Trainable): + def setup(self, config: dict): + + self.config = config + self.gamma = self.config.get("gamma") + self.learning_rate = self.config.get("lr") + self.method = self.config.get("method") + self.use_single_weights = False + + self._set_environment() + self.init_weigths(std=self.config.get("inital_weights_std")) + if self.use_single_weights: + self._exact_loss_matrix_game = ( + self._exact_loss_matrix_game_two_by_two_actions + ) + else: + self._exact_loss_matrix_game = self._exact_loss_matrix_game_generic + + def _set_environment(self): + + if self.config.get("env_name") == "IteratedPrisonersDilemma": + env_class = IteratedPrisonersDilemma + payoff_matrix = env_class.PAYOFF_MATRIX + self.players_ids = env_class({}).players_ids + elif self.config.get("env_name") == "IteratedAsymBoS": + env_class = IteratedAsymBoS + payoff_matrix = env_class.PAYOFF_MATRIX + self.players_ids = env_class({}).players_ids + elif "custom_payoff_matrix" in self.config.keys(): + payoff_matrix = self.config["custom_payoff_matrix"] + self.players_ids = IteratedPrisonersDilemma({}).players_ids + else: + raise NotImplementedError() + + self.n_actions_p1 = np.array(payoff_matrix).shape[0] + self.n_actions_p2 = np.array(payoff_matrix).shape[1] + self.n_non_init_states = self.n_actions_p1 * self.n_actions_p2 + + if self.use_single_weights: + self.dims = [ + (self.n_non_init_states + 1,), + (self.n_non_init_states + 1,), + ] + else: + self.dims = [ + (self.n_non_init_states + 1, self.n_actions_p1), + (self.n_non_init_states + 1, self.n_actions_p2), + ] + + self.payoff_matrix_player_row = torch.tensor( + payoff_matrix[:, :, 0] + ).float() + self.payoff_matrix_player_col = torch.tensor( + payoff_matrix[:, :, 1] + ).float() + + def step(self): + losses = self.update_th() + mean_reward_player_row = -losses[0] * (1 - self.gamma) + mean_reward_player_col = -losses[1] * (1 - self.gamma) + to_log = { + f"mean_reward_{self.players_ids[0]}": mean_reward_player_row, + f"mean_reward_{self.players_ids[1]}": mean_reward_player_col, + "episodes_total": self.training_iteration, + } + return to_log + + def _exact_loss_matrix_game_two_by_two_actions(self): + + pi_player_row_init_state = torch.sigmoid( + self.weights_per_players[0][0:1] + ) + pi_player_col_init_state = torch.sigmoid( + self.weights_per_players[1][0:1] + ) + p = torch.cat( + [ + pi_player_row_init_state * pi_player_col_init_state, + pi_player_row_init_state * (1 - pi_player_col_init_state), + (1 - pi_player_row_init_state) * pi_player_col_init_state, + (1 - pi_player_row_init_state) + * (1 - pi_player_col_init_state), + ] + ) + pi_player_row_other_states = torch.reshape( + torch.sigmoid(self.weights_per_players[0][1:5]), (4, 1) + ) + pi_player_col_other_states = torch.reshape( + torch.sigmoid(self.weights_per_players[1][1:5]), (4, 1) + ) + P = torch.cat( + [ + pi_player_row_other_states * pi_player_col_other_states, + pi_player_row_other_states * (1 - pi_player_col_other_states), + (1 - pi_player_row_other_states) * pi_player_col_other_states, + (1 - pi_player_row_other_states) + * (1 - pi_player_col_other_states), + ], + dim=1, + ) + M = -torch.matmul(p, torch.inverse(torch.eye(4) - self.gamma * P)) + L_1 = torch.matmul( + M, torch.reshape(self.payoff_matrix_player_row, (4, 1)) + ) + L_2 = torch.matmul( + M, torch.reshape(self.payoff_matrix_player_col, (4, 1)) + ) + return [L_1, L_2] + + def _exact_loss_matrix_game_generic(self): + pi_player_row = torch.sigmoid(self.weights_per_players[0]) + pi_player_col = torch.sigmoid(self.weights_per_players[1]) + sum_1 = torch.sum(pi_player_row, dim=1) + sum_1 = torch.stack([sum_1 for _ in range(self.n_actions_p1)], dim=1) + sum_2 = torch.sum(pi_player_col, dim=1) + sum_2 = torch.stack([sum_2 for _ in range(self.n_actions_p2)], dim=1) + pi_player_row = pi_player_row / sum_1 + pi_player_col = pi_player_col / sum_2 + + pi_player_row_init_state = pi_player_row[:1, :] + pi_player_col_init_state = pi_player_col[:1, :] + all_initial_actions_proba_pairs = [] + for action_p1 in range(self.n_actions_p1): + for action_p2 in range(self.n_actions_p2): + all_initial_actions_proba_pairs.append( + pi_player_row_init_state[:, action_p1] + * pi_player_col_init_state[:, action_p2] + ) + p = torch.cat( + all_initial_actions_proba_pairs, + ) + + pi_player_row_other_states = pi_player_row[1:, :] + pi_player_col_other_states = pi_player_col[1:, :] + all_actions_proba_pairs = [] + for action_p1 in range(self.n_actions_p1): + for action_p2 in range(self.n_actions_p2): + all_actions_proba_pairs.append( + pi_player_row_other_states[:, action_p1] + * pi_player_col_other_states[:, action_p2] + ) + P = torch.stack( + all_actions_proba_pairs, + 1, + ) + + M = -torch.matmul( + p, + torch.inverse(torch.eye(self.n_non_init_states) - self.gamma * P), + ) + L_1 = torch.matmul( + M, + torch.reshape( + self.payoff_matrix_player_row, (self.n_non_init_states, 1) + ), + ) + L_2 = torch.matmul( + M, + torch.reshape( + self.payoff_matrix_player_col, (self.n_non_init_states, 1) + ), + ) + return [L_1, L_2] + + def init_weigths(self, std): + self.weights_per_players = [] + self.n_players = len(self.dims) + for i in range(self.n_players): + if std > 0: + init = torch.nn.init.normal_( + torch.empty(self.dims[i], requires_grad=True), std=std + ) + else: + init = torch.zeros(self.dims[i], requires_grad=True) + self.weights_per_players.append(init) + + def update_th( + self, + a=0.5, + b=0.1, + gam=1, + ep=0.1, + lss_lam=0.1, + ): + losses = self._exact_loss_matrix_game() + + grads = self._compute_gradients_wt_selected_method( + a, + b, + ep, + gam, + losses, + lss_lam, + ) + + self._update_weights(grads) + return losses + + def _compute_gradients_wt_selected_method( + self, a, b, ep, gam, losses, lss_lam + ): + n_players = self.n_players + + grad_L = _compute_vanilla_gradients( + losses, n_players, self.weights_per_players + ) + + if self.method == "la": + grads = _get_la_gradients( + self.learning_rate, grad_L, n_players, self.weights_per_players + ) + elif self.method == "lola": + grads = _get_lola_exact_gradients( + self.learning_rate, grad_L, n_players, self.weights_per_players + ) + elif self.method == "sos": + grads = _get_sos_gradients( + a, + self.learning_rate, + b, + grad_L, + n_players, + self.weights_per_players, + ) + elif self.method == "sga": + grads = _get_sga_gradients( + ep, grad_L, n_players, self.weights_per_players + ) + elif self.method == "co": + grads = _get_co_gradients( + gam, grad_L, n_players, self.weights_per_players + ) + elif self.method == "eg": + grads = _get_eg_gradients( + self._exact_loss_matrix_game, + self.learning_rate, + losses, + n_players, + self.weights_per_players, + ) + elif self.method == "cgd": # Slow implementation (matrix inversion) + grads = _get_cgd_gradients( + self.learning_rate, grad_L, n_players, self.weights_per_players + ) + elif self.method == "lss": # Slow implementation (matrix inversion) + grads = _get_lss_gradients( + grad_L, lss_lam, n_players, self.weights_per_players + ) + elif self.method == "naive": # Naive Learning + grads = _get_naive_learning_gradients(grad_L, n_players) + else: + raise ValueError(f"algo: {self.method}") + return grads + + def _update_weights(self, grads): + with torch.no_grad(): + for weight_i, weight_grad in enumerate(grads): + self.weights_per_players[weight_i] -= ( + self.learning_rate * weight_grad + ) + + def save_checkpoint(self, checkpoint_dir): + save_path = os.path.join(checkpoint_dir, "weights.pt") + torch.save(self.weights_per_players, save_path) + return save_path + + def load_checkpoint(self, checkpoint_path): + self.weights_per_players = torch.load(checkpoint_path) + + def _get_agent_to_use(self, policy_id): + if policy_id == "player_row": + agent_n = 0 + elif policy_id == "player_col": + agent_n = 1 + else: + raise ValueError(f"policy_id {policy_id}") + return agent_n + + def _preprocess_obs(self, single_obs, agent_to_use): + single_obs = np.where(single_obs == 1)[0][0] + # because idx 0 is linked to the initial state in the weights + # but this is the last idx which is linked to the + # initial state in the environment obs + if single_obs == len(self.weights_per_players[0]) - 1: + single_obs = 0 + else: + single_obs += 1 + return single_obs + + def _post_process_action(self, action): + return action[None, ...] # add batch dim + + def compute_actions(self, policy_id: str, obs_batch: list): + assert ( + len(obs_batch) == 1 + ), f"{len(obs_batch)} == 1. obs_batch: {obs_batch}" + + for single_obs in obs_batch: + agent_to_use = self._get_agent_to_use(policy_id) + obs = self._preprocess_obs(single_obs, agent_to_use) + policy = torch.sigmoid(self.weights_per_players[agent_to_use]) + + if self.use_single_weights: + coop_proba = policy[obs] + if coop_proba > random.random(): + action = np.array(0) + else: + action = np.array(1) + else: + probabilities = policy[obs, :] + probabilities = torch.tensor(probabilities) + policy_for_this_state = torch.distributions.Categorical( + probs=probabilities + ) + action = np.array(policy_for_this_state.sample()) + + action = self._post_process_action(action) + + state_out = [] + extra_fetches = {} + return action, state_out, extra_fetches + + +def _compute_vanilla_gradients(losses, n, th): + grad_L = [ + [get_gradient(losses[j], th[i]) for j in range(n)] for i in range(n) + ] + return grad_L + + +def _get_la_gradients(alpha, grad_L, n, th): + terms = [ + sum( + [ + # torch.dot(grad_L[j][i], grad_L[j][j].detach()) + torch.matmul( + grad_L[j][i], + torch.transpose(grad_L[j][j].detach(), 0, 1), + ) + for j in range(n) + if j != i + ] + ) + for i in range(n) + ] + grads = [ + grad_L[i][i] - alpha * get_gradient(terms[i], th[i]) for i in range(n) + ] + return grads + + +def _get_lola_exact_gradients(alpha, grad_L, n, th): + terms = [ + sum( + # [torch.dot(grad_L[j][i], grad_L[j][j]) for j in range(n) if j != i] + [ + torch.matmul( + grad_L[j][i], + torch.transpose(grad_L[j][j].unsqueeze(dim=1), 0, 1), + ) + for j in range(n) + if j != i + ] + ) + for i in range(n) + ] + grads = [ + grad_L[i][i] - alpha * get_gradient(terms[i], th[i]) for i in range(n) + ] + return grads + + +def _get_sos_gradients(a, alpha, b, grad_L, n, th): + xi_0 = _get_la_gradients(alpha, grad_L, n, th) + chi = [ + get_gradient( + sum( + [ + # torch.dot(grad_L[j][i].detach(), grad_L[j][j]) + torch.matmul( + grad_L[j][i].detach(), + torch.transpose(grad_L[j][j], 0, 1), + ) + for j in range(n) + if j != i + ] + ), + th[i], + ) + for i in range(n) + ] + # Compute p + # dot = torch.dot(-alpha * torch.cat(chi), torch.cat(xi_0)) + dot = torch.matmul( + -alpha * torch.cat(chi), + torch.transpose(torch.cat(xi_0), 0, 1), + ).sum() + p1 = 1 if dot >= 0 else min(1, -a * torch.norm(torch.cat(xi_0)) ** 2 / dot) + xi = torch.cat([grad_L[i][i] for i in range(n)]) + xi_norm = torch.norm(xi) + p2 = xi_norm ** 2 if xi_norm < b else 1 + p = min(p1, p2) + grads = [xi_0[i] - p * alpha * chi[i] for i in range(n)] + return grads + + +def _get_sga_gradients(ep, grad_L, n, th): + xi = torch.cat([grad_L[i][i] for i in range(n)]) + ham = torch.dot(xi, xi.detach()) + H_t_xi = [get_gradient(ham, th[i]) for i in range(n)] + H_xi = [ + get_gradient( + sum( + [ + torch.dot(grad_L[j][i], grad_L[j][j].detach()) + for j in range(n) + ] + ), + th[i], + ) + for i in range(n) + ] + A_t_xi = [H_t_xi[i] / 2 - H_xi[i] / 2 for i in range(n)] + # Compute lambda (sga with alignment) + dot_xi = torch.dot(xi, torch.cat(H_t_xi)) + dot_A = torch.dot(torch.cat(A_t_xi), torch.cat(H_t_xi)) + d = sum([len(th[i]) for i in range(n)]) + lam = torch.sign(dot_xi * dot_A / d + ep) + grads = [grad_L[i][i] + lam * A_t_xi[i] for i in range(n)] + return grads + + +def _get_co_gradients(gam, grad_L, n, th): + xi = torch.cat([grad_L[i][i] for i in range(n)]) + ham = torch.dot(xi, xi.detach()) + grads = [grad_L[i][i] + gam * get_gradient(ham, th[i]) for i in range(n)] + return grads + + +def _get_eg_gradients(Ls, alpha, losses, n, th): + th_eg = [th[i] - alpha * get_gradient(losses[i], th[i]) for i in range(n)] + losses_eg = Ls(th_eg) + grads = [get_gradient(losses_eg[i], th_eg[i]) for i in range(n)] + return grads + + +def _get_naive_learning_gradients(grad_L, n): + grads = [grad_L[i][i] for i in range(n)] + return grads + + +def _get_lss_gradients(grad_L, lss_lam, n, th): + dims = [len(th[i]) for i in range(n)] + xi = torch.cat([grad_L[i][i] for i in range(n)]) + H = get_hessian(th, grad_L) + if torch.det(H) == 0: + inv = torch.inverse( + torch.matmul(H.T, H) + lss_lam * torch.eye(sum(dims)) + ) + H_inv = torch.matmul(inv, H.T) + else: + H_inv = torch.inverse(H) + grad = ( + torch.matmul(torch.eye(sum(dims)) + torch.matmul(H.T, H_inv), xi) / 2 + ) + grads = [grad[sum(dims[:i]) : sum(dims[: i + 1])] for i in range(n)] + return grads + + +def _get_cgd_gradients(alpha, grad_L, n, th): + dims = [len(th[i]) for i in range(n)] + xi = torch.cat([grad_L[i][i] for i in range(n)]) + H_o = get_hessian(th, grad_L, diag=False) + grad = torch.matmul(torch.inverse(torch.eye(sum(dims)) + alpha * H_o), xi) + grads = [grad[sum(dims[:i]) : sum(dims[: i + 1])] for i in range(n)] + return grads + + +def get_gradient(function, param): + grad = torch.autograd.grad(torch.sum(function), param, create_graph=True)[ + 0 + ] + return grad + + +def get_hessian(th, grad_L, diag=True, off_diag=True): + n = len(th) + H = [] + for i in range(n): + row_block = [] + for j in range(n): + if (i == j and diag) or (i != j and off_diag): + block = [ + torch.unsqueeze( + get_gradient(grad_L[i][i][k], th[j]), + dim=0, + ) + for k in range(len(th[i])) + ] + row_block.append(torch.cat(block, dim=0)) + else: + row_block.append(torch.zeros(len(th[i]), len(th[j]))) + H.append(torch.cat(row_block, dim=1)) + return torch.cat(H, dim=0) diff --git a/marltoolbox/algos/stochastic_population.py b/marltoolbox/algos/stochastic_population.py new file mode 100644 index 0000000..6a76b12 --- /dev/null +++ b/marltoolbox/algos/stochastic_population.py @@ -0,0 +1,33 @@ +import torch +from ray.rllib import SampleBatch +from ray.rllib.utils import override + +from marltoolbox.algos.amTFT import base +from marltoolbox.algos.hierarchical import HierarchicalTorchPolicy + + +class StochasticPopulation(HierarchicalTorchPolicy, base.AmTFTReferenceClass): + def __init__(self, observation_space, action_space, config, **kwargs): + super().__init__(observation_space, action_space, config, **kwargs) + self.stochastic_population_policy = torch.distributions.Categorical( + probs=self.config["sampling_policy_distribution"] + ) + + def on_episode_start(self, *args, **kwargs): + self._select_algo_idx_to_use() + if hasattr(self.algorithms[self.active_algo_idx], "on_episode_start"): + self.algorithms[self.active_algo_idx].on_episode_start( + *args, + **kwargs, + ) + + def _select_algo_idx_to_use(self): + policy_idx_selected = self.stochastic_population_policy.sample() + self.active_algo_idx = policy_idx_selected + self._to_log[ + "StochasticPopulation_active_algo_idx" + ] = self.active_algo_idx + + @override(HierarchicalTorchPolicy) + def _learn_on_batch(self, samples: SampleBatch): + return self.algorithms[self.active_algo_idx]._learn_on_batch(samples) diff --git a/marltoolbox/algos/welfare_coordination.py b/marltoolbox/algos/welfare_coordination.py index 559bc83..13516a3 100644 --- a/marltoolbox/algos/welfare_coordination.py +++ b/marltoolbox/algos/welfare_coordination.py @@ -2,6 +2,7 @@ import logging import random +import torch import numpy as np from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import override @@ -27,6 +28,7 @@ "opp_player_idx": None, "own_default_welfare_fn": None, "opp_default_welfare_fn": None, + "distrib_over_welfare_sets_to_annonce": False, }, ) @@ -60,20 +62,25 @@ def setup_meta_game( self.opp_player_idx = opp_player_idx self.own_default_welfare_fn = own_default_welfare_fn self.opp_default_welfare_fn = opp_default_welfare_fn - self._list_all_welfares_fn() + self.all_welfare_fn = MetaGameSolver.list_all_welfares_fn( + self.all_welfare_pairs_wt_payoffs + ) self._list_all_set_of_welfare_fn() - def _list_all_welfares_fn(self): + @staticmethod + def list_all_welfares_fn(all_welfare_pairs_wt_payoffs): + """List all welfare functions in the given data structure""" all_welfare_fn = [] - for welfare_pairs_key in self.all_welfare_pairs_wt_payoffs.keys(): - for welfare_fn in self._key_to_pair_of_welfare_names( + for welfare_pairs_key in all_welfare_pairs_wt_payoffs.keys(): + for welfare_fn in MetaGameSolver.key_to_pair_of_welfare_names( welfare_pairs_key ): all_welfare_fn.append(welfare_fn) - self.all_welfare_fn = tuple(set(all_welfare_fn)) + return sorted(tuple(set(all_welfare_fn))) @staticmethod - def _key_to_pair_of_welfare_names(key): + def key_to_pair_of_welfare_names(key): + """Convert a key to the two welfare functions composing this key""" return key.split("-") def _list_all_set_of_welfare_fn(self): @@ -90,7 +97,7 @@ def _list_all_set_of_welfare_fn(self): frozenset(combi) for combi in combinations_object ] welfare_fn_sets.extend(combinations_set) - self.welfare_fn_sets = tuple(set(welfare_fn_sets)) + self.welfare_fn_sets = sorted(tuple(set(welfare_fn_sets))) def solve_meta_game(self, tau): """ @@ -98,7 +105,10 @@ def solve_meta_game(self, tau): :param tau: """ print("================================================") - print(f"Start solving meta game with tau={tau}") + print( + f"Player (idx={self.own_player_idx}) starts solving meta game " + f"with tau={tau}" + ) self.tau = tau self.selected_pure_policy_idx = None self.best_objective = -np.inf @@ -116,6 +126,7 @@ def solve_meta_game(self, tau): ] def _search_for_best_action(self): + self.tau_log = [] for idx, welfare_set_annonced in enumerate(self.welfare_fn_sets): optimization_objective = self._compute_optimization_objective( self.tau, welfare_set_annonced @@ -123,26 +134,40 @@ def _search_for_best_action(self): self._keep_action_if_best(optimization_objective, idx) def _compute_optimization_objective(self, tau, welfare_set_annonced): - print(f"========") - print(f"compute objective for set {welfare_set_annonced}") self._compute_payoffs_for_every_opponent_action(welfare_set_annonced) opp_best_response_idx = self._get_opp_best_response_idx() + + exploitability_term = tau * self._get_own_payoff(opp_best_response_idx) + payoffs = self._get_all_possible_payoffs_excluding_provided( + # excluding_idx=opp_best_response_idx + ) + robustness_term = (1 - tau) * sum(payoffs) / len(payoffs) + + optimization_objective = exploitability_term + robustness_term + + print(f"========") + print(f"compute objective for set {welfare_set_annonced}") print( f"opponent best response is {opp_best_response_idx} => " f"{self.welfare_fn_sets[opp_best_response_idx]}" ) - exploitability_term = tau * self._get_own_payoff(opp_best_response_idx) print("exploitability_term", exploitability_term) + print("robustness_term", robustness_term) + print("optimization_objective", optimization_objective) - payoffs = self._get_all_possible_payoffs_excluding_one( - excluding_idx=opp_best_response_idx + self.tau_log.append( + { + "welfare_set_annonced": welfare_set_annonced, + "opp_best_response_idx": opp_best_response_idx, + "opp_best_response_set": self.welfare_fn_sets[ + opp_best_response_idx + ], + "robustness_term": robustness_term, + "optimization_objective": optimization_objective, + } ) - robustness_term = (1 - tau) * sum(payoffs) / len(payoffs) - print("robustness_term", robustness_term) - optimization_objective = exploitability_term + robustness_term - print("optimization_objective", optimization_objective) return optimization_objective def _compute_payoffs_for_every_opponent_action(self, welfare_set_annonced): @@ -158,18 +183,18 @@ def _compute_meta_payoff(self, own_welfare_set, opp_welfare_set): welfare_fn_intersection = own_welfare_set & opp_welfare_set if len(welfare_fn_intersection) == 0: - return self._get_own_payoff_for_default_strategies() + return self._read_payoffs_for_default_strategies() else: - return self._get_own_payoff_averaging_over_welfare_intersection( + return self._read_payoffs_and_average_over_welfare_intersection( welfare_fn_intersection ) - def _get_own_payoff_for_default_strategies(self): + def _read_payoffs_for_default_strategies(self): return self._read_own_payoff_from_data( self.own_default_welfare_fn, self.opp_default_welfare_fn ) - def _get_own_payoff_averaging_over_welfare_intersection( + def _read_payoffs_and_average_over_welfare_intersection( self, welfare_fn_intersection ): payoffs_player_1 = [] @@ -182,16 +207,17 @@ def _get_own_payoff_averaging_over_welfare_intersection( payoffs_player_2.append(payoff_player_2) mean_payoff_p1 = sum(payoffs_player_1) / len(payoffs_player_1) mean_payoff_p2 = sum(payoffs_player_2) / len(payoffs_player_2) - return (mean_payoff_p1, mean_payoff_p2) + return mean_payoff_p1, mean_payoff_p2 def _read_own_payoff_from_data(self, own_welfare, opp_welfare): - welfare_pair_name = self._from_pair_of_welfare_names_to_key( + welfare_pair_name = self.from_pair_of_welfare_names_to_key( own_welfare, opp_welfare ) return self.all_welfare_pairs_wt_payoffs[welfare_pair_name] @staticmethod - def _from_pair_of_welfare_names_to_key(own_welfare_set, opp_welfare_set): + def from_pair_of_welfare_names_to_key(own_welfare_set, opp_welfare_set): + """Convert two welfare functions into a key identifying this pair""" return f"{own_welfare_set}-{opp_welfare_set}" def _get_opp_best_response_idx(self): @@ -201,13 +227,15 @@ def _get_opp_best_response_idx(self): ] return opp_payoffs.index(max(opp_payoffs)) - def _get_all_possible_payoffs_excluding_one(self, excluding_idx): + def _get_all_possible_payoffs_excluding_provided(self, excluding_idx=()): own_payoffs = [] for welfare_set_idx, _ in enumerate(self.welfare_fn_sets): - if welfare_set_idx != excluding_idx: + if welfare_set_idx not in excluding_idx: own_payoff = self._get_own_payoff(welfare_set_idx) own_payoffs.append(own_payoff) - assert len(own_payoffs) == len(self.welfare_fn_sets) - 1 + assert len(own_payoffs) == len(self.welfare_fn_sets) - len( + excluding_idx + ) return own_payoffs def _get_own_payoff(self, idx): @@ -253,14 +281,47 @@ def __init__( **kwargs, ) + assert ( + self._use_distribution_over_sets() + or self.config["solve_meta_game_after_init"] + ) if self.config["solve_meta_game_after_init"]: + assert not self.config["distrib_over_welfare_sets_to_annonce"] self._choose_which_welfare_set_to_annonce() + if self._use_distribution_over_sets(): + self._init_distribution_over_sets_to_announce() self._welfare_set_annonced = False self._welfare_set_in_use = None self._welfare_chosen_for_epi = False self._intersection_of_welfare_sets = None + def _init_distribution_over_sets_to_announce(self): + assert not self.config["solve_meta_game_after_init"] + self.setup_meta_game( + all_welfare_pairs_wt_payoffs=self.config[ + "all_welfare_pairs_wt_payoffs" + ], + own_player_idx=self.config["own_player_idx"], + opp_player_idx=self.config["opp_player_idx"], + own_default_welfare_fn=self.config["own_default_welfare_fn"], + opp_default_welfare_fn=self.config["opp_default_welfare_fn"], + ) + self.stochastic_announcement_policy = torch.distributions.Categorical( + probs=self.config["distrib_over_welfare_sets_to_annonce"] + ) + self._stochastic_selection_of_welfare_set_to_announce() + + def _stochastic_selection_of_welfare_set_to_announce(self): + welfare_set_idx_selected = self.stochastic_announcement_policy.sample() + self.welfare_set_to_annonce = self.welfare_fn_sets[ + welfare_set_idx_selected + ] + self._to_log["welfare_set_to_annonce"] = str( + self.welfare_set_to_annonce + ) + print("self.welfare_set_to_annonce", self.welfare_set_to_annonce) + def _choose_which_welfare_set_to_annonce(self): self.setup_meta_game( all_welfare_pairs_wt_payoffs=self.config[ @@ -272,6 +333,9 @@ def _choose_which_welfare_set_to_annonce(self): opp_default_welfare_fn=self.config["opp_default_welfare_fn"], ) self.solve_meta_game(self.config["tau"]) + self._to_log[f"tau_{self.tau}"] = self.tau_log[ + self.selected_pure_policy_idx + ] @property def policy_checkpoints(self): @@ -282,9 +346,9 @@ def policy_checkpoints(self): @policy_checkpoints.setter def policy_checkpoints(self, value): - msg = f"ignoring set self.policy_checkpoints to value {value}" - print(msg) - logger.warning(msg) + logger.warning( + f"ignoring setting self.policy_checkpoints to value {value}" + ) @override(population.PopulationOfIdenticalAlgo) def set_algo_to_use(self): @@ -310,7 +374,16 @@ def on_episode_start( if not self._welfare_chosen_for_epi: # Called only by one agent, not by both self._coordinate_on_welfare_to_use_for_epi(worker) - self.set_algo_to_use() + super().on_episode_start( + *args, + worker, + policy, + policy_id, + policy_ids, + **kwargs, + ) + assert self._welfare_set_annonced + assert self._welfare_chosen_for_epi @staticmethod def _annonce_welfare_sets( @@ -357,19 +430,35 @@ def _coordinate_on_welfare_to_use_for_epi(worker): if isinstance(policy, WelfareCoordinationTorchPolicy): if len(policy._intersection_of_welfare_sets) > 0: if welfare_to_play is None: - print( - "policy._welfare_set_in_use", - policy._welfare_set_in_use, - ) welfare_to_play = random.choice( tuple(policy._welfare_set_in_use) ) policy._welfare_in_use = welfare_to_play + msg = ( + "Policies coordinated " + "to play for the next epi: " + f"_welfare_set_in_use={policy._welfare_set_in_use}" + f"and selected: policy._welfare_in_use={policy._welfare_in_use}" + ) + logger.info(msg) + print(msg) else: policy._welfare_in_use = random.choice( tuple(policy._welfare_set_in_use) ) + msg = ( + "Policies did NOT coordinate " + "to play for the next epi: " + f"_welfare_set_in_use={policy._welfare_set_in_use}" + f"and selected: policy._welfare_in_use={policy._welfare_in_use}" + ) + logger.info(msg) + print(msg) policy._welfare_chosen_for_epi = True + policy._to_log["welfare_in_use"] = policy._welfare_in_use + policy._to_log[ + "welfare_set_in_use" + ] = policy._welfare_set_in_use def on_episode_end( self, @@ -377,4 +466,36 @@ def on_episode_end( **kwargs, ): self._welfare_chosen_for_epi = False - self.algorithms[self.active_algo_idx].on_episode_end(*args, **kwargs) + if hasattr(self.algorithms[self.active_algo_idx], "on_episode_end"): + self.algorithms[self.active_algo_idx].on_episode_end( + *args, **kwargs + ) + if self._use_distribution_over_sets(): + self._stochastic_selection_of_welfare_set_to_announce() + self._welfare_set_annonced = False + + def _use_distribution_over_sets(self): + return not self.config["distrib_over_welfare_sets_to_annonce"] is False + + @property + @override(population.PopulationOfIdenticalAlgo) + def to_log(self): + to_log = { + "meta_policy": self._to_log, + "nested_policy": { + f"policy_{algo_idx}": algo.to_log + for algo_idx, algo in enumerate(self.algorithms) + if hasattr(algo, "to_log") + }, + } + return to_log + + @to_log.setter + @override(population.PopulationOfIdenticalAlgo) + def to_log(self, value): + if value == {}: + for algo in self.algorithms: + if hasattr(algo, "to_log"): + algo.to_log = {} + + self._to_log = value diff --git a/marltoolbox/envs/coin_game.py b/marltoolbox/envs/coin_game.py index 8c56f8f..4561c55 100644 --- a/marltoolbox/envs/coin_game.py +++ b/marltoolbox/envs/coin_game.py @@ -14,17 +14,22 @@ logger = logging.getLogger(__name__) -PLOT_KEYS = ["pick_speed", - "pick_own_color", - ] +PLOT_KEYS = [ + "pick_speed", + "pick_own_color", +] PLOT_ASSEMBLAGE_TAGS = [ ("pick_own_color_player_red_mean", "pick_own_color_player_blue_mean"), ("pick_speed_player_red_mean", "pick_speed_player_blue_mean"), ("pick_own_color",), ("pick_speed", "pick_own_color"), - ("pick_own_color_player_red_mean", "pick_own_color_player_blue_mean", - "pick_speed_player_red_mean", "pick_speed_player_blue_mean"), + ( + "pick_own_color_player_red_mean", + "pick_own_color_player_blue_mean", + "pick_speed_player_red_mean", + "pick_speed_player_blue_mean", + ), ] @@ -32,6 +37,7 @@ class CoinGame(InfoAccumulationInterface, MultiAgentEnv, gym.Env): """ Coin Game environment. """ + NAME = "CoinGame" NUM_AGENTS = 2 NUM_ACTIONS = 4 @@ -55,7 +61,7 @@ def __init__(self, config: Dict = {}): low=0, high=1, shape=(self.grid_size, self.grid_size, 4), - dtype="uint8" + dtype="uint8", ) self.step_count_in_current_episode = None @@ -69,17 +75,22 @@ def _validate_config(self, config): assert len(config["players_ids"]) == self.NUM_AGENTS def _load_config(self, config): - self.players_ids = \ - config.get("players_ids", ["player_red", "player_blue"]) + self.players_ids = config.get( + "players_ids", ["player_red", "player_blue"] + ) self.max_steps = config.get("max_steps", 20) self.grid_size = config.get("grid_size", 3) - self.output_additional_info = config.get("output_additional_info", - True) + self.output_additional_info = config.get( + "output_additional_info", True + ) self.asymmetric = config.get("asymmetric", False) - self.both_players_can_pick_the_same_coin = \ - config.get("both_players_can_pick_the_same_coin", True) - self.same_obs_for_each_player = \ - config.get("same_obs_for_each_player", True) + self.both_players_can_pick_the_same_coin = config.get( + "both_players_can_pick_the_same_coin", True + ) + self.same_obs_for_each_player = config.get( + "same_obs_for_each_player", True + ) + self.punishment_helped = config.get("punishment_helped", False) @override(gym.Env) def seed(self, seed=None): @@ -89,6 +100,7 @@ def seed(self, seed=None): @override(gym.Env) def reset(self): + # print("reset") self.step_count_in_current_episode = 0 if self.output_additional_info: @@ -98,18 +110,17 @@ def reset(self): self._generate_coin() obs = self._generate_observation() - return { - self.player_red_id: obs[0], - self.player_blue_id: obs[1] - } + return {self.player_red_id: obs[0], self.player_blue_id: obs[1]} def _randomize_color_and_player_positions(self): # Reset coin color and the players and coin positions self.red_coin = self.np_random.randint(low=0, high=2) - self.red_pos = \ - self.np_random.randint(low=0, high=self.grid_size, size=(2,)) - self.blue_pos = \ - self.np_random.randint(low=0, high=self.grid_size, size=(2,)) + self.red_pos = self.np_random.randint( + low=0, high=self.grid_size, size=(2,) + ) + self.blue_pos = self.np_random.randint( + low=0, high=self.grid_size, size=(2,) + ) # self.coin_pos = np.zeros(shape=(2,), dtype=np.int8) self._players_do_not_overlap_at_start() @@ -165,51 +176,62 @@ def _same_pos(self, x, y): return (x == y).all() def _move_players(self, actions): - self.red_pos = \ - (self.red_pos + self.MOVES[actions[0]]) % self.grid_size - self.blue_pos = \ - (self.blue_pos + self.MOVES[actions[1]]) % self.grid_size + self.red_pos = (self.red_pos + self.MOVES[actions[0]]) % self.grid_size + self.blue_pos = ( + self.blue_pos + self.MOVES[actions[1]] + ) % self.grid_size def _compute_reward(self): reward_red = 0.0 reward_blue = 0.0 generate_new_coin = False - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = \ - False, False, False, False + red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = ( + False, + False, + False, + False, + ) red_first_if_both = None if not self.both_players_can_pick_the_same_coin: - if self._same_pos(self.red_pos, self.coin_pos) and \ - self._same_pos(self.blue_pos, self.coin_pos): + if self._same_pos(self.red_pos, self.coin_pos) and self._same_pos( + self.blue_pos, self.coin_pos + ): red_first_if_both = bool(self.np_random.randint(low=0, high=2)) if self.red_coin: - if self._same_pos(self.red_pos, self.coin_pos) and \ - (red_first_if_both is None or red_first_if_both): + if self._same_pos(self.red_pos, self.coin_pos) and ( + red_first_if_both is None or red_first_if_both + ): generate_new_coin = True reward_red += 1 if self.asymmetric: reward_red += 3 red_pick_any = True red_pick_red = True - if self._same_pos(self.blue_pos, self.coin_pos) and \ - (red_first_if_both is None or not red_first_if_both): + if self._same_pos(self.blue_pos, self.coin_pos) and ( + red_first_if_both is None or not red_first_if_both + ): generate_new_coin = True reward_red += -2 reward_blue += 1 blue_pick_any = True + if self.asymmetric and self.punishment_helped: + reward_red -= 6 else: - if self._same_pos(self.red_pos, self.coin_pos) and \ - (red_first_if_both is None or red_first_if_both): + if self._same_pos(self.red_pos, self.coin_pos) and ( + red_first_if_both is None or red_first_if_both + ): generate_new_coin = True reward_red += 1 reward_blue += -2 if self.asymmetric: reward_red += 3 red_pick_any = True - if self._same_pos(self.blue_pos, self.coin_pos) and \ - (red_first_if_both is None or not red_first_if_both): + if self._same_pos(self.blue_pos, self.coin_pos) and ( + red_first_if_both is None or not red_first_if_both + ): generate_new_coin = True reward_blue += 1 blue_pick_blue = True @@ -217,10 +239,12 @@ def _compute_reward(self): reward_list = [reward_red, reward_blue] if self.output_additional_info: - self._accumulate_info(red_pick_any=red_pick_any, - red_pick_red=red_pick_red, - blue_pick_any=blue_pick_any, - blue_pick_blue=blue_pick_blue) + self._accumulate_info( + red_pick_any=red_pick_any, + red_pick_red=red_pick_red, + blue_pick_any=blue_pick_any, + blue_pick_blue=blue_pick_blue, + ) return reward_list, generate_new_coin @@ -260,11 +284,12 @@ def _to_RLLib_API(self, observations, rewards): self.player_blue_id: rewards[1], } - epi_is_done = (self.step_count_in_current_episode >= self.max_steps) + epi_is_done = self.step_count_in_current_episode >= self.max_steps if self.step_count_in_current_episode > self.max_steps: logger.warning( "step_count_in_current_episode > self.max_steps: " - f"{self.step_count_in_current_episode} > {self.max_steps}") + f"{self.step_count_in_current_episode} > {self.max_steps}" + ) done = { self.player_red_id: epi_is_done, @@ -283,7 +308,7 @@ def _to_RLLib_API(self, observations, rewards): return obs, rewards, done, info @override(InfoAccumulationInterface) - def _get_episode_info(self): + def _get_episode_info(self, n_steps_played=None): """ Output the following information: pick_speed is the fraction of steps during which the player picked a @@ -292,20 +317,25 @@ def _get_episode_info(self): the same color as the player. """ player_red_info, player_blue_info = {}, {} + if n_steps_played is None: + n_steps_played = len(self.red_pick) + assert len(self.red_pick) == len(self.blue_pick) if len(self.red_pick) > 0: red_pick = sum(self.red_pick) - player_red_info["pick_speed"] = red_pick / len(self.red_pick) + player_red_info["pick_speed"] = red_pick / n_steps_played if red_pick > 0: - player_red_info["pick_own_color"] = \ + player_red_info["pick_own_color"] = ( sum(self.red_pick_own) / red_pick + ) if len(self.blue_pick) > 0: blue_pick = sum(self.blue_pick) - player_blue_info["pick_speed"] = blue_pick / len(self.blue_pick) + player_blue_info["pick_speed"] = blue_pick / n_steps_played if blue_pick > 0: - player_blue_info["pick_own_color"] = \ + player_blue_info["pick_own_color"] = ( sum(self.blue_pick_own) / blue_pick + ) return player_red_info, player_blue_info @@ -318,7 +348,8 @@ def _reset_info(self): @override(InfoAccumulationInterface) def _accumulate_info( - self, red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue): + self, red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue + ): self.red_pick.append(red_pick_any) self.red_pick_own.append(red_pick_red) diff --git a/marltoolbox/envs/matrix_sequential_social_dilemma.py b/marltoolbox/envs/matrix_sequential_social_dilemma.py index 52fd4f3..05385a4 100644 --- a/marltoolbox/envs/matrix_sequential_social_dilemma.py +++ b/marltoolbox/envs/matrix_sequential_social_dilemma.py @@ -43,7 +43,7 @@ class MatrixSequentialSocialDilemma( """ A multi-agent abstract class for two player matrix games. - PAYOUT_MATRIX: Numpy array. Along the dimension N, the action of the + PAYOFF_MATRIX: Numpy array. Along the dimension N, the action of the Nth player change. The last dimension is used to select the player whose reward you want to know. @@ -56,10 +56,22 @@ class MatrixSequentialSocialDilemma( episode. """ + NUM_AGENTS = 2 + NUM_ACTIONS = None + NUM_STATES = None + ACTION_SPACE = None + OBSERVATION_SPACE = None + PAYOFF_MATRIX = None + NAME = None + def __init__(self, config: Dict = {}): assert "reward_randomness" not in config.keys() - assert self.PAYOUT_MATRIX is not None + assert self.PAYOFF_MATRIX is not None + assert self.PAYOFF_MATRIX.shape[0] == self.NUM_ACTIONS + assert self.PAYOFF_MATRIX.shape[1] == self.NUM_ACTIONS + assert self.PAYOFF_MATRIX.shape[2] == self.NUM_AGENTS + assert len(self.PAYOFF_MATRIX.shape) == 3 if "players_ids" in config: assert ( isinstance(config["players_ids"], Iterable) @@ -156,8 +168,8 @@ def _produce_observations_invariant_to_the_player_trained( def _get_players_rewards(self, action_player_0: int, action_player_1: int): return [ - self.PAYOUT_MATRIX[action_player_0][action_player_1][0], - self.PAYOUT_MATRIX[action_player_0][action_player_1][1], + self.PAYOFF_MATRIX[action_player_0][action_player_1][0], + self.PAYOFF_MATRIX[action_player_0][action_player_1][1], ] def _to_RLLib_API( @@ -205,12 +217,11 @@ class IteratedMatchingPennies( A two-agent environment for the Matching Pennies game. """ - NUM_AGENTS = 2 NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array([[[+1, -1], [-1, +1]], [[-1, +1], [+1, -1]]]) + PAYOFF_MATRIX = np.array([[[+1, -1], [-1, +1]], [[-1, +1], [+1, -1]]]) NAME = "IMP" @@ -221,12 +232,11 @@ class IteratedPrisonersDilemma( A two-agent environment for the Prisoner's Dilemma game. """ - NUM_AGENTS = 2 NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array([[[-1, -1], [-3, +0]], [[+0, -3], [-2, -2]]]) + PAYOFF_MATRIX = np.array([[[-1, -1], [-3, +0]], [[+0, -3], [-2, -2]]]) NAME = "IPD" @@ -237,12 +247,11 @@ class IteratedAsymPrisonersDilemma( A two-agent environment for the Asymmetric Prisoner's Dilemma game. """ - NUM_AGENTS = 2 NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array([[[+0, -1], [-3, +0]], [[+0, -3], [-2, -2]]]) + PAYOFF_MATRIX = np.array([[[+0, -1], [-3, +0]], [[+0, -3], [-2, -2]]]) NAME = "IPD" @@ -253,12 +262,11 @@ class IteratedStagHunt( A two-agent environment for the Stag Hunt game. """ - NUM_AGENTS = 2 NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array([[[3, 3], [0, 2]], [[2, 0], [1, 1]]]) + PAYOFF_MATRIX = np.array([[[3, 3], [0, 2]], [[2, 0], [1, 1]]]) NAME = "IteratedStagHunt" @@ -269,12 +277,11 @@ class IteratedChicken( A two-agent environment for the Chicken game. """ - NUM_AGENTS = 2 NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array( + PAYOFF_MATRIX = np.array( [[[+0, +0], [-1.0, +1.0]], [[+1, -1], [-10, -10]]] ) NAME = "IteratedChicken" @@ -287,12 +294,11 @@ class IteratedAsymChicken( A two-agent environment for the Asymmetric Chicken game. """ - NUM_AGENTS = 2 NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array( + PAYOFF_MATRIX = np.array( [[[+2.0, +0], [-1.0, +1.0]], [[+2.5, -1], [-10, -10]]] ) NAME = "AsymmetricIteratedChicken" @@ -305,12 +311,11 @@ class IteratedBoS( A two-agent environment for the BoS game. """ - NUM_AGENTS = 2 NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array( + PAYOFF_MATRIX = np.array( [[[+3.0, +2.0], [+0.0, +0.0]], [[+0.0, +0.0], [+2.0, +3.0]]] ) NAME = "IteratedBoS" @@ -323,31 +328,31 @@ class IteratedAsymBoS( A two-agent environment for the BoS game. """ - NUM_AGENTS = 2 NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array( + PAYOFF_MATRIX = np.array( [[[+4.0, +1.0], [+0.0, +0.0]], [[+0.0, +0.0], [+2.0, +2.0]]] ) - NAME = "AsymmetricIteratedBoS" + NAME = "IteratedAsymBoS" def define_greed_fear_matrix_game(greed, fear): class GreedFearGame( TwoPlayersTwoActionsInfoMixin, MatrixSequentialSocialDilemma ): - NUM_AGENTS = 2 NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = ( + NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 + ) ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) R = 3 P = 1 T = R + greed S = P - fear - PAYOUT_MATRIX = np.array([[[R, R], [S, T]], [[T, S], [P, P]]]) + PAYOFF_MATRIX = np.array([[[R, R], [S, T]], [[T, S], [P, P]]]) NAME = "IteratedGreedFear" def __str__(self): @@ -363,12 +368,11 @@ class IteratedBoSAndPD( A two-agent environment for the BOTS + PD game. """ - NUM_AGENTS = 2 NUM_ACTIONS = 3 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array( + PAYOFF_MATRIX = np.array( [ [[3.5, +1], [+0, +0], [-3, +2]], [[+0.0, +0], [+1, +3], [-3, +2]], @@ -376,3 +380,28 @@ class IteratedBoSAndPD( ] ) NAME = "IteratedBoSAndPD" + + +class TwoPlayersCustomizableMatrixGame( + NPlayersNDiscreteActionsInfoMixin, MatrixSequentialSocialDilemma +): + """ + A two-agent environment for the BOTS + PD game. + """ + + NAME = "TwoPlayersCustomizableMatrixGame" + + NUM_ACTIONS = None + NUM_STATES = None + ACTION_SPACE = None + OBSERVATION_SPACE = None + PAYOFF_MATRIX = None + + def __init__(self, config: Dict): + self.PAYOFF_MATRIX = config["PAYOFF_MATRIX"] + self.NUM_ACTIONS = config["NUM_ACTIONS"] + self.ACTION_SPACE = Discrete(self.NUM_ACTIONS) + self.NUM_STATES = self.NUM_ACTIONS ** self.NUM_AGENTS + 1 + self.OBSERVATION_SPACE = Discrete(self.NUM_STATES) + + super().__init__(config) diff --git a/marltoolbox/envs/ssd_mixed_motive_coin_game.py b/marltoolbox/envs/ssd_mixed_motive_coin_game.py index 0af66f7..03dbf84 100644 --- a/marltoolbox/envs/ssd_mixed_motive_coin_game.py +++ b/marltoolbox/envs/ssd_mixed_motive_coin_game.py @@ -40,6 +40,7 @@ class SSDMixedMotiveCoinGame(coin_game.CoinGame): @override(coin_game.CoinGame) def __init__(self, config: dict = {}): super().__init__(config) + self.punishment_helped = config.get("punishment_helped", False) self.OBSERVATION_SPACE = gym.spaces.Box( low=0, @@ -137,11 +138,11 @@ def _compute_reward(self): if self._same_pos(self.red_pos, self.red_coin_pos): if self.red_coin and self._same_pos( - self.blue_pos, self.red_coin_pos + self.blue_pos, self.red_coin_pos ): # Red coin is a coop coin generate_new_coin = True - reward_red += 1.2 + reward_red += 2.0 red_pick_any = True red_pick_red = True blue_pick_any = True @@ -152,13 +153,18 @@ def _compute_reward(self): reward_red += 1.0 red_pick_any = True red_pick_red = True + if self.punishment_helped and self._same_pos( + self.blue_pos, self.red_coin_pos + ): + reward_red -= 0.75 + elif self._same_pos(self.blue_pos, self.blue_coin_pos): if not self.red_coin and self._same_pos( - self.red_pos, self.blue_coin_pos + self.red_pos, self.blue_coin_pos ): # Blue coin is a coop coin generate_new_coin = True - reward_blue += 2.0 + reward_blue += 3.0 red_pick_any = True blue_pick_any = True blue_pick_blue = True @@ -169,6 +175,10 @@ def _compute_reward(self): reward_blue += 1.0 blue_pick_any = True blue_pick_blue = True + if self.punishment_helped and self._same_pos( + self.red_pos, self.blue_coin_pos + ): + reward_red -= 0.75 reward_list = [reward_red, reward_blue] if self.output_additional_info: @@ -185,37 +195,29 @@ def _compute_reward(self): @override(coin_game.CoinGame) def _init_info(self): - self.red_pick = [] - self.red_pick_own = [] - self.blue_pick = [] - self.blue_pick_own = [] + super()._init_info() self.picked_red_coop = [] self.picked_blue_coop = [] @override(coin_game.CoinGame) def _reset_info(self): - self.red_pick.clear() - self.red_pick_own.clear() - self.blue_pick.clear() - self.blue_pick_own.clear() + super()._reset_info() self.picked_red_coop.clear() self.picked_blue_coop.clear() @override(coin_game.CoinGame) def _accumulate_info( - self, - red_pick_any, - red_pick_red, - blue_pick_any, - blue_pick_blue, - picked_red_coop, - picked_blue_coop, + self, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + picked_red_coop, + picked_blue_coop, ): - - self.red_pick.append(red_pick_any) - self.red_pick_own.append(red_pick_red) - self.blue_pick.append(blue_pick_any) - self.blue_pick_own.append(blue_pick_blue) + super()._accumulate_info( + red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue + ) self.picked_red_coop.append(picked_red_coop) self.picked_blue_coop.append(picked_blue_coop) @@ -228,38 +230,25 @@ def _get_episode_info(self): pick_own_color is the fraction of coins picked by the player which have the same color as the player. """ - player_red_info, player_blue_info = {}, {} + player_red_info, player_blue_info = super()._get_episode_info() + n_steps_played = len(self.red_pick) assert n_steps_played == len(self.blue_pick) n_coop = sum(self.picked_blue_coop) + sum(self.picked_red_coop) if len(self.red_pick) > 0: red_pick = sum(self.red_pick) - player_red_info["pick_speed"] = red_pick / n_steps_played - if red_pick > 0: - player_red_info["pick_own_color"] = ( - sum(self.red_pick_own) / red_pick - ) - player_red_info["red_coop_speed"] = ( - sum(self.picked_red_coop) / n_steps_played + sum(self.picked_red_coop) / n_steps_played ) - if red_pick > 0: player_red_info["red_coop_fraction"] = n_coop / red_pick if len(self.blue_pick) > 0: blue_pick = sum(self.blue_pick) - player_blue_info["pick_speed"] = blue_pick / n_steps_played - if blue_pick > 0: - player_blue_info["pick_own_color"] = ( - sum(self.blue_pick_own) / blue_pick - ) - player_blue_info["blue_coop_speed"] = ( - sum(self.picked_blue_coop) / n_steps_played + sum(self.picked_blue_coop) / n_steps_played ) - if blue_pick > 0: player_blue_info["blue_coop_fraction"] = n_coop / blue_pick diff --git a/marltoolbox/envs/vectorized_coin_game.py b/marltoolbox/envs/vectorized_coin_game.py index 5f9f4f6..0fed500 100644 --- a/marltoolbox/envs/vectorized_coin_game.py +++ b/marltoolbox/envs/vectorized_coin_game.py @@ -28,17 +28,21 @@ def __init__(self, config={}): self.batch_size = config.get("batch_size", 1) self.force_vectorized = config.get("force_vectorize", False) - assert self.grid_size == 3, \ - "hardcoded in the generate_state numba function" + self.punishment_helped = config.get("punishment_helped", False) + assert ( + self.grid_size == 3 + ), "hardcoded in the generate_state numba function" @override(coin_game.CoinGame) def _randomize_color_and_player_positions(self): # Reset coin color and the players and coin positions - self.red_coin = np.random.randint(2, size=self.batch_size) + self.red_coin = np.random.randint(low=0, high=2, size=self.batch_size) self.red_pos = np.random.randint( - self.grid_size, size=(self.batch_size, 2)) + self.grid_size, size=(self.batch_size, 2) + ) self.blue_pos = np.random.randint( - self.grid_size, size=(self.batch_size, 2)) + self.grid_size, size=(self.batch_size, 2) + ) self.coin_pos = np.zeros((self.batch_size, 2), dtype=np.int8) self._players_do_not_overlap_at_start() @@ -52,18 +56,30 @@ def _players_do_not_overlap_at_start(self): @override(coin_game.CoinGame) def _generate_coin(self): generate = np.ones(self.batch_size, dtype=bool) - self.coin_pos = generate_coin_wt_numba_optimization( - self.batch_size, generate, self.red_coin, self.red_pos, - self.blue_pos, self.coin_pos, self.grid_size) + self.coin_pos, self.red_coin = generate_coin_wt_numba_optimization( + self.batch_size, + generate, + self.red_coin, + self.red_pos, + self.blue_pos, + self.coin_pos, + self.grid_size, + ) @override(coin_game.CoinGame) def _generate_observation(self): obs = generate_observations_wt_numba_optimization( - self.batch_size, self.red_pos, self.blue_pos, self.coin_pos, - self.red_coin, self.grid_size) + self.batch_size, + self.red_pos, + self.blue_pos, + self.coin_pos, + self.red_coin, + self.grid_size, + ) obs = self._apply_optional_invariance_to_the_player_trained(obs) obs, _ = self._optional_unvectorize(obs) + # print("env nonzero obs", np.nonzero(obs[0])) return obs def _optional_unvectorize(self, obs, rewards=None): @@ -75,49 +91,65 @@ def _optional_unvectorize(self, obs, rewards=None): @override(coin_game.CoinGame) def step(self, actions: Iterable): - + # print("step") + # print("env self.red_coin", self.red_coin) + # print("env self.red_pos", self.red_pos) + # print("env self.blue_pos", self.blue_pos) + # print("env self.coin_pos", self.coin_pos) actions = self._from_RLLib_API_to_list(actions) self.step_count_in_current_episode += 1 - (self.red_pos, self.blue_pos, rewards, self.coin_pos, observation, - self.red_coin, red_pick_any, red_pick_red, blue_pick_any, - blue_pick_blue) = vectorized_step_wt_numba_optimization( - actions, self.batch_size, self.red_pos, self.blue_pos, - self.coin_pos, self.red_coin, self.grid_size, self.asymmetric, - self.max_steps, self.both_players_can_pick_the_same_coin) + ( + self.red_pos, + self.blue_pos, + rewards, + self.coin_pos, + observation, + self.red_coin, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + ) = vectorized_step_wt_numba_optimization( + actions, + self.batch_size, + self.red_pos, + self.blue_pos, + self.coin_pos, + self.red_coin, + self.grid_size, + self.asymmetric, + self.max_steps, + self.both_players_can_pick_the_same_coin, + self.punishment_helped, + ) if self.output_additional_info: self._accumulate_info( - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue) + red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue + ) obs = self._apply_optional_invariance_to_the_player_trained( - observation) + observation + ) obs, rewards = self._optional_unvectorize(obs, rewards) - + # print("env actions", actions) + # print("env rewards", rewards) + # print("env self.red_coin", self.red_coin) + # print("env self.red_pos", self.red_pos) + # print("env self.blue_pos", self.blue_pos) + # print("env self.coin_pos", self.coin_pos) + # print("env nonzero obs", np.nonzero(obs[0])) return self._to_RLLib_API(obs, rewards) @override(coin_game.CoinGame) - def _get_episode_info(self): - - player_red_info, player_blue_info = {}, {} - - if len(self.red_pick) > 0: - red_pick = sum(self.red_pick) - player_red_info["pick_speed"] = \ - red_pick / (len(self.red_pick) * self.batch_size) - if red_pick > 0: - player_red_info["pick_own_color"] = \ - sum(self.red_pick_own) / red_pick - - if len(self.blue_pick) > 0: - blue_pick = sum(self.blue_pick) - player_blue_info["pick_speed"] = \ - blue_pick / (len(self.blue_pick) * self.batch_size) - if blue_pick > 0: - player_blue_info["pick_own_color"] = \ - sum(self.blue_pick_own) / blue_pick - - return player_red_info, player_blue_info + def _get_episode_info(self, n_steps_played=None): + n_steps_played = ( + n_steps_played + if n_steps_played is not None + else len(self.red_pick) * self.batch_size + ) + return super()._get_episode_info(n_steps_played=n_steps_played) @override(coin_game.CoinGame) def _from_RLLib_API_to_list(self, actions): @@ -140,15 +172,13 @@ def _save_env(self): "grid_size": self.grid_size, "asymmetric": self.asymmetric, "batch_size": self.batch_size, - "step_count_in_current_episode": - self.step_count_in_current_episode, + "step_count_in_current_episode": self.step_count_in_current_episode, "max_steps": self.max_steps, "red_pick": self.red_pick, "red_pick_own": self.red_pick_own, "blue_pick": self.blue_pick, "blue_pick_own": self.blue_pick_own, - "both_players_can_pick_the_same_coin": - self.both_players_can_pick_the_same_coin, + "both_players_can_pick_the_same_coin": self.both_players_can_pick_the_same_coin, } return copy.deepcopy(env_save_state) @@ -170,94 +200,149 @@ def __init__(self, config={}): @jit(nopython=True) def vectorized_step_wt_numba_optimization( - actions, batch_size, red_pos, blue_pos, coin_pos, red_coin, - grid_size: int, asymmetric: bool, max_steps: int, - both_players_can_pick_the_same_coin: bool): + actions, + batch_size, + red_pos, + blue_pos, + coin_pos, + red_coin, + grid_size: int, + asymmetric: bool, + max_steps: int, + both_players_can_pick_the_same_coin: bool, + punishment_helped: bool, +): red_pos, blue_pos = move_players( - batch_size, actions, red_pos, blue_pos, grid_size) + batch_size, actions, red_pos, blue_pos, grid_size + ) - reward, generate, red_pick_any, red_pick_red, \ - blue_pick_any, blue_pick_blue = compute_reward( - batch_size, red_pos, blue_pos, coin_pos, red_coin, - asymmetric, both_players_can_pick_the_same_coin) + ( + reward, + generate, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + ) = compute_reward( + batch_size, + red_pos, + blue_pos, + coin_pos, + red_coin, + asymmetric, + both_players_can_pick_the_same_coin, + punishment_helped, + ) - coin_pos = generate_coin_wt_numba_optimization( - batch_size, generate, red_coin, red_pos, blue_pos, coin_pos, grid_size) + coin_pos, red_coin = generate_coin_wt_numba_optimization( + batch_size, generate, red_coin, red_pos, blue_pos, coin_pos, grid_size + ) obs = generate_observations_wt_numba_optimization( - batch_size, red_pos, blue_pos, coin_pos, red_coin, grid_size) + batch_size, red_pos, blue_pos, coin_pos, red_coin, grid_size + ) - return red_pos, blue_pos, reward, coin_pos, obs, red_coin, red_pick_any, \ - red_pick_red, blue_pick_any, blue_pick_blue + return ( + red_pos, + blue_pos, + reward, + coin_pos, + obs, + red_coin, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + ) @jit(nopython=True) def move_players(batch_size, actions, red_pos, blue_pos, grid_size): - moves = List([ - np.array([0, 1]), - np.array([0, -1]), - np.array([1, 0]), - np.array([-1, 0]), - ]) + moves = List( + [ + np.array([0, 1]), + np.array([0, -1]), + np.array([1, 0]), + np.array([-1, 0]), + ] + ) for j in prange(batch_size): - red_pos[j] = \ - (red_pos[j] + moves[actions[j, 0]]) % grid_size - blue_pos[j] = \ - (blue_pos[j] + moves[actions[j, 1]]) % grid_size + red_pos[j] = (red_pos[j] + moves[actions[j, 0]]) % grid_size + blue_pos[j] = (blue_pos[j] + moves[actions[j, 1]]) % grid_size return red_pos, blue_pos @jit(nopython=True) -def compute_reward(batch_size, red_pos, blue_pos, coin_pos, red_coin, - asymmetric, both_players_can_pick_the_same_coin): +def compute_reward( + batch_size, + red_pos, + blue_pos, + coin_pos, + red_coin, + asymmetric, + both_players_can_pick_the_same_coin, + punishment_helped, +): reward_red = np.zeros(batch_size) reward_blue = np.zeros(batch_size) generate = np.zeros(batch_size, dtype=np.bool_) - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = \ - 0, 0, 0, 0 + red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = 0, 0, 0, 0 for i in prange(batch_size): red_first_if_both = None if not both_players_can_pick_the_same_coin: - if _same_pos(red_pos[i], coin_pos[i]) and \ - _same_pos(blue_pos[i], coin_pos[i]): + if _same_pos(red_pos[i], coin_pos[i]) and _same_pos( + blue_pos[i], coin_pos[i] + ): red_first_if_both = bool(np.random.randint(low=0, high=2)) if red_coin[i]: - if _same_pos(red_pos[i], coin_pos[i]) and \ - (red_first_if_both is None or red_first_if_both): + if _same_pos(red_pos[i], coin_pos[i]) and ( + red_first_if_both is None or red_first_if_both + ): generate[i] = True reward_red[i] += 1 if asymmetric: reward_red[i] += 3 red_pick_any += 1 red_pick_red += 1 - if _same_pos(blue_pos[i], coin_pos[i]) and \ - (red_first_if_both is None or not red_first_if_both): + if _same_pos(blue_pos[i], coin_pos[i]) and ( + red_first_if_both is None or not red_first_if_both + ): generate[i] = True reward_red[i] += -2 reward_blue[i] += 1 blue_pick_any += 1 + if asymmetric and punishment_helped: + reward_red[i] -= 6 else: - if _same_pos(red_pos[i], coin_pos[i]) and \ - (red_first_if_both is None or red_first_if_both): + if _same_pos(red_pos[i], coin_pos[i]) and ( + red_first_if_both is None or red_first_if_both + ): generate[i] = True reward_red[i] += 1 reward_blue[i] += -2 if asymmetric: reward_red[i] += 3 red_pick_any += 1 - if _same_pos(blue_pos[i], coin_pos[i]) and \ - (red_first_if_both is None or not red_first_if_both): + if _same_pos(blue_pos[i], coin_pos[i]) and ( + red_first_if_both is None or not red_first_if_both + ): generate[i] = True reward_blue[i] += 1 blue_pick_any += 1 blue_pick_blue += 1 reward = [reward_red, reward_blue] - return reward, generate, \ - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue + return ( + reward, + generate, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + ) @jit(nopython=True) @@ -267,13 +352,13 @@ def _same_pos(x, y): @jit(nopython=True) def generate_coin_wt_numba_optimization( - batch_size, generate, red_coin, red_pos, blue_pos, coin_pos, - grid_size): + batch_size, generate, red_coin, red_pos, blue_pos, coin_pos, grid_size +): red_coin[generate] = 1 - red_coin[generate] for i in prange(batch_size): if generate[i]: coin_pos[i] = _place_coin(red_pos[i], blue_pos[i], grid_size) - return coin_pos + return coin_pos, red_coin @jit(nopython=True) @@ -303,8 +388,9 @@ def _unflatten_index(pos, grid_size): @jit(nopython=True) -def generate_observations_wt_numba_optimization(batch_size, red_pos, blue_pos, - coin_pos, red_coin, grid_size): +def generate_observations_wt_numba_optimization( + batch_size, red_pos, blue_pos, coin_pos, red_coin, grid_size +): obs = np.zeros((batch_size, grid_size, grid_size, 4)) for i in prange(batch_size): obs[i, red_pos[i][0], red_pos[i][1], 0] = 1 diff --git a/marltoolbox/envs/vectorized_mixed_motive_coin_game.py b/marltoolbox/envs/vectorized_mixed_motive_coin_game.py index ced01bd..a851f05 100644 --- a/marltoolbox/envs/vectorized_mixed_motive_coin_game.py +++ b/marltoolbox/envs/vectorized_mixed_motive_coin_game.py @@ -7,8 +7,12 @@ from ray.rllib.utils import override from marltoolbox.envs import vectorized_coin_game -from marltoolbox.envs.vectorized_coin_game import \ - _flatten_index, _unflatten_index, _same_pos, move_players +from marltoolbox.envs.vectorized_coin_game import ( + _flatten_index, + _unflatten_index, + _same_pos, + move_players, +) logger = logging.getLogger(__name__) @@ -18,21 +22,23 @@ class VectMixedMotiveCG(vectorized_coin_game.VectorizedCoinGame): - @override(vectorized_coin_game.VectorizedCoinGame) def _load_config(self, config): super()._load_config(config) - assert self.both_players_can_pick_the_same_coin, \ - "both_players_can_pick_the_same_coin option must be True in " \ + assert self.both_players_can_pick_the_same_coin, ( + "both_players_can_pick_the_same_coin option must be True in " "mixed motive coin game." + ) @override(vectorized_coin_game.VectorizedCoinGame) def _randomize_color_and_player_positions(self): # Reset coin color and the players and coin positions self.red_pos = np.random.randint( - self.grid_size, size=(self.batch_size, 2)) + self.grid_size, size=(self.batch_size, 2) + ) self.blue_pos = np.random.randint( - self.grid_size, size=(self.batch_size, 2)) + self.grid_size, size=(self.batch_size, 2) + ) self.red_coin_pos = np.zeros((self.batch_size, 2), dtype=np.int8) self.blue_coin_pos = np.zeros((self.batch_size, 2), dtype=np.int8) @@ -41,17 +47,29 @@ def _randomize_color_and_player_positions(self): @override(vectorized_coin_game.VectorizedCoinGame) def _generate_coin(self): generate = np.ones(self.batch_size, dtype=np.int_) * 3 - self.red_coin_pos, self.blue_coin_pos = \ - generate_coin_wt_numba_optimization( - self.batch_size, generate, self.red_coin_pos, - self.blue_coin_pos, self.red_pos, self.blue_pos, - self.grid_size) + ( + self.red_coin_pos, + self.blue_coin_pos, + ) = generate_coin_wt_numba_optimization( + self.batch_size, + generate, + self.red_coin_pos, + self.blue_coin_pos, + self.red_pos, + self.blue_pos, + self.grid_size, + ) @override(vectorized_coin_game.VectorizedCoinGame) def _generate_observation(self): obs = generate_observations_wt_numba_optimization( - self.batch_size, self.red_pos, self.blue_pos, self.red_coin_pos, - self.blue_coin_pos, self.grid_size) + self.batch_size, + self.red_pos, + self.blue_pos, + self.red_coin_pos, + self.blue_coin_pos, + self.grid_size, + ) obs = self._apply_optional_invariance_to_the_player_trained(obs) obs, _ = self._optional_unvectorize(obs) @@ -62,19 +80,35 @@ def step(self, actions: Iterable): actions = self._from_RLLib_API_to_list(actions) self.step_count_in_current_episode += 1 - (self.red_pos, self.blue_pos, rewards, self.red_coin_pos, - self.blue_coin_pos, observation, red_pick_any, red_pick_red, - blue_pick_any, blue_pick_blue) = \ - vectorized_step_wt_numba_optimization( - actions, self.batch_size, self.red_pos, self.blue_pos, - self.red_coin_pos, self.blue_coin_pos, self.grid_size) + ( + self.red_pos, + self.blue_pos, + rewards, + self.red_coin_pos, + self.blue_coin_pos, + observation, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + ) = vectorized_step_wt_numba_optimization( + actions, + self.batch_size, + self.red_pos, + self.blue_pos, + self.red_coin_pos, + self.blue_coin_pos, + self.grid_size, + ) if self.output_additional_info: self._accumulate_info( - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue) + red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue + ) obs = self._apply_optional_invariance_to_the_player_trained( - observation) + observation + ) obs, rewards = self._optional_unvectorize(obs, rewards) return self._to_RLLib_API(obs, rewards) @@ -88,8 +122,7 @@ def _save_env(self): "blue_coin_pos": self.blue_coin_pos, "grid_size": self.grid_size, "batch_size": self.batch_size, - "step_count_in_current_episode": - self.step_count_in_current_episode, + "step_count_in_current_episode": self.step_count_in_current_episode, "max_steps": self.max_steps, "red_pick": self.red_pick, "red_pick_own": self.red_pick_own, @@ -101,24 +134,55 @@ def _save_env(self): @jit(nopython=True) def vectorized_step_wt_numba_optimization( - actions, batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos, - grid_size: int): + actions, + batch_size, + red_pos, + blue_pos, + red_coin_pos, + blue_coin_pos, + grid_size: int, +): red_pos, blue_pos = move_players( - batch_size, actions, red_pos, blue_pos, grid_size) + batch_size, actions, red_pos, blue_pos, grid_size + ) - reward, generate, red_pick_any, red_pick_red, \ - blue_pick_any, blue_pick_blue = compute_reward( - batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos) + ( + reward, + generate, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + ) = compute_reward( + batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos + ) red_coin_pos, blue_coin_pos = generate_coin_wt_numba_optimization( - batch_size, generate, red_coin_pos, blue_coin_pos, - red_pos, blue_pos, grid_size) + batch_size, + generate, + red_coin_pos, + blue_coin_pos, + red_pos, + blue_pos, + grid_size, + ) obs = generate_observations_wt_numba_optimization( - batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos, grid_size) + batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos, grid_size + ) - return red_pos, blue_pos, reward, red_coin_pos, blue_coin_pos, obs, \ - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue + return ( + red_pos, + blue_pos, + reward, + red_coin_pos, + blue_coin_pos, + obs, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + ) @jit(nopython=True) @@ -126,20 +190,21 @@ def compute_reward(batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos): reward_red = np.zeros(batch_size) reward_blue = np.zeros(batch_size) generate = np.zeros(batch_size, dtype=np.int_) - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = \ - 0, 0, 0, 0 + red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = 0, 0, 0, 0 for i in prange(batch_size): - if _same_pos(red_pos[i], red_coin_pos[i]) and \ - _same_pos(blue_pos[i], red_coin_pos[i]): + if _same_pos(red_pos[i], red_coin_pos[i]) and _same_pos( + blue_pos[i], red_coin_pos[i] + ): generate[i] = 1 reward_red[i] += 2 reward_blue[i] += 2 red_pick_any += 1 red_pick_red += 1 blue_pick_any += 1 - elif _same_pos(red_pos[i], blue_coin_pos[i]) and \ - _same_pos(blue_pos[i], blue_coin_pos[i]): + elif _same_pos(red_pos[i], blue_coin_pos[i]) and _same_pos( + blue_pos[i], blue_coin_pos[i] + ): generate[i] = 2 reward_red[i] += 1 reward_blue[i] += 4 @@ -149,14 +214,26 @@ def compute_reward(batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos): reward = [reward_red, reward_blue] - return reward, generate, \ - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue + return ( + reward, + generate, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + ) @jit(nopython=True) def generate_coin_wt_numba_optimization( - batch_size, generate, red_coin_pos, blue_coin_pos, red_pos, blue_pos, - grid_size): + batch_size, + generate, + red_coin_pos, + blue_coin_pos, + red_pos, + blue_pos, + grid_size, +): for i in prange(batch_size): # generate:0 => no coin generation # generate:1 => red coin generation @@ -164,11 +241,13 @@ def generate_coin_wt_numba_optimization( # generate:0 => red & blue coin generation if generate[i] == 3 or generate[i] == 1: - red_coin_pos[i] = _place_coin(red_pos[i], blue_pos[i], - grid_size, blue_coin_pos[i]) + red_coin_pos[i] = _place_coin( + red_pos[i], blue_pos[i], grid_size, blue_coin_pos[i] + ) if generate[i] == 3 or generate[i] == 2: - blue_coin_pos[i] = _place_coin(red_pos[i], blue_pos[i], - grid_size, red_coin_pos[i]) + blue_coin_pos[i] = _place_coin( + red_pos[i], blue_pos[i], grid_size, red_coin_pos[i] + ) return red_coin_pos, blue_coin_pos @@ -178,19 +257,24 @@ def _place_coin(red_pos_i, blue_pos_i, grid_size, other_coin_pos_i): blue_pos_flat = _flatten_index(blue_pos_i, grid_size) other_coin_pos_flat = _flatten_index(other_coin_pos_i, grid_size) possible_coin_pos = np.array( - [x for x in range(9) - if ((x != blue_pos_flat) and - (x != red_pos_flat) and - (x != other_coin_pos_flat))] + [ + x + for x in range(9) + if ( + (x != blue_pos_flat) + and (x != red_pos_flat) + and (x != other_coin_pos_flat) + ) + ] ) flat_coin_pos = np.random.choice(possible_coin_pos) return _unflatten_index(flat_coin_pos, grid_size) @jit(nopython=True) -def generate_observations_wt_numba_optimization(batch_size, red_pos, blue_pos, - red_coin_pos, blue_coin_pos, - grid_size): +def generate_observations_wt_numba_optimization( + batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos, grid_size +): obs = np.zeros((batch_size, grid_size, grid_size, 4)) for i in prange(batch_size): obs[i, red_pos[i][0], red_pos[i][1], 0] = 1 diff --git a/marltoolbox/envs/vectorized_ssd_mm_coin_game.py b/marltoolbox/envs/vectorized_ssd_mm_coin_game.py new file mode 100644 index 0000000..ef6ee2b --- /dev/null +++ b/marltoolbox/envs/vectorized_ssd_mm_coin_game.py @@ -0,0 +1,428 @@ +import copy +import logging +from collections import Iterable + +import gym +import numpy as np +from numba import jit, prange +from ray.rllib.utils import override + +from marltoolbox.envs import vectorized_coin_game +from marltoolbox.envs.vectorized_coin_game import ( + _same_pos, + move_players, +) +from marltoolbox.envs.vectorized_mixed_motive_coin_game import _place_coin + +logger = logging.getLogger(__name__) + +PLOT_KEYS = vectorized_coin_game.PLOT_KEYS + +PLOT_ASSEMBLAGE_TAGS = vectorized_coin_game.PLOT_ASSEMBLAGE_TAGS + + +class VectSSDMixedMotiveCG( + vectorized_coin_game.VectorizedCoinGame, +): + def __init__(self, config: dict = {}): + super().__init__(config) + self.punishment_helped = config.get("punishment_helped", False) + + self.OBSERVATION_SPACE = gym.spaces.Box( + low=0, + high=1, + shape=(self.grid_size, self.grid_size, 6), + dtype="uint8", + ) + + @override(vectorized_coin_game.VectorizedCoinGame) + def _load_config(self, config): + super()._load_config(config) + assert self.both_players_can_pick_the_same_coin, ( + "both_players_can_pick_the_same_coin option must be True in " + "ssd mixed motive coin game." + ) + assert self.same_obs_for_each_player, ( + "same_obs_for_each_player option must be True in " + "ssd mixed motive coin game." + ) + + @override(vectorized_coin_game.VectorizedCoinGame) + def _randomize_color_and_player_positions(self): + # Reset coin color and the players and coin positions + self.red_pos = np.random.randint( + self.grid_size, size=(self.batch_size, 2) + ) + self.blue_pos = np.random.randint( + self.grid_size, size=(self.batch_size, 2) + ) + self.red_coin = np.random.randint( + 2, size=self.batch_size, dtype=np.int8 + ) + self.red_coin_pos = np.zeros((self.batch_size, 2), dtype=np.int8) + self.blue_coin_pos = np.zeros((self.batch_size, 2), dtype=np.int8) + + self._players_do_not_overlap_at_start() + + @override(vectorized_coin_game.VectorizedCoinGame) + def _generate_coin(self): + self._generate_coin() + + @override(vectorized_coin_game.VectorizedCoinGame) + def _generate_coin(self): + generate = np.ones(self.batch_size, dtype=np.int_) + ( + self.red_coin_pos, + self.blue_coin_pos, + self.red_coin, + ) = generate_coin_wt_numba_optimization( + self.batch_size, + generate, + self.red_coin_pos, + self.blue_coin_pos, + self.red_pos, + self.blue_pos, + self.grid_size, + self.red_coin, + ) + + @override(vectorized_coin_game.VectorizedCoinGame) + def _generate_observation(self): + obs = generate_observations_wt_numba_optimization( + self.batch_size, + self.red_pos, + self.blue_pos, + self.red_coin_pos, + self.blue_coin_pos, + self.grid_size, + self.red_coin, + ) + + obs = self._apply_optional_invariance_to_the_player_trained(obs) + obs, _ = self._optional_unvectorize(obs) + return obs + + @override(vectorized_coin_game.VectorizedCoinGame) + def step(self, actions: Iterable): + actions = self._from_RLLib_API_to_list(actions) + self.step_count_in_current_episode += 1 + + ( + self.red_pos, + self.blue_pos, + rewards, + self.red_coin_pos, + self.blue_coin_pos, + observation, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + picked_red_coop, + picked_blue_coop, + self.red_coin, + ) = vectorized_step_wt_numba_optimization( + actions, + self.batch_size, + self.red_pos, + self.blue_pos, + self.red_coin_pos, + self.blue_coin_pos, + self.grid_size, + self.red_coin, + self.punishment_helped, + ) + + if self.output_additional_info: + self._accumulate_info( + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + picked_red_coop, + picked_blue_coop, + ) + + obs = self._apply_optional_invariance_to_the_player_trained( + observation + ) + obs, rewards = self._optional_unvectorize(obs, rewards) + + return self._to_RLLib_API(obs, rewards) + + @override(vectorized_coin_game.VectorizedCoinGame) + def _save_env(self): + env_save_state = { + "red_pos": self.red_pos, + "blue_pos": self.blue_pos, + "red_coin_pos": self.red_coin_pos, + "blue_coin_pos": self.blue_coin_pos, + "red_coin": self.red_coin, + "grid_size": self.grid_size, + "batch_size": self.batch_size, + "step_count_in_current_episode": self.step_count_in_current_episode, + "max_steps": self.max_steps, + "red_pick": self.red_pick, + "red_pick_own": self.red_pick_own, + "blue_pick": self.blue_pick, + "blue_pick_own": self.blue_pick_own, + } + return copy.deepcopy(env_save_state) + + @override(vectorized_coin_game.VectorizedCoinGame) + def _init_info(self): + super()._init_info() + self.picked_red_coop = [] + self.picked_blue_coop = [] + + @override(vectorized_coin_game.VectorizedCoinGame) + def _reset_info(self): + super()._reset_info() + self.picked_red_coop.clear() + self.picked_blue_coop.clear() + + @override(vectorized_coin_game.VectorizedCoinGame) + def _accumulate_info( + self, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + picked_red_coop, + picked_blue_coop, + ): + super()._accumulate_info( + red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue + ) + self.picked_red_coop.append(picked_red_coop) + self.picked_blue_coop.append(picked_blue_coop) + + @override(vectorized_coin_game.VectorizedCoinGame) + def _get_episode_info(self): + """ + Output the following information: + pick_speed is the fraction of steps during which the player picked a + coin. + pick_own_color is the fraction of coins picked by the player which have + the same color as the player. + """ + n_steps_played = len(self.red_pick) * self.batch_size + player_red_info, player_blue_info = super()._get_episode_info( + n_steps_played=n_steps_played + ) + n_coop = sum(self.picked_blue_coop) + sum(self.picked_red_coop) + + if len(self.red_pick) > 0: + red_pick = sum(self.red_pick) + player_red_info["red_coop_speed"] = ( + sum(self.picked_red_coop) / n_steps_played + ) + if red_pick > 0: + player_red_info["red_coop_fraction"] = n_coop / red_pick + + if len(self.blue_pick) > 0: + blue_pick = sum(self.blue_pick) + player_blue_info["blue_coop_speed"] = ( + sum(self.picked_blue_coop) / n_steps_played + ) + if blue_pick > 0: + player_blue_info["blue_coop_fraction"] = n_coop / blue_pick + + return player_red_info, player_blue_info + + +@jit(nopython=True) +def vectorized_step_wt_numba_optimization( + actions, + batch_size, + red_pos, + blue_pos, + red_coin_pos, + blue_coin_pos, + grid_size: int, + red_coin, + punishment_helped, +): + red_pos, blue_pos = move_players( + batch_size, actions, red_pos, blue_pos, grid_size + ) + + ( + reward, + generate, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + picked_red_coop, + picked_blue_coop, + ) = compute_reward( + batch_size, + red_pos, + blue_pos, + red_coin_pos, + blue_coin_pos, + red_coin, + punishment_helped, + ) + + ( + red_coin_pos, + blue_coin_pos, + red_coin, + ) = generate_coin_wt_numba_optimization( + batch_size, + generate, + red_coin_pos, + blue_coin_pos, + red_pos, + blue_pos, + grid_size, + red_coin, + ) + + obs = generate_observations_wt_numba_optimization( + batch_size, + red_pos, + blue_pos, + red_coin_pos, + blue_coin_pos, + grid_size, + red_coin, + ) + + return ( + red_pos, + blue_pos, + reward, + red_coin_pos, + blue_coin_pos, + obs, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + picked_red_coop, + picked_blue_coop, + red_coin, + ) + + +@jit(nopython=True) +def compute_reward( + batch_size, + red_pos, + blue_pos, + red_coin_pos, + blue_coin_pos, + red_coin, + punishment_helped, +): + reward_red = np.zeros(batch_size) + reward_blue = np.zeros(batch_size) + generate = np.zeros(batch_size, dtype=np.int_) + red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = 0, 0, 0, 0 + picked_red_coop, picked_blue_coop = 0, 0 + + for i in prange(batch_size): + if _same_pos(red_pos[i], red_coin_pos[i]): + if red_coin[i] and _same_pos(blue_pos[i], red_coin_pos[i]): + # Red coin is a coop coin + generate[i] = True + reward_red[i] += 2.0 + red_pick_any += 1 + red_pick_red += 1 + blue_pick_any += 1 + picked_red_coop += 1 + elif not red_coin[i]: + # Red coin is a selfish coin + generate[i] = True + reward_red[i] += 1.0 + red_pick_any += 1 + red_pick_red += 1 + if punishment_helped and _same_pos( + blue_pos[i], red_coin_pos[i] + ): + reward_red[i] -= 0.75 + elif _same_pos(blue_pos[i], blue_coin_pos[i]): + if not red_coin[i] and _same_pos(red_pos[i], blue_coin_pos[i]): + # Blue coin is a coop coin + generate[i] = True + reward_blue[i] += 3.0 + red_pick_any += 1 + blue_pick_any += 1 + blue_pick_blue += 1 + picked_blue_coop += 1 + elif red_coin[i]: + # Blue coin is a selfish coin + generate[i] = True + reward_blue[i] += 1.0 + blue_pick_any += 1 + blue_pick_blue += 1 + if punishment_helped and _same_pos( + red_pos[i], blue_coin_pos[i] + ): + reward_blue[i] -= 0.75 + + reward = [reward_red, reward_blue] + + return ( + reward, + generate, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + picked_red_coop, + picked_blue_coop, + ) + + +@jit(nopython=True) +def generate_observations_wt_numba_optimization( + batch_size, + red_pos, + blue_pos, + red_coin_pos, + blue_coin_pos, + grid_size, + red_coin, +): + obs = np.zeros((batch_size, grid_size, grid_size, 6)) + for i in prange(batch_size): + obs[i, red_pos[i][0], red_pos[i][1], 0] = 1 + obs[i, blue_pos[i][0], blue_pos[i][1], 1] = 1 + if red_coin[i]: + # Feature 4th is for the red cooperative coin + obs[i, red_coin_pos[i][0], red_coin_pos[i][1], 4] = 1 + # Feature 3th is for the blue selfish coin + obs[i, blue_coin_pos[i][0], blue_coin_pos[i][1], 3] = 1 + else: + # Feature 2th is for the red selfish coin + obs[i, red_coin_pos[i][0], red_coin_pos[i][1], 2] = 1 + # Feature 5th is for the blue cooperative coin + obs[i, blue_coin_pos[i][0], blue_coin_pos[i][1], 5] = 1 + return obs + + +@jit(nopython=True) +def generate_coin_wt_numba_optimization( + batch_size, + generate, + red_coin_pos, + blue_coin_pos, + red_pos, + blue_pos, + grid_size, + red_coin, +): + for i in prange(batch_size): + if generate[i]: + red_coin[i] = 1 - red_coin[i] + red_coin_pos[i] = _place_coin( + red_pos[i], blue_pos[i], grid_size, blue_coin_pos[i] + ) + blue_coin_pos[i] = _place_coin( + red_pos[i], blue_pos[i], grid_size, red_coin_pos[i] + ) + return red_coin_pos, blue_coin_pos, red_coin diff --git a/marltoolbox/examples/Tutorial_Basics_How_to_use_the_toolbox.ipynb b/marltoolbox/examples/Tutorial_Basics_How_to_use_the_toolbox.ipynb index bb7bda6..c2a13e5 100644 --- a/marltoolbox/examples/Tutorial_Basics_How_to_use_the_toolbox.ipynb +++ b/marltoolbox/examples/Tutorial_Basics_How_to_use_the_toolbox.ipynb @@ -75,8 +75,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:58:39.129878Z", - "start_time": "2021-04-14T09:58:39.126721Z" + "end_time": "2021-04-15T12:33:11.973283Z", + "start_time": "2021-04-15T12:33:11.970475Z" }, "id": "DsrzADRv4Itm" }, @@ -172,8 +172,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:58:41.814223Z", - "start_time": "2021-04-14T09:58:39.133920Z" + "end_time": "2021-04-15T12:33:14.647926Z", + "start_time": "2021-04-15T12:33:11.976742Z" }, "id": "BdtPEwKt-b47" }, @@ -191,6 +191,7 @@ "\n", "from marltoolbox.envs.matrix_sequential_social_dilemma import IteratedPrisonersDilemma\n", "from marltoolbox.utils.miscellaneous import check_learning_achieved\n", + "from marltoolbox.utils import log\n", "\n", "from IPython.core.display import display, HTML\n", "display(HTML(\"\"))" @@ -224,8 +225,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:58:41.821285Z", - "start_time": "2021-04-14T09:58:41.816146Z" + "end_time": "2021-04-15T12:33:14.653543Z", + "start_time": "2021-04-15T12:33:14.650069Z" }, "id": "MHNIE5wp-nV8" }, @@ -256,8 +257,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:58:41.907870Z", - "start_time": "2021-04-14T09:58:41.824356Z" + "end_time": "2021-04-15T12:33:14.742628Z", + "start_time": "2021-04-15T12:33:14.655685Z" }, "id": "nEC3Fbp9trEZ" }, @@ -284,8 +285,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:58:42.005535Z", - "start_time": "2021-04-14T09:58:41.910259Z" + "end_time": "2021-04-15T12:33:14.857873Z", + "start_time": "2021-04-15T12:33:14.745371Z" }, "id": "6ZyWFTfy-KSd" }, @@ -329,8 +330,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:58:42.124090Z", - "start_time": "2021-04-14T09:58:42.007028Z" + "end_time": "2021-04-15T12:33:14.926104Z", + "start_time": "2021-04-15T12:33:14.859461Z" }, "id": "DrBiGbnSBFns" }, @@ -389,8 +390,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:58:42.211967Z", - "start_time": "2021-04-14T09:58:42.127092Z" + "end_time": "2021-04-15T12:33:15.020084Z", + "start_time": "2021-04-15T12:33:14.928681Z" }, "id": "tB0ig1cuCkl7" }, @@ -480,8 +481,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:58:42.349958Z", - "start_time": "2021-04-14T09:58:42.216488Z" + "end_time": "2021-04-15T12:33:15.114115Z", + "start_time": "2021-04-15T12:33:15.022186Z" }, "id": "s4G0pHVJCpGU" }, @@ -586,8 +587,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:58:53.956824Z", - "start_time": "2021-04-14T09:58:42.352177Z" + "end_time": "2021-04-15T12:33:26.563552Z", + "start_time": "2021-04-15T12:33:15.119380Z" }, "id": "ST5qb68DCMI7", "scrolled": true @@ -602,6 +603,7 @@ "\n", "# stop the training after N Trainable.step (here this is equivalent to N episodes and N updates)\n", "stop_config = {\"training_iteration\": 200} \n", + "exp_name, _ = log.log_in_current_day_dir(\"PG_IPD_notebook_exp\")\n", "\n", "# Restart Ray defensively in case the ray connection is lost.\n", "ray.shutdown() \n", @@ -611,7 +613,7 @@ " Trainable,\n", " stop=stop_config,\n", " config=tune_config,\n", - " name=\"PG_IPD\",\n", + " name=exp_name,\n", " )\n", "\n", "ray.shutdown()\n", @@ -651,8 +653,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:59:06.311720Z", - "start_time": "2021-04-14T09:58:53.963206Z" + "end_time": "2021-04-15T12:33:39.209663Z", + "start_time": "2021-04-15T12:33:26.565556Z" }, "id": "vMbfhunq_ben" }, @@ -669,6 +671,7 @@ "# but here varying \"training_iteration\" is interesting.\n", "stop_config = {\"training_iteration\": tune.grid_search([2, 4, 8, 16, 32, 64, 128])} \n", "#####\n", + "exp_name, _ = log.log_in_current_day_dir(\"PG_IPD_notebook_exp\")\n", "\n", "ray.shutdown() \n", "ray.init(num_cpus=os.cpu_count(), num_gpus=0) \n", @@ -676,7 +679,7 @@ " Trainable,\n", " stop=stop_config,\n", " config=tune_config,\n", - " name=\"PG_IPD\",\n", + " name=exp_name,\n", " )\n", "ray.shutdown()\n", "\n", @@ -718,8 +721,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:00:01.360129Z", - "start_time": "2021-04-14T10:00:01.355038Z" + "end_time": "2021-04-15T12:33:39.226528Z", + "start_time": "2021-04-15T12:33:39.211410Z" }, "id": "9uHr3nuI-Lap" }, @@ -780,8 +783,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:02:10.235678Z", - "start_time": "2021-04-14T10:02:10.224537Z" + "end_time": "2021-04-15T12:33:39.355656Z", + "start_time": "2021-04-15T12:33:39.228620Z" }, "id": "mbxWsThD-vEt" }, @@ -855,8 +858,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:59:06.631534Z", - "start_time": "2021-04-14T09:59:06.473496Z" + "end_time": "2021-04-15T12:33:39.486196Z", + "start_time": "2021-04-15T12:33:39.360619Z" }, "id": "4-VPfgGP9-XH" }, @@ -911,8 +914,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:59:06.714373Z", - "start_time": "2021-04-14T09:59:06.632964Z" + "end_time": "2021-04-15T12:33:39.578807Z", + "start_time": "2021-04-15T12:33:39.487619Z" }, "id": "gej3suhHAo0S" }, @@ -984,8 +987,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:59:06.824496Z", - "start_time": "2021-04-14T09:59:06.716292Z" + "end_time": "2021-04-15T12:33:39.713752Z", + "start_time": "2021-04-15T12:33:39.584035Z" }, "id": "aJIHVPLjbfiC" }, @@ -1047,8 +1050,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:59:06.916296Z", - "start_time": "2021-04-14T09:59:06.826239Z" + "end_time": "2021-04-15T12:33:39.901063Z", + "start_time": "2021-04-15T12:33:39.715122Z" }, "id": "qqlYU7ZiALxl" }, @@ -1091,6 +1094,8 @@ " \"framework\": \"torch\",\n", " \"batch_mode\": \"complete_episodes\",\n", " # LTFT supports only 1 worker only otherwise it would be mixing several opponents trajectories\n", + " # Number of rollout worker actors to create for parallel sampling. Setting\n", + " # this to 0 will force rollouts to be done in the trainer actor.\n", " \"num_workers\": 0,\n", " # LTFT supports only 1 env per worker only otherwise several episodes would be played at the same time\n", " \"num_envs_per_worker\": 1,\n", @@ -1116,10 +1121,11 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:04:28.482320Z", - "start_time": "2021-04-14T10:02:15.269072Z" + "end_time": "2021-04-15T12:35:51.624256Z", + "start_time": "2021-04-15T12:33:39.902523Z" }, - "id": "bfEaO43v3UBB" + "id": "bfEaO43v3UBB", + "scrolled": true }, "outputs": [], "source": [ @@ -1139,14 +1145,14 @@ " \"debug\": False,\n", "}\n", "\n", - "\n", + "exp_name, _ = log.log_in_current_day_dir(\"LTFT_notebook_exp\")\n", "rllib_config = get_rllib_config(ltft_hparameters)\n", "stop_config = get_stop_config(ltft_hparameters)\n", "ray.shutdown()\n", "ray.init(num_cpus=os.cpu_count(), num_gpus=0) \n", "tune_analysis_self_play = ray.tune.run(ltft.LTFTTrainer, config=rllib_config,\n", " checkpoint_freq=0, stop=stop_config, \n", - " checkpoint_at_end=False, name=\"LTFT_exp\")\n", + " checkpoint_at_end=False, name=exp_name)\n", "ray.shutdown()\n", "\n", "check_learning_achieved(tune_results=tune_analysis_self_play, \n", @@ -1192,8 +1198,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:04:28.491013Z", - "start_time": "2021-04-14T10:04:28.484650Z" + "end_time": "2021-04-15T12:35:51.630130Z", + "start_time": "2021-04-15T12:35:51.626154Z" }, "id": "yCZzbx3Ld0vz" }, @@ -1224,8 +1230,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:05:15.895877Z", - "start_time": "2021-04-14T10:04:28.493528Z" + "end_time": "2021-04-15T12:36:53.910972Z", + "start_time": "2021-04-15T12:35:51.631535Z" }, "id": "zETjK-Y5ZKLA" }, @@ -1249,14 +1255,14 @@ " \"debug\": False,\n", "}\n", "\n", - "\n", + "exp_name, _ = log.log_in_current_day_dir(\"LTFT_notebook_exp\")\n", "rllib_config = get_rllib_config(ltft_hparameters)\n", "stop_config = get_stop_config(ltft_hparameters)\n", "ray.shutdown()\n", "ray.init(num_cpus=os.cpu_count(), num_gpus=0) \n", "tune_analysis_self_play = ray.tune.run(ltft.LTFTTrainer, config=rllib_config,\n", " checkpoint_freq=0, stop=stop_config, \n", - " checkpoint_at_end=False, name=\"LTFT_exp\")\n", + " checkpoint_at_end=False, name=exp_name)\n", "ray.shutdown()\n", "\n", "check_learning_achieved(tune_results=tune_analysis_self_play,\n", @@ -1299,8 +1305,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:59:10.306079Z", - "start_time": "2021-04-14T09:58:39.178Z" + "end_time": "2021-04-15T12:36:53.923288Z", + "start_time": "2021-04-15T12:36:53.915782Z" }, "id": "u00VE8oB4Itx" }, @@ -1314,8 +1320,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:59:10.306786Z", - "start_time": "2021-04-14T09:58:39.180Z" + "end_time": "2021-04-15T12:36:54.052499Z", + "start_time": "2021-04-15T12:36:53.925743Z" }, "id": "NpeJySSk4Itx", "scrolled": true diff --git a/marltoolbox/examples/Tutorial_Evaluations_Level_1_best_response_and_self_play_and_cross_play.ipynb b/marltoolbox/examples/Tutorial_Evaluations_Level_1_best_response_and_self_play_and_cross_play.ipynb index 5989126..8811697 100644 --- a/marltoolbox/examples/Tutorial_Evaluations_Level_1_best_response_and_self_play_and_cross_play.ipynb +++ b/marltoolbox/examples/Tutorial_Evaluations_Level_1_best_response_and_self_play_and_cross_play.ipynb @@ -32,8 +32,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:06:53.433046Z", - "start_time": "2021-04-14T10:06:53.428041Z" + "end_time": "2021-04-15T12:55:08.510232Z", + "start_time": "2021-04-15T12:55:08.501235Z" }, "id": "KMXC7aH5CD8p" }, @@ -62,8 +62,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:06:53.517255Z", - "start_time": "2021-04-14T10:06:53.435182Z" + "end_time": "2021-04-15T12:55:08.594695Z", + "start_time": "2021-04-15T12:55:08.512706Z" }, "id": "DsrzADRv4Itm" }, @@ -141,8 +141,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:06:56.606917Z", - "start_time": "2021-04-14T10:06:53.518906Z" + "end_time": "2021-04-15T12:55:11.544021Z", + "start_time": "2021-04-15T12:55:08.597116Z" }, "id": "BdtPEwKt-b47" }, @@ -188,8 +188,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:06:56.614247Z", - "start_time": "2021-04-14T10:06:56.609516Z" + "end_time": "2021-04-15T12:55:11.550532Z", + "start_time": "2021-04-15T12:55:11.546175Z" }, "id": "9fZtrKHovNvV" }, @@ -216,8 +216,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:06:56.722108Z", - "start_time": "2021-04-14T10:06:56.616211Z" + "end_time": "2021-04-15T13:08:41.276000Z", + "start_time": "2021-04-15T13:08:41.268429Z" }, "id": "MHNIE5wp-nV8" }, @@ -229,10 +229,10 @@ "\n", " # This modification to the policy will allow us to load each policies from different checkpoints \n", " # This will be used during evaluation.\n", - " def merged_after_init(*args, **kwargs):\n", + " def merged_before_loss_init(*args, **kwargs):\n", " setup_mixins(*args, **kwargs)\n", - " restore.after_init_load_policy_checkpoint(*args, **kwargs)\n", - " MyPPOPolicy = PPOTorchPolicy.with_updates(after_init=merged_after_init)\n", + " restore.before_loss_init_load_policy_checkpoint(*args, **kwargs)\n", + " MyPPOPolicy = PPOTorchPolicy.with_updates(before_loss_init=merged_before_loss_init)\n", "\n", " stop_config = {\n", " \"episodes_total\": hp[\"episodes_total\"],\n", @@ -311,18 +311,19 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:07:47.492122Z", - "start_time": "2021-04-14T10:06:56.726437Z" + "end_time": "2021-04-15T12:55:59.020889Z", + "start_time": "2021-04-15T12:55:11.652388Z" }, "id": "XijF3Q7O0FxF" }, "outputs": [], "source": [ + "exp_name, _ = log.log_in_current_day_dir(\"PPO_BoS_notebook_exp\")\n", "hyperparameters = {\n", " \"steps_per_epi\": 20,\n", " \"train_n_replicates\": 8,\n", " \"episodes_total\": 200,\n", - " \"exp_name\": \"PPO_BoS\",\n", + " \"exp_name\": exp_name,\n", " \"base_lr\": 5e-1,\n", "}\n", "\n", @@ -350,8 +351,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:07:47.498013Z", - "start_time": "2021-04-14T10:07:47.494191Z" + "end_time": "2021-04-15T12:55:59.045103Z", + "start_time": "2021-04-15T12:55:59.024866Z" }, "id": "e4qziRvL0o_E" }, @@ -375,8 +376,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:08:18.230859Z", - "start_time": "2021-04-14T10:07:47.499394Z" + "end_time": "2021-04-15T13:11:25.364166Z", + "start_time": "2021-04-15T13:10:45.951397Z" }, "id": "6ZyWFTfy-KSd" }, @@ -455,8 +456,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:10:12.067609Z", - "start_time": "2021-04-14T10:10:12.062918Z" + "end_time": "2021-04-15T12:56:29.454882Z", + "start_time": "2021-04-15T12:56:28.770969Z" }, "id": "wmvRIbPqUAwy" }, @@ -503,8 +504,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:08:19.676226Z", - "start_time": "2021-04-14T10:08:19.671968Z" + "end_time": "2021-04-15T12:56:29.461390Z", + "start_time": "2021-04-15T12:56:29.456601Z" }, "id": "bVqSBL72wHrd" }, @@ -531,8 +532,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:08:19.788760Z", - "start_time": "2021-04-14T10:08:19.678381Z" + "end_time": "2021-04-15T12:56:29.561534Z", + "start_time": "2021-04-15T12:56:29.463140Z" }, "id": "vMbfhunq_ben" }, @@ -540,11 +541,12 @@ "source": [ "def train_lvl0_agents(lvl0_hparameters):\n", "\n", + " exp_name, _ = log.log_in_current_day_dir(\"Lvl0_LOLAExact_notebook_exp\")\n", " tune_config = get_tune_config(lvl0_hparameters)\n", " stop_config = get_stop_config(lvl0_hparameters)\n", " ray.shutdown()\n", " ray.init(num_cpus=os.cpu_count(), num_gpus=0) \n", - " tune_analysis_lvl0 = tune.run(LOLAExact, name=\"Lvl0_LOLAExact\", config=tune_config,\n", + " tune_analysis_lvl0 = tune.run(LOLAExact, name=exp_name, config=tune_config,\n", " checkpoint_at_end=True, stop=stop_config, \n", " metric=lvl0_hparameters[\"metric\"], mode=\"max\")\n", " ray.shutdown()\n", @@ -583,8 +585,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:09:38.553623Z", - "start_time": "2021-04-14T10:08:19.792698Z" + "end_time": "2021-04-15T12:57:43.481511Z", + "start_time": "2021-04-15T12:56:29.563639Z" }, "id": "Im5nFVs7YUtU" }, @@ -635,8 +637,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:11:56.338029Z", - "start_time": "2021-04-14T10:11:56.328179Z" + "end_time": "2021-04-15T12:57:43.492585Z", + "start_time": "2021-04-15T12:57:43.483312Z" }, "id": "P2loJhuzVho6" }, @@ -644,6 +646,7 @@ "source": [ "def train_lvl1_agents(hp_lvl1, tune_analysis_lvl0):\n", "\n", + " exp_name, _ = log.log_in_current_day_dir(\"Lvl1_PG_notebook_exp\")\n", " rllib_config_lvl1, trainable_class, env_config = get_rllib_config(hp_lvl1)\n", " stop_config = get_stop_config(hp_lvl1)\n", " \n", @@ -659,7 +662,7 @@ " stop=stop_config,\n", " checkpoint_at_end=True,\n", " metric=\"episode_reward_mean\", mode=\"max\",\n", - " name=\"Lvl1_PG\")\n", + " name=exp_name)\n", " ray.shutdown()\n", " return tune_analysis_lvl1\n", "\n", @@ -746,8 +749,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:12:30.455824Z", - "start_time": "2021-04-14T10:11:58.126061Z" + "end_time": "2021-04-15T12:58:19.055673Z", + "start_time": "2021-04-15T12:57:43.493968Z" }, "id": "b2kmotZMYZBX", "scrolled": true @@ -785,8 +788,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:12:34.152484Z", - "start_time": "2021-04-14T10:12:33.899223Z" + "end_time": "2021-04-15T12:58:19.100252Z", + "start_time": "2021-04-15T12:58:19.056994Z" } }, "outputs": [], @@ -799,8 +802,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:12:36.039089Z", - "start_time": "2021-04-14T10:12:35.975270Z" + "end_time": "2021-04-15T12:58:19.256941Z", + "start_time": "2021-04-15T12:58:19.101888Z" }, "id": "M1lSyEf1D08q" }, @@ -828,8 +831,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:13:04.766311Z", - "start_time": "2021-04-14T10:13:04.716165Z" + "end_time": "2021-04-15T12:58:19.317574Z", + "start_time": "2021-04-15T12:58:19.258240Z" }, "id": "v67qWvcDD08q" }, @@ -871,8 +874,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:09:39.229687Z", - "start_time": "2021-04-14T10:06:53.463Z" + "end_time": "2021-04-15T12:58:19.401613Z", + "start_time": "2021-04-15T12:58:19.319051Z" }, "id": "u00VE8oB4Itx" }, @@ -886,8 +889,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:09:39.230757Z", - "start_time": "2021-04-14T10:06:53.465Z" + "end_time": "2021-04-15T12:58:19.501562Z", + "start_time": "2021-04-15T12:58:19.403275Z" }, "id": "NpeJySSk4Itx", "scrolled": true diff --git a/marltoolbox/examples/rllib_api/dqn_coin_game.py b/marltoolbox/examples/rllib_api/dqn_coin_game.py index ba32255..140139e 100644 --- a/marltoolbox/examples/rllib_api/dqn_coin_game.py +++ b/marltoolbox/examples/rllib_api/dqn_coin_game.py @@ -18,38 +18,36 @@ def main(debug): seeds = miscellaneous.get_random_seeds(train_n_replicates) exp_name, _ = log.log_in_current_day_dir("DQN_CG") - hparams = _get_hyperparameters(seeds, debug, exp_name) + hparams = get_hyperparameters(seeds, debug, exp_name) - rllib_config, stop_config = _get_rllib_configs(hparams) + rllib_config, stop_config = get_rllib_configs(hparams) - tune_analysis = _train_dqn_and_plot_logs( - hparams, rllib_config, stop_config) + tune_analysis = train_dqn_and_plot_logs(hparams, rllib_config, stop_config) return tune_analysis -def _get_hyperparameters(seeds, debug, exp_name): +def get_hyperparameters(seeds, debug, exp_name): + """Get hyperparameters for the Coin Game env and DQN agents""" + hparams = { "seeds": seeds, "debug": debug, "exp_name": exp_name, "n_steps_per_epi": 100, - "n_epi": 2000, - "buf_frac": 0.125, - "last_exploration_temp_value": 0.1, - "bs_epi_mul": 1, - - "plot_keys": - coin_game.PLOT_KEYS + - aggregate_and_plot_tensorboard_data.PLOT_KEYS, - "plot_assemblage_tags": - coin_game.PLOT_ASSEMBLAGE_TAGS + - aggregate_and_plot_tensorboard_data.PLOT_ASSEMBLAGE_TAGS, + "n_epi": 4000, + "buf_frac": 0.5, + "last_exploration_temp_value": 0.003, + "bs_epi_mul": 4, + "plot_keys": coin_game.PLOT_KEYS + + aggregate_and_plot_tensorboard_data.PLOT_KEYS, + "plot_assemblage_tags": coin_game.PLOT_ASSEMBLAGE_TAGS + + aggregate_and_plot_tensorboard_data.PLOT_ASSEMBLAGE_TAGS, } return hparams -def _get_rllib_configs(hp, env_class=None): +def get_rllib_configs(hp, env_class=None): stop_config = { "episodes_total": 2 if hp["debug"] else hp["n_epi"], } @@ -59,58 +57,70 @@ def _get_rllib_configs(hp, env_class=None): "max_steps": hp["n_steps_per_epi"], "grid_size": 3, "get_additional_info": True, + "buf_frac": hp["buf_frac"], + "bs_epi_mul": hp["bs_epi_mul"], } env_class = coin_game.CoinGame if env_class is None else env_class rllib_config = { "env": env_class, "env_config": env_config, - "multiagent": { "policies": { env_config["players_ids"][0]: ( augmented_dqn.MyDQNTorchPolicy, env_class(env_config).OBSERVATION_SPACE, env_class.ACTION_SPACE, - {}), + {}, + ), env_config["players_ids"][1]: ( augmented_dqn.MyDQNTorchPolicy, env_class(env_config).OBSERVATION_SPACE, env_class.ACTION_SPACE, - {}), + {}, + ), }, "policy_mapping_fn": lambda agent_id: agent_id, }, - # === DQN Models === - # Update the target network every `target_network_update_freq` steps. "target_network_update_freq": tune.sample_from( - lambda spec: int(spec.config["env_config"]["max_steps"] * 30)), + lambda spec: int(spec.config["env_config"]["max_steps"] * 30) + ), # === Replay buffer === # Size of the replay buffer. Note that if async_updates is set, then # each worker will have a replay buffer of this size. "buffer_size": tune.sample_from( - lambda spec: int(spec.config["env_config"]["max_steps"] * - spec.stop["episodes_total"] * hp["buf_frac"])), + lambda spec: int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * spec.config["env_config"]["buf_frac"] + ) + ), # Whether to use dueling dqn - "dueling": False, + "dueling": True, # Whether to use double dqn "double_q": True, # If True prioritized replay buffer will be used. "prioritized_replay": False, - "rollout_fragment_length": tune.sample_from( - lambda spec: spec.config["env_config"]["max_steps"]), - "training_intensity": 10, + lambda spec: spec.config["env_config"]["max_steps"] + ), + "training_intensity": tune.sample_from( + lambda spec: spec.config["num_envs_per_worker"] + * max(spec.config["num_workers"], 1) + * 40 + ), # Size of a batch sampled from replay buffer for training. Note that # if async_updates is set, then each worker returns gradients for a # batch of this size. "train_batch_size": tune.sample_from( - lambda spec: int(spec.config["env_config"]["max_steps"] * - hp["bs_epi_mul"])), + lambda spec: int( + spec.config["env_config"]["max_steps"] + * spec.config["env_config"]["bs_epi_mul"] + ) + ), "batch_mode": "complete_episodes", - # === Exploration Settings === # Default exploration behavior, iff `explore`=None is passed into # compute_action(s). @@ -130,61 +140,85 @@ def _get_rllib_configs(hp, env_class=None): "temperature_schedule": tune.sample_from( lambda spec: PiecewiseSchedule( endpoints=[ - (0, - 2.0), - (int(spec.config["env_config"]["max_steps"] * - spec.stop["episodes_total"] * 0.20), - 0.5), - (int(spec.config["env_config"]["max_steps"] * - spec.stop["episodes_total"] * 0.60), - hp["last_exploration_temp_value"])], + (0, 1.0), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * 0.20 + ), + 0.45, + ), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * 0.9 + ), + hp["last_exploration_temp_value"], + ), + ], outside_value=hp["last_exploration_temp_value"], - framework="torch")), + framework="torch", + ) + ), }, - # Size of batches collected from each worker. "model": { "dim": env_config["grid_size"], # [Channel, [Kernel, Kernel], Stride]] - "conv_filters": [[16, [3, 3], 1], [32, [3, 3], 1]] + "conv_filters": [[64, [3, 3], 1], [64, [3, 3], 1]], + "fcnet_hiddens": [64, 64], }, + "hiddens": [32], "gamma": 0.96, - "optimizer": {"sgd_momentum": 0.9, }, + "optimizer": { + "sgd_momentum": 0.9, + }, "lr": 0.1, "lr_schedule": tune.sample_from( lambda spec: [ (0, 0.0), - (int(spec.config["env_config"]["max_steps"] * - spec.stop["episodes_total"] * 0.05), - spec.config.lr), - (int(spec.config["env_config"]["max_steps"] * - spec.stop["episodes_total"]), - spec.config.lr / 1e9) + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * 0.05 + ), + spec.config.lr, + ), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + ), + spec.config.lr / 1e9, + ), ] ), - "seed": tune.grid_search(hp["seeds"]), "callbacks": log.get_logging_callbacks_class(), "framework": "torch", - "logger_config": { "wandb": { "project": "DQN_CG", "group": hp["exp_name"], - "api_key_file": - os.path.join(os.path.dirname(__file__), - "../../../api_key_wandb"), - "log_config": True + "api_key_file": os.path.join( + os.path.dirname(__file__), "../../../api_key_wandb" + ), + "log_config": True, }, }, - + "num_envs_per_worker": 16, + "num_workers": 0, + # "log_level": "INFO", } return rllib_config, stop_config -def _train_dqn_and_plot_logs(hp, rllib_config, stop_config): - ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=hp["debug"]) +def train_dqn_and_plot_logs(hp, rllib_config, stop_config): + ray.init(num_cpus=os.cpu_count(), local_mode=hp["debug"]) tune_analysis = tune.run( DQNTrainer, config=rllib_config, diff --git a/marltoolbox/examples/rllib_api/dqn_wt_welfare.py b/marltoolbox/examples/rllib_api/dqn_wt_welfare.py index 10a7083..601d20c 100644 --- a/marltoolbox/examples/rllib_api/dqn_wt_welfare.py +++ b/marltoolbox/examples/rllib_api/dqn_wt_welfare.py @@ -10,29 +10,35 @@ def main(debug, welfare=postprocessing.WELFARE_UTILITARIAN): seeds = miscellaneous.get_random_seeds(train_n_replicates) exp_name, _ = log.log_in_current_day_dir("DQN_welfare_CG") - hparams = dqn_coin_game._get_hyperparameters(seeds, debug, exp_name) - rllib_config, stop_config = dqn_coin_game._get_rllib_configs(hparams) - rllib_config = _modify_policy_to_use_welfare(rllib_config, welfare) + hparams = dqn_coin_game.get_hyperparameters(seeds, debug, exp_name) + rllib_config, stop_config = dqn_coin_game.get_rllib_configs(hparams) + rllib_config = modify_dqn_rllib_config_to_use_welfare( + rllib_config, welfare + ) - tune_analysis = dqn_coin_game._train_dqn_and_plot_logs( - hparams, rllib_config, stop_config) + tune_analysis = dqn_coin_game.train_dqn_and_plot_logs( + hparams, rllib_config, stop_config + ) return tune_analysis -def _modify_policy_to_use_welfare(rllib_config, welfare): - MyCoopDQNTorchPolicy = augmented_dqn.MyDQNTorchPolicy.with_updates( - postprocess_fn=miscellaneous.merge_policy_postprocessing_fn( - postprocessing.welfares_postprocessing_fn(), - postprocess_nstep_and_prio, - ) +def modify_dqn_rllib_config_to_use_welfare(rllib_config, welfare): + DQNTorchPolicyWtWelfare = _get_policy_class_wt_welfare_preprocessing() + rllib_config = modify_rllib_config_to_use_welfare( + rllib_config, welfare, policy_class_wt_welfare=DQNTorchPolicyWtWelfare ) + return rllib_config + +def modify_rllib_config_to_use_welfare( + rllib_config, welfare, policy_class_wt_welfare, overwrite_reward=True +): policies = rllib_config["multiagent"]["policies"] new_policies = {} for policies_id, policy_tuple in policies.items(): new_policies[policies_id] = list(policy_tuple) - new_policies[policies_id][0] = MyCoopDQNTorchPolicy + new_policies[policies_id][0] = policy_class_wt_welfare if welfare == postprocessing.WELFARE_UTILITARIAN: new_policies[policies_id][3].update( {postprocessing.ADD_UTILITARIAN_WELFARE: True} @@ -40,7 +46,7 @@ def _modify_policy_to_use_welfare(rllib_config, welfare): elif welfare == postprocessing.WELFARE_INEQUITY_AVERSION: add_ia_w = True ia_alpha = 0.0 - ia_beta = 0.5 + ia_beta = 0.5 / 2 ia_gamma = 0.96 ia_lambda = 0.96 inequity_aversion_parameters = ( @@ -51,18 +57,30 @@ def _modify_policy_to_use_welfare(rllib_config, welfare): ia_lambda, ) new_policies[policies_id][3].update( - {postprocessing.ADD_INEQUITY_AVERSION_WELFARE: - inequity_aversion_parameters} + { + postprocessing.ADD_INEQUITY_AVERSION_WELFARE: inequity_aversion_parameters + } ) rllib_config["multiagent"]["policies"] = new_policies - rllib_config["callbacks"] = callbacks.merge_callbacks( - log.get_logging_callbacks_class(), - postprocessing.OverwriteRewardWtWelfareCallback, - ) + if overwrite_reward: + rllib_config["callbacks"] = callbacks.merge_callbacks( + log.get_logging_callbacks_class(), + postprocessing.OverwriteRewardWtWelfareCallback, + ) return rllib_config +def _get_policy_class_wt_welfare_preprocessing(): + DQNTorchPolicyWtWelfare = augmented_dqn.MyDQNTorchPolicy.with_updates( + postprocess_fn=miscellaneous.merge_policy_postprocessing_fn( + postprocessing.welfares_postprocessing_fn(), + postprocess_nstep_and_prio, + ) + ) + return DQNTorchPolicyWtWelfare + + if __name__ == "__main__": debug_mode = True main(debug_mode) diff --git a/marltoolbox/examples/rllib_api/inequity_aversion.py b/marltoolbox/examples/rllib_api/inequity_aversion.py index c730ea2..9e679dc 100644 --- a/marltoolbox/examples/rllib_api/inequity_aversion.py +++ b/marltoolbox/examples/rllib_api/inequity_aversion.py @@ -21,39 +21,35 @@ def main(debug): } policies = { - env_config["players_ids"][0]: - ( - None, - IteratedBoSAndPD.OBSERVATION_SPACE, - IteratedBoSAndPD.ACTION_SPACE, - {} - ), - env_config["players_ids"][1]: - ( - None, - IteratedBoSAndPD.OBSERVATION_SPACE, - IteratedBoSAndPD.ACTION_SPACE, - {} - )} + env_config["players_ids"][0]: ( + None, + IteratedBoSAndPD.OBSERVATION_SPACE, + IteratedBoSAndPD.ACTION_SPACE, + {}, + ), + env_config["players_ids"][1]: ( + None, + IteratedBoSAndPD.OBSERVATION_SPACE, + IteratedBoSAndPD.ACTION_SPACE, + {}, + ), + } rllib_config = { "env": IteratedBoSAndPD, "env_config": env_config, - "num_gpus": 0, "num_workers": 1, - "multiagent": { "policies": policies, "policy_mapping_fn": (lambda agent_id: agent_id), }, "framework": "torch", "gamma": 0.5, - "callbacks": callbacks.merge_callbacks( log.get_logging_callbacks_class(), - postprocessing.OverwriteRewardWtWelfareCallback), - + postprocessing.OverwriteRewardWtWelfareCallback, + ), } MyPGTorchPolicy = PGTorchPolicy.with_updates( @@ -63,18 +59,21 @@ def main(debug): inequity_aversion_beta=1.0, inequity_aversion_alpha=0.0, inequity_aversion_gamma=1.0, - inequity_aversion_lambda=0.5 + inequity_aversion_lambda=0.5, ), - pg_torch_policy.post_process_advantages + pg_torch_policy.post_process_advantages, ) ) - MyPGTrainer = PGTrainer.with_updates(default_policy=MyPGTorchPolicy, - get_policy_class=None) - tune_analysis = tune.run(MyPGTrainer, - stop=stop, - checkpoint_freq=10, - config=rllib_config, - name=exp_name) + MyPGTrainer = PGTrainer.with_updates( + default_policy=MyPGTorchPolicy, get_policy_class=None + ) + tune_analysis = tune.run( + MyPGTrainer, + stop=stop, + checkpoint_freq=10, + config=rllib_config, + name=exp_name, + ) ray.shutdown() return tune_analysis diff --git a/marltoolbox/examples/rllib_api/pg_ipd.py b/marltoolbox/examples/rllib_api/pg_ipd.py index 769843d..89bb6df 100644 --- a/marltoolbox/examples/rllib_api/pg_ipd.py +++ b/marltoolbox/examples/rllib_api/pg_ipd.py @@ -2,7 +2,8 @@ import ray from ray import tune -from ray.rllib.agents.pg import PGTrainer +from ray.rllib.agents.pg import PGTrainer, PGTorchPolicy +from ray.rllib.agents.pg.pg_torch_policy import pg_loss_stats from marltoolbox.envs.matrix_sequential_social_dilemma import ( IteratedPrisonersDilemma, @@ -10,14 +11,14 @@ from marltoolbox.utils import log, miscellaneous -def main(debug, stop_iters=300, tf=False): +def main(debug): train_n_replicates = 1 if debug else 1 seeds = miscellaneous.get_random_seeds(train_n_replicates) exp_name, _ = log.log_in_current_day_dir("PG_IPD") ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug) - rllib_config, stop_config = get_rllib_config(seeds, debug, stop_iters, tf) + rllib_config, stop_config = get_rllib_config(seeds, debug) tune_analysis = tune.run( PGTrainer, config=rllib_config, @@ -30,14 +31,16 @@ def main(debug, stop_iters=300, tf=False): return tune_analysis -def get_rllib_config(seeds, debug=False, stop_iters=300, tf=False): +def get_rllib_config(seeds, debug=False): stop_config = { - "training_iteration": 2 if debug else stop_iters, + "episodes_total": 2 if debug else 400, } + n_steps_in_epi = 20 + env_config = { "players_ids": ["player_row", "player_col"], - "max_steps": 20, + "max_steps": n_steps_in_epi, "get_additional_info": True, } @@ -63,8 +66,9 @@ def get_rllib_config(seeds, debug=False, stop_iters=300, tf=False): }, "seed": tune.grid_search(seeds), "callbacks": log.get_logging_callbacks_class(log_full_epi=True), - "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), - "framework": "tf" if tf else "torch", + "framework": "torch", + "rollout_fragment_length": n_steps_in_epi, + "train_batch_size": n_steps_in_epi, } return rllib_config, stop_config diff --git a/marltoolbox/examples/rllib_api/ppo_coin_game.py b/marltoolbox/examples/rllib_api/ppo_coin_game.py index 78c10a1..3243b73 100644 --- a/marltoolbox/examples/rllib_api/ppo_coin_game.py +++ b/marltoolbox/examples/rllib_api/ppo_coin_game.py @@ -8,12 +8,8 @@ from marltoolbox.envs.coin_game import AsymCoinGame from marltoolbox.utils import log, miscellaneous -parser = argparse.ArgumentParser() -parser.add_argument("--tf", action="store_true") -parser.add_argument("--stop-iters", type=int, default=2000) - -def main(debug, stop_iters=2000, tf=False): +def main(debug): train_n_replicates = 1 if debug else 1 seeds = miscellaneous.get_random_seeds(train_n_replicates) exp_name, _ = log.log_in_current_day_dir("PPO_AsymCG") @@ -21,7 +17,7 @@ def main(debug, stop_iters=2000, tf=False): ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug) stop = { - "training_iteration": 2 if debug else stop_iters, + "training_iteration": 2 if debug else 2000, } env_class = AsymCoinGame @@ -66,7 +62,7 @@ def main(debug, stop_iters=2000, tf=False): "seed": tune.grid_search(seeds), "callbacks": log.get_logging_callbacks_class(), "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), - "framework": "tf" if tf else "torch", + "framework": "torch", "num_workers": 0, } @@ -83,6 +79,5 @@ def main(debug, stop_iters=2000, tf=False): if __name__ == "__main__": - args = parser.parse_args() debug_mode = True - main(debug_mode, args.stop_iters, args.tf) + main(debug_mode) diff --git a/marltoolbox/examples/rllib_api/r2d2_coin_game.py b/marltoolbox/examples/rllib_api/r2d2_coin_game.py new file mode 100644 index 0000000..454f32d --- /dev/null +++ b/marltoolbox/examples/rllib_api/r2d2_coin_game.py @@ -0,0 +1,249 @@ +import os + +import ray +from ray import tune +from ray.rllib.agents import dqn +from ray.tune.integration.wandb import WandbLogger +from ray.tune.logger import DEFAULT_LOGGERS +from ray.rllib.agents.dqn.r2d2_torch_policy import postprocess_nstep_and_prio +from ray.rllib.utils.schedules import PiecewiseSchedule + +from marltoolbox.examples.rllib_api import dqn_coin_game, dqn_wt_welfare +from marltoolbox.scripts import aggregate_and_plot_tensorboard_data +from marltoolbox.utils import log, miscellaneous, postprocessing +from marltoolbox.algos import augmented_r2d2 +from marltoolbox.envs import coin_game, ssd_mixed_motive_coin_game + + +def main(debug): + """Train R2D2 agent in the Coin Game environment""" + + # env = "CoinGame" + env = "SSDMixedMotiveCoinGame" + # welfare_to_use = None + welfare_to_use = postprocessing.WELFARE_UTILITARIAN + # welfare_to_use = postprocessing.WELFARE_INEQUITY_AVERSION + + rllib_config, stop_config, hparams = _get_config_and_hp_for_training( + debug, env, welfare_to_use + ) + + tune_analysis = _train_dqn(hparams, rllib_config, stop_config) + + _plot_log_aggregates(hparams) + + return tune_analysis + + +def _get_config_and_hp_for_training(debug, env, welfare_to_use): + train_n_replicates = 1 if debug else 1 + seeds = miscellaneous.get_random_seeds(train_n_replicates) + exp_name, _ = log.log_in_current_day_dir("R2D2_CG") + if "SSDMixedMotiveCoinGame" in env: + env_class = ssd_mixed_motive_coin_game.SSDMixedMotiveCoinGame + else: + env_class = coin_game.CoinGame + + hparams = dqn_coin_game.get_hyperparameters(seeds, debug, exp_name) + rllib_config, stop_config = dqn_coin_game.get_rllib_configs( + hparams, env_class=env_class + ) + rllib_config, stop_config = _adapt_configs_for_r2d2( + rllib_config, stop_config, hparams + ) + if welfare_to_use is not None: + rllib_config = modify_r2d2_rllib_config_to_use_welfare( + rllib_config, welfare_to_use + ) + return rllib_config, stop_config, hparams + + +def modify_r2d2_rllib_config_to_use_welfare(rllib_config, welfare_to_use): + r2d2_torch_policy_class_wt_welfare = ( + _get_r2d2_policy_class_wt_welfare_preprocessing() + ) + rllib_config = dqn_wt_welfare.modify_rllib_config_to_use_welfare( + rllib_config, + welfare_to_use, + policy_class_wt_welfare=r2d2_torch_policy_class_wt_welfare, + ) + return rllib_config + + +def _get_r2d2_policy_class_wt_welfare_preprocessing(): + r2d2_torch_policy_class_wt_welfare = ( + augmented_r2d2.MyR2D2TorchPolicy.with_updates( + postprocess_fn=miscellaneous.merge_policy_postprocessing_fn( + postprocessing.welfares_postprocessing_fn(), + postprocess_nstep_and_prio, + ) + ) + ) + return r2d2_torch_policy_class_wt_welfare + + +def _adapt_configs_for_r2d2(rllib_config, stop_config, hp): + rllib_config["logger_config"]["wandb"]["project"] = "R2D2_CG" + rllib_config["model"]["use_lstm"] = True + rllib_config["burn_in"] = 0 + rllib_config["zero_init_states"] = False + rllib_config["use_h_function"] = False + + rllib_config = _replace_class_of_policies_by( + augmented_r2d2.MyR2D2TorchPolicy, + rllib_config, + ) + + if not hp["debug"]: + # rllib_config["env_config"]["training_intensity"] = 40 + # rllib_config["lr"] = 0.1 + # rllib_config["training_intensity"] = tune.sample_from( + # lambda spec: spec.config["num_envs_per_worker"] + # * spec.config["env_config"]["training_intensity"] + # * max(1, spec.config["num_workers"]) + # ) + stop_config["episodes_total"] = 8000 + rllib_config["model"]["lstm_cell_size"] = 16 + rllib_config["model"]["max_seq_len"] = 20 + # rllib_config["env_config"]["buf_frac"] = 0.5 + # rllib_config["hiddens"] = [32] + # rllib_config["model"]["fcnet_hiddens"] = [64, 64] + # rllib_config["model"]["conv_filters"] = tune.grid_search( + # [ + # [[32, [3, 3], 1], [32, [3, 3], 1]], + # [[64, [3, 3], 1], [64, [3, 3], 1]], + # ] + # ) + + rllib_config["env_config"]["temp_ratio"] = ( + 1.0 if hp["debug"] else tune.grid_search([0.75, 1.0]) + ) + rllib_config["env_config"]["interm_temp_ratio"] = ( + 1.0 if hp["debug"] else tune.grid_search([0.75, 1.0]) + ) + rllib_config["env_config"]["last_exploration_t"] = ( + 0.6 if hp["debug"] else tune.grid_search([0.9, 0.99]) + ) + rllib_config["env_config"]["last_exploration_temp_value"] = ( + 1.0 if hp["debug"] else 0.003 + ) + rllib_config["env_config"]["interm_exploration_t"] = ( + 0.2 if hp["debug"] else tune.grid_search([0.2, 0.6]) + ) + rllib_config["exploration_config"][ + "temperature_schedule" + ] = tune.sample_from( + lambda spec: PiecewiseSchedule( + endpoints=[ + (0, 1.0 * spec.config["env_config"]["temp_ratio"]), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * spec.config["env_config"]["interm_exploration_t"] + ), + 0.6 + * spec.config["env_config"]["temp_ratio"] + * spec.config["env_config"]["interm_temp_ratio"], + ), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * spec.config["env_config"]["last_exploration_t"] + ), + spec.config["env_config"][ + "last_exploration_temp_value" + ], + ), + ], + outside_value=spec.config["env_config"][ + "last_exploration_temp_value" + ], + framework="torch", + ) + ) + + # rllib_config["env_config"]["bs_epi_mul"] = ( + # 4 if hp["debug"] else tune.grid_search([4, 8, 16]) + # ) + rllib_config["env_config"]["interm_lr_ratio"] = ( + 0.5 + if hp["debug"] + else tune.grid_search([0.5 * 3, 0.5, 0.5 / 3, 0.5 / 9]) + ) + rllib_config["env_config"]["interm_lr_t"] = ( + 0.5 if hp["debug"] else tune.grid_search([0.25, 0.5, 0.75]) + ) + rllib_config["lr"] = 0.1 if hp["debug"] else 0.1 + rllib_config["lr_schedule"] = tune.sample_from( + lambda spec: [ + (0, 0.0), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * 0.05 + ), + spec.config.lr, + ), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * spec.config["env_config"]["interm_lr_t"] + ), + spec.config.lr + * spec.config["env_config"]["interm_lr_ratio"], + ), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + ), + spec.config.lr / 1e9, + ), + ] + ) + else: + rllib_config["model"]["max_seq_len"] = 2 + rllib_config["model"]["lstm_cell_size"] = 8 + + return rllib_config, stop_config + + +def _replace_class_of_policies_by(new_policy_class, rllib_config): + policies = rllib_config["multiagent"]["policies"] + for policy_id in policies.keys(): + policy = list(policies[policy_id]) + policy[0] = new_policy_class + policies[policy_id] = tuple(policy) + return rllib_config + + +def _train_dqn(hp, rllib_config, stop_config): + ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=hp["debug"]) + tune_analysis = tune.run( + dqn.r2d2.R2D2Trainer, + config=rllib_config, + stop=stop_config, + name=hp["exp_name"], + log_to_file=not hp["debug"], + loggers=None if hp["debug"] else DEFAULT_LOGGERS + (WandbLogger,), + ) + ray.shutdown() + return tune_analysis + + +def _plot_log_aggregates(hp): + if not hp["debug"]: + aggregate_and_plot_tensorboard_data.add_summary_plots( + main_path=os.path.join("~/ray_results/", hp["exp_name"]), + plot_keys=hp["plot_keys"], + plot_assemble_tags_in_one_plot=hp["plot_assemblage_tags"], + ) + + +if __name__ == "__main__": + debug_mode = True + main(debug_mode) diff --git a/marltoolbox/examples/rllib_api/r2d2_ipd.py b/marltoolbox/examples/rllib_api/r2d2_ipd.py new file mode 100644 index 0000000..5d6d0bd --- /dev/null +++ b/marltoolbox/examples/rllib_api/r2d2_ipd.py @@ -0,0 +1,68 @@ +import os + +import ray +from ray import tune +from ray.rllib.agents.dqn import R2D2Trainer + +from marltoolbox.algos import augmented_dqn +from marltoolbox.envs import matrix_sequential_social_dilemma +from marltoolbox.examples.rllib_api import pg_ipd +from marltoolbox.scripts import aggregate_and_plot_tensorboard_data +from marltoolbox.utils import log, miscellaneous + + +def main(debug): + train_n_replicates = 1 if debug else 1 + seeds = miscellaneous.get_random_seeds(train_n_replicates) + exp_name, _ = log.log_in_current_day_dir("R2D2_IPD") + + ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug) + + rllib_config, stop_config = pg_ipd.get_rllib_config(seeds, debug) + rllib_config, stop_config = _adapt_configs_for_r2d2( + rllib_config, stop_config, debug + ) + + tune_analysis = tune.run( + R2D2Trainer, + config=rllib_config, + stop=stop_config, + checkpoint_at_end=True, + name=exp_name, + log_to_file=True, + ) + + if not debug: + _plot_log_aggregates(exp_name) + + ray.shutdown() + return tune_analysis + + +def _adapt_configs_for_r2d2(rllib_config, stop_config, debug): + rllib_config["model"] = {"use_lstm": True} + stop_config["episodes_total"] = 10 if debug else 600 + + return rllib_config, stop_config + + +def _plot_log_aggregates(exp_name): + plot_keys = ( + aggregate_and_plot_tensorboard_data.PLOT_KEYS + + matrix_sequential_social_dilemma.PLOT_KEYS + + augmented_dqn.PL + ) + plot_assemble_tags_in_one_plot = ( + aggregate_and_plot_tensorboard_data.PLOT_ASSEMBLAGE_TAGS + + matrix_sequential_social_dilemma.PLOT_ASSEMBLAGE_TAGS + ) + aggregate_and_plot_tensorboard_data.add_summary_plots( + main_path=os.path.join("~/ray_results/", exp_name), + plot_keys=plot_keys, + plot_assemble_tags_in_one_plot=plot_assemble_tags_in_one_plot, + ) + + +if __name__ == "__main__": + debug_mode = True + main(debug_mode) diff --git a/marltoolbox/experiments/rllib_api/amtft_meta_game.py b/marltoolbox/experiments/rllib_api/amtft_meta_game.py index 225aba3..1019a08 100644 --- a/marltoolbox/experiments/rllib_api/amtft_meta_game.py +++ b/marltoolbox/experiments/rllib_api/amtft_meta_game.py @@ -1,8 +1,9 @@ +import collections import copy import json import logging import os -import random +from typing import List, Dict import numpy as np import pandas as pd @@ -11,91 +12,120 @@ from ray.rllib.agents import dqn from ray.rllib.utils import merge_dicts +from marltoolbox import utils from marltoolbox.algos import welfare_coordination from marltoolbox.experiments.rllib_api import amtft_various_env from marltoolbox.utils import ( - self_and_cross_perf, postprocessing, - miscellaneous, plot, + path, + cross_play, + tune_analysis, + restore, ) logger = logging.getLogger(__name__) def main(debug): - if debug: - test_meta_solver(debug) - - hp = _get_hyperparameters(debug) + # _extract_stats_on_welfare_announced( + # exp_dir="/home/maxime/dev-maxime/CLR/vm-data/instance-60-cpu-3-preemtible/amTFT/2021_04_23/12_58_42", + # players_ids=["player_row", "player_col"], + # ) + # + # if debug: + # test_meta_solver(debug) + use_r2d2 = True + hp = get_hyperparameters(debug, use_r2d2=use_r2d2) results = [] + ray.init(num_cpus=os.cpu_count(), local_mode=hp["debug"]) for tau in hp["tau_range"]: + hp["tau"] = tau ( all_rllib_config, hp_eval, env_config, stop_config, - ) = _produce_rllib_config_for_each_replicates(tau, hp) + ) = _produce_rllib_config_for_each_replicates(hp) mixed_rllib_configs = _mix_rllib_config(all_rllib_config, hp_eval) - tune_analysis = ray.tune.run( - dqn.DQNTrainer, + hp["trainer"], config=mixed_rllib_configs, verbose=1, stop=stop_config, - checkpoint_at_end=True, name=hp_eval["exp_name"], - log_to_file=not hp_eval["debug"], - # loggers=None - # if hp_eval["debug"] - # else DEFAULT_LOGGERS + (WandbLogger,), - ) - mean_player_1_payoffs, mean_player_2_payoffs = _extract_metrics( - tune_analysis, hp_eval + # log_to_file=not hp_eval["debug"], ) + ( + mean_player_1_payoffs, + mean_player_2_payoffs, + player_1_payoffs, + player_2_payoffs, + ) = extract_metrics(tune_analysis, hp_eval) results.append( ( tau, (mean_player_1_payoffs, mean_player_2_payoffs), + (player_1_payoffs, player_2_payoffs), ) ) - _save_to_json(exp_name=hp["exp_name"], object=results) - _plot_results(exp_name=hp["exp_name"], results=results, hp_eval=hp_eval) + save_to_json(exp_name=hp["exp_name"], object=results) + plot_results( + exp_name=hp["exp_name"], + results=results, + hp_eval=hp_eval, + format_fn=format_result_for_plotting, + ) + extract_stats_on_welfare_announced( + exp_name=hp["exp_name"], players_ids=env_config["players_ids"] + ) + +def get_hyperparameters(debug, use_r2d2=False): + """Get hyperparameters for meta game with amTFT policies in base game""" -def _get_hyperparameters(debug): # env = "IteratedPrisonersDilemma" env = "IteratedAsymBoS" # env = "CoinGame" hp = amtft_various_env.get_hyperparameters( - debug, train_n_replicates=1, filter_utilitarian=False, env=env + debug, + train_n_replicates=1, + filter_utilitarian=False, + env=env, + use_r2d2=use_r2d2, ) hp.update( { - "n_replicates_over_full_exp": 2, - "final_base_game_eval_over_n_epi": 2, + "n_replicates_over_full_exp": 2 if debug else 5, + "final_base_game_eval_over_n_epi": 1 if debug else 20, "tau_range": np.arange(0.0, 1.1, 0.5) if hp["debug"] else np.arange(0.0, 1.1, 0.1), "n_self_play_in_final_meta_game": 0, - "n_cross_play_in_final_meta_game": 1, + "n_cross_play_in_final_meta_game": 1 if debug else 4, + "max_n_replicates_in_base_game": 10, } ) + + if use_r2d2: + hp["trainer"] = dqn.r2d2.R2D2Trainer + else: + hp["trainer"] = dqn.DQNTrainer + return hp -def _produce_rllib_config_for_each_replicates(tau, hp): +def _produce_rllib_config_for_each_replicates(hp): all_rllib_config = [] for replicate_i in range(hp["n_replicates_over_full_exp"]): hp_eval = _load_base_game_results( copy.deepcopy(hp), load_base_replicate_i=replicate_i ) - hp_eval["tau"] = tau ( rllib_config, @@ -110,29 +140,55 @@ def _produce_rllib_config_for_each_replicates(tau, hp): rllib_config, env_config, hp_eval ) all_rllib_config.append(rllib_config) + return all_rllib_config, hp_eval, env_config, stop_config def _load_base_game_results(hp, load_base_replicate_i): - prefix = "/home/maxime/dev-maxime/CLR/vm-data/instance-10-cpu-2/" - # prefix = "/ray_results/" + # prefix = "~/dev-maxime/CLR/vm-data/instance-10-cpu-2/" + # prefix = "~/dev-maxime/CLR/vm-data/instance-60-cpu-1-preemtible/" + prefix = "~/dev-maxime/CLR/vm-data/instance-60-cpu-2-preemtible/" + # prefix = "~/ray_results/" + prefix = os.path.expanduser(prefix) if "CoinGame" in hp["env_name"]: hp["data_dir"] = (prefix + "amTFT/2021_04_10/19_37_20",)[ load_base_replicate_i ] elif "IteratedAsymBoS" in hp["env_name"]: hp["data_dir"] = ( - prefix + "amTFT/2021_04_13/11_56_23", # 5 replicates - prefix + "amTFT/2021_04_13/13_40_03", # 5 replicates - prefix + "amTFT/2021_04_13/13_40_34", # 5 replicates - prefix + "amTFT/2021_04_13/18_06_48", # 10 replicates - prefix + "amTFT/2021_04_13/18_07_05", # 10 replicates + # Before fix + without R2D2 + # prefix + "amTFT/2021_04_13/11_56_23", # 5 replicates + # prefix + "amTFT/2021_04_13/13_40_03", # 5 replicates + # prefix + "amTFT/2021_04_13/13_40_34", # 5 replicates + # prefix + "amTFT/2021_04_13/18_06_48", # 10 replicates + # prefix + "amTFT/2021_04_13/18_07_05", # 10 replicates + # After env fix + with R2D2 + # prefix + "amTFT/2021_05_03/13_44_48", # 10 replicates + # prefix + "amTFT/2021_05_03/13_45_09", # 10 replicates + # prefix + "amTFT/2021_05_03/13_47_00", # 10 replicates + # prefix + "amTFT/2021_05_03/13_48_03", # 10 replicates + # prefix + "amTFT/2021_05_03/13_49_28", # 10 replicates + # prefix + "amTFT/2021_05_03/13_49_47", # 10 replicates + # After fix env + short debit rollout + punish instead of + # selfish + 20 rllout length insead of 6 + prefix + "amTFT/2021_05_03/13_52_07", # 10 replicates + prefix + "amTFT/2021_05_03/13_52_27", # 10 replicates + prefix + "amTFT/2021_05_03/13_53_37", # 10 replicates + ### prefix + "amTFT/2021_05_03/13_54_08", # 10 replicates + prefix + "amTFT/2021_05_03/13_54_55", # 10 replicates + prefix + "amTFT/2021_05_03/13_56_35", # 10 replicates )[load_base_replicate_i] elif "IteratedPrisonersDilemma" in hp["env_name"]: hp["data_dir"] = ( "/home/maxime/dev-maxime/CLR/vm-data/" "instance-10-cpu-4/amTFT/2021_04_13/12_12_56", )[load_base_replicate_i] + else: + raise ValueError() + + assert os.path.exists( + hp["data_dir"] + ), "Path doesn't exist. Probably that the prefix need to be changed to fit the current machine used" hp["json_file"] = _get_results_json_file_path_in_dir(hp["data_dir"]) hp["ckpt_per_welfare"] = _get_checkpoints_for_each_welfare_in_dir( @@ -150,17 +206,17 @@ def _get_vanilla_amTFT_eval_config(hp, final_eval_over_n_epi): stop_config, env_config, rllib_config = amtft_various_env.get_rllib_config( hp_eval, welfare_fn=postprocessing.WELFARE_INEQUITY_AVERSION, eval=True ) - rllib_config = amtft_various_env.modify_config_for_evaluation( - rllib_config, hp_eval, env_config + rllib_config, stop_config = amtft_various_env.modify_config_for_evaluation( + rllib_config, hp_eval, env_config, stop_config ) hp_eval["n_self_play_per_checkpoint"] = None hp_eval["n_cross_play_per_checkpoint"] = None hp_eval[ "x_axis_metric" - ] = f"policy_reward_mean.{env_config['players_ids'][0]}" + ] = f"policy_reward_mean/{env_config['players_ids'][0]}" hp_eval[ "y_axis_metric" - ] = f"policy_reward_mean.{env_config['players_ids'][1]}" + ] = f"policy_reward_mean/{env_config['players_ids'][1]}" return rllib_config, hp_eval, env_config, stop_config @@ -168,8 +224,10 @@ def _get_vanilla_amTFT_eval_config(hp, final_eval_over_n_epi): def _modify_config_to_use_welfare_coordinators( rllib_config, env_config, hp_eval ): - all_welfare_pairs_wt_payoffs = _get_all_welfare_pairs_wt_payoffs( - hp_eval, env_config["players_ids"] + all_welfare_pairs_wt_payoffs = ( + _get_all_welfare_pairs_wt_cross_play_payoffs( + hp_eval, env_config["players_ids"] + ) ) rllib_config["multiagent"]["policies_to_train"] = ["None"] @@ -192,17 +250,11 @@ def _modify_config_to_use_welfare_coordinators( "all_welfare_pairs_wt_payoffs": all_welfare_pairs_wt_payoffs, "own_player_idx": policy_idx, "opp_player_idx": opp_policy_idx, - # "own_default_welfare_fn": postprocessing.WELFARE_INEQUITY_AVERSION - # if policy_idx - # else postprocessing.WELFARE_UTILITARIAN, - # "opp_default_welfare_fn": postprocessing.WELFARE_INEQUITY_AVERSION - # if opp_policy_idx - # else postprocessing.WELFARE_UTILITARIAN, "own_default_welfare_fn": "inequity aversion" - if policy_idx + if policy_idx == 1 else "utilitarian", "opp_default_welfare_fn": "inequity aversion" - if opp_policy_idx + if opp_policy_idx == 1 else "utilitarian", "policy_id_to_load": policy_id, "policy_checkpoints": hp_eval["ckpt_per_welfare"], @@ -227,7 +279,7 @@ def _mix_rllib_config(all_rllib_configs, hp_eval): hp_eval["n_self_play_in_final_meta_game"] == 0 and hp_eval["n_cross_play_in_final_meta_game"] != 0 ): - master_config = _mix_configs_policies( + master_config = cross_play.utils.mix_policies_in_given_rllib_configs( all_rllib_configs, n_mix_per_config=hp_eval["n_cross_play_in_final_meta_game"], ) @@ -236,100 +288,15 @@ def _mix_rllib_config(all_rllib_configs, hp_eval): raise ValueError() -def _mix_configs_policies(all_rllib_configs, n_mix_per_config): - assert n_mix_per_config <= len(all_rllib_configs) - 1 - policy_ids = all_rllib_configs[0]["multiagent"]["policies"].keys() - assert len(policy_ids) == 2 - _assert_all_config_use_the_same_policies(all_rllib_configs, policy_ids) - - policy_config_variants = _gather_policy_variant_per_policy_id( - all_rllib_configs, policy_ids - ) - - master_config = _create_one_master_config( - all_rllib_configs, policy_config_variants, policy_ids, n_mix_per_config - ) - return master_config - - -def _create_one_master_config( - all_rllib_configs, policy_config_variants, policy_ids, n_mix_per_config -): - all_policy_mix = [] - player_1, player_2 = policy_ids - for config_idx, p1_policy_config in enumerate( - policy_config_variants[player_1] - ): - policies_mixes = _produce_n_mix_with_player_2_policies( - policy_config_variants, - player_2, - config_idx, - n_mix_per_config, - player_1, - p1_policy_config, - ) - all_policy_mix.extend(policies_mixes) - - master_config = copy.deepcopy(all_rllib_configs[0]) - master_config["multiagent"]["policies"] = tune.grid_search(all_policy_mix) - return master_config - - -def _produce_n_mix_with_player_2_policies( - policy_config_variants, - player_2, - config_idx, - n_mix_per_config, - player_1, - p1_policy_config, -): - p2_policy_configs_sampled = _get_p2_policies_samples_excluding_self( - policy_config_variants, player_2, config_idx, n_mix_per_config - ) - policies_mixes = [] - for p2_policy_config in p2_policy_configs_sampled: - policy_mix = { - player_1: p1_policy_config, - player_2: p2_policy_config, - } - policies_mixes.append(policy_mix) - return policies_mixes - - -def _get_p2_policies_samples_excluding_self( - policy_config_variants, player_2, config_idx, n_mix_per_config -): - p2_policy_config_variants = copy.deepcopy(policy_config_variants[player_2]) - p2_policy_config_variants.pop(config_idx) - p2_policy_configs_sampled = random.sample( - p2_policy_config_variants, n_mix_per_config - ) - return p2_policy_configs_sampled - - -def _assert_all_config_use_the_same_policies(all_rllib_configs, policy_ids): - for rllib_config in all_rllib_configs: - assert rllib_config["multiagent"]["policies"].keys() == policy_ids - - -def _gather_policy_variant_per_policy_id(all_rllib_configs, policy_ids): - policy_config_variants = {} - for policy_id in policy_ids: - policy_config_variants[policy_id] = [] - for rllib_config in all_rllib_configs: - policy_config_variants[policy_id].append( - copy.deepcopy( - rllib_config["multiagent"]["policies"][policy_id] - ) - ) - return policy_config_variants - - -def _extract_metrics(tune_analysis, hp_eval): - player_1_payoffs = miscellaneous.extract_metric_values_per_trials( +def extract_metrics(tune_analysis, hp_eval): + # TODO PROBLEM using last result but there are several episodes with + # different welfare functions !! => BUT plot is good + # Metric from the last result means the metric average over the last + # training iteration and we have only one training iteration here. + player_1_payoffs = utils.tune_analysis.extract_metrics_for_each_trials( tune_analysis, metric=hp_eval["x_axis_metric"] ) - player_2_payoffs = miscellaneous.extract_metric_values_per_trials( + player_2_payoffs = utils.tune_analysis.extract_metrics_for_each_trials( tune_analysis, metric=hp_eval["y_axis_metric"] ) mean_player_1_payoffs = sum(player_1_payoffs) / len(player_1_payoffs) @@ -339,33 +306,50 @@ def _extract_metrics(tune_analysis, hp_eval): mean_player_1_payoffs, mean_player_2_payoffs, ) - return mean_player_1_payoffs, mean_player_2_payoffs + return ( + mean_player_1_payoffs, + mean_player_2_payoffs, + player_1_payoffs, + player_2_payoffs, + ) -def _save_to_json(exp_name, object): - exp_dir = os.path.join("~/ray_results", exp_name) - exp_dir = os.path.expanduser(exp_dir) - json_file = os.path.join(exp_dir, "final_eval_in_base_game.json") +class NumpyEncoder(json.JSONEncoder): + """ Special json encoder for numpy types """ + + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + return json.JSONEncoder.default(self, obj) + + +def save_to_json(exp_name, object, filename="final_eval_in_base_game.json"): + exp_dir = _get_exp_dir_from_exp_name(exp_name) + json_file = os.path.join(exp_dir, filename) + if not os.path.exists(exp_dir): + os.makedirs(exp_dir) with open(json_file, "w") as outfile: json.dump(object, outfile) -def _plot_results(exp_name, results, hp_eval): - exp_dir = os.path.join("~/ray_results", exp_name) - exp_dir = os.path.expanduser(exp_dir) +def plot_results(exp_name, results, hp_eval, format_fn, jitter=0.0): + exp_dir = _get_exp_dir_from_exp_name(exp_name) + data_groups_per_mode = format_fn(results) - data_groups_per_mode = _format_result_for_plotting(results) + background_area_coord = None + if "CoinGame" not in hp_eval["env_name"] and "env_class" in hp_eval.keys(): + background_area_coord = hp_eval["env_class"].PAYOFF_MATRIX - if "CoinGame" in hp_eval["env_name"]: - background_area_coord = None - else: - background_area_coord = hp_eval["env_class"].PAYOUT_MATRIX plot_config = plot.PlotConfig( save_dir_path=exp_dir, xlim=hp_eval["x_limits"], ylim=hp_eval["y_limits"], markersize=5, - jitter=hp_eval["jitter"], + jitter=jitter, xlabel="player 1 payoffs", ylabel="player 2 payoffs", x_scale_multiplier=hp_eval["plot_axis_scale_multipliers"][0], @@ -377,12 +361,9 @@ def _plot_results(exp_name, results, hp_eval): plot_helper.plot_dots(data_groups_per_mode) -def _format_result_for_plotting(results): +def format_result_for_plotting(results): data_groups_per_mode = {} - for ( - tau, - (mean_player_1_payoffs, mean_player_2_payoffs), - ) in results: + for (tau, (mean_player_1_payoffs, mean_player_2_payoffs), _) in results: df_row_dict = {} df_row_dict[f"mean"] = ( mean_player_1_payoffs, @@ -394,7 +375,7 @@ def _format_result_for_plotting(results): def test_meta_solver(debug): - hp = _get_hyperparameters(debug) + hp = get_hyperparameters(debug) hp = _load_base_game_results(hp, load_base_replicate_i=0) stop, env_config, rllib_config = amtft_various_env.get_rllib_config( hp, welfare_fn=postprocessing.WELFARE_INEQUITY_AVERSION, eval=True @@ -422,8 +403,10 @@ def test_meta_solver(debug): ) ) - all_welfare_pairs_wt_payoffs = _get_all_welfare_pairs_wt_payoffs( - hp, env_config["players_ids"] + all_welfare_pairs_wt_payoffs = ( + _get_all_welfare_pairs_wt_cross_play_payoffs( + hp, env_config["players_ids"] + ) ) print("all_welfare_pairs_wt_payoffs", all_welfare_pairs_wt_payoffs) player_meta_policy.setup_meta_game( @@ -469,120 +452,96 @@ def _get_results_json_file_path_in_dir(data_dir): def _eval_dir(data_dir): eval_dir_parents_2 = os.path.join(data_dir, "eval") - eval_dir_parents_1 = _get_unique_child_dir(eval_dir_parents_2) - eval_dir = _get_unique_child_dir(eval_dir_parents_1) + eval_dir_parents_1 = path.get_unique_child_dir(eval_dir_parents_2) + eval_dir = path.get_unique_child_dir(eval_dir_parents_1) return eval_dir def _get_json_results_path(eval_dir): - all_files_filtered = _get_children_paths_wt_selecting_filter( - eval_dir, _filter=self_and_cross_perf.RESULTS_SUMMARY_FILENAME_PREFIX + all_files_filtered = path.get_children_paths_wt_selecting_filter( + eval_dir, + _filter=cross_play.evaluator.RESULTS_SUMMARY_FILENAME_PREFIX, ) all_files_filtered = [ file_path for file_path in all_files_filtered if ".json" in file_path ] + all_files_filtered = [ + file_path + for file_path in all_files_filtered + if os.path.split(file_path)[-1].startswith("plotself") + or not os.path.split(file_path)[-1].startswith("plot") + ] assert len(all_files_filtered) == 1, f"{all_files_filtered}" json_result_path = os.path.join(eval_dir, all_files_filtered[0]) return json_result_path -def _get_unique_child_dir(_dir): - list_dir = os.listdir(_dir) - assert len(list_dir) == 1, f"{list_dir}" - unique_child_dir = os.path.join(_dir, list_dir[0]) - return unique_child_dir - - def _get_checkpoints_for_each_welfare_in_dir(data_dir, hp): + """Get the checkpoints of the base game policies in self-play""" ckpt_per_welfare = {} for welfare_fn, welfare_name in hp["welfare_functions"]: welfare_training_save_dir = os.path.join(data_dir, welfare_fn, "coop") - all_replicates_save_dir = _get_dir_of_each_replicate( - welfare_training_save_dir + all_replicates_save_dir = _get_replicates_dir_for_all_possible_config( + hp, welfare_training_save_dir ) - ckpts = _get_checkpoint_for_each_replicates(all_replicates_save_dir) + all_replicates_save_dir = _filter_checkpoints_if_utilitarians( + all_replicates_save_dir, hp + ) + ckpts = restore.get_checkpoint_for_each_replicates( + all_replicates_save_dir + ) + print("len(ckpts)", len(ckpts)) ckpt_per_welfare[welfare_name.replace("_", " ")] = ckpts return ckpt_per_welfare -def _get_dir_of_each_replicate(welfare_training_save_dir): - return _get_children_paths_wt_selecting_filter( - welfare_training_save_dir, _filter="DQN_" - ) - - -def _get_checkpoint_for_each_replicates(all_replicates_save_dir): - ckpt_dir_per_replicate = [] - for replicate_dir_path in all_replicates_save_dir: - ckpt_dir_path = _get_ckpt_dir_for_one_replicate(replicate_dir_path) - ckpt_path = _get_ckpt_fro_ckpt_dir(ckpt_dir_path) - ckpt_dir_per_replicate.append(ckpt_path) - return ckpt_dir_per_replicate - - -def _get_ckpt_dir_for_one_replicate(replicate_dir_path): - partialy_filtered_ckpt_dir = _get_children_paths_wt_selecting_filter( - replicate_dir_path, _filter="checkpoint_" - ) - ckpt_dir = [ - file_path - for file_path in partialy_filtered_ckpt_dir - if ".is_checkpoint" not in file_path - ] - assert len(ckpt_dir) == 1, f"{ckpt_dir}" - return ckpt_dir[0] +def _get_replicates_dir_for_all_possible_config(hp, welfare_training_save_dir): + if "use_r2d2" in hp and hp["use_r2d2"]: + all_replicates_save_dir = get_dir_of_each_replicate( + welfare_training_save_dir, str_in_dir="R2D2_" + ) + else: + all_replicates_save_dir = get_dir_of_each_replicate( + welfare_training_save_dir, str_in_dir="DQN_" + ) + return all_replicates_save_dir -def _get_ckpt_fro_ckpt_dir(ckpt_dir_path): - partialy_filtered_ckpt_path = _get_children_paths_wt_discarding_filter( - ckpt_dir_path, _filter="tune_metadata" - ) - ckpt_path = [ - file_path - for file_path in partialy_filtered_ckpt_path - if ".is_checkpoint" not in file_path +def _filter_checkpoints_if_utilitarians(all_replicates_save_dir, hp): + util_or_inequity_aversion = [ + path.split("/")[-3] for path in all_replicates_save_dir ] - assert len(ckpt_path) == 1, f"{ckpt_path}" - return ckpt_path[0] - - -def _get_children_paths_wt_selecting_filter(parent_dir_path, _filter): - return _get_children_paths_filters( - parent_dir_path, selecting_filter=_filter - ) + print("util_or_inequity_aversion", util_or_inequity_aversion) + if all( + welfare == "utilitarian_welfare" + for welfare in util_or_inequity_aversion + ): + all_replicates_save_dir = ( + utils.path.filter_list_of_replicates_by_results( + all_replicates_save_dir, + filter_key="episode_reward_mean", + filter_mode=tune_analysis.ABOVE, + filter_threshold=95, + ) + ) + if len(all_replicates_save_dir) > hp["max_n_replicates_in_base_game"]: + all_replicates_save_dir = all_replicates_save_dir[ + : hp["max_n_replicates_in_base_game"] + ] + else: + print( + "not filtering ckpts. all_replicates_save_dir[0]", + all_replicates_save_dir[0], + ) + return all_replicates_save_dir -def _get_children_paths_wt_discarding_filter(parent_dir_path, _filter): - return _get_children_paths_filters( - parent_dir_path, discarding_filter=_filter +def get_dir_of_each_replicate(welfare_training_save_dir, str_in_dir="DQN_"): + return path.get_children_paths_wt_selecting_filter( + welfare_training_save_dir, _filter=str_in_dir ) -def _get_children_paths_filters( - parent_dir_path: str, - selecting_filter: str = None, - discarding_filter: str = None, -): - filtered_children = os.listdir(parent_dir_path) - if selecting_filter is not None: - filtered_children = [ - filename - for filename in filtered_children - if selecting_filter in filename - ] - if discarding_filter is not None: - filtered_children = [ - filename - for filename in filtered_children - if discarding_filter not in filename - ] - filtered_children_path = [ - os.path.join(parent_dir_path, filename) - for filename in filtered_children - ] - return filtered_children_path - - def _convert_policy_config_to_meta_policy_config( hp, policy_id, policy_config, policy_class ): @@ -601,18 +560,18 @@ def _convert_policy_config_to_meta_policy_config( return meta_policy_config -def _get_all_welfare_pairs_wt_payoffs(hp, player_ids): +def _get_all_welfare_pairs_wt_cross_play_payoffs(hp, player_ids): with open(hp["json_file"]) as json_file: json_data = json.load(json_file) - cross_play_data = _keep_only_cross_play_values(json_data) + cross_play_data = keep_only_cross_play_values(json_data) cross_play_means = _keep_only_mean_values(cross_play_data) all_welfare_pairs_wt_payoffs = _order_players(cross_play_means, player_ids) return all_welfare_pairs_wt_payoffs -def _keep_only_cross_play_values(json_data): +def keep_only_cross_play_values(json_data): return { _format_eval_mode(eval_mode): v for eval_mode, v in json_data.items() @@ -623,7 +582,7 @@ def _keep_only_cross_play_values(json_data): def _format_eval_mode(eval_mode): k_wtout_kind_of_play = eval_mode.split(":")[-1].strip() both_welfare_fn = k_wtout_kind_of_play.split(" vs ") - return welfare_coordination.WelfareCoordinationTorchPolicy._from_pair_of_welfare_names_to_key( + return welfare_coordination.WelfareCoordinationTorchPolicy.from_pair_of_welfare_names_to_key( *both_welfare_fn ) @@ -648,6 +607,118 @@ def _order_players(cross_play_means, player_ids): } +def extract_stats_on_welfare_announced( + players_ids, exp_name=None, exp_dir=None, nested_info=False +): + if exp_dir is None: + exp_dir = _get_exp_dir_from_exp_name(exp_name) + all_in_exp_dir = utils.path.get_children_paths_wt_discarding_filter( + exp_dir, _filter=None + ) + all_dirs = utils.path.keep_dirs_only(all_in_exp_dir) + dir_name_by_tau = _group_dir_name_by_tau(all_dirs, nested_info) + dir_name_by_tau = _order_by_tau(dir_name_by_tau) + _get_stats_for_each_tau(dir_name_by_tau, players_ids, exp_dir) + + +def _get_exp_dir_from_exp_name(exp_name: str): + exp_dir = os.path.join("~/ray_results", exp_name) + exp_dir = os.path.expanduser(exp_dir) + return exp_dir + + +def _group_dir_name_by_tau(all_dirs: List[str], nested_info) -> Dict: + dirs_by_tau = {} + for trial_dir in all_dirs: + tau, welfare_set_annonced = _get_tau_value(trial_dir, nested_info) + if tau is None: + continue + if tau not in dirs_by_tau.keys(): + dirs_by_tau[tau] = [] + dirs_by_tau[tau].append((trial_dir, welfare_set_annonced)) + return dirs_by_tau + + +def _get_tau_value(trial_dir_path, nested_info=False): + full_epi_file_path = os.path.join( + trial_dir_path, "full_episodes_logs.json" + ) + if os.path.exists(full_epi_file_path): + full_epi_logs = utils.path._read_all_lines_of_file(full_epi_file_path) + first_epi_first_step = json.loads(full_epi_logs[1]) + tau = [] + welfare_set_annonced = {} + for policy_id, policy_info in first_epi_first_step.items(): + print('policy_info["info"]', policy_info["info"]) + if nested_info: + meta_policy_info = policy_info["info"][ + f"meta_policy/{policy_id}" + ] + else: + meta_policy_info = policy_info["info"][f"meta_policy"] + for k, v in meta_policy_info.items(): + if k.startswith("tau_"): + tau_value = float(k.split("_")[-1]) + tau.append(tau_value) + welfare_set_annonced[policy_id] = v["welfare_set_annonced"] + tau = set(tau) + assert len(tau) == 1, f"tau {tau}" + tau = list(tau)[0] + assert len(welfare_set_annonced.keys()) == 2 + return tau, welfare_set_annonced + else: + return None, None + + +def _order_by_tau(dir_name_by_tau): + return collections.OrderedDict(sorted(dir_name_by_tau.items())) + + +def _get_stats_for_each_tau(dir_name_by_tau, players_ids, exp_dir): + file_path = os.path.join(exp_dir, "welfare_announced_by_tau.txt") + with open(file_path, "w") as f: + for tau, dirs_data in dir_name_by_tau.items(): + all_welfares_player_1 = [] + all_welfares_player_2 = [] + for dir, welfare_announced in dirs_data: + all_welfares_player_1.append(welfare_announced[players_ids[0]]) + all_welfares_player_2.append(welfare_announced[players_ids[1]]) + all_welfares_player_1 = _format_in_same_order( + all_welfares_player_1 + ) + all_welfares_player_2 = _format_in_same_order( + all_welfares_player_2 + ) + count_announced_p1 = collections.Counter(all_welfares_player_1) + count_announced_p2 = collections.Counter(all_welfares_player_2) + msg = ( + f"===== Welfare sets announced with tau = {tau} =====\n" + f"Player 1: {count_announced_p1}\n" + f"Player 2: {count_announced_p2}\n" + ) + print(msg) + f.write(msg) + + +def _format_in_same_order(all_welfares_player_n): + formatted_welfare_announced = [] + for welfare_announced in all_welfares_player_n: + formated_name = "" + if "utilitarian" in welfare_announced: + formated_name += "utilitarian + " + if ( + "inequity" in welfare_announced + or "egalitarian" in welfare_announced + ): + formated_name += "egalitarian + " + if "mixed" in welfare_announced: + formated_name += "mixed" + if formated_name.endswith(" + "): + formated_name = formated_name[:-3] + formatted_welfare_announced.append(formated_name) + return formatted_welfare_announced + + if __name__ == "__main__": - debug_mode = False + debug_mode = True main(debug_mode) diff --git a/marltoolbox/experiments/rllib_api/amtft_various_env.py b/marltoolbox/experiments/rllib_api/amtft_various_env.py index f6900c8..83e7042 100644 --- a/marltoolbox/experiments/rllib_api/amtft_various_env.py +++ b/marltoolbox/experiments/rllib_api/amtft_various_env.py @@ -5,17 +5,17 @@ import ray from ray import tune from ray.rllib.agents import dqn -from ray.rllib.agents.dqn.dqn_torch_policy import postprocess_nstep_and_prio +from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio from ray.rllib.utils import merge_dicts from ray.rllib.utils.schedules import PiecewiseSchedule -from ray.tune.integration.wandb import WandbLogger -from ray.tune.logger import DEFAULT_LOGGERS +from ray.tune.integration.wandb import WandbLoggerCallback -from marltoolbox.algos import amTFT +from marltoolbox import utils +from marltoolbox.algos import amTFT, augmented_r2d2 from marltoolbox.envs import ( matrix_sequential_social_dilemma, - vectorized_coin_game, - vectorized_mixed_motive_coin_game, + coin_game, + mixed_motive_coin_game, ssd_mixed_motive_coin_game, ) from marltoolbox.envs.utils.wrappers import ( @@ -28,21 +28,30 @@ postprocessing, miscellaneous, plot, - self_and_cross_perf, callbacks, + cross_play, + config_helper, ) logger = logging.getLogger(__name__) -def main(debug, train_n_replicates=None, filter_utilitarian=None, env=None): +def main( + debug, + train_n_replicates=None, + filter_utilitarian=None, + env=None, + use_r2d2=False, +): hparams = get_hyperparameters( - debug, train_n_replicates, filter_utilitarian, env + debug, train_n_replicates, filter_utilitarian, env, use_r2d2=use_r2d2 ) if hparams["load_plot_data"] is None: ray.init( - num_cpus=os.cpu_count(), num_gpus=0, local_mode=hparams["debug"] + num_gpus=0, + num_cpus=os.cpu_count(), + local_mode=hparams["debug"], ) # Train @@ -76,13 +85,14 @@ def get_hyperparameters( filter_utilitarian=None, env=None, reward_uncertainty=0.0, + use_r2d2=False, ): if debug: train_n_replicates = 2 n_times_more_utilitarians_seeds = 1 elif train_n_replicates is None: n_times_more_utilitarians_seeds = 4 - train_n_replicates = 4 + train_n_replicates = 10 else: n_times_more_utilitarians_seeds = 4 @@ -91,8 +101,31 @@ def get_hyperparameters( ) pool_of_seeds = miscellaneous.get_random_seeds(n_seeds_to_prepare) exp_name, _ = log.log_in_current_day_dir("amTFT") + # prefix = "/home/maxime/dev-maxime/CLR/vm-data/instance-10-cpu-1/amTFT/2021_04_29/12_16_29/" + prefix = ( + "/home/maxime/dev-maxime/CLR/vm-data/instance-10-cpu-1/amTFT" + "/2021_04_30/22_41_32/" + ) + # prefix = "~/ray_results/amTFT/2021_04_29/12_16_29/" + util_prefix = os.path.join(prefix, "utilitarian_welfare/coop") + inequity_aversion_prefix = os.path.join( + prefix, "inequity_aversion_welfare/coop" + ) + util_load_data_list = [ + "R2D2_AsymVectorizedCoinGame_05122_00000_0_buffer_size=400000,temperature_schedule=