diff --git a/.gitignore b/.gitignore index d0f4500..ebbbe1c 100644 --- a/.gitignore +++ b/.gitignore @@ -44,9 +44,12 @@ pip-delete-this-directory.txt rllib_bugs/ *.log +*.png +*.out .ipynb_checkpoints/ .coverage marltoolbox/examples/Tutorial_*.py marltoolbox/experiments/Tutorial_*.py api_key_wandb wandb/ +requirements.txt diff --git a/marltoolbox/algos/amTFT/__init__.py b/marltoolbox/algos/amTFT/__init__.py index a8d5f52..d3d651e 100644 --- a/marltoolbox/algos/amTFT/__init__.py +++ b/marltoolbox/algos/amTFT/__init__.py @@ -19,3 +19,21 @@ AmTFTRolloutsTorchPolicy, ) from marltoolbox.algos.amTFT.train_helper import train_amtft + + +__all__ = [ + "train_amtft", + "AmTFTRolloutsTorchPolicy", + "amTFTQValuesTorchPolicy", + "Level1amTFTExploiterTorchPolicy", + "observation_fn", + "AmTFTCallbacks", + "WORKING_STATES", + "WORKING_STATES_IN_EVALUATION", + "AmTFTReferenceClass", + "DEFAULT_NESTED_POLICY_COOP", + "DEFAULT_NESTED_POLICY_SELFISH", + "PLOT_ASSEMBLAGE_TAGS", + "PLOT_KEYS", + "DEFAULT_CONFIG", +] diff --git a/marltoolbox/algos/amTFT/base.py b/marltoolbox/algos/amTFT/base.py index 6c383a3..6fe20e7 100644 --- a/marltoolbox/algos/amTFT/base.py +++ b/marltoolbox/algos/amTFT/base.py @@ -4,33 +4,40 @@ from ray.rllib.utils import merge_dicts from marltoolbox.algos import hierarchical, augmented_dqn -from marltoolbox.utils import \ - postprocessing, miscellaneous +from marltoolbox.utils import postprocessing, miscellaneous + logger = logging.getLogger(__name__) APPROXIMATION_METHOD_Q_VALUE = "amTFT_use_Q_net" APPROXIMATION_METHOD_ROLLOUTS = "amTFT_use_rollout" -APPROXIMATION_METHODS = (APPROXIMATION_METHOD_Q_VALUE, - APPROXIMATION_METHOD_ROLLOUTS) -WORKING_STATES = ("train_coop", - "train_selfish", - "eval_amtft", - "eval_naive_selfish", - "eval_naive_coop") -WORKING_STATES_IN_EVALUATION = WORKING_STATES[2:] +APPROXIMATION_METHODS = ( + APPROXIMATION_METHOD_Q_VALUE, + APPROXIMATION_METHOD_ROLLOUTS, +) +WORKING_STATES = ( + "train_coop", + "train_selfish", + "eval_amtft", + "eval_naive_selfish", + "eval_naive_coop", + "use_true_selfish", +) +WORKING_STATES_IN_EVALUATION = WORKING_STATES[2:5] OWN_COOP_POLICY_IDX = 0 OWN_SELFISH_POLICY_IDX = 1 OPP_COOP_POLICY_IDX = 2 OPP_SELFISH_POLICY_IDX = 3 +TRUE_SELFISH_POLICY_IDX = 4 DEFAULT_NESTED_POLICY_SELFISH = augmented_dqn.MyDQNTorchPolicy DEFAULT_NESTED_POLICY_COOP = DEFAULT_NESTED_POLICY_SELFISH.with_updates( postprocess_fn=miscellaneous.merge_policy_postprocessing_fn( postprocessing.welfares_postprocessing_fn( - add_utilitarian_welfare=True, ), - postprocess_nstep_and_prio + add_utilitarian_welfare=True, + ), + postprocess_nstep_and_prio, ) ) @@ -44,31 +51,36 @@ "rollout_length": 40, "n_rollout_replicas": 20, "last_k": 1, + "punish_instead_of_selfish": False, + "min_punish_steps": 0, # TODO use log level of RLLib instead of mine "verbose": 1, "auto_load_checkpoint": True, - # To configure "own_policy_id": None, "opp_policy_id": None, "callbacks": None, # One from marltoolbox.utils.postprocessing.WELFARES "welfare_key": postprocessing.WELFARE_UTILITARIAN, - - 'nested_policies': [ + "nested_policies": [ # Here the trainer need to be a DQNTrainer # to provide the config for the 3 DQNTorchPolicy - {"Policy_class": DEFAULT_NESTED_POLICY_COOP, - "config_update": {}}, - {"Policy_class": DEFAULT_NESTED_POLICY_SELFISH, - "config_update": {}}, - {"Policy_class": DEFAULT_NESTED_POLICY_COOP, - "config_update": {}}, - {"Policy_class": DEFAULT_NESTED_POLICY_SELFISH, - "config_update": {}}, + {"Policy_class": DEFAULT_NESTED_POLICY_COOP, "config_update": {}}, + { + "Policy_class": DEFAULT_NESTED_POLICY_SELFISH, + "config_update": {}, + }, + {"Policy_class": DEFAULT_NESTED_POLICY_COOP, "config_update": {}}, + { + "Policy_class": DEFAULT_NESTED_POLICY_SELFISH, + "config_update": {}, + }, ], - "optimizer": {"sgd_momentum": 0.0, }, - } + "optimizer": { + "sgd_momentum": 0.0, + }, + "batch_mode": "complete_episodes", + }, ) PLOT_KEYS = [ @@ -76,7 +88,8 @@ "debit", "debit_threshold", "summed_debit", - "summed_n_steps_to_punish" + "summed_n_steps_to_punish", + "reset_rnn_state", ] PLOT_ASSEMBLAGE_TAGS = [ @@ -85,6 +98,7 @@ ("debit_threshold",), ("summed_debit",), ("summed_n_steps_to_punish",), + ("reset_rnn_state",), ] diff --git a/marltoolbox/algos/amTFT/base_policy.py b/marltoolbox/algos/amTFT/base_policy.py index 33c63d0..1fe94b7 100644 --- a/marltoolbox/algos/amTFT/base_policy.py +++ b/marltoolbox/algos/amTFT/base_policy.py @@ -1,14 +1,20 @@ import copy import logging -from typing import List, Union, Optional, Dict, Tuple, TYPE_CHECKING +from typing import Dict, TYPE_CHECKING import numpy as np +import torch from ray.rllib.env import BaseEnv from ray.rllib.evaluation import MultiAgentEpisode from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import override -from ray.rllib.utils.typing import TensorType, PolicyID +from ray.rllib.utils.threading import with_lock +from ray.rllib.utils.torch_ops import ( + convert_to_torch_tensor, +) +from ray.rllib.utils.typing import PolicyID +from ray.rllib.evaluation import postprocessing as postprocessing_rllib if TYPE_CHECKING: from ray.rllib.evaluation import RolloutWorker @@ -23,6 +29,7 @@ OPP_COOP_POLICY_IDX, OWN_COOP_POLICY_IDX, OWN_SELFISH_POLICY_IDX, + TRUE_SELFISH_POLICY_IDX, ) logger = logging.getLogger(__name__) @@ -37,21 +44,32 @@ def __init__(self, observation_space, action_space, config, **kwargs): self.total_debit = 0 self.n_steps_to_punish = 0 self.observed_n_step_in_current_epi = 0 - self._first_fake_step_played = False + self.last_own_algo_idx_in_eval = OWN_COOP_POLICY_IDX self.opp_previous_obs = None self.opp_new_obs = None self.own_previous_obs = None self.own_new_obs = None self.both_previous_raw_obs = None self.both_new_raw_obs = None + self.coop_own_rnn_state_before_last_act = None + self.coop_opp_rnn_state_before_last_act = None # notation T in the paper self.debit_threshold = config["debit_threshold"] # notation alpha in the paper self.punishment_multiplier = config["punishment_multiplier"] self.working_state = config["working_state"] + assert ( + self.working_state in WORKING_STATES + ), f"self.working_state {self.working_state}" self.verbose = config["verbose"] self.welfare_key = config["welfare_key"] self.auto_load_checkpoint = config.get("auto_load_checkpoint", True) + self.punish_instead_of_selfish = config.get( + "punish_instead_of_selfish", False + ) + self.punish_instead_of_selfish_key = ( + postprocessing.OPPONENT_NEGATIVE_REWARD + ) if self.working_state in WORKING_STATES_IN_EVALUATION: self._set_models_for_evaluation() @@ -60,45 +78,33 @@ def __init__(self, observation_space, action_space, config, **kwargs): self.auto_load_checkpoint and restore.LOAD_FROM_CONFIG_KEY in config.keys() ): - restore.after_init_load_policy_checkpoint(self) + print("amTFT going to load checkpoint") + restore.before_loss_init_load_policy_checkpoint(self) def _set_models_for_evaluation(self): for algo in self.algorithms: algo.model.eval() + @with_lock @override(hierarchical.HierarchicalTorchPolicy) - def compute_actions( - self, - obs_batch: Union[List[TensorType], TensorType], - state_batches: Optional[List[TensorType]] = None, - prev_action_batch: Union[List[TensorType], TensorType] = None, - prev_reward_batch: Union[List[TensorType], TensorType] = None, - info_batch: Optional[Dict[str, list]] = None, - episodes: Optional[List["MultiAgentEpisode"]] = None, - explore: Optional[bool] = None, - timestep: Optional[int] = None, - **kwargs, - ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: - - self._select_witch_algo_to_use() - actions, state_out, extra_fetches = self.algorithms[ - self.active_algo_idx - ].compute_actions( - obs_batch, - state_batches, - prev_action_batch, - prev_reward_batch, - info_batch, - episodes, - explore, - timestep, - **kwargs, + def _compute_action_helper( + self, input_dict, state_batches, seq_lens, explore, timestep + ): + state_batches = self._select_witch_algo_to_use(state_batches) + + self._track_last_coop_rnn_state(state_batches) + actions, state_out, extra_fetches = super()._compute_action_helper( + input_dict, state_batches, seq_lens, explore, timestep ) - if self.verbose > 2: - print(f"self.active_algo_idx {self.active_algo_idx}") + if self.verbose > 1: + print("algo idx", self.active_algo_idx, "action", actions) + print("extra_fetches", extra_fetches) + print("state_batches (in)", state_batches) + print("state_out", state_out) + return actions, state_out, extra_fetches - def _select_witch_algo_to_use(self): + def _select_witch_algo_to_use(self, state_batches): if ( self.working_state == WORKING_STATES[0] or self.working_state == WORKING_STATES[4] @@ -109,14 +115,19 @@ def _select_witch_algo_to_use(self): or self.working_state == WORKING_STATES[3] ): self.active_algo_idx = OWN_SELFISH_POLICY_IDX + elif self.working_state == WORKING_STATES[2]: - self._select_algo_to_use_in_eval() + state_batches = self._select_algo_to_use_in_eval(state_batches) + elif self.working_state == WORKING_STATES[5]: + self.active_algo_idx = TRUE_SELFISH_POLICY_IDX else: raise ValueError( f'config["working_state"] ' f"must be one of {WORKING_STATES}" ) - def _select_algo_to_use_in_eval(self): + return state_batches + + def _select_algo_to_use_in_eval(self, state_batches): if self.n_steps_to_punish == 0: self.active_algo_idx = OWN_COOP_POLICY_IDX elif self.n_steps_to_punish > 0: @@ -125,30 +136,76 @@ def _select_algo_to_use_in_eval(self): else: raise ValueError("self.n_steps_to_punish can't be below zero") + state_batches = self._check_for_rnn_state_reset( + state_batches, "last_own_algo_idx_in_eval" + ) + + return state_batches + + def _check_for_rnn_state_reset(self, state_batches, last_algo_idx: str): + if getattr(self, last_algo_idx) != self.active_algo_idx: + state_batches = self._get_initial_rnn_state(state_batches) + self._to_log["reset_rnn_state"] = self.active_algo_idx + setattr(self, last_algo_idx, self.active_algo_idx) + if self.verbose > 0: + print("reset_rnn_state") + else: + if "reset_rnn_state" in self._to_log.keys(): + self._to_log.pop("reset_rnn_state") + return state_batches + + def _get_initial_rnn_state(self, state_batches): + if "model" in self.config.keys() and self.config["model"]["use_lstm"]: + initial_state = self.algorithms[ + self.active_algo_idx + ].get_initial_state() + initial_state = [ + convert_to_torch_tensor(s, self.device) for s in initial_state + ] + initial_state = [s.unsqueeze(0) for s in initial_state] + msg = ( + f"self.active_algo_idx {self.active_algo_idx} " + f"state_batches {state_batches} reset to initial rnn state" + ) + # print(msg) + logger.info(msg) + return initial_state + else: + return state_batches + + def _track_last_coop_rnn_state(self, state_batches): + if self.active_algo_idx == OWN_COOP_POLICY_IDX: + self.coop_own_rnn_state_before_last_act = state_batches + @override(hierarchical.HierarchicalTorchPolicy) def _learn_on_batch(self, samples: SampleBatch): - # working_state_idx = WORKING_STATES.index(self.working_state) - # assert working_state_idx == OWN_COOP_POLICY_IDX \ - # or working_state_idx == OWN_SELFISH_POLICY_IDX, \ - # f"current working_state is {self.working_state} " \ - # f"but must be one of " \ - # f"[{WORKING_STATES[OWN_COOP_POLICY_IDX]}, " \ - # f"{WORKING_STATES[OWN_SELFISH_POLICY_IDX]}]" - if self.working_state == WORKING_STATES[0]: algo_idx_to_train = OWN_COOP_POLICY_IDX elif self.working_state == WORKING_STATES[1]: algo_idx_to_train = OWN_SELFISH_POLICY_IDX + elif self.working_state == WORKING_STATES[5]: + algo_idx_to_train = TRUE_SELFISH_POLICY_IDX else: - raise ValueError() + raise ValueError( + f"self.working_state must be one of " + f"{WORKING_STATES[0:2]+[WORKING_STATES[5]]}" + ) + # print("base amTFT samples['advantages'] before recomputing", samples[ + # "advantages"]) samples = self._modify_batch_for_policy(algo_idx_to_train, samples) - algo_to_train = self.algorithms[algo_idx_to_train] + # print("base amTFT samples['rewards']", samples["rewards"]) + # print("base amTFT samples['advantages']", samples["advantages"]) + learner_stats_original = algo_to_train.learn_on_batch(samples) learner_stats = {"learner_stats": {}} learner_stats["learner_stats"][ f"algo{algo_idx_to_train}" - ] = algo_to_train.learn_on_batch(samples) + ] = learner_stats_original + if "kl" in learner_stats_original["learner_stats"]: + learner_stats["learner_stats"]["kl"] = learner_stats_original[ + "learner_stats" + ]["kl"] if self.verbose > 1: print(f"learn_on_batch WORKING_STATES " f"{self.working_state}") @@ -156,16 +213,33 @@ def _learn_on_batch(self, samples: SampleBatch): return learner_stats def _modify_batch_for_policy(self, algo_idx_to_train, samples): + if algo_idx_to_train == OWN_COOP_POLICY_IDX: samples = samples.copy() - samples = self._overwrite_reward_for_policy_in_use(samples) + samples = self._overwrite_reward_for_policy_in_use( + samples, self.welfare_key + ) + elif ( + self.punish_instead_of_selfish + and algo_idx_to_train == OWN_SELFISH_POLICY_IDX + ): + samples = samples.copy() + samples = self._overwrite_reward_for_policy_in_use( + samples, self.punish_instead_of_selfish_key + ) + + if postprocessing_rllib.Postprocessing.ADVANTAGES in samples.keys(): + samples = self.algorithms[ + algo_idx_to_train + ].postprocess_trajectory(samples) + return samples - def _overwrite_reward_for_policy_in_use(self, samples_copy): + def _overwrite_reward_for_policy_in_use(self, samples_copy, welfare_key): samples_copy[samples_copy.REWARDS] = np.array( - samples_copy.data[self.welfare_key] + samples_copy.data[welfare_key] ) - logger.debug(f"overwrite reward with {self.welfare_key}") + logger.debug(f"overwrite reward with {welfare_key}") return samples_copy def on_observation_fn(self, own_new_obs, opp_new_obs, both_new_raw_obs): @@ -173,7 +247,7 @@ def on_observation_fn(self, own_new_obs, opp_new_obs, both_new_raw_obs): # observation produced by this action. But we need the # observation that cause the agent to play this action # thus the observation n-1 - if self._first_fake_step_played: + if self.own_new_obs is not None: self.own_previous_obs = self.own_new_obs self.opp_previous_obs = self.opp_new_obs self.both_previous_raw_obs = self.both_new_raw_obs @@ -193,20 +267,17 @@ def on_episode_step( *args, **kwargs, ): - if self._first_fake_step_played: - opp_obs, raw_obs, opp_a = self._get_information_from_opponent( - policy_id, policy_ids, episode - ) + opp_obs, raw_obs, opp_a = self._get_information_from_opponent( + policy_id, policy_ids, episode + ) - # Ignored the first step in epi because the - # actions provided are fake (they were not played) - self._on_episode_step( - opp_obs, raw_obs, opp_a, worker, base_env, episode, env_index - ) + # Ignored the first step in epi because the + # actions provided are fake (they were not played) + self._on_episode_step( + opp_obs, raw_obs, opp_a, worker, base_env, episode, env_index + ) - self.observed_n_step_in_current_epi += 1 - else: - self._first_fake_step_played = True + self.observed_n_step_in_current_epi += 1 def _get_information_from_opponent(self, agent_id, agent_ids, episode): opp_agent_id = [one_id for one_id in agent_ids if one_id != agent_id][ @@ -243,11 +314,27 @@ def _on_episode_step( coop_opp_simulated_action = ( self._simulate_action_from_cooperative_opponent(opp_obs) ) + assert ( + len(worker.async_env.env_states) == 1 + ), "amTFT in eval only works with one env not vector of envs" + assert ( + worker.env.step_count_in_current_episode + == worker.async_env.env_states[ + 0 + ].env.step_count_in_current_episode + ) + assert ( + worker.env.step_count_in_current_episode + == self._base_env_at_last_step.get_unwrapped()[ + 0 + ].step_count_in_current_episode + + 1 + ) self._update_total_debit( last_obs, opp_action, worker, - base_env, + self._base_env_at_last_step, episode, env_index, coop_opp_simulated_action, @@ -257,20 +344,53 @@ def _on_episode_step( opp_action, coop_opp_simulated_action, worker, last_obs ) + self._base_env_at_last_step = copy.deepcopy(base_env) + self._to_log["n_steps_to_punish"] = self.n_steps_to_punish self._to_log["debit_threshold"] = self.debit_threshold def _is_punishment_planned(self): return self.n_steps_to_punish > 0 + def on_episode_start(self, *args, **kwargs): + if self.working_state in WORKING_STATES_IN_EVALUATION: + self._base_env_at_last_step = copy.deepcopy(kwargs["base_env"]) + self.last_own_algo_idx_in_eval = OWN_COOP_POLICY_IDX + self.coop_opp_rnn_state_after_last_act = self.algorithms[ + OPP_COOP_POLICY_IDX + ].get_initial_state() + def _simulate_action_from_cooperative_opponent(self, opp_obs): + if self.verbose > 1: + print("opp_obs for opp coop simu nonzero obs", np.nonzero(opp_obs)) + for i, algo in enumerate(self.algorithms): + print("algo", i, algo) + self.coop_opp_rnn_state_before_last_act = ( + self.coop_opp_rnn_state_after_last_act + ) ( coop_opp_simulated_action, - _, - self.coop_opp_extra_fetches, - ) = self.algorithms[OPP_COOP_POLICY_IDX].compute_actions([opp_obs]) - # Returned a list - coop_opp_simulated_action = coop_opp_simulated_action[0] + self.coop_opp_rnn_state_after_last_act, + coop_opp_extra_fetches, + ) = self.algorithms[OPP_COOP_POLICY_IDX].compute_single_action( + obs=opp_obs, + state=self.coop_opp_rnn_state_after_last_act, + ) + if self.verbose > 1: + print( + coop_opp_simulated_action, + "coop_opp_extra_fetches", + coop_opp_extra_fetches, + ) + print( + "state before simu coop opp", + self.coop_opp_rnn_state_before_last_act, + ) + print( + "state after simu coop opp", + self.coop_opp_rnn_state_after_last_act, + ) + return coop_opp_simulated_action def _update_total_debit( @@ -283,8 +403,19 @@ def _update_total_debit( env_index, coop_opp_simulated_action, ): - + if self.verbose > 1: + print( + self.own_policy_id, + self.config[restore.LOAD_FROM_CONFIG_KEY][0].split("/")[-5], + ) if coop_opp_simulated_action != opp_action: + if self.verbose > 0: + print( + self.own_policy_id, + "coop_opp_simulated_action != opp_action:", + coop_opp_simulated_action, + opp_action, + ) if ( worker.env.step_count_in_current_episode >= worker.env.max_steps @@ -319,20 +450,22 @@ def _update_total_debit( ) if coop_opp_simulated_action != opp_action: if self.verbose > 0: - print( - "coop_opp_simulated_action != opp_action:", - coop_opp_simulated_action, - opp_action, - ) print(f"debit {debit}") print( f"self.total_debit {self.total_debit}, previous was {tmp}" ) + if self.verbose > 0: + print("_update_total_debit") + print(" debit", debit) + print(" self.total_debit", self.total_debit) + print(" self.summed_debit", self._to_log["summed_debit"]) + print(" self.performing_rollouts", self.performing_rollouts) + print(" self.use_opponent_policies", self.use_opponent_policies) def _is_starting_new_punishment_required(self, manual_threshold=None): if manual_threshold is not None: - return self.total_debit > manual_threshold - return self.total_debit > self.debit_threshold + return self.total_debit >= manual_threshold + return self.total_debit >= self.debit_threshold def _plan_punishment( self, opp_action, coop_opp_simulated_action, worker, last_obs @@ -355,7 +488,11 @@ def _plan_punishment( print(f"reset self.total_debit to 0 since planned punishement") def on_episode_end(self, *args, **kwargs): - if self.working_state not in WORKING_STATES_IN_EVALUATION: + self._defensive_check_observed_n_opp_moves(*args, **kwargs) + self._if_in_eval_reset_debit_and_punish() + + def _defensive_check_observed_n_opp_moves(self, *args, **kwargs): + if self.working_state in WORKING_STATES_IN_EVALUATION: assert ( self.observed_n_step_in_current_epi == kwargs["base_env"].get_unwrapped()[0].max_steps @@ -365,13 +502,16 @@ def on_episode_end(self, *args, **kwargs): f"{kwargs['base_env'].get_unwrapped()[0].max_steps} " "steps per episodes." ) - self.observed_n_step_in_current_epi = 0 + self.observed_n_step_in_current_epi = 0 - if self.working_state == WORKING_STATES[2]: + def _if_in_eval_reset_debit_and_punish(self): + if self.working_state in WORKING_STATES_IN_EVALUATION: self.total_debit = 0 self.n_steps_to_punish = 0 if self.verbose > 0: - print(f"reset self.total_debit to 0 since end of episode") + logger.info( + "reset self.total_debit to 0 since end of " "episode" + ) def _compute_debit( self, @@ -390,6 +530,27 @@ def _compute_punishment_duration( ): raise NotImplementedError() + def _get_last_rnn_states_before_rollouts(self): + if self.config["model"]["use_lstm"]: + return { + self.own_policy_id: self._squeezes_rnn_state( + self.coop_own_rnn_state_before_last_act + ), + self.opp_policy_id: self.coop_opp_rnn_state_before_last_act, + } + + else: + return None + + @staticmethod + def _squeezes_rnn_state(state): + return [ + s.squeeze(0) + if torch and isinstance(s, torch.Tensor) + else np.squeeze(s, 0) + for s in state + ] + class AmTFTCallbacks(callbacks.PolicyCallbacks): def on_train_result(self, trainer, *args, **kwargs): diff --git a/marltoolbox/algos/amTFT/inversed_policy.py b/marltoolbox/algos/amTFT/inversed_policy.py index fef207e..578dcd8 100644 --- a/marltoolbox/algos/amTFT/inversed_policy.py +++ b/marltoolbox/algos/amTFT/inversed_policy.py @@ -4,14 +4,15 @@ class InversedAmTFTRolloutsTorchPolicy( - policy_using_rollouts.AmTFTRolloutsTorchPolicy): + policy_using_rollouts.AmTFTRolloutsTorchPolicy +): """ Instead of simulating the opponent, simulate our own policy and act as if it was the opponent. """ - def _init_for_rollout(self, config): - super()._init_for_rollout(config) + def _init(self, config): + super()._init(config) self.ag_id_rollout_reward_to_read = self.own_policy_id @override(base_policy.AmTFTPolicyBase) @@ -26,86 +27,3 @@ def _switch_own_and_opp(self, agent_id): output = super()._switch_own_and_opp(agent_id) self.use_opponent_policies = not self.use_opponent_policies return output - - # @override(policy_using_rollouts.AmTFTRolloutsTorchPolicy) - # def _select_algo_to_use_in_eval(self): - # assert self.performing_rollouts - # - # if not self.use_opponent_policies: - # if self.n_steps_to_punish == 0: - # self.active_algo_idx = base.OPP_COOP_POLICY_IDX - # elif self.n_steps_to_punish > 0: - # self.active_algo_idx = base.OPP_SELFISH_POLICY_IDX - # self.n_steps_to_punish -= 1 - # else: - # raise ValueError("self.n_steps_to_punish can't be below zero") - # else: - # # assert self.performing_rollouts - # if self.n_steps_to_punish_opponent == 0: - # self.active_algo_idx = base.OWN_COOP_POLICY_IDX - # elif self.n_steps_to_punish_opponent > 0: - # self.active_algo_idx = base.OWN_SELFISH_POLICY_IDX - # self.n_steps_to_punish_opponent -= 1 - # else: - # raise ValueError("self.n_steps_to_punish_opp " - # "can't be below zero") - - # @override(policy_using_rollouts.AmTFTRolloutsTorchPolicy) - # def _init_for_rollout(self, config): - # super()._init_for_rollout(config) - # # the policies stored as opponent_policies are our own policy - # # (not the opponent's policies) - # self.use_opponent_policies = False - - # @override(policy_using_rollouts.AmTFTRolloutsTorchPolicy) - # def _prepare_to_perform_virtual_rollouts_in_env(self, worker): - # outputs = super()._prepare_to_perform_virtual_rollouts_in_env( - # worker) - # # the policies stored as opponent_policies are our own policy - # # (not the opponent's policies) - # self.use_opponent_policies = True - # return outputs - - # @override(policy_using_rollouts.AmTFTRolloutsTorchPolicy) - # def _stop_performing_virtual_rollouts_in_env(self, n_steps_to_punish): - # super()._stop_performing_virtual_rollouts_in_env(n_steps_to_punish) - # # the policies stored as opponent_policies are our own policy - # # (not the opponent's policies) - # self.use_opponent_policies = True - - # @override(policy_using_rollouts.AmTFTRolloutsTorchPolicy) - # def compute_actions( - # self, - # obs_batch: Union[List[TensorType], TensorType], - # state_batches: Optional[List[TensorType]] = None, - # prev_action_batch: Union[List[TensorType], TensorType] = None, - # prev_reward_batch: Union[List[TensorType], TensorType] = None, - # info_batch: Optional[Dict[str, list]] = None, - # episodes: Optional[List["MultiAgentEpisode"]] = None, - # explore: Optional[bool] = None, - # timestep: Optional[int] = None, - # **kwargs) -> \ - # Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: - # - # # Option to overwrite action during internal rollouts - # if self.use_opponent_policies: - # if len(self.overwrite_action) > 0: - # actions, state_out, extra_fetches = \ - # self.overwrite_action.pop(0) - # if self.verbose > 1: - # print("overwrite actions", actions, type(actions)) - # return actions, state_out, extra_fetches - # - # return super().compute_actions( - # obs_batch, state_batches, prev_action_batch, prev_reward_batch, - # info_batch, episodes, explore, timestep, **kwargs) - -# debit = self._compute_debit( -# last_obs, opp_action, worker, base_env, -# episode, env_index, coop_opp_simulated_action) - -# self.n_steps_to_punish = self._compute_punishment_duration( -# opp_action, -# coop_opp_simulated_action, -# worker, -# last_obs) diff --git a/marltoolbox/algos/amTFT/level_1_exploiter.py b/marltoolbox/algos/amTFT/level_1_exploiter.py index c395ded..7f5c52d 100644 --- a/marltoolbox/algos/amTFT/level_1_exploiter.py +++ b/marltoolbox/algos/amTFT/level_1_exploiter.py @@ -3,6 +3,7 @@ from ray.rllib import SampleBatch from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import override +from ray.rllib.utils.threading import with_lock from marltoolbox.algos import hierarchical, amTFT from marltoolbox.algos.amTFT import base @@ -32,7 +33,7 @@ def __init__(self, observation_space, action_space, config, **kwargs): self.exploiter_activated = config.get("exploiter_activated") if restore.LOAD_FROM_CONFIG_KEY in config.keys(): - restore.after_init_load_policy_checkpoint(self) + restore.before_loss_init_load_policy_checkpoint(self) if self.working_state in base.WORKING_STATES_IN_EVALUATION: if self.exploiter_activated: @@ -48,12 +49,16 @@ def working_state(self, value): def on_episode_start(self, worker, *args, **kwargs): self.worker = worker + self.simulated_opponent.on_episode_start(*args, **kwargs) def on_episode_end(self, *args, **kwargs): self.simulated_opponent.on_episode_end(*args, **kwargs) - def on_observation_fn(self, *args, **kwargs): - self.simulated_opponent.on_observation_fn(*args, **kwargs) + def on_observation_fn(self, own_new_obs, opp_new_obs, both_new_raw_obs): + self.simulated_opponent.on_observation_fn( + own_new_obs, opp_new_obs, both_new_raw_obs + ) + self.both_new_raw_obs = both_new_raw_obs def on_episode_step(self, *args, **kwargs): self.simulated_opponent.on_episode_step(*args, **kwargs) @@ -78,12 +83,14 @@ def _train_dual_policies(self, samples: SampleBatch): ) return learner_stats + @with_lock @override(level_1.Level1ExploiterTorchPolicy) - def compute_actions(self, *args, **kwargs): - + def _compute_action_helper(self, *args, **kwargs): # use the simulated amTFT opponent when performing rollouts if self.simulated_opponent.performing_rollouts: - return self.simulated_opponent.compute_actions(*args, **kwargs) + return self.simulated_opponent._compute_action_helper( + *args, **kwargs + ) self._to_log["use_simulated_opponent"] = 0.0 self._to_log["use_exploiter_cooperating"] = 0.0 @@ -93,16 +100,19 @@ def compute_actions(self, *args, **kwargs): if self.working_state not in base.WORKING_STATES_IN_EVALUATION: self._to_log["use_simulated_opponent"] = 1.0 self.active_algo_idx = self.LEVEL_1_POLICY_IDX - return self.simulated_opponent.compute_actions(*args, **kwargs) + return self.simulated_opponent._compute_action_helper( + *args, **kwargs + ) # Use the simulated amTFT opponent when not using the exploiter if not self.exploiter_activated: self._to_log["use_simulated_opponent"] = 1.0 self.active_algo_idx = self.LEVEL_1_POLICY_IDX - return self.simulated_opponent.compute_actions(*args, **kwargs) + return self.simulated_opponent._compute_action_helper( + *args, **kwargs + ) - outputs = super().compute_actions(*args, **kwargs) - return outputs + return super()._compute_action_helper(*args, **kwargs) def _is_cooperation_needed_to_prevent_punishement( self, selfish_action_tuple, *args, **kwargs @@ -117,9 +127,14 @@ def _lookahead_for_opponent_punishing( self, selfish_action, *args, **kwargs ) -> bool: selfish_action = selfish_action[0] - last_obs = kwargs["episodes"][0]._agent_to_last_raw_obs selfish_step_debit = self.simulated_opponent._compute_debit( - last_obs, selfish_action, self.worker, None, None, None, None + self.both_new_raw_obs, + selfish_action, + self.worker, + None, + None, + None, + None, ) if ( "expl_min_anticipated_gain" not in self._to_log.keys() diff --git a/marltoolbox/algos/amTFT/policy_using_rollouts.py b/marltoolbox/algos/amTFT/policy_using_rollouts.py index d6ba59b..4609483 100644 --- a/marltoolbox/algos/amTFT/policy_using_rollouts.py +++ b/marltoolbox/algos/amTFT/policy_using_rollouts.py @@ -1,15 +1,13 @@ -from typing import List, Union, Optional, Dict, Tuple - import numpy as np -from ray.rllib.evaluation import MultiAgentEpisode -from ray.rllib.utils.typing import TensorType +from ray.rllib.utils import override +from ray.rllib.utils.threading import with_lock from marltoolbox.algos.amTFT.base import ( OWN_COOP_POLICY_IDX, OWN_SELFISH_POLICY_IDX, OPP_SELFISH_POLICY_IDX, OPP_COOP_POLICY_IDX, - WORKING_STATES, + WORKING_STATES_IN_EVALUATION, ) from marltoolbox.algos.amTFT.base_policy import AmTFTPolicyBase from marltoolbox.utils import rollout @@ -18,9 +16,9 @@ class AmTFTRolloutsTorchPolicy(AmTFTPolicyBase): def __init__(self, observation_space, action_space, config, **kwargs): super().__init__(observation_space, action_space, config, **kwargs) - self._init_for_rollout(self.config) + self._init(self.config) - def _init_for_rollout(self, config): + def _init(self, config): self.last_k = config["last_k"] self.use_opponent_policies = False self.rollout_length = config["rollout_length"] @@ -31,49 +29,36 @@ def _init_for_rollout(self, config): self.opp_policy_id = config["opp_policy_id"] self.n_steps_to_punish_opponent = 0 self.ag_id_rollout_reward_to_read = self.opp_policy_id + self.last_opp_algo_idx_in_rollout = OPP_COOP_POLICY_IDX + self.last_own_algo_idx_in_rollout = OWN_COOP_POLICY_IDX + self.use_short_debit_rollout = config.get( + "use_short_debit_rollout", False + ) - # Don't support LSTM (at least because of action - # overwriting needed in the rollouts) - if "model" in config.keys(): - if "use_lstm" in config["model"].keys(): - assert not config["model"]["use_lstm"] - - def compute_actions( - self, - obs_batch: Union[List[TensorType], TensorType], - state_batches: Optional[List[TensorType]] = None, - prev_action_batch: Union[List[TensorType], TensorType] = None, - prev_reward_batch: Union[List[TensorType], TensorType] = None, - info_batch: Optional[Dict[str, list]] = None, - episodes: Optional[List["MultiAgentEpisode"]] = None, - explore: Optional[bool] = None, - timestep: Optional[int] = None, - **kwargs, - ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: - + @with_lock + @override(AmTFTPolicyBase) + def _compute_action_helper( + self, input_dict, state_batches, seq_lens, explore, timestep + ): # Option to overwrite action during internal rollouts if self.use_opponent_policies: if len(self.overwrite_action) > 0: actions, state_out, extra_fetches = self.overwrite_action.pop( 0 ) + self.last_opp_algo_idx_in_rollout = OPP_SELFISH_POLICY_IDX if self.verbose > 1: - print("overwritten actions", actions, type(actions)) + print( + "overwritten actions", actions, "state_out", state_out + ) return actions, state_out, extra_fetches - return super().compute_actions( - obs_batch, - state_batches, - prev_action_batch, - prev_reward_batch, - info_batch, - episodes, - explore, - timestep, - **kwargs, + return super()._compute_action_helper( + input_dict, state_batches, seq_lens, explore, timestep ) - def _select_algo_to_use_in_eval(self): + @override(AmTFTPolicyBase) + def _select_algo_to_use_in_eval(self, state_batches): if not self.use_opponent_policies: if self.n_steps_to_punish == 0: self.active_algo_idx = OWN_COOP_POLICY_IDX @@ -82,6 +67,16 @@ def _select_algo_to_use_in_eval(self): self.n_steps_to_punish -= 1 else: raise ValueError("self.n_steps_to_punish can't be below zero") + + if self.performing_rollouts: + state_batches = self._check_for_rnn_state_reset( + state_batches, "last_own_algo_idx_in_rollout" + ) + else: + state_batches = self._check_for_rnn_state_reset( + state_batches, "last_own_algo_idx_in_eval" + ) + else: assert self.performing_rollouts if self.n_steps_to_punish_opponent == 0: @@ -94,6 +89,17 @@ def _select_algo_to_use_in_eval(self): "self.n_steps_to_punish_opp " "can't be below zero" ) + state_batches = self._check_for_rnn_state_reset( + state_batches, "last_opp_algo_idx_in_rollout" + ) + + return state_batches + + @override(AmTFTPolicyBase) + def _track_last_coop_rnn_state(self, state_batches): + if not self.performing_rollouts and not self.use_opponent_policies: + super()._track_last_coop_rnn_state(state_batches) + def _on_episode_step( self, opp_obs, @@ -104,16 +110,16 @@ def _on_episode_step( episode, env_index, ): - if not self.performing_rollouts: - super()._on_episode_step( - opp_obs, - last_obs, - opp_action, - worker, - base_env, - episode, - env_index, - ) + assert not self.performing_rollouts + super()._on_episode_step( + opp_obs, + last_obs, + opp_action, + worker, + base_env, + episode, + env_index, + ) def _compute_debit( self, @@ -126,11 +132,13 @@ def _compute_debit( coop_opp_simulated_action, ): approximated_debit = self._compute_debit_using_rollouts( - last_obs, opp_action, worker + last_obs, opp_action, worker, base_env ) return approximated_debit - def _compute_debit_using_rollouts(self, last_obs, opp_action, worker): + def _compute_debit_using_rollouts( + self, last_obs, opp_action, worker, base_env + ): ( n_steps_to_punish, policy_map, @@ -138,6 +146,9 @@ def _compute_debit_using_rollouts(self, last_obs, opp_action, worker): ) = self._prepare_to_perform_virtual_rollouts_in_env(worker) # Cooperative rollouts + if self.verbose > 1: + print("Compute debit") + print("Cooperative rollouts") ( mean_total_reward_for_totally_coop_opp, _, @@ -148,8 +159,12 @@ def _compute_debit_using_rollouts(self, last_obs, opp_action, worker): partially_coop=False, opp_action=None, last_obs=last_obs, + rollout_length=1 if self.use_short_debit_rollout else None, + base_env=base_env, ) # Cooperative rollouts with first action as the real one + if self.verbose > 1: + print("Parital cooperative rollouts") ( mean_total_reward_for_partially_coop_opp, _, @@ -160,6 +175,8 @@ def _compute_debit_using_rollouts(self, last_obs, opp_action, worker): partially_coop=True, opp_action=opp_action, last_obs=last_obs, + rollout_length=1 if self.use_short_debit_rollout else None, + base_env=base_env, ) if self.verbose > 0: @@ -209,9 +226,13 @@ def _switch_own_and_opp(self, agent_id): def _compute_punishment_duration( self, opp_action, coop_opp_simulated_action, worker, last_obs ): - return self._compute_punishment_duration_from_rollouts( + punishment_duration = self._compute_punishment_duration_from_rollouts( worker, last_obs ) + punishment_duration = max( + punishment_duration, self.config["min_punish_steps"] + ) + return punishment_duration def _compute_punishment_duration_from_rollouts(self, worker, last_obs): ( @@ -489,24 +510,30 @@ def _compute_opp_mean_total_reward( opp_action, last_obs, k_to_explore=0, + rollout_length=None, + base_env=None, ): opp_total_rewards = [] for i in range(self.n_rollout_replicas // 2): + self.last_opp_algo_idx_in_rollout = OPP_COOP_POLICY_IDX + self.last_own_algo_idx_in_rollout = OWN_COOP_POLICY_IDX self.n_steps_to_punish = k_to_explore self.n_steps_to_punish_opponent = k_to_explore if partially_coop: - assert len(self.overwrite_action) == 0 - self.overwrite_action = [ - (np.array([opp_action]), [], {}), - ] + self._set_overwrite_action(opp_action) + last_rnn_states = self._get_last_rnn_states_before_rollouts() coop_rollout = rollout.internal_rollout( worker, - num_steps=self.rollout_length, + num_steps=self.rollout_length + if rollout_length is None + else rollout_length, policy_map=policy_map, last_obs=last_obs, policy_agent_mapping=policy_agent_mapping, reset_env_before=False, num_episodes=1, + last_rnn_states=last_rnn_states, + base_env=base_env, ) assert ( coop_rollout._num_episodes == 1 @@ -536,8 +563,33 @@ def _compute_opp_mean_total_reward( opp_mean_total_reward = sum(opp_total_rewards) / len(opp_total_rewards) return opp_mean_total_reward, n_steps_played + def _set_overwrite_action(self, opp_action): + assert len(self.overwrite_action) == 0 + if self.config["model"]["use_lstm"]: + # When we play the real opponent action, don't break the sequence + # of rnn state for the cooperative opponent + # rnn_state = [[el] for el in self.coop_opp_rnn_state_after_last_act] + rnn_state = [ + self._get_initial_rnn_state(None) + for el in self.coop_opp_rnn_state_after_last_act + ] + + else: + rnn_state = [] + self.overwrite_action = [ + (np.array([opp_action]), rnn_state, {}), + ] + + @override(AmTFTPolicyBase) def on_episode_end(self, *args, **kwargs): - if self.working_state == WORKING_STATES[2]: - self.total_debit = 0 - self.n_steps_to_punish = 0 + assert not self.performing_rollouts + if self.working_state in WORKING_STATES_IN_EVALUATION: self.n_steps_to_punish_opponent = 0 + super().on_episode_end(*args, **kwargs) + + @override(AmTFTPolicyBase) + def on_episode_start(self, *args, **kwargs): + assert not self.performing_rollouts + if self.working_state in WORKING_STATES_IN_EVALUATION: + self.last_own_algo_idx_in_eval = OWN_COOP_POLICY_IDX + super().on_episode_start(*args, **kwargs) diff --git a/marltoolbox/algos/amTFT/train_helper.py b/marltoolbox/algos/amTFT/train_helper.py index 9d6109b..5d95e4e 100644 --- a/marltoolbox/algos/amTFT/train_helper.py +++ b/marltoolbox/algos/amTFT/train_helper.py @@ -5,22 +5,22 @@ from ray import tune from ray.rllib.agents.dqn import DQNTrainer +from marltoolbox import utils from marltoolbox.algos import amTFT -from marltoolbox.scripts.aggregate_and_plot_tensorboard_data import ( - add_summary_plots, -) -from marltoolbox.utils import miscellaneous, restore +from marltoolbox.scripts import aggregate_and_plot_tensorboard_data +from marltoolbox.utils import miscellaneous, restore, exp_analysis def train_amtft( stop_config, rllib_config, name, - do_not_load=[], + do_not_load=(), TrainerClass=DQNTrainer, - plot_keys=[], - plot_assemblage_tags=[], + plot_keys=(), + plot_assemblage_tags=(), debug=False, + punish_instead_of_selfish=False, **kwargs, ): """ @@ -46,56 +46,153 @@ def train_amtft( :param plot_assemblage_tags: arg for add_summary_plots :param debug: debug mode :param kwargs: kwargs for ray.tune.run - :return: tune_analysis containing the checkpoints of the pair of amTFT - policies + :return: experiment_analysis containing the checkpoints of the pair of + amTFT policies """ - selfish_name = os.path.join(name, "selfish") - tune_analysis_selfish_policies = _train_selfish_policies_inside_amtft( - stop_config, rllib_config, selfish_name, TrainerClass, **kwargs - ) - plot_keys, plot_assemblage_tags = _get_plot_keys( - plot_keys, plot_assemblage_tags - ) - if not debug: - add_summary_plots( - main_path=os.path.join("~/ray_results/", selfish_name), - plot_keys=plot_keys, - plot_assemble_tags_in_one_plot=plot_assemblage_tags, + if "loggers" in kwargs.keys() and "logger_config" in rllib_config.keys(): + using_wandb = True + else: + using_wandb = False + + if punish_instead_of_selfish: + policies = list(rllib_config["multiagent"]["policies"].keys()) + assert len(policies) == 2 + + selfish_name = os.path.join(name, "punisher_" + policies[0]) + if using_wandb: + rllib_config["logger_config"]["wandb"][ + "group" + ] += f"+_punisher_{policies[0]}" + experiment_analysis_selfish_policies = ( + _train_selfish_policies_inside_amtft( + stop_config, + rllib_config, + selfish_name, + TrainerClass, + punisher=policies[0], + true_selfish=policies[1], + **kwargs, + ) ) + plot_keys, plot_assemblage_tags = _get_plot_keys( + plot_keys, plot_assemblage_tags + ) + if not debug: + aggregate_and_plot_tensorboard_data.add_summary_plots( + main_path=os.path.join("~/ray_results/", selfish_name), + plot_keys=plot_keys, + plot_assemble_tags_in_one_plot=plot_assemblage_tags, + ) + + seed_to_selfish_checkpoints = _extract_selfish_policies_checkpoints( + experiment_analysis_selfish_policies + ) + rllib_config = _modify_config_to_load_selfish_policies_in_amtft( + rllib_config, do_not_load, seed_to_selfish_checkpoints + ) + + selfish_name = os.path.join(name, "punisher_" + policies[1]) + if using_wandb: + rllib_config["logger_config"]["wandb"]["group"] = ( + rllib_config["logger_config"]["wandb"]["group"].split("+")[0] + + f"+_punisher_{policies[1]}" + ) + experiment_analysis_selfish_policies = ( + _train_selfish_policies_inside_amtft( + stop_config, + rllib_config, + selfish_name, + TrainerClass, + punisher=policies[1], + true_selfish=policies[0], + **kwargs, + ) + ) + plot_keys, plot_assemblage_tags = _get_plot_keys( + plot_keys, plot_assemblage_tags + ) + if not debug: + aggregate_and_plot_tensorboard_data.add_summary_plots( + main_path=os.path.join("~/ray_results/", selfish_name), + plot_keys=plot_keys, + plot_assemble_tags_in_one_plot=plot_assemblage_tags, + ) + else: + selfish_name = os.path.join(name, "selfish") + if using_wandb: + rllib_config["logger_config"]["wandb"]["group"] = ( + rllib_config["logger_config"]["wandb"]["group"].split("+")[0] + + f"+_selfish" + ) + experiment_analysis_selfish_policies = ( + _train_selfish_policies_inside_amtft( + stop_config, rllib_config, selfish_name, TrainerClass, **kwargs + ) + ) + plot_keys, plot_assemblage_tags = _get_plot_keys( + plot_keys, plot_assemblage_tags + ) + if not debug: + aggregate_and_plot_tensorboard_data.add_summary_plots( + main_path=os.path.join("~/ray_results/", selfish_name), + plot_keys=plot_keys, + plot_assemble_tags_in_one_plot=plot_assemblage_tags, + ) seed_to_selfish_checkpoints = _extract_selfish_policies_checkpoints( - tune_analysis_selfish_policies + experiment_analysis_selfish_policies ) rllib_config = _modify_config_to_load_selfish_policies_in_amtft( rllib_config, do_not_load, seed_to_selfish_checkpoints ) + if using_wandb: + rllib_config["logger_config"]["wandb"]["group"] = ( + rllib_config["logger_config"]["wandb"]["group"].split("+")[0] + + "+_coop" + ) coop_name = os.path.join(name, "coop") - tune_analysis_amTFT_policies = _train_cooperative_policies_inside_amtft( - stop_config, rllib_config, coop_name, TrainerClass, **kwargs + experiment_analysis_amTFT_policies = ( + _train_cooperative_policies_inside_amtft( + stop_config, rllib_config, coop_name, TrainerClass, **kwargs + ) ) - if not debug: - add_summary_plots( + if not debug or True: + aggregate_and_plot_tensorboard_data.add_summary_plots( main_path=os.path.join("~/ray_results/", coop_name), plot_keys=plot_keys, plot_assemble_tags_in_one_plot=plot_assemblage_tags, ) - return tune_analysis_amTFT_policies + return experiment_analysis_amTFT_policies def _train_selfish_policies_inside_amtft( - stop_config, rllib_config, name, trainer_class, **kwargs + stop_config, + rllib_config, + name, + trainer_class, + punisher=None, + true_selfish=None, + **kwargs, ): rllib_config = copy.deepcopy(rllib_config) stop_config = copy.deepcopy(stop_config) - for policy_id in rllib_config["multiagent"]["policies"].keys(): - rllib_config["multiagent"]["policies"][policy_id][3][ + if punisher is not None: + rllib_config["multiagent"]["policies"][punisher][3][ "working_state" ] = "train_selfish" + rllib_config["multiagent"]["policies"][true_selfish][3][ + "working_state" + ] = "use_true_selfish" + else: + for policy_id in rllib_config["multiagent"]["policies"].keys(): + rllib_config["multiagent"]["policies"][policy_id][3][ + "working_state" + ] = "train_selfish" print("==============================================") print("amTFT starting to train the selfish policy") - tune_analysis_selfish_policies = ray.tune.run( + experiment_analysis_selfish_policies = ray.tune.run( trainer_class, config=rllib_config, stop=stop_config, @@ -105,7 +202,7 @@ def _train_selfish_policies_inside_amtft( mode="max", **kwargs, ) - return tune_analysis_selfish_policies + return experiment_analysis_selfish_policies def _get_plot_keys(plot_keys, plot_assemblage_tags): @@ -116,12 +213,14 @@ def _get_plot_keys(plot_keys, plot_assemblage_tags): return plot_keys, plot_assemble_tags_in_one_plot -def _extract_selfish_policies_checkpoints(tune_analysis_selfish_policies): - checkpoints = miscellaneous.extract_checkpoints( - tune_analysis_selfish_policies +def _extract_selfish_policies_checkpoints( + experiment_analysis_selfish_policies, +): + checkpoints = restore.extract_checkpoints_from_experiment_analysis( + experiment_analysis_selfish_policies ) - seeds = miscellaneous.extract_config_values_from_tune_analysis( - tune_analysis_selfish_policies, "seed" + seeds = exp_analysis.extract_config_values_from_experiment_analysis( + experiment_analysis_selfish_policies, "seed" ) seed_to_checkpoint = {} for seed, checkpoint in zip(seeds, checkpoints): @@ -157,7 +256,7 @@ def _train_cooperative_policies_inside_amtft( print("==============================================") print("amTFT starting to train the cooperative policy") - tune_analysis_amTFT_policies = ray.tune.run( + experiment_analysis_amtft_policies = ray.tune.run( trainer_class, config=rllib_config, stop=stop_config, @@ -167,4 +266,4 @@ def _train_cooperative_policies_inside_amtft( mode="max", **kwargs, ) - return tune_analysis_amTFT_policies + return experiment_analysis_amtft_policies diff --git a/marltoolbox/algos/amTFT/weights_exchanger.py b/marltoolbox/algos/amTFT/weights_exchanger.py index c296a05..380e812 100644 --- a/marltoolbox/algos/amTFT/weights_exchanger.py +++ b/marltoolbox/algos/amTFT/weights_exchanger.py @@ -44,6 +44,7 @@ def _share_weights_during_training(trainer): local_policy_map ) if in_training: + print("amTFT _share_weights_during_training") WeightsExchanger._check_only_amTFT_policies(local_policy_map) policies_weights = trainer.get_weights() policies_weights = ( diff --git a/marltoolbox/algos/augmented_dqn.py b/marltoolbox/algos/augmented_dqn.py index 0cbf986..e5bbe7c 100644 --- a/marltoolbox/algos/augmented_dqn.py +++ b/marltoolbox/algos/augmented_dqn.py @@ -1,7 +1,6 @@ -from typing import Dict - import torch import torch.nn.functional as F +from ray.rllib.agents.dqn import dqn_torch_policy from ray.rllib.agents.dqn.dqn_tf_policy import PRIO_WEIGHTS from ray.rllib.agents.dqn.dqn_torch_policy import ComputeTDErrorMixin from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy @@ -11,6 +10,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.torch_ops import FLOAT_MIN from ray.rllib.utils.typing import TensorType +from typing import Dict from marltoolbox.utils import log, optimizers, policy @@ -20,36 +20,34 @@ def build_q_losses_wt_additional_logs( ) -> TensorType: """ Copy of build_q_losses with additional values saved into the policy - Made only 2 changes, see in comments. + Made only 1 change, see in comments. """ config = policy.config # Q-network evaluation. - q_t, q_logits_t, q_probs_t = compute_q_values( + q_t, q_logits_t, q_probs_t, _ = compute_q_values( policy, - policy.q_model, - train_batch[SampleBatch.CUR_OBS], + model, + {"obs": train_batch[SampleBatch.CUR_OBS]}, explore=False, is_training=True, ) - # Addition 1 out of 2 - policy.last_q_t = q_t.clone() - # Target Q-network evaluation. - q_tp1, q_logits_tp1, q_probs_tp1 = compute_q_values( + q_tp1, _, q_probs_tp1, _ = compute_q_values( policy, policy.target_q_model, - train_batch[SampleBatch.NEXT_OBS], + {"obs": train_batch[SampleBatch.NEXT_OBS]}, explore=False, is_training=True, ) - # Addition 2 out of 2 - policy.last_target_q_t = q_tp1.clone() + # Only additions + policy.last_q_t = q_t.clone().detach() + policy.last_target_q_t = q_tp1.clone().detach() # Q scores for actions which we know were selected in the given state. one_hot_selection = F.one_hot( - train_batch[SampleBatch.ACTIONS], policy.action_space.n + train_batch[SampleBatch.ACTIONS].long(), policy.action_space.n ) q_t_selected = torch.sum( torch.where( @@ -68,10 +66,11 @@ def build_q_losses_wt_additional_logs( q_tp1_using_online_net, q_logits_tp1_using_online_net, q_dist_tp1_using_online_net, + _, ) = compute_q_values( policy, - policy.q_model, - train_batch[SampleBatch.NEXT_OBS], + model, + {"obs": train_batch[SampleBatch.NEXT_OBS]}, explore=False, is_training=True, ) @@ -108,20 +107,12 @@ def build_q_losses_wt_additional_logs( q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1 ) - if PRIO_WEIGHTS not in train_batch.keys(): - assert config["prioritized_replay"] is False - prio_weights = torch.tensor( - [1.0] * len(train_batch[SampleBatch.REWARDS]) - ).to(policy.device) - else: - prio_weights = train_batch[PRIO_WEIGHTS] - policy.q_loss = QLoss( q_t_selected, q_logits_t_selected, q_tp1_best, q_probs_tp1_best, - prio_weights, + train_batch[PRIO_WEIGHTS], train_batch[SampleBatch.REWARDS], train_batch[SampleBatch.DONES].float(), config["gamma"], @@ -134,28 +125,25 @@ def build_q_losses_wt_additional_logs( return policy.q_loss.loss -def build_q_stats_wt_addtional_log( - policy: Policy, batch -) -> Dict[str, TensorType]: - entropy_avg, entropy_single = log._compute_entropy_from_raw_q_values( - policy, policy.last_q_t.clone() - ) +def my_build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]: + q_stats = dqn_torch_policy.build_q_stats(policy, batch) - return dict( + entropy_avg, _ = log.compute_entropy_from_raw_q_values( + policy, policy.last_q_t + ) + q_stats.update( { "entropy_avg": entropy_avg, - "cur_lr": policy.cur_lr, - }, - **policy.q_loss.stats, + } ) + return q_stats + MyDQNTorchPolicy = DQNTorchPolicy.with_updates( optimizer_fn=optimizers.sgd_optimizer_dqn, loss_fn=build_q_losses_wt_additional_logs, - stats_fn=log.augment_stats_fn_wt_additionnal_logs( - build_q_stats_wt_addtional_log - ), + stats_fn=log.augment_stats_fn_wt_additionnal_logs(my_build_q_stats), before_init=policy.my_setup_early_mixins, mixins=[ TargetNetworkMixin, diff --git a/marltoolbox/algos/augmented_ppo.py b/marltoolbox/algos/augmented_ppo.py new file mode 100644 index 0000000..cf46a3a --- /dev/null +++ b/marltoolbox/algos/augmented_ppo.py @@ -0,0 +1,56 @@ +""" +PyTorch policy class used for PPO. +""" +import logging + +import gym +from ray.rllib.agents.ppo import PPOTorchPolicy +from ray.rllib.agents.ppo.ppo_torch_policy import ( + ValueNetworkMixin, + KLCoeffMixin, +) +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.torch_policy import EntropyCoeffSchedule +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import TrainerConfigDict + +from marltoolbox.utils.policy import MyLearningRateSchedule + +torch, nn = try_import_torch() + +logger = logging.getLogger(__name__) + + +def setup_mixins( + policy: Policy, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: TrainerConfigDict, +) -> None: + """Call all mixin classes' constructors before PPOPolicy initialization. + + Args: + policy (Policy): The Policy object. + obs_space (gym.spaces.Space): The Policy's observation space. + action_space (gym.spaces.Space): The Policy's action space. + config (TrainerConfigDict): The Policy's config. + """ + ValueNetworkMixin.__init__(policy, obs_space, action_space, config) + KLCoeffMixin.__init__(policy, config) + EntropyCoeffSchedule.__init__( + policy, config["entropy_coeff"], config["entropy_coeff_schedule"] + ) + MyLearningRateSchedule.__init__( + policy, config["lr"], config["lr_schedule"] + ) + + +MyPPOTorchPolicy = PPOTorchPolicy.with_updates( + before_loss_init=setup_mixins, + mixins=[ + MyLearningRateSchedule, + EntropyCoeffSchedule, + KLCoeffMixin, + ValueNetworkMixin, + ], +) diff --git a/marltoolbox/algos/augmented_r2d2.py b/marltoolbox/algos/augmented_r2d2.py new file mode 100644 index 0000000..0a46d65 --- /dev/null +++ b/marltoolbox/algos/augmented_r2d2.py @@ -0,0 +1,224 @@ +"""PyTorch policy class used for R2D2.""" + +import torch +import torch.nn.functional as F +from ray.rllib.agents.dqn import r2d2_torch_policy +from ray.rllib.agents.dqn.dqn_tf_policy import PRIO_WEIGHTS +from ray.rllib.agents.dqn.dqn_torch_policy import ComputeTDErrorMixin +from ray.rllib.agents.dqn.dqn_torch_policy import compute_q_values +from ray.rllib.agents.dqn.r2d2_torch_policy import ( + R2D2TorchPolicy, + h_function, + h_inverse, +) +from ray.rllib.agents.dqn.simple_q_torch_policy import TargetNetworkMixin +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.torch_ops import FLOAT_MIN +from ray.rllib.utils.torch_ops import huber_loss, sequence_mask, l2_loss +from ray.rllib.utils.typing import TensorType +from typing import Dict + +from marltoolbox.utils import log, optimizers, policy + + +def my_r2d2_loss( + policy: Policy, model, _, train_batch: SampleBatch +) -> TensorType: + """Constructs the loss for R2D2TorchPolicy. + + Args: + policy (Policy): The Policy to calculate the loss for. + model (ModelV2): The Model to calculate the loss for. + train_batch (SampleBatch): The training data. + + Returns: + TensorType: A single loss tensor. + """ + config = policy.config + + # Construct internal state inputs. + i = 0 + state_batches = [] + while "state_in_{}".format(i) in train_batch: + state_batches.append(train_batch["state_in_{}".format(i)]) + i += 1 + assert state_batches + + # Q-network evaluation (at t). + q, _, _, _ = compute_q_values( + policy, + model, + train_batch, + state_batches=state_batches, + seq_lens=train_batch.get("seq_lens"), + explore=False, + is_training=True, + ) + + # Target Q-network evaluation (at t+1). + q_target, _, _, _ = compute_q_values( + policy, + policy.target_q_model, + train_batch, + state_batches=state_batches, + seq_lens=train_batch.get("seq_lens"), + explore=False, + is_training=True, + ) + + # Only additions + policy.last_q = q.clone().detach() + policy.last_q_target = q_target.clone().detach() + + actions = train_batch[SampleBatch.ACTIONS].long() + dones = train_batch[SampleBatch.DONES].float() + rewards = train_batch[SampleBatch.REWARDS] + weights = train_batch[PRIO_WEIGHTS] + + B = state_batches[0].shape[0] + T = q.shape[0] // B + + # Q scores for actions which we know were selected in the given state. + one_hot_selection = F.one_hot(actions, policy.action_space.n) + q_selected = torch.sum( + torch.where(q > FLOAT_MIN, q, torch.tensor(0.0, device=policy.device)) + * one_hot_selection, + 1, + ) + + if config["double_q"]: + best_actions = torch.argmax(q, dim=1) + else: + best_actions = torch.argmax(q_target, dim=1) + + best_actions_one_hot = F.one_hot(best_actions, policy.action_space.n) + q_target_best = torch.sum( + torch.where( + q_target > FLOAT_MIN, + q_target, + torch.tensor(0.0, device=policy.device), + ) + * best_actions_one_hot, + dim=1, + ) + + if config["num_atoms"] > 1: + raise ValueError("Distributional R2D2 not supported yet!") + else: + q_target_best_masked_tp1 = (1.0 - dones) * torch.cat( + [q_target_best[1:], torch.tensor([0.0], device=policy.device)] + ) + + if config["use_h_function"]: + h_inv = h_inverse( + q_target_best_masked_tp1, config["h_function_epsilon"] + ) + target = h_function( + rewards + config["gamma"] ** config["n_step"] * h_inv, + config["h_function_epsilon"], + ) + else: + target = ( + rewards + + config["gamma"] ** config["n_step"] + * q_target_best_masked_tp1 + ) + + # Seq-mask all loss-related terms. + seq_mask = sequence_mask(train_batch["seq_lens"], T)[:, :-1] + # Mask away also the burn-in sequence at the beginning. + burn_in = policy.config["burn_in"] + if burn_in > 0 and burn_in < T: + seq_mask[:, :burn_in] = False + + num_valid = torch.sum(seq_mask) + + def reduce_mean_valid(t): + return torch.sum(t[seq_mask]) / num_valid + + # Make sure use the correct time indices: + # Q(t) - [gamma * r + Q^(t+1)] + q_selected = q_selected.reshape([B, T])[:, :-1] + td_error = q_selected - target.reshape([B, T])[:, :-1].detach() + td_error = td_error * seq_mask + weights = weights.reshape([B, T])[:, :-1] + + return weights, td_error, q_selected, reduce_mean_valid + + +def my_r2d2_loss_wt_huber_loss( + policy: Policy, model, _, train_batch: SampleBatch +) -> TensorType: + + weights, td_error, q_selected, reduce_mean_valid = my_r2d2_loss( + policy, model, _, train_batch + ) + + policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error)) + policy._td_error = td_error.reshape([-1]) + policy._loss_stats = { + "mean_q": reduce_mean_valid(q_selected), + "min_q": torch.min(q_selected), + "max_q": torch.max(q_selected), + "mean_td_error": reduce_mean_valid(td_error), + } + + return policy._total_loss + + +def my_r2d2_loss_wtout_huber_loss( + policy: Policy, model, _, train_batch: SampleBatch +) -> TensorType: + + weights, td_error, q_selected, reduce_mean_valid = my_r2d2_loss( + policy, model, _, train_batch + ) + + policy._total_loss = reduce_mean_valid(weights * l2_loss(td_error)) + policy._td_error = td_error.reshape([-1]) + policy._loss_stats = { + "mean_q": reduce_mean_valid(q_selected), + "min_q": torch.min(q_selected), + "max_q": torch.max(q_selected), + "mean_td_error": reduce_mean_valid(td_error), + } + + return policy._total_loss + + +def my_build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]: + q_stats = r2d2_torch_policy.build_q_stats(policy, batch) + + entropy_avg, _ = log.compute_entropy_from_raw_q_values( + policy, policy.last_q.clone() + ) + q_stats.update( + { + "entropy_avg": entropy_avg, + } + ) + + return q_stats + + +MyR2D2TorchPolicy = R2D2TorchPolicy.with_updates( + optimizer_fn=optimizers.sgd_optimizer_dqn, + loss_fn=my_r2d2_loss_wt_huber_loss, + stats_fn=log.augment_stats_fn_wt_additionnal_logs(my_build_q_stats), + before_init=policy.my_setup_early_mixins, + mixins=[ + TargetNetworkMixin, + ComputeTDErrorMixin, + policy.MyLearningRateSchedule, + ], +) + +MyR2D2TorchPolicyWtMSELoss = MyR2D2TorchPolicy.with_updates( + loss_fn=my_r2d2_loss_wtout_huber_loss, +) + +MyAdamDQNTorchPolicy = MyR2D2TorchPolicy.with_updates( + optimizer_fn=optimizers.adam_optimizer_dqn, +) diff --git a/marltoolbox/algos/exploiters/dual_behavior.py b/marltoolbox/algos/exploiters/dual_behavior.py index 31c8cfd..13e1527 100644 --- a/marltoolbox/algos/exploiters/dual_behavior.py +++ b/marltoolbox/algos/exploiters/dual_behavior.py @@ -4,6 +4,7 @@ from ray.rllib import SampleBatch from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import override +from ray.rllib.utils.threading import with_lock from marltoolbox.algos import hierarchical from marltoolbox.utils import postprocessing @@ -34,6 +35,7 @@ def __init__(self, observation_space, action_space, config, **kwargs): super().__init__(observation_space, action_space, config, **kwargs) self.active_algo_idx = self.SELFISH_POLICY_IDX + @with_lock @override(hierarchical.HierarchicalTorchPolicy) def _learn_on_batch(self, samples: SampleBatch): return self._train_dual_policies(samples) diff --git a/marltoolbox/algos/exploiters/level_1.py b/marltoolbox/algos/exploiters/level_1.py index cb2807a..b5bd8d8 100644 --- a/marltoolbox/algos/exploiters/level_1.py +++ b/marltoolbox/algos/exploiters/level_1.py @@ -1,10 +1,11 @@ import logging -from typing import TYPE_CHECKING from ray.rllib import SampleBatch -from ray.rllib.policy import Policy +from ray.rllib.policy import TorchPolicy from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import override +from ray.rllib.utils.threading import with_lock +from typing import TYPE_CHECKING from marltoolbox.algos.exploiters import dual_behavior @@ -17,7 +18,7 @@ dual_behavior.DEFAULT_CONFIG, { "lookahead_n_times": int(1), - } + }, ) @@ -34,6 +35,7 @@ class Level1ExploiterTorchPolicy(dual_behavior.DualBehaviorTorchPolicy): """ + LEVEL_1_POLICY_IDX = 2 def __init__(self, observation_space, action_space, config, **kwargs): @@ -50,13 +52,14 @@ def _learn_on_batch(self, samples: SampleBatch): def _train_level_1_policy(self, learner_stats, samples): raise NotImplementedError() - @override(Policy) - def compute_actions(self, *args, **kwargs): - + @with_lock + @override(TorchPolicy) + def _compute_action_helper(self, *args, **kwargs): selfish_action_tuple = self._compute_selfish_action(*args, **kwargs) coop_needed = self._is_cooperation_needed_to_prevent_punishement( - selfish_action_tuple, *args, **kwargs) + selfish_action_tuple, *args, **kwargs + ) self._to_log["coop_needed"] = coop_needed @@ -64,7 +67,7 @@ def compute_actions(self, *args, **kwargs): self._to_log["use_exploiter_cooperating"] = 1.0 self._to_log["use_exploiter_selfish"] = 0.0 self.active_algo_idx = self.COOP_POLICY_IDX - return super().compute_actions(*args, **kwargs) + return super()._compute_action_helper(*args, **kwargs) else: self._to_log["use_exploiter_cooperating"] = 0.0 self._to_log["use_exploiter_selfish"] = 1.0 @@ -74,12 +77,13 @@ def compute_actions(self, *args, **kwargs): def _compute_selfish_action(self, *args, **kwargs): tmp_active_algo_idx = self.active_algo_idx self.active_algo_idx = self.SELFISH_POLICY_IDX - selfish_action_tuple = super().compute_actions(*args, **kwargs) + selfish_action_tuple = super()._compute_action_helper(*args, **kwargs) self.active_algo_idx = tmp_active_algo_idx return selfish_action_tuple def _is_cooperation_needed_to_prevent_punishement( - self, selfish_action_tuple, *args, **kwargs): + self, selfish_action_tuple, *args, **kwargs + ): selfish_action = selfish_action_tuple[0] @@ -87,23 +91,26 @@ def _is_cooperation_needed_to_prevent_punishement( is_going_to_punish_at_next_step = False for _ in range(self.lookahead_n_times): - is_going_to_punish_at_next_step = \ + is_going_to_punish_at_next_step = ( self._lookahead_for_opponent_punishing( - selfish_action, - *args, - **kwargs) + selfish_action, *args, **kwargs + ) + ) if is_going_to_punish_at_next_step: break self._to_log["simu_opp_is_punishing"] = is_punishing - self._to_log["simu_opp_is_going_to_punish_at_next_step"] = \ - is_going_to_punish_at_next_step - return (is_punishing or is_going_to_punish_at_next_step) + self._to_log[ + "simu_opp_is_going_to_punish_at_next_step" + ] = is_going_to_punish_at_next_step + return is_punishing or is_going_to_punish_at_next_step def _lookahead_for_opponent_punishing( - self, selfish_action, *args, **kwargs) -> bool: + self, selfish_action, *args, **kwargs + ) -> bool: raise NotImplementedError() + # class ObserveOpponentCallbacks(DefaultCallbacks): # # def on_postprocess_trajectory( diff --git a/marltoolbox/algos/hierarchical.py b/marltoolbox/algos/hierarchical.py index aa6ab30..c7c3880 100644 --- a/marltoolbox/algos/hierarchical.py +++ b/marltoolbox/algos/hierarchical.py @@ -1,13 +1,14 @@ import copy import logging -from typing import List, Union, Optional, Dict, Tuple, Iterable + +from ray.rllib.utils.threading import with_lock +from typing import List, Union, Iterable import torch from ray import rllib -from ray.rllib.evaluation import MultiAgentEpisode from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.typing import TensorType +from ray.rllib.utils.annotations import override from marltoolbox.utils import log @@ -41,9 +42,9 @@ def __init__( updated_config.update(nested_config["config_update"]) if nested_config["Policy_class"] is None: raise ValueError( - f"You must specify the classes for the nested Policies " - f'in config["nested_config"]["Policy_class"] ' - f'current value is {nested_config["Policy_class"]}.' + "You must specify the classes for the nested Policies " + 'in config["nested_config"]["Policy_class"]. ' + f"nested_config: {nested_config}." ) Policy = nested_config["Policy_class"] policy = Policy( @@ -70,6 +71,21 @@ def __init__( self._to_log = {} self._already_printed_warnings = [] + self._merge_all_view_requirements() + + def _merge_all_view_requirements(self): + self.view_requirements = {} + for algo in self.algorithms: + for k, v in algo.view_requirements.items(): + self._add_in_view_requirement_or_check_are_equals(k, v) + + def _add_in_view_requirement_or_check_are_equals(self, k, v): + if k not in self.view_requirements.keys(): + self.view_requirements[k] = v + else: + assert vars(self.view_requirements[k]) == vars( + v + ), f"{vars(self.view_requirements[k])} must equal {vars(v)}" def __getattribute__(self, attr): """ @@ -87,10 +103,8 @@ def __getattribute__(self, attr): f"printing this message." ) logger.info(msg) - # from warnings import warn - # warn(msg) - return object.__getattribute__( - self.algorithms[self.active_algo_idx], attr + return self.algorithms[self.active_algo_idx].__getattribute__( + attr ) except AttributeError as secondary: raise type(initial)(f"{initial.args} and {secondary.args}") @@ -113,7 +127,6 @@ def to_log(self, value): for algo in self.algorithms: if hasattr(algo, "to_log"): algo.to_log = {} - # setattr(algo, "to_log", {}) self._to_log = value @@ -125,6 +138,14 @@ def model(self): def model(self, value): self.algorithms[self.active_algo_idx].model = value + @property + def exploration(self): + return self.algorithms[self.active_algo_idx].exploration + + @exploration.setter + def exploration(self, value): + self.algorithms[self.active_algo_idx].exploration = value + @property def dist_class(self): return self.algorithms[self.active_algo_idx].dist_class @@ -142,9 +163,11 @@ def global_timestep(self, value): for algo in self.algorithms: algo.global_timestep = value + @override(rllib.policy.Policy) def on_global_var_update(self, global_vars): for algo in self.algorithms: algo.on_global_var_update(global_vars) + super().on_global_var_update(global_vars) @property def update_target(self): @@ -155,12 +178,14 @@ def nested_update_target(): return nested_update_target + @override(rllib.policy.TorchPolicy) def get_weights(self): return { self.nested_key(i): algo.get_weights() for i, algo in enumerate(self.algorithms) } + @override(rllib.policy.TorchPolicy) def set_weights(self, weights): for i, algo in enumerate(self.algorithms): algo.set_weights(weights[self.nested_key(i)]) @@ -168,27 +193,22 @@ def set_weights(self, weights): def nested_key(self, i): return f"nested_{i}" - def compute_actions( - self, - obs_batch: Union[List[TensorType], TensorType], - state_batches: Optional[List[TensorType]] = None, - prev_action_batch: Union[List[TensorType], TensorType] = None, - prev_reward_batch: Union[List[TensorType], TensorType] = None, - info_batch: Optional[Dict[str, list]] = None, - episodes: Optional[List["MultiAgentEpisode"]] = None, - explore: Optional[bool] = None, - timestep: Optional[int] = None, - **kwargs, - ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: - + @with_lock + @override(rllib.policy.TorchPolicy) + def _compute_action_helper( + self, input_dict, state_batches, seq_lens, explore, timestep + ): actions, state_out, extra_fetches = self.algorithms[ self.active_algo_idx - ].compute_actions(obs_batch) + ]._compute_action_helper( + input_dict, state_batches, seq_lens, explore, timestep + ) return actions, state_out, extra_fetches + @with_lock def learn_on_batch(self, samples: SampleBatch): - self._update_lr_in_all_optimizers() + stats = self._learn_on_batch(samples) self._log_learning_rates() return stats @@ -217,22 +237,12 @@ def optimizer( The local PyTorch optimizer(s) to use for this Policy. """ - # TODO find a clean solution to update the LR when using a LearningRateSchedule - # TODO this will probably no more be needed when moving to RLLib>1.0.0 - for algo in self.algorithms: - if hasattr(algo, "cur_lr"): - for opt in algo._optimizers: - for p in opt.param_groups: - p["lr"] = algo.cur_lr - all_optimizers = [] for algo_n, algo in enumerate(self.algorithms): opt = algo.optimizer() all_optimizers.extend(opt) return all_optimizers - # TODO Move this in helper functions - def postprocess_trajectory( self, sample_batch, other_agent_batches=None, episode=None ): @@ -245,14 +255,13 @@ def _log_learning_rates(self): Use to log LR to check that they are really updated as configured. """ for algo_idx, algo in enumerate(self.algorithms): - self.to_log[f"algo{algo_idx}"] = log._log_learning_rate(algo) + self.to_log[f"algo{algo_idx}"] = log.log_learning_rate(algo) def set_state(self, state: object) -> None: state = state.copy() # shallow copy # Set optimizer vars first. optimizer_vars = state.pop("_optimizer_variables", None) if optimizer_vars: - print("self", self) assert len(optimizer_vars) == len( self._optimizers ), f"{len(optimizer_vars)} {len(self._optimizers)}" @@ -260,3 +269,12 @@ def set_state(self, state: object) -> None: o.load_state_dict(s) # Then the Policy's (NN) weights. super().set_state(state) + + def _get_dummy_batch_from_view_requirements( + self, batch_size: int = 1 + ) -> SampleBatch: + dummy_sample_batch = super()._get_dummy_batch_from_view_requirements( + batch_size + ) + dummy_sample_batch[dummy_sample_batch.DONES][-1] = True + return dummy_sample_batch diff --git a/marltoolbox/algos/lola/__init__.py b/marltoolbox/algos/lola/__init__.py index 7e97b24..553ba18 100644 --- a/marltoolbox/algos/lola/__init__.py +++ b/marltoolbox/algos/lola/__init__.py @@ -1,3 +1,3 @@ from marltoolbox.algos.lola.train_cg_tune_class_API import LOLAPGCG -from marltoolbox.algos.lola.train_exact_tune_class_API import LOLAExact +from marltoolbox.algos.lola.train_exact_tune_class_API import LOLAExactTrainer from marltoolbox.algos.lola.train_pg_tune_class_API import LOLAPGMatrice diff --git a/marltoolbox/algos/lola/networks.py b/marltoolbox/algos/lola/networks.py index dd01eaf..d17c8f8 100644 --- a/marltoolbox/algos/lola/networks.py +++ b/marltoolbox/algos/lola/networks.py @@ -100,7 +100,7 @@ def __init__( self.sample_reward_bis = tf.placeholder( shape=[None, trace_length], dtype=tf.float32, - name="sample_reward", + name="sample_reward_bis", ) self.j = tf.placeholder(shape=[None], dtype=tf.float32, name="j") @@ -239,6 +239,7 @@ def __init__( shape=[None, 1], dtype=tf.float32, name="next_value" ) self.next_v = tf.matmul(self.next_value, self.gamma_array_inverse) + if use_critic: self.target = self.sample_reward_bis + self.next_v else: @@ -325,6 +326,7 @@ def __init__( entropy_coeff * self.entropy + weigth_decay * self.weigths_norm ) * self.loss_multiplier + self.updateModel = self.trainer.minimize( total_loss, var_list=self.value_params ) diff --git a/marltoolbox/algos/lola/train_cg_tune_class_API.py b/marltoolbox/algos/lola/train_cg_tune_class_API.py index 7a5f903..ac93645 100644 --- a/marltoolbox/algos/lola/train_cg_tune_class_API.py +++ b/marltoolbox/algos/lola/train_cg_tune_class_API.py @@ -21,7 +21,7 @@ from marltoolbox.algos.lola.networks import Pnetwork from marltoolbox.algos.lola.utils import get_monte_carlo, make_cube from marltoolbox.envs.vectorized_coin_game import AsymVectorizedCoinGame -from marltoolbox.utils.full_epi_logger import FullEpisodeLogger +from marltoolbox.utils.log.full_epi_logger import FullEpisodeLogger PLOT_KEYS = [ "player_1_loss", @@ -141,7 +141,7 @@ def _init_lola( grid_size, gamma, hidden, - bs_mul, + global_lr_divider, lr, env_config, mem_efficient=True, @@ -168,6 +168,14 @@ def _init_lola( use_destabilizer=False, adding_scaled_weights=False, always_train_PG=False, + use_normalized_rewards=False, + use_centered_reward=False, + use_rolling_avg_actor_grad=False, + process_reward_after_rolling=False, + only_process_reward=False, + use_rolling_avg_reward=False, + reward_processing_bais=False, + center_and_normalize_with_rolling_avg=False, **kwargs, ): @@ -194,7 +202,6 @@ def _init_lola( self.grid_size = grid_size self.gamma = gamma self.hidden = hidden - self.bs_mul = bs_mul self.lr = lr self.mem_efficient = mem_efficient self.asymmetry = env_class == AsymVectorizedCoinGame @@ -232,10 +239,47 @@ def _init_lola( assert self.adding_scaled_weights > 0.0 self.always_train_PG = always_train_PG self.last_term_to_use = 0.0 + self.use_normalized_rewards = use_normalized_rewards + self.use_centered_reward = use_centered_reward + self.use_rolling_avg_actor_grad = use_rolling_avg_actor_grad + self.process_reward_after_rolling = process_reward_after_rolling + self.only_process_reward = only_process_reward + self.use_rolling_avg_reward = use_rolling_avg_reward + self.reward_processing_bais = reward_processing_bais + self.center_and_normalize_with_rolling_avg = ( + center_and_normalize_with_rolling_avg + ) + self._reward_rolling_avg = {} + self._reward_rolling_avg_length = 100 + + if self.use_rolling_avg_reward and self.use_rolling_avg_actor_grad: + self.use_rolling_avg_actor_grad = False + print("self.use_normalized_rewards", self.use_normalized_rewards) + print("self.use_centered_reward", self.use_centered_reward) + print( + "self.use_rolling_avg_actor_grad", self.use_rolling_avg_actor_grad + ) + + global_lr_divider_mul = 1.0 + if self.use_normalized_rewards: + global_lr_divider_mul *= 2.0 + if self.use_centered_reward: + global_lr_divider_mul *= 1.0 + if self.use_rolling_avg_actor_grad: + global_lr_divider_mul *= 1 / 20.0 + if self.use_rolling_avg_reward: + global_lr_divider_mul *= 1 / 10000.0 + + self.global_lr_divider = global_lr_divider * global_lr_divider_mul + self.global_lr_divider_original = ( + global_lr_divider * global_lr_divider_mul + ) self.obs_batch = deque(maxlen=self.batch_size) self.full_episode_logger = FullEpisodeLogger( - logdir=self._logdir, log_interval=100, log_ful_epi_one_hot_obs=True + logdir=self._logdir, + log_interval=100, + convert_one_hot_obs_to_idx=True, ) # Setting the training parameters @@ -249,6 +293,10 @@ def _init_lola( self.max_epLength = ( trace_length + 1 ) # The max allowed length of our episode. + self.actor_rolling_avg = [ + 10.0 / global_lr_divider_mul + for agent in range(self.total_n_agents) + ] graph = tf.Graph() @@ -301,69 +349,23 @@ def _init_lola( use_critic=use_critic, ) ) - # Clones of the opponents - if opp_model: - self.mainPN_clone = [] - for agent in range(self.total_n_agents): - self.mainPN_clone.append( - Pnetwork( - f"clone_{agent}", - self.h_size[agent], - agent, - self.env, - trace_length=trace_length, - batch_size=batch_size, - changed_config=changed_config, - ac_lr=ac_lr, - use_MAE=use_MAE, - clip_loss_norm=clip_loss_norm, - sess=self.sess, - entropy_coeff=entropy_coeff, - use_critic=use_critic, - ) - ) if not mem_efficient: self.cube, self.cube_ops = make_cube(trace_length) else: self.cube, self.cube_ops = None, None - if not opp_model: - corrections_func( - self.mainPN, - batch_size, - trace_length, - corrections, - self.cube, - clip_lola_update_norm=clip_lola_update_norm, - lola_correction_multiplier=self.lola_correction_multiplier, - clip_lola_correction_norm=clip_lola_correction_norm, - clip_lola_actor_norm=clip_lola_actor_norm, - ) - else: - corrections_func( - [self.mainPN[0], self.mainPN_clone[1]], - batch_size, - trace_length, - corrections, - self.cube, - clip_lola_update_norm=clip_lola_update_norm, - lola_correction_multiplier=self.lola_correction_multiplier, - clip_lola_correction_norm=clip_lola_correction_norm, - clip_lola_actor_norm=clip_lola_actor_norm, - ) - corrections_func( - [self.mainPN[1], self.mainPN_clone[0]], - batch_size, - trace_length, - corrections, - self.cube, - clip_lola_update_norm=clip_lola_update_norm, - lola_correction_multiplier=self.lola_correction_multiplier, - clip_lola_correction_norm=clip_lola_correction_norm, - clip_lola_actor_norm=clip_lola_actor_norm, - ) - clone_update(self.mainPN_clone) + corrections_func( + self.mainPN, + batch_size, + trace_length, + corrections, + self.cube, + clip_lola_update_norm=clip_lola_update_norm, + lola_correction_multiplier=self.lola_correction_multiplier, + clip_lola_correction_norm=clip_lola_correction_norm, + clip_lola_actor_norm=clip_lola_actor_norm, + ) if self.use_PG_exploiter: simple_actor_training_func( @@ -553,19 +555,24 @@ def step(self): pow_series = np.arange(self.trace_length) discount = np.array([pow(self.gamma, item) for item in pow_series]) + reward_mean_p0 = np.array(trainBatch0[2]).mean() + reward_mean_p1 = np.array(trainBatch1[2]).mean() + reward_std_p0 = np.array(trainBatch0[2]).std() + reward_std_p1 = np.array(trainBatch1[2]).std() + ( sample_return0, sample_reward0, sample_reward0_bis, ) = self.compute_centered_discounted_r( - rewards=trainBatch0[2], discount=discount + rewards=trainBatch0[2], discount=discount, player_key="player_0" ) ( sample_return1, sample_reward1, sample_reward1_bis, ) = self.compute_centered_discounted_r( - rewards=trainBatch1[2], discount=discount + rewards=trainBatch1[2], discount=discount, player_key="player_1" ) state_input0 = np.concatenate(trainBatch0[0], axis=0) @@ -607,34 +614,6 @@ def step(self): }, ) - if self.opp_model: - ## update local clones - update_clone = [ - self.mainPN_clone[0].update, - self.mainPN_clone[1].update, - ] - feed_dict = { - self.mainPN_clone[0].state_input: state_input1, - self.mainPN_clone[0].actions: actions1, - self.mainPN_clone[0].sample_return: sample_return1, - self.mainPN_clone[0].sample_reward: sample_reward1, - self.mainPN_clone[1].state_input: state_input0, - self.mainPN_clone[1].actions: actions0, - self.mainPN_clone[1].sample_return: sample_return0, - self.mainPN_clone[1].sample_reward: sample_reward0, - self.mainPN_clone[0].gamma_array: np.reshape( - discount, [1, -1] - ), - self.mainPN_clone[1].gamma_array: np.reshape( - discount, [1, -1] - ), - self.mainPN_clone[0].is_training: True, - self.mainPN_clone[1].is_training: True, - } - num_loops = 50 if self.timestep == 0 else 1 - for _ in range(num_loops): - self.sess.run(update_clone, feed_dict=feed_dict) - if self.lr_decay: lr_decay = (self.num_episodes - self.timestep) / self.num_episodes else: @@ -667,25 +646,6 @@ def step(self): self.mainPN[0].is_training: True, self.mainPN[1].is_training: True, } - if self.opp_model: - feed_dict.update( - { - self.mainPN_clone[0].state_input: state_input1, - self.mainPN_clone[0].actions: actions1, - self.mainPN_clone[0].sample_return: sample_return1, - self.mainPN_clone[0].sample_reward: sample_reward1, - self.mainPN_clone[1].state_input: state_input0, - self.mainPN_clone[1].actions: actions0, - self.mainPN_clone[1].sample_return: sample_return0, - self.mainPN_clone[1].sample_reward: sample_reward0, - self.mainPN_clone[0].gamma_array: np.reshape( - discount, [1, -1] - ), - self.mainPN_clone[1].gamma_array: np.reshape( - discount, [1, -1] - ), - } - ) lola_training_list = [ self.mainPN[0].value, @@ -728,103 +688,46 @@ def step(self): ) ( # Player_red - values, + values_p1, updateModel_1, update1, player_1_value, player_1_target, player_1_loss, - entropy_p_0, - v_0_log, - actor_target_error_0, - actor_loss_0, - parameters_norm_0, - second_order0, - v_0_grad_theta_0, - second_order0_sum, - actor_grad_sum_0, + entropy_p1, + v_0_log_p1, + actor_target_error_p1, + actor_loss_p1, + parameters_norm_p1, + second_order_p1, + v_0_grad_theta_p1, + second_order_sum_p1, + actor_grad_sum_p1, v_0_grad_01, - multiply0, + multiply_p1, # Player_blue - values_1, - updateModel_2, - update2, + values_p2, + updateModel_p2, + update_p2, player_2_value, player_2_target, player_2_loss, - entropy_p_1, - v_1_log, - actor_target_error_1, - actor_loss_1, - parameters_norm_1, - second_order1, - v_1_grad_theta_1, - second_order1_sum, - actor_grad_sum_1, + entropy_p2, + v_1_log_p2, + actor_target_error_p2, + actor_loss_p2, + parameters_norm_p2, + second_order_p2, + v_1_grad_theta_p2, + second_order_sum_p2, + actor_grad_sum_p2, ) = self.sess.run(lola_training_list, feed_dict=feed_dict) - if self.warmup: - update1 = update1 * self.warmup_step_n / self.warmup - update2 = update2 * self.warmup_step_n / self.warmup - if self.lr_decay: - update1 = update1 * lr_decay - update2 = update2 * lr_decay - - update1_sum = sum(update1) / self.bs_mul - update2_sum = sum(update2) / self.bs_mul - - update( - self.mainPN, - self.lr, - update1 / self.bs_mul, - update2 / self.bs_mul, - use_actions_from_exploiter, + update1_sum, update2_sum = self._update_players( + update1, update_p2, lr_decay, use_actions_from_exploiter ) if self.use_PG_exploiter: - # Update policy networks - feed_dict = { - self.mainPN[0].state_input: state_input0, - self.mainPN[0].sample_return: sample_return0, - self.mainPN[0].actions: actions0, - self.mainPN[2].state_input: state_input1, - self.mainPN[2].sample_return: sample_return1, - self.mainPN[2].actions: actions1, - self.mainPN[0].sample_reward: sample_reward0, - self.mainPN[2].sample_reward: sample_reward1, - self.mainPN[0].sample_reward_bis: sample_reward0_bis, - self.mainPN[2].sample_reward_bis: sample_reward1_bis, - self.mainPN[0].gamma_array: np.reshape(discount, [1, -1]), - self.mainPN[2].gamma_array: np.reshape(discount, [1, -1]), - self.mainPN[0].next_value: value_0_next, - self.mainPN[2].next_value: expl_value_next, - self.mainPN[0].gamma_array_inverse: np.reshape( - self.discount_array, [1, -1] - ), - self.mainPN[2].gamma_array_inverse: np.reshape( - self.discount_array, [1, -1] - ), - self.mainPN[0].loss_multiplier: [lr_decay], - self.mainPN[2].loss_multiplier: [lr_decay], - self.mainPN[0].is_training: True, - self.mainPN[2].is_training: True, - } - - lola_training_list = [ - self.mainPN[2].value, - self.mainPN[2].updateModel, - self.mainPN[2].delta, - self.mainPN[2].value, - self.mainPN[2].target, - self.mainPN[2].loss, - self.mainPN[2].entropy, - self.mainPN[2].v_0_log, - self.mainPN[2].actor_target_error, - self.mainPN[2].actor_loss, - self.mainPN[2].weigths_norm, - self.mainPN[2].grad, - self.mainPN[2].grad_sum, - ] ( pg_expl_values, pg_expl_updateModel, @@ -839,32 +742,405 @@ def step(self): pg_expl_parameters_norm, pg_expl_v_grad_theta, pg_expl_actor_grad_sum, - ) = self.sess.run(lola_training_list, feed_dict=feed_dict) + pg_expl_update_sum, + ) = self._update_pg_exploiter( + state_input0, + sample_return0, + actions0, + state_input1, + sample_return1, + actions1, + sample_reward0, + sample_reward1, + sample_reward0_bis, + sample_reward1_bis, + discount, + value_0_next, + expl_value_next, + lr_decay, + ) + else: + pg_expl_player_loss = None + pg_expl_v_log = None + pg_expl_entropy = None + pg_expl_actor_loss = None + pg_expl_parameters_norm = None + pg_expl_update_sum = None + pg_expl_actor_grad_sum = None + + self._update_actor_rolling_avg( + actor_grad_sum_p1, + actor_grad_sum_p2, + trainBatch0[2], + trainBatch1[2], + ) - if self.warmup: - pg_expl_update = ( - pg_expl_update * self.warmup_step_n / self.warmup + print("update params") + to_report = self._prepare_logs( + to_report, + last_info, + player_1_loss, + player_2_loss, + v_0_log_p1, + v_1_log_p2, + entropy_p1, + entropy_p2, + actor_loss_p1, + actor_loss_p2, + parameters_norm_p1, + parameters_norm_p2, + second_order_sum_p1, + second_order_sum_p2, + update1_sum, + update2_sum, + actor_grad_sum_p1, + actor_grad_sum_p2, + pg_expl_player_loss, + pg_expl_v_log, + pg_expl_entropy, + pg_expl_actor_loss, + pg_expl_parameters_norm, + pg_expl_update_sum, + pg_expl_actor_grad_sum, + reward_mean_p0, + reward_mean_p1, + reward_std_p0, + reward_std_p1, + sample_return0, + sample_return1, + sample_reward0, + sample_reward1, + values_p1, + values_p2, + player_1_target, + player_2_target, + ) + return to_report + + def compute_centered_discounted_r(self, rewards, discount, player_key=""): + if not self.only_process_reward: + rewards = self._process_rewards( + rewards, preprocess=True, key=player_key + "_rewards" + ) + + sample_return = np.reshape( + get_monte_carlo( + rewards, self.y, self.trace_length, self.batch_size + ), + [self.batch_size, -1], + ) + + if self.only_process_reward: + rewards = self._process_rewards( + rewards, preprocess=True, key=player_key + "_rewards" + ) + + if self.correction_reward_baseline_per_step: + sample_reward = discount * np.reshape( + rewards - np.mean(np.array(rewards), axis=0), + [-1, self.trace_length], + ) + else: + sample_reward = discount * np.reshape( + rewards - np.mean(rewards), [-1, self.trace_length] + ) + sample_reward_bis = discount * np.reshape( + rewards, [-1, self.trace_length] + ) + + if not self.only_process_reward: + sample_return = self._process_rewards( + sample_return, + preprocess=False, + key=player_key + "_sample_return", + ) + sample_reward = self._process_rewards( + sample_reward, preprocess=False, key=player_key + "_sample_reward" + ) + sample_reward_bis = self._process_rewards( + sample_reward_bis, + preprocess=False, + key=player_key + "_sample_reward_bis", + ) + return sample_return, sample_reward, sample_reward_bis + + def _process_rewards(self, rewards, preprocess, key: str): + postprocess = not preprocess + preprocessing = preprocess and not self.process_reward_after_rolling + postprocessing = postprocess and self.process_reward_after_rolling + if preprocessing or postprocessing: + if self.use_normalized_rewards: + return self._normalize_rewards(rewards, key) + elif self.use_centered_reward: + return self._center_reward(rewards, key) + return rewards + + def _normalize_rewards(self, rewards, key: str): + r_array = np.array(rewards) + if self.center_and_normalize_with_rolling_avg: + mean_key = key + "_mean" + std_key = key + "_std" + self._update_reward_rolling_avg(r_array.mean(), mean_key) + self._update_reward_rolling_avg(r_array.std(), std_key) + rewards = ( + r_array - self._get_reward_rolling_avg(mean_key) + ) / self._get_reward_rolling_avg(std_key) + else: + rewards = (r_array - r_array.mean()) / r_array.std() + if self.reward_processing_bais: + rewards += self.reward_processing_bais + return rewards.tolist() + + def _center_reward(self, rewards, key: str): + r_array = np.array(rewards) + if self.center_and_normalize_with_rolling_avg: + self._update_reward_rolling_avg(r_array.mean(), key) + rewards = r_array - self._get_reward_rolling_avg(key) + else: + rewards = r_array - r_array.mean() + if self.reward_processing_bais: + rewards += self.reward_processing_bais + return rewards.tolist() + + def _update_reward_rolling_avg(self, value, key): + if key in self._reward_rolling_avg: + self._reward_rolling_avg[key] = value + self._reward_rolling_avg[ + key + ] * (1 - (1 / self._reward_rolling_avg_length)) + else: + self._reward_rolling_avg[key] = ( + value * self._reward_rolling_avg_length + ) + + def _get_reward_rolling_avg(self, key): + return self._reward_rolling_avg[key] / self._reward_rolling_avg_length + + def _update_bs_mul_wt_rolling_avg(self): + rolling_avg = None + if self.use_rolling_avg_actor_grad: + rolling_avg = self.use_rolling_avg_actor_grad + if self.use_rolling_avg_reward: + rolling_avg = self.use_rolling_avg_reward + + if rolling_avg is not None: + self.global_lr_divider = [ + self.global_lr_divider_original * v / rolling_avg + for v in self.actor_rolling_avg + ] + + def _update_actor_rolling_avg( + self, + actor_grad_sum_0, + actor_grad_sum_1, + reward_player_0, + reward_player_1, + ): + if self.use_rolling_avg_actor_grad or self.use_rolling_avg_reward: + roll_avg_p0 = self.actor_rolling_avg[0] + roll_avg_p1 = self.actor_rolling_avg[1] + if self.use_rolling_avg_actor_grad: + roll_avg_p0 = np.abs(actor_grad_sum_0) + roll_avg_p0 * ( + 1 - (1 / self.use_rolling_avg_actor_grad) ) - if self.lr_decay: - pg_expl_update = pg_expl_update * lr_decay + roll_avg_p1 = np.abs(actor_grad_sum_1) + roll_avg_p1 * ( + 1 - (1 / self.use_rolling_avg_actor_grad) + ) + if self.use_rolling_avg_reward: + roll_avg_p0 = np.sum(np.abs(reward_player_0)) + roll_avg_p0 * ( + 1 - (1 / self.use_rolling_avg_reward) + ) + roll_avg_p1 = np.sum(np.abs(reward_player_1)) + roll_avg_p1 * ( + 1 - (1 / self.use_rolling_avg_reward) + ) + self.actor_rolling_avg[0] = roll_avg_p0 + self.actor_rolling_avg[1] = roll_avg_p1 - # pg_expl_update_to_log = pg_expl_update - pg_expl_update_sum = sum(pg_expl_update) / self.bs_mul + def _update_players( + self, update1, update2, lr_decay, use_actions_from_exploiter + ): + if self.warmup: + update1 = update1 * self.warmup_step_n / self.warmup + update2 = update2 * self.warmup_step_n / self.warmup + if self.lr_decay: + update1 = update1 * lr_decay + update2 = update2 * lr_decay - update_single( - self.mainPN[2], self.lr, pg_expl_update / self.bs_mul + if self.use_rolling_avg_actor_grad or self.use_rolling_avg_reward: + self._update_bs_mul_wt_rolling_avg() + update1_sum = sum(update1) / self.global_lr_divider[0] + update2_sum = sum(update2) / self.global_lr_divider[1] + update( + self.mainPN, + self.lr, + update1 / self.global_lr_divider[0], + update2 / self.global_lr_divider[1], + use_actions_from_exploiter, ) + else: + update1_sum = sum(update1) / self.global_lr_divider + update2_sum = sum(update2) / self.global_lr_divider + update( + self.mainPN, + self.lr, + update1 / self.global_lr_divider, + update2 / self.global_lr_divider, + use_actions_from_exploiter, + ) + return update1_sum, update2_sum - if self.timestep >= self.start_using_exploiter_at_update_n: - if self.timestep % self.every_n_updates_copy_weights == 0: - copy_weigths( - from_policy=self.mainPN[2], - to_policy=self.mainPN[1], - adding_scaled_weights=self.adding_scaled_weights, - ) + def _update_pg_exploiter( + self, + state_input0, + sample_return0, + actions0, + state_input1, + sample_return1, + actions1, + sample_reward0, + sample_reward1, + sample_reward0_bis, + sample_reward1_bis, + discount, + value_0_next, + expl_value_next, + lr_decay, + ): + # Update policy networks + feed_dict = { + self.mainPN[0].state_input: state_input0, + self.mainPN[0].sample_return: sample_return0, + self.mainPN[0].actions: actions0, + self.mainPN[2].state_input: state_input1, + self.mainPN[2].sample_return: sample_return1, + self.mainPN[2].actions: actions1, + self.mainPN[0].sample_reward: sample_reward0, + self.mainPN[2].sample_reward: sample_reward1, + self.mainPN[0].sample_reward_bis: sample_reward0_bis, + self.mainPN[2].sample_reward_bis: sample_reward1_bis, + self.mainPN[0].gamma_array: np.reshape(discount, [1, -1]), + self.mainPN[2].gamma_array: np.reshape(discount, [1, -1]), + self.mainPN[0].next_value: value_0_next, + self.mainPN[2].next_value: expl_value_next, + self.mainPN[0].gamma_array_inverse: np.reshape( + self.discount_array, [1, -1] + ), + self.mainPN[2].gamma_array_inverse: np.reshape( + self.discount_array, [1, -1] + ), + self.mainPN[0].loss_multiplier: [lr_decay], + self.mainPN[2].loss_multiplier: [lr_decay], + self.mainPN[0].is_training: True, + self.mainPN[2].is_training: True, + } - print("update params") + lola_training_list = [ + self.mainPN[2].value, + self.mainPN[2].updateModel, + self.mainPN[2].delta, + self.mainPN[2].value, + self.mainPN[2].target, + self.mainPN[2].loss, + self.mainPN[2].entropy, + self.mainPN[2].v_0_log, + self.mainPN[2].actor_target_error, + self.mainPN[2].actor_loss, + self.mainPN[2].weigths_norm, + self.mainPN[2].grad, + self.mainPN[2].grad_sum, + ] + ( + pg_expl_values, + pg_expl_updateModel, + pg_expl_update, + pg_expl_player_value, + pg_expl_player_target, + pg_expl_player_loss, + pg_expl_entropy, + pg_expl_v_log, + pg_expl_actor_target_error, + pg_expl_actor_loss, + pg_expl_parameters_norm, + pg_expl_v_grad_theta, + pg_expl_actor_grad_sum, + ) = self.sess.run(lola_training_list, feed_dict=feed_dict) + + if self.warmup: + pg_expl_update = pg_expl_update * self.warmup_step_n / self.warmup + if self.lr_decay: + pg_expl_update = pg_expl_update * lr_decay + # pg_expl_update_to_log = pg_expl_update + pg_expl_update_sum = sum(pg_expl_update) / self.global_lr_divider + + update_single( + self.mainPN[2], self.lr, pg_expl_update / self.global_lr_divider + ) + + if self.timestep >= self.start_using_exploiter_at_update_n: + if self.timestep % self.every_n_updates_copy_weights == 0: + copy_weigths( + from_policy=self.mainPN[2], + to_policy=self.mainPN[1], + adding_scaled_weights=self.adding_scaled_weights, + ) + + return ( + pg_expl_values, + pg_expl_updateModel, + pg_expl_update, + pg_expl_player_value, + pg_expl_player_target, + pg_expl_player_loss, + pg_expl_entropy, + pg_expl_v_log, + pg_expl_actor_target_error, + pg_expl_actor_loss, + pg_expl_parameters_norm, + pg_expl_v_grad_theta, + pg_expl_actor_grad_sum, + pg_expl_update_sum, + ) + + def _prepare_logs( + self, + to_report, + last_info, + player_1_loss, + player_2_loss, + v_0_log, + v_1_log, + entropy_p_0, + entropy_p_1, + actor_loss_0, + actor_loss_1, + parameters_norm_0, + parameters_norm_1, + second_order0_sum, + second_order1_sum, + update1_sum, + update2_sum, + actor_grad_sum_0, + actor_grad_sum_1, + pg_expl_player_loss, + pg_expl_v_log, + pg_expl_entropy, + pg_expl_actor_loss, + pg_expl_parameters_norm, + pg_expl_update_sum, + pg_expl_actor_grad_sum, + reward_mean_p0, + reward_mean_p1, + reward_std_p0, + reward_std_p1, + sample_return0, + sample_return1, + sample_reward0, + sample_reward1, + values, + values_1, + player_1_target, + player_2_target, + ): rlog = np.sum(self.rList[-self.summary_len :], 0) to_plot = {} @@ -891,6 +1167,10 @@ def step(self): last_info.pop("available_actions", None) training_info = { + "reward_mean_p0": reward_mean_p0, + "reward_mean_p1": reward_mean_p1, + "reward_std_p0": reward_std_p0, + "reward_std_p1": reward_std_p1, "player_1_loss": player_1_loss, "player_2_loss": player_2_loss, "v_0_log": v_0_log, @@ -911,16 +1191,18 @@ def step(self): / self.num_episodes, } # Logging distribution (can be a speed bottleneck) - # training_info.update({ - # "sample_return0": sample_return0, - # "sample_return1": sample_return1, - # "sample_reward0": sample_reward0, - # "sample_reward1": sample_reward1, - # "player_1_values": values, - # "player_2_values": values_1, - # "player_1_target": player_1_target, - # "player_2_target": player_2_target, - # }) + # training_info.update( + # { + # "sample_return0": sample_return0, + # "sample_return1": sample_return1, + # "sample_reward0": sample_reward0, + # "sample_reward1": sample_reward1, + # "player_1_values": values, + # "player_2_values": values_1, + # "player_1_target": player_1_target, + # "player_2_target": player_2_target, + # } + # ) # self.update1_list.clear() # self.update2_list.clear() @@ -952,30 +1234,14 @@ def step(self): + to_report.get("player_red_pick_own_color", 0.0) ) + if self.use_rolling_avg_actor_grad: + to_report["rolling_avg_actor_grad_p0"] = self.actor_rolling_avg[0] + to_report["rolling_avg_actor_grad_p1"] = self.actor_rolling_avg[1] + if self.use_rolling_avg_reward: + to_report["rolling_avg_reward_p0"] = self.actor_rolling_avg[0] + to_report["rolling_avg_reward_p1"] = self.actor_rolling_avg[1] return to_report - def compute_centered_discounted_r(self, rewards, discount): - sample_return = np.reshape( - get_monte_carlo( - rewards, self.y, self.trace_length, self.batch_size - ), - [self.batch_size, -1], - ) - - if self.correction_reward_baseline_per_step: - sample_reward = discount * np.reshape( - rewards - np.mean(np.array(rewards), axis=0), - [-1, self.trace_length], - ) - else: - sample_reward = discount * np.reshape( - rewards - np.mean(rewards), [-1, self.trace_length] - ) - sample_reward_bis = discount * np.reshape( - rewards, [-1, self.trace_length] - ) - return sample_return, sample_reward, sample_reward_bis - def _log_one_step_in_full_episode(self, s, r, actions, obs, info): self.full_episode_logger.on_episode_step( step_data={ diff --git a/marltoolbox/algos/lola/train_exact_tune_class_API.py b/marltoolbox/algos/lola/train_exact_tune_class_API.py index ef844ab..472c00a 100644 --- a/marltoolbox/algos/lola/train_exact_tune_class_API.py +++ b/marltoolbox/algos/lola/train_exact_tune_class_API.py @@ -11,7 +11,7 @@ import json import os import random - +import torch import numpy as np import tensorflow as tf import tensorflow.contrib.layers as layers @@ -29,22 +29,35 @@ class Qnetwork: Q-network that is either a look-up table or an MLP with 1 hidden layer. """ - def __init__(self, myScope, num_hidden, sess, simple_net=True): + def __init__( + self, myScope, num_hidden, sess, simple_net=True, n_actions=2, std=1.0 + ): with tf.variable_scope(myScope): - self.input_place = tf.placeholder(shape=[5], dtype=tf.int32) + # self.input_place = tf.placeholder(shape=[5], dtype=tf.int32) + # if simple_net: + # self.p_act = tf.Variable(tf.random_normal([5, 1], stddev=3.0)) + # else: + # act = tf.nn.tanh( + # layers.fully_connected( + # tf.one_hot(self.input_place, 5, dtype=tf.float32), + # num_outputs=num_hidden, + # activation_fn=None, + # ) + # ) + # self.p_act = layers.fully_connected( + # act, num_outputs=1, activation_fn=None + # ) + self.input_place = tf.placeholder( + shape=[n_actions ** 2 + 1], dtype=tf.int32 + ) if simple_net: - self.p_act = tf.Variable(tf.random_normal([5, 1], stddev=1.0)) - else: - act = tf.nn.tanh( - layers.fully_connected( - tf.one_hot(self.input_place, 5, dtype=tf.float32), - num_outputs=num_hidden, - activation_fn=None, + self.p_act = tf.Variable( + tf.random_normal( + [n_actions ** 2 + 1, n_actions], stddev=std ) ) - self.p_act = layers.fully_connected( - act, num_outputs=1, activation_fn=None - ) + else: + raise ValueError() self.parameters = [] for i in tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=myScope @@ -63,51 +76,135 @@ def update(mainQN, lr, final_delta_1_v, final_delta_2_v): ) -def corrections_func(mainQN, corrections, gamma, pseudo, reg): +def corrections_func(mainQN, corrections, gamma, pseudo, reg, n_actions=2): + print_opts = [] + mainQN[0].lr_correction = tf.placeholder(shape=[1], dtype=tf.float32) mainQN[1].lr_correction = tf.placeholder(shape=[1], dtype=tf.float32) theta_1_all = mainQN[0].p_act theta_2_all = mainQN[1].p_act - theta_1 = tf.slice(theta_1_all, [0, 0], [4, 1]) - theta_2 = tf.slice(theta_2_all, [0, 0], [4, 1]) - - theta_1_0 = tf.slice(theta_1_all, [4, 0], [1, 1]) - theta_2_0 = tf.slice(theta_2_all, [4, 0], [1, 1]) - p_1 = tf.nn.sigmoid(theta_1) - p_2 = tf.nn.sigmoid(theta_2) - mainQN[0].policy = tf.nn.sigmoid(theta_1_all) - mainQN[1].policy = tf.nn.sigmoid(theta_2_all) - - p_1_0 = tf.nn.sigmoid(theta_1_0) - p_2_0 = tf.nn.sigmoid(theta_2_0) + # Using sigmoid + normalize to keep values similar to the official + # implementation + pi_player_1_for_all_states = tf.nn.sigmoid(theta_1_all) + pi_player_2_for_all_states = tf.nn.sigmoid(theta_2_all) + print_opts.append( + tf.print("pi_player_1_for_all_states", pi_player_1_for_all_states) + ) + print_opts.append( + tf.print("pi_player_2_for_all_states", pi_player_2_for_all_states) + ) + sum_1 = tf.reduce_sum(pi_player_1_for_all_states, axis=1) + sum_1 = tf.stack([sum_1 for _ in range(n_actions)], axis=1) + sum_2 = tf.reduce_sum(pi_player_2_for_all_states, axis=1) + sum_2 = tf.stack([sum_2 for _ in range(n_actions)], axis=1) + pi_player_1_for_all_states = pi_player_1_for_all_states / sum_1 + pi_player_2_for_all_states = pi_player_2_for_all_states / sum_2 + print_opts.append( + tf.print("pi_player_1_for_all_states", pi_player_1_for_all_states) + ) + mainQN[0].policy = pi_player_1_for_all_states + mainQN[1].policy = pi_player_2_for_all_states + n_states = int(pi_player_1_for_all_states.shape[0]) + print_opts.append(tf.print("mainQN[0].policy", mainQN[0].policy)) + print_opts.append(tf.print("mainQN[1].policy", mainQN[1].policy)) + pi_player_1_for_states_in_game = tf.slice( + pi_player_1_for_all_states, + [0, 0], + [n_states - 1, n_actions], + ) + pi_player_2_for_states_in_game = tf.slice( + pi_player_2_for_all_states, + [0, 0], + [n_states - 1, n_actions], + ) + pi_player_1_for_initial_state = tf.slice( + pi_player_1_for_all_states, + [n_states - 1, 0], + [1, n_actions], + ) + pi_player_2_for_initial_state = tf.slice( + pi_player_2_for_all_states, + [n_states - 1, 0], + [1, n_actions], + ) + pi_player_1_for_initial_state = tf.transpose(pi_player_1_for_initial_state) + pi_player_2_for_initial_state = tf.transpose(pi_player_2_for_initial_state) + print_opts.append( + tf.print( + "pi_player_1_for_initial_state", pi_player_1_for_initial_state + ) + ) - p_1_0_v = tf.concat([p_1_0, (1 - p_1_0)], 0) - p_2_0_v = tf.concat([p_2_0, (1 - p_2_0)], 0) + s_0 = tf.reshape( + tf.matmul( + pi_player_1_for_initial_state, + tf.transpose(pi_player_2_for_initial_state), + ), + [-1, 1], + ) - s_0 = tf.reshape(tf.matmul(p_1_0_v, tf.transpose(p_2_0_v)), [-1, 1]) + pi_p1 = tf.reshape(pi_player_1_for_states_in_game, [n_states - 1, -1, 1]) + pi_p2 = tf.reshape(pi_player_2_for_states_in_game, [n_states - 1, -1, 1]) - # CC, CD, DC, DD + all_actions_proba_pairs = [] + for action_p1 in range(n_actions): + for action_p2 in range(n_actions): + all_actions_proba_pairs.append( + tf.multiply(pi_p1[:, action_p1], pi_p2[:, action_p2]) + ) P = tf.concat( - [ - tf.multiply(p_1, p_2), - tf.multiply(p_1, 1 - p_2), - tf.multiply(1 - p_1, p_2), - tf.multiply(1 - p_1, 1 - p_2), - ], + all_actions_proba_pairs, 1, ) - R_1 = tf.placeholder(shape=[4, 1], dtype=tf.float32) - R_2 = tf.placeholder(shape=[4, 1], dtype=tf.float32) - - I_m_P = tf.diag([1.0, 1.0, 1.0, 1.0]) - P * gamma + # if n_actions == 2: + # # CC, CD, DC, DD + # + # P = tf.concat( + # [ + # tf.multiply(pi_p1[:, 0], pi_p2[:, 0]), + # tf.multiply(pi_p1[:, 0], pi_p2[:, 1]), + # tf.multiply(pi_p1[:, 1], pi_p2[:, 0]), + # tf.multiply(pi_p1[:, 1], pi_p2[:, 1]), + # ], + # 1, + # ) + # elif n_actions == 3: + # # CC, CD, CN, DC, DD, DN, NC, ND, NN + # P = tf.concat( + # [ + # tf.multiply(pi_p1[:, 0], pi_p2[:, 0]), + # tf.multiply(pi_p1[:, 0], pi_p2[:, 1]), + # tf.multiply(pi_p1[:, 0], pi_p2[:, 2]), + # tf.multiply(pi_p1[:, 1], pi_p2[:, 0]), + # tf.multiply(pi_p1[:, 1], pi_p2[:, 1]), + # tf.multiply(pi_p1[:, 1], pi_p2[:, 2]), + # tf.multiply(pi_p1[:, 2], pi_p2[:, 0]), + # tf.multiply(pi_p1[:, 2], pi_p2[:, 1]), + # tf.multiply(pi_p1[:, 2], pi_p2[:, 2]), + # ], + # 1, + # ) + # else: + # raise ValueError(f"n_actions {n_actions}") + # R_1 = tf.placeholder(shape=[4, 1], dtype=tf.float32) + # R_2 = tf.placeholder(shape=[4, 1], dtype=tf.float32) + R_1 = tf.placeholder(shape=[n_actions ** 2, 1], dtype=tf.float32) + R_2 = tf.placeholder(shape=[n_actions ** 2, 1], dtype=tf.float32) + + # I_m_P = tf.diag([1.0, 1.0, 1.0, 1.0]) - P * gamma + I_m_P = tf.diag([1.0] * (n_actions ** 2)) - P * gamma v_0 = tf.matmul( tf.matmul(tf.matrix_inverse(I_m_P), R_1), s_0, transpose_a=True ) v_1 = tf.matmul( tf.matmul(tf.matrix_inverse(I_m_P), R_2), s_0, transpose_a=True ) + print_opts.append(tf.print("s_0", s_0)) + print_opts.append(tf.print("I_m_P", I_m_P)) + print_opts.append(tf.print("R_1", R_1)) + print_opts.append(tf.print("v_0", v_0)) if reg > 0: for indx, _ in enumerate(mainQN[0].parameters): v_0 -= reg * tf.reduce_sum( @@ -145,6 +242,7 @@ def corrections_func(mainQN, corrections, gamma, pseudo, reg): tf.reshape(v_0_grad_theta_0_wrong, [param_len, 1]), ) + # with tf.control_dependencies(print_opts): second_order0 = flatgrad(multiply0, mainQN[0].parameters) second_order1 = flatgrad(multiply1, mainQN[1].parameters) @@ -158,11 +256,11 @@ def corrections_func(mainQN, corrections, gamma, pseudo, reg): mainQN[1].delta += tf.multiply(second_order1, mainQN[1].lr_correction) -class LOLAExact(tune.Trainable): +class LOLAExactTrainer(tune.Trainable): def _init_lola( self, - env_name, *, + env_name="IteratedAsymBoS", num_episodes=50, trace_length=200, simple_net=True, @@ -175,10 +273,12 @@ def _init_lola( gamma=0.96, with_linear_LR_decay_to_zero=False, clip_update=None, + re_init_every_n_epi=1, + Q_net_std=1.0, **kwargs, ): - print("args not used:", kwargs) + # print("args not used:", kwargs) self.num_episodes = num_episodes self.trace_length = trace_length @@ -192,6 +292,8 @@ def _init_lola( self.gamma = gamma self.with_linear_LR_decay_to_zero = with_linear_LR_decay_to_zero self.clip_update = clip_update + self.re_init_every_n_epi = re_init_every_n_epi + self.Q_net_std = Q_net_std graph = tf.Graph() @@ -202,11 +304,19 @@ def _init_lola( if env_name == "IPD": self.payout_mat_1 = np.array([[-1.0, 0.0], [-3.0, -2.0]]) self.payout_mat_2 = self.payout_mat_1.T - elif env_name == "AsymmetricIteratedBoS": + elif env_name == "IteratedAsymBoS": self.payout_mat_1 = np.array([[+4.0, 0.0], [0.0, +2.0]]) self.payout_mat_2 = np.array([[+1.0, 0.0], [0.0, +2.0]]) + elif "custom_payoff_matrix" in kwargs.keys(): + custom_matrix = kwargs["custom_payoff_matrix"] + self.payout_mat_1 = custom_matrix[:, :, 0] + self.payout_mat_2 = custom_matrix[:, :, 1] else: raise ValueError(f"exp_name: {env_name}") + self.n_actions = int(self.payout_mat_1.shape[0]) + assert self.n_actions == self.payout_mat_1.shape[1] + self.policy1 = [0.0] * self.n_actions + self.policy2 = [0.0] * self.n_actions # Sanity @@ -219,6 +329,8 @@ def _init_lola( self.num_hidden, self.sess, self.simple_net, + self.n_actions, + self.Q_net_std, ) ) @@ -229,6 +341,7 @@ def _init_lola( self.gamma, self.pseudo, self.reg, + self.n_actions, ) self.results = [] @@ -241,21 +354,22 @@ def _init_lola( # TODO add something to not load and create everything when only evaluating with RLLib def setup(self, config): - print("_init_lola", config) self._init_lola(**config) def step(self): - self.sess.run(self.init) - lr_coor = np.ones(1) * self.lr_correction - log_items = {} log_items["episode"] = self.training_iteration + if self.training_iteration % self.re_init_every_n_epi == 0: + self.sess.run(self.init) + + lr_coor = np.ones(1) * self.lr_correction + res = [] params_time = [] delta_time = [] - input_vals = np.reshape(np.array(range(5)) + 1, [-1]) + input_vals = self._get_input_vals() for i in range(self.trace_length): params0 = self.mainQN[0].getparams() params1 = self.mainQN[1].getparams() @@ -267,14 +381,13 @@ def step(self): self.mainQN[0].policy, self.mainQN[1].policy, ] - # print("input_vals", input_vals) - update1, update2, v1, v2, policy1, policy2 = self.sess.run( + (update1, update2, v1, v2, policy1, policy2) = self.sess.run( outputs, feed_dict={ self.mainQN[0].input_place: input_vals, self.mainQN[1].input_place: input_vals, - self.mainQN[0].R1: np.reshape(self.payout_mat_2, [-1, 1]), - self.mainQN[1].R1: np.reshape(self.payout_mat_1, [-1, 1]), + self.mainQN[0].R1: np.reshape(self.payout_mat_1, [-1, 1]), + self.mainQN[1].R1: np.reshape(self.payout_mat_2, [-1, 1]), self.mainQN[0].lr_correction: lr_coor, self.mainQN[1].lr_correction: lr_coor, }, @@ -289,10 +402,21 @@ def step(self): res.append([v1[0][0] / self.norm, v2[0][0] / self.norm]) self.results.append(res) + if self.training_iteration % self.re_init_every_n_epi == ( + self.re_init_every_n_epi - 1 + ): + self.policy1 = policy1 + self.policy2 = policy2 + log_items["episodes_total"] = self.training_iteration + log_items["policy1"] = self.policy1 + log_items["policy2"] = self.policy2 return log_items + def _get_input_vals(self): + return np.reshape(np.array(range(self.n_actions ** 2 + 1)) + 1, [-1]) + def _clip_update(self, update1, update2): if self.clip_update is not None: assert self.clip_update > 0.0 @@ -382,24 +506,28 @@ def _post_process_action(self, action): return action[None, ...] # add batch dim def compute_actions(self, policy_id: str, obs_batch: list): - # because of the LSTM - assert len(obs_batch) == 1 + assert ( + len(obs_batch) == 1 + ), f"{len(obs_batch)} == 1. obs_batch: {obs_batch}" for single_obs in obs_batch: agent_to_use = self._get_agent_to_use(policy_id) obs = self._preprocess_obs(single_obs, agent_to_use) - input_vals = np.reshape(np.array(range(5)) + 1, [-1]) + input_vals = self._get_input_vals() policy = self.sess.run( - [self.mainQN[agent_to_use].policy], + [ + self.mainQN[agent_to_use].policy, + ], feed_dict={ self.mainQN[agent_to_use].input_place: input_vals, }, ) - coop_proba = policy[0][obs][0] - if coop_proba > random.random(): - action = np.array(0) - else: - action = np.array(1) + probabilities = policy[0][obs] + probabilities = torch.tensor(probabilities) + policy_for_this_state = torch.distributions.Categorical( + probs=probabilities + ) + action = policy_for_this_state.sample() action = self._post_process_action(action) diff --git a/marltoolbox/algos/lola/train_pg_tune_class_API.py b/marltoolbox/algos/lola/train_pg_tune_class_API.py index e5a7b58..2c26779 100644 --- a/marltoolbox/algos/lola/train_pg_tune_class_API.py +++ b/marltoolbox/algos/lola/train_pg_tune_class_API.py @@ -548,16 +548,17 @@ def _post_process_action(self, action): return action def compute_actions(self, policy_id: str, obs_batch: list): - # because of the LSTM assert len(obs_batch) == 1 for single_obs in obs_batch: agent_to_use = self._get_agent_to_use(policy_id) obs = self._preprocess_obs(single_obs) + obs = np.expand_dims(obs, axis=0) a = self.sess.run( [self.mainQN[agent_to_use].predict], - feed_dict={self.mainQN[agent_to_use].scalarInput: [obs]}, + # feed_dict={self.mainQN[agent_to_use].scalarInput: [obs]}, + feed_dict={self.mainQN[agent_to_use].scalarInput: obs}, ) action = self._post_process_action(a) diff --git a/marltoolbox/algos/ltft/ltft.py b/marltoolbox/algos/ltft/ltft.py index 5a91e15..203f7ad 100644 --- a/marltoolbox/algos/ltft/ltft.py +++ b/marltoolbox/algos/ltft/ltft.py @@ -12,13 +12,16 @@ from ray.rllib.execution.replay_buffer import LocalReplayBuffer from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches -from ray.rllib.execution.train_ops import TrainOneStep -from ray.rllib.execution.train_ops import UpdateTargetNetwork from ray.rllib.utils import merge_dicts from ray.rllib.utils.schedules import PiecewiseSchedule from ray.rllib.utils.typing import PolicyID, SampleBatchType from ray.rllib.utils.typing import TrainerConfigDict from ray.util.iter import LocalIterator +from ray.rllib.execution.train_ops import ( + TrainOneStep, + UpdateTargetNetwork, + TrainTFMultiGPU, +) from torch.nn import CrossEntropyLoss from marltoolbox.algos import augmented_dqn, supervised_learning, hierarchical @@ -116,7 +119,7 @@ framework="torch", ), }, - "log_level": "DEBUG", + "batch_mode": "complete_episodes", } ) @@ -133,7 +136,7 @@ PLOT_ASSEMBLAGE_TAGS = [ ("defection",), - ("defection", "chosen_percentile"), + ("defection_", "chosen_percentile"), ("punish",), ("delta_log_likelihood_coop_std", "delta_log_likelihood_coop_mean"), ("likelihood_opponent", "likelihood_approximated_opponent"), @@ -158,7 +161,15 @@ def execution_plan( workers: WorkerSet, config: TrainerConfigDict ) -> LocalIterator[dict]: """ - Modified from the execution plan of the DQNTrainer + Execution plan of the DQN algorithm. Defines the distributed dataflow. + + Args: + workers (WorkerSet): The WorkerSet for training the Polic(y/ies) + of the Trainer. + config (TrainerConfigDict): The trainer's configuration dict. + + Returns: + LocalIterator[dict]: A local iterator over training metrics. """ if config.get("prioritized_replay"): prio_args = { @@ -175,7 +186,9 @@ def execution_plan( buffer_size=config["buffer_size"], replay_batch_size=config["train_batch_size"], replay_mode=config["multiagent"]["replay_mode"], - replay_sequence_length=config["replay_sequence_length"], + replay_sequence_length=config.get("replay_sequence_length", 1), + replay_burn_in=config.get("burn_in", 0), + replay_zero_init_states=config.get("zero_init_states", True), **prio_args, ) @@ -204,10 +217,9 @@ def update_prio(item): td_error = info.get( "td_error", info[LEARNER_STATS_KEY].get("td_error") ) + samples.policy_batches[policy_id].set_get_interceptor(None) prio_dict[policy_id] = ( - samples.policy_batches[policy_id].data.get( - "batch_indexes" - ), + samples.policy_batches[policy_id].get("batch_indexes"), td_error, ) local_replay_buffer.update_priorities(prio_dict) @@ -217,11 +229,25 @@ def update_prio(item): # returned from the LocalReplay() iterator is passed to TrainOneStep to # take a SGD step, and then we decide whether to update the target network. post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b) + + if config["simple_optimizer"]: + train_step_op = TrainOneStep(workers) + else: + train_step_op = TrainTFMultiGPU( + workers=workers, + sgd_minibatch_size=config["train_batch_size"], + num_sgd_iter=1, + num_gpus=config["num_gpus"], + shuffle_sequences=True, + _fake_gpus=config["_fake_gpus"], + framework=config.get("framework"), + ) + replay_op_dqn = ( Replay(local_buffer=local_replay_buffer) .for_each(lambda x: post_fn(x, workers, config)) .for_each(LocalTrainablePolicyModifier(workers, train_dqn_only)) - .for_each(TrainOneStep(workers)) + .for_each(train_step_op) .for_each(update_prio) .for_each( UpdateTargetNetwork(workers, config["target_network_update_freq"]) diff --git a/marltoolbox/algos/ltft/ltft_torch_policy.py b/marltoolbox/algos/ltft/ltft_torch_policy.py index 05ce962..6101e5d 100644 --- a/marltoolbox/algos/ltft/ltft_torch_policy.py +++ b/marltoolbox/algos/ltft/ltft_torch_policy.py @@ -8,7 +8,7 @@ import copy import logging from collections import deque -from typing import List, Union, Optional, Dict, Tuple, TYPE_CHECKING +from typing import Optional, Dict, Tuple, TYPE_CHECKING, Union, List import numpy as np import torch @@ -16,9 +16,11 @@ from ray.rllib.evaluation import MultiAgentEpisode from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import override from ray.rllib.utils.typing import AgentID, PolicyID, TensorType +from ray.rllib.utils.torch_ops import convert_to_torch_tensor if TYPE_CHECKING: from ray.rllib.evaluation import RolloutWorker @@ -45,6 +47,9 @@ }, ) +BEING_PUNISHED = "being_punished" +PUNISHING = "punishing" + class LTFTTorchPolicy(hierarchical.HierarchicalTorchPolicy): """ @@ -63,12 +68,10 @@ class LTFTTorchPolicy(hierarchical.HierarchicalTorchPolicy): INITIALLY_ACTIVE_ALGO = COOP_POLICY_IDX def __init__(self, observation_space, action_space, config, **kwargs): - super().__init__( observation_space, action_space, config, - # after_init_nested=_init_weights_fn, **kwargs, ) @@ -100,10 +103,8 @@ def __init__(self, observation_space, action_space, config, **kwargs): self.n_steps_since_start = 0 self.last_computed_w = None - # self._episode_starting = True self.opp_previous_obs = None self.opp_new_obs = None - self._first_fake_step_played = False self.observed_n_step_in_current_epi = 0 self.data_queue = deque(maxlen=self.length_of_history) @@ -140,47 +141,39 @@ def __init__(self, observation_space, action_space, config, **kwargs): add_opponent_neg_reward=True, ) + self.view_requirements.update( + { + PUNISHING: ViewRequirement(), + } + ) + def _reset_learner_stats(self): self.learner_stats = {"learner_stats": {}} @override(hierarchical.HierarchicalTorchPolicy) - def compute_actions( - self, - obs_batch: Union[List[TensorType], TensorType], - state_batches: Optional[List[TensorType]] = None, - prev_action_batch: Union[List[TensorType], TensorType] = None, - prev_reward_batch: Union[List[TensorType], TensorType] = None, - info_batch: Optional[Dict[str, list]] = None, - episodes: Optional[List["MultiAgentEpisode"]] = None, - explore: Optional[bool] = None, - timestep: Optional[int] = None, - **kwargs, - ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: - assert ( - len(obs_batch) == 1 - ), "LE only works with sampling one step at a time" - - actions, state_out, extra_fetches = super().compute_actions( - obs_batch, - state_batches, - prev_action_batch, - prev_reward_batch, - info_batch, - episodes, - explore, - timestep, - **kwargs, + def _compute_action_helper( + self, input_dict, state_batches, seq_lens, explore, timestep + ): + obs_batch = input_dict["obs"] + assert len(obs_batch) == 1, ( + "The LTFT policy only works when sampling/infering one step " + "at a time." ) - extra_fetches["punishing"] = [self._is_punishing() for _ in obs_batch] - return actions, [], extra_fetches + actions, state_out, extra_fetches = super()._compute_action_helper( + input_dict, state_batches, seq_lens, explore, timestep + ) + extra_fetches[PUNISHING] = [ + self._is_punishing() for _ in obs_batch.tolist() + ] + + return actions, state_out, extra_fetches def _is_punishing(self): return self.active_algo_idx == self.PUNITIVE_POLICY_IDX @override(hierarchical.HierarchicalTorchPolicy) def _learn_on_batch(self, samples: SampleBatch): - ( policies_idx_to_train, policies_to_train, @@ -237,14 +230,14 @@ def _modify_batch_for_policy(self, policy_n, samples_copy): ) if policy_n == self.COOP_POLICY_IDX: samples_copy = filter_sample_batch( - samples_copy, remove=True, filter_key="punishing" + samples_copy, remove=True, filter_key=PUNISHING ) elif policy_n == self.COOP_OPP_POLICY_IDX: samples_copy[samples_copy.ACTIONS] = np.array( samples_copy.data[postprocessing.OPPONENT_ACTIONS] ) samples_copy = filter_sample_batch( - samples_copy, remove=True, filter_key="being_punished" + samples_copy, remove=True, filter_key=BEING_PUNISHED ) else: raise ValueError() @@ -260,7 +253,7 @@ def _modify_batch_for_policy(self, policy_n, samples_copy): # this policy. This is because the opponent is not using its # cooperative policy at that time. samples_copy = filter_sample_batch( - samples_copy, remove=True, filter_key="being_punished" + samples_copy, remove=True, filter_key=BEING_PUNISHED ) else: raise ValueError() @@ -268,30 +261,25 @@ def _modify_batch_for_policy(self, policy_n, samples_copy): return samples_copy def on_episode_step(self, episode, policy_id, policy_ids, *args, **kwargs): - if self._first_fake_step_played: - ( - opp_previous_obs, - opp_a, - being_punished_by_opp, - ) = self._get_information_from_opponent( - episode, policy_id, policy_ids + ( + opp_previous_obs, + opp_a, + being_punished_by_opp, + ) = self._get_information_from_opponent(episode, policy_id, policy_ids) + + self.being_punished_by_opp = being_punished_by_opp + if not self.being_punished_by_opp: + self.n_steps_since_start += 1 + self._put_log_likelihood_in_data_buffer( + opp_previous_obs, opp_a, self.data_queue ) - self.being_punished_by_opp = being_punished_by_opp - if not self.being_punished_by_opp: - self.n_steps_since_start += 1 - self._put_log_likelihood_in_data_buffer( - opp_previous_obs, opp_a, self.data_queue - ) - - if self.remaining_punishing_time > 0: - self.n_punishment_steps_in_current_epi += 1 - else: - self.n_cooperation_steps_in_current_epi += 1 - - self.observed_n_step_in_current_epi += 1 + if self.remaining_punishing_time > 0: + self.n_punishment_steps_in_current_epi += 1 else: - self._first_fake_step_played = True + self.n_cooperation_steps_in_current_epi += 1 + + self.observed_n_step_in_current_epi += 1 def _get_information_from_opponent(self, episode, agent_id, agent_ids): opp_agent_id = [one_id for one_id in agent_ids if one_id != agent_id][ @@ -299,7 +287,7 @@ def _get_information_from_opponent(self, episode, agent_id, agent_ids): ] opp_a = episode.last_action_for(opp_agent_id) being_punished_by_opp = episode.last_pi_info_for(opp_agent_id).get( - "punishing", False + PUNISHING, False ) return self.opp_previous_obs, opp_a, being_punished_by_opp @@ -308,8 +296,7 @@ def on_observation_fn(self, opp_new_obs): # observation produced by this action. But we need the # observation that cause the agent to play this action # thus the observation n-1 - if self._first_fake_step_played: - self.opp_previous_obs = self.opp_new_obs + self.opp_previous_obs = self.opp_new_obs self.opp_new_obs = opp_new_obs def on_episode_end(self, base_env, *args, **kwargs): @@ -386,15 +373,11 @@ def postprocess_trajectory( def _put_log_likelihood_in_data_buffer(self, s, a, data_queue, log=True): s = torch.from_numpy(s).unsqueeze(dim=0) - log_likelihood_opponent_cooperating = ( - compute_log_likelihoods_wt_exploration( - self.algorithms[self.COOP_OPP_POLICY_IDX], a, s - ) + log_likelihood_opponent_cooperating = compute_log_likelihoods( + self.algorithms[self.COOP_OPP_POLICY_IDX], [a], s ) - log_likelihood_approximated_opponent = ( - compute_log_likelihoods_wt_exploration( - self.algorithms[self.SPL_OPP_POLICY_IDX], a, s - ) + log_likelihood_approximated_opponent = compute_log_likelihoods( + self.algorithms[self.SPL_OPP_POLICY_IDX], [a], s ) log_likelihood_opponent_cooperating = float( log_likelihood_opponent_cooperating @@ -500,20 +483,20 @@ def on_postprocess_trajectory( opp_policy_id = all_agent_keys[0] opp_is_broadcast_punishment_state = ( - "punishing" in original_batches[opp_policy_id][1].data.keys() + PUNISHING in original_batches[opp_policy_id][1].keys() ) if opp_is_broadcast_punishment_state: - postprocessed_batch.data["being_punished"] = copy.deepcopy( - original_batches[all_agent_keys[0]][1].data["punishing"] + postprocessed_batch.data[BEING_PUNISHED] = copy.deepcopy( + original_batches[all_agent_keys[0]][1][PUNISHING] ) else: - postprocessed_batch.data["being_punished"] = [False] * len( + postprocessed_batch.data[BEING_PUNISHED] = [False] * len( postprocessed_batch[postprocessed_batch.OBS] ) -# Modified from torch_policy_template -def compute_log_likelihoods_wt_exploration( +# Modified from torch_policy +def compute_log_likelihoods( policy, actions: Union[List[TensorType], TensorType], obs_batch: Union[List[TensorType], TensorType], @@ -521,6 +504,7 @@ def compute_log_likelihoods_wt_exploration( prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None, prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None, ) -> TensorType: + if policy.action_sampler_fn and policy.action_distribution_fn is None: raise ValueError( "Cannot compute log-prob/likelihood w/o an " @@ -537,34 +521,63 @@ def compute_log_likelihoods_wt_exploration( if prev_reward_batch is not None: input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch seq_lens = torch.ones(len(obs_batch), dtype=torch.int32) + state_batches = [ + convert_to_torch_tensor(s, policy.device) + for s in (state_batches or []) + ] # Exploration hook before each forward pass. policy.exploration.before_compute_actions(explore=False) # Action dist class and inputs are generated via custom function. if policy.action_distribution_fn: - dist_inputs, dist_class, _ = policy.action_distribution_fn( - policy=policy, - model=policy.model, - obs_batch=input_dict[SampleBatch.CUR_OBS], - explore=False, - is_training=False, - ) + + # Try new action_distribution_fn signature, supporting + # state_batches and seq_lens. + try: + (dist_inputs, dist_class, _,) = policy.action_distribution_fn( + policy, + policy.model, + input_dict=input_dict, + state_batches=state_batches, + seq_lens=seq_lens, + explore=False, + is_training=False, + ) + # Trying the old way (to stay backward compatible). + # TODO: Remove in future. + except TypeError as e: + if ( + "positional argument" in e.args[0] + or "unexpected keyword argument" in e.args[0] + ): + dist_inputs, dist_class, _ = policy.action_distribution_fn( + policy=policy, + model=policy.model, + obs_batch=input_dict[SampleBatch.CUR_OBS], + explore=False, + is_training=False, + ) + else: + raise e + # Default action-dist inputs calculation. else: dist_class = policy.dist_class dist_inputs, _ = policy.model(input_dict, state_batches, seq_lens) action_dist = dist_class(dist_inputs, policy.model) + # ADDITION 1/1 STARTS if policy.config["explore"]: # Adding that because of a bug in TorchCategorical # which modify dist_inputs through action_dist: - _, _ = policy.exploration.get_exploration_action( + _ = policy.exploration.get_exploration_action( action_distribution=action_dist, timestep=policy.global_timestep, explore=policy.config["explore"], ) action_dist = dist_class(dist_inputs, policy.model) + # ADDITION 1/1 ENDS log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS]) diff --git a/marltoolbox/algos/population.py b/marltoolbox/algos/population.py index 81cb229..ea5e0b2 100644 --- a/marltoolbox/algos/population.py +++ b/marltoolbox/algos/population.py @@ -1,12 +1,9 @@ import copy import random -from typing import List, Union, Optional, Dict, Tuple -from ray.rllib.evaluation import MultiAgentEpisode from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import merge_dicts -from ray.rllib.utils.annotations import override -from ray.rllib.utils.typing import TensorType +from ray.rllib.utils import override from marltoolbox.algos import hierarchical from marltoolbox.utils import miscellaneous, restore @@ -17,7 +14,7 @@ # To configure "policy_checkpoints": [], "policy_id_to_load": None, - "nested_policies": None, + # "nested_policies": None, "freeze_algo": True, "use_random_algo": True, "use_algo_in_order": False, @@ -54,12 +51,10 @@ def set_algo_to_use(self): """ Called by a callback at the start of every episode. """ - if self.switch_of_algo_counter == self.switch_of_algo_every_n_epi: self.switch_of_algo_counter = 0 self._set_algo_to_use() - else: - self.switch_of_algo_counter += 1 + self.switch_of_algo_counter += 1 def _set_algo_to_use(self): self._select_algo_idx_to_use() @@ -75,10 +70,17 @@ def _select_algo_idx_to_use(self): else: raise ValueError() + self._to_log[ + "PopulationOfIdenticalAlgo_active_checkpoint_idx" + ] = self.active_checkpoint_idx + def _use_random_algo(self): self.active_checkpoint_idx = random.randint( 0, len(self.policy_checkpoints) - 1 ) + self._to_log[ + "PopulationOfIdenticalAlgo_active_checkpoint_idx" + ] = self.active_checkpoint_idx def _use_next_algo(self): self.active_checkpoint_idx += 1 @@ -110,22 +112,7 @@ def set_weights(self, weights): self.algorithms[self.active_algo_idx].set_weights(weights) @override(hierarchical.HierarchicalTorchPolicy) - def compute_actions( - self, - obs_batch: Union[List[TensorType], TensorType], - state_batches: Optional[List[TensorType]] = None, - prev_action_batch: Union[List[TensorType], TensorType] = None, - prev_reward_batch: Union[List[TensorType], TensorType] = None, - info_batch: Optional[Dict[str, list]] = None, - episodes: Optional[List["MultiAgentEpisode"]] = None, - explore: Optional[bool] = None, - timestep: Optional[int] = None, - **kwargs, - ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: - return self.algorithms[self.active_algo_idx].compute_actions(obs_batch) - - @override(hierarchical.HierarchicalTorchPolicy) - def learn_on_batch(self, samples: SampleBatch): + def _learn_on_batch(self, samples: SampleBatch): if not self.freeze_algo: # TODO maybe need to call optimizer to update the LR of the nested optimizers learner_stats = {"learner_stats": {}} @@ -165,6 +152,11 @@ def on_episode_start( **kwargs, ): self.set_algo_to_use() + if hasattr(self.algorithms[self.active_algo_idx], "on_episode_start"): + self.algorithms[self.active_algo_idx].on_episode_start( + *args, + **kwargs, + ) def modify_config_to_use_population( diff --git a/marltoolbox/algos/psro.py b/marltoolbox/algos/psro.py new file mode 100644 index 0000000..39936b6 --- /dev/null +++ b/marltoolbox/algos/psro.py @@ -0,0 +1,346 @@ +# Code modified from: https://github.com/deepmind/open_spiel/blob/master/open_spiel/python/examples/psro_v2_example.py + + +import time + +from absl import app +from absl import flags +import numpy as np + +# pylint: disable=g-bad-import-order +import pyspiel +import tensorflow.compat.v1 as tf + +# pylint: enable=g-bad-import-order + +from open_spiel.python import policy +from open_spiel.python import rl_environment +from open_spiel.python.algorithms import exploitability +from open_spiel.python.algorithms import get_all_states +from open_spiel.python.algorithms import policy_aggregator +from open_spiel.python.algorithms.psro_v2 import best_response_oracle +from open_spiel.python.algorithms.psro_v2 import psro_v2 +from open_spiel.python.algorithms.psro_v2 import rl_oracle +from open_spiel.python.algorithms.psro_v2 import rl_policy +from open_spiel.python.algorithms.psro_v2 import strategy_selectors +from open_spiel.python.games import simple_bargaining, kuhn_poker, tic_tac_toe +from marltoolbox.utils.miscellaneous import get_random_seeds + +from ray import tune + +from open_spiel.python.examples.psro_v2_example import ( + _print_base_policies, + print_policy_analysis, +) + + +class PSROTrainer(tune.Trainable): + def setup(self, config): + self.seed = config["seed"] + self.n_players = config["n_players"] + self.game_name = config["game_name"] + self.oracle_type = config["oracle_type"] + self.training_strategy_selector = config["training_strategy_selector"] + self.rectifier = config["rectifier"] + self.sims_per_entry = config["sims_per_entry"] + self.number_policies_selected = config["number_policies_selected"] + self.meta_strategy_method = config["meta_strategy_method"] + self.symmetric_game = config["symmetric_game"] + self.verbose = config["verbose"] + self.loss_str = config["loss_str"] + self.hidden_layer_size = config["hidden_layer_size"] + self.n_hidden_layers = config["n_hidden_layers"] + self.batch_size = config["batch_size"] + self.entropy_cost = config["entropy_cost"] + self.critic_learning_rate = config["critic_learning_rate"] + self.pi_learning_rate = config["pi_learning_rate"] + self.num_q_before_pi = config["num_q_before_pi"] + self.optimizer_str = config["optimizer_str"] + self.number_training_episodes = config["number_training_episodes"] + self.self_play_proportion = config["self_play_proportion"] + self.sigma = config["sigma"] + self.dqn_learning_rate = config["dqn_learning_rate"] + self.update_target_network_every = config[ + "update_target_network_every" + ] + self.learn_every = config["learn_every"] + self.num_iterations = config["num_iterations"] + + np.random.seed(self.seed) + try: + game = pyspiel.load_game_as_turn_based( + self.game_name, + {"players": pyspiel.GameParameter(self.n_players)}, + ) + except pyspiel.SpielError: + game = pyspiel.load_game_as_turn_based(self.game_name) + + self.env = rl_environment.Environment(game) + + # Initialize oracle and agents + with tf.Session() as sess: + if self.oracle_type == "DQN": + oracle, agents = self.init_dqn_responder(sess, self.env) + elif self.oracle_type == "PG": + oracle, agents = self.init_pg_responder(sess, self.env) + elif self.oracle_type == "BR": + oracle, agents = self.init_br_responder(self.env) + sess.run(tf.global_variables_initializer()) + + sample_from_marginals = ( + True # TODO(somidshafiei) set False for alpharank + ) + training_strategy_selector = ( + self.training_strategy_selector or strategy_selectors.probabilistic + ) + + self.g_psro_solver = psro_v2.PSROSolver( + self.env.game, + oracle, + initial_policies=agents, + training_strategy_selector=training_strategy_selector, + rectifier=self.rectifier, + sims_per_entry=self.sims_per_entry, + number_policies_selected=self.number_policies_selected, + meta_strategy_method=self.meta_strategy_method, + prd_iterations=50000, + prd_gamma=1e-10, + sample_from_marginals=sample_from_marginals, + symmetric_game=self.symmetric_game, + ) + + self.start_time = time.time() + + def step(self): + to_report = {"training_iteration": self.training_iteration} + if self.verbose: + print("Iteration : {}".format(self.training_iteration)) + print("Time so far: {}".format(time.time() - self.start_time)) + self.g_psro_solver.iteration() + meta_game = self.g_psro_solver.get_meta_game() + meta_probabilities = self.g_psro_solver.get_meta_strategies() + policies = self.g_psro_solver.get_policies() + + _print_base_policies(policies) + + if self.verbose: + print("Meta game : {}".format(meta_game)) + print("Probabilities : {}".format(meta_probabilities)) + to_report["meta_probabilities"] = meta_probabilities + to_report["meta_game"] = meta_game + # The following lines only work for sequential games for the moment. + if ( + self.env.game.get_type().dynamics + == pyspiel.GameType.Dynamics.SEQUENTIAL + ): + aggregator = policy_aggregator.PolicyAggregator(self.env.game) + aggr_policies = aggregator.aggregate( + range(self.n_players), policies, meta_probabilities + ) + + exploitabilities, expl_per_player = exploitability.nash_conv( + self.env.game, aggr_policies, return_only_nash_conv=False + ) + + _ = print_policy_analysis(policies, self.env.game, self.verbose) + if self.verbose: + print("Exploitabilities : {}".format(exploitabilities)) + print( + "Exploitabilities per player : {}".format(expl_per_player) + ) + to_report["exploitabilities"] = exploitabilities + to_report["expl_per_player"] = expl_per_player + + to_report["finished"] = ( + False if self.training_iteration < self.num_iterations else True + ) + return to_report + + def init_pg_responder(self, sess, env): + """Initializes the Policy Gradient-based responder and agents.""" + info_state_size = env.observation_spec()["info_state"][0] + num_actions = env.action_spec()["num_actions"] + + agent_class = rl_policy.PGPolicy + + agent_kwargs = { + "session": sess, + "info_state_size": info_state_size, + "num_actions": num_actions, + "loss_str": self.loss_str, + "loss_class": False, + "hidden_layers_sizes": [self.hidden_layer_size] + * self.n_hidden_layers, + "batch_size": self.batch_size, + "entropy_cost": self.entropy_cost, + "critic_learning_rate": self.critic_learning_rate, + "pi_learning_rate": self.pi_learning_rate, + "num_critic_before_pi": self.num_q_before_pi, + "optimizer_str": self.optimizer_str, + } + oracle = rl_oracle.RLOracle( + env, + agent_class, + agent_kwargs, + number_training_episodes=self.number_training_episodes, + self_play_proportion=self.self_play_proportion, + sigma=self.sigma, + ) + + agents = [ + agent_class( # pylint: disable=g-complex-comprehension + env, player_id, **agent_kwargs + ) + for player_id in range(self.n_players) + ] + for agent in agents: + agent.freeze() + return oracle, agents + + def init_br_responder(self, env): + """Initializes the tabular best-response based responder and agents.""" + random_policy = policy.TabularPolicy(env.game) + oracle = best_response_oracle.BestResponseOracle( + game=env.game, policy=random_policy + ) + agents = [random_policy.__copy__() for _ in range(self.n_players)] + return oracle, agents + + def init_dqn_responder(self, sess, env): + """Initializes the Policy Gradient-based responder and agents.""" + state_representation_size = env.observation_spec()["info_state"][0] + num_actions = env.action_spec()["num_actions"] + + agent_class = rl_policy.DQNPolicy + agent_kwargs = { + "session": sess, + "state_representation_size": state_representation_size, + "num_actions": num_actions, + "hidden_layers_sizes": [self.hidden_layer_size] + * self.n_hidden_layers, + "batch_size": self.batch_size, + "learning_rate": self.dqn_learning_rate, + "update_target_network_every": self.update_target_network_every, + "learn_every": self.learn_every, + "optimizer_str": self.optimizer_str, + } + oracle = rl_oracle.RLOracle( + env, + agent_class, + agent_kwargs, + number_training_episodes=self.number_training_episodes, + self_play_proportion=self.self_play_proportion, + sigma=self.sigma, + ) + + agents = [ + agent_class( # pylint: disable=g-complex-comprehension + env, player_id, **agent_kwargs + ) + for player_id in range(self.n_players) + ] + for agent in agents: + agent.freeze() + return oracle, agents + + # def save_checkpoint(self, checkpoint_dir): + # path = os.path.join(checkpoint_dir, "checkpoint.json") + # tf_checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") + # tf_checkpoint_dir, tf_checkpoint_filename = os.path.split( + # tf_checkpoint_path + # ) + # checkpoint = { + # "timestep": self.timestep, + # "tf_checkpoint_dir": tf_checkpoint_dir, + # "tf_checkpoint_filename": tf_checkpoint_filename, + # } + # with open(path, "w") as f: + # json.dump(checkpoint, f, sort_keys=True, indent=4) + # + # # TF v1 + # save_path = self.saver.save(self.sess, f"{tf_checkpoint_path}.ckpt") + # + # return path + # + # def load_checkpoint(self, checkpoint_path): + # + # checkpoint_path = os.path.expanduser(checkpoint_path) + # print("Loading Model...", checkpoint_path) + # with open(checkpoint_path, "r") as f: + # checkpoint = json.load(f) + # print("checkpoint", checkpoint) + # + # # Support VM and local (manual) loading + # tf_checkpoint_dir, _ = os.path.split(checkpoint_path) + # print("tf_checkpoint_dir", tf_checkpoint_dir) + # ckpt = tf.train.get_checkpoint_state( + # tf_checkpoint_dir, + # latest_filename=f'{checkpoint["tf_checkpoint_filename"]}', + # ) + # tail, head = os.path.split(ckpt.model_checkpoint_path) + # ckpt.model_checkpoint_path = os.path.join(tf_checkpoint_dir, head) + # self.saver.restore(self.sess, ckpt.model_checkpoint_path) + # + # def cleanup(self): + # self.sess.close() + # super().cleanup() + # + # def compute_actions(self, policy_id: str, obs_batch: list): + # # because of the LSTM + # assert len(obs_batch) == 1 + # + # for single_obs in obs_batch: + # agent_to_use = self._get_agent_to_use(policy_id) + # obs = self._preprocess_obs(single_obs, agent_to_use) + # a, lstm_s = self.sess.run( + # [ + # self.mainPN_step[agent_to_use].predict, + # self.mainPN_step[agent_to_use].lstm_state_output, + # ], + # feed_dict={ + # self.mainPN_step[agent_to_use].state_input: obs, + # self.mainPN_step[agent_to_use].lstm_state: self.lstm_state[ + # agent_to_use + # ], + # self.mainPN_step[agent_to_use].is_training: False, + # }, + # ) + # self.lstm_state[agent_to_use] = lstm_s + # action = self._post_process_action(a) + # + # state_out = [] + # extra_fetches = {} + # return action, state_out, extra_fetches + # + # def _get_agent_to_use(self, policy_id): + # if policy_id == "player_red": + # agent_n = 0 + # elif policy_id == "player_blue": + # agent_n = 1 + # else: + # raise ValueError(f"policy_id {policy_id}") + # return agent_n + # + # def _preprocess_obs(self, single_obs, agent_to_use): + # single_obs = single_obs[None, ...] # add batch dim + # + # # Compensate for the batch norm not in evaluation mode + # while len(self.obs_batch) < self.batch_size: + # self.obs_batch.append(single_obs) + # self.obs_batch.append(single_obs) + # single_obs = np.concatenate(list(self.obs_batch), axis=0) + # return single_obs + # + # def _post_process_action(self, action): + # # Compensate for the batch norm not in evaluation mode + # if isinstance(action, Iterable): + # action = action[-1] + # + # return action[None, ...] # add batch dim + # + # def reset_compute_actions_state(self): + # self.lstm_state = [] + # for agent in range(self.n_agents): + # self.lstm_state.append( + # np.zeros((self.batch_size, self.h_size[agent] * 2)) + # ) diff --git a/marltoolbox/algos/psro_hardcoded.py b/marltoolbox/algos/psro_hardcoded.py new file mode 100644 index 0000000..9dd3635 --- /dev/null +++ b/marltoolbox/algos/psro_hardcoded.py @@ -0,0 +1,460 @@ +# Code modified from: https://github.com/deepmind/open_spiel/blob/master/open_spiel/python/examples/psro_v2_example.py + + +import copy +import os +import pickle +import collections + +import numpy as np +import torch +from ray import tune +from ray.rllib.agents.callbacks import DefaultCallbacks +from ray.rllib.agents.pg import PGTorchPolicy, DEFAULT_CONFIG +from ray.rllib.evaluation.sample_batch_builder import ( + MultiAgentSampleBatchBuilder, + SampleBatchBuilder, +) +from ray.rllib.examples.policy.random_policy import RandomPolicy +from ray.rllib.policy.policy import clip_action +from ray.rllib.env.base_env import _DUMMY_AGENT_ID + +from marltoolbox.experiments.tune_class_api.various_algo_meta_game import ( + _compute_policy_wt_alpha_rank, +) + + +class MySampleBatchBuilder(SampleBatchBuilder): + def __init__(self): + # deprecation_warning( + # old="SampleBatchBuilder", + # new="child class of `SampleCollector`", + # error=False) + self.buffers = collections.defaultdict(list) + self.count = 0 + + +class MyMultiAgentSampleBatchBuilder(MultiAgentSampleBatchBuilder): + def __init__(self, policy_map, clip_rewards, callbacks): + # deprecation_warning(old="MultiAgentSampleBatchBuilder", error=False) + self.policy_map = policy_map + self.clip_rewards = clip_rewards + # Build the Policies' SampleBatchBuilders. + self.policy_builders = { + k: MySampleBatchBuilder() for k in policy_map.keys() + } + # Whenever we observe a new agent, add a new SampleBatchBuilder for + # this agent. + self.agent_builders = {} + # Internal agent-to-policy map. + self.agent_to_policy = {} + self.callbacks = callbacks + # Number of "inference" steps taken in the environment. + # Regardless of the number of agents involved in each of these steps. + self.count = 0 + + def add_values(self, agent_id, policy_id, **values) -> None: + """Add the given dictionary (row) of values to this batch. + + Args: + agent_id (obj): Unique id for the agent we are adding values for. + policy_id (obj): Unique id for policy controlling the agent. + values (dict): Row of values to add for this agent. + """ + + if agent_id not in self.agent_builders: + self.agent_builders[agent_id] = MySampleBatchBuilder() + self.agent_to_policy[agent_id] = policy_id + + # Include the current agent id for multi-agent algorithms. + if agent_id != _DUMMY_AGENT_ID: + values["agent_id"] = agent_id + + self.agent_builders[agent_id].add_values(**values) + + +class PSROTrainer(tune.Trainable): + def setup(self, config): + self.config = config + self.env = self.config["env_class"](self.config["env_config"]) + self.eval_cell_over_n_epi = self.config["eval_cell_over_n_epi"] + self.batch_size = self.config["oracle_config"]["train_batch_size"] + self.policy_ids = self.config["env_config"]["players_ids"] + self.n_steps_by_epi = self.config["env_config"]["n_steps_by_epi"] + self.train_oracle_n_epi = self.config["train_oracle_n_epi"] + self.num_iterations = self.config["num_iterations"] + self.training = self.config["training"] + self.verbose = self.config["verbose"] + self.oracle_config = self.config["oracle_config"] + self.center_returns = self.config["center_returns"] + self.to_report_base_game = {} + self.meta_policies = { + policy_id: self._init_meta_player() + for policy_id in self.policy_ids + } + self.init_obs = self.env.reset() + if self.training: + self._init_meta_payoff_matrix() + + def _init_meta_player(self): + meta_policy = [1.0] + return { + "meta_policy": self._create_policy_sampler(meta_policy), + "policies": [ + RandomPolicy( + self.env.OBSERVATION_SPACE, self.env.ACTION_SPACE, {} + ) + ], + } + + def _create_policy_sampler(self, meta_policy_proba): + meta_policy_proba = torch.tensor(meta_policy_proba) + return torch.distributions.Categorical(probs=meta_policy_proba) + + def _init_meta_payoff_matrix(self): + self.meta_game_payoff_matrix = np.zeros((1, 1, len(self.policy_ids))) + assert len(self.meta_policies[self.policy_ids[0]]["policies"]) == 1 + assert len(self.meta_policies[self.policy_ids[1]]["policies"]) == 1 + self.meta_game_payoff_matrix[0, 0, :] = self._compute_joint_payoffs( + self.meta_policies[self.policy_ids[0]]["policies"][0], + self.meta_policies[self.policy_ids[1]]["policies"][0], + ) + + def _compute_joint_payoffs(self, policy_pl_1, policy_pl_2): + # self._init_to_report(in_cell_eval=True) + self._reset_total_welfare() + self.players = { + self.policy_ids[0]: policy_pl_1, + self.policy_ids[1]: policy_pl_2, + } + self.multi_agent_batch_builder = self._init_batch_builder() + for epi_n in range(self.eval_cell_over_n_epi): + self._play_one_episode() + + total_r_pl_1, total_r_pl_2 = self.total_welfare + self._reset_total_welfare() + n_steps_playerd = int(self.eval_cell_over_n_epi * self.n_steps_by_epi) + mean_r_pl1 = total_r_pl_1 / n_steps_playerd + mean_r_pl2 = total_r_pl_2 / n_steps_playerd + return [mean_r_pl1, mean_r_pl2] + + def step(self): + self._init_to_report(in_cell_eval=False) + for policy_id in self.policy_ids: + br_policy = self.train_one_br(policy_id) + self.meta_policies[policy_id]["policies"].append(br_policy) + self.compute_new_meta_policies() + + self.to_report["finished"] = ( + False + if self.training_iteration < (self.num_iterations - 1) + else True + ) + self._fill_to_report() + return self.to_report + + def _init_to_report(self, in_cell_eval, in_rllib_eval=False): + self.to_report = { + "training_iteration": self.training_iteration, + "in_cell_eval": in_cell_eval, + "in_rllib_eval": in_rllib_eval, + } + + def _fill_to_report(self): + # Add base policies actions + for player_id, meta_policy in self.meta_policies.items(): + self.to_report[f"{player_id}_meta_policy"] = meta_policy[ + "meta_policy" + ].probs + for pi_idx, base_policy in enumerate(meta_policy["policies"]): + actions = self._helper_compute_action( + base_policy, obs=self.init_obs[player_id] + ) + self.to_report[f"{player_id}_pi_idx_{pi_idx}_act"] = actions + + # Add base policies actions + if hasattr(self, "meta_game_payoff_matrix"): + self.to_report[ + f"meta_game_payoff_matrix_pl0" + ] = self.meta_game_payoff_matrix[..., 0] + self.to_report[ + f"meta_game_payoff_matrix_pl1" + ] = self.meta_game_payoff_matrix[..., 1] + + def compute_new_meta_policies(self): + + self._fill_new_meta_payoff_table() + + policy_player_1, policy_player_2 = _compute_policy_wt_alpha_rank( + [ + self.meta_game_payoff_matrix[..., 0], + self.meta_game_payoff_matrix[..., 1], + ] + ) + policy_player_1, policy_player_2 = self._clamp_policies_normalize( + [policy_player_1, policy_player_2] + ) + self.meta_policies[self.policy_ids[0]][ + "meta_policy" + ] = self._create_policy_sampler(policy_player_1) + self.meta_policies[self.policy_ids[1]][ + "meta_policy" + ] = self._create_policy_sampler(policy_player_2) + + def _clamp_policies_normalize(self, policies): + pi_clamped = [] + for pi in policies: + assert not (any(pi > 1.01) or any(pi < -0.01)), f"pi {pi}" + pi = pi / pi.sum() + pi = pi.clamp(min=0.0, max=1.0) + pi_clamped.append(pi) + return pi_clamped + + def _fill_new_meta_payoff_table(self): + prev_mat_shape = self.meta_game_payoff_matrix.shape + new_mat_shape = list(prev_mat_shape) + new_mat_shape[0] += 1 + new_mat_shape[1] += 1 + + new_payoff_mat = np.zeros(new_mat_shape) + new_payoff_mat[:-1, :-1, :] = self.meta_game_payoff_matrix + self.meta_game_payoff_matrix = new_payoff_mat + + # Fill last row + for pl_2_idx in range(new_payoff_mat.shape[1]): + pl_1_idx = new_payoff_mat.shape[0] - 1 + self._fill_meta_game_cell(pl_1_idx, pl_2_idx) + + # Fill last col + for pl_1_idx in range(new_payoff_mat.shape[0] - 1): + pl_2_idx = new_payoff_mat.shape[1] - 1 + self._fill_meta_game_cell(pl_1_idx, pl_2_idx) + + print("meta_game_payoff_matrix") + print("pl1", self.meta_game_payoff_matrix[..., 0]) + print("pl2", self.meta_game_payoff_matrix[..., 1]) + + def _fill_meta_game_cell(self, pl_1_idx, pl_2_idx): + policy_pl_1 = self.meta_policies[self.policy_ids[0]]["policies"][ + pl_1_idx + ] + policy_pl_2 = self.meta_policies[self.policy_ids[1]]["policies"][ + pl_2_idx + ] + self.meta_game_payoff_matrix[ + pl_1_idx, pl_2_idx, : + ] = self._compute_joint_payoffs(policy_pl_1, policy_pl_2) + + def train_one_br(self, policy_id): + policy = self._init_pg_policy() + self._get_base_players(policy, policy_id) + self.multi_agent_batch_builder = self._init_batch_builder() + self.to_report_base_game = {} + self._reset_total_welfare() + self.n_steps_in_batch = 0 + for i in range(self.train_oracle_n_epi): + self._play_one_episode() + self.multi_agent_batch_builder.postprocess_batch_so_far() + self.n_steps_in_batch += 1 + if self.n_steps_in_batch == self.batch_size: + self._optimize_weights(policy, policy_id) + self._get_base_players(policy, policy_id) + self.multi_agent_batch_builder = self._init_batch_builder() + if self.verbose: + print( + policy_id, + "self.total_welfare", + self.total_welfare[0] / self.batch_size, + self.total_welfare[1] / self.batch_size, + ) + self._reset_total_welfare() + if self.verbose: + print( + policy_id, + "self.to_report_base_game", + self.to_report_base_game, + ) + self.to_report_base_game = {} + self.n_steps_in_batch = 0 + + return policy + + def _reset_total_welfare(self): + self.total_welfare = [0.0] * len(self.policy_ids) + + def _init_batch_builder(self): + return MyMultiAgentSampleBatchBuilder( + policy_map={ + player_id: player for player_id, player in self.players.items() + }, + clip_rewards=False, + callbacks=DefaultCallbacks(), + ) + + def _init_pg_policy(self): + pg_policy = PGTorchPolicy( + self.env.OBSERVATION_SPACE, + self.env.ACTION_SPACE, + self.oracle_config, + ) + # for v in dir(pg_policy): + # if isinstance(v, torch.nn.Module): + # print(v) + return pg_policy + + def _get_base_players(self, policy, policy_id): + opp_policy = self._sample_opponent_policy(policy_id) + self.players = {} + for one_policy_id in self.policy_ids: + if one_policy_id == policy_id: + self.players[one_policy_id] = policy + else: + self.players[one_policy_id] = opp_policy + + def _sample_opponent_policy(self, policy_id): + self.players = {} + for opp_policy_id in self.policy_ids: + if opp_policy_id != policy_id: + return self._sample_base_policy(opp_policy_id) + + def _sample_base_policy(self, policy_id): + policy_idx = self.meta_policies[policy_id]["meta_policy"].sample() + self.to_report[f"base_pi_idx_{policy_id}"] = policy_idx + return self.meta_policies[policy_id]["policies"][policy_idx] + + def _play_one_episode(self): + obs_before_act = self.env.reset() + done = {"__all__": False} + while not done["__all__"]: + obs_after_act, actions, rewards, done = self._play_one_step( + obs_before_act + ) + self._add_step_in_batch_builder_buffer( + obs_before_act, actions, rewards, done + ) + obs_before_act = obs_after_act + + def _play_one_step(self, obs_before_act): + actions = { + player_id: self._helper_compute_action( + player_policy, obs_before_act[player_id] + ) + for player_id, player_policy in self.players.items() + } + obs_after_act, rewards, done, info = self.env.step(actions) + self.to_report_base_game.update(info) + + return obs_after_act, actions, rewards, done + + def _helper_compute_action(self, player_policy, obs): + return clip_action( + player_policy.compute_actions([obs])[0][0], + player_policy.action_space_struct, + ) + + def _add_step_in_batch_builder_buffer( + self, obs_before_act, actions, rewards, done + ): + for i, policy_id in enumerate(self.policy_ids): + self.total_welfare[i] += rewards[policy_id] + + step_player_values = { + "eps_id": self.training_iteration, + "obs": obs_before_act[policy_id], + "actions": actions[policy_id], + "rewards": rewards[policy_id], + "dones": done[policy_id], + } + self.multi_agent_batch_builder.add_values( + agent_id=policy_id, policy_id=policy_id, **step_player_values + ) + + def _optimize_weights(self, policy, policy_id): + multiagent_batch = self.multi_agent_batch_builder.build_and_reset() + if self.center_returns: + multiagent_batch = self._center_reward(multiagent_batch, policy_id) + stats = policy.learn_on_batch( + multiagent_batch.policy_batches[policy_id] + ) + + def _center_reward(self, multiagent_batch, player_id): + multiagent_batch.policy_batches[player_id]["rewards"] = ( + multiagent_batch.policy_batches[player_id]["rewards"] + - multiagent_batch.policy_batches[player_id]["rewards"].mean() + ) + return multiagent_batch + + def save_checkpoint(self, checkpoint_dir): + all_weights = self.get_weights() + checkpoint_dir = os.path.expanduser(checkpoint_dir) + checkpoint_path = os.path.join(checkpoint_dir, "psro_hardcoded_ckpt.p") + with open(checkpoint_path, "wb") as f: + pickle.dump(all_weights, f, protocol=pickle.HIGHEST_PROTOCOL) + return checkpoint_path + + def load_checkpoint(self, checkpoint_path): + with open(checkpoint_path, "rb") as f: + all_weights = pickle.load(f) + self.set_weights(all_weights) + + def get_weights(self): + save_object = { + k: {"meta_policy": v["meta_policy"]} + for k, v in self.meta_policies.items() + } + + for policy_id, meta_policy in save_object.items(): + weights = [] + for base_policy in self.meta_policies[policy_id]["policies"]: + if not isinstance(base_policy, RandomPolicy): + weights.append(base_policy.get_weights()) + else: + weights.append(None) + save_object[policy_id]["policies"] = weights + return save_object + + def set_weights(self, save_object): + meta_policies_save = copy.deepcopy(save_object) + for policy_id, meta_policy in self.meta_policies.items(): + for pi_idx, weights in enumerate( + meta_policies_save[policy_id]["policies"] + ): + if not len(meta_policy["policies"]) > pi_idx: + new_policy = self._init_pg_policy() + meta_policy["policies"].append(new_policy) + base_policy = meta_policy["policies"][pi_idx] + if not isinstance(base_policy, RandomPolicy): + base_policy.set_weights(weights) + self.meta_policies[policy_id]["meta_policy"] = meta_policies_save[ + policy_id + ]["meta_policy"] + + def cleanup(self): + super().cleanup() + + def compute_actions(self, policy_id: str, obs_batch: list): + assert len(obs_batch) == 1 + + for single_obs in obs_batch: + player_policy = self.base_policies[policy_id] + obs = self._preprocess_obs(single_obs) + a = self._helper_compute_action(player_policy, obs) + + action = self._post_process_action(a) + + state_out = [] + extra_fetches = {} + return action, state_out, extra_fetches + + def _preprocess_obs(self, single_obs): + # single_obs = single_obs[None, ...] # add batch dim + return np.array(single_obs) + + def _post_process_action(self, action): + return action[None, ...] # add batch dim + + def on_episode_start(self): + self._init_to_report(in_cell_eval=False, in_rllib_eval=True) + self.base_policies = {} + for policy_id in self.policy_ids: + self.base_policies[policy_id] = self._sample_base_policy(policy_id) diff --git a/marltoolbox/algos/sos.py b/marltoolbox/algos/sos.py new file mode 100644 index 0000000..be470f3 --- /dev/null +++ b/marltoolbox/algos/sos.py @@ -0,0 +1,528 @@ +###### +# Code modified from: +# https://github.com/julianstastny/openspiel-social-dilemmas +###### +import os +import random + +import torch +import numpy as np +from ray import tune + +from marltoolbox.envs.matrix_sequential_social_dilemma import ( + IteratedPrisonersDilemma, + IteratedAsymBoS, +) + + +class SOSTrainer(tune.Trainable): + def setup(self, config: dict): + + self.config = config + self.gamma = self.config.get("gamma") + self.learning_rate = self.config.get("lr") + self.method = self.config.get("method") + self.use_single_weights = False + + self._set_environment() + self.init_weigths(std=self.config.get("inital_weights_std")) + if self.use_single_weights: + self._exact_loss_matrix_game = ( + self._exact_loss_matrix_game_two_by_two_actions + ) + else: + self._exact_loss_matrix_game = self._exact_loss_matrix_game_generic + + def _set_environment(self): + + if self.config.get("env_name") == "IteratedPrisonersDilemma": + env_class = IteratedPrisonersDilemma + payoff_matrix = env_class.PAYOFF_MATRIX + self.players_ids = env_class({}).players_ids + elif self.config.get("env_name") == "IteratedAsymBoS": + env_class = IteratedAsymBoS + payoff_matrix = env_class.PAYOFF_MATRIX + self.players_ids = env_class({}).players_ids + elif "custom_payoff_matrix" in self.config.keys(): + payoff_matrix = self.config["custom_payoff_matrix"] + self.players_ids = IteratedPrisonersDilemma({}).players_ids + else: + raise NotImplementedError() + + self.n_actions_p1 = np.array(payoff_matrix).shape[0] + self.n_actions_p2 = np.array(payoff_matrix).shape[1] + self.n_non_init_states = self.n_actions_p1 * self.n_actions_p2 + + if self.use_single_weights: + self.dims = [ + (self.n_non_init_states + 1,), + (self.n_non_init_states + 1,), + ] + else: + self.dims = [ + (self.n_non_init_states + 1, self.n_actions_p1), + (self.n_non_init_states + 1, self.n_actions_p2), + ] + + self.payoff_matrix_player_row = torch.tensor( + payoff_matrix[:, :, 0] + ).float() + self.payoff_matrix_player_col = torch.tensor( + payoff_matrix[:, :, 1] + ).float() + + def step(self): + losses = self.update_th() + mean_reward_player_row = -losses[0] * (1 - self.gamma) + mean_reward_player_col = -losses[1] * (1 - self.gamma) + to_log = { + f"mean_reward_{self.players_ids[0]}": mean_reward_player_row, + f"mean_reward_{self.players_ids[1]}": mean_reward_player_col, + "episodes_total": self.training_iteration, + "policy1": self.policy_player1, + "policy2": self.policy_player2, + } + return to_log + + def _exact_loss_matrix_game_two_by_two_actions(self): + + self.policy_player1 = torch.sigmoid(self.weights_per_players[0]) + self.policy_player2 = torch.sigmoid(self.weights_per_players[1]) + + pi_player_row_init_state = torch.sigmoid( + self.weights_per_players[0][0:1] + ) + pi_player_col_init_state = torch.sigmoid( + self.weights_per_players[1][0:1] + ) + p = torch.cat( + [ + pi_player_row_init_state * pi_player_col_init_state, + pi_player_row_init_state * (1 - pi_player_col_init_state), + (1 - pi_player_row_init_state) * pi_player_col_init_state, + (1 - pi_player_row_init_state) + * (1 - pi_player_col_init_state), + ] + ) + pi_player_row_other_states = torch.reshape( + torch.sigmoid(self.weights_per_players[0][1:5]), (4, 1) + ) + pi_player_col_other_states = torch.reshape( + torch.sigmoid(self.weights_per_players[1][1:5]), (4, 1) + ) + P = torch.cat( + [ + pi_player_row_other_states * pi_player_col_other_states, + pi_player_row_other_states * (1 - pi_player_col_other_states), + (1 - pi_player_row_other_states) * pi_player_col_other_states, + (1 - pi_player_row_other_states) + * (1 - pi_player_col_other_states), + ], + dim=1, + ) + M = -torch.matmul(p, torch.inverse(torch.eye(4) - self.gamma * P)) + L_1 = torch.matmul( + M, torch.reshape(self.payoff_matrix_player_row, (4, 1)) + ) + L_2 = torch.matmul( + M, torch.reshape(self.payoff_matrix_player_col, (4, 1)) + ) + return [L_1, L_2] + + def _exact_loss_matrix_game_generic(self): + pi_player_row = torch.sigmoid(self.weights_per_players[0]) + pi_player_col = torch.sigmoid(self.weights_per_players[1]) + sum_1 = torch.sum(pi_player_row, dim=1) + sum_1 = torch.stack([sum_1 for _ in range(self.n_actions_p1)], dim=1) + sum_2 = torch.sum(pi_player_col, dim=1) + sum_2 = torch.stack([sum_2 for _ in range(self.n_actions_p2)], dim=1) + pi_player_row = pi_player_row / sum_1 + pi_player_col = pi_player_col / sum_2 + self.policy_player1 = pi_player_row + self.policy_player2 = pi_player_col + + pi_player_row_init_state = pi_player_row[:1, :] + pi_player_col_init_state = pi_player_col[:1, :] + all_initial_actions_proba_pairs = [] + for action_p1 in range(self.n_actions_p1): + for action_p2 in range(self.n_actions_p2): + all_initial_actions_proba_pairs.append( + pi_player_row_init_state[:, action_p1] + * pi_player_col_init_state[:, action_p2] + ) + p = torch.cat( + all_initial_actions_proba_pairs, + ) + + pi_player_row_other_states = pi_player_row[1:, :] + pi_player_col_other_states = pi_player_col[1:, :] + all_actions_proba_pairs = [] + for action_p1 in range(self.n_actions_p1): + for action_p2 in range(self.n_actions_p2): + all_actions_proba_pairs.append( + pi_player_row_other_states[:, action_p1] + * pi_player_col_other_states[:, action_p2] + ) + P = torch.stack( + all_actions_proba_pairs, + 1, + ) + + M = -torch.matmul( + p, + torch.inverse(torch.eye(self.n_non_init_states) - self.gamma * P), + ) + L_1 = torch.matmul( + M, + torch.reshape( + self.payoff_matrix_player_row, (self.n_non_init_states, 1) + ), + ) + L_2 = torch.matmul( + M, + torch.reshape( + self.payoff_matrix_player_col, (self.n_non_init_states, 1) + ), + ) + return [L_1, L_2] + + def init_weigths(self, std): + self.weights_per_players = [] + self.n_players = len(self.dims) + for i in range(self.n_players): + if std > 0: + init = torch.nn.init.normal_( + torch.empty(self.dims[i], requires_grad=True), std=std + ) + else: + init = torch.zeros(self.dims[i], requires_grad=True) + self.weights_per_players.append(init) + + def update_th( + self, + a=0.5, + b=0.1, + gam=1, + ep=0.1, + lss_lam=0.1, + ): + losses = self._exact_loss_matrix_game() + + grads = self._compute_gradients_wt_selected_method( + a, + b, + ep, + gam, + losses, + lss_lam, + ) + + self._update_weights(grads) + return losses + + def _compute_gradients_wt_selected_method( + self, a, b, ep, gam, losses, lss_lam + ): + n_players = self.n_players + + grad_L = _compute_vanilla_gradients( + losses, n_players, self.weights_per_players + ) + + if self.method == "la": + grads = _get_la_gradients( + self.learning_rate, grad_L, n_players, self.weights_per_players + ) + elif self.method == "lola": + grads = _get_lola_exact_gradients( + self.learning_rate, grad_L, n_players, self.weights_per_players + ) + elif self.method == "sos": + grads = _get_sos_gradients( + a, + self.learning_rate, + b, + grad_L, + n_players, + self.weights_per_players, + ) + elif self.method == "sga": + grads = _get_sga_gradients( + ep, grad_L, n_players, self.weights_per_players + ) + elif self.method == "co": + grads = _get_co_gradients( + gam, grad_L, n_players, self.weights_per_players + ) + elif self.method == "eg": + grads = _get_eg_gradients( + self._exact_loss_matrix_game, + self.learning_rate, + losses, + n_players, + self.weights_per_players, + ) + elif self.method == "cgd": # Slow implementation (matrix inversion) + grads = _get_cgd_gradients( + self.learning_rate, grad_L, n_players, self.weights_per_players + ) + elif self.method == "lss": # Slow implementation (matrix inversion) + grads = _get_lss_gradients( + grad_L, lss_lam, n_players, self.weights_per_players + ) + elif self.method == "naive": # Naive Learning + grads = _get_naive_learning_gradients(grad_L, n_players) + else: + raise ValueError(f"algo: {self.method}") + return grads + + def _update_weights(self, grads): + with torch.no_grad(): + for weight_i, weight_grad in enumerate(grads): + self.weights_per_players[weight_i] -= ( + self.learning_rate * weight_grad + ) + + def save_checkpoint(self, checkpoint_dir): + save_path = os.path.join(checkpoint_dir, "weights.pt") + torch.save(self.weights_per_players, save_path) + return save_path + + def load_checkpoint(self, checkpoint_path): + self.weights_per_players = torch.load(checkpoint_path) + + def _get_agent_to_use(self, policy_id): + if policy_id == "player_row": + agent_n = 0 + elif policy_id == "player_col": + agent_n = 1 + else: + raise ValueError(f"policy_id {policy_id}") + return agent_n + + def _preprocess_obs(self, single_obs, agent_to_use): + single_obs = np.where(single_obs == 1)[0][0] + # because idx 0 is linked to the initial state in the weights + # but this is the last idx which is linked to the + # initial state in the environment obs + if single_obs == len(self.weights_per_players[0]) - 1: + single_obs = 0 + else: + single_obs += 1 + return single_obs + + def _post_process_action(self, action): + return action[None, ...] # add batch dim + + def compute_actions(self, policy_id: str, obs_batch: list): + assert ( + len(obs_batch) == 1 + ), f"{len(obs_batch)} == 1. obs_batch: {obs_batch}" + + for single_obs in obs_batch: + agent_to_use = self._get_agent_to_use(policy_id) + obs = self._preprocess_obs(single_obs, agent_to_use) + policy = torch.sigmoid(self.weights_per_players[agent_to_use]) + + if self.use_single_weights: + coop_proba = policy[obs] + if coop_proba > random.random(): + action = np.array(0) + else: + action = np.array(1) + else: + probabilities = policy[obs, :] + probabilities = torch.tensor(probabilities) + policy_for_this_state = torch.distributions.Categorical( + probs=probabilities + ) + action = np.array(policy_for_this_state.sample()) + + action = self._post_process_action(action) + + state_out = [] + extra_fetches = {} + return action, state_out, extra_fetches + + +def _compute_vanilla_gradients(losses, n, th): + grad_L = [ + [get_gradient(losses[j], th[i]) for j in range(n)] for i in range(n) + ] + return grad_L + + +def _get_la_gradients(alpha, grad_L, n, th): + terms = [ + sum( + [ + # torch.dot(grad_L[j][i], grad_L[j][j].detach()) + torch.matmul( + grad_L[j][i], + torch.transpose(grad_L[j][j].detach(), 0, 1), + ) + for j in range(n) + if j != i + ] + ) + for i in range(n) + ] + grads = [ + grad_L[i][i] - alpha * get_gradient(terms[i], th[i]) for i in range(n) + ] + return grads + + +def _get_lola_exact_gradients(alpha, grad_L, n, th): + terms = [ + sum( + # [torch.dot(grad_L[j][i], grad_L[j][j]) for j in range(n) if j != i] + [ + torch.matmul( + grad_L[j][i], + torch.transpose(grad_L[j][j].unsqueeze(dim=1), 0, 1), + ) + for j in range(n) + if j != i + ] + ) + for i in range(n) + ] + grads = [ + grad_L[i][i] - alpha * get_gradient(terms[i], th[i]) for i in range(n) + ] + return grads + + +def _get_sos_gradients(a, alpha, b, grad_L, n, th): + xi_0 = _get_la_gradients(alpha, grad_L, n, th) + chi = [ + get_gradient( + sum( + [ + # torch.dot(grad_L[j][i].detach(), grad_L[j][j]) + torch.matmul( + grad_L[j][i].detach(), + torch.transpose(grad_L[j][j], 0, 1), + ) + for j in range(n) + if j != i + ] + ), + th[i], + ) + for i in range(n) + ] + # Compute p + # dot = torch.dot(-alpha * torch.cat(chi), torch.cat(xi_0)) + dot = torch.matmul( + -alpha * torch.cat(chi), + torch.transpose(torch.cat(xi_0), 0, 1), + ).sum() + p1 = 1 if dot >= 0 else min(1, -a * torch.norm(torch.cat(xi_0)) ** 2 / dot) + xi = torch.cat([grad_L[i][i] for i in range(n)]) + xi_norm = torch.norm(xi) + p2 = xi_norm ** 2 if xi_norm < b else 1 + p = min(p1, p2) + grads = [xi_0[i] - p * alpha * chi[i] for i in range(n)] + return grads + + +def _get_sga_gradients(ep, grad_L, n, th): + xi = torch.cat([grad_L[i][i] for i in range(n)]) + ham = torch.dot(xi, xi.detach()) + H_t_xi = [get_gradient(ham, th[i]) for i in range(n)] + H_xi = [ + get_gradient( + sum( + [ + torch.dot(grad_L[j][i], grad_L[j][j].detach()) + for j in range(n) + ] + ), + th[i], + ) + for i in range(n) + ] + A_t_xi = [H_t_xi[i] / 2 - H_xi[i] / 2 for i in range(n)] + # Compute lambda (sga with alignment) + dot_xi = torch.dot(xi, torch.cat(H_t_xi)) + dot_A = torch.dot(torch.cat(A_t_xi), torch.cat(H_t_xi)) + d = sum([len(th[i]) for i in range(n)]) + lam = torch.sign(dot_xi * dot_A / d + ep) + grads = [grad_L[i][i] + lam * A_t_xi[i] for i in range(n)] + return grads + + +def _get_co_gradients(gam, grad_L, n, th): + xi = torch.cat([grad_L[i][i] for i in range(n)]) + ham = torch.dot(xi, xi.detach()) + grads = [grad_L[i][i] + gam * get_gradient(ham, th[i]) for i in range(n)] + return grads + + +def _get_eg_gradients(Ls, alpha, losses, n, th): + th_eg = [th[i] - alpha * get_gradient(losses[i], th[i]) for i in range(n)] + losses_eg = Ls(th_eg) + grads = [get_gradient(losses_eg[i], th_eg[i]) for i in range(n)] + return grads + + +def _get_naive_learning_gradients(grad_L, n): + grads = [grad_L[i][i] for i in range(n)] + return grads + + +def _get_lss_gradients(grad_L, lss_lam, n, th): + dims = [len(th[i]) for i in range(n)] + xi = torch.cat([grad_L[i][i] for i in range(n)]) + H = get_hessian(th, grad_L) + if torch.det(H) == 0: + inv = torch.inverse( + torch.matmul(H.T, H) + lss_lam * torch.eye(sum(dims)) + ) + H_inv = torch.matmul(inv, H.T) + else: + H_inv = torch.inverse(H) + grad = ( + torch.matmul(torch.eye(sum(dims)) + torch.matmul(H.T, H_inv), xi) / 2 + ) + grads = [grad[sum(dims[:i]) : sum(dims[: i + 1])] for i in range(n)] + return grads + + +def _get_cgd_gradients(alpha, grad_L, n, th): + dims = [len(th[i]) for i in range(n)] + xi = torch.cat([grad_L[i][i] for i in range(n)]) + H_o = get_hessian(th, grad_L, diag=False) + grad = torch.matmul(torch.inverse(torch.eye(sum(dims)) + alpha * H_o), xi) + grads = [grad[sum(dims[:i]) : sum(dims[: i + 1])] for i in range(n)] + return grads + + +def get_gradient(function, param): + grad = torch.autograd.grad(torch.sum(function), param, create_graph=True)[ + 0 + ] + return grad + + +def get_hessian(th, grad_L, diag=True, off_diag=True): + n = len(th) + H = [] + for i in range(n): + row_block = [] + for j in range(n): + if (i == j and diag) or (i != j and off_diag): + block = [ + torch.unsqueeze( + get_gradient(grad_L[i][i][k], th[j]), + dim=0, + ) + for k in range(len(th[i])) + ] + row_block.append(torch.cat(block, dim=0)) + else: + row_block.append(torch.zeros(len(th[i]), len(th[j]))) + H.append(torch.cat(row_block, dim=1)) + return torch.cat(H, dim=0) diff --git a/marltoolbox/algos/stochastic_population.py b/marltoolbox/algos/stochastic_population.py new file mode 100644 index 0000000..6a76b12 --- /dev/null +++ b/marltoolbox/algos/stochastic_population.py @@ -0,0 +1,33 @@ +import torch +from ray.rllib import SampleBatch +from ray.rllib.utils import override + +from marltoolbox.algos.amTFT import base +from marltoolbox.algos.hierarchical import HierarchicalTorchPolicy + + +class StochasticPopulation(HierarchicalTorchPolicy, base.AmTFTReferenceClass): + def __init__(self, observation_space, action_space, config, **kwargs): + super().__init__(observation_space, action_space, config, **kwargs) + self.stochastic_population_policy = torch.distributions.Categorical( + probs=self.config["sampling_policy_distribution"] + ) + + def on_episode_start(self, *args, **kwargs): + self._select_algo_idx_to_use() + if hasattr(self.algorithms[self.active_algo_idx], "on_episode_start"): + self.algorithms[self.active_algo_idx].on_episode_start( + *args, + **kwargs, + ) + + def _select_algo_idx_to_use(self): + policy_idx_selected = self.stochastic_population_policy.sample() + self.active_algo_idx = policy_idx_selected + self._to_log[ + "StochasticPopulation_active_algo_idx" + ] = self.active_algo_idx + + @override(HierarchicalTorchPolicy) + def _learn_on_batch(self, samples: SampleBatch): + return self.algorithms[self.active_algo_idx]._learn_on_batch(samples) diff --git a/marltoolbox/algos/welfare_coordination.py b/marltoolbox/algos/welfare_coordination.py index 559bc83..786a4e2 100644 --- a/marltoolbox/algos/welfare_coordination.py +++ b/marltoolbox/algos/welfare_coordination.py @@ -2,6 +2,7 @@ import logging import random +import torch import numpy as np from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import override @@ -27,6 +28,7 @@ "opp_player_idx": None, "own_default_welfare_fn": None, "opp_default_welfare_fn": None, + "distrib_over_welfare_sets_to_annonce": False, }, ) @@ -60,26 +62,33 @@ def setup_meta_game( self.opp_player_idx = opp_player_idx self.own_default_welfare_fn = own_default_welfare_fn self.opp_default_welfare_fn = opp_default_welfare_fn - self._list_all_welfares_fn() + self.all_welfare_fn = MetaGameSolver.list_all_welfares_fn( + self.all_welfare_pairs_wt_payoffs + ) self._list_all_set_of_welfare_fn() - def _list_all_welfares_fn(self): + @staticmethod + def list_all_welfares_fn(all_welfare_pairs_wt_payoffs): + """List all welfare functions in the given data structure""" all_welfare_fn = [] - for welfare_pairs_key in self.all_welfare_pairs_wt_payoffs.keys(): - for welfare_fn in self._key_to_pair_of_welfare_names( + for welfare_pairs_key in all_welfare_pairs_wt_payoffs.keys(): + for welfare_fn in MetaGameSolver.key_to_pair_of_welfare_names( welfare_pairs_key ): all_welfare_fn.append(welfare_fn) - self.all_welfare_fn = tuple(set(all_welfare_fn)) + return sorted(tuple(set(all_welfare_fn))) @staticmethod - def _key_to_pair_of_welfare_names(key): + def key_to_pair_of_welfare_names(key): + """Convert a key to the two welfare functions composing this key""" return key.split("-") def _list_all_set_of_welfare_fn(self): """ Each conbinaison is a potentiel action in the meta game """ + from ordered_set import OrderedSet + welfare_fn_sets = [] for n_items in range(1, len(self.all_welfare_fn) + 1): combinations_object = itertools.combinations( @@ -87,10 +96,12 @@ def _list_all_set_of_welfare_fn(self): ) combinations_object = list(combinations_object) combinations_set = [ - frozenset(combi) for combi in combinations_object + OrderedSet(combi) for combi in combinations_object ] - welfare_fn_sets.extend(combinations_set) - self.welfare_fn_sets = tuple(set(welfare_fn_sets)) + for el in combinations_set: + if el not in welfare_fn_sets: + welfare_fn_sets.append(el) + self.welfare_fn_sets = tuple(sorted(welfare_fn_sets, key=str)) def solve_meta_game(self, tau): """ @@ -98,7 +109,10 @@ def solve_meta_game(self, tau): :param tau: """ print("================================================") - print(f"Start solving meta game with tau={tau}") + print( + f"Player (idx={self.own_player_idx}) starts solving meta game " + f"with tau={tau}" + ) self.tau = tau self.selected_pure_policy_idx = None self.best_objective = -np.inf @@ -116,6 +130,7 @@ def solve_meta_game(self, tau): ] def _search_for_best_action(self): + self.tau_log = [] for idx, welfare_set_annonced in enumerate(self.welfare_fn_sets): optimization_objective = self._compute_optimization_objective( self.tau, welfare_set_annonced @@ -123,26 +138,40 @@ def _search_for_best_action(self): self._keep_action_if_best(optimization_objective, idx) def _compute_optimization_objective(self, tau, welfare_set_annonced): - print(f"========") - print(f"compute objective for set {welfare_set_annonced}") self._compute_payoffs_for_every_opponent_action(welfare_set_annonced) opp_best_response_idx = self._get_opp_best_response_idx() + + exploitability_term = tau * self._get_own_payoff(opp_best_response_idx) + payoffs = self._get_all_possible_payoffs_excluding_provided( + # excluding_idx=opp_best_response_idx + ) + robustness_term = (1 - tau) * sum(payoffs) / len(payoffs) + + optimization_objective = exploitability_term + robustness_term + + print(f"========") + print(f"compute objective for set {welfare_set_annonced}") print( f"opponent best response is {opp_best_response_idx} => " f"{self.welfare_fn_sets[opp_best_response_idx]}" ) - exploitability_term = tau * self._get_own_payoff(opp_best_response_idx) print("exploitability_term", exploitability_term) + print("robustness_term", robustness_term) + print("optimization_objective", optimization_objective) - payoffs = self._get_all_possible_payoffs_excluding_one( - excluding_idx=opp_best_response_idx + self.tau_log.append( + { + "welfare_set_annonced": welfare_set_annonced, + "opp_best_response_idx": opp_best_response_idx, + "opp_best_response_set": self.welfare_fn_sets[ + opp_best_response_idx + ], + "robustness_term": robustness_term, + "optimization_objective": optimization_objective, + } ) - robustness_term = (1 - tau) * sum(payoffs) / len(payoffs) - print("robustness_term", robustness_term) - optimization_objective = exploitability_term + robustness_term - print("optimization_objective", optimization_objective) return optimization_objective def _compute_payoffs_for_every_opponent_action(self, welfare_set_annonced): @@ -158,18 +187,18 @@ def _compute_meta_payoff(self, own_welfare_set, opp_welfare_set): welfare_fn_intersection = own_welfare_set & opp_welfare_set if len(welfare_fn_intersection) == 0: - return self._get_own_payoff_for_default_strategies() + return self._read_payoffs_for_default_strategies() else: - return self._get_own_payoff_averaging_over_welfare_intersection( + return self._read_payoffs_and_average_over_welfare_intersection( welfare_fn_intersection ) - def _get_own_payoff_for_default_strategies(self): + def _read_payoffs_for_default_strategies(self): return self._read_own_payoff_from_data( self.own_default_welfare_fn, self.opp_default_welfare_fn ) - def _get_own_payoff_averaging_over_welfare_intersection( + def _read_payoffs_and_average_over_welfare_intersection( self, welfare_fn_intersection ): payoffs_player_1 = [] @@ -182,16 +211,17 @@ def _get_own_payoff_averaging_over_welfare_intersection( payoffs_player_2.append(payoff_player_2) mean_payoff_p1 = sum(payoffs_player_1) / len(payoffs_player_1) mean_payoff_p2 = sum(payoffs_player_2) / len(payoffs_player_2) - return (mean_payoff_p1, mean_payoff_p2) + return mean_payoff_p1, mean_payoff_p2 def _read_own_payoff_from_data(self, own_welfare, opp_welfare): - welfare_pair_name = self._from_pair_of_welfare_names_to_key( + welfare_pair_name = self.from_pair_of_welfare_names_to_key( own_welfare, opp_welfare ) return self.all_welfare_pairs_wt_payoffs[welfare_pair_name] @staticmethod - def _from_pair_of_welfare_names_to_key(own_welfare_set, opp_welfare_set): + def from_pair_of_welfare_names_to_key(own_welfare_set, opp_welfare_set): + """Convert two welfare functions into a key identifying this pair""" return f"{own_welfare_set}-{opp_welfare_set}" def _get_opp_best_response_idx(self): @@ -201,13 +231,15 @@ def _get_opp_best_response_idx(self): ] return opp_payoffs.index(max(opp_payoffs)) - def _get_all_possible_payoffs_excluding_one(self, excluding_idx): + def _get_all_possible_payoffs_excluding_provided(self, excluding_idx=()): own_payoffs = [] for welfare_set_idx, _ in enumerate(self.welfare_fn_sets): - if welfare_set_idx != excluding_idx: + if welfare_set_idx not in excluding_idx: own_payoff = self._get_own_payoff(welfare_set_idx) own_payoffs.append(own_payoff) - assert len(own_payoffs) == len(self.welfare_fn_sets) - 1 + assert len(own_payoffs) == len(self.welfare_fn_sets) - len( + excluding_idx + ) return own_payoffs def _get_own_payoff(self, idx): @@ -253,14 +285,47 @@ def __init__( **kwargs, ) + assert ( + self._use_distribution_over_sets() + or self.config["solve_meta_game_after_init"] + ) if self.config["solve_meta_game_after_init"]: + assert not self.config["distrib_over_welfare_sets_to_annonce"] self._choose_which_welfare_set_to_annonce() + if self._use_distribution_over_sets(): + self._init_distribution_over_sets_to_announce() self._welfare_set_annonced = False self._welfare_set_in_use = None self._welfare_chosen_for_epi = False self._intersection_of_welfare_sets = None + def _init_distribution_over_sets_to_announce(self): + assert not self.config["solve_meta_game_after_init"] + self.setup_meta_game( + all_welfare_pairs_wt_payoffs=self.config[ + "all_welfare_pairs_wt_payoffs" + ], + own_player_idx=self.config["own_player_idx"], + opp_player_idx=self.config["opp_player_idx"], + own_default_welfare_fn=self.config["own_default_welfare_fn"], + opp_default_welfare_fn=self.config["opp_default_welfare_fn"], + ) + self.stochastic_announcement_policy = torch.distributions.Categorical( + probs=self.config["distrib_over_welfare_sets_to_annonce"] + ) + self._stochastic_selection_of_welfare_set_to_announce() + + def _stochastic_selection_of_welfare_set_to_announce(self): + welfare_set_idx_selected = self.stochastic_announcement_policy.sample() + self.welfare_set_to_annonce = self.welfare_fn_sets[ + welfare_set_idx_selected + ] + self._to_log["welfare_set_to_annonce"] = str( + self.welfare_set_to_annonce + ) + print("self.welfare_set_to_annonce", self.welfare_set_to_annonce) + def _choose_which_welfare_set_to_annonce(self): self.setup_meta_game( all_welfare_pairs_wt_payoffs=self.config[ @@ -272,6 +337,9 @@ def _choose_which_welfare_set_to_annonce(self): opp_default_welfare_fn=self.config["opp_default_welfare_fn"], ) self.solve_meta_game(self.config["tau"]) + self._to_log[f"tau_{self.tau}"] = self.tau_log[ + self.selected_pure_policy_idx + ] @property def policy_checkpoints(self): @@ -282,9 +350,9 @@ def policy_checkpoints(self): @policy_checkpoints.setter def policy_checkpoints(self, value): - msg = f"ignoring set self.policy_checkpoints to value {value}" - print(msg) - logger.warning(msg) + logger.warning( + f"ignoring setting self.policy_checkpoints to value {value}" + ) @override(population.PopulationOfIdenticalAlgo) def set_algo_to_use(self): @@ -302,20 +370,28 @@ def on_episode_start( policy, policy_id, policy_ids, + episode, **kwargs, ): if not self._welfare_set_annonced: # Called only by one agent, not by both - self._annonce_welfare_sets(worker) + self._annonce_welfare_sets(worker, episode) if not self._welfare_chosen_for_epi: # Called only by one agent, not by both self._coordinate_on_welfare_to_use_for_epi(worker) - self.set_algo_to_use() + super().on_episode_start( + *args, + worker, + policy, + policy_id, + policy_ids, + **kwargs, + ) + assert self._welfare_set_annonced + assert self._welfare_chosen_for_epi @staticmethod - def _annonce_welfare_sets( - worker, - ): + def _annonce_welfare_sets(worker, episode): intersection_of_welfare_sets = ( WelfareCoordinationTorchPolicy._find_intersection_of_welfare_sets( worker @@ -328,6 +404,18 @@ def _annonce_welfare_sets( policy, intersection_of_welfare_sets ) + WelfareCoordinationTorchPolicy._add_coordination_info_into_custom_metrics( + intersection_of_welfare_sets, episode + ) + + @staticmethod + def _add_coordination_info_into_custom_metrics( + intersection_of_welfare_sets, episode + ): + episode.custom_metrics[f"coordination_success"] = float( + len(intersection_of_welfare_sets) > 0 + ) + @staticmethod def _find_intersection_of_welfare_sets(worker): welfare_sets_annonced = [ @@ -357,19 +445,35 @@ def _coordinate_on_welfare_to_use_for_epi(worker): if isinstance(policy, WelfareCoordinationTorchPolicy): if len(policy._intersection_of_welfare_sets) > 0: if welfare_to_play is None: - print( - "policy._welfare_set_in_use", - policy._welfare_set_in_use, - ) welfare_to_play = random.choice( tuple(policy._welfare_set_in_use) ) policy._welfare_in_use = welfare_to_play + msg = ( + "Policies coordinated " + "to play for the next epi: " + f"_welfare_set_in_use={policy._welfare_set_in_use}" + f"and selected: policy._welfare_in_use={policy._welfare_in_use}" + ) + logger.info(msg) + print(msg) else: policy._welfare_in_use = random.choice( tuple(policy._welfare_set_in_use) ) + msg = ( + "Policies did NOT coordinate " + "to play for the next epi: " + f"_welfare_set_in_use={policy._welfare_set_in_use}" + f"and selected: policy._welfare_in_use={policy._welfare_in_use}" + ) + logger.info(msg) + print(msg) policy._welfare_chosen_for_epi = True + policy._to_log["welfare_in_use"] = policy._welfare_in_use + policy._to_log[ + "welfare_set_in_use" + ] = policy._welfare_set_in_use def on_episode_end( self, @@ -377,4 +481,36 @@ def on_episode_end( **kwargs, ): self._welfare_chosen_for_epi = False - self.algorithms[self.active_algo_idx].on_episode_end(*args, **kwargs) + if hasattr(self.algorithms[self.active_algo_idx], "on_episode_end"): + self.algorithms[self.active_algo_idx].on_episode_end( + *args, **kwargs + ) + if self._use_distribution_over_sets(): + self._stochastic_selection_of_welfare_set_to_announce() + self._welfare_set_annonced = False + + def _use_distribution_over_sets(self): + return not self.config["distrib_over_welfare_sets_to_annonce"] is False + + @property + @override(population.PopulationOfIdenticalAlgo) + def to_log(self): + to_log = { + "meta_policy": self._to_log, + "nested_policy": { + f"policy_{algo_idx}": algo.to_log + for algo_idx, algo in enumerate(self.algorithms) + if hasattr(algo, "to_log") + }, + } + return to_log + + @to_log.setter + @override(population.PopulationOfIdenticalAlgo) + def to_log(self, value): + if value == {}: + for algo in self.algorithms: + if hasattr(algo, "to_log"): + algo.to_log = {} + + self._to_log = value diff --git a/marltoolbox/envs/coin_game.py b/marltoolbox/envs/coin_game.py index 8c56f8f..4561c55 100644 --- a/marltoolbox/envs/coin_game.py +++ b/marltoolbox/envs/coin_game.py @@ -14,17 +14,22 @@ logger = logging.getLogger(__name__) -PLOT_KEYS = ["pick_speed", - "pick_own_color", - ] +PLOT_KEYS = [ + "pick_speed", + "pick_own_color", +] PLOT_ASSEMBLAGE_TAGS = [ ("pick_own_color_player_red_mean", "pick_own_color_player_blue_mean"), ("pick_speed_player_red_mean", "pick_speed_player_blue_mean"), ("pick_own_color",), ("pick_speed", "pick_own_color"), - ("pick_own_color_player_red_mean", "pick_own_color_player_blue_mean", - "pick_speed_player_red_mean", "pick_speed_player_blue_mean"), + ( + "pick_own_color_player_red_mean", + "pick_own_color_player_blue_mean", + "pick_speed_player_red_mean", + "pick_speed_player_blue_mean", + ), ] @@ -32,6 +37,7 @@ class CoinGame(InfoAccumulationInterface, MultiAgentEnv, gym.Env): """ Coin Game environment. """ + NAME = "CoinGame" NUM_AGENTS = 2 NUM_ACTIONS = 4 @@ -55,7 +61,7 @@ def __init__(self, config: Dict = {}): low=0, high=1, shape=(self.grid_size, self.grid_size, 4), - dtype="uint8" + dtype="uint8", ) self.step_count_in_current_episode = None @@ -69,17 +75,22 @@ def _validate_config(self, config): assert len(config["players_ids"]) == self.NUM_AGENTS def _load_config(self, config): - self.players_ids = \ - config.get("players_ids", ["player_red", "player_blue"]) + self.players_ids = config.get( + "players_ids", ["player_red", "player_blue"] + ) self.max_steps = config.get("max_steps", 20) self.grid_size = config.get("grid_size", 3) - self.output_additional_info = config.get("output_additional_info", - True) + self.output_additional_info = config.get( + "output_additional_info", True + ) self.asymmetric = config.get("asymmetric", False) - self.both_players_can_pick_the_same_coin = \ - config.get("both_players_can_pick_the_same_coin", True) - self.same_obs_for_each_player = \ - config.get("same_obs_for_each_player", True) + self.both_players_can_pick_the_same_coin = config.get( + "both_players_can_pick_the_same_coin", True + ) + self.same_obs_for_each_player = config.get( + "same_obs_for_each_player", True + ) + self.punishment_helped = config.get("punishment_helped", False) @override(gym.Env) def seed(self, seed=None): @@ -89,6 +100,7 @@ def seed(self, seed=None): @override(gym.Env) def reset(self): + # print("reset") self.step_count_in_current_episode = 0 if self.output_additional_info: @@ -98,18 +110,17 @@ def reset(self): self._generate_coin() obs = self._generate_observation() - return { - self.player_red_id: obs[0], - self.player_blue_id: obs[1] - } + return {self.player_red_id: obs[0], self.player_blue_id: obs[1]} def _randomize_color_and_player_positions(self): # Reset coin color and the players and coin positions self.red_coin = self.np_random.randint(low=0, high=2) - self.red_pos = \ - self.np_random.randint(low=0, high=self.grid_size, size=(2,)) - self.blue_pos = \ - self.np_random.randint(low=0, high=self.grid_size, size=(2,)) + self.red_pos = self.np_random.randint( + low=0, high=self.grid_size, size=(2,) + ) + self.blue_pos = self.np_random.randint( + low=0, high=self.grid_size, size=(2,) + ) # self.coin_pos = np.zeros(shape=(2,), dtype=np.int8) self._players_do_not_overlap_at_start() @@ -165,51 +176,62 @@ def _same_pos(self, x, y): return (x == y).all() def _move_players(self, actions): - self.red_pos = \ - (self.red_pos + self.MOVES[actions[0]]) % self.grid_size - self.blue_pos = \ - (self.blue_pos + self.MOVES[actions[1]]) % self.grid_size + self.red_pos = (self.red_pos + self.MOVES[actions[0]]) % self.grid_size + self.blue_pos = ( + self.blue_pos + self.MOVES[actions[1]] + ) % self.grid_size def _compute_reward(self): reward_red = 0.0 reward_blue = 0.0 generate_new_coin = False - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = \ - False, False, False, False + red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = ( + False, + False, + False, + False, + ) red_first_if_both = None if not self.both_players_can_pick_the_same_coin: - if self._same_pos(self.red_pos, self.coin_pos) and \ - self._same_pos(self.blue_pos, self.coin_pos): + if self._same_pos(self.red_pos, self.coin_pos) and self._same_pos( + self.blue_pos, self.coin_pos + ): red_first_if_both = bool(self.np_random.randint(low=0, high=2)) if self.red_coin: - if self._same_pos(self.red_pos, self.coin_pos) and \ - (red_first_if_both is None or red_first_if_both): + if self._same_pos(self.red_pos, self.coin_pos) and ( + red_first_if_both is None or red_first_if_both + ): generate_new_coin = True reward_red += 1 if self.asymmetric: reward_red += 3 red_pick_any = True red_pick_red = True - if self._same_pos(self.blue_pos, self.coin_pos) and \ - (red_first_if_both is None or not red_first_if_both): + if self._same_pos(self.blue_pos, self.coin_pos) and ( + red_first_if_both is None or not red_first_if_both + ): generate_new_coin = True reward_red += -2 reward_blue += 1 blue_pick_any = True + if self.asymmetric and self.punishment_helped: + reward_red -= 6 else: - if self._same_pos(self.red_pos, self.coin_pos) and \ - (red_first_if_both is None or red_first_if_both): + if self._same_pos(self.red_pos, self.coin_pos) and ( + red_first_if_both is None or red_first_if_both + ): generate_new_coin = True reward_red += 1 reward_blue += -2 if self.asymmetric: reward_red += 3 red_pick_any = True - if self._same_pos(self.blue_pos, self.coin_pos) and \ - (red_first_if_both is None or not red_first_if_both): + if self._same_pos(self.blue_pos, self.coin_pos) and ( + red_first_if_both is None or not red_first_if_both + ): generate_new_coin = True reward_blue += 1 blue_pick_blue = True @@ -217,10 +239,12 @@ def _compute_reward(self): reward_list = [reward_red, reward_blue] if self.output_additional_info: - self._accumulate_info(red_pick_any=red_pick_any, - red_pick_red=red_pick_red, - blue_pick_any=blue_pick_any, - blue_pick_blue=blue_pick_blue) + self._accumulate_info( + red_pick_any=red_pick_any, + red_pick_red=red_pick_red, + blue_pick_any=blue_pick_any, + blue_pick_blue=blue_pick_blue, + ) return reward_list, generate_new_coin @@ -260,11 +284,12 @@ def _to_RLLib_API(self, observations, rewards): self.player_blue_id: rewards[1], } - epi_is_done = (self.step_count_in_current_episode >= self.max_steps) + epi_is_done = self.step_count_in_current_episode >= self.max_steps if self.step_count_in_current_episode > self.max_steps: logger.warning( "step_count_in_current_episode > self.max_steps: " - f"{self.step_count_in_current_episode} > {self.max_steps}") + f"{self.step_count_in_current_episode} > {self.max_steps}" + ) done = { self.player_red_id: epi_is_done, @@ -283,7 +308,7 @@ def _to_RLLib_API(self, observations, rewards): return obs, rewards, done, info @override(InfoAccumulationInterface) - def _get_episode_info(self): + def _get_episode_info(self, n_steps_played=None): """ Output the following information: pick_speed is the fraction of steps during which the player picked a @@ -292,20 +317,25 @@ def _get_episode_info(self): the same color as the player. """ player_red_info, player_blue_info = {}, {} + if n_steps_played is None: + n_steps_played = len(self.red_pick) + assert len(self.red_pick) == len(self.blue_pick) if len(self.red_pick) > 0: red_pick = sum(self.red_pick) - player_red_info["pick_speed"] = red_pick / len(self.red_pick) + player_red_info["pick_speed"] = red_pick / n_steps_played if red_pick > 0: - player_red_info["pick_own_color"] = \ + player_red_info["pick_own_color"] = ( sum(self.red_pick_own) / red_pick + ) if len(self.blue_pick) > 0: blue_pick = sum(self.blue_pick) - player_blue_info["pick_speed"] = blue_pick / len(self.blue_pick) + player_blue_info["pick_speed"] = blue_pick / n_steps_played if blue_pick > 0: - player_blue_info["pick_own_color"] = \ + player_blue_info["pick_own_color"] = ( sum(self.blue_pick_own) / blue_pick + ) return player_red_info, player_blue_info @@ -318,7 +348,8 @@ def _reset_info(self): @override(InfoAccumulationInterface) def _accumulate_info( - self, red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue): + self, red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue + ): self.red_pick.append(red_pick_any) self.red_pick_own.append(red_pick_red) diff --git a/marltoolbox/envs/matrix_sequential_social_dilemma.py b/marltoolbox/envs/matrix_sequential_social_dilemma.py index 52fd4f3..c2be273 100644 --- a/marltoolbox/envs/matrix_sequential_social_dilemma.py +++ b/marltoolbox/envs/matrix_sequential_social_dilemma.py @@ -43,7 +43,7 @@ class MatrixSequentialSocialDilemma( """ A multi-agent abstract class for two player matrix games. - PAYOUT_MATRIX: Numpy array. Along the dimension N, the action of the + PAYOFF_MATRIX: Numpy array. Along the dimension N, the action of the Nth player change. The last dimension is used to select the player whose reward you want to know. @@ -56,10 +56,21 @@ class MatrixSequentialSocialDilemma( episode. """ + NUM_AGENTS = 2 + NUM_ACTIONS = None + NUM_STATES = None + ACTION_SPACE = None + OBSERVATION_SPACE = None + PAYOFF_MATRIX = None + NAME = None + def __init__(self, config: Dict = {}): - assert "reward_randomness" not in config.keys() - assert self.PAYOUT_MATRIX is not None + assert self.PAYOFF_MATRIX is not None + assert self.PAYOFF_MATRIX.shape[0] == self.NUM_ACTIONS + assert self.PAYOFF_MATRIX.shape[1] == self.NUM_ACTIONS + assert self.PAYOFF_MATRIX.shape[2] == self.NUM_AGENTS + assert len(self.PAYOFF_MATRIX.shape) == 3 if "players_ids" in config: assert ( isinstance(config["players_ids"], Iterable) @@ -85,7 +96,7 @@ def __init__(self, config: Dict = {}): self._init_info() def seed(self, seed=None): - """Seed the PRNG of this space. """ + """Seed the PRNG of this space.""" self.np_random, seed = seeding.np_random(seed) return [seed] @@ -156,8 +167,8 @@ def _produce_observations_invariant_to_the_player_trained( def _get_players_rewards(self, action_player_0: int, action_player_1: int): return [ - self.PAYOUT_MATRIX[action_player_0][action_player_1][0], - self.PAYOUT_MATRIX[action_player_0][action_player_1][1], + self.PAYOFF_MATRIX[action_player_0][action_player_1][0], + self.PAYOFF_MATRIX[action_player_0][action_player_1][1], ] def _to_RLLib_API( @@ -205,12 +216,11 @@ class IteratedMatchingPennies( A two-agent environment for the Matching Pennies game. """ - NUM_AGENTS = 2 NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array([[[+1, -1], [-1, +1]], [[-1, +1], [+1, -1]]]) + PAYOFF_MATRIX = np.array([[[+1, -1], [-1, +1]], [[-1, +1], [+1, -1]]]) NAME = "IMP" @@ -221,12 +231,11 @@ class IteratedPrisonersDilemma( A two-agent environment for the Prisoner's Dilemma game. """ - NUM_AGENTS = 2 NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array([[[-1, -1], [-3, +0]], [[+0, -3], [-2, -2]]]) + PAYOFF_MATRIX = np.array([[[-1, -1], [-3, +0]], [[+0, -3], [-2, -2]]]) NAME = "IPD" @@ -237,12 +246,11 @@ class IteratedAsymPrisonersDilemma( A two-agent environment for the Asymmetric Prisoner's Dilemma game. """ - NUM_AGENTS = 2 NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array([[[+0, -1], [-3, +0]], [[+0, -3], [-2, -2]]]) + PAYOFF_MATRIX = np.array([[[+0, -1], [-3, +0]], [[+0, -3], [-2, -2]]]) NAME = "IPD" @@ -253,12 +261,11 @@ class IteratedStagHunt( A two-agent environment for the Stag Hunt game. """ - NUM_AGENTS = 2 NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array([[[3, 3], [0, 2]], [[2, 0], [1, 1]]]) + PAYOFF_MATRIX = np.array([[[3, 3], [0, 2]], [[2, 0], [1, 1]]]) NAME = "IteratedStagHunt" @@ -269,12 +276,11 @@ class IteratedChicken( A two-agent environment for the Chicken game. """ - NUM_AGENTS = 2 NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array( + PAYOFF_MATRIX = np.array( [[[+0, +0], [-1.0, +1.0]], [[+1, -1], [-10, -10]]] ) NAME = "IteratedChicken" @@ -287,12 +293,11 @@ class IteratedAsymChicken( A two-agent environment for the Asymmetric Chicken game. """ - NUM_AGENTS = 2 NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array( + PAYOFF_MATRIX = np.array( [[[+2.0, +0], [-1.0, +1.0]], [[+2.5, -1], [-10, -10]]] ) NAME = "AsymmetricIteratedChicken" @@ -305,12 +310,11 @@ class IteratedBoS( A two-agent environment for the BoS game. """ - NUM_AGENTS = 2 NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array( + PAYOFF_MATRIX = np.array( [[[+3.0, +2.0], [+0.0, +0.0]], [[+0.0, +0.0], [+2.0, +3.0]]] ) NAME = "IteratedBoS" @@ -323,31 +327,31 @@ class IteratedAsymBoS( A two-agent environment for the BoS game. """ - NUM_AGENTS = 2 NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array( + PAYOFF_MATRIX = np.array( [[[+4.0, +1.0], [+0.0, +0.0]], [[+0.0, +0.0], [+2.0, +2.0]]] ) - NAME = "AsymmetricIteratedBoS" + NAME = "IteratedAsymBoS" def define_greed_fear_matrix_game(greed, fear): class GreedFearGame( TwoPlayersTwoActionsInfoMixin, MatrixSequentialSocialDilemma ): - NUM_AGENTS = 2 NUM_ACTIONS = 2 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = ( + NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 + ) ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) R = 3 P = 1 T = R + greed S = P - fear - PAYOUT_MATRIX = np.array([[[R, R], [S, T]], [[T, S], [P, P]]]) + PAYOFF_MATRIX = np.array([[[R, R], [S, T]], [[T, S], [P, P]]]) NAME = "IteratedGreedFear" def __str__(self): @@ -363,12 +367,11 @@ class IteratedBoSAndPD( A two-agent environment for the BOTS + PD game. """ - NUM_AGENTS = 2 NUM_ACTIONS = 3 - NUM_STATES = NUM_ACTIONS ** NUM_AGENTS + 1 + NUM_STATES = NUM_ACTIONS ** MatrixSequentialSocialDilemma.NUM_AGENTS + 1 ACTION_SPACE = Discrete(NUM_ACTIONS) OBSERVATION_SPACE = Discrete(NUM_STATES) - PAYOUT_MATRIX = np.array( + PAYOFF_MATRIX = np.array( [ [[3.5, +1], [+0, +0], [-3, +2]], [[+0.0, +0], [+1, +3], [-3, +2]], @@ -376,3 +379,25 @@ class IteratedBoSAndPD( ] ) NAME = "IteratedBoSAndPD" + + +class TwoPlayersCustomizableMatrixGame( + NPlayersNDiscreteActionsInfoMixin, MatrixSequentialSocialDilemma +): + + NAME = "TwoPlayersCustomizableMatrixGame" + + NUM_ACTIONS = None + NUM_STATES = None + ACTION_SPACE = None + OBSERVATION_SPACE = None + PAYOFF_MATRIX = None + + def __init__(self, config: Dict): + self.PAYOFF_MATRIX = config["PAYOFF_MATRIX"] + self.NUM_ACTIONS = config["NUM_ACTIONS"] + self.ACTION_SPACE = Discrete(self.NUM_ACTIONS) + self.NUM_STATES = self.NUM_ACTIONS ** self.NUM_AGENTS + 1 + self.OBSERVATION_SPACE = Discrete(self.NUM_STATES) + + super().__init__(config) diff --git a/marltoolbox/envs/mixed_motive_coin_game.py b/marltoolbox/envs/mixed_motive_coin_game.py deleted file mode 100644 index bfc6fcc..0000000 --- a/marltoolbox/envs/mixed_motive_coin_game.py +++ /dev/null @@ -1,111 +0,0 @@ -import logging - -import numpy as np -from ray.rllib.utils import override - -from marltoolbox.envs import coin_game - -logger = logging.getLogger(__name__) - -PLOT_KEYS = coin_game.PLOT_KEYS - -PLOT_ASSEMBLAGE_TAGS = coin_game.PLOT_ASSEMBLAGE_TAGS - - -class MixedMotiveCoinGame(coin_game.CoinGame): - - @override(coin_game.CoinGame) - def _load_config(self, config): - super()._load_config(config) - assert self.both_players_can_pick_the_same_coin, \ - "both_players_can_pick_the_same_coin option must be True in " \ - "mixed motive coin game." - - @override(coin_game.CoinGame) - def _randomize_color_and_player_positions(self): - # Reset coin color and the players and coin positions - self.red_pos = \ - self.np_random.randint(low=0, high=self.grid_size, size=(2,)) - self.blue_pos = \ - self.np_random.randint(low=0, high=self.grid_size, size=(2,)) - self.red_coin_pos = np.zeros(shape=(2,), dtype=np.int8) - self.blue_coin_pos = np.zeros(shape=(2,), dtype=np.int8) - - self._players_do_not_overlap_at_start() - - @override(coin_game.CoinGame) - def _generate_coin(self, color_to_generate="both"): - self._wt_coin_pos_different_from_players_and_other_coin( - color_to_generate) - - def _wt_coin_pos_different_from_players_and_other_coin( - self, color_to_generate): - - if color_to_generate == "both" or color_to_generate == "red": - success = 0 - while success < self.NUM_AGENTS + 1: - self.red_coin_pos = \ - self.np_random.randint(self.grid_size, size=2) - success = 1 - self._same_pos( - self.red_pos, self.red_coin_pos) - success += 1 - self._same_pos( - self.blue_pos, self.red_coin_pos) - success += 1 - self._same_pos( - self.blue_coin_pos, self.red_coin_pos) - if color_to_generate == "both" or color_to_generate == "blue": - success = 0 - while success < self.NUM_AGENTS + 1: - self.blue_coin_pos = \ - self.np_random.randint(self.grid_size, size=2) - success = 1 - self._same_pos( - self.red_pos, self.blue_coin_pos) - success += 1 - self._same_pos( - self.blue_pos, self.blue_coin_pos) - success += 1 - self._same_pos( - self.blue_coin_pos, self.red_coin_pos) - - @override(coin_game.CoinGame) - def _generate_observation(self): - obs = np.zeros((self.grid_size, self.grid_size, 4)) - obs[self.red_pos[0], self.red_pos[1], 0] = 1 - obs[self.blue_pos[0], self.blue_pos[1], 1] = 1 - obs[self.red_coin_pos[0], self.red_coin_pos[1], 2] = 1 - obs[self.blue_coin_pos[0], self.blue_coin_pos[1], 3] = 1 - - obs = self._apply_optional_invariance_to_the_player_trained(obs) - return obs - - @override(coin_game.CoinGame) - def _compute_reward(self): - - reward_red = 0.0 - reward_blue = 0.0 - generate_new_coin = False - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = \ - False, False, False, False - - if self._same_pos(self.red_pos, self.red_coin_pos) and \ - self._same_pos(self.blue_pos, self.red_coin_pos): - generate_new_coin = "red" - reward_red += 2 - reward_blue += 2 - red_pick_any = True - red_pick_red = True - blue_pick_any = True - elif self._same_pos(self.red_pos, self.blue_coin_pos) and \ - self._same_pos(self.blue_pos, self.blue_coin_pos): - generate_new_coin = "blue" - reward_red += 1 - reward_blue += 4 - red_pick_any = True - blue_pick_any = True - blue_pick_blue = True - - reward_list = [reward_red, reward_blue] - if self.output_additional_info: - self._accumulate_info(red_pick_any=red_pick_any, - red_pick_red=red_pick_red, - blue_pick_any=blue_pick_any, - blue_pick_blue=blue_pick_blue) - - return reward_list, generate_new_coin diff --git a/marltoolbox/envs/simple_bargaining.py b/marltoolbox/envs/simple_bargaining.py new file mode 100644 index 0000000..23f445f --- /dev/null +++ b/marltoolbox/envs/simple_bargaining.py @@ -0,0 +1,298 @@ +########## +# Part of the code modified from: +# https://github.com/julianstastny/openspiel-social-dilemmas/blob/jesse-br/lola_bots_ipd.py +########## + +import logging +from abc import ABC +from collections import Iterable +from typing import Dict + +import numpy as np +from gym.spaces import Discrete, Box, MultiDiscrete +from gym.utils import seeding +from ray.rllib.env.multi_agent_env import MultiAgentEnv + +from marltoolbox.envs.utils.interfaces import InfoAccumulationInterface +from marltoolbox.envs.utils.mixins import NPlayersNContinuousActionsInfoMixin + +logger = logging.getLogger(__name__) + +PLOT_KEYS = [ + "player0_", + "player1_", + "_mean", + "_std", +] + +PLOT_ASSEMBLAGE_TAGS = [ + ("player0_",), + ("player1_",), + ("player0_", "player1_"), + ("_mean",), + ("_std",), +] + + +class SimpleBargaining( + NPlayersNContinuousActionsInfoMixin, + InfoAccumulationInterface, + MultiAgentEnv, + ABC, +): + NUM_AGENTS = 2 + NUM_ACTIONS = 2 + ACTION_SPACE = Box( + low=0.0, + high=1.0, + shape=(NUM_ACTIONS,), + dtype="float32", + ) + OBSERVATION_SPACE = Box( + low=0.0, + high=1.0, + shape=(NUM_AGENTS, NUM_ACTIONS), + dtype="float32", + ) + NAME = "SimpleBargaining" + INIT_STATE_VALUE = np.ones(shape=OBSERVATION_SPACE.shape) * 0.0 + G = GAINS_FROM_TRADE_FACTOR = 3.0 + MULTIPLIER = 0.2 + PL0_T0, PL0_T1, PL1_T0, PL1_T1 = np.array([3, 9, 7, 2]) * MULTIPLIER + + def __init__(self, config: Dict = {}): + # logger.warning("ENV NOT DEBBUGED, NOT TESTED") + + if "players_ids" in config: + assert ( + isinstance(config["players_ids"], Iterable) + and len(config["players_ids"]) == self.NUM_AGENTS + ) + + self.players_ids = config.get("players_ids", ["player_0", "player_1"]) + self.player_0_id, self.player_1_id = self.players_ids + self.max_steps = config.get("max_steps", 1) + assert self.max_steps == 1 + self.output_additional_info = config.get( + "output_additional_info", True + ) + self.step_count_in_current_episode = None + + # To store info about the fraction of each states + if self.output_additional_info: + self._init_info() + + def seed(self, seed=None): + """Seed the PRNG of this space. """ + self.np_random, seed = seeding.np_random(seed) + return [seed] + + def reset(self): + self.step_count_in_current_episode = 0 + if self.output_additional_info: + self._reset_info() + return { + self.player_0_id: self.INIT_STATE_VALUE, + self.player_1_id: self.INIT_STATE_VALUE, + } + + def step(self, actions: dict): + """ + :param actions: Dict containing both actions for player_1 and player_2 + :return: observations, rewards, done, info + """ + self.step_count_in_current_episode += 1 + # print("actions", actions) + actions_player_0 = actions[self.player_0_id] + actions_player_1 = actions[self.player_1_id] + + if self.output_additional_info: + self._accumulate_info( + pl0_work_on_task0=actions_player_0[0], + pl0_cutoff=actions_player_0[1], + pl1_work_on_task0=actions_player_1[0], + pl1_cutoff=actions_player_1[1], + ) + + observations = self._produce_observations( + actions_player_0, actions_player_1 + ) + rewards = self._get_players_rewards(actions_player_0, actions_player_1) + epi_is_done = self.step_count_in_current_episode >= self.max_steps + if self.step_count_in_current_episode > self.max_steps: + logger.warning( + "self.step_count_in_current_episode >= self.max_steps" + ) + info = self._get_info_for_current_epi(epi_is_done) + return self._to_RLLib_API(observations, rewards, epi_is_done, info) + + def _produce_observations(self, actions_player_row, actions_player_col): + one_player_obs = np.array( + [ + actions_player_row, + actions_player_col, + ] + ) + return [one_player_obs, one_player_obs] + + def _get_players_rewards( + self, action_player_0: list, action_player_1: list + ): + pl0_work_on_task0 = action_player_0[0] + pl1_work_on_task0 = action_player_1[0] + pl0_work_on_task1 = 1 - pl0_work_on_task0 + pl1_work_on_task1 = 1 - pl1_work_on_task0 + cutoff_p0 = action_player_0[1] + cutoff_p1 = action_player_1[1] + # print("pl0_work_on_task0 - cutoff_p1", pl0_work_on_task0, cutoff_p1) + # print("pl1_work_on_task1 - cutoff_p0", pl1_work_on_task1, cutoff_p0) + accept_offer_task0 = (pl0_work_on_task0 - cutoff_p1) > 0 + accept_offer_task1 = (pl1_work_on_task1 - cutoff_p0) > 0 + # accept_offer_task1 = (cutoff_p0 - pl1_work_on_task0) > 0 + accept_offer = accept_offer_task0 and accept_offer_task1 + if not accept_offer: + return [0.0, 0.0] + else: + + r_pl0_from_task0 = SimpleBargaining._log_v_plus_one( + np.power(pl0_work_on_task0 + pl1_work_on_task0, self.PL0_T0) + ) + r_pl0_from_task1 = SimpleBargaining._log_v_plus_one( + np.power( + # player 1 is more productive at task 1 as seen by player 0 + pl0_work_on_task1 + self.G * pl1_work_on_task1, + self.PL0_T1, + ) + ) + r_pl1_from_task0 = SimpleBargaining._log_v_plus_one( + np.power( + # player 0 is more productive at task 0 as seen by player 1 + self.G * pl0_work_on_task0 + pl1_work_on_task0, + self.PL1_T0, + ) + ) + r_pl1_from_task1 = SimpleBargaining._log_v_plus_one( + np.power(pl0_work_on_task1 + pl1_work_on_task1, self.PL1_T1) + ) + # print( + # "r_pl0_from_task0", + # r_pl0_from_task0, + # "r_pl0_from_task1", + # r_pl0_from_task1, + # ) + # print( + # "r_pl1_from_task0", + # r_pl1_from_task0, + # "r_pl1_from_task1", + # r_pl1_from_task1, + # ) + reward_player0 = r_pl0_from_task0 + r_pl0_from_task1 + reward_player1 = r_pl1_from_task0 + r_pl1_from_task1 + return [reward_player0, reward_player1] + + @staticmethod + def _log_v_plus_one(v): + assert (v + 1) > 0, f"v: {v}" + return np.log(v + 1) + + def _to_RLLib_API( + self, observations: list, rewards: list, epi_is_done: bool, info: dict + ): + + observations = { + self.player_0_id: observations[0], + self.player_1_id: observations[1], + } + + rewards = { + self.player_0_id: rewards[0], + self.player_1_id: rewards[1], + } + + if info is None: + info = {} + else: + info = {self.player_0_id: info, self.player_1_id: info} + + done = { + self.player_0_id: epi_is_done, + self.player_1_id: epi_is_done, + "__all__": epi_is_done, + } + # print("observations", observations) + return observations, rewards, done, info + + def _get_info_for_current_epi(self, epi_is_done): + if epi_is_done and self.output_additional_info: + info_for_current_epi = self._get_episode_info() + else: + info_for_current_epi = None + return info_for_current_epi + + def __str__(self): + return self.NAME + + +if __name__ == "__main__": + env = SimpleBargaining({}) + v_range = np.arange(0.01, 0.99, 0.1) + v_range = np.round(v_range, 2) + print("v_range", v_range) + for pl_0_w in v_range: + for pl_0_c in v_range: + for pl_1_w in v_range: + for pl_1_c in v_range: + pl_0_w = round(pl_0_w, 2) + pl_0_c = round(pl_0_c, 2) + pl_1_w = round(pl_1_w, 2) + pl_1_c = round(pl_1_c, 2) + pl0_a = [pl_0_w, pl_0_c] + pl1_a = [pl_1_w, pl_1_c] + r_0, r_1 = env._get_players_rewards(pl0_a, pl1_a) + r_0 = round(r_0, 2) + r_1 = round(r_1, 2) + print("act", [pl0_a, pl1_a], "r", [r_0, r_1]) + + all_r = np.zeros((len(v_range), len(v_range), 2)) + print("all_r", all_r.shape) + for i, pl_0_w in enumerate(v_range): + for j, pl_1_w in enumerate(v_range): + pl_0_w = round(pl_0_w, 2) + pl_0_c = 0.0 + pl_1_w = round(pl_1_w, 2) + pl_1_c = 0.0 + pl0_a = [pl_0_w, pl_0_c] + pl1_a = [pl_1_w, pl_1_c] + r_0, r_1 = env._get_players_rewards(pl0_a, pl1_a) + r_0 = round(r_0, 2) + r_1 = round(r_1, 2) + all_r[i, j, :] = [r_0, r_1] + + import matplotlib.pyplot as plt + from marltoolbox.scripts.plot_meta_policies import ( + heatmap, + annotate_heatmap, + ) + + # plt.plot(all_r[..., 0]) + # plt.show() + + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 12)) + plt.suptitle("SimpleBargaining payoffs") + _, _ = heatmap( + all_r[..., 0], + v_range, + v_range, + ax=ax1, + cmap="YlGn", + cbarlabel="Reward player 1", + ) + _, _ = heatmap( + all_r[..., 1], + v_range, + v_range, + ax=ax2, + cmap="YlGn", + cbarlabel="Reward player 2", + ) + plt.show() diff --git a/marltoolbox/envs/simple_bargaining_open_spiel.py b/marltoolbox/envs/simple_bargaining_open_spiel.py new file mode 100644 index 0000000..b410267 --- /dev/null +++ b/marltoolbox/envs/simple_bargaining_open_spiel.py @@ -0,0 +1,211 @@ +import numpy as np +import pyspiel + +from marltoolbox.envs.simple_bargaining import SimpleBargaining + +RLLIB_SIMPLE_BARGAINING = SimpleBargaining({}) + +_NUM_PLAYERS = 2 +_N_DISCRETE = 11 +_N_ACTIONS = 2 +_GAME_TYPE = pyspiel.GameType( + short_name="python_simple_bargaining", + long_name="Python Simple Bargaining", + dynamics=pyspiel.GameType.Dynamics.SEQUENTIAL, + chance_mode=pyspiel.GameType.ChanceMode.DETERMINISTIC, + information=pyspiel.GameType.Information.ONE_SHOT, + utility=pyspiel.GameType.Utility.GENERAL_SUM, + reward_model=pyspiel.GameType.RewardModel.TERMINAL, + max_num_players=_NUM_PLAYERS, + min_num_players=_NUM_PLAYERS, + provides_information_state_string=True, + provides_information_state_tensor=True, + provides_observation_string=True, + provides_observation_tensor=True, + provides_factored_observation_string=True, +) +_GAME_INFO = pyspiel.GameInfo( + num_distinct_actions=_N_ACTIONS * _N_DISCRETE, + max_chance_outcomes=0, + num_players=_NUM_PLAYERS, + min_utility=0.0, + max_utility=3.0, + utility_sum=0.0, + max_game_length=1, +) + + +class SimpleBargainingGame(pyspiel.Game): + def __init__(self, params=None): + super().__init__(_GAME_TYPE, _GAME_INFO, params or dict()) + + def new_initial_state(self): + """Returns a state corresponding to the start of a game.""" + return SimpleBargainingState(self) + + def make_py_observer(self, iig_obs_type=None, params=None): + """Returns an object used for observing game state.""" + return SimpleBargainingObserver( + iig_obs_type or pyspiel.IIGObservationType(), params + ) + + +class SimpleBargainingState(pyspiel.State): + G = GAINS_FROM_TRADE_FACTOR = 3.0 + MULTIPLIER = 0.2 + PL0_T0, PL0_T1, PL1_T0, PL1_T1 = np.array([3, 9, 7, 2]) * MULTIPLIER + + def __init__(self, game): + """Constructor; should only be called by Game.new_initial_state.""" + super().__init__(game) + self._game_turn = 0 + self._actions_possible = list(range(_N_ACTIONS * _N_DISCRETE)) + self._game_over = False + self._next_player = 0 + self.action_player_0 = [None] * _N_ACTIONS + self.action_player_1 = [None] * _N_ACTIONS + self.rllib_game = RLLIB_SIMPLE_BARGAINING + self.verbose = False + if self.verbose: + print("start epi") + + # OpenSpiel (PySpiel) API functions are below. This is the standard set that + # should be implemented by every sequential-move game with chance. + + def current_player(self): + """Returns id of the next player to move, or TERMINAL if game is over.""" + if self._game_over: + return pyspiel.PlayerId.TERMINAL + else: + return self._next_player + + def legal_actions(self, player=None): + """Returns a list of legal actions, sorted in ascending order.""" + legal_actions = None + if self._game_turn == 4 or self._game_over: + legal_actions = [] + elif player is not None and player != self._next_player: + legal_actions = [] + elif self._game_turn == 0 or self._game_turn == 2: + legal_actions = self._actions_possible[:_N_DISCRETE] + elif self._game_turn == 1 or self._game_turn == 3: + legal_actions = self._actions_possible[_N_DISCRETE:] + if self.verbose: + print(self._game_turn, "legal_actions", legal_actions) + return legal_actions + + def chance_outcomes(self): + """Returns the possible chance outcomes and their probabilities.""" + raise NotImplementedError() + + def _apply_action(self, action): + """Applies the specified action to the state.""" + self._assert_turn_action(self._game_turn, action) + if self._game_turn == 0: + self.action_player_0[0] = self._get_float_values(action) + self._next_player = 0 + elif self._game_turn == 1: + self.action_player_0[1] = self._get_float_values(action) + self._next_player = 1 + elif self._game_turn == 2: + self.action_player_1[0] = self._get_float_values(action) + self._next_player = 1 + elif self._game_turn == 3: + self.action_player_1[1] = self._get_float_values(action) + self._game_over = True + self._next_player = 0 + else: + raise ValueError(f"self._game_over {self._game_over}") + if self.verbose: + print(self._game_turn, "_apply_action", action) + self._game_turn += 1 + + @staticmethod + def _assert_turn_action(i, action): + i = i % _N_ACTIONS + assert ( + (_N_DISCRETE * i) <= action <= (_N_DISCRETE * (i + 1)) + ), f"action {action}, _N_DISCRETE * i: {_N_DISCRETE * i}" + + @staticmethod + def _get_float_values(action): + return (action % _N_DISCRETE) / _N_DISCRETE + + def _action_to_string(self, player, action): + """Action -> string.""" + return f"{player}_{action}" + + def is_terminal(self): + """Returns True if the game is over.""" + return self._game_over + + def returns(self): + """Total reward for each player over the course of the game so far.""" + if not self._game_over: + return [0.0, 0.0] + + rewards = self.rllib_game._get_players_rewards( + self.action_player_0, self.action_player_1 + ) + + if self.verbose: + print(self._game_turn, "returns", rewards) + + return rewards + + def __str__(self): + """String for debug purposes. No particular semantics are required.""" + return f"{self.action_player_0}_{self.action_player_1}" + + +class SimpleBargainingObserver: + """Observer, conforming to the PyObserver interface (see observation.py).""" + + def __init__(self, iig_obs_type, params): + """Initializes an empty observation tensor.""" + if params: + raise ValueError( + f"Observation parameters not supported; passed {params}" + ) + + shape = (_N_DISCRETE,) + self.tensor = np.zeros(np.prod(shape), np.float32) + self.dict = {"observation": np.reshape(self.tensor, shape)} + + def set_from(self, state, player): + """Updates `tensor` and `dict` to reflect `state` from PoV of `player`.""" + obs = self.dict["observation"] + obs.fill(0) + if state._game_turn == 1: + if player == 0: + pl0_act0 = self._get_action_idx(state.action_player_0[0]) + obs[pl0_act0] = 1 + elif state._game_turn == 3: + if player == 1: + pl1_act0 = self._get_action_idx(state.action_player_1[0]) + obs[pl1_act0] = 1 + elif ( + state._game_turn == 0 + or state._game_turn == 2 + or state._game_turn == 4 + ): + pass + else: + raise ValueError(f"state._game_turn {state._game_turn}") + assert np.all(self.dict["observation"] == obs) + + @staticmethod + def _get_action_idx(value): + return int(value * _N_DISCRETE) + + def string_from(self, state, player): + """Observation of `state` from the PoV of `player`, as a string.""" + pieces = [] + pieces.extend(state.action_player_0) + pieces.extend(state.action_player_1) + return " ".join(str(p) for p in pieces) + + +# Register the game with the OpenSpiel library + +pyspiel.register_game(_GAME_TYPE, SimpleBargainingGame) diff --git a/marltoolbox/envs/ssd_mixed_motive_coin_game.py b/marltoolbox/envs/ssd_mixed_motive_coin_game.py index 0af66f7..03dbf84 100644 --- a/marltoolbox/envs/ssd_mixed_motive_coin_game.py +++ b/marltoolbox/envs/ssd_mixed_motive_coin_game.py @@ -40,6 +40,7 @@ class SSDMixedMotiveCoinGame(coin_game.CoinGame): @override(coin_game.CoinGame) def __init__(self, config: dict = {}): super().__init__(config) + self.punishment_helped = config.get("punishment_helped", False) self.OBSERVATION_SPACE = gym.spaces.Box( low=0, @@ -137,11 +138,11 @@ def _compute_reward(self): if self._same_pos(self.red_pos, self.red_coin_pos): if self.red_coin and self._same_pos( - self.blue_pos, self.red_coin_pos + self.blue_pos, self.red_coin_pos ): # Red coin is a coop coin generate_new_coin = True - reward_red += 1.2 + reward_red += 2.0 red_pick_any = True red_pick_red = True blue_pick_any = True @@ -152,13 +153,18 @@ def _compute_reward(self): reward_red += 1.0 red_pick_any = True red_pick_red = True + if self.punishment_helped and self._same_pos( + self.blue_pos, self.red_coin_pos + ): + reward_red -= 0.75 + elif self._same_pos(self.blue_pos, self.blue_coin_pos): if not self.red_coin and self._same_pos( - self.red_pos, self.blue_coin_pos + self.red_pos, self.blue_coin_pos ): # Blue coin is a coop coin generate_new_coin = True - reward_blue += 2.0 + reward_blue += 3.0 red_pick_any = True blue_pick_any = True blue_pick_blue = True @@ -169,6 +175,10 @@ def _compute_reward(self): reward_blue += 1.0 blue_pick_any = True blue_pick_blue = True + if self.punishment_helped and self._same_pos( + self.red_pos, self.blue_coin_pos + ): + reward_red -= 0.75 reward_list = [reward_red, reward_blue] if self.output_additional_info: @@ -185,37 +195,29 @@ def _compute_reward(self): @override(coin_game.CoinGame) def _init_info(self): - self.red_pick = [] - self.red_pick_own = [] - self.blue_pick = [] - self.blue_pick_own = [] + super()._init_info() self.picked_red_coop = [] self.picked_blue_coop = [] @override(coin_game.CoinGame) def _reset_info(self): - self.red_pick.clear() - self.red_pick_own.clear() - self.blue_pick.clear() - self.blue_pick_own.clear() + super()._reset_info() self.picked_red_coop.clear() self.picked_blue_coop.clear() @override(coin_game.CoinGame) def _accumulate_info( - self, - red_pick_any, - red_pick_red, - blue_pick_any, - blue_pick_blue, - picked_red_coop, - picked_blue_coop, + self, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + picked_red_coop, + picked_blue_coop, ): - - self.red_pick.append(red_pick_any) - self.red_pick_own.append(red_pick_red) - self.blue_pick.append(blue_pick_any) - self.blue_pick_own.append(blue_pick_blue) + super()._accumulate_info( + red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue + ) self.picked_red_coop.append(picked_red_coop) self.picked_blue_coop.append(picked_blue_coop) @@ -228,38 +230,25 @@ def _get_episode_info(self): pick_own_color is the fraction of coins picked by the player which have the same color as the player. """ - player_red_info, player_blue_info = {}, {} + player_red_info, player_blue_info = super()._get_episode_info() + n_steps_played = len(self.red_pick) assert n_steps_played == len(self.blue_pick) n_coop = sum(self.picked_blue_coop) + sum(self.picked_red_coop) if len(self.red_pick) > 0: red_pick = sum(self.red_pick) - player_red_info["pick_speed"] = red_pick / n_steps_played - if red_pick > 0: - player_red_info["pick_own_color"] = ( - sum(self.red_pick_own) / red_pick - ) - player_red_info["red_coop_speed"] = ( - sum(self.picked_red_coop) / n_steps_played + sum(self.picked_red_coop) / n_steps_played ) - if red_pick > 0: player_red_info["red_coop_fraction"] = n_coop / red_pick if len(self.blue_pick) > 0: blue_pick = sum(self.blue_pick) - player_blue_info["pick_speed"] = blue_pick / n_steps_played - if blue_pick > 0: - player_blue_info["pick_own_color"] = ( - sum(self.blue_pick_own) / blue_pick - ) - player_blue_info["blue_coop_speed"] = ( - sum(self.picked_blue_coop) / n_steps_played + sum(self.picked_blue_coop) / n_steps_played ) - if blue_pick > 0: player_blue_info["blue_coop_fraction"] = n_coop / blue_pick diff --git a/marltoolbox/envs/utils/mixins.py b/marltoolbox/envs/utils/mixins.py index f19b313..83fd0ed 100644 --- a/marltoolbox/envs/utils/mixins.py +++ b/marltoolbox/envs/utils/mixins.py @@ -1,7 +1,12 @@ +import logging from abc import ABC +import numpy as np + from marltoolbox.envs.utils.interfaces import InfoAccumulationInterface +logger = logging.getLogger(__name__) + class TwoPlayersTwoActionsInfoMixin(InfoAccumulationInterface, ABC): """ @@ -65,3 +70,37 @@ def _accumulate_info(self, *actions): self.info_counters[id] = 0 self.info_counters[id] += 1 self.info_counters["n_steps_accumulated"] += 1 + + +class NPlayersNContinuousActionsInfoMixin(InfoAccumulationInterface, ABC): + """ + Mixin class to add logging capability in N player games with continuous + actions. + Logs the mean and std of action profiles used + (action profile: the set of actions used during one step by all players). + """ + + logger.warning( + "MIXING NPlayersNContinuousActionsInfoMixin NOT DEBBUGED, NOT TESTED" + ) + + def _init_info(self): + self.data_accumulated = {} + + def _reset_info(self): + self.data_accumulated = {} + + def _get_episode_info(self): + info = {} + for k, v in self.data_accumulated.items(): + array = np.array(v) + info[f"{k}_mean"] = array.mean() + info[f"{k}_std"] = array.std() + + return info + + def _accumulate_info(self, **kwargs_actions): + for k, v in kwargs_actions.items(): + if k not in self.data_accumulated.keys(): + self.data_accumulated[k] = [] + self.data_accumulated[k].append(v) diff --git a/marltoolbox/envs/vectorized_coin_game.py b/marltoolbox/envs/vectorized_coin_game.py index 5f9f4f6..0fed500 100644 --- a/marltoolbox/envs/vectorized_coin_game.py +++ b/marltoolbox/envs/vectorized_coin_game.py @@ -28,17 +28,21 @@ def __init__(self, config={}): self.batch_size = config.get("batch_size", 1) self.force_vectorized = config.get("force_vectorize", False) - assert self.grid_size == 3, \ - "hardcoded in the generate_state numba function" + self.punishment_helped = config.get("punishment_helped", False) + assert ( + self.grid_size == 3 + ), "hardcoded in the generate_state numba function" @override(coin_game.CoinGame) def _randomize_color_and_player_positions(self): # Reset coin color and the players and coin positions - self.red_coin = np.random.randint(2, size=self.batch_size) + self.red_coin = np.random.randint(low=0, high=2, size=self.batch_size) self.red_pos = np.random.randint( - self.grid_size, size=(self.batch_size, 2)) + self.grid_size, size=(self.batch_size, 2) + ) self.blue_pos = np.random.randint( - self.grid_size, size=(self.batch_size, 2)) + self.grid_size, size=(self.batch_size, 2) + ) self.coin_pos = np.zeros((self.batch_size, 2), dtype=np.int8) self._players_do_not_overlap_at_start() @@ -52,18 +56,30 @@ def _players_do_not_overlap_at_start(self): @override(coin_game.CoinGame) def _generate_coin(self): generate = np.ones(self.batch_size, dtype=bool) - self.coin_pos = generate_coin_wt_numba_optimization( - self.batch_size, generate, self.red_coin, self.red_pos, - self.blue_pos, self.coin_pos, self.grid_size) + self.coin_pos, self.red_coin = generate_coin_wt_numba_optimization( + self.batch_size, + generate, + self.red_coin, + self.red_pos, + self.blue_pos, + self.coin_pos, + self.grid_size, + ) @override(coin_game.CoinGame) def _generate_observation(self): obs = generate_observations_wt_numba_optimization( - self.batch_size, self.red_pos, self.blue_pos, self.coin_pos, - self.red_coin, self.grid_size) + self.batch_size, + self.red_pos, + self.blue_pos, + self.coin_pos, + self.red_coin, + self.grid_size, + ) obs = self._apply_optional_invariance_to_the_player_trained(obs) obs, _ = self._optional_unvectorize(obs) + # print("env nonzero obs", np.nonzero(obs[0])) return obs def _optional_unvectorize(self, obs, rewards=None): @@ -75,49 +91,65 @@ def _optional_unvectorize(self, obs, rewards=None): @override(coin_game.CoinGame) def step(self, actions: Iterable): - + # print("step") + # print("env self.red_coin", self.red_coin) + # print("env self.red_pos", self.red_pos) + # print("env self.blue_pos", self.blue_pos) + # print("env self.coin_pos", self.coin_pos) actions = self._from_RLLib_API_to_list(actions) self.step_count_in_current_episode += 1 - (self.red_pos, self.blue_pos, rewards, self.coin_pos, observation, - self.red_coin, red_pick_any, red_pick_red, blue_pick_any, - blue_pick_blue) = vectorized_step_wt_numba_optimization( - actions, self.batch_size, self.red_pos, self.blue_pos, - self.coin_pos, self.red_coin, self.grid_size, self.asymmetric, - self.max_steps, self.both_players_can_pick_the_same_coin) + ( + self.red_pos, + self.blue_pos, + rewards, + self.coin_pos, + observation, + self.red_coin, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + ) = vectorized_step_wt_numba_optimization( + actions, + self.batch_size, + self.red_pos, + self.blue_pos, + self.coin_pos, + self.red_coin, + self.grid_size, + self.asymmetric, + self.max_steps, + self.both_players_can_pick_the_same_coin, + self.punishment_helped, + ) if self.output_additional_info: self._accumulate_info( - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue) + red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue + ) obs = self._apply_optional_invariance_to_the_player_trained( - observation) + observation + ) obs, rewards = self._optional_unvectorize(obs, rewards) - + # print("env actions", actions) + # print("env rewards", rewards) + # print("env self.red_coin", self.red_coin) + # print("env self.red_pos", self.red_pos) + # print("env self.blue_pos", self.blue_pos) + # print("env self.coin_pos", self.coin_pos) + # print("env nonzero obs", np.nonzero(obs[0])) return self._to_RLLib_API(obs, rewards) @override(coin_game.CoinGame) - def _get_episode_info(self): - - player_red_info, player_blue_info = {}, {} - - if len(self.red_pick) > 0: - red_pick = sum(self.red_pick) - player_red_info["pick_speed"] = \ - red_pick / (len(self.red_pick) * self.batch_size) - if red_pick > 0: - player_red_info["pick_own_color"] = \ - sum(self.red_pick_own) / red_pick - - if len(self.blue_pick) > 0: - blue_pick = sum(self.blue_pick) - player_blue_info["pick_speed"] = \ - blue_pick / (len(self.blue_pick) * self.batch_size) - if blue_pick > 0: - player_blue_info["pick_own_color"] = \ - sum(self.blue_pick_own) / blue_pick - - return player_red_info, player_blue_info + def _get_episode_info(self, n_steps_played=None): + n_steps_played = ( + n_steps_played + if n_steps_played is not None + else len(self.red_pick) * self.batch_size + ) + return super()._get_episode_info(n_steps_played=n_steps_played) @override(coin_game.CoinGame) def _from_RLLib_API_to_list(self, actions): @@ -140,15 +172,13 @@ def _save_env(self): "grid_size": self.grid_size, "asymmetric": self.asymmetric, "batch_size": self.batch_size, - "step_count_in_current_episode": - self.step_count_in_current_episode, + "step_count_in_current_episode": self.step_count_in_current_episode, "max_steps": self.max_steps, "red_pick": self.red_pick, "red_pick_own": self.red_pick_own, "blue_pick": self.blue_pick, "blue_pick_own": self.blue_pick_own, - "both_players_can_pick_the_same_coin": - self.both_players_can_pick_the_same_coin, + "both_players_can_pick_the_same_coin": self.both_players_can_pick_the_same_coin, } return copy.deepcopy(env_save_state) @@ -170,94 +200,149 @@ def __init__(self, config={}): @jit(nopython=True) def vectorized_step_wt_numba_optimization( - actions, batch_size, red_pos, blue_pos, coin_pos, red_coin, - grid_size: int, asymmetric: bool, max_steps: int, - both_players_can_pick_the_same_coin: bool): + actions, + batch_size, + red_pos, + blue_pos, + coin_pos, + red_coin, + grid_size: int, + asymmetric: bool, + max_steps: int, + both_players_can_pick_the_same_coin: bool, + punishment_helped: bool, +): red_pos, blue_pos = move_players( - batch_size, actions, red_pos, blue_pos, grid_size) + batch_size, actions, red_pos, blue_pos, grid_size + ) - reward, generate, red_pick_any, red_pick_red, \ - blue_pick_any, blue_pick_blue = compute_reward( - batch_size, red_pos, blue_pos, coin_pos, red_coin, - asymmetric, both_players_can_pick_the_same_coin) + ( + reward, + generate, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + ) = compute_reward( + batch_size, + red_pos, + blue_pos, + coin_pos, + red_coin, + asymmetric, + both_players_can_pick_the_same_coin, + punishment_helped, + ) - coin_pos = generate_coin_wt_numba_optimization( - batch_size, generate, red_coin, red_pos, blue_pos, coin_pos, grid_size) + coin_pos, red_coin = generate_coin_wt_numba_optimization( + batch_size, generate, red_coin, red_pos, blue_pos, coin_pos, grid_size + ) obs = generate_observations_wt_numba_optimization( - batch_size, red_pos, blue_pos, coin_pos, red_coin, grid_size) + batch_size, red_pos, blue_pos, coin_pos, red_coin, grid_size + ) - return red_pos, blue_pos, reward, coin_pos, obs, red_coin, red_pick_any, \ - red_pick_red, blue_pick_any, blue_pick_blue + return ( + red_pos, + blue_pos, + reward, + coin_pos, + obs, + red_coin, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + ) @jit(nopython=True) def move_players(batch_size, actions, red_pos, blue_pos, grid_size): - moves = List([ - np.array([0, 1]), - np.array([0, -1]), - np.array([1, 0]), - np.array([-1, 0]), - ]) + moves = List( + [ + np.array([0, 1]), + np.array([0, -1]), + np.array([1, 0]), + np.array([-1, 0]), + ] + ) for j in prange(batch_size): - red_pos[j] = \ - (red_pos[j] + moves[actions[j, 0]]) % grid_size - blue_pos[j] = \ - (blue_pos[j] + moves[actions[j, 1]]) % grid_size + red_pos[j] = (red_pos[j] + moves[actions[j, 0]]) % grid_size + blue_pos[j] = (blue_pos[j] + moves[actions[j, 1]]) % grid_size return red_pos, blue_pos @jit(nopython=True) -def compute_reward(batch_size, red_pos, blue_pos, coin_pos, red_coin, - asymmetric, both_players_can_pick_the_same_coin): +def compute_reward( + batch_size, + red_pos, + blue_pos, + coin_pos, + red_coin, + asymmetric, + both_players_can_pick_the_same_coin, + punishment_helped, +): reward_red = np.zeros(batch_size) reward_blue = np.zeros(batch_size) generate = np.zeros(batch_size, dtype=np.bool_) - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = \ - 0, 0, 0, 0 + red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = 0, 0, 0, 0 for i in prange(batch_size): red_first_if_both = None if not both_players_can_pick_the_same_coin: - if _same_pos(red_pos[i], coin_pos[i]) and \ - _same_pos(blue_pos[i], coin_pos[i]): + if _same_pos(red_pos[i], coin_pos[i]) and _same_pos( + blue_pos[i], coin_pos[i] + ): red_first_if_both = bool(np.random.randint(low=0, high=2)) if red_coin[i]: - if _same_pos(red_pos[i], coin_pos[i]) and \ - (red_first_if_both is None or red_first_if_both): + if _same_pos(red_pos[i], coin_pos[i]) and ( + red_first_if_both is None or red_first_if_both + ): generate[i] = True reward_red[i] += 1 if asymmetric: reward_red[i] += 3 red_pick_any += 1 red_pick_red += 1 - if _same_pos(blue_pos[i], coin_pos[i]) and \ - (red_first_if_both is None or not red_first_if_both): + if _same_pos(blue_pos[i], coin_pos[i]) and ( + red_first_if_both is None or not red_first_if_both + ): generate[i] = True reward_red[i] += -2 reward_blue[i] += 1 blue_pick_any += 1 + if asymmetric and punishment_helped: + reward_red[i] -= 6 else: - if _same_pos(red_pos[i], coin_pos[i]) and \ - (red_first_if_both is None or red_first_if_both): + if _same_pos(red_pos[i], coin_pos[i]) and ( + red_first_if_both is None or red_first_if_both + ): generate[i] = True reward_red[i] += 1 reward_blue[i] += -2 if asymmetric: reward_red[i] += 3 red_pick_any += 1 - if _same_pos(blue_pos[i], coin_pos[i]) and \ - (red_first_if_both is None or not red_first_if_both): + if _same_pos(blue_pos[i], coin_pos[i]) and ( + red_first_if_both is None or not red_first_if_both + ): generate[i] = True reward_blue[i] += 1 blue_pick_any += 1 blue_pick_blue += 1 reward = [reward_red, reward_blue] - return reward, generate, \ - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue + return ( + reward, + generate, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + ) @jit(nopython=True) @@ -267,13 +352,13 @@ def _same_pos(x, y): @jit(nopython=True) def generate_coin_wt_numba_optimization( - batch_size, generate, red_coin, red_pos, blue_pos, coin_pos, - grid_size): + batch_size, generate, red_coin, red_pos, blue_pos, coin_pos, grid_size +): red_coin[generate] = 1 - red_coin[generate] for i in prange(batch_size): if generate[i]: coin_pos[i] = _place_coin(red_pos[i], blue_pos[i], grid_size) - return coin_pos + return coin_pos, red_coin @jit(nopython=True) @@ -303,8 +388,9 @@ def _unflatten_index(pos, grid_size): @jit(nopython=True) -def generate_observations_wt_numba_optimization(batch_size, red_pos, blue_pos, - coin_pos, red_coin, grid_size): +def generate_observations_wt_numba_optimization( + batch_size, red_pos, blue_pos, coin_pos, red_coin, grid_size +): obs = np.zeros((batch_size, grid_size, grid_size, 4)) for i in prange(batch_size): obs[i, red_pos[i][0], red_pos[i][1], 0] = 1 diff --git a/marltoolbox/envs/vectorized_mixed_motive_coin_game.py b/marltoolbox/envs/vectorized_mixed_motive_coin_game.py deleted file mode 100644 index ced01bd..0000000 --- a/marltoolbox/envs/vectorized_mixed_motive_coin_game.py +++ /dev/null @@ -1,200 +0,0 @@ -import copy -import logging -from collections import Iterable - -import numpy as np -from numba import jit, prange -from ray.rllib.utils import override - -from marltoolbox.envs import vectorized_coin_game -from marltoolbox.envs.vectorized_coin_game import \ - _flatten_index, _unflatten_index, _same_pos, move_players - -logger = logging.getLogger(__name__) - -PLOT_KEYS = vectorized_coin_game.PLOT_KEYS - -PLOT_ASSEMBLAGE_TAGS = vectorized_coin_game.PLOT_ASSEMBLAGE_TAGS - - -class VectMixedMotiveCG(vectorized_coin_game.VectorizedCoinGame): - - @override(vectorized_coin_game.VectorizedCoinGame) - def _load_config(self, config): - super()._load_config(config) - assert self.both_players_can_pick_the_same_coin, \ - "both_players_can_pick_the_same_coin option must be True in " \ - "mixed motive coin game." - - @override(vectorized_coin_game.VectorizedCoinGame) - def _randomize_color_and_player_positions(self): - # Reset coin color and the players and coin positions - self.red_pos = np.random.randint( - self.grid_size, size=(self.batch_size, 2)) - self.blue_pos = np.random.randint( - self.grid_size, size=(self.batch_size, 2)) - self.red_coin_pos = np.zeros((self.batch_size, 2), dtype=np.int8) - self.blue_coin_pos = np.zeros((self.batch_size, 2), dtype=np.int8) - - self._players_do_not_overlap_at_start() - - @override(vectorized_coin_game.VectorizedCoinGame) - def _generate_coin(self): - generate = np.ones(self.batch_size, dtype=np.int_) * 3 - self.red_coin_pos, self.blue_coin_pos = \ - generate_coin_wt_numba_optimization( - self.batch_size, generate, self.red_coin_pos, - self.blue_coin_pos, self.red_pos, self.blue_pos, - self.grid_size) - - @override(vectorized_coin_game.VectorizedCoinGame) - def _generate_observation(self): - obs = generate_observations_wt_numba_optimization( - self.batch_size, self.red_pos, self.blue_pos, self.red_coin_pos, - self.blue_coin_pos, self.grid_size) - - obs = self._apply_optional_invariance_to_the_player_trained(obs) - obs, _ = self._optional_unvectorize(obs) - return obs - - @override(vectorized_coin_game.VectorizedCoinGame) - def step(self, actions: Iterable): - actions = self._from_RLLib_API_to_list(actions) - self.step_count_in_current_episode += 1 - - (self.red_pos, self.blue_pos, rewards, self.red_coin_pos, - self.blue_coin_pos, observation, red_pick_any, red_pick_red, - blue_pick_any, blue_pick_blue) = \ - vectorized_step_wt_numba_optimization( - actions, self.batch_size, self.red_pos, self.blue_pos, - self.red_coin_pos, self.blue_coin_pos, self.grid_size) - - if self.output_additional_info: - self._accumulate_info( - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue) - - obs = self._apply_optional_invariance_to_the_player_trained( - observation) - obs, rewards = self._optional_unvectorize(obs, rewards) - - return self._to_RLLib_API(obs, rewards) - - @override(vectorized_coin_game.VectorizedCoinGame) - def _save_env(self): - env_save_state = { - "red_pos": self.red_pos, - "blue_pos": self.blue_pos, - "red_coin_pos": self.red_coin_pos, - "blue_coin_pos": self.blue_coin_pos, - "grid_size": self.grid_size, - "batch_size": self.batch_size, - "step_count_in_current_episode": - self.step_count_in_current_episode, - "max_steps": self.max_steps, - "red_pick": self.red_pick, - "red_pick_own": self.red_pick_own, - "blue_pick": self.blue_pick, - "blue_pick_own": self.blue_pick_own, - } - return copy.deepcopy(env_save_state) - - -@jit(nopython=True) -def vectorized_step_wt_numba_optimization( - actions, batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos, - grid_size: int): - red_pos, blue_pos = move_players( - batch_size, actions, red_pos, blue_pos, grid_size) - - reward, generate, red_pick_any, red_pick_red, \ - blue_pick_any, blue_pick_blue = compute_reward( - batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos) - - red_coin_pos, blue_coin_pos = generate_coin_wt_numba_optimization( - batch_size, generate, red_coin_pos, blue_coin_pos, - red_pos, blue_pos, grid_size) - - obs = generate_observations_wt_numba_optimization( - batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos, grid_size) - - return red_pos, blue_pos, reward, red_coin_pos, blue_coin_pos, obs, \ - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue - - -@jit(nopython=True) -def compute_reward(batch_size, red_pos, blue_pos, red_coin_pos, blue_coin_pos): - reward_red = np.zeros(batch_size) - reward_blue = np.zeros(batch_size) - generate = np.zeros(batch_size, dtype=np.int_) - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = \ - 0, 0, 0, 0 - - for i in prange(batch_size): - if _same_pos(red_pos[i], red_coin_pos[i]) and \ - _same_pos(blue_pos[i], red_coin_pos[i]): - generate[i] = 1 - reward_red[i] += 2 - reward_blue[i] += 2 - red_pick_any += 1 - red_pick_red += 1 - blue_pick_any += 1 - elif _same_pos(red_pos[i], blue_coin_pos[i]) and \ - _same_pos(blue_pos[i], blue_coin_pos[i]): - generate[i] = 2 - reward_red[i] += 1 - reward_blue[i] += 4 - red_pick_any += 1 - blue_pick_any += 1 - blue_pick_blue += 1 - - reward = [reward_red, reward_blue] - - return reward, generate, \ - red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue - - -@jit(nopython=True) -def generate_coin_wt_numba_optimization( - batch_size, generate, red_coin_pos, blue_coin_pos, red_pos, blue_pos, - grid_size): - for i in prange(batch_size): - # generate:0 => no coin generation - # generate:1 => red coin generation - # generate:2 => blue coin generation - # generate:0 => red & blue coin generation - - if generate[i] == 3 or generate[i] == 1: - red_coin_pos[i] = _place_coin(red_pos[i], blue_pos[i], - grid_size, blue_coin_pos[i]) - if generate[i] == 3 or generate[i] == 2: - blue_coin_pos[i] = _place_coin(red_pos[i], blue_pos[i], - grid_size, red_coin_pos[i]) - return red_coin_pos, blue_coin_pos - - -@jit(nopython=True) -def _place_coin(red_pos_i, blue_pos_i, grid_size, other_coin_pos_i): - red_pos_flat = _flatten_index(red_pos_i, grid_size) - blue_pos_flat = _flatten_index(blue_pos_i, grid_size) - other_coin_pos_flat = _flatten_index(other_coin_pos_i, grid_size) - possible_coin_pos = np.array( - [x for x in range(9) - if ((x != blue_pos_flat) and - (x != red_pos_flat) and - (x != other_coin_pos_flat))] - ) - flat_coin_pos = np.random.choice(possible_coin_pos) - return _unflatten_index(flat_coin_pos, grid_size) - - -@jit(nopython=True) -def generate_observations_wt_numba_optimization(batch_size, red_pos, blue_pos, - red_coin_pos, blue_coin_pos, - grid_size): - obs = np.zeros((batch_size, grid_size, grid_size, 4)) - for i in prange(batch_size): - obs[i, red_pos[i][0], red_pos[i][1], 0] = 1 - obs[i, blue_pos[i][0], blue_pos[i][1], 1] = 1 - obs[i, red_coin_pos[i][0], red_coin_pos[i][1], 2] = 1 - obs[i, blue_coin_pos[i][0], blue_coin_pos[i][1], 3] = 1 - return obs diff --git a/marltoolbox/envs/vectorized_ssd_mm_coin_game.py b/marltoolbox/envs/vectorized_ssd_mm_coin_game.py new file mode 100644 index 0000000..40b82fe --- /dev/null +++ b/marltoolbox/envs/vectorized_ssd_mm_coin_game.py @@ -0,0 +1,455 @@ +import copy +import logging +from collections import Iterable + +import gym +import numpy as np +from numba import jit, prange +from ray.rllib.utils import override + +from marltoolbox.envs import vectorized_coin_game +from marltoolbox.envs.vectorized_coin_game import ( + _flatten_index, + _unflatten_index, + _same_pos, + move_players, +) + +logger = logging.getLogger(__name__) + +PLOT_KEYS = vectorized_coin_game.PLOT_KEYS + +PLOT_ASSEMBLAGE_TAGS = vectorized_coin_game.PLOT_ASSEMBLAGE_TAGS + + +class VectSSDMixedMotiveCG( + vectorized_coin_game.VectorizedCoinGame, +): + def __init__(self, config: dict = {}): + super().__init__(config) + self.punishment_helped = config.get("punishment_helped", False) + + self.OBSERVATION_SPACE = gym.spaces.Box( + low=0, + high=1, + shape=(self.grid_size, self.grid_size, 6), + dtype="uint8", + ) + + @override(vectorized_coin_game.VectorizedCoinGame) + def _load_config(self, config): + super()._load_config(config) + assert self.both_players_can_pick_the_same_coin, ( + "both_players_can_pick_the_same_coin option must be True in " + "ssd mixed motive coin game." + ) + assert self.same_obs_for_each_player, ( + "same_obs_for_each_player option must be True in " + "ssd mixed motive coin game." + ) + + @override(vectorized_coin_game.VectorizedCoinGame) + def _randomize_color_and_player_positions(self): + # Reset coin color and the players and coin positions + self.red_pos = np.random.randint( + self.grid_size, size=(self.batch_size, 2) + ) + self.blue_pos = np.random.randint( + self.grid_size, size=(self.batch_size, 2) + ) + self.red_coin = np.random.randint( + 2, size=self.batch_size, dtype=np.int8 + ) + self.red_coin_pos = np.zeros((self.batch_size, 2), dtype=np.int8) + self.blue_coin_pos = np.zeros((self.batch_size, 2), dtype=np.int8) + + self._players_do_not_overlap_at_start() + + @override(vectorized_coin_game.VectorizedCoinGame) + def _generate_coin(self): + self._generate_coin() + + @override(vectorized_coin_game.VectorizedCoinGame) + def _generate_coin(self): + generate = np.ones(self.batch_size, dtype=np.int_) + ( + self.red_coin_pos, + self.blue_coin_pos, + self.red_coin, + ) = generate_coin_wt_numba_optimization( + self.batch_size, + generate, + self.red_coin_pos, + self.blue_coin_pos, + self.red_pos, + self.blue_pos, + self.grid_size, + self.red_coin, + ) + + @override(vectorized_coin_game.VectorizedCoinGame) + def _generate_observation(self): + obs = generate_observations_wt_numba_optimization( + self.batch_size, + self.red_pos, + self.blue_pos, + self.red_coin_pos, + self.blue_coin_pos, + self.grid_size, + self.red_coin, + ) + + obs = self._apply_optional_invariance_to_the_player_trained(obs) + obs, _ = self._optional_unvectorize(obs) + return obs + + @override(vectorized_coin_game.VectorizedCoinGame) + def step(self, actions: Iterable): + actions = self._from_RLLib_API_to_list(actions) + self.step_count_in_current_episode += 1 + + ( + self.red_pos, + self.blue_pos, + rewards, + self.red_coin_pos, + self.blue_coin_pos, + observation, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + picked_red_coop, + picked_blue_coop, + self.red_coin, + ) = vectorized_step_wt_numba_optimization( + actions, + self.batch_size, + self.red_pos, + self.blue_pos, + self.red_coin_pos, + self.blue_coin_pos, + self.grid_size, + self.red_coin, + self.punishment_helped, + ) + + if self.output_additional_info: + self._accumulate_info( + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + picked_red_coop, + picked_blue_coop, + ) + + obs = self._apply_optional_invariance_to_the_player_trained( + observation + ) + obs, rewards = self._optional_unvectorize(obs, rewards) + + return self._to_RLLib_API(obs, rewards) + + @override(vectorized_coin_game.VectorizedCoinGame) + def _save_env(self): + env_save_state = { + "red_pos": self.red_pos, + "blue_pos": self.blue_pos, + "red_coin_pos": self.red_coin_pos, + "blue_coin_pos": self.blue_coin_pos, + "red_coin": self.red_coin, + "grid_size": self.grid_size, + "batch_size": self.batch_size, + "step_count_in_current_episode": self.step_count_in_current_episode, + "max_steps": self.max_steps, + "red_pick": self.red_pick, + "red_pick_own": self.red_pick_own, + "blue_pick": self.blue_pick, + "blue_pick_own": self.blue_pick_own, + } + return copy.deepcopy(env_save_state) + + @override(vectorized_coin_game.VectorizedCoinGame) + def _init_info(self): + super()._init_info() + self.picked_red_coop = [] + self.picked_blue_coop = [] + + @override(vectorized_coin_game.VectorizedCoinGame) + def _reset_info(self): + super()._reset_info() + self.picked_red_coop.clear() + self.picked_blue_coop.clear() + + @override(vectorized_coin_game.VectorizedCoinGame) + def _accumulate_info( + self, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + picked_red_coop, + picked_blue_coop, + ): + super()._accumulate_info( + red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue + ) + self.picked_red_coop.append(picked_red_coop) + self.picked_blue_coop.append(picked_blue_coop) + + @override(vectorized_coin_game.VectorizedCoinGame) + def _get_episode_info(self): + """ + Output the following information: + pick_speed is the fraction of steps during which the player picked a + coin. + pick_own_color is the fraction of coins picked by the player which have + the same color as the player. + """ + n_steps_played = len(self.red_pick) * self.batch_size + player_red_info, player_blue_info = super()._get_episode_info( + n_steps_played=n_steps_played + ) + n_coop = sum(self.picked_blue_coop) + sum(self.picked_red_coop) + + if len(self.red_pick) > 0: + red_pick = sum(self.red_pick) + player_red_info["red_coop_speed"] = ( + sum(self.picked_red_coop) / n_steps_played + ) + if red_pick > 0: + player_red_info["red_coop_fraction"] = n_coop / red_pick + + if len(self.blue_pick) > 0: + blue_pick = sum(self.blue_pick) + player_blue_info["blue_coop_speed"] = ( + sum(self.picked_blue_coop) / n_steps_played + ) + if blue_pick > 0: + player_blue_info["blue_coop_fraction"] = n_coop / blue_pick + + return player_red_info, player_blue_info + + def _save_env(self): + raise NotImplementedError() + + def _load_env(self, env_state): + raise NotImplementedError() + + +@jit(nopython=True) +def _place_coin(red_pos_i, blue_pos_i, grid_size, other_coin_pos_i): + red_pos_flat = _flatten_index(red_pos_i, grid_size) + blue_pos_flat = _flatten_index(blue_pos_i, grid_size) + other_coin_pos_flat = _flatten_index(other_coin_pos_i, grid_size) + possible_coin_pos = np.array( + [ + x + for x in range(9) + if ( + (x != blue_pos_flat) + and (x != red_pos_flat) + and (x != other_coin_pos_flat) + ) + ] + ) + flat_coin_pos = np.random.choice(possible_coin_pos) + return _unflatten_index(flat_coin_pos, grid_size) + + +@jit(nopython=True) +def vectorized_step_wt_numba_optimization( + actions, + batch_size, + red_pos, + blue_pos, + red_coin_pos, + blue_coin_pos, + grid_size: int, + red_coin, + punishment_helped, +): + red_pos, blue_pos = move_players( + batch_size, actions, red_pos, blue_pos, grid_size + ) + + ( + reward, + generate, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + picked_red_coop, + picked_blue_coop, + ) = compute_reward( + batch_size, + red_pos, + blue_pos, + red_coin_pos, + blue_coin_pos, + red_coin, + punishment_helped, + ) + + ( + red_coin_pos, + blue_coin_pos, + red_coin, + ) = generate_coin_wt_numba_optimization( + batch_size, + generate, + red_coin_pos, + blue_coin_pos, + red_pos, + blue_pos, + grid_size, + red_coin, + ) + + obs = generate_observations_wt_numba_optimization( + batch_size, + red_pos, + blue_pos, + red_coin_pos, + blue_coin_pos, + grid_size, + red_coin, + ) + + return ( + red_pos, + blue_pos, + reward, + red_coin_pos, + blue_coin_pos, + obs, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + picked_red_coop, + picked_blue_coop, + red_coin, + ) + + +@jit(nopython=True) +def compute_reward( + batch_size, + red_pos, + blue_pos, + red_coin_pos, + blue_coin_pos, + red_coin, + punishment_helped, +): + reward_red = np.zeros(batch_size) + reward_blue = np.zeros(batch_size) + generate = np.zeros(batch_size, dtype=np.int_) + red_pick_any, red_pick_red, blue_pick_any, blue_pick_blue = 0, 0, 0, 0 + picked_red_coop, picked_blue_coop = 0, 0 + + for i in prange(batch_size): + if _same_pos(red_pos[i], red_coin_pos[i]): + if red_coin[i] and _same_pos(blue_pos[i], red_coin_pos[i]): + # Red coin is a coop coin + generate[i] = True + reward_red[i] += 2.0 + red_pick_any += 1 + red_pick_red += 1 + blue_pick_any += 1 + picked_red_coop += 1 + elif not red_coin[i]: + # Red coin is a selfish coin + generate[i] = True + reward_red[i] += 1.0 + red_pick_any += 1 + red_pick_red += 1 + if punishment_helped and _same_pos( + blue_pos[i], red_coin_pos[i] + ): + reward_red[i] -= 0.75 + elif _same_pos(blue_pos[i], blue_coin_pos[i]): + if not red_coin[i] and _same_pos(red_pos[i], blue_coin_pos[i]): + # Blue coin is a coop coin + generate[i] = True + reward_blue[i] += 3.0 + red_pick_any += 1 + blue_pick_any += 1 + blue_pick_blue += 1 + picked_blue_coop += 1 + elif red_coin[i]: + # Blue coin is a selfish coin + generate[i] = True + reward_blue[i] += 1.0 + blue_pick_any += 1 + blue_pick_blue += 1 + if punishment_helped and _same_pos( + red_pos[i], blue_coin_pos[i] + ): + reward_blue[i] -= 0.75 + + reward = [reward_red, reward_blue] + + return ( + reward, + generate, + red_pick_any, + red_pick_red, + blue_pick_any, + blue_pick_blue, + picked_red_coop, + picked_blue_coop, + ) + + +@jit(nopython=True) +def generate_observations_wt_numba_optimization( + batch_size, + red_pos, + blue_pos, + red_coin_pos, + blue_coin_pos, + grid_size, + red_coin, +): + obs = np.zeros((batch_size, grid_size, grid_size, 6)) + for i in prange(batch_size): + obs[i, red_pos[i][0], red_pos[i][1], 0] = 1 + obs[i, blue_pos[i][0], blue_pos[i][1], 1] = 1 + if red_coin[i]: + # Feature 4th is for the red cooperative coin + obs[i, red_coin_pos[i][0], red_coin_pos[i][1], 4] = 1 + # Feature 3th is for the blue selfish coin + obs[i, blue_coin_pos[i][0], blue_coin_pos[i][1], 3] = 1 + else: + # Feature 2th is for the red selfish coin + obs[i, red_coin_pos[i][0], red_coin_pos[i][1], 2] = 1 + # Feature 5th is for the blue cooperative coin + obs[i, blue_coin_pos[i][0], blue_coin_pos[i][1], 5] = 1 + return obs + + +@jit(nopython=True) +def generate_coin_wt_numba_optimization( + batch_size, + generate, + red_coin_pos, + blue_coin_pos, + red_pos, + blue_pos, + grid_size, + red_coin, +): + for i in prange(batch_size): + if generate[i]: + red_coin[i] = 1 - red_coin[i] + red_coin_pos[i] = _place_coin( + red_pos[i], blue_pos[i], grid_size, blue_coin_pos[i] + ) + blue_coin_pos[i] = _place_coin( + red_pos[i], blue_pos[i], grid_size, red_coin_pos[i] + ) + return red_coin_pos, blue_coin_pos, red_coin diff --git a/marltoolbox/examples/Tutorial_Basics_How_to_use_the_toolbox.ipynb b/marltoolbox/examples/Tutorial_Basics_How_to_use_the_toolbox.ipynb index bb7bda6..c2a13e5 100644 --- a/marltoolbox/examples/Tutorial_Basics_How_to_use_the_toolbox.ipynb +++ b/marltoolbox/examples/Tutorial_Basics_How_to_use_the_toolbox.ipynb @@ -75,8 +75,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:58:39.129878Z", - "start_time": "2021-04-14T09:58:39.126721Z" + "end_time": "2021-04-15T12:33:11.973283Z", + "start_time": "2021-04-15T12:33:11.970475Z" }, "id": "DsrzADRv4Itm" }, @@ -172,8 +172,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:58:41.814223Z", - "start_time": "2021-04-14T09:58:39.133920Z" + "end_time": "2021-04-15T12:33:14.647926Z", + "start_time": "2021-04-15T12:33:11.976742Z" }, "id": "BdtPEwKt-b47" }, @@ -191,6 +191,7 @@ "\n", "from marltoolbox.envs.matrix_sequential_social_dilemma import IteratedPrisonersDilemma\n", "from marltoolbox.utils.miscellaneous import check_learning_achieved\n", + "from marltoolbox.utils import log\n", "\n", "from IPython.core.display import display, HTML\n", "display(HTML(\"\"))" @@ -224,8 +225,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:58:41.821285Z", - "start_time": "2021-04-14T09:58:41.816146Z" + "end_time": "2021-04-15T12:33:14.653543Z", + "start_time": "2021-04-15T12:33:14.650069Z" }, "id": "MHNIE5wp-nV8" }, @@ -256,8 +257,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:58:41.907870Z", - "start_time": "2021-04-14T09:58:41.824356Z" + "end_time": "2021-04-15T12:33:14.742628Z", + "start_time": "2021-04-15T12:33:14.655685Z" }, "id": "nEC3Fbp9trEZ" }, @@ -284,8 +285,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:58:42.005535Z", - "start_time": "2021-04-14T09:58:41.910259Z" + "end_time": "2021-04-15T12:33:14.857873Z", + "start_time": "2021-04-15T12:33:14.745371Z" }, "id": "6ZyWFTfy-KSd" }, @@ -329,8 +330,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:58:42.124090Z", - "start_time": "2021-04-14T09:58:42.007028Z" + "end_time": "2021-04-15T12:33:14.926104Z", + "start_time": "2021-04-15T12:33:14.859461Z" }, "id": "DrBiGbnSBFns" }, @@ -389,8 +390,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:58:42.211967Z", - "start_time": "2021-04-14T09:58:42.127092Z" + "end_time": "2021-04-15T12:33:15.020084Z", + "start_time": "2021-04-15T12:33:14.928681Z" }, "id": "tB0ig1cuCkl7" }, @@ -480,8 +481,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:58:42.349958Z", - "start_time": "2021-04-14T09:58:42.216488Z" + "end_time": "2021-04-15T12:33:15.114115Z", + "start_time": "2021-04-15T12:33:15.022186Z" }, "id": "s4G0pHVJCpGU" }, @@ -586,8 +587,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:58:53.956824Z", - "start_time": "2021-04-14T09:58:42.352177Z" + "end_time": "2021-04-15T12:33:26.563552Z", + "start_time": "2021-04-15T12:33:15.119380Z" }, "id": "ST5qb68DCMI7", "scrolled": true @@ -602,6 +603,7 @@ "\n", "# stop the training after N Trainable.step (here this is equivalent to N episodes and N updates)\n", "stop_config = {\"training_iteration\": 200} \n", + "exp_name, _ = log.log_in_current_day_dir(\"PG_IPD_notebook_exp\")\n", "\n", "# Restart Ray defensively in case the ray connection is lost.\n", "ray.shutdown() \n", @@ -611,7 +613,7 @@ " Trainable,\n", " stop=stop_config,\n", " config=tune_config,\n", - " name=\"PG_IPD\",\n", + " name=exp_name,\n", " )\n", "\n", "ray.shutdown()\n", @@ -651,8 +653,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:59:06.311720Z", - "start_time": "2021-04-14T09:58:53.963206Z" + "end_time": "2021-04-15T12:33:39.209663Z", + "start_time": "2021-04-15T12:33:26.565556Z" }, "id": "vMbfhunq_ben" }, @@ -669,6 +671,7 @@ "# but here varying \"training_iteration\" is interesting.\n", "stop_config = {\"training_iteration\": tune.grid_search([2, 4, 8, 16, 32, 64, 128])} \n", "#####\n", + "exp_name, _ = log.log_in_current_day_dir(\"PG_IPD_notebook_exp\")\n", "\n", "ray.shutdown() \n", "ray.init(num_cpus=os.cpu_count(), num_gpus=0) \n", @@ -676,7 +679,7 @@ " Trainable,\n", " stop=stop_config,\n", " config=tune_config,\n", - " name=\"PG_IPD\",\n", + " name=exp_name,\n", " )\n", "ray.shutdown()\n", "\n", @@ -718,8 +721,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:00:01.360129Z", - "start_time": "2021-04-14T10:00:01.355038Z" + "end_time": "2021-04-15T12:33:39.226528Z", + "start_time": "2021-04-15T12:33:39.211410Z" }, "id": "9uHr3nuI-Lap" }, @@ -780,8 +783,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:02:10.235678Z", - "start_time": "2021-04-14T10:02:10.224537Z" + "end_time": "2021-04-15T12:33:39.355656Z", + "start_time": "2021-04-15T12:33:39.228620Z" }, "id": "mbxWsThD-vEt" }, @@ -855,8 +858,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:59:06.631534Z", - "start_time": "2021-04-14T09:59:06.473496Z" + "end_time": "2021-04-15T12:33:39.486196Z", + "start_time": "2021-04-15T12:33:39.360619Z" }, "id": "4-VPfgGP9-XH" }, @@ -911,8 +914,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:59:06.714373Z", - "start_time": "2021-04-14T09:59:06.632964Z" + "end_time": "2021-04-15T12:33:39.578807Z", + "start_time": "2021-04-15T12:33:39.487619Z" }, "id": "gej3suhHAo0S" }, @@ -984,8 +987,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:59:06.824496Z", - "start_time": "2021-04-14T09:59:06.716292Z" + "end_time": "2021-04-15T12:33:39.713752Z", + "start_time": "2021-04-15T12:33:39.584035Z" }, "id": "aJIHVPLjbfiC" }, @@ -1047,8 +1050,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:59:06.916296Z", - "start_time": "2021-04-14T09:59:06.826239Z" + "end_time": "2021-04-15T12:33:39.901063Z", + "start_time": "2021-04-15T12:33:39.715122Z" }, "id": "qqlYU7ZiALxl" }, @@ -1091,6 +1094,8 @@ " \"framework\": \"torch\",\n", " \"batch_mode\": \"complete_episodes\",\n", " # LTFT supports only 1 worker only otherwise it would be mixing several opponents trajectories\n", + " # Number of rollout worker actors to create for parallel sampling. Setting\n", + " # this to 0 will force rollouts to be done in the trainer actor.\n", " \"num_workers\": 0,\n", " # LTFT supports only 1 env per worker only otherwise several episodes would be played at the same time\n", " \"num_envs_per_worker\": 1,\n", @@ -1116,10 +1121,11 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:04:28.482320Z", - "start_time": "2021-04-14T10:02:15.269072Z" + "end_time": "2021-04-15T12:35:51.624256Z", + "start_time": "2021-04-15T12:33:39.902523Z" }, - "id": "bfEaO43v3UBB" + "id": "bfEaO43v3UBB", + "scrolled": true }, "outputs": [], "source": [ @@ -1139,14 +1145,14 @@ " \"debug\": False,\n", "}\n", "\n", - "\n", + "exp_name, _ = log.log_in_current_day_dir(\"LTFT_notebook_exp\")\n", "rllib_config = get_rllib_config(ltft_hparameters)\n", "stop_config = get_stop_config(ltft_hparameters)\n", "ray.shutdown()\n", "ray.init(num_cpus=os.cpu_count(), num_gpus=0) \n", "tune_analysis_self_play = ray.tune.run(ltft.LTFTTrainer, config=rllib_config,\n", " checkpoint_freq=0, stop=stop_config, \n", - " checkpoint_at_end=False, name=\"LTFT_exp\")\n", + " checkpoint_at_end=False, name=exp_name)\n", "ray.shutdown()\n", "\n", "check_learning_achieved(tune_results=tune_analysis_self_play, \n", @@ -1192,8 +1198,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:04:28.491013Z", - "start_time": "2021-04-14T10:04:28.484650Z" + "end_time": "2021-04-15T12:35:51.630130Z", + "start_time": "2021-04-15T12:35:51.626154Z" }, "id": "yCZzbx3Ld0vz" }, @@ -1224,8 +1230,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:05:15.895877Z", - "start_time": "2021-04-14T10:04:28.493528Z" + "end_time": "2021-04-15T12:36:53.910972Z", + "start_time": "2021-04-15T12:35:51.631535Z" }, "id": "zETjK-Y5ZKLA" }, @@ -1249,14 +1255,14 @@ " \"debug\": False,\n", "}\n", "\n", - "\n", + "exp_name, _ = log.log_in_current_day_dir(\"LTFT_notebook_exp\")\n", "rllib_config = get_rllib_config(ltft_hparameters)\n", "stop_config = get_stop_config(ltft_hparameters)\n", "ray.shutdown()\n", "ray.init(num_cpus=os.cpu_count(), num_gpus=0) \n", "tune_analysis_self_play = ray.tune.run(ltft.LTFTTrainer, config=rllib_config,\n", " checkpoint_freq=0, stop=stop_config, \n", - " checkpoint_at_end=False, name=\"LTFT_exp\")\n", + " checkpoint_at_end=False, name=exp_name)\n", "ray.shutdown()\n", "\n", "check_learning_achieved(tune_results=tune_analysis_self_play,\n", @@ -1299,8 +1305,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:59:10.306079Z", - "start_time": "2021-04-14T09:58:39.178Z" + "end_time": "2021-04-15T12:36:53.923288Z", + "start_time": "2021-04-15T12:36:53.915782Z" }, "id": "u00VE8oB4Itx" }, @@ -1314,8 +1320,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T09:59:10.306786Z", - "start_time": "2021-04-14T09:58:39.180Z" + "end_time": "2021-04-15T12:36:54.052499Z", + "start_time": "2021-04-15T12:36:53.925743Z" }, "id": "NpeJySSk4Itx", "scrolled": true diff --git a/marltoolbox/examples/Tutorial_Evaluations_Level_1_best_response_and_self_play_and_cross_play.ipynb b/marltoolbox/examples/Tutorial_Evaluations_Level_1_best_response_and_self_play_and_cross_play.ipynb index 5989126..8811697 100644 --- a/marltoolbox/examples/Tutorial_Evaluations_Level_1_best_response_and_self_play_and_cross_play.ipynb +++ b/marltoolbox/examples/Tutorial_Evaluations_Level_1_best_response_and_self_play_and_cross_play.ipynb @@ -32,8 +32,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:06:53.433046Z", - "start_time": "2021-04-14T10:06:53.428041Z" + "end_time": "2021-04-15T12:55:08.510232Z", + "start_time": "2021-04-15T12:55:08.501235Z" }, "id": "KMXC7aH5CD8p" }, @@ -62,8 +62,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:06:53.517255Z", - "start_time": "2021-04-14T10:06:53.435182Z" + "end_time": "2021-04-15T12:55:08.594695Z", + "start_time": "2021-04-15T12:55:08.512706Z" }, "id": "DsrzADRv4Itm" }, @@ -141,8 +141,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:06:56.606917Z", - "start_time": "2021-04-14T10:06:53.518906Z" + "end_time": "2021-04-15T12:55:11.544021Z", + "start_time": "2021-04-15T12:55:08.597116Z" }, "id": "BdtPEwKt-b47" }, @@ -188,8 +188,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:06:56.614247Z", - "start_time": "2021-04-14T10:06:56.609516Z" + "end_time": "2021-04-15T12:55:11.550532Z", + "start_time": "2021-04-15T12:55:11.546175Z" }, "id": "9fZtrKHovNvV" }, @@ -216,8 +216,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:06:56.722108Z", - "start_time": "2021-04-14T10:06:56.616211Z" + "end_time": "2021-04-15T13:08:41.276000Z", + "start_time": "2021-04-15T13:08:41.268429Z" }, "id": "MHNIE5wp-nV8" }, @@ -229,10 +229,10 @@ "\n", " # This modification to the policy will allow us to load each policies from different checkpoints \n", " # This will be used during evaluation.\n", - " def merged_after_init(*args, **kwargs):\n", + " def merged_before_loss_init(*args, **kwargs):\n", " setup_mixins(*args, **kwargs)\n", - " restore.after_init_load_policy_checkpoint(*args, **kwargs)\n", - " MyPPOPolicy = PPOTorchPolicy.with_updates(after_init=merged_after_init)\n", + " restore.before_loss_init_load_policy_checkpoint(*args, **kwargs)\n", + " MyPPOPolicy = PPOTorchPolicy.with_updates(before_loss_init=merged_before_loss_init)\n", "\n", " stop_config = {\n", " \"episodes_total\": hp[\"episodes_total\"],\n", @@ -311,18 +311,19 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:07:47.492122Z", - "start_time": "2021-04-14T10:06:56.726437Z" + "end_time": "2021-04-15T12:55:59.020889Z", + "start_time": "2021-04-15T12:55:11.652388Z" }, "id": "XijF3Q7O0FxF" }, "outputs": [], "source": [ + "exp_name, _ = log.log_in_current_day_dir(\"PPO_BoS_notebook_exp\")\n", "hyperparameters = {\n", " \"steps_per_epi\": 20,\n", " \"train_n_replicates\": 8,\n", " \"episodes_total\": 200,\n", - " \"exp_name\": \"PPO_BoS\",\n", + " \"exp_name\": exp_name,\n", " \"base_lr\": 5e-1,\n", "}\n", "\n", @@ -350,8 +351,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:07:47.498013Z", - "start_time": "2021-04-14T10:07:47.494191Z" + "end_time": "2021-04-15T12:55:59.045103Z", + "start_time": "2021-04-15T12:55:59.024866Z" }, "id": "e4qziRvL0o_E" }, @@ -375,8 +376,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:08:18.230859Z", - "start_time": "2021-04-14T10:07:47.499394Z" + "end_time": "2021-04-15T13:11:25.364166Z", + "start_time": "2021-04-15T13:10:45.951397Z" }, "id": "6ZyWFTfy-KSd" }, @@ -455,8 +456,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:10:12.067609Z", - "start_time": "2021-04-14T10:10:12.062918Z" + "end_time": "2021-04-15T12:56:29.454882Z", + "start_time": "2021-04-15T12:56:28.770969Z" }, "id": "wmvRIbPqUAwy" }, @@ -503,8 +504,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:08:19.676226Z", - "start_time": "2021-04-14T10:08:19.671968Z" + "end_time": "2021-04-15T12:56:29.461390Z", + "start_time": "2021-04-15T12:56:29.456601Z" }, "id": "bVqSBL72wHrd" }, @@ -531,8 +532,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:08:19.788760Z", - "start_time": "2021-04-14T10:08:19.678381Z" + "end_time": "2021-04-15T12:56:29.561534Z", + "start_time": "2021-04-15T12:56:29.463140Z" }, "id": "vMbfhunq_ben" }, @@ -540,11 +541,12 @@ "source": [ "def train_lvl0_agents(lvl0_hparameters):\n", "\n", + " exp_name, _ = log.log_in_current_day_dir(\"Lvl0_LOLAExact_notebook_exp\")\n", " tune_config = get_tune_config(lvl0_hparameters)\n", " stop_config = get_stop_config(lvl0_hparameters)\n", " ray.shutdown()\n", " ray.init(num_cpus=os.cpu_count(), num_gpus=0) \n", - " tune_analysis_lvl0 = tune.run(LOLAExact, name=\"Lvl0_LOLAExact\", config=tune_config,\n", + " tune_analysis_lvl0 = tune.run(LOLAExact, name=exp_name, config=tune_config,\n", " checkpoint_at_end=True, stop=stop_config, \n", " metric=lvl0_hparameters[\"metric\"], mode=\"max\")\n", " ray.shutdown()\n", @@ -583,8 +585,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:09:38.553623Z", - "start_time": "2021-04-14T10:08:19.792698Z" + "end_time": "2021-04-15T12:57:43.481511Z", + "start_time": "2021-04-15T12:56:29.563639Z" }, "id": "Im5nFVs7YUtU" }, @@ -635,8 +637,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:11:56.338029Z", - "start_time": "2021-04-14T10:11:56.328179Z" + "end_time": "2021-04-15T12:57:43.492585Z", + "start_time": "2021-04-15T12:57:43.483312Z" }, "id": "P2loJhuzVho6" }, @@ -644,6 +646,7 @@ "source": [ "def train_lvl1_agents(hp_lvl1, tune_analysis_lvl0):\n", "\n", + " exp_name, _ = log.log_in_current_day_dir(\"Lvl1_PG_notebook_exp\")\n", " rllib_config_lvl1, trainable_class, env_config = get_rllib_config(hp_lvl1)\n", " stop_config = get_stop_config(hp_lvl1)\n", " \n", @@ -659,7 +662,7 @@ " stop=stop_config,\n", " checkpoint_at_end=True,\n", " metric=\"episode_reward_mean\", mode=\"max\",\n", - " name=\"Lvl1_PG\")\n", + " name=exp_name)\n", " ray.shutdown()\n", " return tune_analysis_lvl1\n", "\n", @@ -746,8 +749,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:12:30.455824Z", - "start_time": "2021-04-14T10:11:58.126061Z" + "end_time": "2021-04-15T12:58:19.055673Z", + "start_time": "2021-04-15T12:57:43.493968Z" }, "id": "b2kmotZMYZBX", "scrolled": true @@ -785,8 +788,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:12:34.152484Z", - "start_time": "2021-04-14T10:12:33.899223Z" + "end_time": "2021-04-15T12:58:19.100252Z", + "start_time": "2021-04-15T12:58:19.056994Z" } }, "outputs": [], @@ -799,8 +802,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:12:36.039089Z", - "start_time": "2021-04-14T10:12:35.975270Z" + "end_time": "2021-04-15T12:58:19.256941Z", + "start_time": "2021-04-15T12:58:19.101888Z" }, "id": "M1lSyEf1D08q" }, @@ -828,8 +831,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:13:04.766311Z", - "start_time": "2021-04-14T10:13:04.716165Z" + "end_time": "2021-04-15T12:58:19.317574Z", + "start_time": "2021-04-15T12:58:19.258240Z" }, "id": "v67qWvcDD08q" }, @@ -871,8 +874,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:09:39.229687Z", - "start_time": "2021-04-14T10:06:53.463Z" + "end_time": "2021-04-15T12:58:19.401613Z", + "start_time": "2021-04-15T12:58:19.319051Z" }, "id": "u00VE8oB4Itx" }, @@ -886,8 +889,8 @@ "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2021-04-14T10:09:39.230757Z", - "start_time": "2021-04-14T10:06:53.465Z" + "end_time": "2021-04-15T12:58:19.501562Z", + "start_time": "2021-04-15T12:58:19.403275Z" }, "id": "NpeJySSk4Itx", "scrolled": true diff --git a/marltoolbox/examples/rllib_api/dqn_coin_game.py b/marltoolbox/examples/rllib_api/dqn_coin_game.py index ba32255..140139e 100644 --- a/marltoolbox/examples/rllib_api/dqn_coin_game.py +++ b/marltoolbox/examples/rllib_api/dqn_coin_game.py @@ -18,38 +18,36 @@ def main(debug): seeds = miscellaneous.get_random_seeds(train_n_replicates) exp_name, _ = log.log_in_current_day_dir("DQN_CG") - hparams = _get_hyperparameters(seeds, debug, exp_name) + hparams = get_hyperparameters(seeds, debug, exp_name) - rllib_config, stop_config = _get_rllib_configs(hparams) + rllib_config, stop_config = get_rllib_configs(hparams) - tune_analysis = _train_dqn_and_plot_logs( - hparams, rllib_config, stop_config) + tune_analysis = train_dqn_and_plot_logs(hparams, rllib_config, stop_config) return tune_analysis -def _get_hyperparameters(seeds, debug, exp_name): +def get_hyperparameters(seeds, debug, exp_name): + """Get hyperparameters for the Coin Game env and DQN agents""" + hparams = { "seeds": seeds, "debug": debug, "exp_name": exp_name, "n_steps_per_epi": 100, - "n_epi": 2000, - "buf_frac": 0.125, - "last_exploration_temp_value": 0.1, - "bs_epi_mul": 1, - - "plot_keys": - coin_game.PLOT_KEYS + - aggregate_and_plot_tensorboard_data.PLOT_KEYS, - "plot_assemblage_tags": - coin_game.PLOT_ASSEMBLAGE_TAGS + - aggregate_and_plot_tensorboard_data.PLOT_ASSEMBLAGE_TAGS, + "n_epi": 4000, + "buf_frac": 0.5, + "last_exploration_temp_value": 0.003, + "bs_epi_mul": 4, + "plot_keys": coin_game.PLOT_KEYS + + aggregate_and_plot_tensorboard_data.PLOT_KEYS, + "plot_assemblage_tags": coin_game.PLOT_ASSEMBLAGE_TAGS + + aggregate_and_plot_tensorboard_data.PLOT_ASSEMBLAGE_TAGS, } return hparams -def _get_rllib_configs(hp, env_class=None): +def get_rllib_configs(hp, env_class=None): stop_config = { "episodes_total": 2 if hp["debug"] else hp["n_epi"], } @@ -59,58 +57,70 @@ def _get_rllib_configs(hp, env_class=None): "max_steps": hp["n_steps_per_epi"], "grid_size": 3, "get_additional_info": True, + "buf_frac": hp["buf_frac"], + "bs_epi_mul": hp["bs_epi_mul"], } env_class = coin_game.CoinGame if env_class is None else env_class rllib_config = { "env": env_class, "env_config": env_config, - "multiagent": { "policies": { env_config["players_ids"][0]: ( augmented_dqn.MyDQNTorchPolicy, env_class(env_config).OBSERVATION_SPACE, env_class.ACTION_SPACE, - {}), + {}, + ), env_config["players_ids"][1]: ( augmented_dqn.MyDQNTorchPolicy, env_class(env_config).OBSERVATION_SPACE, env_class.ACTION_SPACE, - {}), + {}, + ), }, "policy_mapping_fn": lambda agent_id: agent_id, }, - # === DQN Models === - # Update the target network every `target_network_update_freq` steps. "target_network_update_freq": tune.sample_from( - lambda spec: int(spec.config["env_config"]["max_steps"] * 30)), + lambda spec: int(spec.config["env_config"]["max_steps"] * 30) + ), # === Replay buffer === # Size of the replay buffer. Note that if async_updates is set, then # each worker will have a replay buffer of this size. "buffer_size": tune.sample_from( - lambda spec: int(spec.config["env_config"]["max_steps"] * - spec.stop["episodes_total"] * hp["buf_frac"])), + lambda spec: int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * spec.config["env_config"]["buf_frac"] + ) + ), # Whether to use dueling dqn - "dueling": False, + "dueling": True, # Whether to use double dqn "double_q": True, # If True prioritized replay buffer will be used. "prioritized_replay": False, - "rollout_fragment_length": tune.sample_from( - lambda spec: spec.config["env_config"]["max_steps"]), - "training_intensity": 10, + lambda spec: spec.config["env_config"]["max_steps"] + ), + "training_intensity": tune.sample_from( + lambda spec: spec.config["num_envs_per_worker"] + * max(spec.config["num_workers"], 1) + * 40 + ), # Size of a batch sampled from replay buffer for training. Note that # if async_updates is set, then each worker returns gradients for a # batch of this size. "train_batch_size": tune.sample_from( - lambda spec: int(spec.config["env_config"]["max_steps"] * - hp["bs_epi_mul"])), + lambda spec: int( + spec.config["env_config"]["max_steps"] + * spec.config["env_config"]["bs_epi_mul"] + ) + ), "batch_mode": "complete_episodes", - # === Exploration Settings === # Default exploration behavior, iff `explore`=None is passed into # compute_action(s). @@ -130,61 +140,85 @@ def _get_rllib_configs(hp, env_class=None): "temperature_schedule": tune.sample_from( lambda spec: PiecewiseSchedule( endpoints=[ - (0, - 2.0), - (int(spec.config["env_config"]["max_steps"] * - spec.stop["episodes_total"] * 0.20), - 0.5), - (int(spec.config["env_config"]["max_steps"] * - spec.stop["episodes_total"] * 0.60), - hp["last_exploration_temp_value"])], + (0, 1.0), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * 0.20 + ), + 0.45, + ), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * 0.9 + ), + hp["last_exploration_temp_value"], + ), + ], outside_value=hp["last_exploration_temp_value"], - framework="torch")), + framework="torch", + ) + ), }, - # Size of batches collected from each worker. "model": { "dim": env_config["grid_size"], # [Channel, [Kernel, Kernel], Stride]] - "conv_filters": [[16, [3, 3], 1], [32, [3, 3], 1]] + "conv_filters": [[64, [3, 3], 1], [64, [3, 3], 1]], + "fcnet_hiddens": [64, 64], }, + "hiddens": [32], "gamma": 0.96, - "optimizer": {"sgd_momentum": 0.9, }, + "optimizer": { + "sgd_momentum": 0.9, + }, "lr": 0.1, "lr_schedule": tune.sample_from( lambda spec: [ (0, 0.0), - (int(spec.config["env_config"]["max_steps"] * - spec.stop["episodes_total"] * 0.05), - spec.config.lr), - (int(spec.config["env_config"]["max_steps"] * - spec.stop["episodes_total"]), - spec.config.lr / 1e9) + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * 0.05 + ), + spec.config.lr, + ), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + ), + spec.config.lr / 1e9, + ), ] ), - "seed": tune.grid_search(hp["seeds"]), "callbacks": log.get_logging_callbacks_class(), "framework": "torch", - "logger_config": { "wandb": { "project": "DQN_CG", "group": hp["exp_name"], - "api_key_file": - os.path.join(os.path.dirname(__file__), - "../../../api_key_wandb"), - "log_config": True + "api_key_file": os.path.join( + os.path.dirname(__file__), "../../../api_key_wandb" + ), + "log_config": True, }, }, - + "num_envs_per_worker": 16, + "num_workers": 0, + # "log_level": "INFO", } return rllib_config, stop_config -def _train_dqn_and_plot_logs(hp, rllib_config, stop_config): - ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=hp["debug"]) +def train_dqn_and_plot_logs(hp, rllib_config, stop_config): + ray.init(num_cpus=os.cpu_count(), local_mode=hp["debug"]) tune_analysis = tune.run( DQNTrainer, config=rllib_config, diff --git a/marltoolbox/examples/rllib_api/dqn_wt_welfare.py b/marltoolbox/examples/rllib_api/dqn_wt_welfare.py index 10a7083..601d20c 100644 --- a/marltoolbox/examples/rllib_api/dqn_wt_welfare.py +++ b/marltoolbox/examples/rllib_api/dqn_wt_welfare.py @@ -10,29 +10,35 @@ def main(debug, welfare=postprocessing.WELFARE_UTILITARIAN): seeds = miscellaneous.get_random_seeds(train_n_replicates) exp_name, _ = log.log_in_current_day_dir("DQN_welfare_CG") - hparams = dqn_coin_game._get_hyperparameters(seeds, debug, exp_name) - rllib_config, stop_config = dqn_coin_game._get_rllib_configs(hparams) - rllib_config = _modify_policy_to_use_welfare(rllib_config, welfare) + hparams = dqn_coin_game.get_hyperparameters(seeds, debug, exp_name) + rllib_config, stop_config = dqn_coin_game.get_rllib_configs(hparams) + rllib_config = modify_dqn_rllib_config_to_use_welfare( + rllib_config, welfare + ) - tune_analysis = dqn_coin_game._train_dqn_and_plot_logs( - hparams, rllib_config, stop_config) + tune_analysis = dqn_coin_game.train_dqn_and_plot_logs( + hparams, rllib_config, stop_config + ) return tune_analysis -def _modify_policy_to_use_welfare(rllib_config, welfare): - MyCoopDQNTorchPolicy = augmented_dqn.MyDQNTorchPolicy.with_updates( - postprocess_fn=miscellaneous.merge_policy_postprocessing_fn( - postprocessing.welfares_postprocessing_fn(), - postprocess_nstep_and_prio, - ) +def modify_dqn_rllib_config_to_use_welfare(rllib_config, welfare): + DQNTorchPolicyWtWelfare = _get_policy_class_wt_welfare_preprocessing() + rllib_config = modify_rllib_config_to_use_welfare( + rllib_config, welfare, policy_class_wt_welfare=DQNTorchPolicyWtWelfare ) + return rllib_config + +def modify_rllib_config_to_use_welfare( + rllib_config, welfare, policy_class_wt_welfare, overwrite_reward=True +): policies = rllib_config["multiagent"]["policies"] new_policies = {} for policies_id, policy_tuple in policies.items(): new_policies[policies_id] = list(policy_tuple) - new_policies[policies_id][0] = MyCoopDQNTorchPolicy + new_policies[policies_id][0] = policy_class_wt_welfare if welfare == postprocessing.WELFARE_UTILITARIAN: new_policies[policies_id][3].update( {postprocessing.ADD_UTILITARIAN_WELFARE: True} @@ -40,7 +46,7 @@ def _modify_policy_to_use_welfare(rllib_config, welfare): elif welfare == postprocessing.WELFARE_INEQUITY_AVERSION: add_ia_w = True ia_alpha = 0.0 - ia_beta = 0.5 + ia_beta = 0.5 / 2 ia_gamma = 0.96 ia_lambda = 0.96 inequity_aversion_parameters = ( @@ -51,18 +57,30 @@ def _modify_policy_to_use_welfare(rllib_config, welfare): ia_lambda, ) new_policies[policies_id][3].update( - {postprocessing.ADD_INEQUITY_AVERSION_WELFARE: - inequity_aversion_parameters} + { + postprocessing.ADD_INEQUITY_AVERSION_WELFARE: inequity_aversion_parameters + } ) rllib_config["multiagent"]["policies"] = new_policies - rllib_config["callbacks"] = callbacks.merge_callbacks( - log.get_logging_callbacks_class(), - postprocessing.OverwriteRewardWtWelfareCallback, - ) + if overwrite_reward: + rllib_config["callbacks"] = callbacks.merge_callbacks( + log.get_logging_callbacks_class(), + postprocessing.OverwriteRewardWtWelfareCallback, + ) return rllib_config +def _get_policy_class_wt_welfare_preprocessing(): + DQNTorchPolicyWtWelfare = augmented_dqn.MyDQNTorchPolicy.with_updates( + postprocess_fn=miscellaneous.merge_policy_postprocessing_fn( + postprocessing.welfares_postprocessing_fn(), + postprocess_nstep_and_prio, + ) + ) + return DQNTorchPolicyWtWelfare + + if __name__ == "__main__": debug_mode = True main(debug_mode) diff --git a/marltoolbox/examples/rllib_api/inequity_aversion.py b/marltoolbox/examples/rllib_api/inequity_aversion.py index c730ea2..bd0e070 100644 --- a/marltoolbox/examples/rllib_api/inequity_aversion.py +++ b/marltoolbox/examples/rllib_api/inequity_aversion.py @@ -21,39 +21,35 @@ def main(debug): } policies = { - env_config["players_ids"][0]: - ( - None, - IteratedBoSAndPD.OBSERVATION_SPACE, - IteratedBoSAndPD.ACTION_SPACE, - {} - ), - env_config["players_ids"][1]: - ( - None, - IteratedBoSAndPD.OBSERVATION_SPACE, - IteratedBoSAndPD.ACTION_SPACE, - {} - )} + env_config["players_ids"][0]: ( + None, + IteratedBoSAndPD.OBSERVATION_SPACE, + IteratedBoSAndPD.ACTION_SPACE, + {}, + ), + env_config["players_ids"][1]: ( + None, + IteratedBoSAndPD.OBSERVATION_SPACE, + IteratedBoSAndPD.ACTION_SPACE, + {}, + ), + } rllib_config = { "env": IteratedBoSAndPD, "env_config": env_config, - "num_gpus": 0, "num_workers": 1, - "multiagent": { "policies": policies, "policy_mapping_fn": (lambda agent_id: agent_id), }, "framework": "torch", "gamma": 0.5, - "callbacks": callbacks.merge_callbacks( log.get_logging_callbacks_class(), - postprocessing.OverwriteRewardWtWelfareCallback), - + postprocessing.OverwriteRewardWtWelfareCallback, + ), } MyPGTorchPolicy = PGTorchPolicy.with_updates( @@ -63,20 +59,23 @@ def main(debug): inequity_aversion_beta=1.0, inequity_aversion_alpha=0.0, inequity_aversion_gamma=1.0, - inequity_aversion_lambda=0.5 + inequity_aversion_lambda=0.5, ), - pg_torch_policy.post_process_advantages + pg_torch_policy.post_process_advantages, ) ) - MyPGTrainer = PGTrainer.with_updates(default_policy=MyPGTorchPolicy, - get_policy_class=None) - tune_analysis = tune.run(MyPGTrainer, - stop=stop, - checkpoint_freq=10, - config=rllib_config, - name=exp_name) + MyPGTrainer = PGTrainer.with_updates( + default_policy=MyPGTorchPolicy, get_policy_class=None + ) + experiment_analysis = tune.run( + MyPGTrainer, + stop=stop, + checkpoint_freq=10, + config=rllib_config, + name=exp_name, + ) ray.shutdown() - return tune_analysis + return experiment_analysis if __name__ == "__main__": diff --git a/marltoolbox/examples/rllib_api/pg_ipd.py b/marltoolbox/examples/rllib_api/pg_ipd.py index 769843d..5851e8e 100644 --- a/marltoolbox/examples/rllib_api/pg_ipd.py +++ b/marltoolbox/examples/rllib_api/pg_ipd.py @@ -2,7 +2,8 @@ import ray from ray import tune -from ray.rllib.agents.pg import PGTrainer +from ray.rllib.agents.pg import PGTrainer, PGTorchPolicy +from ray.rllib.agents.pg.pg_torch_policy import pg_loss_stats from marltoolbox.envs.matrix_sequential_social_dilemma import ( IteratedPrisonersDilemma, @@ -10,15 +11,15 @@ from marltoolbox.utils import log, miscellaneous -def main(debug, stop_iters=300, tf=False): +def main(debug): train_n_replicates = 1 if debug else 1 seeds = miscellaneous.get_random_seeds(train_n_replicates) exp_name, _ = log.log_in_current_day_dir("PG_IPD") ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug) - rllib_config, stop_config = get_rllib_config(seeds, debug, stop_iters, tf) - tune_analysis = tune.run( + rllib_config, stop_config = get_rllib_config(seeds, debug) + experiment_analysis = tune.run( PGTrainer, config=rllib_config, stop=stop_config, @@ -27,17 +28,19 @@ def main(debug, stop_iters=300, tf=False): log_to_file=True, ) ray.shutdown() - return tune_analysis + return experiment_analysis -def get_rllib_config(seeds, debug=False, stop_iters=300, tf=False): +def get_rllib_config(seeds, debug=False): stop_config = { - "training_iteration": 2 if debug else stop_iters, + "episodes_total": 2 if debug else 400, } + n_steps_in_epi = 20 + env_config = { "players_ids": ["player_row", "player_col"], - "max_steps": 20, + "max_steps": n_steps_in_epi, "get_additional_info": True, } @@ -63,8 +66,9 @@ def get_rllib_config(seeds, debug=False, stop_iters=300, tf=False): }, "seed": tune.grid_search(seeds), "callbacks": log.get_logging_callbacks_class(log_full_epi=True), - "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), - "framework": "tf" if tf else "torch", + "framework": "torch", + "rollout_fragment_length": n_steps_in_epi, + "train_batch_size": n_steps_in_epi, } return rllib_config, stop_config diff --git a/marltoolbox/examples/rllib_api/ppo_coin_game.py b/marltoolbox/examples/rllib_api/ppo_coin_game.py index 78c10a1..3243b73 100644 --- a/marltoolbox/examples/rllib_api/ppo_coin_game.py +++ b/marltoolbox/examples/rllib_api/ppo_coin_game.py @@ -8,12 +8,8 @@ from marltoolbox.envs.coin_game import AsymCoinGame from marltoolbox.utils import log, miscellaneous -parser = argparse.ArgumentParser() -parser.add_argument("--tf", action="store_true") -parser.add_argument("--stop-iters", type=int, default=2000) - -def main(debug, stop_iters=2000, tf=False): +def main(debug): train_n_replicates = 1 if debug else 1 seeds = miscellaneous.get_random_seeds(train_n_replicates) exp_name, _ = log.log_in_current_day_dir("PPO_AsymCG") @@ -21,7 +17,7 @@ def main(debug, stop_iters=2000, tf=False): ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug) stop = { - "training_iteration": 2 if debug else stop_iters, + "training_iteration": 2 if debug else 2000, } env_class = AsymCoinGame @@ -66,7 +62,7 @@ def main(debug, stop_iters=2000, tf=False): "seed": tune.grid_search(seeds), "callbacks": log.get_logging_callbacks_class(), "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), - "framework": "tf" if tf else "torch", + "framework": "torch", "num_workers": 0, } @@ -83,6 +79,5 @@ def main(debug, stop_iters=2000, tf=False): if __name__ == "__main__": - args = parser.parse_args() debug_mode = True - main(debug_mode, args.stop_iters, args.tf) + main(debug_mode) diff --git a/marltoolbox/examples/rllib_api/r2d2_coin_game.py b/marltoolbox/examples/rllib_api/r2d2_coin_game.py new file mode 100644 index 0000000..454f32d --- /dev/null +++ b/marltoolbox/examples/rllib_api/r2d2_coin_game.py @@ -0,0 +1,249 @@ +import os + +import ray +from ray import tune +from ray.rllib.agents import dqn +from ray.tune.integration.wandb import WandbLogger +from ray.tune.logger import DEFAULT_LOGGERS +from ray.rllib.agents.dqn.r2d2_torch_policy import postprocess_nstep_and_prio +from ray.rllib.utils.schedules import PiecewiseSchedule + +from marltoolbox.examples.rllib_api import dqn_coin_game, dqn_wt_welfare +from marltoolbox.scripts import aggregate_and_plot_tensorboard_data +from marltoolbox.utils import log, miscellaneous, postprocessing +from marltoolbox.algos import augmented_r2d2 +from marltoolbox.envs import coin_game, ssd_mixed_motive_coin_game + + +def main(debug): + """Train R2D2 agent in the Coin Game environment""" + + # env = "CoinGame" + env = "SSDMixedMotiveCoinGame" + # welfare_to_use = None + welfare_to_use = postprocessing.WELFARE_UTILITARIAN + # welfare_to_use = postprocessing.WELFARE_INEQUITY_AVERSION + + rllib_config, stop_config, hparams = _get_config_and_hp_for_training( + debug, env, welfare_to_use + ) + + tune_analysis = _train_dqn(hparams, rllib_config, stop_config) + + _plot_log_aggregates(hparams) + + return tune_analysis + + +def _get_config_and_hp_for_training(debug, env, welfare_to_use): + train_n_replicates = 1 if debug else 1 + seeds = miscellaneous.get_random_seeds(train_n_replicates) + exp_name, _ = log.log_in_current_day_dir("R2D2_CG") + if "SSDMixedMotiveCoinGame" in env: + env_class = ssd_mixed_motive_coin_game.SSDMixedMotiveCoinGame + else: + env_class = coin_game.CoinGame + + hparams = dqn_coin_game.get_hyperparameters(seeds, debug, exp_name) + rllib_config, stop_config = dqn_coin_game.get_rllib_configs( + hparams, env_class=env_class + ) + rllib_config, stop_config = _adapt_configs_for_r2d2( + rllib_config, stop_config, hparams + ) + if welfare_to_use is not None: + rllib_config = modify_r2d2_rllib_config_to_use_welfare( + rllib_config, welfare_to_use + ) + return rllib_config, stop_config, hparams + + +def modify_r2d2_rllib_config_to_use_welfare(rllib_config, welfare_to_use): + r2d2_torch_policy_class_wt_welfare = ( + _get_r2d2_policy_class_wt_welfare_preprocessing() + ) + rllib_config = dqn_wt_welfare.modify_rllib_config_to_use_welfare( + rllib_config, + welfare_to_use, + policy_class_wt_welfare=r2d2_torch_policy_class_wt_welfare, + ) + return rllib_config + + +def _get_r2d2_policy_class_wt_welfare_preprocessing(): + r2d2_torch_policy_class_wt_welfare = ( + augmented_r2d2.MyR2D2TorchPolicy.with_updates( + postprocess_fn=miscellaneous.merge_policy_postprocessing_fn( + postprocessing.welfares_postprocessing_fn(), + postprocess_nstep_and_prio, + ) + ) + ) + return r2d2_torch_policy_class_wt_welfare + + +def _adapt_configs_for_r2d2(rllib_config, stop_config, hp): + rllib_config["logger_config"]["wandb"]["project"] = "R2D2_CG" + rllib_config["model"]["use_lstm"] = True + rllib_config["burn_in"] = 0 + rllib_config["zero_init_states"] = False + rllib_config["use_h_function"] = False + + rllib_config = _replace_class_of_policies_by( + augmented_r2d2.MyR2D2TorchPolicy, + rllib_config, + ) + + if not hp["debug"]: + # rllib_config["env_config"]["training_intensity"] = 40 + # rllib_config["lr"] = 0.1 + # rllib_config["training_intensity"] = tune.sample_from( + # lambda spec: spec.config["num_envs_per_worker"] + # * spec.config["env_config"]["training_intensity"] + # * max(1, spec.config["num_workers"]) + # ) + stop_config["episodes_total"] = 8000 + rllib_config["model"]["lstm_cell_size"] = 16 + rllib_config["model"]["max_seq_len"] = 20 + # rllib_config["env_config"]["buf_frac"] = 0.5 + # rllib_config["hiddens"] = [32] + # rllib_config["model"]["fcnet_hiddens"] = [64, 64] + # rllib_config["model"]["conv_filters"] = tune.grid_search( + # [ + # [[32, [3, 3], 1], [32, [3, 3], 1]], + # [[64, [3, 3], 1], [64, [3, 3], 1]], + # ] + # ) + + rllib_config["env_config"]["temp_ratio"] = ( + 1.0 if hp["debug"] else tune.grid_search([0.75, 1.0]) + ) + rllib_config["env_config"]["interm_temp_ratio"] = ( + 1.0 if hp["debug"] else tune.grid_search([0.75, 1.0]) + ) + rllib_config["env_config"]["last_exploration_t"] = ( + 0.6 if hp["debug"] else tune.grid_search([0.9, 0.99]) + ) + rllib_config["env_config"]["last_exploration_temp_value"] = ( + 1.0 if hp["debug"] else 0.003 + ) + rllib_config["env_config"]["interm_exploration_t"] = ( + 0.2 if hp["debug"] else tune.grid_search([0.2, 0.6]) + ) + rllib_config["exploration_config"][ + "temperature_schedule" + ] = tune.sample_from( + lambda spec: PiecewiseSchedule( + endpoints=[ + (0, 1.0 * spec.config["env_config"]["temp_ratio"]), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * spec.config["env_config"]["interm_exploration_t"] + ), + 0.6 + * spec.config["env_config"]["temp_ratio"] + * spec.config["env_config"]["interm_temp_ratio"], + ), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * spec.config["env_config"]["last_exploration_t"] + ), + spec.config["env_config"][ + "last_exploration_temp_value" + ], + ), + ], + outside_value=spec.config["env_config"][ + "last_exploration_temp_value" + ], + framework="torch", + ) + ) + + # rllib_config["env_config"]["bs_epi_mul"] = ( + # 4 if hp["debug"] else tune.grid_search([4, 8, 16]) + # ) + rllib_config["env_config"]["interm_lr_ratio"] = ( + 0.5 + if hp["debug"] + else tune.grid_search([0.5 * 3, 0.5, 0.5 / 3, 0.5 / 9]) + ) + rllib_config["env_config"]["interm_lr_t"] = ( + 0.5 if hp["debug"] else tune.grid_search([0.25, 0.5, 0.75]) + ) + rllib_config["lr"] = 0.1 if hp["debug"] else 0.1 + rllib_config["lr_schedule"] = tune.sample_from( + lambda spec: [ + (0, 0.0), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * 0.05 + ), + spec.config.lr, + ), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * spec.config["env_config"]["interm_lr_t"] + ), + spec.config.lr + * spec.config["env_config"]["interm_lr_ratio"], + ), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + ), + spec.config.lr / 1e9, + ), + ] + ) + else: + rllib_config["model"]["max_seq_len"] = 2 + rllib_config["model"]["lstm_cell_size"] = 8 + + return rllib_config, stop_config + + +def _replace_class_of_policies_by(new_policy_class, rllib_config): + policies = rllib_config["multiagent"]["policies"] + for policy_id in policies.keys(): + policy = list(policies[policy_id]) + policy[0] = new_policy_class + policies[policy_id] = tuple(policy) + return rllib_config + + +def _train_dqn(hp, rllib_config, stop_config): + ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=hp["debug"]) + tune_analysis = tune.run( + dqn.r2d2.R2D2Trainer, + config=rllib_config, + stop=stop_config, + name=hp["exp_name"], + log_to_file=not hp["debug"], + loggers=None if hp["debug"] else DEFAULT_LOGGERS + (WandbLogger,), + ) + ray.shutdown() + return tune_analysis + + +def _plot_log_aggregates(hp): + if not hp["debug"]: + aggregate_and_plot_tensorboard_data.add_summary_plots( + main_path=os.path.join("~/ray_results/", hp["exp_name"]), + plot_keys=hp["plot_keys"], + plot_assemble_tags_in_one_plot=hp["plot_assemblage_tags"], + ) + + +if __name__ == "__main__": + debug_mode = True + main(debug_mode) diff --git a/marltoolbox/examples/rllib_api/r2d2_ipd.py b/marltoolbox/examples/rllib_api/r2d2_ipd.py new file mode 100644 index 0000000..5d6d0bd --- /dev/null +++ b/marltoolbox/examples/rllib_api/r2d2_ipd.py @@ -0,0 +1,68 @@ +import os + +import ray +from ray import tune +from ray.rllib.agents.dqn import R2D2Trainer + +from marltoolbox.algos import augmented_dqn +from marltoolbox.envs import matrix_sequential_social_dilemma +from marltoolbox.examples.rllib_api import pg_ipd +from marltoolbox.scripts import aggregate_and_plot_tensorboard_data +from marltoolbox.utils import log, miscellaneous + + +def main(debug): + train_n_replicates = 1 if debug else 1 + seeds = miscellaneous.get_random_seeds(train_n_replicates) + exp_name, _ = log.log_in_current_day_dir("R2D2_IPD") + + ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug) + + rllib_config, stop_config = pg_ipd.get_rllib_config(seeds, debug) + rllib_config, stop_config = _adapt_configs_for_r2d2( + rllib_config, stop_config, debug + ) + + tune_analysis = tune.run( + R2D2Trainer, + config=rllib_config, + stop=stop_config, + checkpoint_at_end=True, + name=exp_name, + log_to_file=True, + ) + + if not debug: + _plot_log_aggregates(exp_name) + + ray.shutdown() + return tune_analysis + + +def _adapt_configs_for_r2d2(rllib_config, stop_config, debug): + rllib_config["model"] = {"use_lstm": True} + stop_config["episodes_total"] = 10 if debug else 600 + + return rllib_config, stop_config + + +def _plot_log_aggregates(exp_name): + plot_keys = ( + aggregate_and_plot_tensorboard_data.PLOT_KEYS + + matrix_sequential_social_dilemma.PLOT_KEYS + + augmented_dqn.PL + ) + plot_assemble_tags_in_one_plot = ( + aggregate_and_plot_tensorboard_data.PLOT_ASSEMBLAGE_TAGS + + matrix_sequential_social_dilemma.PLOT_ASSEMBLAGE_TAGS + ) + aggregate_and_plot_tensorboard_data.add_summary_plots( + main_path=os.path.join("~/ray_results/", exp_name), + plot_keys=plot_keys, + plot_assemble_tags_in_one_plot=plot_assemble_tags_in_one_plot, + ) + + +if __name__ == "__main__": + debug_mode = True + main(debug_mode) diff --git a/marltoolbox/experiments/rllib_api/amtft_meta_game.py b/marltoolbox/experiments/rllib_api/amtft_meta_game.py index 225aba3..0d351ce 100644 --- a/marltoolbox/experiments/rllib_api/amtft_meta_game.py +++ b/marltoolbox/experiments/rllib_api/amtft_meta_game.py @@ -1,8 +1,9 @@ +import collections import copy import json import logging import os -import random +from typing import List, Dict import numpy as np import pandas as pd @@ -11,91 +12,121 @@ from ray.rllib.agents import dqn from ray.rllib.utils import merge_dicts +from marltoolbox import utils from marltoolbox.algos import welfare_coordination from marltoolbox.experiments.rllib_api import amtft_various_env from marltoolbox.utils import ( - self_and_cross_perf, postprocessing, - miscellaneous, plot, + path, + cross_play, + exp_analysis, + restore, ) +from marltoolbox.utils.path import get_exp_dir_from_exp_name logger = logging.getLogger(__name__) def main(debug): - if debug: - test_meta_solver(debug) - - hp = _get_hyperparameters(debug) + # _extract_stats_on_welfare_announced( + # exp_dir="/home/maxime/dev-maxime/CLR/vm-data/instance-60-cpu-3-preemtible/amTFT/2021_04_23/12_58_42", + # players_ids=["player_row", "player_col"], + # ) + # + # if debug: + # test_meta_solver(debug) + use_r2d2 = True + hp = get_hyperparameters(debug, use_r2d2=use_r2d2) results = [] + ray.init(num_cpus=os.cpu_count(), local_mode=hp["debug"]) for tau in hp["tau_range"]: + hp["tau"] = tau ( all_rllib_config, hp_eval, env_config, stop_config, - ) = _produce_rllib_config_for_each_replicates(tau, hp) + ) = _produce_rllib_config_for_each_replicates(hp) mixed_rllib_configs = _mix_rllib_config(all_rllib_config, hp_eval) - - tune_analysis = ray.tune.run( - dqn.DQNTrainer, + experiment_analysis = ray.tune.run( + hp["trainer"], config=mixed_rllib_configs, verbose=1, stop=stop_config, - checkpoint_at_end=True, name=hp_eval["exp_name"], - log_to_file=not hp_eval["debug"], - # loggers=None - # if hp_eval["debug"] - # else DEFAULT_LOGGERS + (WandbLogger,), - ) - mean_player_1_payoffs, mean_player_2_payoffs = _extract_metrics( - tune_analysis, hp_eval + # log_to_file=not hp_eval["debug"], ) + ( + mean_player_1_payoffs, + mean_player_2_payoffs, + player_1_payoffs, + player_2_payoffs, + ) = extract_metrics(experiment_analysis, hp_eval) results.append( ( tau, (mean_player_1_payoffs, mean_player_2_payoffs), + (player_1_payoffs, player_2_payoffs), ) ) - _save_to_json(exp_name=hp["exp_name"], object=results) - _plot_results(exp_name=hp["exp_name"], results=results, hp_eval=hp_eval) + save_to_json(exp_name=hp["exp_name"], object=results) + plot_results( + exp_name=hp["exp_name"], + results=results, + hp_eval=hp_eval, + format_fn=format_result_for_plotting, + ) + extract_stats_on_welfare_announced( + exp_name=hp["exp_name"], players_ids=env_config["players_ids"] + ) + +def get_hyperparameters(debug, use_r2d2=False): + """Get hyperparameters for meta game with amTFT policies in base game""" -def _get_hyperparameters(debug): # env = "IteratedPrisonersDilemma" env = "IteratedAsymBoS" # env = "CoinGame" hp = amtft_various_env.get_hyperparameters( - debug, train_n_replicates=1, filter_utilitarian=False, env=env + debug, + train_n_replicates=1, + filter_utilitarian=False, + env=env, + use_r2d2=use_r2d2, ) hp.update( { - "n_replicates_over_full_exp": 2, - "final_base_game_eval_over_n_epi": 2, + "n_replicates_over_full_exp": 2 if debug else 5, + "final_base_game_eval_over_n_epi": 1 if debug else 20, "tau_range": np.arange(0.0, 1.1, 0.5) if hp["debug"] else np.arange(0.0, 1.1, 0.1), "n_self_play_in_final_meta_game": 0, - "n_cross_play_in_final_meta_game": 1, + "n_cross_play_in_final_meta_game": 1 if debug else 4, + "max_n_replicates_in_base_game": 10, } ) + + if use_r2d2: + hp["trainer"] = dqn.r2d2.R2D2Trainer + else: + hp["trainer"] = dqn.DQNTrainer + return hp -def _produce_rllib_config_for_each_replicates(tau, hp): +def _produce_rllib_config_for_each_replicates(hp): all_rllib_config = [] for replicate_i in range(hp["n_replicates_over_full_exp"]): hp_eval = _load_base_game_results( copy.deepcopy(hp), load_base_replicate_i=replicate_i ) - hp_eval["tau"] = tau ( rllib_config, @@ -110,29 +141,55 @@ def _produce_rllib_config_for_each_replicates(tau, hp): rllib_config, env_config, hp_eval ) all_rllib_config.append(rllib_config) + return all_rllib_config, hp_eval, env_config, stop_config def _load_base_game_results(hp, load_base_replicate_i): - prefix = "/home/maxime/dev-maxime/CLR/vm-data/instance-10-cpu-2/" - # prefix = "/ray_results/" + # prefix = "~/dev-maxime/CLR/vm-data/instance-10-cpu-2/" + # prefix = "~/dev-maxime/CLR/vm-data/instance-60-cpu-1-preemtible/" + prefix = "~/dev-maxime/CLR/vm-data/instance-60-cpu-2-preemtible/" + # prefix = "~/ray_results/" + prefix = os.path.expanduser(prefix) if "CoinGame" in hp["env_name"]: hp["data_dir"] = (prefix + "amTFT/2021_04_10/19_37_20",)[ load_base_replicate_i ] elif "IteratedAsymBoS" in hp["env_name"]: hp["data_dir"] = ( - prefix + "amTFT/2021_04_13/11_56_23", # 5 replicates - prefix + "amTFT/2021_04_13/13_40_03", # 5 replicates - prefix + "amTFT/2021_04_13/13_40_34", # 5 replicates - prefix + "amTFT/2021_04_13/18_06_48", # 10 replicates - prefix + "amTFT/2021_04_13/18_07_05", # 10 replicates + # Before fix + without R2D2 + # prefix + "amTFT/2021_04_13/11_56_23", # 5 replicates + # prefix + "amTFT/2021_04_13/13_40_03", # 5 replicates + # prefix + "amTFT/2021_04_13/13_40_34", # 5 replicates + # prefix + "amTFT/2021_04_13/18_06_48", # 10 replicates + # prefix + "amTFT/2021_04_13/18_07_05", # 10 replicates + # After env fix + with R2D2 + # prefix + "amTFT/2021_05_03/13_44_48", # 10 replicates + # prefix + "amTFT/2021_05_03/13_45_09", # 10 replicates + # prefix + "amTFT/2021_05_03/13_47_00", # 10 replicates + # prefix + "amTFT/2021_05_03/13_48_03", # 10 replicates + # prefix + "amTFT/2021_05_03/13_49_28", # 10 replicates + # prefix + "amTFT/2021_05_03/13_49_47", # 10 replicates + # After fix env + short debit rollout + punish instead of + # selfish + 20 rllout length insead of 6 + prefix + "amTFT/2021_05_03/13_52_07", # 10 replicates + prefix + "amTFT/2021_05_03/13_52_27", # 10 replicates + prefix + "amTFT/2021_05_03/13_53_37", # 10 replicates + ### prefix + "amTFT/2021_05_03/13_54_08", # 10 replicates + prefix + "amTFT/2021_05_03/13_54_55", # 10 replicates + prefix + "amTFT/2021_05_03/13_56_35", # 10 replicates )[load_base_replicate_i] elif "IteratedPrisonersDilemma" in hp["env_name"]: hp["data_dir"] = ( "/home/maxime/dev-maxime/CLR/vm-data/" "instance-10-cpu-4/amTFT/2021_04_13/12_12_56", )[load_base_replicate_i] + else: + raise ValueError() + + assert os.path.exists( + hp["data_dir"] + ), "Path doesn't exist. Probably that the prefix need to be changed to fit the current machine used" hp["json_file"] = _get_results_json_file_path_in_dir(hp["data_dir"]) hp["ckpt_per_welfare"] = _get_checkpoints_for_each_welfare_in_dir( @@ -150,17 +207,17 @@ def _get_vanilla_amTFT_eval_config(hp, final_eval_over_n_epi): stop_config, env_config, rllib_config = amtft_various_env.get_rllib_config( hp_eval, welfare_fn=postprocessing.WELFARE_INEQUITY_AVERSION, eval=True ) - rllib_config = amtft_various_env.modify_config_for_evaluation( - rllib_config, hp_eval, env_config + rllib_config, stop_config = amtft_various_env.modify_config_for_evaluation( + rllib_config, hp_eval, env_config, stop_config ) hp_eval["n_self_play_per_checkpoint"] = None hp_eval["n_cross_play_per_checkpoint"] = None hp_eval[ "x_axis_metric" - ] = f"policy_reward_mean.{env_config['players_ids'][0]}" + ] = f"policy_reward_mean/{env_config['players_ids'][0]}" hp_eval[ "y_axis_metric" - ] = f"policy_reward_mean.{env_config['players_ids'][1]}" + ] = f"policy_reward_mean/{env_config['players_ids'][1]}" return rllib_config, hp_eval, env_config, stop_config @@ -168,8 +225,10 @@ def _get_vanilla_amTFT_eval_config(hp, final_eval_over_n_epi): def _modify_config_to_use_welfare_coordinators( rllib_config, env_config, hp_eval ): - all_welfare_pairs_wt_payoffs = _get_all_welfare_pairs_wt_payoffs( - hp_eval, env_config["players_ids"] + all_welfare_pairs_wt_payoffs = ( + _get_all_welfare_pairs_wt_cross_play_payoffs( + hp_eval, env_config["players_ids"] + ) ) rllib_config["multiagent"]["policies_to_train"] = ["None"] @@ -192,17 +251,11 @@ def _modify_config_to_use_welfare_coordinators( "all_welfare_pairs_wt_payoffs": all_welfare_pairs_wt_payoffs, "own_player_idx": policy_idx, "opp_player_idx": opp_policy_idx, - # "own_default_welfare_fn": postprocessing.WELFARE_INEQUITY_AVERSION - # if policy_idx - # else postprocessing.WELFARE_UTILITARIAN, - # "opp_default_welfare_fn": postprocessing.WELFARE_INEQUITY_AVERSION - # if opp_policy_idx - # else postprocessing.WELFARE_UTILITARIAN, "own_default_welfare_fn": "inequity aversion" - if policy_idx + if policy_idx == 1 else "utilitarian", "opp_default_welfare_fn": "inequity aversion" - if opp_policy_idx + if opp_policy_idx == 1 else "utilitarian", "policy_id_to_load": policy_id, "policy_checkpoints": hp_eval["ckpt_per_welfare"], @@ -227,7 +280,7 @@ def _mix_rllib_config(all_rllib_configs, hp_eval): hp_eval["n_self_play_in_final_meta_game"] == 0 and hp_eval["n_cross_play_in_final_meta_game"] != 0 ): - master_config = _mix_configs_policies( + master_config = cross_play.utils.mix_policies_in_given_rllib_configs( all_rllib_configs, n_mix_per_config=hp_eval["n_cross_play_in_final_meta_game"], ) @@ -236,101 +289,16 @@ def _mix_rllib_config(all_rllib_configs, hp_eval): raise ValueError() -def _mix_configs_policies(all_rllib_configs, n_mix_per_config): - assert n_mix_per_config <= len(all_rllib_configs) - 1 - policy_ids = all_rllib_configs[0]["multiagent"]["policies"].keys() - assert len(policy_ids) == 2 - _assert_all_config_use_the_same_policies(all_rllib_configs, policy_ids) - - policy_config_variants = _gather_policy_variant_per_policy_id( - all_rllib_configs, policy_ids - ) - - master_config = _create_one_master_config( - all_rllib_configs, policy_config_variants, policy_ids, n_mix_per_config - ) - return master_config - - -def _create_one_master_config( - all_rllib_configs, policy_config_variants, policy_ids, n_mix_per_config -): - all_policy_mix = [] - player_1, player_2 = policy_ids - for config_idx, p1_policy_config in enumerate( - policy_config_variants[player_1] - ): - policies_mixes = _produce_n_mix_with_player_2_policies( - policy_config_variants, - player_2, - config_idx, - n_mix_per_config, - player_1, - p1_policy_config, - ) - all_policy_mix.extend(policies_mixes) - - master_config = copy.deepcopy(all_rllib_configs[0]) - master_config["multiagent"]["policies"] = tune.grid_search(all_policy_mix) - return master_config - - -def _produce_n_mix_with_player_2_policies( - policy_config_variants, - player_2, - config_idx, - n_mix_per_config, - player_1, - p1_policy_config, -): - p2_policy_configs_sampled = _get_p2_policies_samples_excluding_self( - policy_config_variants, player_2, config_idx, n_mix_per_config +def extract_metrics(experiment_analysis, hp_eval): + # TODO PROBLEM using last result but there are several episodes with + # different welfare functions !! => BUT plot is good + # Metric from the last result means the metric average over the last + # training iteration and we have only one training iteration here. + player_1_payoffs = utils.exp_analysis.extract_metrics_for_each_trials( + experiment_analysis, metric=hp_eval["x_axis_metric"] ) - policies_mixes = [] - for p2_policy_config in p2_policy_configs_sampled: - policy_mix = { - player_1: p1_policy_config, - player_2: p2_policy_config, - } - policies_mixes.append(policy_mix) - return policies_mixes - - -def _get_p2_policies_samples_excluding_self( - policy_config_variants, player_2, config_idx, n_mix_per_config -): - p2_policy_config_variants = copy.deepcopy(policy_config_variants[player_2]) - p2_policy_config_variants.pop(config_idx) - p2_policy_configs_sampled = random.sample( - p2_policy_config_variants, n_mix_per_config - ) - return p2_policy_configs_sampled - - -def _assert_all_config_use_the_same_policies(all_rllib_configs, policy_ids): - for rllib_config in all_rllib_configs: - assert rllib_config["multiagent"]["policies"].keys() == policy_ids - - -def _gather_policy_variant_per_policy_id(all_rllib_configs, policy_ids): - policy_config_variants = {} - for policy_id in policy_ids: - policy_config_variants[policy_id] = [] - for rllib_config in all_rllib_configs: - policy_config_variants[policy_id].append( - copy.deepcopy( - rllib_config["multiagent"]["policies"][policy_id] - ) - ) - return policy_config_variants - - -def _extract_metrics(tune_analysis, hp_eval): - player_1_payoffs = miscellaneous.extract_metric_values_per_trials( - tune_analysis, metric=hp_eval["x_axis_metric"] - ) - player_2_payoffs = miscellaneous.extract_metric_values_per_trials( - tune_analysis, metric=hp_eval["y_axis_metric"] + player_2_payoffs = utils.exp_analysis.extract_metrics_for_each_trials( + experiment_analysis, metric=hp_eval["y_axis_metric"] ) mean_player_1_payoffs = sum(player_1_payoffs) / len(player_1_payoffs) mean_player_2_payoffs = sum(player_2_payoffs) / len(player_2_payoffs) @@ -339,33 +307,53 @@ def _extract_metrics(tune_analysis, hp_eval): mean_player_1_payoffs, mean_player_2_payoffs, ) - return mean_player_1_payoffs, mean_player_2_payoffs + return ( + mean_player_1_payoffs, + mean_player_2_payoffs, + player_1_payoffs, + player_2_payoffs, + ) + +class NumpyEncoder(json.JSONEncoder): + """ Special json encoder for numpy types """ -def _save_to_json(exp_name, object): - exp_dir = os.path.join("~/ray_results", exp_name) - exp_dir = os.path.expanduser(exp_dir) - json_file = os.path.join(exp_dir, "final_eval_in_base_game.json") + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + return json.JSONEncoder.default(self, obj) + + +def save_to_json(exp_name, object, filename="final_eval_in_base_game.json"): + exp_dir = get_exp_dir_from_exp_name(exp_name) + json_file = os.path.join(exp_dir, filename) + if not os.path.exists(exp_dir): + os.makedirs(exp_dir) with open(json_file, "w") as outfile: json.dump(object, outfile) -def _plot_results(exp_name, results, hp_eval): - exp_dir = os.path.join("~/ray_results", exp_name) - exp_dir = os.path.expanduser(exp_dir) +def plot_results( + exp_name, results, hp_eval, format_fn, jitter=0.0, title=None +): + exp_dir = get_exp_dir_from_exp_name(exp_name) + data_groups_per_mode = format_fn(results) - data_groups_per_mode = _format_result_for_plotting(results) + background_area_coord = None + if "CoinGame" not in hp_eval["env_name"] and "env_class" in hp_eval.keys(): + background_area_coord = hp_eval["env_class"].PAYOFF_MATRIX - if "CoinGame" in hp_eval["env_name"]: - background_area_coord = None - else: - background_area_coord = hp_eval["env_class"].PAYOUT_MATRIX plot_config = plot.PlotConfig( + title=title, save_dir_path=exp_dir, xlim=hp_eval["x_limits"], ylim=hp_eval["y_limits"], markersize=5, - jitter=hp_eval["jitter"], + jitter=jitter, xlabel="player 1 payoffs", ylabel="player 2 payoffs", x_scale_multiplier=hp_eval["plot_axis_scale_multipliers"][0], @@ -377,12 +365,9 @@ def _plot_results(exp_name, results, hp_eval): plot_helper.plot_dots(data_groups_per_mode) -def _format_result_for_plotting(results): +def format_result_for_plotting(results): data_groups_per_mode = {} - for ( - tau, - (mean_player_1_payoffs, mean_player_2_payoffs), - ) in results: + for (tau, (mean_player_1_payoffs, mean_player_2_payoffs), _) in results: df_row_dict = {} df_row_dict[f"mean"] = ( mean_player_1_payoffs, @@ -394,7 +379,7 @@ def _format_result_for_plotting(results): def test_meta_solver(debug): - hp = _get_hyperparameters(debug) + hp = get_hyperparameters(debug) hp = _load_base_game_results(hp, load_base_replicate_i=0) stop, env_config, rllib_config = amtft_various_env.get_rllib_config( hp, welfare_fn=postprocessing.WELFARE_INEQUITY_AVERSION, eval=True @@ -422,8 +407,10 @@ def test_meta_solver(debug): ) ) - all_welfare_pairs_wt_payoffs = _get_all_welfare_pairs_wt_payoffs( - hp, env_config["players_ids"] + all_welfare_pairs_wt_payoffs = ( + _get_all_welfare_pairs_wt_cross_play_payoffs( + hp, env_config["players_ids"] + ) ) print("all_welfare_pairs_wt_payoffs", all_welfare_pairs_wt_payoffs) player_meta_policy.setup_meta_game( @@ -469,120 +456,96 @@ def _get_results_json_file_path_in_dir(data_dir): def _eval_dir(data_dir): eval_dir_parents_2 = os.path.join(data_dir, "eval") - eval_dir_parents_1 = _get_unique_child_dir(eval_dir_parents_2) - eval_dir = _get_unique_child_dir(eval_dir_parents_1) + eval_dir_parents_1 = path.get_unique_child_dir(eval_dir_parents_2) + eval_dir = path.get_unique_child_dir(eval_dir_parents_1) return eval_dir def _get_json_results_path(eval_dir): - all_files_filtered = _get_children_paths_wt_selecting_filter( - eval_dir, _filter=self_and_cross_perf.RESULTS_SUMMARY_FILENAME_PREFIX + all_files_filtered = path.get_children_paths_wt_selecting_filter( + eval_dir, + _filter=cross_play.evaluator.RESULTS_SUMMARY_FILENAME_PREFIX, ) all_files_filtered = [ file_path for file_path in all_files_filtered if ".json" in file_path ] + all_files_filtered = [ + file_path + for file_path in all_files_filtered + if os.path.split(file_path)[-1].startswith("plotself") + or not os.path.split(file_path)[-1].startswith("plot") + ] assert len(all_files_filtered) == 1, f"{all_files_filtered}" json_result_path = os.path.join(eval_dir, all_files_filtered[0]) return json_result_path -def _get_unique_child_dir(_dir): - list_dir = os.listdir(_dir) - assert len(list_dir) == 1, f"{list_dir}" - unique_child_dir = os.path.join(_dir, list_dir[0]) - return unique_child_dir - - def _get_checkpoints_for_each_welfare_in_dir(data_dir, hp): + """Get the checkpoints of the base game policies in self-play""" ckpt_per_welfare = {} for welfare_fn, welfare_name in hp["welfare_functions"]: welfare_training_save_dir = os.path.join(data_dir, welfare_fn, "coop") - all_replicates_save_dir = _get_dir_of_each_replicate( - welfare_training_save_dir + all_replicates_save_dir = _get_replicates_dir_for_all_possible_config( + hp, welfare_training_save_dir ) - ckpts = _get_checkpoint_for_each_replicates(all_replicates_save_dir) + all_replicates_save_dir = _filter_checkpoints_if_utilitarians( + all_replicates_save_dir, hp + ) + ckpts = restore.get_checkpoint_for_each_replicates( + all_replicates_save_dir + ) + print("len(ckpts)", len(ckpts)) ckpt_per_welfare[welfare_name.replace("_", " ")] = ckpts return ckpt_per_welfare -def _get_dir_of_each_replicate(welfare_training_save_dir): - return _get_children_paths_wt_selecting_filter( - welfare_training_save_dir, _filter="DQN_" - ) - - -def _get_checkpoint_for_each_replicates(all_replicates_save_dir): - ckpt_dir_per_replicate = [] - for replicate_dir_path in all_replicates_save_dir: - ckpt_dir_path = _get_ckpt_dir_for_one_replicate(replicate_dir_path) - ckpt_path = _get_ckpt_fro_ckpt_dir(ckpt_dir_path) - ckpt_dir_per_replicate.append(ckpt_path) - return ckpt_dir_per_replicate - - -def _get_ckpt_dir_for_one_replicate(replicate_dir_path): - partialy_filtered_ckpt_dir = _get_children_paths_wt_selecting_filter( - replicate_dir_path, _filter="checkpoint_" - ) - ckpt_dir = [ - file_path - for file_path in partialy_filtered_ckpt_dir - if ".is_checkpoint" not in file_path - ] - assert len(ckpt_dir) == 1, f"{ckpt_dir}" - return ckpt_dir[0] +def _get_replicates_dir_for_all_possible_config(hp, welfare_training_save_dir): + if "use_r2d2" in hp and hp["use_r2d2"]: + all_replicates_save_dir = get_dir_of_each_replicate( + welfare_training_save_dir, str_in_dir="R2D2_" + ) + else: + all_replicates_save_dir = get_dir_of_each_replicate( + welfare_training_save_dir, str_in_dir="DQN_" + ) + return all_replicates_save_dir -def _get_ckpt_fro_ckpt_dir(ckpt_dir_path): - partialy_filtered_ckpt_path = _get_children_paths_wt_discarding_filter( - ckpt_dir_path, _filter="tune_metadata" - ) - ckpt_path = [ - file_path - for file_path in partialy_filtered_ckpt_path - if ".is_checkpoint" not in file_path +def _filter_checkpoints_if_utilitarians(all_replicates_save_dir, hp): + util_or_inequity_aversion = [ + path.split("/")[-3] for path in all_replicates_save_dir ] - assert len(ckpt_path) == 1, f"{ckpt_path}" - return ckpt_path[0] - - -def _get_children_paths_wt_selecting_filter(parent_dir_path, _filter): - return _get_children_paths_filters( - parent_dir_path, selecting_filter=_filter - ) + print("util_or_inequity_aversion", util_or_inequity_aversion) + if all( + welfare == "utilitarian_welfare" + for welfare in util_or_inequity_aversion + ): + all_replicates_save_dir = ( + utils.path.filter_list_of_replicates_by_results( + all_replicates_save_dir, + filter_key="episode_reward_mean", + filter_mode=tune_analysis.ABOVE, + filter_threshold=95, + ) + ) + if len(all_replicates_save_dir) > hp["max_n_replicates_in_base_game"]: + all_replicates_save_dir = all_replicates_save_dir[ + : hp["max_n_replicates_in_base_game"] + ] + else: + print( + "not filtering ckpts. all_replicates_save_dir[0]", + all_replicates_save_dir[0], + ) + return all_replicates_save_dir -def _get_children_paths_wt_discarding_filter(parent_dir_path, _filter): - return _get_children_paths_filters( - parent_dir_path, discarding_filter=_filter +def get_dir_of_each_replicate(welfare_training_save_dir, str_in_dir="DQN_"): + return path.get_children_paths_wt_selecting_filter( + welfare_training_save_dir, _filter=str_in_dir ) -def _get_children_paths_filters( - parent_dir_path: str, - selecting_filter: str = None, - discarding_filter: str = None, -): - filtered_children = os.listdir(parent_dir_path) - if selecting_filter is not None: - filtered_children = [ - filename - for filename in filtered_children - if selecting_filter in filename - ] - if discarding_filter is not None: - filtered_children = [ - filename - for filename in filtered_children - if discarding_filter not in filename - ] - filtered_children_path = [ - os.path.join(parent_dir_path, filename) - for filename in filtered_children - ] - return filtered_children_path - - def _convert_policy_config_to_meta_policy_config( hp, policy_id, policy_config, policy_class ): @@ -601,18 +564,18 @@ def _convert_policy_config_to_meta_policy_config( return meta_policy_config -def _get_all_welfare_pairs_wt_payoffs(hp, player_ids): +def _get_all_welfare_pairs_wt_cross_play_payoffs(hp, player_ids): with open(hp["json_file"]) as json_file: json_data = json.load(json_file) - cross_play_data = _keep_only_cross_play_values(json_data) + cross_play_data = keep_only_cross_play_values(json_data) cross_play_means = _keep_only_mean_values(cross_play_data) all_welfare_pairs_wt_payoffs = _order_players(cross_play_means, player_ids) return all_welfare_pairs_wt_payoffs -def _keep_only_cross_play_values(json_data): +def keep_only_cross_play_values(json_data): return { _format_eval_mode(eval_mode): v for eval_mode, v in json_data.items() @@ -623,7 +586,7 @@ def _keep_only_cross_play_values(json_data): def _format_eval_mode(eval_mode): k_wtout_kind_of_play = eval_mode.split(":")[-1].strip() both_welfare_fn = k_wtout_kind_of_play.split(" vs ") - return welfare_coordination.WelfareCoordinationTorchPolicy._from_pair_of_welfare_names_to_key( + return welfare_coordination.WelfareCoordinationTorchPolicy.from_pair_of_welfare_names_to_key( *both_welfare_fn ) @@ -648,6 +611,112 @@ def _order_players(cross_play_means, player_ids): } +def extract_stats_on_welfare_announced( + players_ids, exp_name=None, exp_dir=None, nested_info=False +): + if exp_dir is None: + exp_dir = get_exp_dir_from_exp_name(exp_name) + all_in_exp_dir = utils.path.get_children_paths_wt_discarding_filter( + exp_dir, _filter=None + ) + all_dirs = utils.path.keep_dirs_only(all_in_exp_dir) + dir_name_by_tau = _group_dir_name_by_tau(all_dirs, nested_info) + dir_name_by_tau = _order_by_tau(dir_name_by_tau) + _get_stats_for_each_tau(dir_name_by_tau, players_ids, exp_dir) + + +def _group_dir_name_by_tau(all_dirs: List[str], nested_info) -> Dict: + dirs_by_tau = {} + for trial_dir in all_dirs: + tau, welfare_set_annonced = _get_tau_value(trial_dir, nested_info) + if tau is None: + continue + if tau not in dirs_by_tau.keys(): + dirs_by_tau[tau] = [] + dirs_by_tau[tau].append((trial_dir, welfare_set_annonced)) + return dirs_by_tau + + +def _get_tau_value(trial_dir_path, nested_info=False): + full_epi_file_path = os.path.join( + trial_dir_path, "full_episodes_logs.json" + ) + if os.path.exists(full_epi_file_path): + full_epi_logs = utils.path._read_all_lines_of_file(full_epi_file_path) + first_epi_first_step = json.loads(full_epi_logs[1]) + tau = [] + welfare_set_annonced = {} + for policy_id, policy_info in first_epi_first_step.items(): + print('policy_info["info"]', policy_info["info"]) + if nested_info: + meta_policy_info = policy_info["info"][ + f"meta_policy/{policy_id}" + ] + else: + meta_policy_info = policy_info["info"][f"meta_policy"] + for k, v in meta_policy_info.items(): + if k.startswith("tau_"): + tau_value = float(k.split("_")[-1]) + tau.append(tau_value) + welfare_set_annonced[policy_id] = v["welfare_set_annonced"] + tau = set(tau) + assert len(tau) == 1, f"tau {tau}" + tau = list(tau)[0] + assert len(welfare_set_annonced.keys()) == 2 + return tau, welfare_set_annonced + else: + return None, None + + +def _order_by_tau(dir_name_by_tau): + return collections.OrderedDict(sorted(dir_name_by_tau.items())) + + +def _get_stats_for_each_tau(dir_name_by_tau, players_ids, exp_dir): + file_path = os.path.join(exp_dir, "welfare_announced_by_tau.txt") + with open(file_path, "w") as f: + for tau, dirs_data in dir_name_by_tau.items(): + all_welfares_player_1 = [] + all_welfares_player_2 = [] + for dir, welfare_announced in dirs_data: + all_welfares_player_1.append(welfare_announced[players_ids[0]]) + all_welfares_player_2.append(welfare_announced[players_ids[1]]) + all_welfares_player_1 = _format_in_same_order( + all_welfares_player_1 + ) + all_welfares_player_2 = _format_in_same_order( + all_welfares_player_2 + ) + count_announced_p1 = collections.Counter(all_welfares_player_1) + count_announced_p2 = collections.Counter(all_welfares_player_2) + msg = ( + f"===== Welfare sets announced with tau = {tau} =====\n" + f"Player 1: {count_announced_p1}\n" + f"Player 2: {count_announced_p2}\n" + ) + print(msg) + f.write(msg) + + +def _format_in_same_order(all_welfares_player_n): + formatted_welfare_announced = [] + for welfare_announced in all_welfares_player_n: + formated_name = "" + if "utilitarian" in welfare_announced: + formated_name += "utilitarian + " + if ( + "inequity" in welfare_announced + or "egalitarian" in welfare_announced + ): + formated_name += "egalitarian + " + if "mixed" in welfare_announced: + formated_name += "mixed" + if formated_name.endswith(" + "): + formated_name = formated_name[:-3] + formatted_welfare_announced.append(formated_name) + return formatted_welfare_announced + + if __name__ == "__main__": - debug_mode = False + debug_mode = True main(debug_mode) diff --git a/marltoolbox/experiments/rllib_api/amtft_various_env.py b/marltoolbox/experiments/rllib_api/amtft_various_env.py index f6900c8..5c58b5c 100644 --- a/marltoolbox/experiments/rllib_api/amtft_various_env.py +++ b/marltoolbox/experiments/rllib_api/amtft_various_env.py @@ -1,21 +1,27 @@ import copy import logging import os +import argparse import ray from ray import tune from ray.rllib.agents import dqn -from ray.rllib.agents.dqn.dqn_torch_policy import postprocess_nstep_and_prio +from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio from ray.rllib.utils import merge_dicts from ray.rllib.utils.schedules import PiecewiseSchedule +from ray.tune.integration.wandb import WandbLoggerCallback +from ray.rllib.agents.ppo import PPOTrainer +from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch +from ray.rllib.utils import merge_dicts +from ray.rllib.utils.exploration import StochasticSampling from ray.tune.integration.wandb import WandbLogger from ray.tune.logger import DEFAULT_LOGGERS -from marltoolbox.algos import amTFT +from marltoolbox.algos import amTFT, augmented_r2d2 +from marltoolbox.algos.augmented_ppo import MyPPOTorchPolicy from marltoolbox.envs import ( matrix_sequential_social_dilemma, - vectorized_coin_game, - vectorized_mixed_motive_coin_game, + coin_game, ssd_mixed_motive_coin_game, ) from marltoolbox.envs.utils.wrappers import ( @@ -28,46 +34,64 @@ postprocessing, miscellaneous, plot, - self_and_cross_perf, callbacks, + cross_play, + config_helper, + exp_analysis, ) logger = logging.getLogger(__name__) -def main(debug, train_n_replicates=None, filter_utilitarian=None, env=None): +def main( + debug, + train_n_replicates=None, + filter_utilitarian=None, + env=None, + use_r2d2=False, + use_policy_gratient=False, + hyperparameter_search=False, +): hparams = get_hyperparameters( - debug, train_n_replicates, filter_utilitarian, env + debug, + train_n_replicates, + filter_utilitarian, + env, + use_r2d2=use_r2d2, + use_policy_gratient=use_policy_gratient, + hyperparameter_search=hyperparameter_search, ) if hparams["load_plot_data"] is None: ray.init( - num_cpus=os.cpu_count(), num_gpus=0, local_mode=hparams["debug"] + num_gpus=0, + num_cpus=os.cpu_count(), + local_mode=hparams["debug"], ) # Train if hparams["load_policy_data"] is None: - tune_analysis_per_welfare = train_for_each_welfare_function( + experiment_analysis_per_welfare = train_for_each_welfare_function( hparams ) else: - tune_analysis_per_welfare = load_tune_analysis( + experiment_analysis_per_welfare = load_experiment_analysis( hparams["load_policy_data"] ) # Eval & Plot analysis_metrics_per_mode = config_and_evaluate_cross_play( - tune_analysis_per_welfare, hparams + experiment_analysis_per_welfare, hparams ) ray.shutdown() else: - tune_analysis_per_welfare = None + experiment_analysis_per_welfare = None # Plot analysis_metrics_per_mode = config_and_evaluate_cross_play( - tune_analysis_per_welfare, hparams + experiment_analysis_per_welfare, hparams ) - return tune_analysis_per_welfare, analysis_metrics_per_mode + return experiment_analysis_per_welfare, analysis_metrics_per_mode def get_hyperparameters( @@ -76,14 +100,43 @@ def get_hyperparameters( filter_utilitarian=None, env=None, reward_uncertainty=0.0, + use_r2d2=False, + use_policy_gratient=False, + hyperparameter_search=False, ): - if debug: + if not debug: + parser = argparse.ArgumentParser() + parser.add_argument( + "--env", + type=str, + choices=[ + "IteratedPrisonersDilemma", + "IteratedAsymBoS", + "IteratedAsymBoSandPD", + "CoinGame", + "ABCoinGame", + ], + help="Env to use.", + ) + parser.add_argument("--train_n_replicates", type=int) + args = parser.parse_args() + args = args.__dict__ + if "env" in args.keys(): + env = args["env"] + if "train_n_replicates" in args.keys(): + train_n_replicates = args["train_n_replicates"] + print("env", env) + if hyperparameter_search: + if train_n_replicates is None: + train_n_replicates = 1 + n_times_more_utilitarians_seeds = 1 + elif debug: train_n_replicates = 2 n_times_more_utilitarians_seeds = 1 - elif train_n_replicates is None: - n_times_more_utilitarians_seeds = 4 - train_n_replicates = 4 else: + if train_n_replicates is None: + train_n_replicates = 40 + # train_n_replicates = 2 n_times_more_utilitarians_seeds = 4 n_seeds_to_prepare = train_n_replicates * ( @@ -93,6 +146,7 @@ def get_hyperparameters( exp_name, _ = log.log_in_current_day_dir("amTFT") hparams = { "debug": debug, + "use_r2d2": use_r2d2, "filter_utilitarian": filter_utilitarian if filter_utilitarian is not None else not debug, @@ -101,115 +155,69 @@ def get_hyperparameters( "n_times_more_utilitarians_seeds": n_times_more_utilitarians_seeds, "exp_name": exp_name, "log_n_points": 250, + "num_envs_per_worker": 16, "load_plot_data": None, - # Example: "load_plot_data": ".../SelfAndCrossPlay_save.p", "load_policy_data": None, - # "load_policy_data": { - # "Util": [ - # ".../IBP/amTFT/trials/" - # "DQN_AsymCoinGame_...", - # ".../IBP/amTFT/trials/" - # "DQN_AsymCoinGame_..."], - # 'IA':[ - # ".../temp/IBP/amTFT/trials/" - # "DQN_AsymCoinGame_...", - # ".../IBP/amTFT/trials/" - # "DQN_AsymCoinGame_..."], - # }, - # "load_policy_data": { - # "Util": [ - # "~/dev-maxime/CLR/vm-data/instance-60-cpu-1-preemtible/amTFT" - # "/2021_03_28/19_38_55/utilitarian_welfare/coop" - # "/DQN_VectMixedMotiveCG_06231_00000_0_seed=1616960338_2021-03-29_00-52-23/checkpoint_250/checkpoint-250", - # # "~/dev-maxime/CLR/vm-data/instance-60-cpu-1-preemtible/amTFT" - # # "/2021_03_24/18_22_47/utilitarian_welfare/coop" - # # "/DQN_VectMixedMotiveCG_e1de7_00001_1_seed=1616610171_2021-03-25_00-27-29/checkpoint_250/checkpoint-250", - # # "~/dev-maxime/CLR/vm-data/instance-60-cpu-1-preemtible/amTFT" - # # "/2021_03_24/18_22_47/utilitarian_welfare/coop" - # # "/DQN_VectMixedMotiveCG_e1de7_00002_2_seed=1616610172_2021-03-25_00-27-29/checkpoint_250/checkpoint-250", - # ], - # 'IA':[ - # "~/dev-maxime/CLR/vm-data/instance-60-cpu-1-preemtible" - # "/amTFT/2021_03_28/19_38_55/inequity_aversion_welfare/coop" - # "/DQN_VectMixedMotiveCG_d5a2a_00000_0_seed=1616960335_2021-03-28_21-23-26/checkpoint_250/checkpoint-250", - # # "~/dev-maxime/CLR/vm-data/instance-60-cpu-1-preemtible" - # # "/amTFT/2021_03_24/18_22_47/inequity_aversion_welfare/coop" - # # "/DQN_VectMixedMotiveCG_9cfe6_00001_1_seed=1616610168_2021-03-24_20-22-11/checkpoint_250/checkpoint-250", - # # "~/dev-maxime/CLR/vm-data/instance-60-cpu-1-preemtible" - # # "/amTFT/2021_03_24/18_22_47/inequity_aversion_welfare/coop" - # # "/DQN_VectMixedMotiveCG_9cfe6_00002_2_seed=1616610169_2021-03-24_20-22-11/checkpoint_250/checkpoint-250", - # ], - # }, - # "load_policy_data": { - # "Util": [ - # "~/ray_results/amTFT" - # "/2021_03_24/18_22_47/utilitarian_welfare/coop" - # "/DQN_VectMixedMotiveCG_e1de7_00000_0_seed=1616610170_2021-03-25_00-27-29/checkpoint_250/checkpoint-250", - # "~/ray_results/amTFT" - # "/2021_03_24/18_22_47/utilitarian_welfare/coop" - # "/DQN_VectMixedMotiveCG_e1de7_00001_1_seed=1616610171_2021-03-25_00-27-29/checkpoint_250/checkpoint-250", - # "~/ray_results/amTFT" - # "/2021_03_24/18_22_47/utilitarian_welfare/coop" - # "/DQN_VectMixedMotiveCG_e1de7_00002_2_seed=1616610172_2021-03-25_00-27-29/checkpoint_250/checkpoint-250", - # ], - # 'IA': [ - # "~/ray_results" - # "/amTFT/2021_03_24/18_22_47/inequity_aversion_welfare/coop" - # "/DQN_VectMixedMotiveCG_9cfe6_00000_0_seed=1616610167_2021-03-24_20-22-10/checkpoint_250/checkpoint-250", - # "~/ray_results" - # "/amTFT/2021_03_24/18_22_47/inequity_aversion_welfare/coop" - # "/DQN_VectMixedMotiveCG_9cfe6_00001_1_seed=1616610168_2021-03-24_20-22-11/checkpoint_250/checkpoint-250", - # "~/ray_results" - # "/amTFT/2021_03_24/18_22_47/inequity_aversion_welfare/coop" - # "/DQN_VectMixedMotiveCG_9cfe6_00002_2_seed=1616610169_2021-03-24_20-22-11/checkpoint_250/checkpoint-250", - # ], - # }, "amTFTPolicy": amTFT.AmTFTRolloutsTorchPolicy, "welfare_functions": [ (postprocessing.WELFARE_INEQUITY_AVERSION, "inequity_aversion"), (postprocessing.WELFARE_UTILITARIAN, "utilitarian"), ], + # "amTFT_punish_instead_of_selfish": False, + # "use_short_debit_rollout": False, + # "punishment_helped": False, + "amTFT_punish_instead_of_selfish": True, + "use_short_debit_rollout": True, + "punishment_helped": True, "jitter": 0.05, "hiddens": [64], "gamma": 0.96, # If not in self play then amTFT # will be evaluated against a naive selfish policy or an exploiter "self_play": True, - # "self_play": False, # Not tested - "env_name": "IteratedPrisonersDilemma" if env is None else env, + # "env_name": "IteratedPrisonersDilemma" if env is None else env, + # "env_name": "IteratedBoS" if env is None else env, # "env_name": "IteratedAsymBoS" if env is None else env, - # "env_name": "CoinGame" if env is None else env, - # "env_name": "AsymCoinGame" if env is None else env, - # "env_name": "MixedMotiveCoinGame" if env is None else env, - # "env_name": "SSDMixedMotiveCoinGame" if env is None else env, + # "env_name": "IteratedAsymBoSandPD" if env is None else env, + "env_name": "CoinGame" if env is None else env, + # "env_name": "ABCoinGame" if env is None else env, "overwrite_reward": True, "explore_during_evaluation": True, "reward_uncertainty": reward_uncertainty, + "use_other_play": False, + # "use_other_play": True, + "use_policy_gratient": use_policy_gratient, + "use_MSE_in_r2d2": True, + "hyperparameter_search": hyperparameter_search, + "using_wandb": use_policy_gratient, } + if hparams["load_policy_data"] is not None: + hparams["train_n_replicates"] = len( + hparams["load_policy_data"]["Util"] + ) + hparams = modify_hyperparams_for_the_selected_env(hparams) - hparams["plot_keys"] = amTFT.PLOT_KEYS + hparams["plot_keys"] - hparams["plot_assemblage_tags"] = ( - amTFT.PLOT_ASSEMBLAGE_TAGS + hparams["plot_assemblage_tags"] - ) return hparams -def load_tune_analysis(grouped_checkpoints_paths: dict): - tune_analysis = {} - msg = "start load_tune_analysis" +def load_experiment_analysis(grouped_checkpoints_paths: dict): + experiment_analysis = {} + msg = "start load_experiment_analysis" print(msg) logger.info(msg) for group_name, checkpoints_paths in grouped_checkpoints_paths.items(): - one_tune_analysis = miscellaneous.load_one_tune_analysis( - checkpoints_paths, n_dir_level_between_ckpt_and_exp_state=3 + one_experiment_analysis = ( + exp_analysis.load_experiment_analysis_wt_ckpt_only( + checkpoints_paths, n_dir_level_between_ckpt_and_exp_state=3 + ) ) - tune_analysis[group_name] = one_tune_analysis - msg = "end load_tune_analysis" + experiment_analysis[group_name] = one_experiment_analysis + msg = "end load_experiment_analysis" print(msg) logger.info(msg) - return tune_analysis + return experiment_analysis def modify_hyperparams_for_the_selected_env(hp): @@ -220,77 +228,86 @@ def modify_hyperparams_for_the_selected_env(hp): amTFT.PLOT_ASSEMBLAGE_TAGS + aggregate_and_plot_tensorboard_data.PLOT_ASSEMBLAGE_TAGS ) - mul_temp = 1.0 - hp["punishment_multiplier"] = 3.0 hp["buf_frac"] = 0.125 - hp["training_intensity"] = 10 - # hp["rollout_length"] = 40 - # hp["n_rollout_replicas"] = 20 - hp["rollout_length"] = 4 - hp["n_rollout_replicas"] = 5 + hp["training_intensity"] = 1 if hp["debug"] else 40 + hp["rollout_length"] = 20 + hp["n_rollout_replicas"] = 10 + hp["beta_steps_config"] = [ + (0, 0.125), + (1.0, 0.25), + ] if "CoinGame" in hp["env_name"]: - hp["plot_keys"] += vectorized_coin_game.PLOT_KEYS - hp["plot_assemblage_tags"] += vectorized_coin_game.PLOT_ASSEMBLAGE_TAGS hp["n_steps_per_epi"] = 20 if hp["debug"] else 100 hp["n_epi"] = 10 if hp["debug"] else 4000 - hp["base_lr"] = 0.1 - hp["bs_epi_mul"] = 1 + hp["eval_over_n_epi"] = 1 + # hp["base_lr"] = 0.1 + + hp["bs_epi_mul"] = 4 hp["both_players_can_pick_the_same_coin"] = False hp["sgd_momentum"] = 0.9 hp["lambda"] = 0.96 hp["alpha"] = 0.0 - hp["beta"] = 0.5 + hp["beta"] = config_helper.configurable_linear_scheduler( + "beta_steps_config" + ) - hp["debit_threshold"] = 30.0 - hp["jitter"] = 0.02 + hp["debit_threshold"] = 3.0 + hp["punishment_multiplier"] = 6.0 + hp["jitter"] = 0.0 hp["filter_utilitarian"] = False - hp["target_network_update_freq"] = 100 * hp["n_steps_per_epi"] - hp["last_exploration_temp_value"] = 0.03 * mul_temp - - hp["temperature_schedule"] = PiecewiseSchedule( - endpoints=[ - (0, 2.0 * mul_temp), - ( - int(hp["n_steps_per_epi"] * hp["n_epi"] * 0.20), - 0.5 * mul_temp, - ), - ( - int(hp["n_steps_per_epi"] * hp["n_epi"] * 0.60), - hp["last_exploration_temp_value"], - ), - ], - outside_value=hp["last_exploration_temp_value"], - framework="torch", - ) - - if "AsymCoinGame" in hp["env_name"]: - hp["x_limits"] = (-0.5, 3.0) - hp["y_limits"] = (-1.1, 0.6) - hp["env_class"] = vectorized_coin_game.AsymVectorizedCoinGame - elif "MixedMotiveCoinGame" in hp["env_name"]: - if "SSDMixedMotiveCoinGame" in hp["env_name"]: - hp["debit_threshold"] = 3.0 - hp["x_limits"] = (-0.25, 1.0) - hp["y_limits"] = (-0.25, 1.5) - hp[ - "env_class" - ] = ssd_mixed_motive_coin_game.SSDMixedMotiveCoinGame - else: - hp["x_limits"] = (-2.0, 2.0) - hp["y_limits"] = (-0.5, 3.0) - hp[ - "env_class" - ] = vectorized_mixed_motive_coin_game.VectMixedMotiveCG - hp["both_players_can_pick_the_same_coin"] = True + # hp["buf_frac"] = 0.5 + hp["target_network_update_freq"] = 30 * hp["n_steps_per_epi"] + hp["last_exploration_temp_value"] = 0.003 / 4 + + # hp["lr_steps_config"] = [ + # (0, 1.0), + # (0.25, 0.5), + # (1.0, 1e-9), + # ] + # hp["temperature_steps_config"] = [ + # (0, 0.75), + # (0.2, 0.45), + # (0.9, hp["last_exploration_temp_value"]), + # ] + + hp["lr_steps_config"] = [ + (0, 1e-9), + (0.2, 1.0), + # (0.5, 1.0/3), + # (0.5, 1.0/10), + # (0.75, 1.0/10/3), + (1.0, 1e-9), + # (1.0, 1.0), + # (1.0, 0.33), + ] + + hp["buf_frac"] = 0.1 + # hp["last_exploration_temp_value"] = 0.1 + hp["temperature_steps_config"] = [ + (0, 0.75), + # (0, 0.25), + (0.2, 0.25), + # (0.5, 0.1), + # (0.7, 0.1), + (0.9, hp["last_exploration_temp_value"]), + ] + + # hp["buf_frac"] = 0.25 + hp["base_lr"] = 0.1 / 4 + + if "ABCoinGame" in hp["env_name"]: + raise NotImplementedError() else: - hp["x_limits"] = (-0.5, 0.6) - hp["y_limits"] = (-0.5, 0.6) - hp["env_class"] = vectorized_coin_game.VectorizedCoinGame + hp["plot_keys"] += coin_game.PLOT_KEYS + hp["plot_assemblage_tags"] += coin_game.PLOT_ASSEMBLAGE_TAGS + hp["x_limits"] = (-0.1, 0.6) + hp["y_limits"] = (-0.1, 0.6) + hp["env_class"] = coin_game.CoinGame else: hp["plot_keys"] += matrix_sequential_social_dilemma.PLOT_KEYS @@ -299,36 +316,33 @@ def modify_hyperparams_for_the_selected_env(hp): ] += matrix_sequential_social_dilemma.PLOT_ASSEMBLAGE_TAGS hp["base_lr"] = 0.03 - hp["bs_epi_mul"] = 1 - hp["n_steps_per_epi"] = 20 - hp["n_epi"] = 10 if hp["debug"] else 800 + hp["bs_epi_mul"] = 4 + hp["n_steps_per_epi"] = 10 if hp["debug"] else 20 + hp["n_epi"] = 5 if hp["debug"] else 800 + hp["eval_over_n_epi"] = 5 hp["lambda"] = 0.96 hp["alpha"] = 0.0 hp["beta"] = 1.0 hp["sgd_momentum"] = 0.0 hp["debit_threshold"] = 10.0 + hp["punishment_multiplier"] = 3.0 hp["target_network_update_freq"] = 30 * hp["n_steps_per_epi"] - hp["last_exploration_temp_value"] = 0.1 * mul_temp - - hp["temperature_schedule"] = PiecewiseSchedule( - endpoints=[ - (0, 2.0 * mul_temp), - ( - int(hp["n_steps_per_epi"] * hp["n_epi"] * 0.33), - 0.5 * mul_temp, - ), - ( - int(hp["n_steps_per_epi"] * hp["n_epi"] * 0.66), - hp["last_exploration_temp_value"], - ), - ], - outside_value=hp["last_exploration_temp_value"], - framework="torch", - ) - - if "IteratedPrisonersDilemma" in hp["env_name"]: + hp["last_exploration_temp_value"] = 0.1 + + hp["temperature_steps_config"] = [ + (0, 2.0), + (0.33, 0.5), + (0.66, hp["last_exploration_temp_value"]), + ] + hp["lr_steps_config"] = [ + (0, 0.0), + (0.05, 1.0), + (1.0, 1e-9), + ] + + if "IteratedPrisonersDilemma" == hp["env_name"]: hp["filter_utilitarian"] = False hp["x_limits"] = (-3.5, 0.5) hp["y_limits"] = (-3.5, 0.5) @@ -336,35 +350,55 @@ def modify_hyperparams_for_the_selected_env(hp): hp[ "env_class" ] = matrix_sequential_social_dilemma.IteratedPrisonersDilemma - elif "IteratedAsymBoS" in hp["env_name"]: + elif "IteratedAsymBoS" == hp["env_name"]: hp["x_limits"] = (-0.1, 4.1) hp["y_limits"] = (-0.1, 4.1) hp["utilitarian_filtering_threshold"] = 3.2 hp["env_class"] = matrix_sequential_social_dilemma.IteratedAsymBoS + elif "IteratedAsymBoSandPD" == hp["env_name"]: + # hp["x_limits"] = (-3.1, 4.1) + # hp["y_limits"] = (-3.1, 4.1) + hp["x_limits"] = (-6.1, 5.1) + hp["y_limits"] = (-6.1, 5.1) + hp["utilitarian_filtering_threshold"] = 3.2 + hp[ + "env_class" + ] = matrix_sequential_social_dilemma.IteratedAsymBoSandPD + elif "IteratedBoS" == hp["env_name"]: + hp["x_limits"] = (-0.1, 3.1) + hp["y_limits"] = (-0.1, 3.1) + hp["utilitarian_filtering_threshold"] = 2.6 + hp["env_class"] = matrix_sequential_social_dilemma.IteratedBoS else: raise NotImplementedError(f'hp["env_name"]: {hp["env_name"]}') - hp["lr_schedule"] = [ - (0, 0.0), - (int(hp["n_steps_per_epi"] * hp["n_epi"] * 0.05), hp["base_lr"]), - (int(hp["n_steps_per_epi"] * hp["n_epi"]), hp["base_lr"] / 1e9), - ] - hp["plot_axis_scale_multipliers"] = ( (1 / hp["n_steps_per_epi"]), # for x axis (1 / hp["n_steps_per_epi"]), ) # for y axis - hp["env_class"] = add_RewardUncertaintyEnvClassWrapper( - env_class=hp["env_class"], - reward_uncertainty_std=hp["reward_uncertainty"], - ) + if "reward_uncertainty" in hp.keys() and hp["reward_uncertainty"] != 0.0: + hp["env_class"] = add_RewardUncertaintyEnvClassWrapper( + env_class=hp["env_class"], + reward_uncertainty_std=hp["reward_uncertainty"], + ) + + hp["temperature_schedule"] = config_helper.get_temp_scheduler() + hp["lr_schedule"] = config_helper.get_lr_scheduler() return hp def train_for_each_welfare_function(hp): - tune_analysis_per_welfare = {} + experiment_analysis_per_welfare = {} + + if hp["use_r2d2"]: + trainer = dqn.r2d2.R2D2Trainer + elif hp["use_policy_gratient"]: + trainer = PPOTrainer + else: + trainer = dqn.DQNTrainer + for welfare_fn, welfare_group_name in hp["welfare_functions"]: print("==============================================") print( @@ -375,24 +409,33 @@ def train_for_each_welfare_function(hp): hp = preprocess_utilitarian_config(hp) stop, env_config, rllib_config = get_rllib_config(hp, welfare_fn) + rllib_config_copy = copy.deepcopy(rllib_config) + if hp["using_wandb"]: + rllib_config_copy["logger_config"]["wandb"][ + "group" + ] += f"_{welfare_group_name}" + exp_name = os.path.join(hp["exp_name"], welfare_fn) results = amTFT.train_amtft( stop_config=stop, - rllib_config=rllib_config, + rllib_config=rllib_config_copy, name=exp_name, - TrainerClass=dqn.DQNTrainer, + TrainerClass=trainer, plot_keys=hp["plot_keys"], plot_assemblage_tags=hp["plot_assemblage_tags"], debug=hp["debug"], log_to_file=not hp["debug"], - loggers=None if hp["debug"] else DEFAULT_LOGGERS + (WandbLogger,), + punish_instead_of_selfish=hp["amTFT_punish_instead_of_selfish"], + loggers=DEFAULT_LOGGERS + (WandbLogger,) + if hp["using_wandb"] + else DEFAULT_LOGGERS, ) if welfare_fn == postprocessing.WELFARE_UTILITARIAN: results, hp = postprocess_utilitarian_results( results, env_config, hp ) - tune_analysis_per_welfare[welfare_group_name] = results - return tune_analysis_per_welfare + experiment_analysis_per_welfare[welfare_group_name] = results + return experiment_analysis_per_welfare def preprocess_utilitarian_config(hp): @@ -406,11 +449,12 @@ def preprocess_utilitarian_config(hp): def get_rllib_config(hp, welfare_fn, eval=False): - stop = { + stop_config = { "episodes_total": hp["n_epi"], } env_config = get_env_config(hp) + policies = get_policies(hp, env_config, welfare_fn, eval) selected_seeds = hp["seeds"][: hp["train_n_replicates"]] @@ -422,13 +466,6 @@ def get_rllib_config(hp, welfare_fn, eval=False): "multiagent": { "policies": policies, "policy_mapping_fn": lambda agent_id: agent_id, - # When replay_mode=lockstep, RLlib will replay all the agent - # transitions at a particular timestep together in a batch. - # This allows the policy to implement differentiable shared - # computations between agents it controls at that timestep. - # When replay_mode=independent, - # transitions are replayed independently per policy. - # "replay_mode": "lockstep", "observation_fn": amTFT.observation_fn, }, "gamma": hp["gamma"], @@ -446,8 +483,17 @@ def get_rllib_config(hp, welfare_fn, eval=False): # Size of a batch sampled from replay buffer for training. Note that # if async_updates is set, then each worker returns gradients for a # batch of this size. - "train_batch_size": int(hp["n_steps_per_epi"] * hp["bs_epi_mul"]), - "training_intensity": hp["training_intensity"], + "train_batch_size": tune.sample_from( + lambda spec: int( + spec.config["env_config"]["max_steps"] + * spec.config["env_config"]["bs_epi_mul"] + ) + ), + "training_intensity": tune.sample_from( + lambda spec: spec.config["num_envs_per_worker"] + * max(1, spec.config["num_workers"]) + * hp["training_intensity"] + ), # Minimum env steps to optimize for per train call. This value does # not affect learning, only the length of iterations. "timesteps_per_iteration": hp["n_steps_per_epi"] @@ -456,12 +502,10 @@ def get_rllib_config(hp, welfare_fn, eval=False): "min_iter_time_s": 0.0, # General config "framework": "torch", - # LE supports only 1 worker only otherwise - # it would be mixing several opponents trajectories "num_workers": 0, # LE supports only 1 env per worker only otherwise # several episodes would be played at the same time - "num_envs_per_worker": 1, + "num_envs_per_worker": hp["num_envs_per_worker"], # Callbacks that will be run during various phases of training. See the # `DefaultCallbacks` class and # `examples/custom_metrics_and_callbacks.py` for more usage @@ -469,27 +513,25 @@ def get_rllib_config(hp, welfare_fn, eval=False): "callbacks": callbacks.merge_callbacks( amTFT.AmTFTCallbacks, log.get_logging_callbacks_class( - log_full_epi=True, log_full_epi_interval=100 + log_full_epi=hp["num_envs_per_worker"] == 1, + log_model_sumamry=True, ), ), - "logger_config": { - "wandb": { - "project": "amTFT", - "group": hp["exp_name"], - "api_key_file": os.path.join( - os.path.dirname(__file__), "../../../api_key_wandb" - ), - "log_config": True, - }, - }, # === DQN Models === # Update the target network every `target_network_update_freq` steps. "target_network_update_freq": hp["target_network_update_freq"], # === Replay buffer === # Size of the replay buffer. Note that if async_updates is set, then # each worker will have a replay buffer of this size. - "buffer_size": max( - int(hp["n_steps_per_epi"] * hp["n_epi"] * hp["buf_frac"]), 5 + "buffer_size": tune.sample_from( + lambda spec: max( + int( + spec.config["env_config"]["max_steps"] + * spec.config["env_config"]["buf_frac"] + * spec.stop["episodes_total"] + ), + 5, + ) ), # Whether to use dueling dqn "dueling": True, @@ -502,12 +544,17 @@ def get_rllib_config(hp, welfare_fn, eval=False): "prioritized_replay": False, "model": { # Number of hidden layers for fully connected net - "fcnet_hiddens": hp["hiddens"], + "fcnet_hiddens": [64], # Nonlinearity for fully connected net (tanh, relu) "fcnet_activation": "relu", }, # How many steps of the model to sample before learning starts. - "learning_starts": int(hp["n_steps_per_epi"] * hp["bs_epi_mul"]), + "learning_starts": tune.sample_from( + lambda spec: int( + spec.config["env_config"]["max_steps"] + * spec.config["env_config"]["bs_epi_mul"] + ) + ), # === Exploration Settings === # Default exploration behavior, iff `explore`=None is passed into # compute_action(s). @@ -526,16 +573,22 @@ def get_rllib_config(hp, welfare_fn, eval=False): # Add constructor kwargs here (if any). "temperature_schedule": hp["temperature_schedule"], }, + "optimizer": { + "sgd_momentum": hp["sgd_momentum"], + }, + # "log_level": "DEBUG", + "evaluation_interval": None, + "evaluation_parallel_to_training": False, } - if "CoinGame" in hp["env_name"]: - rllib_config["model"] = { - "dim": env_config["grid_size"], - "conv_filters": [[16, [3, 3], 1], [32, [3, 3], 1]], - # [Channel, [Kernel, Kernel], Stride]] - } - - return stop, env_config, rllib_config + rllib_config = _modify_config_for_coin_game(rllib_config, env_config, hp) + rllib_config, stop_config = _modify_config_for_r2d2( + rllib_config, hp, stop_config, eval + ) + rllib_config, stop_config = _modify_config_for_policy_gradient( + rllib_config, hp, stop_config, eval + ) + return stop_config, env_config, rllib_config def get_env_config(hp): @@ -547,23 +600,34 @@ def get_env_config(hp): "both_players_can_pick_the_same_coin": hp[ "both_players_can_pick_the_same_coin" ], + "punishment_helped": hp["punishment_helped"], } else: env_config = { "players_ids": ["player_row", "player_col"], "max_steps": hp["n_steps_per_epi"], } + env_config["bs_epi_mul"] = hp.get("bs_epi_mul", None) + env_config["buf_frac"] = hp.get("buf_frac", None) + env_config["temperature_steps_config"] = hp.get( + "temperature_steps_config", None + ) + env_config["lr_steps_config"] = hp.get("lr_steps_config", None) + env_config["beta_steps_config"] = hp.get("beta_steps_config", None) + env_config["use_other_play"] = hp["use_other_play"] return env_config def get_policies(hp, env_config, welfare_fn, eval=False): - PolicyClass = hp["amTFTPolicy"] - NestedPolicyClass, CoopNestedPolicyClass = get_nested_policy_class( - hp, welfare_fn - ) + policy_class = hp["amTFTPolicy"] + ( + nested_policy_class, + coop_nested_policy_class, + nested_selfish_policy_class, + ) = get_nested_policy_class(hp) if eval: - NestedPolicyClass = CoopNestedPolicyClass + nested_policy_class = coop_nested_policy_class amTFT_config_update = merge_dicts( amTFT.DEFAULT_CONFIG, @@ -572,23 +636,52 @@ def get_policies(hp, env_config, welfare_fn, eval=False): "working_state": "train_coop", "welfare_key": welfare_fn, "verbose": 1 if hp["debug"] else 0, - # "verbose": 1 if hp["debug"] else 2, + # "verbose": 2, "punishment_multiplier": hp["punishment_multiplier"], "debit_threshold": hp["debit_threshold"], - "rollout_length": min(hp["n_steps_per_epi"], hp["rollout_length"]), - "n_rollout_replicas": hp["n_rollout_replicas"], + "rollout_length": 2 + if hp["debug"] + else min(hp["n_steps_per_epi"], hp["rollout_length"]), + "n_rollout_replicas": 2 + if hp["debug"] + else hp["n_rollout_replicas"], + "punish_instead_of_selfish": hp["amTFT_punish_instead_of_selfish"], + "use_short_debit_rollout": hp["use_short_debit_rollout"], "optimizer": { "sgd_momentum": hp["sgd_momentum"], }, "nested_policies": [ - {"Policy_class": CoopNestedPolicyClass, "config_update": {}}, - {"Policy_class": NestedPolicyClass, "config_update": {}}, - {"Policy_class": CoopNestedPolicyClass, "config_update": {}}, - {"Policy_class": NestedPolicyClass, "config_update": {}}, + { + "Policy_class": coop_nested_policy_class, + "config_update": { + postprocessing.ADD_INEQUITY_AVERSION_WELFARE: [ + welfare_fn + == postprocessing.WELFARE_INEQUITY_AVERSION, + hp["alpha"], + hp["beta"], + hp["gamma"], + hp["lambda"], + ], + postprocessing.ADD_UTILITARIAN_WELFARE: ( + welfare_fn == postprocessing.WELFARE_UTILITARIAN + ), + }, + }, + {"Policy_class": nested_policy_class, "config_update": {}}, + { + "Policy_class": coop_nested_policy_class, + "config_update": {}, + }, + {"Policy_class": nested_policy_class, "config_update": {}}, ], }, ) + if hp["amTFT_punish_instead_of_selfish"]: + amTFT_config_update["nested_policies"].append( + {"Policy_class": nested_selfish_policy_class, "config_update": {}}, + ) + policy_1_config = copy.deepcopy(amTFT_config_update) policy_1_config["own_policy_id"] = env_config["players_ids"][0] policy_1_config["opp_policy_id"] = env_config["players_ids"][1] @@ -598,49 +691,237 @@ def get_policies(hp, env_config, welfare_fn, eval=False): policy_2_config["opp_policy_id"] = env_config["players_ids"][0] policies = { - env_config["players_ids"][0]: ( + env_config["players_ids"][0]: [ # The default policy is DQN defined in DQNTrainer but # we overwrite it to use the LE policy - PolicyClass, + policy_class, hp["env_class"](env_config).OBSERVATION_SPACE, hp["env_class"].ACTION_SPACE, policy_1_config, - ), - env_config["players_ids"][1]: ( - PolicyClass, + ], + env_config["players_ids"][1]: [ + policy_class, hp["env_class"](env_config).OBSERVATION_SPACE, hp["env_class"].ACTION_SPACE, policy_2_config, - ), + ], } return policies -def get_nested_policy_class(hp, welfare_fn): - NestedPolicyClass = amTFT.DEFAULT_NESTED_POLICY_SELFISH - CoopNestedPolicyClass = NestedPolicyClass.with_updates( - # TODO problem: this prevent to use HP searches on gamma etc. +def get_nested_policy_class(hp): + nested_selfish_policy_class = _select_base_policy(hp) + + if hp["use_policy_gratient"]: + original_postprocess_fn = compute_gae_for_sample_batch + else: + original_postprocess_fn = postprocess_nstep_and_prio + + coop_nested_policy_class = nested_selfish_policy_class.with_updates( postprocess_fn=miscellaneous.merge_policy_postprocessing_fn( - postprocessing.welfares_postprocessing_fn( - add_utilitarian_welfare=( - welfare_fn == postprocessing.WELFARE_UTILITARIAN - ), - add_inequity_aversion_welfare=( - welfare_fn == postprocessing.WELFARE_INEQUITY_AVERSION + postprocessing.welfares_postprocessing_fn(), + original_postprocess_fn, + ) + ) + + if hp["amTFT_punish_instead_of_selfish"]: + nested_policy_class = nested_selfish_policy_class.with_updates( + # TODO problem: this prevent to use HP searches on gamma etc. + postprocess_fn=miscellaneous.merge_policy_postprocessing_fn( + postprocessing.welfares_postprocessing_fn( + add_opponent_neg_reward=True, ), - inequity_aversion_alpha=hp["alpha"], - inequity_aversion_beta=hp["beta"], - inequity_aversion_gamma=hp["gamma"], - inequity_aversion_lambda=hp["lambda"], - ), - postprocess_nstep_and_prio, + original_postprocess_fn, + ) ) + else: + nested_policy_class = nested_selfish_policy_class + return ( + nested_policy_class, + coop_nested_policy_class, + nested_selfish_policy_class, ) - return NestedPolicyClass, CoopNestedPolicyClass + + +def _select_base_policy(hp): + if hp["use_policy_gratient"]: + print("using PPOTorchPolicy") + assert not hp["use_r2d2"] + # nested_policy_class = PPOTorchPolicy + nested_policy_class = MyPPOTorchPolicy + elif hp["use_r2d2"]: + print("using augmented_r2d2.MyR2D2TorchPolicy") + if hp["use_MSE_in_r2d2"]: + nested_policy_class = augmented_r2d2.MyR2D2TorchPolicyWtMSELoss + else: + nested_policy_class = augmented_r2d2.MyR2D2TorchPolicy + + else: + nested_policy_class = amTFT.DEFAULT_NESTED_POLICY_SELFISH + + # if hp["use_other_play"]: + # print("use_other_play with_updates") + # + # if "CoinGame" in hp["env_name"]: + # symetries_available = coin_game.CoinGame.SYMMETRIES + # elif hp["env_name"] == "IteratedBoS": + # symetries_available = ( + # matrix_sequential_social_dilemma.IteratedBoS.SYMMETRIES + # ) + # else: + # raise NotImplementedError() + # + # nested_policy_class = nested_policy_class.with_updates( + # _after_loss_init=partial( + # other_play.after_init_wrap_model_other_play, + # symetries_available=symetries_available, + # ) + # ) + return nested_policy_class + + +def _modify_config_for_coin_game(rllib_config, env_config, hp): + if "CoinGame" in hp["env_name"]: + rllib_config["hiddens"] = [32] + rllib_config["model"] = { + "dim": env_config["grid_size"], + "conv_filters": [[64, [3, 3], 1], [64, [3, 3], 1]], + # [Channel, [Kernel, Kernel], Stride]] + "fcnet_hiddens": [64, 64], + } + return rllib_config + + +def _modify_config_for_r2d2(rllib_config, hp, stop_config, eval=False): + if hp["use_r2d2"]: + rllib_config["model"]["use_lstm"] = True + rllib_config["use_h_function"] = False + rllib_config["burn_in"] = 0 + rllib_config["zero_init_states"] = False + rllib_config["model"]["lstm_cell_size"] = 16 + if hp["debug"]: + rllib_config["model"]["max_seq_len"] = 2 + else: + rllib_config["model"]["max_seq_len"] = 20 + rllib_config["env_config"]["bs_epi_mul"] = 4 + if "CoinGame" in hp["env_name"]: + rllib_config["training_intensity"] = tune.sample_from( + lambda spec: spec.config["num_envs_per_worker"] + * max(1, spec.config["num_workers"]) + * hp["training_intensity"] + ) + if not eval: + stop_config["episodes_total"] = 8000 * hp["n_epi"] / 4000 + + return rllib_config, stop_config + + +def _modify_config_for_policy_gradient( + rllib_config, hp, stop_config, eval=False +): + if hp["use_policy_gratient"]: + # rllib_config.pop("lr_schedule") + rllib_config.pop("target_network_update_freq") + rllib_config.pop("buffer_size") + rllib_config.pop("dueling") + rllib_config.pop("double_q") + rllib_config.pop("prioritized_replay") + rllib_config.pop("training_intensity") + rllib_config.pop("hiddens") + rllib_config.pop("learning_starts") + + if hp["debug"]: + rllib_config["train_batch_size"] = int(128 * 2) + rllib_config["num_sgd_iter"] = 2 + elif not hp["debug"] and not eval: + # if hp["hyperparameter_search"]: + # rllib_config["train_batch_size"] = tune.grid_search( + # [1024, 4096] + # ) + # else: + rllib_config["train_batch_size"] = 4096 + + stop_config["episodes_total"] = 5000 * hp["n_epi"] / 4000 + + rllib_config["exploration_config"] = { + # The Exploration class to use. In the simplest case, + # this is the name (str) of any class present in the + # `rllib.utils.exploration` package. + # You can also provide the python class directly or + # the full location of your class (e.g. + # "ray.rllib.utils.exploration.epsilon_greedy. + # EpsilonGreedy"). + "type": StochasticSampling, + # Add constructor kwargs here (if any). + # "temperature_schedule": hp["temperature_schedule"], + } + + if hp["hyperparameter_search"]: + # rllib_config["lr"] = 1e-4 + rllib_config["lr"] = tune.grid_search([1e-3, 3e-4]) + rllib_config["vf_loss_coeff"] = tune.grid_search([3.0, 1.0, 0.3]) + # rllib_config["model"]["vf_share_layers"] = tune.grid_search( + # [False, True] + # ) + # rllib_config["use_gae"] = tune.grid_search([True, False]) + # rllib_config["batch_mode"] = tune.grid_search( + # ["truncate_episodes", "complete_episodes"] + # ) + # rllib_config["batch_mode"] = "complete_episodes" + + rllib_config["env_config"]["beta_steps_config"] = [ + (0, 0.125 * 3), + (1.0, 0.25 * 3), + ] + # rllib_config["env_config"]["beta_steps_config"] = [ + # (0, 0.125*4), + # (1.0, 0.25*4), + # ] + + else: + rllib_config["lr"] = 0.1 / 300 + # hp["sgd_momentum"] = 0.5 + + # Coefficient of the value function loss. IMPORTANT: you must tune this if + # you set vf_share_layers=True inside your model's config. + # "vf_loss_coeff": 1.0, + # # Coefficient of the entropy regularizer. + # "entropy_coeff": 0.0, + # # Decay schedule for the entropy regularizer. + # "entropy_coeff_schedule": None, + # # PPO clip parameter. + # "clip_param": 0.3, + # # Clip param for the value function. Note that this is sensitive to the + # # scale of the rewards. If your expected V is large, increase this. + # "vf_clip_param": 10.0, + # # If specified, clip the global norm of gradients by this amount. + # "grad_clip": None, + # # Target value for KL divergence. + # "kl_target": 0.01, + + # if hp["hyperparameter_search"]: + # rllib_config["kl_target"] = tune.grid_search([0.003, 0.01, 0.03]) + # rllib_config["vf_clip_param"] = 1.0 + # rllib_config["vf_loss_coeff"] = 0.1 + + rllib_config["logger_config"] = { + "wandb": { + "project": "amTFT", + "group": hp["exp_name"], + "api_key_file": os.path.join( + os.path.dirname(__file__), "../../api_key_wandb" + ), + "log_config": True, + }, + } + + return rllib_config, stop_config def postprocess_utilitarian_results(results, env_config, hp): + """Reverse the changes made by preprocess_utilitarian_results""" + hp_cp = copy.deepcopy(hp) if hp["filter_utilitarian"]: @@ -648,7 +929,7 @@ def postprocess_utilitarian_results(results, env_config, hp): hp_cp["train_n_replicates"] // hp_cp["n_times_more_utilitarians_seeds"] ) - results = miscellaneous.filter_tune_results( + results = exp_analysis.filter_trials( results, metric=f"policy_reward_mean/{env_config['players_ids'][0]}", metric_threshold=hp_cp["utilitarian_filtering_threshold"] @@ -663,18 +944,25 @@ def postprocess_utilitarian_results(results, env_config, hp): return results, hp_cp -def config_and_evaluate_cross_play(tune_analysis_per_welfare, hp): - config_eval, env_config, stop, hp_eval = generate_eval_config(hp) +def config_and_evaluate_cross_play(experiment_analysis_per_welfare, hp): + config_eval, env_config, stop, hp_eval = _generate_eval_config(hp) return evaluate_self_play_cross_play( - tune_analysis_per_welfare, config_eval, env_config, stop, hp_eval + experiment_analysis_per_welfare, config_eval, env_config, stop, hp_eval ) def evaluate_self_play_cross_play( - tune_analysis_per_welfare, config_eval, env_config, stop, hp_eval + experiment_analysis_per_welfare, config_eval, env_config, stop, hp_eval ): + if hp_eval["use_r2d2"]: + trainer = dqn.r2d2.R2D2Trainer + elif hp_eval["use_policy_gratient"]: + trainer = PPOTrainer + else: + trainer = dqn.DQNTrainer + exp_name = os.path.join(hp_eval["exp_name"], "eval") - evaluator = self_and_cross_perf.SelfAndCrossPlayEvaluator( + evaluator = cross_play.evaluator.SelfAndCrossPlayEvaluator( exp_name=exp_name, local_mode=hp_eval["debug"], ) @@ -684,17 +972,23 @@ def evaluate_self_play_cross_play( policies_to_load_from_checkpoint=copy.deepcopy( env_config["players_ids"] ), - tune_analysis_per_exp=tune_analysis_per_welfare, - TrainerClass=dqn.DQNTrainer, + experiment_analysis_per_welfare=experiment_analysis_per_welfare, + rllib_trainer_class=trainer, n_self_play_per_checkpoint=hp_eval["n_self_play_per_checkpoint"], n_cross_play_per_checkpoint=hp_eval["n_cross_play_per_checkpoint"], to_load_path=hp_eval["load_plot_data"], ) + return plot_evaluation( + hp_eval, evaluator, analysis_metrics_per_mode, env_config + ) + + +def plot_evaluation(hp_eval, evaluator, analysis_metrics_per_mode, env_config): if "CoinGame" in hp_eval["env_name"]: background_area_coord = None else: - background_area_coord = hp_eval["env_class"].PAYOUT_MATRIX + background_area_coord = hp_eval["env_class"].PAYOFF_MATRIX plot_config = plot.PlotConfig( xlim=hp_eval["x_limits"], ylim=hp_eval["y_limits"], @@ -715,29 +1009,32 @@ def evaluate_self_play_cross_play( y_axis_metric=f"policy_reward_mean/{env_config['players_ids'][1]}", ) - print_inequity_aversion_welfare(env_config, analysis_metrics_per_mode) + # print_inequity_aversion_welfare(env_config, analysis_metrics_per_mode) return analysis_metrics_per_mode -def generate_eval_config(hp): - hp_eval = modify_hp_for_evaluation(hp) +def _generate_eval_config(hp): + hp_eval = modify_hp_for_evaluation(hp, hp["eval_over_n_epi"]) fake_welfare_function = postprocessing.WELFARE_INEQUITY_AVERSION - stop, env_config, rllib_config = get_rllib_config( + stop_config, env_config, rllib_config = get_rllib_config( hp_eval, fake_welfare_function, eval=True ) - config_eval = modify_config_for_evaluation( - rllib_config, hp_eval, env_config + config_eval, stop_config = modify_config_for_evaluation( + rllib_config, hp_eval, env_config, stop_config ) - return config_eval, env_config, stop, hp_eval + return config_eval, env_config, stop_config, hp_eval def modify_hp_for_evaluation(hp: dict, eval_over_n_epi: int = 1): hp_eval = copy.deepcopy(hp) # TODO is the overwrite_reward hp useless? hp_eval["overwrite_reward"] = False + hp_eval["num_envs_per_worker"] = 1 hp_eval["n_epi"] = eval_over_n_epi - hp_eval["n_steps_per_epi"] = 5 if hp_eval["debug"] else 100 + if hp_eval["debug"]: + hp_eval["n_epi"] = 1 + hp_eval["n_steps_per_epi"] = 5 hp_eval["bs_epi_mul"] = 1 hp_eval["plot_axis_scale_multipliers"] = ( # for x axis @@ -756,9 +1053,17 @@ def modify_hp_for_evaluation(hp: dict, eval_over_n_epi: int = 1): return hp_eval -def modify_config_for_evaluation(config_eval, hp, env_config): +def modify_config_for_evaluation(config_eval, hp, env_config, stop_config): config_eval["explore"] = False - config_eval["seed"] = None + config_eval.pop("seed") + # = None + config_eval["num_workers"] = 0 + assert ( + config_eval["num_envs_per_worker"] == 1 + ), f'num_envs_per_worker {config_eval["num_envs_per_worker"]}' + assert ( + stop_config["episodes_total"] <= 20 + ), f'episodes_total {stop_config["episodes_total"]}' policies = config_eval["multiagent"]["policies"] for policy_id in policies.keys(): policy_config = policies[policy_id][3] @@ -768,7 +1073,7 @@ def modify_config_for_evaluation(config_eval, hp, env_config): naive_player_policy_config = policies[naive_player_id][3] naive_player_policy_config["working_state"] = "eval_naive_selfish" - if hp["explore_during_evaluation"]: + if hp["explore_during_evaluation"] and not hp["use_policy_gratient"]: tmp_mul = 1.0 config_eval["explore"] = (miscellaneous.OVERWRITE_KEY, True) config_eval["exploration_config"] = { @@ -788,11 +1093,11 @@ def modify_config_for_evaluation(config_eval, hp, env_config): policies[policy_id][3]["debit_threshold"] = 0.5 policies[policy_id][3]["last_k"] = hp["n_steps_per_epi"] - 1 - return config_eval + return config_eval, stop_config def print_inequity_aversion_welfare(env_config, analysis_metrics_per_mode): - plotter = self_and_cross_perf.SelfAndCrossPlayPlotter() + plotter = cross_play.evaluator.SelfAndCrossPlayPlotter() plotter._reset( x_axis_metric=f"nested_policy/{env_config['players_ids'][0]}/worker_0/" f"policy_0/sum_over_epi_inequity_aversion_welfare", @@ -808,5 +1113,21 @@ def print_inequity_aversion_welfare(env_config, analysis_metrics_per_mode): if __name__ == "__main__": - debug_mode = True - main(debug_mode) + use_r2d2 = True + use_policy_gratient = False + # + # use_r2d2 = False + # use_policy_gratient = True + # + debug_mode = False + # debug_mode = True + + # hyperparameter_search = True + hyperparameter_search = False + + main( + debug_mode, + use_r2d2=use_r2d2, + use_policy_gratient=use_policy_gratient, + hyperparameter_search=hyperparameter_search, + ) diff --git a/marltoolbox/experiments/rllib_api/amtft_vs_lvl1_exploiter.py b/marltoolbox/experiments/rllib_api/amtft_vs_lvl1_exploiter.py index adadd63..4dc5141 100644 --- a/marltoolbox/experiments/rllib_api/amtft_vs_lvl1_exploiter.py +++ b/marltoolbox/experiments/rllib_api/amtft_vs_lvl1_exploiter.py @@ -42,29 +42,31 @@ def main(debug, train_n_replicates=None, filter_utilitarian=None, env=None): # Train if hparams["load_policy_data"] is None: - tune_analysis_per_welfare = train_for_each_welfare_function( + experiment_analysis_per_welfare = train_for_each_welfare_function( hparams ) else: - tune_analysis_per_welfare = amtft_various_env.load_tune_analysis( - hparams["load_policy_data"] + experiment_analysis_per_welfare = ( + amtft_various_env.load_experiment_analysis( + hparams["load_policy_data"] + ) ) _evaluate_perf_wt_and_without_exploiter( - hparams, tune_analysis_per_welfare + hparams, experiment_analysis_per_welfare ) ray.shutdown() else: - tune_analysis_per_welfare = None + experiment_analysis_per_welfare = None _evaluate_perf_wt_and_without_exploiter( - hparams, tune_analysis_per_welfare + hparams, experiment_analysis_per_welfare ) - return tune_analysis_per_welfare + return experiment_analysis_per_welfare def train_for_each_welfare_function(hp): - tune_analysis_per_welfare = {} + experiment_analysis_per_welfare = {} for welfare_fn, welfare_group_name in hp["welfare_functions"]: print("==============================================") print( @@ -96,8 +98,8 @@ def train_for_each_welfare_function(hp): results, hp = amtft_various_env.postprocess_utilitarian_results( results, env_config, hp ) - tune_analysis_per_welfare[welfare_group_name] = results - return tune_analysis_per_welfare + experiment_analysis_per_welfare[welfare_group_name] = results + return experiment_analysis_per_welfare def modify_hp_for_lvl_exploiter(hp): @@ -109,7 +111,6 @@ def modify_hp_for_lvl_exploiter(hp): hp["punishment_multiplier_range"] = range(2, 4, 1) else: hp["lookahead_n_times"] = 5 - # hp["punishment_multiplier_range"] = [1, 2, 4, 8] hp["punishment_multiplier_range"] = [4] if "CoinGame" in hp["env_name"]: hp["debit_threshold_range"] = [2, 4, 8, 16, 32, 64] @@ -151,10 +152,19 @@ def modify_hp_for_lvl_exploiter(hp): if "CoinGame" in hp["env_name"]: hp["player_1_metric"] = "policy_reward_mean/player_red" hp["player_2_metric"] = "policy_reward_mean/player_blue" + if "AsymCoinGame" in hp["env_name"]: + hp["ylim"] = (-1.0, 3.0) + else: + hp["ylim"] = (0, 1.0) else: hp["player_1_metric"] = "policy_reward_mean/player_row" hp["player_2_metric"] = "policy_reward_mean/player_col" - + if "IteratedPrisonersDilemma" in hp["env_name"]: + hp["ylim"] = (-6, 0) + elif "IteratedAsymBoS" in hp["env_name"]: + hp["ylim"] = (0, 4) + else: + raise ValueError() # if hp["debug"]: # hp["load_policy_data"] = { # "Util": [ @@ -206,7 +216,7 @@ def _update_exploiter_policy_config(hp, exploiter_config, welfare_fn): ( NestedPolicyClass, CoopNestedPolicyClass, - ) = amtft_various_env.get_nested_policy_class(hp, welfare_fn) + ) = amtft_various_env.get_nested_policy_class(hp) exploiter_config[0] = amTFT.Level1amTFTExploiterTorchPolicy exploiter_config[3] = merge_dicts( @@ -251,15 +261,17 @@ def _set_exploiter_policy_config(hp, rllib_config, exploiter_config): return rllib_config -def _evaluate_perf_wt_and_without_exploiter(hp, tune_analysis_per_welfare): +def _evaluate_perf_wt_and_without_exploiter( + hp, experiment_analysis_per_welfare +): for exploiter_activated in (True, False): _evaluate_perf_inself_or_cross_play( - hp, tune_analysis_per_welfare, exploiter_activated + hp, experiment_analysis_per_welfare, exploiter_activated ) def _evaluate_perf_inself_or_cross_play( - hp, tune_analysis_per_welfare, exploiter_activated + hp, experiment_analysis_per_welfare, exploiter_activated ): for in_cross_play in (False, True): hp_eval = copy.deepcopy(hp) @@ -276,14 +288,14 @@ def _evaluate_perf_inself_or_cross_play( # ) _evaluate_perf_over_all_possible_hp( hp_eval, - tune_analysis_per_welfare, + experiment_analysis_per_welfare, exploiter_activated=exploiter_activated, in_cross_play=in_cross_play, ) def _evaluate_perf_over_all_possible_hp( - hp, tune_analysis_per_welfare, exploiter_activated, in_cross_play + hp, experiment_analysis_per_welfare, exploiter_activated, in_cross_play ): for punishment_multiplier in hp["punishment_multiplier_range"]: hp_eval = copy.deepcopy(hp) @@ -297,7 +309,7 @@ def _evaluate_perf_over_all_possible_hp( rows_data_p1_ia_vs_ia, rows_data_p2_ia_vs_ia, ) = _gather_data_for_one_group_of_hp_sets( - hp_eval, tune_analysis_per_welfare, punishment_multiplier + hp_eval, experiment_analysis_per_welfare, punishment_multiplier ) data_groups["p1_util_vs_util"] = pd.DataFrame( @@ -322,7 +334,7 @@ def _evaluate_perf_over_all_possible_hp( def _gather_data_for_one_group_of_hp_sets( - hp_eval, tune_analysis_per_welfare, punishment_multiplier + hp_eval, experiment_analysis_per_welfare, punishment_multiplier ): rows_data_p1_util_vs_util = [] rows_data_p2_util_vs_util = [] @@ -337,7 +349,7 @@ def _gather_data_for_one_group_of_hp_sets( analysis_metrics_per_mode, hp_eval_updated, ) = _evaluate_perf_for_one_hp_set( - tune_analysis_per_welfare, hp_eval_updated + experiment_analysis_per_welfare, hp_eval_updated ) ( @@ -385,7 +397,7 @@ def _gather_data_for_one_group_of_hp_sets( ) -def _evaluate_perf_for_one_hp_set(tune_analysis_per_welfare, hp_eval): +def _evaluate_perf_for_one_hp_set(experiment_analysis_per_welfare, hp_eval): config_eval, env_config, stop, hp_eval_updated = _generate_eval_config( hp_eval ) @@ -393,7 +405,7 @@ def _evaluate_perf_for_one_hp_set(tune_analysis_per_welfare, hp_eval): # Eval & Plot analysis_metrics_per_mode = ( amtft_various_env.evaluate_self_play_cross_play( - tune_analysis_per_welfare, + experiment_analysis_per_welfare, config_eval, env_config, stop, @@ -415,8 +427,8 @@ def _generate_eval_config(hp): hp, rllib_config, fake_welfare_function ) hp_eval["debit_threshold_debug_override"] = False - config_eval = amtft_various_env.modify_config_for_evaluation( - rllib_config, hp_eval, env_config + config_eval, stop = amtft_various_env.modify_config_for_evaluation( + rllib_config, hp_eval, env_config, stop ) return config_eval, env_config, stop, hp_eval @@ -472,6 +484,9 @@ def _plot_final_results(hp: dict, data_groups: dict, suffix: str = ""): save_dir_path=save_dir_path, filename_prefix=f"plot_{suffix}", x_use_log_scale=True, + ylim=hp["ylim"], + # x_scale_multiplier=hp["plot_axis_scale_multipliers"][0], + y_scale_multiplier=hp["plot_axis_scale_multipliers"][0], ) plotter = plot.PlotHelper(plot_config) plotter.plot_lines(data_groups=data_groups) diff --git a/marltoolbox/experiments/rllib_api/dqn_coin_game_speed_search.py b/marltoolbox/experiments/rllib_api/dqn_coin_game_speed_search.py index 0e6d802..5857d61 100644 --- a/marltoolbox/experiments/rllib_api/dqn_coin_game_speed_search.py +++ b/marltoolbox/experiments/rllib_api/dqn_coin_game_speed_search.py @@ -2,15 +2,11 @@ from ray.rllib.utils.schedules import PiecewiseSchedule from marltoolbox.envs import coin_game, ssd_mixed_motive_coin_game -from marltoolbox.examples.rllib_api.dqn_coin_game import ( - _get_hyperparameters, - _get_rllib_configs, - _train_dqn_and_plot_logs, -) +from marltoolbox.examples.rllib_api import dqn_coin_game from marltoolbox.examples.rllib_api.dqn_wt_welfare import ( - _modify_policy_to_use_welfare, + modify_dqn_rllib_config_to_use_welfare, ) -from marltoolbox.utils import log, miscellaneous, postprocessing, exploration +from marltoolbox.utils import log, miscellaneous, postprocessing def main(debug): @@ -19,31 +15,31 @@ def main(debug): exp_name, _ = log.log_in_current_day_dir("DQN_CG_speed_search") env = "CoinGame" - # env = "SSDMixedMotiveCoinGame" - # welfare_to_use = None - # welfare_to_use = postprocessing.WELFARE_UTILITARIAN - welfare_to_use = postprocessing.WELFARE_INEQUITY_AVERSION + env = "SSDMixedMotiveCoinGame" + welfare_to_use = None + welfare_to_use = postprocessing.WELFARE_UTILITARIAN + # welfare_to_use = postprocessing.WELFARE_INEQUITY_AVERSION if "SSDMixedMotiveCoinGame" in env: env_class = ssd_mixed_motive_coin_game.SSDMixedMotiveCoinGame else: env_class = coin_game.CoinGame - hparams = _get_hyperparameters(seeds, debug, exp_name) + hparams = dqn_coin_game.get_hyperparameters(seeds, debug, exp_name) - rllib_config, stop_config = _get_rllib_configs( + rllib_config, stop_config = dqn_coin_game.get_rllib_configs( hparams, env_class=env_class ) if welfare_to_use is not None: - rllib_config = _modify_policy_to_use_welfare( + rllib_config = modify_dqn_rllib_config_to_use_welfare( rllib_config, welfare_to_use ) rllib_config, stop_config = _add_search_to_config( rllib_config, stop_config, hparams ) - tune_analysis = _train_dqn_and_plot_logs( + tune_analysis = dqn_coin_game.train_dqn_and_plot_logs( hparams, rllib_config, stop_config ) @@ -51,67 +47,98 @@ def main(debug): def _add_search_to_config(rllib_config, stop_config, hp): - rllib_config["num_envs_per_worker"] = tune.grid_search([1, 4, 8, 16, 32]) - rllib_config["lr"] = tune.grid_search([0.1, 0.1 * 2, 0.1 * 4]) - rllib_config["model"] = { - "dim": 3, - "conv_filters": [[16, [3, 3], 1], [16, [3, 3], 1]], - "fcnet_hiddens": [256, 256], - } - rllib_config["hiddens"] = [32] - rllib_config["env_config"] = { - "players_ids": ["player_red", "player_blue"], - "max_steps": 100, - "grid_size": 3, - "get_additional_info": True, - "temp_mid_step": 0.6, - "bs_epi_mul": tune.grid_search([2, 4, 8]), - } - rllib_config["training_intensity"] = 10 - - stop_config["episodes_total"] = tune.grid_search([1000, 2000]) - - rllib_config["exploration_config"] = { - # The Exploration class to use. In the simplest case, - # this is the name (str) of any class present in the - # `rllib.utils.exploration` package. - # You can also provide the python class directly or - # the full location of your class (e.g. - # "ray.rllib.utils.exploration.epsilon_greedy.EpsilonGreedy"). - # "type": exploration.SoftQSchedule, - "type": exploration.SoftQSchedule, - # Add constructor kwargs here (if any). - "temperature_schedule": tune.sample_from( - lambda spec: PiecewiseSchedule( - endpoints=[ - (0, 2.0), - ( - int( - spec.config["env_config"]["max_steps"] - * spec.stop["episodes_total"] - * 0.20 - ), - 0.5, - ), - ( - int( - spec.config["env_config"]["max_steps"] - * spec.stop["episodes_total"] - * spec.config["env_config"]["temp_mid_step"] - ), - hp["last_exploration_temp_value"], - ), - ], - outside_value=hp["last_exploration_temp_value"], - framework="torch", - ) - ), - } - rllib_config["train_batch_size"] = tune.sample_from( - lambda spec: int( - spec.config["env_config"]["max_steps"] - * spec.config["env_config"]["bs_epi_mul"] - ) + assert hp["last_exploration_temp_value"] == 0.01 + + stop_config["episodes_total"] = 10 if hp["debug"] else 8000 + # rllib_config["lr"] = 0.1 + # rllib_config["env_config"]["training_intensity"] = ( + # 20 if hp["debug"] else 40 + # ) + # rllib_config["training_intensity"] = tune.sample_from( + # lambda spec: spec.config["num_envs_per_worker"] + # * max(spec.config["num_workers"], 1) + # * spec.config["env_config"]["training_intensity"] + # ) + # rllib_config["env_config"]["temp_ratio"] = ( + # 1.0 if hp["debug"] else tune.grid_search([1.0, 0.5, 2.0]) + # ) + # rllib_config["env_config"]["interm_temp_ratio"] = ( + # 1.0 if hp["debug"] else tune.grid_search([1.0, 5.0, 2.0, 3.0, 10.0]) + # ) + # rllib_config["env_config"]["last_exploration_t"] = ( + # 0.6 if hp["debug"] else 0.9 + # ) + # rllib_config["env_config"]["last_exploration_temp_value"] = ( + # 1.0 if hp["debug"] else 0.003 + # ) + # rllib_config["exploration_config"][ + # "temperature_schedule" + # ] = tune.sample_from( + # lambda spec: PiecewiseSchedule( + # endpoints=[ + # (0, 0.5 * spec.config["env_config"]["temp_ratio"]), + # ( + # int( + # spec.config["env_config"]["max_steps"] + # * spec.stop["episodes_total"] + # * 0.20 + # ), + # 0.1 + # * spec.config["env_config"]["temp_ratio"] + # * spec.config["env_config"]["interm_temp_ratio"], + # ), + # ( + # int( + # spec.config["env_config"]["max_steps"] + # * spec.stop["episodes_total"] + # * spec.config["env_config"]["last_exploration_t"] + # ), + # spec.config["env_config"]["last_exploration_temp_value"], + # ), + # ], + # outside_value=spec.config["env_config"][ + # "last_exploration_temp_value" + # ], + # framework="torch", + # ) + # ) + + rllib_config["env_config"]["bs_epi_mul"] = ( + 4 if hp["debug"] else tune.grid_search([4, 8, 16]) + ) + rllib_config["env_config"]["interm_lr_ratio"] = ( + 0.5 if hp["debug"] else tune.grid_search([0.5 * 3, 0.5, 0.5 / 3]) + ) + rllib_config["lr"] = ( + 0.1 if hp["debug"] else tune.grid_search([0.1, 0.2, 0.4]) + ) + rllib_config["lr_schedule"] = tune.sample_from( + lambda spec: [ + (0, 0.0), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * 0.05 + ), + spec.config.lr, + ), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * 0.5 + ), + spec.config.lr * spec.config["env_config"]["interm_lr_ratio"], + ), + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + ), + spec.config.lr / 1e9, + ), + ] ) return rllib_config, stop_config diff --git a/marltoolbox/experiments/rllib_api/l1br_amtft.py b/marltoolbox/experiments/rllib_api/l1br_amtft.py index a727496..2fffabb 100644 --- a/marltoolbox/experiments/rllib_api/l1br_amtft.py +++ b/marltoolbox/experiments/rllib_api/l1br_amtft.py @@ -4,12 +4,11 @@ import ray from ray import tune from ray.rllib.agents import dqn -from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy -from ray.rllib.utils.framework import try_import_torch - -torch, nn = try_import_torch() +from marltoolbox import utils from marltoolbox.algos import amTFT +from marltoolbox.algos.augmented_dqn import MyDQNTorchPolicy +from marltoolbox.experiments.rllib_api import amtft_various_env from marltoolbox.utils import ( log, postprocessing, @@ -17,7 +16,6 @@ miscellaneous, callbacks, ) -from marltoolbox.experiments.rllib_api import amtft_various_env def main(debug, env=None): @@ -40,31 +38,29 @@ def get_hyperparameters(debug, env): train_n_replicates = 4 if debug else 8 pool_of_seeds = miscellaneous.get_random_seeds(train_n_replicates) - hparams = { - "debug": debug, - "filter_utilitarian": False, - "train_n_replicates": train_n_replicates, - "seeds": pool_of_seeds, - "exp_name": exp_name, - "welfare_functions": [ - (postprocessing.WELFARE_UTILITARIAN, "utilitarian") - ], - "amTFTPolicy": amTFT.AmTFTRolloutsTorchPolicy, - "explore_during_evaluation": True, - "n_seeds_lvl0": train_n_replicates, - "n_seeds_lvl1": train_n_replicates // 2, - "gamma": 0.96, - "temperature_schedule": False, - "jitter": 0.05, - "hiddens": [64], - "env_name": "IteratedPrisonersDilemma", - # "env_name": "IteratedAsymBoS", - # "env_name": "IteratedAsymChicken", - # "env_name": "CoinGame", - # "env_name": "AsymCoinGame", - "overwrite_reward": True, - "reward_uncertainty": 0.0, - } + + hparams = amtft_various_env.get_hyperparameters(debug) + hparams.update( + { + "debug": debug, + "use_r2d2": False, + "filter_utilitarian": False, + "train_n_replicates": train_n_replicates, + "seeds": pool_of_seeds, + "exp_name": exp_name, + "num_envs_per_worker": 16, + "welfare_functions": [ + (postprocessing.WELFARE_UTILITARIAN, "utilitarian") + ], + "n_seeds_lvl0": train_n_replicates, + "n_seeds_lvl1": train_n_replicates // 2, + "env_name": "IteratedPrisonersDilemma", + # "env_name": "IteratedAsymBoS", + # "env_name": "IteratedAsymChicken", + # "env_name": "CoinGame", + # "env_name": "AsymCoinGame", + } + ) if env is not None: hparams["env_name"] = env @@ -96,7 +92,11 @@ def train_lvl1_agents(hp_lvl1, tune_analysis_lvl0): stop, env_config, rllib_config = amtft_various_env.get_rllib_config( hp_lvl1, hp_lvl1["welfare_functions"][0][0] ) - checkpoints_lvl0 = miscellaneous.extract_checkpoints(tune_analysis_lvl0) + checkpoints_lvl0 = ( + utils.restore.extract_checkpoints_from_experiment_analysis( + tune_analysis_lvl0 + ) + ) rllib_config = modify_conf_for_lvl1_training( hp_lvl1, env_config, rllib_config, checkpoints_lvl0 ) @@ -124,7 +124,7 @@ def modify_conf_for_lvl1_training( # Use a simple DQN as lvl1 agent (instead of amTFT with nested DQN) rllib_config_lvl1["multiagent"]["policies"][lvl1_policy_id] = ( - DQNTorchPolicy, + MyDQNTorchPolicy, hp_lvl1["env_class"](env_config).OBSERVATION_SPACE, hp_lvl1["env_class"].ACTION_SPACE, {}, @@ -132,9 +132,7 @@ def modify_conf_for_lvl1_training( rllib_config_lvl1["callbacks"] = callbacks.merge_callbacks( amTFT.AmTFTCallbacks, - log.get_logging_callbacks_class( - log_full_epi=False, log_full_epi_interval=100 - ), + log.get_logging_callbacks_class(), ) l1br_configuration_helper = lvl1_best_response.L1BRConfigurationHelper( diff --git a/marltoolbox/experiments/rllib_api/ltft_various_env.py b/marltoolbox/experiments/rllib_api/ltft_various_env.py index f637530..f16e805 100644 --- a/marltoolbox/experiments/rllib_api/ltft_various_env.py +++ b/marltoolbox/experiments/rllib_api/ltft_various_env.py @@ -5,8 +5,7 @@ from ray import tune from ray.rllib.agents.dqn.dqn_torch_policy import postprocess_nstep_and_prio from ray.rllib.utils.schedules import PiecewiseSchedule -from ray.tune.integration.wandb import WandbLogger -from ray.tune.logger import DEFAULT_LOGGERS +from ray.tune.integration.wandb import WandbLoggerCallback from marltoolbox.algos import ltft, augmented_dqn from marltoolbox.algos.exploiters.influence_evader import ( @@ -15,7 +14,6 @@ from marltoolbox.envs import ( matrix_sequential_social_dilemma, vectorized_coin_game, - mixed_motive_coin_game, ) from marltoolbox.envs.utils.wrappers import ( add_RewardUncertaintyEnvClassWrapper, @@ -68,6 +66,13 @@ def _get_hyparameters( "seeds": seeds, "debug": debug, "exp_name": exp_name, + "wandb": { + "project": "LTFT", + "group": exp_name, + "api_key_file": os.path.join( + os.path.dirname(__file__), "../../../api_key_wandb" + ), + }, "hiddens": [64], "log_n_points": 260, "clustering_distance": 0.2, @@ -132,11 +137,6 @@ def _modify_hyperparams_for_the_selected_env(hp): hp["x_limits"] = (-0.5, 3.0) hp["y_limits"] = (-1.1, 0.6) hp["env_class"] = vectorized_coin_game.AsymVectorizedCoinGame - elif "MixedMotiveCoinGame" in hp["env_name"]: - hp["x_limits"] = (-0.5, 1.0) - hp["y_limits"] = (-0.5, 1.0) - hp["env_class"] = mixed_motive_coin_game.MixedMotiveCoinGame - hp["both_players_can_pick_the_same_coin"] = True else: hp["x_limits"] = (-0.5, 0.6) hp["y_limits"] = (-0.5, 0.6) @@ -231,13 +231,6 @@ def _get_rllib_config(hp: dict): ), }, "policy_mapping_fn": lambda agent_id: agent_id, - # When replay_mode=lockstep, RLlib will replay all the agent - # transitions at a particular timestep together in a batch. - # This allows the policy to implement differentiable shared - # computations between agents it controls at that timestep. When - # replay_mode=independent, - # transitions are replayed independently per policy. - # "replay_mode": "lockstep", "observation_fn": ltft.observation_fn, }, # === DQN Models === @@ -322,18 +315,7 @@ def _get_rllib_config(hp: dict): # time "num_envs_per_worker": 1, "batch_mode": "complete_episodes", - "logger_config": { - "wandb": { - "project": "LTFT", - "group": hp["exp_name"], - "api_key_file": os.path.join( - os.path.dirname(__file__), "../../../api_key_wandb" - ), - "log_config": True, - }, - }, # === Debug Settings === - "log_level": "INFO", # Callbacks that will be run during various phases of training. See the # `DefaultCallbacks` class and # `examples/custom_metrics_and_callbacks.py` @@ -401,12 +383,20 @@ def _train_in_self_play(rllib_config, stop, exp_name, hp): tune_analysis_self_play = ray.tune.run( ltft.LTFTTrainer, config=rllib_config, - verbose=1, stop=stop, checkpoint_at_end=True, name=full_exp_name, log_to_file=not hp["debug"], - loggers=None if hp["debug"] else DEFAULT_LOGGERS + (WandbLogger,), + callbacks=None + if hp["debug"] + else [ + WandbLoggerCallback( + project=hp["wandb"]["project"], + group=hp["wandb"]["group"], + api_key_file=hp["wandb"]["api_key_file"], + log_config=True, + ), + ], ) if not hp["debug"]: @@ -430,12 +420,20 @@ def _train_against_opponent(hp, rllib_config, stop, exp_name, env_config): tune_analysis_naive_opponent = ray.tune.run( ltft.LTFTTrainer, config=rllib_config, - verbose=1, stop=stop, checkpoint_at_end=True, name=full_exp_name, log_to_file=not hp["debug"], - loggers=None if hp["debug"] else DEFAULT_LOGGERS + (WandbLogger,), + callbacks=None + if hp["debug"] + else [ + WandbLoggerCallback( + project=hp["wandb"]["project"], + group=hp["wandb"]["group"], + api_key_file=hp["wandb"]["api_key_file"], + log_config=True, + ) + ], ) if not hp["debug"]: diff --git a/marltoolbox/experiments/tune_class_api/alternating_offers/train.py b/marltoolbox/experiments/tune_class_api/alternating_offers/train.py index 39fdcdb..241d1f3 100644 --- a/marltoolbox/experiments/tune_class_api/alternating_offers/train.py +++ b/marltoolbox/experiments/tune_class_api/alternating_offers/train.py @@ -8,7 +8,10 @@ from ray import tune import os -from marltoolbox.algos.alternating_offers.alt_offers_training import AltOffersTraining +import utils.restore +from marltoolbox.algos.alternating_offers.alt_offers_training import ( + AltOffersTraining, +) from marltoolbox.utils import miscellaneous, log diff --git a/marltoolbox/experiments/tune_class_api/bootstrapped_replicates_prosociality_coeff_0.3 b/marltoolbox/experiments/tune_class_api/bootstrapped_replicates_prosociality_coeff_0.3 new file mode 100644 index 0000000..ce93bfa Binary files /dev/null and b/marltoolbox/experiments/tune_class_api/bootstrapped_replicates_prosociality_coeff_0.3 differ diff --git a/marltoolbox/experiments/tune_class_api/empirical_game_matrices_prosociality_coeff_0.3 b/marltoolbox/experiments/tune_class_api/empirical_game_matrices_prosociality_coeff_0.3 new file mode 100644 index 0000000..9623710 Binary files /dev/null and b/marltoolbox/experiments/tune_class_api/empirical_game_matrices_prosociality_coeff_0.3 differ diff --git a/marltoolbox/experiments/tune_class_api/l1br_lola_pg.py b/marltoolbox/experiments/tune_class_api/l1br_lola_pg.py index 76baf48..a70205a 100644 --- a/marltoolbox/experiments/tune_class_api/l1br_lola_pg.py +++ b/marltoolbox/experiments/tune_class_api/l1br_lola_pg.py @@ -16,10 +16,12 @@ from ray.rllib.agents.dqn.dqn_torch_policy import ( DQNTorchPolicy, build_q_stats, - after_init, + before_loss_init, ) from ray.rllib.utils.schedules import PiecewiseSchedule +from ray.tune.integration.wandb import WandbLoggerCallback +import marltoolbox.utils.restore from marltoolbox.algos.lola.train_cg_tune_class_API import LOLAPGCG from marltoolbox.algos.lola.train_pg_tune_class_API import LOLAPGMatrice from marltoolbox.envs.matrix_sequential_social_dilemma import ( @@ -41,6 +43,7 @@ restore, callbacks, ) +from marltoolbox.experiments.tune_class_api import lola_exact_official # TODO make it work for all env (not only ACG and CG)? or only for them @@ -55,56 +58,67 @@ def main(debug): ] lvl1_seeds = list(range(n_lvl1)) - exp_name, _ = log.log_in_current_day_dir("L1BR_LOLA_PG") - - tune_hparams = { - "exp_name": exp_name, - "load_data": None, - # Example: "load_data": ".../lvl1_results.p", - "load_population": None, - # Example: "load_population": - # [".../checkpoint.json", ".../checkpoint.json", ...] - "num_episodes": 5 if debug else 2000, - "trace_length": 5 if debug else 20, - "lr": None, - "gamma": 0.5, - "batch_size": 5 if debug else 512, - # "env_name": "IteratedPrisonersDilemma", - # "env_name": "IteratedBoS", - # "env_name": "IteratedAsymBoS", - "env_name": "VectorizedCoinGame", - # "env_name": "AsymVectorizedCoinGame", - "pseudo": False, - "grid_size": 3, - "lola_update": True, - "opp_model": False, - "mem_efficient": True, - "lr_correction": 1, - "bs_mul": 1 / 10, - "simple_net": True, - "hidden": 32, - "reg": 0, - "set_zero": 0, - "exact": False, - "warmup": 1, - "lvl0_seeds": lvl0_seeds, - "lvl1_seeds": lvl1_seeds, - "changed_config": False, - "ac_lr": 1.0, - "summary_len": 1, - "use_MAE": False, - "use_toolbox_env": True, - "clip_loss_norm": False, - "clip_lola_update_norm": False, - "clip_lola_correction_norm": 3.0, - "clip_lola_actor_norm": 10.0, - "entropy_coeff": 0.001, - "weigth_decay": 0.03, - "lola_correction_multiplier": 1, - "lr_decay": True, - "correction_reward_baseline_per_step": False, - "use_critic": False, - } + # exp_name, _ = log.log_in_current_day_dir("L1BR_LOLA_PG") + + tune_hparams = lola_exact_official.get_hyperparameters( + debug, n_in_lvl0_population + ) + # tune_hparams = { + # "debug": debug, + # "exp_name": exp_name, + # "wandb": { + # "project": "L1BR_LOLA_PG", + # "group": exp_name, + # "api_key_file": os.path.join( + # os.path.dirname(__file__), "../../../api_key_wandb" + # ), + # }, + # "load_data": None, + # # Example: "load_data": ".../lvl1_results.p", + # "load_population": None, + # # Example: "load_population": + # # [".../checkpoint.json", ".../checkpoint.json", ...] + # "num_episodes": 5 if debug else 2000, + # "trace_length": 5 if debug else 20, + # "lr": None, + # "gamma": 0.5, + # "batch_size": 5 if debug else 512, + # # "env_name": "IteratedPrisonersDilemma", + # # "env_name": "IteratedBoS", + # "env_name": "IteratedAsymBoS", + # # "env_name": "VectorizedCoinGame", + # # "env_name": "AsymVectorizedCoinGame", + # "pseudo": False, + # "grid_size": 3, + # "lola_update": True, + # "opp_model": False, + # "mem_efficient": True, + # "lr_correction": 1, + # "bs_mul": 1 / 10, + # "simple_net": True, + # "hidden": 32, + # "reg": 0, + # "set_zero": 0, + # "exact": False, + # "warmup": 1, + # "lvl0_seeds": lvl0_seeds, + # "lvl1_seeds": lvl1_seeds, + # "changed_config": False, + # "ac_lr": 1.0, + # "summary_len": 1, + # "use_MAE": False, + # "use_toolbox_env": True, + # "clip_loss_norm": False, + # "clip_lola_update_norm": False, + # "clip_lola_correction_norm": 3.0, + # "clip_lola_actor_norm": 10.0, + # "entropy_coeff": 0.001, + # "weigth_decay": 0.03, + # "lola_correction_multiplier": 1, + # "lr_decay": True, + # "correction_reward_baseline_per_step": False, + # "use_critic": False, + # } rllib_hparams = { "debug": debug, @@ -241,6 +255,16 @@ def train_lvl0_population(tune_hp): stop=stop, metric=tune_config["metric"], mode="max", + callbacks=None + if tune_hp["debug"] + else [ + WandbLoggerCallback( + project=tune_hp["wandb"]["project"], + group=tune_hp["wandb"]["group"], + api_key_file=tune_hp["wandb"]["api_key_file"], + log_config=True, + ) + ], ) @@ -267,11 +291,15 @@ def train_lvl1_agents(tune_hp, rllib_hp, results_list_lvl0): endpoints=[ (0, 10.0), ( - int(tune_hp["n_steps_per_epi"] * tune_hp["n_epi"] * 0.33), + int( + rllib_hp["n_steps_per_epi"] * rllib_hp["n_epi"] * 0.33 + ), 2.0, ), ( - int(tune_hp["n_steps_per_epi"] * tune_hp["n_epi"] * 0.66), + int( + rllib_hp["n_steps_per_epi"] * rllib_hp["n_epi"] * 0.66 + ), 0.1, ), ], @@ -306,7 +334,11 @@ def train_lvl1_agents(tune_hp, rllib_hp, results_list_lvl0): ) if tune_hp["load_population"] is None: - lvl0_checkpoints = miscellaneous.extract_checkpoints(results_list_lvl0) + lvl0_checkpoints = ( + utils.restore.extract_checkpoints_from_experiment_analysis( + results_list_lvl0 + ) + ) else: lvl0_checkpoints = tune_hp["load_population"] lvl0_policy_id = env_config["players_ids"][lvl0_policy_idx] @@ -331,6 +363,16 @@ def train_lvl1_agents(tune_hp, rllib_hp, results_list_lvl0): checkpoint_at_end=True, metric="episode_reward_mean", mode="max", + # callbacks=None + # if tune_hp["debug"] + # else [ + # WandbLoggerCallback( + # project=tune_hp["wandb"]["project"], + # group=tune_hp["wandb"]["group"], + # api_key_file=tune_hp["wandb"]["api_key_file"], + # log_config=True, + # ) + # ], ) return results @@ -340,13 +382,16 @@ def get_rllib_config(hp: dict, lvl1_idx: list, lvl1_training: bool): assert lvl1_training tune_config, _, env_config = get_tune_config(hp=hp) - tune_config["seed"] = 2020 + # tune_config["seed"] = 2020 stop = {"episodes_total": hp["n_epi"]} - after_init_fn = functools.partial( + before_loss_init_fn = functools.partial( miscellaneous.sequence_of_fn_wt_same_args, - function_list=[restore.after_init_load_policy_checkpoint, after_init], + function_list=[ + before_loss_init, + restore.before_loss_init_load_policy_checkpoint, + ], ) def sgd_optimizer_dqn(policy, config) -> "torch.optim.Optimizer": @@ -359,7 +404,7 @@ def sgd_optimizer_dqn(policy, config) -> "torch.optim.Optimizer": MyDQNTorchPolicy = DQNTorchPolicy.with_updates( stats_fn=log.augment_stats_fn_wt_additionnal_logs(build_q_stats), optimizer_fn=sgd_optimizer_dqn, - after_init=after_init_fn, + before_loss_init=before_loss_init_fn, ) if tune_config["env_class"] in ( diff --git a/marltoolbox/experiments/tune_class_api/lola_dice_official.py b/marltoolbox/experiments/tune_class_api/lola_dice_official.py index ee23242..69e079f 100644 --- a/marltoolbox/experiments/tune_class_api/lola_dice_official.py +++ b/marltoolbox/experiments/tune_class_api/lola_dice_official.py @@ -10,6 +10,7 @@ import ray from ray import tune from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy +from ray.tune.integration.wandb import WandbLoggerCallback from marltoolbox.algos.lola_dice.train_tune_class_API import LOLADICE from marltoolbox.envs.coin_game import CoinGame, AsymCoinGame @@ -36,6 +37,13 @@ def main(debug): # Example: "load_plot_data": ".../SameAndCrossPlay_save.p", "exp_name": exp_name, "train_n_replicates": train_n_replicates, + "wandb": { + "project": "LOLA_DICE", + "group": exp_name, + "api_key_file": os.path.join( + os.path.dirname(__file__), "../../../api_key_wandb" + ), + }, "env_name": "IPD", # "env_name": "IMP", # "env_name": "AsymBoS", @@ -77,7 +85,7 @@ def main(debug): def train(hp): tune_config, stop, _ = get_tune_config(hp) # Train with the Tune Class API (not RLLib Class) - tune_analysis = tune.run( + experiment_analysis = tune.run( LOLADICE, name=hp["exp_name"], config=tune_config, @@ -85,9 +93,19 @@ def train(hp): stop=stop, metric=hp["metric"], mode="max", + callbacks=None + if hp["debug"] + else [ + WandbLoggerCallback( + project=hp["wandb"]["project"], + group=hp["wandb"]["group"], + api_key_file=hp["wandb"]["api_key_file"], + log_config=True, + ) + ], ) - tune_analysis_per_exp = {"": tune_analysis} - return tune_analysis_per_exp + experiment_analysis_per_welfare = {"": experiment_analysis} + return experiment_analysis_per_welfare def get_tune_config(hp: dict) -> dict: @@ -163,7 +181,7 @@ def get_tune_config(hp: dict) -> dict: return config, stop, env_config -def evaluate(tune_analysis_per_exp, hp, debug): +def evaluate(experiment_analysis_per_welfare, hp, debug): ( rllib_hp, rllib_config_eval, @@ -180,7 +198,7 @@ def evaluate(tune_analysis_per_exp, hp, debug): trainable_class, stop, env_config, - tune_analysis_per_exp, + experiment_analysis_per_welfare, ) diff --git a/marltoolbox/experiments/tune_class_api/lola_exact_meta_game.py b/marltoolbox/experiments/tune_class_api/lola_exact_meta_game.py new file mode 100644 index 0000000..9081e33 --- /dev/null +++ b/marltoolbox/experiments/tune_class_api/lola_exact_meta_game.py @@ -0,0 +1,483 @@ +########## +# Additional dependencies are needed: +# Follow the LOLA installation described in the +# tune_class_api/lola_pg_official.py file +########## + +import copy +import logging +import os + +import numpy as np +import ray +from ray import tune +from ray.rllib.agents.pg import PGTrainer + +from marltoolbox import utils +from marltoolbox.algos import welfare_coordination +from marltoolbox.experiments.rllib_api import amtft_meta_game +from marltoolbox.experiments.tune_class_api import lola_exact_official +from marltoolbox.utils import ( + cross_play, + restore, + path, + callbacks, + log, + miscellaneous, +) + +logger = logging.getLogger(__name__) + + +EGALITARIAN = "egalitarian" +MIXED = "mixed" +UTILITARIAN = "utilitarian" +FAILURE = "failure" + + +def main(debug): + # amtft_meta_game._extract_stats_on_welfare_announced( + # players_ids=["player_row", "player_col"], + # exp_dir="/home/maxime/dev-maxime/CLR/vm-data/instance-10-cpu-2" + # "/LOLA_Exact/2021_04_21/17_53_53", + # nested_info=True, + # ) + + hp = get_hyperparameters(debug) + + results = [] + ray.init(num_cpus=os.cpu_count(), local_mode=hp["debug"]) + for tau in hp["tau_range"]: + hp["tau"] = tau + ( + all_rllib_config, + hp_eval, + env_config, + stop_config, + ) = _produce_rllib_config_for_each_replicates(hp) + + mixed_rllib_configs = ( + cross_play.utils.mix_policies_in_given_rllib_configs( + all_rllib_config, hp_eval["n_cross_play_in_final_meta_game"] + ) + ) + + experiment_analysis = ray.tune.run( + PGTrainer, + config=mixed_rllib_configs, + verbose=1, + stop=stop_config, + name=hp_eval["exp_name"], + log_to_file=not hp_eval["debug"], + ) + + ( + mean_player_1_payoffs, + mean_player_2_payoffs, + player_1_payoffs, + player_2_payoffs, + ) = amtft_meta_game.extract_metrics(experiment_analysis, hp_eval) + + results.append( + ( + tau, + (mean_player_1_payoffs, mean_player_2_payoffs), + (player_1_payoffs, player_2_payoffs), + ) + ) + amtft_meta_game.save_to_json(exp_name=hp["exp_name"], object=results) + amtft_meta_game.plot_results( + exp_name=hp["exp_name"], + results=results, + hp_eval=hp_eval, + format_fn=amtft_meta_game.format_result_for_plotting, + ) + amtft_meta_game.extract_stats_on_welfare_announced( + players_ids=env_config["players_ids"], + exp_name=hp["exp_name"], + nested_info=True, + ) + + +def get_hyperparameters(debug): + """Get hyperparameters for meta game with LOLA-Exact policies in base + game""" + # env = "IPD" + env = "IteratedAsymBoS" + + hp = lola_exact_official.get_hyperparameters( + debug, train_n_replicates=1, env=env + ) + + hp.update( + { + "n_replicates_over_full_exp": 2 if debug else 20, + "final_base_game_eval_over_n_epi": 1 if debug else 200, + "tau_range": np.arange(0.0, 1.1, 0.5) + if hp["debug"] + else np.arange(0.0, 1.1, 0.1), + "n_self_play_in_final_meta_game": 0, + "n_cross_play_in_final_meta_game": 1 if debug else 10, + "welfare_functions": [ + (EGALITARIAN, EGALITARIAN), + (MIXED, MIXED), + (UTILITARIAN, UTILITARIAN), + ], + } + ) + return hp + + +def _produce_rllib_config_for_each_replicates(hp): + all_rllib_config = [] + for replicate_i in range(hp["n_replicates_over_full_exp"]): + hp_eval = _load_base_game_results( + copy.deepcopy(hp), load_base_replicate_i=replicate_i + ) + + ( + rllib_config, + hp_eval, + env_config, + stop_config, + ) = _get_vanilla_lola_exact_eval_config( + hp_eval, hp_eval["final_base_game_eval_over_n_epi"] + ) + + rllib_config = _modify_config_to_use_welfare_coordinators( + rllib_config, env_config, hp_eval + ) + all_rllib_config.append(rllib_config) + return all_rllib_config, hp_eval, env_config, stop_config + + +def _load_base_game_results(hp, load_base_replicate_i): + + # In local machine + # prefix = "~/dev-maxime/CLR/vm-data/instance-10-cpu-2/" + # prefix = "~/dev-maxime/CLR/vm-data/instance-10-cpu-2/" + # prefix = "~/dev-maxime/CLR/vm-data/instance-60-cpu-2-preemtible/" + prefix = "~/dev-maxime/CLR/vm-data/instance-60-cpu-3-preemtible/" + prefix2 = "~/dev-maxime/CLR/vm-data/instance-60-cpu-4-preemtible/" + + # In VM + # prefix = "~/ray_results/" + # prefix2 = prefix + + prefix = os.path.expanduser(prefix) + prefix2 = os.path.expanduser(prefix2) + if "IteratedAsymBoS" in hp["env_name"]: + hp["data_dir"] = ( + # instance-60-cpu-4-preemtible + prefix2 + "LOLA_Exact/2021_05_07/07_52_32", # 30 replicates + prefix2 + "LOLA_Exact/2021_05_07/08_02_38", # 30 replicates + prefix2 + "LOLA_Exact/2021_05_07/08_02_49", # 30 replicates + prefix2 + "LOLA_Exact/2021_05_07/08_03_03", # 30 replicates + prefix2 + "LOLA_Exact/2021_05_07/08_54_58", # 30 replicates + prefix2 + "LOLA_Exact/2021_05_07/08_55_34", # 30 replicates + prefix2 + "LOLA_Exact/2021_05_07/09_04_07", # 30 replicates + prefix2 + "LOLA_Exact/2021_05_07/09_09_30", # 30 replicates + prefix2 + "LOLA_Exact/2021_05_07/09_09_42", # 30 replicates + prefix2 + "LOLA_Exact/2021_05_07/10_02_15", # 30 replicates + prefix2 + "LOLA_Exact/2021_05_07/10_02_30", # 30 replicates + prefix2 + "LOLA_Exact/2021_05_07/10_02_39", # 30 replicates + prefix2 + "LOLA_Exact/2021_05_07/10_02_50", # 30 replicates + # instance-60-cpu-3-preemtible & instance-60-cpu-4-preemtible + prefix + "LOLA_Exact/2021_05_05/14_49_18", # 30 replicates + prefix + "LOLA_Exact/2021_05_05/14_50_39", # 30 replicates + prefix + "LOLA_Exact/2021_05_05/14_51_01", # 30 replicates + prefix + "LOLA_Exact/2021_05_05/14_53_56", # 30 replicates + prefix + "LOLA_Exact/2021_05_05/14_56_32", # 30 replicates + prefix + "LOLA_Exact/2021_05_05/15_46_08", # 30 replicates + prefix + "LOLA_Exact/2021_05_05/15_46_23", # 30 replicates + prefix + "LOLA_Exact/2021_05_05/15_46_59", # 30 replicates + prefix + "LOLA_Exact/2021_05_05/15_47_22", # 30 replicates + prefix + "LOLA_Exact/2021_05_05/15_48_22", # 30 replicates + )[load_base_replicate_i] + else: + raise ValueError(f'bad env_name: {hp["env_name"]}') + + assert os.path.exists(hp["data_dir"]), ( + "Path doesn't exist. Probably that the prefix need to " + f"be changed to fit the current machine used. path: {hp['data_dir']}" + ) + + print("==== Going to process data_dir", hp["data_dir"], "====") + + hp["ckpt_per_welfare"] = _get_checkpoints_for_each_welfare_in_dir( + hp["data_dir"], hp + ) + + return hp + + +def _get_checkpoints_for_each_welfare_in_dir(data_dir, hp): + all_replicates_save_dir = amtft_meta_game.get_dir_of_each_replicate( + data_dir, str_in_dir="LOLAExactTrainer_" + ) + assert len(all_replicates_save_dir) > 0 + welfares = _classify_base_replicates_into_welfares(all_replicates_save_dir) + + ckpt_per_welfare = {} + for welfare_fn, welfare_name in hp["welfare_functions"]: + replicates_save_dir_for_welfare = _filter_replicate_dir_by_welfare( + all_replicates_save_dir, welfares, welfare_name + ) + ckpts = restore.get_checkpoint_for_each_replicates( + replicates_save_dir_for_welfare + ) + ckpt_per_welfare[welfare_name] = [ckpt + ".json" for ckpt in ckpts] + return ckpt_per_welfare + + +def _classify_base_replicates_into_welfares(all_replicates_save_dir): + welfares = [] + for replicate_dir in all_replicates_save_dir: + reward_player_1, reward_player_2 = _get_last_episode_rewards( + replicate_dir + ) + welfare_name = classify_into_welfare_based_on_rewards( + reward_player_1, reward_player_2 + ) + welfares.append(welfare_name) + return welfares + + +def classify_into_welfare_based_on_rewards(reward_player_1, reward_player_2): + + ratio = reward_player_1 / reward_player_2 + if ratio < 1.5: + return EGALITARIAN + elif ratio < 2.5: + return MIXED + else: + return UTILITARIAN + + +def _filter_replicate_dir_by_welfare( + all_replicates_save_dir, welfares, welfare_name +): + replicates_save_dir_for_welfare = [ + replicate_dir + for welfare, replicate_dir in zip(welfares, all_replicates_save_dir) + if welfare == welfare_name + ] + return replicates_save_dir_for_welfare + + +def _get_last_episode_rewards(replicate_dir): + results = utils.path.get_results_for_replicate(replicate_dir) + last_epsiode_results = results[-1] + return last_epsiode_results["ret1"], last_epsiode_results["ret2"] + + +def _get_vanilla_lola_exact_eval_config(hp, final_eval_over_n_epi): + ( + hp_eval, + rllib_config, + policies_to_load, + trainable_class, + stop_config, + env_config, + ) = lola_exact_official.generate_eval_config(hp) + + hp_eval["n_self_play_per_checkpoint"] = None + hp_eval["n_cross_play_per_checkpoint"] = None + hp_eval[ + "x_axis_metric" + ] = f"policy_reward_mean/{env_config['players_ids'][0]}" + hp_eval[ + "y_axis_metric" + ] = f"policy_reward_mean/{env_config['players_ids'][1]}" + hp_eval["plot_axis_scale_multipliers"] = ( + 1 / hp_eval["trace_length"], + 1 / hp_eval["trace_length"], + ) + hp_eval["num_episodes"] = final_eval_over_n_epi + stop_config["episodes_total"] = final_eval_over_n_epi + rllib_config["callbacks"] = callbacks.merge_callbacks( + callbacks.PolicyCallbacks, + log.get_logging_callbacks_class( + log_full_epi=True, + log_from_policy_in_evaluation=True, + ), + ) + rllib_config["seed"] = miscellaneous.get_random_seeds(1)[0] + rllib_config["log_level"] = "INFO" + + return rllib_config, hp_eval, env_config, stop_config + + +def _modify_config_to_use_welfare_coordinators( + rllib_config, env_config, hp_eval +): + all_welfare_pairs_wt_payoffs = ( + _get_all_welfare_pairs_wt_cross_play_payoffs( + hp_eval, env_config["players_ids"] + ) + ) + + rllib_config["multiagent"]["policies_to_train"] = ["None"] + policies = rllib_config["multiagent"]["policies"] + for policy_idx, policy_id in enumerate(env_config["players_ids"]): + policy_config_items = list(policies[policy_id]) + opp_policy_idx = (policy_idx + 1) % 2 + + meta_policy_config = copy.deepcopy(welfare_coordination.DEFAULT_CONFIG) + meta_policy_config.update( + { + "nested_policies": [ + { + "Policy_class": copy.deepcopy(policy_config_items[0]), + "config_update": copy.deepcopy(policy_config_items[3]), + }, + ], + "solve_meta_game_after_init": True, + "tau": hp_eval["tau"], + "all_welfare_pairs_wt_payoffs": all_welfare_pairs_wt_payoffs, + "own_player_idx": policy_idx, + "opp_player_idx": opp_policy_idx, + "own_default_welfare_fn": EGALITARIAN + if policy_idx == 1 + else UTILITARIAN, + "opp_default_welfare_fn": EGALITARIAN + if opp_policy_idx == 1 + else UTILITARIAN, + "policy_id_to_load": policy_id, + "policy_checkpoints": hp_eval["ckpt_per_welfare"], + } + ) + policy_config_items[ + 0 + ] = welfare_coordination.WelfareCoordinationTorchPolicy + policy_config_items[3] = meta_policy_config + policies[policy_id] = tuple(policy_config_items) + + return rllib_config + + +def _get_all_welfare_pairs_wt_cross_play_payoffs(hp, player_ids): + all_eval_replicates_dirs = _get_list_of_replicates_path_in_eval(hp) + + raw_data_points_wt_welfares = {} + for eval_replicate_path in all_eval_replicates_dirs: + players_ckpts = _extract_checkpoints_used_for_each_players( + player_ids, eval_replicate_path + ) + if _is_cross_play(players_ckpts): + players_welfares = _convert_checkpoint_names_to_welfares( + hp, players_ckpts + ) + raw_players_perf = _extract_performance( + eval_replicate_path, player_ids + ) + play_mode = _get_play_mode(players_welfares) + if play_mode not in raw_data_points_wt_welfares.keys(): + raw_data_points_wt_welfares[play_mode] = [] + raw_data_points_wt_welfares[play_mode].append(raw_players_perf) + all_welfare_pairs_wt_payoffs = _average_perf_per_play_mode( + raw_data_points_wt_welfares, hp + ) + print("all_welfare_pairs_wt_payoffs", all_welfare_pairs_wt_payoffs) + return all_welfare_pairs_wt_payoffs + + +def _get_list_of_replicates_path_in_eval(hp): + child_dirs = utils.path.get_children_paths_wt_discarding_filter( + hp["data_dir"], _filter="LOLAExact" + ) + child_dirs = utils.path.keep_dirs_only(child_dirs) + assert len(child_dirs) == 1, f"{child_dirs}" + eval_dir = utils.path.get_unique_child_dir(child_dirs[0]) + eval_replicates_dir = utils.path.get_unique_child_dir(eval_dir) + possible_nested_dir = utils.path.try_get_unique_child_dir( + eval_replicates_dir + ) + if possible_nested_dir is not None: + eval_replicates_dir = possible_nested_dir + all_eval_replicates_dirs = ( + utils.path.get_children_paths_wt_selecting_filter( + eval_replicates_dir, _filter="PG_" + ) + ) + return all_eval_replicates_dirs + + +def _extract_checkpoints_used_for_each_players( + player_ids, eval_replicate_path +): + params = utils.path.get_params_for_replicate(eval_replicate_path) + policies_config = params["multiagent"]["policies"] + ckps = [ + policies_config[player_id][3]["checkpoint_to_load_from"][0] + for player_id in player_ids + ] + return ckps + + +def _is_cross_play(players_ckpts): + return players_ckpts[0] != players_ckpts[1] + + +def _convert_checkpoint_names_to_welfares(hp, players_ckpts): + players_welfares = [] + for player_ckpt in players_ckpts: + player_ckpt_wtout_root = "/".join(player_ckpt.split("/")[-4:]) + for welfare, ckpts_for_welfare in hp["ckpt_per_welfare"].items(): + if any( + player_ckpt_wtout_root in ckpt for ckpt in ckpts_for_welfare + ): + players_welfares.append(welfare) + break + + assert len(players_welfares) == len( + players_ckpts + ), f"{len(players_welfares)} == {len(players_ckpts)}" + return players_welfares + + +def _extract_performance(eval_replicate_path, player_ids): + results_per_epi = utils.path.get_results_for_replicate(eval_replicate_path) + players_avg_reward = _extract_and_average_perf(results_per_epi, player_ids) + return players_avg_reward + + +def _extract_and_average_perf(results_per_epi, player_ids): + players_avg_reward = [] + for player_id in player_ids: + player_rewards = [] + for result_in_one_epi in results_per_epi: + total_player_reward_in_one_epi = result_in_one_epi[ + "policy_reward_mean" + ][player_id] + player_rewards.append(total_player_reward_in_one_epi) + players_avg_reward.append(sum(player_rewards) / len(player_rewards)) + return players_avg_reward + + +def _get_play_mode(players_welfares): + return f"{players_welfares[0]}-{players_welfares[1]}" + + +def _average_perf_per_play_mode(raw_data_points_wt_welfares, hp): + all_welfare_pairs_wt_payoffs = {} + for ( + play_mode, + values_per_replicates, + ) in raw_data_points_wt_welfares.items(): + player_1_values = [ + value_replicate[0] for value_replicate in values_per_replicates + ] + player_2_values = [ + value_replicate[1] for value_replicate in values_per_replicates + ] + all_welfare_pairs_wt_payoffs[play_mode] = ( + sum(player_1_values) / len(player_1_values) / hp["trace_length"], + sum(player_2_values) / len(player_2_values) / hp["trace_length"], + ) + return all_welfare_pairs_wt_payoffs + + +if __name__ == "__main__": + debug_mode = True + main(debug_mode) diff --git a/marltoolbox/experiments/tune_class_api/lola_exact_official.py b/marltoolbox/experiments/tune_class_api/lola_exact_official.py index 1122d30..9fefc95 100644 --- a/marltoolbox/experiments/tune_class_api/lola_exact_official.py +++ b/marltoolbox/experiments/tune_class_api/lola_exact_official.py @@ -9,9 +9,12 @@ import ray from ray import tune +from ray.tune.analysis import ExperimentAnalysis from ray.rllib.agents.pg import PGTorchPolicy +from ray.tune.integration.wandb import WandbLoggerCallback +from marltoolbox.experiments.tune_class_api import lola_exact_meta_game -from marltoolbox.algos.lola.train_exact_tune_class_API import LOLAExact +from marltoolbox.algos.lola.train_exact_tune_class_API import LOLAExactTrainer from marltoolbox.envs.matrix_sequential_social_dilemma import ( IteratedPrisonersDilemma, IteratedMatchingPennies, @@ -19,10 +22,31 @@ ) from marltoolbox.experiments.tune_class_api import lola_pg_official from marltoolbox.utils import policy, log, miscellaneous +from marltoolbox.scripts import aggregate_and_plot_tensorboard_data def main(debug): - train_n_replicates = 2 if debug else 40 + hparams = get_hyperparameters(debug) + + if hparams["load_plot_data"] is None: + ray.init( + num_cpus=os.cpu_count(), + num_gpus=0, + local_mode=debug, + ) + experiment_analysis_per_welfare = train(hparams) + else: + experiment_analysis_per_welfare = None + + evaluate(experiment_analysis_per_welfare, hparams) + ray.shutdown() + + +def get_hyperparameters(debug, train_n_replicates=None, env=None): + """Get hyperparameters for LOLA-Exact for matrix games""" + + if train_n_replicates is None: + train_n_replicates = 2 if debug else int(3 * 1) seeds = miscellaneous.get_random_seeds(train_n_replicates) exp_name, _ = log.log_in_current_day_dir("LOLA_Exact") @@ -32,12 +56,24 @@ def main(debug): "load_plot_data": None, # Example "load_plot_data": ".../SelfAndCrossPlay_save.p", "exp_name": exp_name, + "classify_into_welfare_fn": True, "train_n_replicates": train_n_replicates, - "env_name": "IPD", - # "env_name": "IMP", - # "env_name": "AsymBoS", + "wandb": { + "project": "LOLA_Exact", + "group": exp_name, + "api_key_file": os.path.join( + os.path.dirname(__file__), "../../../api_key_wandb" + ), + }, + # "env_name": "IPD" if env is None else env, + # "env_name": "IMP" if env is None else env, + "env_name": "IteratedAsymBoS" if env is None else env, "num_episodes": 5 if debug else 50, "trace_length": 5 if debug else 200, + "re_init_every_n_epi": 1, + # "num_episodes": 5 if debug else 50 * 200, + # "trace_length": 1, + # "re_init_every_n_epi": 50, "simple_net": True, "corrections": True, "pseudo": False, @@ -53,46 +89,68 @@ def main(debug): # "with_linear_LR_decay_to_zero": True, # "clip_update": 0.1, # "lr": 0.001, + "plot_keys": aggregate_and_plot_tensorboard_data.PLOT_KEYS + ["ret"], + "plot_assemblage_tags": aggregate_and_plot_tensorboard_data.PLOT_ASSEMBLAGE_TAGS + + [("ret",)], + "x_limits": (-0.1, 4.1), + "y_limits": (-0.1, 4.1), } - if hparams["load_plot_data"] is None: - ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug) - tune_analysis_per_exp = train(hparams) - else: - tune_analysis_per_exp = None - - evaluate(tune_analysis_per_exp, hparams) - ray.shutdown() + hparams["plot_axis_scale_multipliers"] = ( + 1 / hparams["trace_length"], + 1 / hparams["trace_length"], + ) + return hparams def train(hp): - tune_config, stop, _ = get_tune_config(hp) + tune_config, stop_config, _ = get_tune_config(hp) # Train with the Tune Class API (not an RLLib Trainer) - tune_analysis = tune.run( - LOLAExact, + experiment_analysis = tune.run( + LOLAExactTrainer, name=hp["exp_name"], config=tune_config, checkpoint_at_end=True, - stop=stop, + stop=stop_config, metric=hp["metric"], mode="max", + # callbacks=None + # if hp["debug"] + # else [ + # WandbLoggerCallback( + # project=hp["wandb"]["project"], + # group=hp["wandb"]["group"], + # api_key_file=hp["wandb"]["api_key_file"], + # log_config=True, + # ) + # ], ) - tune_analysis_per_exp = {"": tune_analysis} - return tune_analysis_per_exp + if hp["classify_into_welfare_fn"]: + experiment_analysis_per_welfare = ( + _classify_trials_in_function_of_welfare(experiment_analysis) + ) + else: + experiment_analysis_per_welfare = {"": experiment_analysis} + return experiment_analysis_per_welfare -def get_tune_config(hp: dict) -> dict: + +def get_tune_config(hp: dict): tune_config = copy.deepcopy(hp) - assert tune_config["env_name"] in ("IPD", "IMP", "BoS", "AsymBoS") + assert tune_config["env_name"] in ("IPD", "IMP", "BoS", "IteratedAsymBoS") - if tune_config["env_name"] in ("IPD", "IMP", "BoS", "AsymBoS"): - env_config = { - "players_ids": ["player_row", "player_col"], - "max_steps": tune_config["trace_length"], - "get_additional_info": True, - } + env_config = { + "players_ids": ["player_row", "player_col"], + "max_steps": tune_config["trace_length"], + "get_additional_info": True, + } - if tune_config["env_name"] in ("IPD", "BoS", "AsymBoS"): + if tune_config["env_name"] == "IteratedAsymBoS": + tune_config["Q_net_std"] = 3.0 + else: + tune_config["Q_net_std"] = 1.0 + + if tune_config["env_name"] in ("IPD", "BoS", "IteratedAsymBoS"): tune_config["gamma"] = ( 0.96 if tune_config["gamma"] is None else tune_config["gamma"] ) @@ -103,17 +161,17 @@ def get_tune_config(hp: dict) -> dict: ) tune_config["save_dir"] = "dice_results_imp" - stop = {"episodes_total": tune_config["num_episodes"]} - return tune_config, stop, env_config + stop_config = {"episodes_total": tune_config["num_episodes"]} + return tune_config, stop_config, env_config -def evaluate(tune_analysis_per_exp, hp): +def evaluate(experiment_analysis_per_welfare, hp): ( rllib_hp, rllib_config_eval, policies_to_load, trainable_class, - stop, + stop_config, env_config, ) = generate_eval_config(hp) @@ -122,9 +180,12 @@ def evaluate(tune_analysis_per_exp, hp): rllib_config_eval, policies_to_load, trainable_class, - stop, + stop_config, env_config, - tune_analysis_per_exp, + experiment_analysis_per_welfare, + n_cross_play_per_checkpoint=min(15, hp["train_n_replicates"] - 1) + if hp["classify_into_welfare_fn"] + else None, ) @@ -136,8 +197,8 @@ def generate_eval_config(hp): hp_eval["batch_size"] = 1 hp_eval["num_episodes"] = 100 - tune_config, stop, env_config = get_tune_config(hp_eval) - tune_config["TuneTrainerClass"] = LOLAExact + tune_config, stop_config, env_config = get_tune_config(hp_eval) + tune_config["TuneTrainerClass"] = LOLAExactTrainer hp_eval["group_names"] = ["lola"] hp_eval["scale_multipliers"] = ( @@ -154,7 +215,7 @@ def generate_eval_config(hp): hp_eval["env_class"] = IteratedMatchingPennies hp_eval["x_limits"] = (-1.0, 1.0) hp_eval["y_limits"] = (-1.0, 1.0) - elif hp_eval["env_name"] == "AsymBoS": + elif hp_eval["env_name"] == "IteratedAsymBoS": hp_eval["env_class"] = IteratedAsymBoS hp_eval["x_limits"] = (-0.1, 4.1) hp_eval["y_limits"] = (-0.1, 4.1) @@ -184,21 +245,57 @@ def generate_eval_config(hp): }, "seed": hp_eval["seed"], "min_iter_time_s": hp_eval["min_iter_time_s"], + "num_workers": 0, + "num_envs_per_worker": 1, } policies_to_load = copy.deepcopy(env_config["players_ids"]) - trainable_class = LOLAExact + trainable_class = LOLAExactTrainer return ( hp_eval, rllib_config_eval, policies_to_load, trainable_class, - stop, + stop_config, env_config, ) +def _classify_trials_in_function_of_welfare( + experiment_analysis, +): + experiment_analysis_per_welfare = {} + for trial in experiment_analysis.trials: + welfare_name = _get_trial_welfare(trial) + if welfare_name not in experiment_analysis_per_welfare.keys(): + _add_empty_experiment_analysis( + experiment_analysis_per_welfare, + welfare_name, + experiment_analysis, + ) + experiment_analysis_per_welfare[welfare_name].trials.append(trial) + return experiment_analysis_per_welfare + + +def _get_trial_welfare(trial): + reward_player_1 = trial.last_result["ret1"] + reward_player_2 = trial.last_result["ret2"] + welfare_name = lola_exact_meta_game.classify_into_welfare_based_on_rewards( + reward_player_1, reward_player_2 + ) + return welfare_name + + +def _add_empty_experiment_analysis( + experiment_analysis_per_welfare, welfare_name, experiment_analysis +): + experiment_analysis_per_welfare[welfare_name] = copy.deepcopy( + experiment_analysis + ) + experiment_analysis_per_welfare[welfare_name].trials = [] + + if __name__ == "__main__": - debug_mode = True + debug_mode = False main(debug_mode) diff --git a/marltoolbox/experiments/tune_class_api/lola_pg_official.py b/marltoolbox/experiments/tune_class_api/lola_pg_official.py index 080976b..7d7ec85 100644 --- a/marltoolbox/experiments/tune_class_api/lola_pg_official.py +++ b/marltoolbox/experiments/tune_class_api/lola_pg_official.py @@ -11,25 +11,30 @@ ########## import copy +import logging import os import time import ray from ray import tune from ray.rllib.agents.dqn import DQNTorchPolicy -from ray.tune.integration.wandb import WandbLogger -from ray.tune.logger import DEFAULT_LOGGERS +from ray.tune.integration.wandb import WandbLoggerCallback -from marltoolbox.algos.lola import train_cg_tune_class_API -from marltoolbox.algos.lola.train_pg_tune_class_API import LOLAPGMatrice +from marltoolbox.algos.lola import ( + train_cg_tune_class_API, + train_pg_tune_class_API, +) from marltoolbox.envs import ( vectorized_coin_game, - vectorized_mixed_motive_coin_game, + vectorized_ssd_mm_coin_game, matrix_sequential_social_dilemma, ) from marltoolbox.scripts import aggregate_and_plot_tensorboard_data -from marltoolbox.utils import policy, log, self_and_cross_perf +from marltoolbox.utils import policy, log, cross_play, exp_analysis, callbacks from marltoolbox.utils.plot import PlotConfig +from marltoolbox.experiments.tune_class_api import lola_exact_official + +logger = logging.getLogger(__name__) def main(debug: bool, env=None): @@ -40,68 +45,64 @@ def main(debug: bool, env=None): :param debug: selection of debug mode using less compute :param env: option to overwrite the env selection """ - train_n_replicates = 2 if debug else 1 + train_n_replicates = 2 if debug else 20 timestamp = int(time.time()) seeds = [seed + timestamp for seed in list(range(train_n_replicates))] exp_name, _ = log.log_in_current_day_dir("LOLA_PG") + tune_hparams = _get_hyperparameters( + debug, train_n_replicates, seeds, exp_name, env + ) + + if tune_hparams["load_plot_data"] is None: + ray.init(num_cpus=10, num_gpus=0, local_mode=debug) + experiment_analysis_per_welfare = _train(tune_hparams) + else: + experiment_analysis_per_welfare = None + + _evaluate(tune_hparams, debug, experiment_analysis_per_welfare) + ray.shutdown() + + +def _get_hyperparameters(debug, train_n_replicates, seeds, exp_name, env): # The InfluenceEvader(like) use_best_exploiter = False - # use_best_exploiter = True - high_coop_speed_hp = True if use_best_exploiter else False - # high_coop_speed_hp = True + + gamma = 0.9 tune_hparams = { "debug": debug, "exp_name": exp_name, "train_n_replicates": train_n_replicates, - # wandb configuration - "wandb": None - if debug - else { + "wandb": { "project": "LOLA_PG", "group": exp_name, "api_key_file": os.path.join( os.path.dirname(__file__), "../../../api_key_wandb" ), - "log_config": True, }, + "classify_into_welfare_fn": True, # Print metrics "load_plot_data": None, # Example: "load_plot_data": ".../SelfAndCrossPlay_save.p", # - # "gamma": 0.5, - # "num_episodes": 3 if debug else 4000 if high_coop_speed_hp else 2000, - # "trace_length": 4 if debug else 20, - # "lr": None, - # - # "gamma": 0.875, - # "lr": 0.005 / 4, - # "num_episodes": 3 if debug else 4000, - # "trace_length": 4 if debug else 20, - # - "gamma": 0.9375, - "lr": 0.005 / 4 - if debug - else tune.grid_search([0.005 / 4, 0.005 / 4 / 2, 0.005 / 4 / 2 / 2]), - "num_episodes": 3 if debug else tune.grid_search([4000, 8000]), - "trace_length": 4 if debug else tune.grid_search([40, 80]), - # - "batch_size": 8 if debug else 512, # "env_name": "IteratedPrisonersDilemma" if env is None else env, # "env_name": "IteratedAsymBoS" if env is None else env, - "env_name": "VectorizedCoinGame" if env is None else env, + # "env_name": "VectorizedCoinGame" if env is None else env, # "env_name": "AsymVectorizedCoinGame" if env is None else env, - # "env_name": "VectorizedMixedMotiveCoinGame" if env is None else env, + "env_name": "VectorizedSSDMixedMotiveCoinGame" if env is None else env, + # "remove_trials_below_speed": False, + "remove_trials_below_speed": 0.2, + "remove_trials_below_speed_for_both": True, "pseudo": False, "grid_size": 3, "lola_update": True, "opp_model": False, "mem_efficient": True, "lr_correction": 1, - "bs_mul": 1 / 10 * 3 if use_best_exploiter else 1 / 10, + "global_lr_divider": 1 / 10 * 3 if use_best_exploiter else 1 / 10, "simple_net": True, "hidden": 32, "reg": 0, @@ -117,20 +118,7 @@ def main(debug: bool, env=None): "clip_loss_norm": False, "clip_lola_update_norm": False, "clip_lola_correction_norm": 3.0, - # "clip_lola_correction_norm": - # tune.grid_search([3.0 / 2, 3.0, 3.0 * 2]), "clip_lola_actor_norm": 10.0, - # "clip_lola_actor_norm": tune.grid_search([10.0 / 2, 10.0, 10.0 * 2]), - "entropy_coeff": 0.001, - # "entropy_coeff": tune.grid_search([0.001/2/2, 0.001/2, 0.001]), - # "weigth_decay": 0.03, - "weigth_decay": 0.03 - if debug - else tune.grid_search([0.03 / 8 / 2 / 2, 0.03 / 8 / 2, 0.03 / 8]), - # "lola_correction_multiplier": 1, - "lola_correction_multiplier": 1 - if debug - else tune.grid_search([1 * 4, 1 * 4 * 2, 1 * 4 * 2 * 2]), "lr_decay": True, "correction_reward_baseline_per_step": False, "use_critic": False, @@ -143,32 +131,91 @@ def main(debug: bool, env=None): ("total_reward",), ("entrop",), ], + "use_normalized_rewards": False, + "use_centered_reward": False, + "use_rolling_avg_actor_grad": False, + "process_reward_after_rolling": False, + "only_process_reward": False, + "use_rolling_avg_reward": False, + "reward_processing_bais": False, + "center_and_normalize_with_rolling_avg": False, + "punishment_helped": True, } - # Add exploiter hyperparameters - tune_hparams.update( - { - "start_using_exploiter_at_update_n": 1 - if debug - else 3000 - if high_coop_speed_hp - else 1500, - # PG exploiter - "use_PG_exploiter": True if use_best_exploiter else False, - "every_n_updates_copy_weights": 1 if debug else 100, - # "adding_scaled_weights": False, - # "adding_scaled_weights": 0.33, - } - ) - - if tune_hparams["load_plot_data"] is None: - ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug) - tune_analysis_per_exp = _train(tune_hparams) - else: - tune_analysis_per_exp = None + if gamma == 0.5: + tune_hparams.update( + { + "gamma": 0.5, + "num_episodes": 3 + if debug + else 4000 + if high_coop_speed_hp + else 2000, + "trace_length": 10 if debug else 20, + "lr": None, + "weigth_decay": 0.03, + "lola_correction_multiplier": 1, + "entropy_coeff": 0.001, + "batch_size": 12 if debug else 512, + } + ) + elif gamma == 0.875: + tune_hparams.update( + { + "gamma": 0.875, + "lr": 0.005 / 4, + "num_episodes": 3 if debug else 4000, + "trace_length": 10 if debug else 20, + "weigth_decay": 0.03 / 8, + "lola_correction_multiplier": 4, + "entropy_coeff": 0.001, + "batch_size": 12 if debug else 512, + } + ) + elif gamma == 0.9375: + tune_hparams.update( + { + "gamma": 0.9375, + "lr": 0.005 / 4, + "num_episodes": 3 if debug else 2000, + "trace_length": 10 if debug else 40, + "weigth_decay": 0.03 / 32, + "lola_correction_multiplier": 4, + "entropy_coeff": 0.002, + "batch_size": 12 if debug else 1024, + } + ) + elif gamma == 0.9: + tune_hparams.update( + { + "gamma": 0.9, + "lr": 0.01, + "num_episodes": 3 if debug else 2000, + "trace_length": 10 if debug else 40, + "weigth_decay": 0.001875, + "lola_correction_multiplier": 8, + "entropy_coeff": 0.02, + "batch_size": 12 if debug else 1024, + "use_normalized_rewards": False, + "reward_processing_bais": 0.1, + "center_and_normalize_with_rolling_avg": False, + } + ) - _evaluate(tune_hparams, debug, tune_analysis_per_exp) - ray.shutdown() + if use_best_exploiter: + # Add exploiter hyperparameters + tune_hparams.update( + { + "start_using_exploiter_at_update_n": 1 + if debug + else 3000 + if high_coop_speed_hp + else 1500, + "use_PG_exploiter": True if use_best_exploiter else False, + "every_n_updates_copy_weights": 1 if debug else 100, + } + ) + return tune_hparams def _train(tune_hp): @@ -177,10 +224,10 @@ def _train(tune_hp): if "CoinGame" in tune_config["env_name"]: trainable_class = train_cg_tune_class_API.LOLAPGCG else: - trainable_class = LOLAPGMatrice + trainable_class = train_pg_tune_class_API.LOLAPGMatrice # Train with the Tune Class API (not RLLib Class) - tune_analysis = tune.run( + experiment_analysis = tune.run( trainable_class, name=tune_hp["exp_name"], config=tune_config, @@ -189,18 +236,71 @@ def _train(tune_hp): metric=tune_config["metric"], mode="max", log_to_file=not tune_hp["debug"], - loggers=DEFAULT_LOGGERS + (WandbLogger,), + callbacks=None + if tune_hp["debug"] + else [ + WandbLoggerCallback( + project=tune_hp["wandb"]["project"], + group=tune_hp["wandb"]["group"], + api_key_file=tune_hp["wandb"]["api_key_file"], + log_config=True, + ) + ], ) - tune_analysis_per_exp = {"": tune_analysis} - # if not tune_hp["debug"]: + if tune_hp["remove_trials_below_speed"]: + experiment_analysis = _remove_failed_trials( + experiment_analysis, tune_hp + ) + + if tune_hp["classify_into_welfare_fn"]: + experiment_analysis_per_welfare = ( + _classify_trials_in_function_of_welfare( + experiment_analysis, tune_hp + ) + ) + else: + experiment_analysis_per_welfare = {"": experiment_analysis} + aggregate_and_plot_tensorboard_data.add_summary_plots( main_path=os.path.join("~/ray_results/", tune_config["exp_name"]), plot_keys=tune_config["plot_keys"], plot_assemble_tags_in_one_plot=tune_config["plot_assemblage_tags"], ) - return tune_analysis_per_exp + return experiment_analysis_per_welfare + + +def _remove_failed_trials(results, tune_hp): + if tune_hp["remove_trials_below_speed_for_both"]: + results = exp_analysis.filter_trials_wt_n_metrics( + results, + metrics=("player_red_pick_speed", "player_blue_pick_speed"), + metric_thresholds=( + tune_hp["remove_trials_below_speed"], + tune_hp["remove_trials_below_speed"], + ), + # metrics=("total_reward_player_blue",), + # metric_thresholds=(10000,), + metric_modes=("last-5-avg", "last-5-avg"), + threshold_modes=(exp_analysis.ABOVE, exp_analysis.ABOVE), + ) + else: + results = exp_analysis.filter_trials( + results, + metric="player_red_pick_speed", + metric_threshold=tune_hp["remove_trials_below_speed"], + metric_mode="last-5-avg", + threshold_mode=exp_analysis.ABOVE, + ) + results = exp_analysis.filter_trials( + results, + metric="player_blue_pick_speed", + metric_threshold=tune_hp["remove_trials_below_speed"], + metric_mode="last-5-avg", + threshold_mode=exp_analysis.ABOVE, + ) + return results def _get_tune_config(tune_hp: dict, stop_on_epi_number: bool = False): @@ -216,36 +316,21 @@ def _get_tune_config(tune_hp: dict, stop_on_epi_number: bool = False): tune_config[ "env_class" ] = vectorized_coin_game.AsymVectorizedCoinGame - elif tune_config["env_name"] == "VectorizedMixedMotiveCoinGame": + elif tune_config["env_name"] == "VectorizedSSDMixedMotiveCoinGame": tune_config[ "env_class" - ] = vectorized_mixed_motive_coin_game.VectMixedMotiveCG + ] = vectorized_ssd_mm_coin_game.VectSSDMixedMotiveCG else: raise ValueError() - tune_config["num_episodes"] = ( - 100000 - if tune_config["num_episodes"] is None - else tune_config["num_episodes"] - ) - tune_config["trace_length"] = ( - 150 - if tune_config["trace_length"] is None - else tune_config["trace_length"] - ) - tune_config["batch_size"] = ( - 4000 - if tune_config["batch_size"] is None - else tune_config["batch_size"] - ) tune_config["lr"] = ( 0.005 if tune_config["lr"] is None else tune_config["lr"] ) tune_config["gamma"] = ( 0.96 if tune_config["gamma"] is None else tune_config["gamma"] ) - tune_hp["x_limits"] = (-1.0, 1.0) - tune_hp["y_limits"] = (-1.0, 1.0) + tune_hp["x_limits"] = (-0.1, 0.6) + tune_hp["y_limits"] = (-0.1, 0.6) if ( tune_config["env_class"] == vectorized_coin_game.AsymVectorizedCoinGame @@ -253,21 +338,29 @@ def _get_tune_config(tune_hp: dict, stop_on_epi_number: bool = False): tune_hp["x_limits"] = (-1.0, 3.0) elif ( tune_config["env_class"] - == vectorized_mixed_motive_coin_game.VectMixedMotiveCG + == vectorized_ssd_mm_coin_game.VectSSDMixedMotiveCG ): - tune_hp["x_limits"] = (-2.0, 4.0) - tune_hp["y_limits"] = (-2.0, 4.0) - tune_hp["jitter"] = 0.02 + tune_hp["x_limits"] = (-0.02, 0.8) + tune_hp["y_limits"] = (-0.02, 1.5) + + tune_hp["jitter"] = 0.00 env_config = { "players_ids": ["player_red", "player_blue"], - "batch_size": tune_config["batch_size"], - "max_steps": tune_config["trace_length"], + "batch_size": tune.sample_from( + lambda spec: spec.config["batch_size"] + ), + "max_steps": tune.sample_from( + lambda spec: spec.config["trace_length"] + ), "grid_size": tune_config["grid_size"], "get_additional_info": True, - "both_players_can_pick_the_same_coin": tune_config["env_name"] - == "VectorizedMixedMotiveCoinGame", + "both_players_can_pick_the_same_coin": True, + # tune_config["env_name"] + # == "VectorizedMixedMotiveCoinGame" + # or tune_config["env_name"] == "VectorizedSSDMixedMotiveCoinGame", "force_vectorize": False, "same_obs_for_each_player": True, + "punishment_helped": tune_config["punishment_helped"], } tune_config["metric"] = "player_blue_pick_speed" tune_config["plot_keys"] += ( @@ -324,17 +417,13 @@ def _get_tune_config(tune_hp: dict, stop_on_epi_number: bool = False): } tune_config["metric"] = "player_row_CC_freq" - tune_hp["scale_multipliers"] = ( - 1 / tune_config["trace_length"], - 1 / tune_config["trace_length"], + # For hyperparameter search + tune_hp["scale_multipliers"] = tune.sample_from( + lambda spec: ( + 1 / spec.config["trace_length"], + 1 / spec.config["trace_length"], + ) ) - # For HP search - # tune_hp["scale_multipliers"] = tune.sample_from( - # lambda spec: ( - # 1 / spec.config["trace_length"], - # 1 / spec.config["trace_length"], - # ) - # ) tune_config["env_config"] = env_config if stop_on_epi_number: @@ -345,7 +434,7 @@ def _get_tune_config(tune_hp: dict, stop_on_epi_number: bool = False): return tune_config, stop, env_config -def _evaluate(tune_hp, debug, tune_analysis_per_exp): +def _evaluate(tune_hp, debug, experiment_analysis_per_exp): ( rllib_hp, rllib_config_eval, @@ -362,7 +451,7 @@ def _evaluate(tune_hp, debug, tune_analysis_per_exp): trainable_class, stop, env_config, - tune_analysis_per_exp, + experiment_analysis_per_exp, ) @@ -379,7 +468,17 @@ def _generate_eval_config(tune_hp, debug): env_config["batch_size"] = 1 tune_config["TuneTrainerClass"] = train_cg_tune_class_API.LOLAPGCG else: - tune_config["TuneTrainerClass"] = LOLAPGMatrice + tune_config["TuneTrainerClass"] = train_pg_tune_class_API.LOLAPGMatrice + tune_config["env_config"].update( + { + "batch_size": env_config["batch_size"], + "max_steps": rllib_hp["trace_length"], + } + ) + rllib_hp["scale_multipliers"] = ( + 1 / rllib_hp["trace_length"], + 1 / rllib_hp["trace_length"], + ) rllib_config_eval = { "env": rllib_hp["env_class"], @@ -387,8 +486,6 @@ def _generate_eval_config(tune_hp, debug): "multiagent": { "policies": { env_config["players_ids"][0]: ( - # The default policy is DQN defined in DQNTrainer - # but we overwrite it to use the LE policy policy.get_tune_policy_class(DQNTorchPolicy), rllib_hp["env_class"](env_config).OBSERVATION_SPACE, rllib_hp["env_class"].ACTION_SPACE, @@ -406,9 +503,17 @@ def _generate_eval_config(tune_hp, debug): }, "seed": rllib_hp["seed"], "min_iter_time_s": 3.0, - "callbacks": log.get_logging_callbacks_class( - log_full_epi=True, + # "callbacks": log.get_logging_callbacks_class( + # log_full_epi=True, + # ), + "callbacks": callbacks.merge_callbacks( + log.get_logging_callbacks_class( + log_full_epi=True, + ), + callbacks.PolicyCallbacks, ), + "num_envs_per_worker": 1, + "num_workers": 0, } policies_to_load = copy.deepcopy(env_config["players_ids"]) @@ -421,7 +526,7 @@ def _generate_eval_config(tune_hp, debug): "conv_filters": [[16, [3, 3], 1], [32, [3, 3], 1]], } else: - trainable_class = LOLAPGMatrice + trainable_class = train_pg_tune_class_API.LOLAPGMatrice return ( rllib_hp, @@ -440,20 +545,23 @@ def _evaluate_self_and_cross_perf( trainable_class, stop, env_config, - tune_analysis_per_exp, + experiment_analysis_per_welfare, + n_cross_play_per_checkpoint=None, ): - evaluator = self_and_cross_perf.SelfAndCrossPlayEvaluator( - exp_name=rllib_hp["exp_name"], + exp_name = os.path.join(rllib_hp["exp_name"], "eval") + evaluator = cross_play.evaluator.SelfAndCrossPlayEvaluator( + exp_name=exp_name, local_mode=rllib_hp["debug"], - use_wandb=not rllib_hp["debug"], ) analysis_metrics_per_mode = evaluator.perform_evaluation_or_load_data( evaluation_config=rllib_config_eval, stop_config=stop, policies_to_load_from_checkpoint=policies_to_load, - tune_analysis_per_exp=tune_analysis_per_exp, - TuneTrainerClass=trainable_class, - n_cross_play_per_checkpoint=min(5, rllib_hp["train_n_replicates"] - 1), + experiment_analysis_per_welfare=experiment_analysis_per_welfare, + tune_trainer_class=trainable_class, + n_cross_play_per_checkpoint=min(5, rllib_hp["train_n_replicates"] - 1) + if n_cross_play_per_checkpoint is None + else n_cross_play_per_checkpoint, to_load_path=rllib_hp["load_plot_data"], ) @@ -461,7 +569,7 @@ def _evaluate_self_and_cross_perf( rllib_hp["env_class"], matrix_sequential_social_dilemma.MatrixSequentialSocialDilemma, ): - background_area_coord = rllib_hp["env_class"].PAYOUT_MATRIX + background_area_coord = rllib_hp["env_class"].PAYOFF_MATRIX else: background_area_coord = None @@ -485,6 +593,56 @@ def _evaluate_self_and_cross_perf( ) +FAILURE = "failures" +EGALITARIAN = "egalitarian" +UTILITARIAN = "utilitarian" + + +def _classify_trials_in_function_of_welfare(experiment_analysis, hp): + experiment_analysis_per_welfare = {} + for trial in experiment_analysis.trials: + welfare_name = _get_trial_welfare(trial, hp) + if welfare_name not in experiment_analysis_per_welfare.keys(): + lola_exact_official._add_empty_experiment_analysis( + experiment_analysis_per_welfare, + welfare_name, + experiment_analysis, + ) + experiment_analysis_per_welfare[welfare_name].trials.append(trial) + return experiment_analysis_per_welfare + + +def _get_trial_welfare(trial, hp): + pick_own_player_1 = trial.last_result["player_red_pick_own_color"] + pick_own_player_2 = trial.last_result["player_blue_pick_own_color"] + reward_player_1 = trial.last_result["total_reward_player_red"] + reward_player_2 = trial.last_result["total_reward_player_blue"] + welfare_name = lola_pg_classify_fn( + pick_own_player_1, + pick_own_player_2, + hp, + reward_player_1, + reward_player_2, + ) + return welfare_name + + +def lola_pg_classify_fn( + pick_own_player_1, pick_own_player_2, hp, reward_player_1, reward_player_2 +): + if reward_player_2 != 0.0 and reward_player_1 != 0.0: + if hp["env_name"] == "VectorizedSSDMixedMotiveCoinGame": + ratio = reward_player_2 / reward_player_1 + else: + ratio = max( + reward_player_1 / reward_player_2, + reward_player_2 / reward_player_1, + ) + if ratio > 1.2: + return UTILITARIAN + return EGALITARIAN + + if __name__ == "__main__": debug_mode = True main(debug_mode) diff --git a/marltoolbox/experiments/tune_class_api/meta_solver_cross_play.py b/marltoolbox/experiments/tune_class_api/meta_solver_cross_play.py new file mode 100644 index 0000000..30659d8 --- /dev/null +++ b/marltoolbox/experiments/tune_class_api/meta_solver_cross_play.py @@ -0,0 +1,236 @@ +import copy +import json +import os +import random +import torch +import ray +from ray import tune + +from marltoolbox.experiments.tune_class_api import ( + lola_exact_official, + various_algo_meta_game, +) +from marltoolbox.utils import log, miscellaneous, exp_analysis +from marltoolbox.experiments.tune_class_api.various_algo_meta_game import ( + META_UNIFORM, + META_SOS, + META_RANDOM, + META_LOLA_EXACT, + META_PG, + META_APLHA_RANK, + META_APLHA_PURE, + META_REPLICATOR_DYNAMIC, + META_REPLICATOR_DYNAMIC_ZERO_INIT, + BASE_NEGOTIATION, + BASE_LOLA_EXACT, + BASE_AMTFT, + POLICY_ID_PL0, + POLICY_ID_PL1, + META_MINIMUM, +) + +payoffs_per_groups = None + +prefix = ( + "~/dev-maxime/CLR/vm-data/instance-60-cpu-4-preemtible/meta_game_compare" +) +# prefix = "~/ray_results/meta_game_compare" +prefix = os.path.expanduser(prefix) +META_POLICY_SAVE_PATHS = { + META_APLHA_RANK: prefix + + "/2021_05_26/19_16_35/meta_game/meta_policies.json", + META_APLHA_PURE: prefix + + "/2021_05_26/19_26_28/meta_game/meta_policies.json", + META_REPLICATOR_DYNAMIC: prefix + + "/2021_05_26/19_36_11/meta_game/meta_policies.json", + META_REPLICATOR_DYNAMIC_ZERO_INIT: prefix + + "/2021_05_26/19_47_52/meta_game/meta_policies.json", + META_RANDOM: prefix + "/2021_05_26/19_59_35/meta_game/meta_policies.json", + META_PG: prefix + "/2021_05_26/20_09_19/meta_game/meta_policies.json", + META_LOLA_EXACT: prefix + + "/2021_05_26/20_24_26/meta_game/meta_policies.json", + META_SOS: prefix + "/2021_05_26/22_10_37/meta_game/meta_policies.json", + META_UNIFORM: prefix + "/2021_05_26/22_21_07/meta_game/meta_policies.json", + META_MINIMUM: prefix + "/2021_05_27/19_24_36/meta_game/meta_policies.json", +} + + +def main(debug, base_game_algo=None, pair_of_meta_game_algo=None): + """Evaluate meta game performances""" + + n_replicates_over_full_exp = 2 if debug else 20 + train_n_replicates = 1 + n_cross_play = 4 + eval_over_n_epi = 2 if debug else 10 + seeds = miscellaneous.get_random_seeds(train_n_replicates) + exp_name, _ = log.log_in_current_day_dir("meta_game_compare") + + ( + rllib_configs_by_meta_solvers, + stop_config, + hp, + hp_eval, + trainer, + ) = _get_rllib_config_by_meta_solvers( + pair_of_meta_game_algo, + debug, + seeds, + exp_name, + n_replicates_over_full_exp, + base_game_algo, + ) + stop_config["episodes_total"] = eval_over_n_epi + master_rllib_config = _mix_rllib_config( + hp_eval, rllib_configs_by_meta_solvers, n_cross_play + ) + ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug) + tune_analysis = various_algo_meta_game._train_with_tune( + master_rllib_config, + stop_config, + hp_eval, + trainer, + plot_aggregates=False, + ) + ray.shutdown() + + various_algo_meta_game._extract_metric_and_log_and_plot( + tune_analysis, + hp, + hp_eval, + title=f"BASE({base_game_algo}) META({pair_of_meta_game_algo})", + ) + + +def _get_rllib_config_by_meta_solvers( + pair_of_meta_game_algo, + debug, + seeds, + exp_name, + n_replicates_over_full_exp, + base_game_algo, +): + rllib_configs_by_meta_solvers = [] + for meta_game_algo in pair_of_meta_game_algo: + hp = various_algo_meta_game._get_hyperparameters( + debug, seeds, exp_name, base_game_algo, meta_game_algo + ) + hp["n_replicates_over_full_exp"] = n_replicates_over_full_exp + global payoffs_per_groups + ( + hp["payoff_matrices"], + hp["actions_possible"], + hp["base_ckpt_per_replicat"], + payoffs_per_groups, + ) = various_algo_meta_game._form_n_matrices_from_base_game_payoffs(hp) + hp = _load_meta_policies(hp, meta_game_algo) + rllib_configs, stop_config, trainer, hp_eval = _get_rllib_configs(hp) + rllib_configs_by_meta_solvers.append(rllib_configs) + return rllib_configs_by_meta_solvers, stop_config, hp, hp_eval, trainer + + +def _load_meta_policies(hp, meta_game_algo): + meta_policy_save_path = META_POLICY_SAVE_PATHS[meta_game_algo] + with open(meta_policy_save_path) as json_file: + json_content = json.load(json_file) + clamped_meta_policies = copy.deepcopy( + json_content["clamped_meta_policies"] + ) + clamped_meta_policies = [ + {k: torch.tensor(v) for k, v in el.items()} + for el in clamped_meta_policies + ] + hp["meta_game_policy_distributions"] = clamped_meta_policies + + return hp + + +def _mix_rllib_config(hp, rllib_configs_by_meta_solvers, n_cross_play): + all_config_mix = [] + assert n_cross_play % 2 == 0 + n_meta_policies_meta_solver_1 = len(rllib_configs_by_meta_solvers[0]) + for meta_solver_1_idx in range(n_meta_policies_meta_solver_1): + for player_order in range(2): + for i in range(n_cross_play // 2): + meta_solver_1_config = rllib_configs_by_meta_solvers[ + player_order + ] + meta_solver_2_config = rllib_configs_by_meta_solvers[ + (player_order + 1) % 2 + ] + + meta_solver_2_idx = random.randint( + 0, len(meta_solver_2_config) - 1 + ) + + pl_1_config = meta_solver_1_config[meta_solver_1_idx][ + "multiagent" + ]["policies"][POLICY_ID_PL0] + pl_2_config = meta_solver_2_config[meta_solver_2_idx][ + "multiagent" + ]["policies"][POLICY_ID_PL1] + mix_policies = { + POLICY_ID_PL0: copy.deepcopy(pl_1_config), + POLICY_ID_PL1: copy.deepcopy(pl_2_config), + } + all_config_mix.append(mix_policies) + + master_config = copy.deepcopy(rllib_configs_by_meta_solvers[0][0]) + print("len(all_config_mix)", len(all_config_mix)) + master_config["multiagent"]["policies"] = tune.grid_search(all_config_mix) + return master_config + + +def _get_rllib_configs(hp): + rllib_configs_for_one_meta_solver = [] + for meta_game_idx in range(hp["n_replicates_over_full_exp"]): + ( + rllib_config, + stop_config, + trainer, + hp_eval, + ) = various_algo_meta_game._get_final_base_game_rllib_config( + copy.deepcopy(hp), meta_game_idx + ) + + rllib_configs_for_one_meta_solver.append(rllib_config) + return rllib_configs_for_one_meta_solver, stop_config, trainer, hp_eval + + +if __name__ == "__main__": + debug_mode = True + loop_over_main = True + + if loop_over_main: + base_game_algo_to_eval = ( + BASE_LOLA_EXACT, + # BASE_NEGOTIATION, + ) + meta_game_algo_to_eval = ( + # META_APLHA_RANK, + # META_APLHA_PURE, + # META_REPLICATOR_DYNAMIC, + # META_REPLICATOR_DYNAMIC_ZERO_INIT, + META_RANDOM, + META_PG, + # META_LOLA_EXACT, + # META_SOS, + META_UNIFORM, + META_MINIMUM, + ) + pairs_seen = [] + for base_game_algo_ in base_game_algo_to_eval: + for meta_game_algo_1 in meta_game_algo_to_eval: + for meta_game_algo_2 in meta_game_algo_to_eval: + if meta_game_algo_1 != meta_game_algo_2: + meta_pair = [meta_game_algo_1, meta_game_algo_2] + if sorted(meta_pair) not in pairs_seen: + main( + debug_mode, + base_game_algo_, + meta_pair, + ) + pairs_seen.append(sorted(meta_pair)) + else: + print("skipping pair", meta_pair) + else: + main(debug_mode) diff --git a/marltoolbox/experiments/tune_class_api/meta_solver_exploitability.py b/marltoolbox/experiments/tune_class_api/meta_solver_exploitability.py new file mode 100644 index 0000000..242efcd --- /dev/null +++ b/marltoolbox/experiments/tune_class_api/meta_solver_exploitability.py @@ -0,0 +1,292 @@ +import copy +import random +import torch +import numpy as np +from marltoolbox.experiments.tune_class_api import ( + lola_exact_official, + various_algo_meta_game, +) +from marltoolbox.utils import log, miscellaneous, exp_analysis +from marltoolbox.experiments.tune_class_api.various_algo_meta_game import ( + META_UNIFORM, + META_SOS, + META_RANDOM, + META_LOLA_EXACT, + META_PG, + META_APLHA_RANK, + META_APLHA_PURE, + META_REPLICATOR_DYNAMIC, + META_REPLICATOR_DYNAMIC_ZERO_INIT, + BASE_NEGOTIATION, + BASE_LOLA_EXACT, + BASE_AMTFT, + POLICY_ID_PL0, + POLICY_ID_PL1, + META_MINIMUM, +) + +from marltoolbox.experiments.tune_class_api import meta_solver_cross_play +from marltoolbox.experiments.rllib_api import amtft_meta_game + +POLICY_IDS = [POLICY_ID_PL0, POLICY_ID_PL1] + + +def main( + debug, + base_game_algo=None, + meta_game_algo=None, + exploiter_idx=None, + being_exploited_idx=None, +): + """Evaluate meta game performances""" + + print(f"==== base_game_algo {base_game_algo} =====") + print(f"==== meta_game_algo {meta_game_algo} =====") + train_n_replicates = 1 + tau = 0.0 + n_in_bootstrap = 10 + n_replicates = 40 + seeds = miscellaneous.get_random_seeds(train_n_replicates) + exp_name, _ = log.log_in_current_day_dir("meta_game_compare") + + hp = _get_meta_payoff_matrices_and_policies( + meta_game_algo, + debug, + seeds, + exp_name, + base_game_algo, + ) + pl_1_r = [] + pl_2_r = [] + br_idx = [] + bootstrapped_idx = [] + sets_announced = [] + for repl_i in range(n_replicates): + bootstrapped_hp = _bootstrap_payoff_mat(hp, n_in_bootstrap) + meta_payoff_mat_avg = _average_payoff_matrices(bootstrapped_hp) + player_0_meta_policy = _average_meta_opponent_policies( + bootstrapped_hp, being_exploited_idx + ) + ( + best_reponse_action, + player_0_payoff_in_br, + player_1_payoff_in_br, + ) = _get_best_response( + meta_payoff_mat_avg, + player_0_meta_policy, + tau, + exploiter_idx, + being_exploited_idx, + ) + pl_1_r.append(player_0_payoff_in_br) + pl_2_r.append(player_1_payoff_in_br) + br_idx.append(int(best_reponse_action)) + bootstrapped_idx.append(bootstrapped_hp["bootstrapped_idx"]) + best_reponse_set = bootstrapped_hp["actions_possible"][ + best_reponse_action + ] + print( + "best_reponse_action", + best_reponse_action, + best_reponse_set, + ) + sets_announced.append(list(best_reponse_set)) + + pl_1_r = np.array(pl_1_r) + print("pl_1_r.shape", pl_1_r.shape) + pl_1_r_mean = pl_1_r.mean() + pl_1_r_std_err = pl_1_r.std() / np.sqrt(pl_1_r.shape[0]) + + pl_2_r = np.array(pl_2_r) + print("pl_2_r.shape", pl_2_r.shape) + pl_2_r_mean = pl_2_r.mean() + pl_2_r_std_err = pl_2_r.std() / np.sqrt(pl_2_r.shape[0]) + + print("players payoffs", player_0_payoff_in_br, player_1_payoff_in_br) + result_to_json = { + "br_idx": br_idx, + "sets_announced": sets_announced, + "pl_1_r_mean": float(pl_1_r_mean), + "pl_2_r_mean": float(pl_2_r_mean), + "pl_1_r_std_err": float(pl_1_r_std_err), + "pl_2_r_std_err": float(pl_2_r_std_err), + "bootstrapped_idx": bootstrapped_idx, + "player_0_meta_policy": player_0_meta_policy.tolist(), + "payoff_matrices": [el.tolist() for el in hp["payoff_matrices"]], + "meta_game_policy_distributions": [ + {k: v.tolist() for k, v in el.items()} + for el in hp["meta_game_policy_distributions"] + ], + } + amtft_meta_game.save_to_json( + exp_name=hp["exp_name"], object=result_to_json + ) + return ( + float(pl_1_r_mean), + float(pl_2_r_mean), + float(pl_1_r_std_err), + float(pl_2_r_std_err), + ) + + +def _bootstrap_payoff_mat(hp, n_in_bootstrap): + bootstrapped_hp = copy.deepcopy(hp) + meta_payoff_matrices = bootstrapped_hp["payoff_matrices"] + print("len before bootstrapped", len(meta_payoff_matrices)) + bootstrapped_hp["bootstrapped_idx"] = random.choices( + range(len(meta_payoff_matrices)), k=n_in_bootstrap + ) + bootstrapped_hp["payoff_matrices"] = [ + meta_payoff_matrices[idx] + for idx in bootstrapped_hp["bootstrapped_idx"] + ] + print("len after bootstrapped", len(bootstrapped_hp["payoff_matrices"])) + return bootstrapped_hp + + +def _average_payoff_matrices(hp): + print("len(payoff_matrices)", len(hp["payoff_matrices"])) + print("payoff_matrices[0].shape", hp["payoff_matrices"][0].shape) + + meta_payoff_mat_avg = copy.deepcopy(hp["payoff_matrices"]) + meta_payoff_mat_avg = np.array(meta_payoff_mat_avg).mean(axis=0) + print("meta_payoff_mat_avg", meta_payoff_mat_avg.shape) + return meta_payoff_mat_avg + + +def _average_meta_opponent_policies(hp, being_exploited_idx): + pl_being_exploited_id = POLICY_IDS[being_exploited_idx] + player_0_policies = [ + el[pl_being_exploited_id] + for el in hp["meta_game_policy_distributions"] + ] + print("player_0_policies", player_0_policies) + player_0_meta_policy = torch.stack(player_0_policies, dim=0).mean(dim=0) + print("mean player_0_policies", player_0_meta_policy) + + return player_0_meta_policy + + +def _get_best_response( + meta_payoff_mat_avg, + player_0_meta_policy, + tau=0.0, + exploiter_idx=None, + being_exploited_idx=None, +): + meta_payoff_mat_avg = torch.tensor(meta_payoff_mat_avg) + exploiter_payoffs = meta_payoff_mat_avg[..., exploiter_idx] + print("player_0_meta_policy", player_0_meta_policy) + matrix_opp_policy_prob = torch.stack( + [player_0_meta_policy] * exploiter_payoffs.shape[0], dim=1 + ) + print("matrix_opp_policy_prob", matrix_opp_policy_prob) + exploiter_expected_payoffs = matrix_opp_policy_prob * exploiter_payoffs + print( + "exploiter_expected_payoffs", + exploiter_expected_payoffs, + exploiter_expected_payoffs.shape, + ) + mean_exploiter_payoff = exploiter_expected_payoffs.mean(dim=0) + print("mean_exploiter_payoff", mean_exploiter_payoff) + best_reponse_action = np.argmax(mean_exploiter_payoff) + print("best_reponse_action", best_reponse_action) + exploiter_payoff_in_br = ( + exploiter_payoffs[:, best_reponse_action] * player_0_meta_policy + ) + being_exploited_payoff_in_br = ( + meta_payoff_mat_avg[..., being_exploited_idx][:, best_reponse_action] + * player_0_meta_policy + ) + being_exploited_payoff_in_br = being_exploited_payoff_in_br.sum() + exploiter_payoff_in_br = exploiter_payoff_in_br.sum() + if exploiter_idx == 1: + return ( + best_reponse_action, + being_exploited_payoff_in_br, + exploiter_payoff_in_br, + ) + elif exploiter_idx == 0: + return ( + best_reponse_action, + exploiter_payoff_in_br, + being_exploited_payoff_in_br, + ) + else: + raise ValueError() + + +def _get_meta_payoff_matrices_and_policies( + meta_game_algo, + debug, + seeds, + exp_name, + base_game_algo, +): + hp = various_algo_meta_game._get_hyperparameters( + debug, seeds, exp_name, base_game_algo, meta_game_algo + ) + global payoffs_per_groups + ( + hp["payoff_matrices"], + hp["actions_possible"], + hp["base_ckpt_per_replicat"], + payoffs_per_groups, + ) = various_algo_meta_game._form_n_matrices_from_base_game_payoffs(hp) + hp = meta_solver_cross_play._load_meta_policies(hp, meta_game_algo) + return hp + + +if __name__ == "__main__": + debug_mode = False + loop_over_main = True + + if loop_over_main: + base_game_algo_to_eval = ( + BASE_LOLA_EXACT, + # BASE_NEGOTIATION, + ) + meta_game_algo_to_eval = ( + META_APLHA_RANK, + META_APLHA_PURE, + META_REPLICATOR_DYNAMIC, + # META_REPLICATOR_DYNAMIC_ZERO_INIT, + META_RANDOM, + META_PG, + # META_LOLA_EXACT, + META_SOS, + META_UNIFORM, + META_MINIMUM, + ) + pairs_seen = [] + results = [] + for base_game_algo_ in base_game_algo_to_eval: + for meta_game_algo_ in meta_game_algo_to_eval: + pair_result = [] + for exploiter_idx in range(2): + being_exploited_idx = (exploiter_idx + 1) % 2 + pl_1_mean, pl_2_mean, pl_1_std_err, pl_2_std_err = main( + debug_mode, + base_game_algo_, + meta_game_algo_, + exploiter_idx, + being_exploited_idx, + ) + pair_result.append( + { + "exploiter_player": exploiter_idx + 1, + "player_being_exploited": being_exploited_idx + 1, + "base_game": base_game_algo_, + "meta_game": meta_game_algo_, + "pl_1_mean": pl_1_mean, + "pl_2_mean": pl_2_mean, + "pl_1_std_err": pl_1_std_err, + "pl_2_std_err": pl_2_std_err, + } + ) + results.append(pair_result) + + print("final results") + print(results) + else: + main(debug_mode) diff --git a/marltoolbox/experiments/tune_class_api/negociation_game_replicates_list.pickle b/marltoolbox/experiments/tune_class_api/negociation_game_replicates_list.pickle new file mode 100644 index 0000000..797730a Binary files /dev/null and b/marltoolbox/experiments/tune_class_api/negociation_game_replicates_list.pickle differ diff --git a/marltoolbox/experiments/tune_class_api/psro_simple_bargaining.py b/marltoolbox/experiments/tune_class_api/psro_simple_bargaining.py new file mode 100644 index 0000000..8b62ad7 --- /dev/null +++ b/marltoolbox/experiments/tune_class_api/psro_simple_bargaining.py @@ -0,0 +1,385 @@ +import copy +import logging +import os +import time + +import ray +from ray import tune +from ray.rllib.agents.dqn import DQNTorchPolicy +from ray.tune.integration.wandb import WandbLoggerCallback + +from marltoolbox.algos.lola import ( + train_cg_tune_class_API, + train_pg_tune_class_API, +) +from marltoolbox.envs import ( + simple_bargaining, +) +from marltoolbox.scripts import aggregate_and_plot_tensorboard_data +from marltoolbox.utils import policy, log, cross_play, exp_analysis +from marltoolbox.utils.plot import PlotConfig +from marltoolbox.experiments.tune_class_api import lola_exact_official +from marltoolbox.algos.psro import PSROTrainer + +logger = logging.getLogger(__name__) + + +def main(debug: bool, env=None): + """ + Train several LOLA_PG pairs of agent on the selected environment and + plot their performances in self-play and cross-play. + + :param debug: selection of debug mode using less compute + :param env: option to overwrite the env selection + """ + train_n_replicates = 2 if debug else 40 + timestamp = int(time.time()) + seeds = [seed + timestamp for seed in list(range(train_n_replicates))] + + exp_name, _ = log.log_in_current_day_dir("PSRO") + + tune_hparams = _get_hyperparameters( + debug, train_n_replicates, seeds, exp_name, env + ) + + if tune_hparams["load_plot_data"] is None: + ray.init(num_cpus=10, num_gpus=0, local_mode=debug) + experiment_analysis_per_welfare = _train(tune_hparams) + else: + experiment_analysis_per_welfare = None + + _evaluate(tune_hparams, debug, experiment_analysis_per_welfare) + ray.shutdown() + + +def _get_hyperparameters(debug, train_n_replicates, seeds, exp_name, env): + + tune_hparams = { + "debug": debug, + "exp_name": exp_name, + "train_n_replicates": train_n_replicates, + "wandb": { + "project": "LOLA_PG", + "group": exp_name, + "api_key_file": os.path.join( + os.path.dirname(__file__), "../../../api_key_wandb" + ), + }, + # "classify_into_welfare_fn": True, + "seed": tune.grid_search(seeds), + "load_plot_data": None, + # Example: "load_plot_data": ".../SelfAndCrossPlay_save.p", + "n_players": 2, + # "game_name": "kuhn_poker", + "game_name": "python_simple_bargaining", + "oracle_type": "PG", + "training_strategy_selector": "probabilistic", + "rectifier": "", + "sims_per_entry": 1000, + "number_policies_selected": 1, + "meta_strategy_method": "alpharank", + "symmetric_game": False, + "verbose": True, + "loss_str": "qpg", + "hidden_layer_size": 256, + "n_hidden_layers": 4, + "batch_size": 32, + "entropy_cost": 0.001, + "critic_learning_rate": 1e-2, + "pi_learning_rate": 1e-3, + "num_q_before_pi": 8, + "optimizer_str": "adam", + "self_play_proportion": 0.0, + "number_training_episodes": int(1e4), + "play_proportion": 0.0, + "sigma": 0.0, + "dqn_learning_rate": 1e-2, + "update_target_network_every": 1000, + "learn_every": 10, + "num_iterations": 100, + "plot_keys": [ + "reward", + "total_reward", + "entrop", + ], + "plot_assemblage_tags": [ + ("total_reward",), + ("entrop",), + ], + } + return tune_hparams + + +def _train(tune_hp): + tune_config, stop, env_config = _get_tune_config(tune_hp) + + # Train with the Tune Class API (not RLLib Class) + experiment_analysis = tune.run( + PSROTrainer, + name=tune_hp["exp_name"], + config=tune_config, + checkpoint_at_end=True, + stop=stop, + metric=tune_config["metric"], + mode="max", + log_to_file=True, # not tune_hp["debug"], + callbacks=None + if tune_hp["debug"] + else [ + WandbLoggerCallback( + project=tune_hp["wandb"]["project"], + group=tune_hp["wandb"]["group"], + api_key_file=tune_hp["wandb"]["api_key_file"], + log_config=True, + ) + ], + ) + + if tune_hp["classify_into_welfare_fn"]: + experiment_analysis_per_welfare = ( + _classify_trials_in_function_of_welfare( + experiment_analysis, tune_hp + ) + ) + else: + experiment_analysis_per_welfare = {"": experiment_analysis} + + aggregate_and_plot_tensorboard_data.add_summary_plots( + main_path=os.path.join("~/ray_results/", tune_config["exp_name"]), + plot_keys=tune_config["plot_keys"], + plot_assemble_tags_in_one_plot=tune_config["plot_assemblage_tags"], + ) + + return experiment_analysis_per_welfare + + +def _get_tune_config(tune_hp: dict, stop_on_epi_number: bool = False): + tune_config = copy.deepcopy(tune_hp) + + tune_config["plot_keys"] += ( + train_cg_tune_class_API.PLOT_KEYS + + aggregate_and_plot_tensorboard_data.PLOT_KEYS + ) + tune_config["plot_assemblage_tags"] += ( + train_cg_tune_class_API.PLOT_ASSEMBLAGE_TAGS + + aggregate_and_plot_tensorboard_data.PLOT_ASSEMBLAGE_TAGS + ) + tune_hp["x_limits"] = (-3.0, 3.0) + tune_hp["y_limits"] = (-3.0, 3.0) + tune_hp["jitter"] = 0.05 + tune_config["metric"] = "training_iteration" + env_config = {} + # For hyperparameter search + tune_hp["scale_multipliers"] = tune.sample_from( + lambda spec: ( + 1 / 1, + 1 / 1, + ) + ) + if stop_on_epi_number: + stop = {"episodes_total": tune_config["num_iterations"]} + else: + stop = {"finished": True} + + return tune_config, stop, env_config + + +def _evaluate(tune_hp, debug, experiment_analysis_per_exp): + ( + rllib_hp, + rllib_config_eval, + policies_to_load, + trainable_class, + stop, + env_config, + ) = _generate_eval_config(tune_hp, debug) + + _evaluate_self_and_cross_perf( + rllib_hp, + rllib_config_eval, + policies_to_load, + trainable_class, + stop, + env_config, + experiment_analysis_per_exp, + ) + + +def _generate_eval_config(tune_hp, debug): + rllib_hp = copy.deepcopy(tune_hp) + rllib_hp["seed"] = 2020 + rllib_hp["num_episodes"] = 1 if debug else 100 + tune_config, stop, env_config = _get_tune_config( + rllib_hp, stop_on_epi_number=True + ) + rllib_hp["env_class"] = tune_config["env_class"] + + if "CoinGame" in tune_config["env_name"]: + env_config["batch_size"] = 1 + tune_config["TuneTrainerClass"] = train_cg_tune_class_API.LOLAPGCG + else: + tune_config["TuneTrainerClass"] = train_pg_tune_class_API.LOLAPGMatrice + tune_config["env_config"].update( + { + "batch_size": env_config["batch_size"], + "max_steps": rllib_hp["trace_length"], + } + ) + rllib_hp["scale_multipliers"] = ( + 1 / rllib_hp["trace_length"], + 1 / rllib_hp["trace_length"], + ) + + rllib_config_eval = { + "env": rllib_hp["env_class"], + "env_config": env_config, + "multiagent": { + "policies": { + env_config["players_ids"][0]: ( + # The default policy is DQN defined in DQNTrainer + # but we overwrite it to use the LE policy + policy.get_tune_policy_class(DQNTorchPolicy), + rllib_hp["env_class"](env_config).OBSERVATION_SPACE, + rllib_hp["env_class"].ACTION_SPACE, + {"tune_config": tune_config}, + ), + env_config["players_ids"][1]: ( + policy.get_tune_policy_class(DQNTorchPolicy), + rllib_hp["env_class"](env_config).OBSERVATION_SPACE, + rllib_hp["env_class"].ACTION_SPACE, + {"tune_config": tune_config}, + ), + }, + "policy_mapping_fn": lambda agent_id: agent_id, + "policies_to_train": ["None"], + }, + "seed": rllib_hp["seed"], + "min_iter_time_s": 3.0, + "callbacks": log.get_logging_callbacks_class( + log_full_epi=True, + ), + "num_envs_per_worker": 1, + "num_workers": 0, + } + + policies_to_load = copy.deepcopy(env_config["players_ids"]) + + if "CoinGame" in rllib_hp["env_name"]: + trainable_class = train_cg_tune_class_API.LOLAPGCG + rllib_config_eval["model"] = { + "dim": env_config["grid_size"], + # [Channel, [Kernel, Kernel], Stride]] + "conv_filters": [[16, [3, 3], 1], [32, [3, 3], 1]], + } + else: + trainable_class = train_pg_tune_class_API.LOLAPGMatrice + + return ( + rllib_hp, + rllib_config_eval, + policies_to_load, + trainable_class, + stop, + env_config, + ) + + +def _evaluate_self_and_cross_perf( + rllib_hp, + rllib_config_eval, + policies_to_load, + trainable_class, + stop, + env_config, + experiment_analysis_per_welfare, + n_cross_play_per_checkpoint=None, +): + exp_name = os.path.join(rllib_hp["exp_name"], "eval") + evaluator = cross_play.evaluator.SelfAndCrossPlayEvaluator( + exp_name=exp_name, + local_mode=rllib_hp["debug"], + ) + analysis_metrics_per_mode = evaluator.perform_evaluation_or_load_data( + evaluation_config=rllib_config_eval, + stop_config=stop, + policies_to_load_from_checkpoint=policies_to_load, + experiment_analysis_per_welfare=experiment_analysis_per_welfare, + tune_trainer_class=trainable_class, + n_cross_play_per_checkpoint=min(5, rllib_hp["train_n_replicates"] - 1) + if n_cross_play_per_checkpoint is None + else n_cross_play_per_checkpoint, + to_load_path=rllib_hp["load_plot_data"], + ) + + plot_config = PlotConfig( + xlim=rllib_hp["x_limits"], + ylim=rllib_hp["y_limits"], + markersize=5, + jitter=rllib_hp["jitter"], + xlabel="player 1 payoffs", + ylabel="player 2 payoffs", + plot_max_n_points=rllib_hp["train_n_replicates"], + x_scale_multiplier=rllib_hp["scale_multipliers"][0], + y_scale_multiplier=rllib_hp["scale_multipliers"][1], + ) + evaluator.plot_results( + analysis_metrics_per_mode, + plot_config=plot_config, + x_axis_metric=f"policy_reward_mean/{env_config['players_ids'][0]}", + y_axis_metric=f"policy_reward_mean/{env_config['players_ids'][1]}", + ) + + +FAILURE = "failures" +EGALITARIAN = "egalitarian" +UTILITARIAN = "utilitarian" + + +def _classify_trials_in_function_of_welfare(experiment_analysis, hp): + experiment_analysis_per_welfare = {} + for trial in experiment_analysis.trials: + welfare_name = _get_trial_welfare(trial, hp) + if welfare_name not in experiment_analysis_per_welfare.keys(): + lola_exact_official._add_empty_experiment_analysis( + experiment_analysis_per_welfare, + welfare_name, + experiment_analysis, + ) + experiment_analysis_per_welfare[welfare_name].trials.append(trial) + return experiment_analysis_per_welfare + + +def _get_trial_welfare(trial, hp): + pick_own_player_1 = trial.last_result["player_red_pick_own_color"] + pick_own_player_2 = trial.last_result["player_blue_pick_own_color"] + reward_player_1 = trial.last_result["total_reward_player_red"] + reward_player_2 = trial.last_result["total_reward_player_blue"] + welfare_name = lola_pg_classify_fn( + pick_own_player_1, + pick_own_player_2, + hp, + reward_player_1, + reward_player_2, + ) + return welfare_name + + +def lola_pg_classify_fn( + pick_own_player_1, pick_own_player_2, hp, reward_player_1, reward_player_2 +): + if reward_player_2 != 0.0 and reward_player_1 != 0.0: + if hp["env_name"] == "VectorizedSSDMixedMotiveCoinGame": + ratio = reward_player_2 / reward_player_1 + else: + ratio = max( + reward_player_1 / reward_player_2, + reward_player_2 / reward_player_1, + ) + if ratio > 1.2: + return UTILITARIAN + return EGALITARIAN + + +if __name__ == "__main__": + debug_mode = True + main(debug_mode) diff --git a/marltoolbox/experiments/tune_class_api/psro_simple_bargaining_hardcoded.py b/marltoolbox/experiments/tune_class_api/psro_simple_bargaining_hardcoded.py new file mode 100644 index 0000000..8aec6ef --- /dev/null +++ b/marltoolbox/experiments/tune_class_api/psro_simple_bargaining_hardcoded.py @@ -0,0 +1,367 @@ +import copy +import logging +import os +import time + +import ray +from ray import tune +from ray.rllib.agents import pg +from ray.rllib.agents.pg import PGTorchPolicy + +from marltoolbox.algos.psro_hardcoded import PSROTrainer +from marltoolbox.envs import ( + simple_bargaining, +) +from marltoolbox.scripts import aggregate_and_plot_tensorboard_data +from marltoolbox.utils import ( + policy, + log, + cross_play, + miscellaneous, + callbacks, +) +from marltoolbox.utils.plot import PlotConfig + +logger = logging.getLogger(__name__) + + +def main(debug: bool, env=None): + """ + Train several PSRO pairs of agent on the selected environment and + plot their performances in self-play and cross-play. + + :param debug: selection of debug mode using less compute + :param env: option to overwrite the env selection + """ + train_n_replicates = 2 if debug else 4 + timestamp = int(time.time()) + seeds = [seed + timestamp for seed in list(range(train_n_replicates))] + + exp_name, _ = log.log_in_current_day_dir("PSRO_hardcoded") + + tune_hparams = _get_hyperparameters( + debug, train_n_replicates, seeds, exp_name, env + ) + + if tune_hparams["load_plot_data"] is None: + ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug) + experiment_analysis_per_welfare = _train(tune_hparams) + else: + experiment_analysis_per_welfare = None + + _evaluate(tune_hparams, debug, experiment_analysis_per_welfare) + ray.shutdown() + + +def _get_hyperparameters(debug, train_n_replicates, seeds, exp_name, env): + from ray.rllib.models.catalog import MODEL_DEFAULTS + + oracle_model_config = copy.deepcopy(MODEL_DEFAULTS) + oracle_model_config.update( + { + # "fcnet_hiddens": [16, 16], + # "fcnet_activation": "relu", + "fcnet_hiddens": [], + } + ) + + oracle_config = copy.deepcopy(pg.DEFAULT_CONFIG) + oracle_config.update( + { + "gamma": 0.96, + "train_batch_size": 1, + "model": oracle_model_config, + "lr": 0.001, + } + ) + + tune_hparams = { + "debug": debug, + "exp_name": exp_name, + "train_n_replicates": train_n_replicates, + "wandb": { + "project": "LOLA_PG", + "group": exp_name, + "api_key_file": os.path.join( + os.path.dirname(__file__), "../../../api_key_wandb" + ), + }, + "classify_into_welfare_fn": False, + "seed": tune.grid_search(seeds), + "load_plot_data": None, + # Example: "load_plot_data": ".../SelfAndCrossPlay_save.p", + # "game_name": "kuhn_poker", + "training": True, + "eval_cell_over_n_epi": 10 if debug else 100, + "train_oracle_n_epi": 100 if debug else 4000, + "num_iterations": 3 if debug else 10, + "oracle_config": oracle_config, + "verbose": debug, + "center_returns": False, + "env_class": simple_bargaining.SimpleBargaining, + "env_config": { + "players_ids": ["player_0", "player_1"], + "n_steps_by_epi": 1, + }, + "plot_keys": [ + "reward", + "total_reward", + "entrop", + ], + "plot_assemblage_tags": [ + ("total_reward",), + ("entrop",), + ], + } + return tune_hparams + + +def _train(tune_hp): + tune_config, stop, env_config = _get_tune_config(tune_hp) + + # Train with the Tune Class API (not RLLib Class) + experiment_analysis = tune.run( + PSROTrainer, + name=tune_hp["exp_name"], + config=tune_config, + checkpoint_at_end=True, + checkpoint_freq=0, + stop=stop, + metric=tune_config["metric"], + mode="max", + log_to_file=not tune_hp["debug"], + # callbacks=None + # if tune_hp["debug"] + # else [ + # WandbLoggerCallback( + # project=tune_hp["wandb"]["project"], + # group=tune_hp["wandb"]["group"], + # api_key_file=tune_hp["wandb"]["api_key_file"], + # log_config=True, + # ) + # ], + ) + + # if tune_hp["classify_into_welfare_fn"]: + # experiment_analysis_per_welfare = ( + # _classify_trials_in_function_of_welfare( + # experiment_analysis, tune_hp + # ) + # ) + # else: + experiment_analysis_per_welfare = {"": experiment_analysis} + + aggregate_and_plot_tensorboard_data.add_summary_plots( + main_path=os.path.join("~/ray_results/", tune_config["exp_name"]), + plot_keys=tune_config["plot_keys"], + plot_assemble_tags_in_one_plot=tune_config["plot_assemblage_tags"], + ) + + return experiment_analysis_per_welfare + + +def _get_tune_config(tune_hp: dict, stop_on_epi_number: bool = False): + tune_config = copy.deepcopy(tune_hp) + + tune_config["plot_keys"] += aggregate_and_plot_tensorboard_data.PLOT_KEYS + tune_config[ + "plot_assemblage_tags" + ] += aggregate_and_plot_tensorboard_data.PLOT_ASSEMBLAGE_TAGS + tune_hp["x_limits"] = (-0.1, 3.0) + tune_hp["y_limits"] = (-0.1, 3.0) + tune_hp["jitter"] = 0.00 + tune_config["metric"] = "training_iteration" + env_config = tune_hp["env_config"] + # For hyperparameter search + tune_hp["scale_multipliers"] = ( + 1 / 1, + 1 / 1, + ) + if stop_on_epi_number: + stop = {"episodes_total": tune_config["num_iterations"]} + else: + stop = {"finished": True} + + return tune_config, stop, env_config + + +def _evaluate(tune_hp, debug, experiment_analysis_per_exp): + ( + rllib_hp, + rllib_config_eval, + policies_to_load, + trainable_class, + stop, + env_config, + ) = _generate_eval_config(tune_hp, debug) + + _evaluate_self_and_cross_perf( + rllib_hp, + rllib_config_eval, + policies_to_load, + trainable_class, + stop, + env_config, + experiment_analysis_per_exp, + ) + + +def _generate_eval_config(tune_hp, debug): + rllib_hp = copy.deepcopy(tune_hp) + rllib_hp["seed"] = miscellaneous.get_random_seeds(1)[0] + rllib_hp["num_episodes"] = 5 if debug else 100 + tune_config, stop, env_config = _get_tune_config( + rllib_hp, stop_on_epi_number=True + ) + rllib_hp["env_class"] = tune_config["env_class"] + + tune_config["TuneTrainerClass"] = PSROTrainer + tune_config["training"] = False + + rllib_config_eval = { + "env": rllib_hp["env_class"], + "env_config": env_config, + "multiagent": { + "policies": { + env_config["players_ids"][0]: ( + policy.get_tune_policy_class(PGTorchPolicy), + rllib_hp["env_class"].OBSERVATION_SPACE, + rllib_hp["env_class"].ACTION_SPACE, + {"tune_config": tune_config}, + ), + env_config["players_ids"][1]: ( + policy.get_tune_policy_class(PGTorchPolicy), + rllib_hp["env_class"].OBSERVATION_SPACE, + rllib_hp["env_class"].ACTION_SPACE, + {"tune_config": tune_config}, + ), + }, + "policy_mapping_fn": lambda agent_id: agent_id, + "policies_to_train": ["None"], + }, + "seed": rllib_hp["seed"], + "min_iter_time_s": 3.0, + "callbacks": callbacks.merge_callbacks( + log.get_logging_callbacks_class( + log_full_epi=True, + ), + callbacks.PolicyCallbacks, + ), + "num_envs_per_worker": 1, + "num_workers": 0, + "framework": "torch", + } + + policies_to_load = copy.deepcopy(env_config["players_ids"]) + + trainable_class = PSROTrainer + + return ( + rllib_hp, + rllib_config_eval, + policies_to_load, + trainable_class, + stop, + env_config, + ) + + +def _evaluate_self_and_cross_perf( + rllib_hp, + rllib_config_eval, + policies_to_load, + trainable_class, + stop, + env_config, + experiment_analysis_per_welfare, + n_cross_play_per_checkpoint=None, +): + exp_name = os.path.join(rllib_hp["exp_name"], "eval") + evaluator = cross_play.evaluator.SelfAndCrossPlayEvaluator( + exp_name=exp_name, + local_mode=rllib_hp["debug"], + ) + analysis_metrics_per_mode = evaluator.perform_evaluation_or_load_data( + evaluation_config=rllib_config_eval, + stop_config=stop, + policies_to_load_from_checkpoint=policies_to_load, + experiment_analysis_per_welfare=experiment_analysis_per_welfare, + tune_trainer_class=trainable_class, + n_cross_play_per_checkpoint=min(5, rllib_hp["train_n_replicates"] - 1) + if n_cross_play_per_checkpoint is None + else n_cross_play_per_checkpoint, + to_load_path=rllib_hp["load_plot_data"], + ) + + plot_config = PlotConfig( + xlim=rllib_hp["x_limits"], + ylim=rllib_hp["y_limits"], + markersize=5, + jitter=rllib_hp["jitter"], + xlabel="player 1 payoffs", + ylabel="player 2 payoffs", + plot_max_n_points=rllib_hp["train_n_replicates"], + x_scale_multiplier=rllib_hp["scale_multipliers"][0], + y_scale_multiplier=rllib_hp["scale_multipliers"][1], + ) + evaluator.plot_results( + analysis_metrics_per_mode, + plot_config=plot_config, + x_axis_metric=f"policy_reward_mean/{env_config['players_ids'][0]}", + y_axis_metric=f"policy_reward_mean/{env_config['players_ids'][1]}", + ) + + +# FAILURE = "failures" +# EGALITARIAN = "egalitarian" +# UTILITARIAN = "utilitarian" + + +# def _classify_trials_in_function_of_welfare(experiment_analysis, hp): +# experiment_analysis_per_welfare = {} +# for trial in experiment_analysis.trials: +# welfare_name = _get_trial_welfare(trial, hp) +# if welfare_name not in experiment_analysis_per_welfare.keys(): +# lola_exact_official._add_empty_experiment_analysis( +# experiment_analysis_per_welfare, +# welfare_name, +# experiment_analysis, +# ) +# experiment_analysis_per_welfare[welfare_name].trials.append(trial) +# return experiment_analysis_per_welfare +# +# +# def _get_trial_welfare(trial, hp): +# pick_own_player_1 = trial.last_result["player_red_pick_own_color"] +# pick_own_player_2 = trial.last_result["player_blue_pick_own_color"] +# reward_player_1 = trial.last_result["total_reward_player_red"] +# reward_player_2 = trial.last_result["total_reward_player_blue"] +# welfare_name = lola_pg_classify_fn( +# pick_own_player_1, +# pick_own_player_2, +# hp, +# reward_player_1, +# reward_player_2, +# ) +# return welfare_name +# +# +# def lola_pg_classify_fn( +# pick_own_player_1, pick_own_player_2, hp, reward_player_1, reward_player_2 +# ): +# if reward_player_2 != 0.0 and reward_player_1 != 0.0: +# if hp["env_name"] == "VectorizedSSDMixedMotiveCoinGame": +# ratio = reward_player_2 / reward_player_1 +# else: +# ratio = max( +# reward_player_1 / reward_player_2, +# reward_player_2 / reward_player_1, +# ) +# if ratio > 1.2: +# return UTILITARIAN +# return EGALITARIAN + + +if __name__ == "__main__": + debug_mode = False + main(debug_mode) diff --git a/marltoolbox/experiments/tune_class_api/simple_bargaining.py b/marltoolbox/experiments/tune_class_api/simple_bargaining.py new file mode 100644 index 0000000..983d3d7 --- /dev/null +++ b/marltoolbox/experiments/tune_class_api/simple_bargaining.py @@ -0,0 +1,853 @@ +import numpy as np + +# import nashpy as nash +import torch +import matplotlib.pyplot as plt +from functools import partial + +# from prd import prd +from collections import Counter +import pdb +import argparse +import copy + +# from feasible_set_figure import optimize_welfare_discrete +# from meta_solver_exploration import sweep_and_plot_alpha +from open_spiel.python.algorithms import projected_replicator_dynamics +from open_spiel.python.egt import alpharank +from open_spiel.python.algorithms.psro_v2 import utils +import matplotlib.pyplot as plt +from marltoolbox.utils import miscellaneous + +np.set_printoptions(suppress=True) + +IPD_PAYOUT = torch.Tensor([[2, -3], [3, -1]]) +ASYM_IPD_PAYOUT_1 = torch.Tensor([[2, -3], [20, -1]]) +ASYM_IPD_PAYOUT_2 = torch.Tensor([[2, 3], [-2.5, -1]]) + +DISAGREEMENT = 0.0 +MISCOORDINATION = 0.8 +# CHICKEN_PAYOUT_1 = np.array([[1., -2.], [2, -10]]) +# CHICKEN_PAYOUT_2 = np.array([[1., 2.], [-2, -10]]) + +BOTS_PAYOUT_1 = torch.Tensor([[4.0, MISCOORDINATION], [MISCOORDINATION, 2.0]]) +BOTS_PAYOUT_2 = torch.Tensor([[1.5, MISCOORDINATION], [MISCOORDINATION, 2.0]]) +META_PAYOUT_1 = torch.Tensor( + [ + [ + 4.0, + MISCOORDINATION, + MISCOORDINATION, + 4.0, + 4.0, + MISCOORDINATION, + 4.0, + ], + [ + MISCOORDINATION, + 2.0, + MISCOORDINATION, + 2.0, + MISCOORDINATION, + 2.0, + 2.0, + ], + [ + MISCOORDINATION, + MISCOORDINATION, + 3.0, + MISCOORDINATION, + 3.0, + 3.0, + 3.0, + ], + [4.0, 2.0, MISCOORDINATION, 3.0, 4.0, 2.0, 3.0], + [4.0, MISCOORDINATION, 3.0, 4.0, 3.5, 3.0, 3.5], + [MISCOORDINATION, 2.0, 3.0, 2.0, 3.0, 2.5, 2.5], + [4.0, 2.0, 3.0, 3.0, 3.5, 2.5, 3.0], + ] +) +META_PAYOUT_2 = torch.Tensor( + [ + [1.0, MISCOORDINATION, MISCOORDINATION, 1.0, 1, MISCOORDINATION, 1.0], + [ + MISCOORDINATION, + 2.0, + MISCOORDINATION, + 2.0, + MISCOORDINATION, + 2.0, + 2.0, + ], + [ + MISCOORDINATION, + MISCOORDINATION, + 1.5, + MISCOORDINATION, + 1.5, + 1.5, + 1.5, + ], + [1.0, 2.0, MISCOORDINATION, 1.5, 1.0, 2.0, 1.5], + [1.0, MISCOORDINATION, 1.5, 1.0, 1.25, 1.5, 1.75], + [MISCOORDINATION, 2.0, 1.5, 2.0, 1.5, 1.75, 1.75], + [1.0, 2.0, 1.5, 1.5, 1.25, 1.75, 1.5], + ] +) + +# SMALL_META_PAYOUT_1 = np.array([[4., 0., 4.], [0., 2., 2.], [4., 2., 3.]]) +# SMALL_META_PAYOUT_2 = np.array([[1., 0., 1.], [0., 2., 2.], [1., 2., 1.5]]) +MULTIPLIER = 0.2 +P11, P12, P21, P22 = np.array([3, 9, 7, 2]) * MULTIPLIER +TEMPERATURE = 20 +G = 3 + + +def bots(gamma=0.96): + dims = [5, 5] + payout_mat_1 = BOTS_PAYOUT_1 + payout_mat_2 = BOTS_PAYOUT_2 + + def Ls(th): + p_1_0 = torch.sigmoid(th[0][0:1]) + p_2_0 = torch.sigmoid(th[1][0:1]) + p = torch.cat( + [ + p_1_0 * p_2_0, + p_1_0 * (1 - p_2_0), + (1 - p_1_0) * p_2_0, + (1 - p_1_0) * (1 - p_2_0), + ] + ) + p_1 = torch.reshape(torch.sigmoid(th[0][1:5]), (4, 1)) + p_2 = torch.reshape(torch.sigmoid(th[1][1:5]), (4, 1)) + P = torch.cat( + [ + p_1 * p_2, + p_1 * (1 - p_2), + (1 - p_1) * p_2, + (1 - p_1) * (1 - p_2), + ], + dim=1, + ) + M = -torch.matmul(p, torch.inverse(torch.eye(4) - gamma * P)) + L_1 = torch.matmul(M, torch.reshape(payout_mat_1, (4, 1))) + L_2 = torch.matmul(M, torch.reshape(payout_mat_2, (4, 1))) + return [L_1, L_2] + + return dims, Ls + + +def ipd(gamma=0.96, asymmetric=False): + dims = [5, 5] + if asymmetric: + payout_mat_1 = ASYM_IPD_PAYOUT_1 + payout_mat_2 = ASYM_IPD_PAYOUT_2 + else: + payout_mat_1 = IPD_PAYOUT + payout_mat_2 = payout_mat_1.T + + def Ls(th): + p_1_0 = torch.sigmoid(th[0][0:1]) + p_2_0 = torch.sigmoid(th[1][0:1]) + p = torch.cat( + [ + p_1_0 * p_2_0, + p_1_0 * (1 - p_2_0), + (1 - p_1_0) * p_2_0, + (1 - p_1_0) * (1 - p_2_0), + ] + ) + p_1 = torch.reshape(torch.sigmoid(th[0][1:5]), (4, 1)) + p_2 = torch.reshape(torch.sigmoid(th[1][1:5]), (4, 1)) + P = torch.cat( + [ + p_1 * p_2, + p_1 * (1 - p_2), + (1 - p_1) * p_2, + (1 - p_1) * (1 - p_2), + ], + dim=1, + ) + M = -torch.matmul(p, torch.inverse(torch.eye(4) - gamma * P)) + L_1 = torch.matmul(M, torch.reshape(payout_mat_1, (4, 1))) + L_2 = torch.matmul(M, torch.reshape(payout_mat_2, (4, 1))) + return [L_1, L_2] + + return dims, Ls + + +def meta(): + dims = [7, 7] + payout_mat_1 = META_PAYOUT_1 + payout_mat_2 = META_PAYOUT_2 + + fair1 = torch.Tensor([1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0]) / 6 + fair2 = torch.Tensor([0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) / 6 + + def Ls(th): + p1 = torch.nn.functional.softmax(th[0]) + p2 = torch.nn.functional.softmax(th[1]) + L_1 = -torch.dot(p2, torch.matmul(payout_mat_1, p1)) + L_2 = -torch.dot(p1, torch.matmul(payout_mat_2, p2)) + return [L_1, L_2] + + def Ls2(th): + p1 = torch.nn.functional.softmax(th[0]) + p2 = torch.nn.functional.softmax(th[1]) + L_1 = -torch.dot(fair1, torch.matmul(payout_mat_1, p1)) + L_2 = -torch.dot(fair2, torch.matmul(payout_mat_2, p2)) + return [L_1, L_2] + + return dims, Ls, Ls2 + + +# def tandem(): +# dims = [1, 1] +# def Ls(th): +# x, y = th +# L_1 = (x+y)**2-2*x +# L_2 = (x+y)**2-2*y +# return [L_1, L_2] +# return dims, Ls + + +# def get_policies_as_matrix(th, nA): +# # Parameters for initial actions +# p_1_0 = torch.nn.functional.softmax(th[0][0:nA]) +# p_2_0 = torch.nn.functional.softmax(th[1][0:nA]) +# p = torch.flatten(torch.ger(p_1_0, p_2_0)) +# +# # Parameters for actions conditional on previous action profile +# p_1_lst = [] +# p_2_lst = [] +# joint_p_lst = [] +# for i in range(1, nA**2+1): # Loop through action profile indices +# p_1_i = torch.nn.functional.softmax(th[0][(nA*i):(nA*i+nA)]) +# p_2_i = torch.nn.functional.softmax(th[1][(nA*i):(nA*i+nA)]) +# p_1_lst.append(p_1_i) +# p_2_lst.append(p_2_i) +# +# # joint_p gives probability of each action profile conditional on previous profile (torch.ger is outer product) +# joint_p = torch.ger(p_1_i, p_2_i).reshape(nA**2, 1) +# joint_p_lst.append(joint_p.T) +# +# p_1 = torch.cat(p_1_lst) +# p_2 = torch.cat(p_2_lst) +# P = torch.cat(joint_p_lst, dim=0) +# return p, p_1, p_2, P + + +# def bipd_loss(th, u1, u2, gamma): +# """ +# Get closed form value functions for each player given policy profile parameter +# (negated since these algorithms act on losses). +# V_i = u_i + \gamma * PV_i +# => V_i = (I - \gamma P)^-1 u_i; +# L_i = -p.V_i, where p is distribution of initial action profiles +# :param th: Tensor of parameters for each policy +# :param u1: Player 1 payoff matrix +# :param u2: Player 2 payoff matrix +# """ +# p, p_1, p_2, P = get_bipd_policies_as_matrix(th) +# M = -torch.matmul(p, torch.inverse(torch.eye(9) - gamma * P)) +# L_1 = torch.matmul(M, torch.reshape(u1, (9, 1))) +# L_2 = torch.matmul(M, torch.reshape(u2, (9, 1))) +# return [L_1, L_2] +# +# +# def bipd(payout_mat_1, payout_mat_2, gamma=0.96): +# dims = [30, 30] +# # ToDo: fiddling with payoffs, may not be original bipd payoffs +# +# Ls = partial(bipd_loss, u1=payout_mat_1, u2=payout_mat_2, gamma=gamma) +# +# return dims, Ls + + +def continuous_bargaining(temperature=None): + """ + Utility functions of the form + 1[ \theta_1 > \tau_2 ] * 1[ \theta_2 < \tau_1] * [ theta_i^P_i1 + (1 - theta_i)*P_i2 ] + Policy params are allocations \theta_i and acceptance thresholds \theta_i. + """ + + dims = [2, 2] + if temperature is None: + temperature = TEMPERATURE + + def Ls(th_, soft_cutoff=True): + allocation_1 = torch.sigmoid(th_[0][0]) + allocation_2 = torch.sigmoid(th_[1][0]) + cutoff_1 = torch.sigmoid(th_[0][1]) + cutoff_2 = torch.sigmoid(th_[1][1]) + + soft_indicator_1 = torch.sigmoid( + temperature * (allocation_1 - cutoff_2) + ) + soft_indicator_2 = torch.sigmoid( + temperature * (cutoff_1 - allocation_2) + ) + soft_indicator = soft_indicator_1 * soft_indicator_2 + + if not soft_cutoff: + if (allocation_1 - cutoff_2) < 0 or (cutoff_1 - allocation_2) < 0: + return [torch.tensor(0.0), torch.tensor(0.0)] + else: + soft_indicator = 1.0 + + l11 = torch.log(torch.pow(allocation_1 + allocation_2, P11) + 1) + l12 = torch.log( + torch.pow(1 - allocation_1 + G * (1 - allocation_2), P12) + 1 + ) + l21 = torch.log(torch.pow(G * allocation_1 + allocation_2, P21) + 1) + l22 = torch.log(torch.pow(2 - (allocation_1 + allocation_2), P22) + 1) + + agreement_payoff_1 = l11 + l12 + agreement_payoff_2 = l21 + l22 + + L1 = -( + soft_indicator * agreement_payoff_1 + + (1 - soft_indicator) * DISAGREEMENT + ) + L2 = -( + soft_indicator * agreement_payoff_2 + + (1 - soft_indicator) * DISAGREEMENT + ) + + return [L1, L2] + + return dims, Ls + + +# @markdown Gradient computations for each algorithm. +def init_th(dims, std): + th = [] + for i in range(len(dims)): + if std > 0: + init = torch.nn.init.normal_( + torch.empty(dims[i], requires_grad=True), std=std + ) + else: + init = torch.zeros(dims[i], requires_grad=True) + th.append(init) + return th + + +def get_gradient(function, param): + grad = torch.autograd.grad(function, param, create_graph=True)[0] + return grad + + +def get_hessian(th, grad_L, diag=True, off_diag=True): + n = len(th) + H = [] + for i in range(n): + row_block = [] + for j in range(n): + if (i == j and diag) or (i != j and off_diag): + block = [ + torch.unsqueeze( + get_gradient(grad_L[i][i][k], th[j]), dim=0 + ) + for k in range(len(th[i])) + ] + row_block.append(torch.cat(block, dim=0)) + else: + row_block.append(torch.zeros(len(th[i]), len(th[j]))) + H.append(torch.cat(row_block, dim=1)) + return torch.cat(H, dim=0) + + +def update_th( + th, + Ls, + alpha, + algo, + Ls2=None, + a=0.5, + b=0.1, + gam=1, + ep=0.1, + lss_lam=0.1, + tradeoff=None, +): + n = len(th) + losses = Ls(th) + + # Compute gradients + grad_L = [ + [get_gradient(losses[j], th[i]) for j in range(n)] for i in range(n) + ] + if Ls2 is not None: + losses2 = Ls2(th) + grad_L2 = [get_gradient(losses2[i], th[i]) for i in range(n)] + if algo == "la": + terms = [ + sum( + [ + torch.dot(grad_L[j][i], grad_L[j][j].detach()) + for j in range(n) + if j != i + ] + ) + for i in range(n) + ] + grads = [ + grad_L[i][i] - alpha * get_gradient(terms[i], th[i]) + for i in range(n) + ] + elif algo == "lola": + terms = [ + sum( + [ + torch.dot(grad_L[j][i], grad_L[j][j]) + for j in range(n) + if j != i + ] + ) + for i in range(n) + ] + grads = [ + grad_L[i][i] - alpha * get_gradient(terms[i], th[i]) + for i in range(n) + ] + elif algo == "sos": + terms = [ + sum( + [ + torch.dot(grad_L[j][i], grad_L[j][j].detach()) + for j in range(n) + if j != i + ] + ) + for i in range(n) + ] + xi_0 = [ + grad_L[i][i] - alpha * get_gradient(terms[i], th[i]) + for i in range(n) + ] + chi = [ + get_gradient( + sum( + [ + torch.dot(grad_L[j][i].detach(), grad_L[j][j]) + for j in range(n) + if j != i + ] + ), + th[i], + ) + for i in range(n) + ] + # Compute p + dot = torch.dot(-alpha * torch.cat(chi), torch.cat(xi_0)) + p1 = ( + 1 + if dot >= 0 + else min(1, -a * torch.norm(torch.cat(xi_0)) ** 2 / dot) + ) + xi = torch.cat([grad_L[i][i] for i in range(n)]) + xi_norm = torch.norm(xi) + p2 = xi_norm ** 2 if xi_norm < b else 1 + p = min(p1, p2) + grads = [xi_0[i] - p * alpha * chi[i] for i in range(n)] + elif algo == "sos_tradeoff": + terms = [ + sum( + [ + torch.dot(grad_L[j][i], grad_L[j][j].detach()) + for j in range(n) + if j != i + ] + ) + for i in range(n) + ] + xi_0 = [ + grad_L[i][i] - alpha * get_gradient(terms[i], th[i]) + for i in range(n) + ] + chi = [ + get_gradient( + sum( + [ + torch.dot(grad_L[j][i].detach(), grad_L[j][j]) + for j in range(n) + if j != i + ] + ), + th[i], + ) + for i in range(n) + ] + # Compute p + dot = torch.dot(-alpha * torch.cat(chi), torch.cat(xi_0)) + p1 = ( + 1 + if dot >= 0 + else min(1, -a * torch.norm(torch.cat(xi_0)) ** 2 / dot) + ) + xi = torch.cat([grad_L[i][i] for i in range(n)]) + xi_norm = torch.norm(xi) + p2 = xi_norm ** 2 if xi_norm < b else 1 + p = min(p1, p2) + grads = [ + tradeoff * (xi_0[i] - p * alpha * chi[i]) + + (1 - tradeoff) * grad_L2[i] + for i in range(n) + ] + elif algo == "sga": + xi = torch.cat([grad_L[i][i] for i in range(n)]) + ham = torch.dot(xi, xi.detach()) + H_t_xi = [get_gradient(ham, th[i]) for i in range(n)] + H_xi = [ + get_gradient( + sum( + [ + torch.dot(grad_L[j][i], grad_L[j][j].detach()) + for j in range(n) + ] + ), + th[i], + ) + for i in range(n) + ] + A_t_xi = [H_t_xi[i] / 2 - H_xi[i] / 2 for i in range(n)] + # Compute lambda (sga with alignment) + dot_xi = torch.dot(xi, torch.cat(H_t_xi)) + dot_A = torch.dot(torch.cat(A_t_xi), torch.cat(H_t_xi)) + d = sum([len(th[i]) for i in range(n)]) + lam = torch.sign(dot_xi * dot_A / d + ep) + grads = [grad_L[i][i] + lam * A_t_xi[i] for i in range(n)] + elif algo == "co": + xi = torch.cat([grad_L[i][i] for i in range(n)]) + ham = torch.dot(xi, xi.detach()) + grads = [ + grad_L[i][i] + gam * get_gradient(ham, th[i]) for i in range(n) + ] + elif algo == "eg": + th_eg = [ + th[i] - alpha * get_gradient(losses[i], th[i]) for i in range(n) + ] + losses_eg = Ls(th_eg) + grads = [get_gradient(losses_eg[i], th_eg[i]) for i in range(n)] + elif algo == "cgd": # Slow implementation (matrix inversion) + dims = [len(th[i]) for i in range(n)] + xi = torch.cat([grad_L[i][i] for i in range(n)]) + H_o = get_hessian(th, grad_L, diag=False) + grad = torch.matmul( + torch.inverse(torch.eye(sum(dims)) + alpha * H_o), xi + ) + grads = [grad[sum(dims[:i]) : sum(dims[: i + 1])] for i in range(n)] + elif algo == "lss": # Slow implementation (matrix inversion) + dims = [len(th[i]) for i in range(n)] + xi = torch.cat([grad_L[i][i] for i in range(n)]) + H = get_hessian(th, grad_L) + if torch.det(H) == 0: + inv = torch.inverse( + torch.matmul(H.T, H) + lss_lam * torch.eye(sum(dims)) + ) + H_inv = torch.matmul(inv, H.T) + else: + H_inv = torch.inverse(H) + grad = ( + torch.matmul(torch.eye(sum(dims)) + torch.matmul(H.T, H_inv), xi) + / 2 + ) + grads = [grad[sum(dims[:i]) : sum(dims[: i + 1])] for i in range(n)] + else: # Naive Learning + grads = [grad_L[i][i] for i in range(n)] + + # Update theta + with torch.no_grad(): + for i in range(n): + th[i] -= alpha * grads[i] + return th, losses + + +def continuous_bargaining_meta_penalty(th_): + penalty = 1 * torch.norm(th_[0][1:] - th_[1][1:]) + return penalty + + +def learn_bots_pd( + num_runs=20, + method="lola", + game="bots_pd", + save=False, + tradeoff=None, + U1=None, + U2=None, + std=1.0, + min_joint_perf=-1e9, + temperature=TEMPERATURE, + jitter=0.0, + num_epochs=200, + alpha=1.0, + verbose=False, + lr_decay=False, + hard_cutoff_in_eval=False, +): + gamma = 0.96 + Ls2 = None + + if game == "pd": + payout_mat_1 = IPD_PAYOUT + payout_mat_2 = payout_mat_1.T + dims, Ls = ipd(gamma) + + elif game == "bos": + payout_mat_1 = BOTS_PAYOUT_1 + payout_mat_2 = BOTS_PAYOUT_2 + dims, Ls = bots(gamma) + + elif game == "meta": + dims, Ls, Ls2 = meta() + + elif game == "continuous": + + dims, Ls = continuous_bargaining(temperature) + + # elif game == 'continuous_unexploitable': + # dims, Ls = continuous_bargaining_unexploitable() + # + # elif game == 'continuous_meta': + # dims, Ls = continuous_bargaining_meta() + # + # elif game == 'matrix': + # U1_tensor = torch.tensor(U1).float() + # U2_tensor = torch.tensor(U2).float() + # dims, Ls = matrix(U1_tensor, U2_tensor) + # + # elif game == 'aipd': + # dims, Ls = ipd(gamma, asymmetric=True) + + # alpha = 1 + # std = 1 + + # List of policy profiles thetai generated by several independent runs by player i + theta1_dbn = [] + theta2_dbn = [] + + losses_out = np.zeros((num_runs, num_epochs)) + + def _get_rewards(losses, th): + if game in ["meta", "continuous", "continuous_unexploitable"]: + r_pl1 = -losses[0].detach().numpy() + r_pl2 = -losses[1].detach().numpy() + elif game in ["continuous_meta"]: + pen = continuous_bargaining_meta_penalty(th).detach().numpy() + losses_ = Ls(th) + r_pl1 = -losses_[0].detach().numpy() + pen + r_pl2 = -losses_[1].detach().numpy() + pen + else: + r_pl1 = -losses[0].detach().numpy() * (1 - gamma) + r_pl2 = -losses[1].detach().numpy() * (1 - gamma) + return r_pl1, r_pl2 + + # Lists of (negative) losses for self play + final_losses_1 = [] + final_losses_2 = [] + + # List of (negative) losses for cross play + cross_play_losses_1 = [] + cross_play_losses_2 = [] + th_prev = None + for i in range(num_runs): + print("run n°", i) + r_pl1, r_pl2 = min_joint_perf, min_joint_perf + trial = 0 + while r_pl1 <= min_joint_perf and r_pl2 <= min_joint_perf: + trial += 1 + print("trial", trial, "rewards", r_pl1, r_pl2) + + seed = miscellaneous.get_random_seeds(1)[0] + torch.manual_seed(seed) + np.random.seed(seed) + + th = init_th(dims, std) + for k in range(num_epochs): + if lr_decay: + th, losses = update_th( + th, + Ls, + alpha * (num_epochs - k) / num_epochs, + method, + Ls2=Ls2, + tradeoff=tradeoff, + ) + else: + th, losses = update_th( + th, Ls, alpha, method, Ls2=Ls2, tradeoff=tradeoff + ) + # Negating so that higher => better! + losses_out[i, k] = -(1 - gamma) * losses[0].data.numpy() + + self_losses = Ls(th, soft_cutoff=hard_cutoff_in_eval) + r_pl1, r_pl2 = _get_rewards(self_losses, th) + + final_losses_1.append(r_pl1) + final_losses_2.append(r_pl2) + theta1_dbn.append(th[0]) + theta2_dbn.append(th[1]) + print("th", th) + + # Evaluate player 1 policy learned in previous run against player 2 policy learned in this run + if i > 0: + th_cross = [th_prev[0], th[1]] + cross_losses = Ls(th_cross, soft_cutoff=hard_cutoff_in_eval) + r_pl1_cross, r_pl2_cross = _get_rewards(cross_losses, th_cross) + cross_play_losses_1.append(r_pl1_cross) + cross_play_losses_2.append(r_pl2_cross) + + th_cross = [th[0], th_prev[1]] + cross_losses = Ls(th_cross, soft_cutoff=hard_cutoff_in_eval) + r_pl1_cross, r_pl2_cross = _get_rewards(cross_losses, th_cross) + cross_play_losses_1.append(r_pl1_cross) + cross_play_losses_2.append(r_pl2_cross) + + th_prev = th + + if verbose: + print("losses_out[i, :]", losses_out[i, :]) + + def _format_values(values): + if game not in [ + "meta", + "continuous", + "continuous_unexploitable", + "continuous_meta", + ]: + values = [v[0] for v in values] + values = np.array(values) + return values + + final_losses_1 = _format_values(final_losses_1) + final_losses_2 = _format_values(final_losses_2) + cross_play_losses_1 = _format_values(cross_play_losses_1) + cross_play_losses_2 = _format_values(cross_play_losses_2) + + print( + "self-play r_pl_1 mean std", + final_losses_1.mean(), + "std", + final_losses_1.std(), + ) + print( + "self-play r_pl_2 mean std", + final_losses_2.mean(), + "std", + final_losses_2.std(), + ) + print( + "cross-play r_pl_1 mean std", + cross_play_losses_1.mean(), + "std", + cross_play_losses_1.std(), + ) + print( + "cross-play r_pl_2 mean", + cross_play_losses_2.mean(), + "std", + cross_play_losses_2.std(), + ) + if save: + np.save("cross_play_losses_1.npy", np.array(cross_play_losses_1)) + np.save("cross_play_losses_2.npy", np.array(cross_play_losses_2)) + torch.save(torch.stack(theta1_dbn), "theta1_dbn.pt") + torch.save(torch.stack(theta2_dbn), "theta2_dbn.pt") + + def _add_jitter(values): + values += np.random.normal(0.0, jitter, values.shape) + return values + + final_losses_1 = _add_jitter(final_losses_1) + final_losses_2 = _add_jitter(final_losses_2) + cross_play_losses_1 = _add_jitter(cross_play_losses_1) + cross_play_losses_2 = _add_jitter(cross_play_losses_2) + + # plt.scatter(x=cross_play_losses_1, y=cross_play_losses_2) + # plt.scatter(x=final_losses_1, y=final_losses_2) + plt.plot( + cross_play_losses_1, + cross_play_losses_2, + markerfacecolor="none", + markeredgecolor="#1f77b4", + marker="o", + linestyle="None", + ) + plt.plot( + final_losses_1, + final_losses_2, + markerfacecolor="none", + markeredgecolor="#ff7f0e", + marker="o", + linestyle="None", + ) + plt.legend(["cross-play", "self-play"]) + title = ( + f"env({game}) " + f"algo({method}) " + f"hard_cutoff({hard_cutoff_in_eval}) " + f"fail({min_joint_perf}) " + f"T({temperature}) " + f"num_runs({num_runs})" + ) + plt.suptitle(title) + plt.xlim((-0.1, 3.0)) + plt.ylim((-0.1, 3.0)) + title = title.replace(" ", "_") + plt.savefig(title + ".png") + plt.show() + + return ( + final_losses_1, + final_losses_2, + cross_play_losses_1, + cross_play_losses_2, + theta1_dbn, + theta2_dbn, + ) + + +def main(debug): + + scaling_training = 1 + learn_bots_pd( + num_runs=40, + # + # method="ni", + method="lola", + # method="sos", + # + # game="pd", + # + # game="bos", + # std=3.0, + # + game="continuous", + # min_joint_perf=0.2, + # min_joint_perf=2.0, + min_joint_perf=2.25, + # + # temperature=1, + jitter=0.02, + # + save=False, + tradeoff=None, + U1=None, + U2=None, + # + # num_epochs=200 * scaling_training, + alpha=1.0 / scaling_training, + # verbose=True, + num_epochs=200 * scaling_training, + # 0.0581 -> 0.0597 + # lr_decay=True, # not helping + hard_cutoff_in_eval=True, + ) + + +if __name__ == "__main__": + debug_mode = True + main(debug_mode) diff --git a/marltoolbox/experiments/tune_class_api/sos_exact_official.py b/marltoolbox/experiments/tune_class_api/sos_exact_official.py new file mode 100644 index 0000000..5ef0534 --- /dev/null +++ b/marltoolbox/experiments/tune_class_api/sos_exact_official.py @@ -0,0 +1,269 @@ +########## +# Additional dependencies are needed: +# Follow the LOLA installation described in the +# tune_class_api/lola_pg_official.py file +########## + +import copy +import os + +import ray +from ray import tune +from ray.rllib.agents.pg import PGTorchPolicy + +from marltoolbox.algos.sos import SOSTrainer +from marltoolbox.envs.matrix_sequential_social_dilemma import ( + IteratedPrisonersDilemma, + IteratedMatchingPennies, + IteratedAsymBoS, +) +from marltoolbox.experiments.tune_class_api import lola_exact_meta_game +from marltoolbox.experiments.tune_class_api import lola_pg_official +from marltoolbox.scripts import aggregate_and_plot_tensorboard_data +from marltoolbox.utils import policy, log, miscellaneous + + +def main(debug): + hparams = get_hyperparameters(debug) + + if hparams["load_plot_data"] is None: + ray.init( + num_cpus=os.cpu_count(), + num_gpus=0, + local_mode=debug, + ) + experiment_analysis_per_welfare = train(hparams) + else: + experiment_analysis_per_welfare = None + + evaluate(experiment_analysis_per_welfare, hparams) + ray.shutdown() + + +def get_hyperparameters(debug, train_n_replicates=None, env=None): + """Get hyperparameters for LOLA-Exact for matrix games""" + + if train_n_replicates is None: + train_n_replicates = 2 if debug else int(3 * 2) + seeds = miscellaneous.get_random_seeds(train_n_replicates) + + exp_name, _ = log.log_in_current_day_dir("SOS") + + hparams = { + "debug": debug, + "load_plot_data": None, + "exp_name": exp_name, + "classify_into_welfare_fn": True, + "train_n_replicates": train_n_replicates, + "wandb": { + "project": "SOS", + "group": exp_name, + "api_key_file": os.path.join( + os.path.dirname(__file__), "../../../api_key_wandb" + ), + }, + "env_name": "IteratedAsymBoS" if env is None else env, + "lr": 1.0 / 10, + "gamma": 0.96, + "num_epochs": 5 if debug else 100, + # "method": "lola", + "method": "sos", + "inital_weights_std": 1.0, + "seed": tune.grid_search(seeds), + "metric": "mean_reward_player_row", + "plot_keys": aggregate_and_plot_tensorboard_data.PLOT_KEYS + + ["mean_reward"], + "plot_assemblage_tags": aggregate_and_plot_tensorboard_data.PLOT_ASSEMBLAGE_TAGS + + [("mean_reward",)], + "x_limits": (-0.1, 4.1), + "y_limits": (-0.1, 4.1), + "max_steps_in_eval": 100, + } + + return hparams + + +def train(hp): + tune_config, stop_config, _ = get_tune_config(hp) + # Train with the Tune Class API (not an RLLib Trainer) + experiment_analysis = tune.run( + SOSTrainer, + name=hp["exp_name"], + config=tune_config, + checkpoint_at_end=True, + stop=stop_config, + metric=hp["metric"], + mode="max", + ) + if hp["classify_into_welfare_fn"]: + experiment_analysis_per_welfare = _split_tune_results_wt_welfare( + experiment_analysis + ) + else: + experiment_analysis_per_welfare = {"": experiment_analysis} + + aggregate_and_plot_tensorboard_data.add_summary_plots( + main_path=os.path.join("~/ray_results/", tune_config["exp_name"]), + plot_keys=tune_config["plot_keys"], + plot_assemble_tags_in_one_plot=tune_config["plot_assemblage_tags"], + ) + + return experiment_analysis_per_welfare + + +def get_tune_config(hp: dict): + tune_config = copy.deepcopy(hp) + assert tune_config["env_name"] in ("IPD", "IteratedAsymBoS") + env_config = { + "players_ids": ["player_row", "player_col"], + "max_steps": hp["max_steps_in_eval"], + "get_additional_info": True, + } + tune_config["plot_axis_scale_multipliers"] = ( + ( + 1 / hp["max_steps_in_eval"], + 1 / hp["max_steps_in_eval"], + ), + ) + if "num_episodes" in tune_config: + stop_config = {"episodes_total": tune_config["num_episodes"]} + else: + stop_config = {"episodes_total": tune_config["num_epochs"]} + + return tune_config, stop_config, env_config + + +def evaluate(experiment_analysis_per_welfare, hp): + ( + rllib_hp, + rllib_config_eval, + policies_to_load, + trainable_class, + stop_config, + env_config, + ) = generate_eval_config(hp) + + lola_pg_official._evaluate_self_and_cross_perf( + rllib_hp, + rllib_config_eval, + policies_to_load, + trainable_class, + stop_config, + env_config, + experiment_analysis_per_welfare, + n_cross_play_per_checkpoint=min(15, hp["train_n_replicates"] - 1) + if hp["classify_into_welfare_fn"] + else None, + ) + + +def generate_eval_config(hp): + hp_eval = copy.deepcopy(hp) + + hp_eval["min_iter_time_s"] = 3.0 + hp_eval["seed"] = miscellaneous.get_random_seeds(1)[0] + hp_eval["batch_size"] = 1 + hp_eval["num_episodes"] = 100 + + tune_config, stop_config, env_config = get_tune_config(hp_eval) + tune_config["TuneTrainerClass"] = SOSTrainer + + hp_eval["group_names"] = ["lola"] + hp_eval["scale_multipliers"] = ( + 1 / hp_eval["max_steps_in_eval"], + 1 / hp_eval["max_steps_in_eval"], + ) + hp_eval["jitter"] = 0.05 + + if hp_eval["env_name"] == "IPD": + hp_eval["env_class"] = IteratedPrisonersDilemma + hp_eval["x_limits"] = (-3.5, 0.5) + hp_eval["y_limits"] = (-3.5, 0.5) + elif hp_eval["env_name"] == "IMP": + hp_eval["env_class"] = IteratedMatchingPennies + hp_eval["x_limits"] = (-1.0, 1.0) + hp_eval["y_limits"] = (-1.0, 1.0) + elif hp_eval["env_name"] == "IteratedAsymBoS": + hp_eval["env_class"] = IteratedAsymBoS + hp_eval["x_limits"] = (-0.1, 4.1) + hp_eval["y_limits"] = (-0.1, 4.1) + else: + raise NotImplementedError() + + rllib_config_eval = { + "env": hp_eval["env_class"], + "env_config": env_config, + "multiagent": { + "policies": { + env_config["players_ids"][0]: ( + policy.get_tune_policy_class(PGTorchPolicy), + hp_eval["env_class"](env_config).OBSERVATION_SPACE, + hp_eval["env_class"].ACTION_SPACE, + {"tune_config": copy.deepcopy(tune_config)}, + ), + env_config["players_ids"][1]: ( + policy.get_tune_policy_class(PGTorchPolicy), + hp_eval["env_class"](env_config).OBSERVATION_SPACE, + hp_eval["env_class"].ACTION_SPACE, + {"tune_config": copy.deepcopy(tune_config)}, + ), + }, + "policy_mapping_fn": lambda agent_id: agent_id, + "policies_to_train": ["None"], + }, + "seed": hp_eval["seed"], + "min_iter_time_s": hp_eval["min_iter_time_s"], + "num_workers": 0, + "num_envs_per_worker": 1, + } + + policies_to_load = copy.deepcopy(env_config["players_ids"]) + trainable_class = SOSTrainer + + return ( + hp_eval, + rllib_config_eval, + policies_to_load, + trainable_class, + stop_config, + env_config, + ) + + +def _split_tune_results_wt_welfare( + experiment_analysis, +): + experiment_analysis_per_welfare = {} + for trial in experiment_analysis.trials: + welfare_name = _get_trial_welfare(trial) + if welfare_name not in experiment_analysis_per_welfare.keys(): + _add_empty_experiment_analysis( + experiment_analysis_per_welfare, + welfare_name, + experiment_analysis, + ) + experiment_analysis_per_welfare[welfare_name].trials.append(trial) + return experiment_analysis_per_welfare + + +def _get_trial_welfare(trial): + reward_player_1 = trial.last_result["mean_reward_player_row"] + reward_player_2 = trial.last_result["mean_reward_player_col"] + welfare_name = lola_exact_meta_game.classify_into_welfare_based_on_rewards( + reward_player_1, reward_player_2 + ) + return welfare_name + + +def _add_empty_experiment_analysis( + experiment_analysis_per_welfare, welfare_name, tune_analysis +): + experiment_analysis_per_welfare[welfare_name] = copy.deepcopy( + tune_analysis + ) + experiment_analysis_per_welfare[welfare_name].trials = [] + + +if __name__ == "__main__": + debug_mode = False + main(debug_mode) diff --git a/marltoolbox/experiments/tune_class_api/various_algo_meta_game.py b/marltoolbox/experiments/tune_class_api/various_algo_meta_game.py new file mode 100644 index 0000000..c5a6c2b --- /dev/null +++ b/marltoolbox/experiments/tune_class_api/various_algo_meta_game.py @@ -0,0 +1,1779 @@ +import copy +import os +import pickle +import random + +import numpy as np +import pandas as pd +import ray +import torch +from ray import tune +from ray.rllib.agents import dqn +from ray.rllib.agents.pg import PGTrainer +from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy, pg_loss_stats + +from marltoolbox.algos import population, welfare_coordination +from marltoolbox.algos.lola.train_exact_tune_class_API import LOLAExactTrainer +from marltoolbox.algos.sos import SOSTrainer +from marltoolbox.algos.stochastic_population import StochasticPopulation +from marltoolbox.algos.welfare_coordination import MetaGameSolver +from marltoolbox.envs.matrix_sequential_social_dilemma import ( + TwoPlayersCustomizableMatrixGame, +) +from marltoolbox.examples.rllib_api import pg_ipd +from marltoolbox.experiments.rllib_api import amtft_meta_game +from marltoolbox.experiments.rllib_api import amtft_various_env +from marltoolbox.experiments.tune_class_api import ( + lola_exact_meta_game, + lola_exact_official, +) +from marltoolbox.experiments.tune_class_api import sos_exact_official +from marltoolbox.scripts import ( + aggregate_and_plot_tensorboard_data, + plot_meta_policies, +) +from marltoolbox.utils import log, miscellaneous, callbacks, exp_analysis, path + +EPSILON = 1e-6 +POLICY_ID_PL0 = "player_row" +POLICY_ID_PL1 = "player_col" + + +def main(debug, base_game_algo=None, meta_game_algo=None): + """Evaluate meta game performances""" + + train_n_replicates = 1 + seeds = miscellaneous.get_random_seeds(train_n_replicates) + exp_name, _ = log.log_in_current_day_dir("meta_game_compare") + + ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug) + hparams = _get_hyperparameters( + debug, seeds, exp_name, base_game_algo, meta_game_algo + ) + ( + hparams["payoff_matrices"], + hparams["actions_possible"], + hparams["base_ckpt_per_replicat"], + _, + ) = _form_n_matrices_from_base_game_payoffs(hparams) + hparams["meta_game_policy_distributions"] = _train_meta_policies(hparams) + tune_analysis, hp_eval = _evaluate_in_base_game(hparams) + ray.shutdown() + + _extract_metric_and_log_and_plot( + tune_analysis, + hparams, + hp_eval, + title=f"BASE({base_game_algo}) META({meta_game_algo})", + ) + + +def _extract_metric_and_log_and_plot( + tune_analysis, hparams, hp_eval, title=None +): + ( + mean_player_1_payoffs, + mean_player_2_payoffs, + ) = _extract_metrics(tune_analysis, hparams) + results = [] + for player1_avg_r_one_replicate, player2_avg_r_one_replicate in zip( + mean_player_1_payoffs, mean_player_2_payoffs + ): + results.append( + (player1_avg_r_one_replicate, player2_avg_r_one_replicate) + ) + result_to_json = {"results": copy.deepcopy(results)} + coordination_success = _extract_coordination_metric(tune_analysis) + print("coordination_success", coordination_success) + result_to_json["coordination_success"] = np.array( + coordination_success + ).tolist() + result_to_json["mean_coordination_success"] = np.array( + coordination_success + ).mean() + + amtft_meta_game.save_to_json( + exp_name=hparams["exp_name"], object=result_to_json + ) + amtft_meta_game.plot_results( + exp_name=hparams["exp_name"], + results=results, + hp_eval=hp_eval, + format_fn=_format_result_for_plotting, + jitter=0.05, + title=title, + ) + + +BASE_AMTFT = "amTFT" +BASE_LOLA_EXACT = "base LOLA-Exact" +BASE_NEGOTIATION = "base negociation" +META_LOLA_EXACT = "meta LOLA-Exact" +META_PG = "PG" +META_SOS = "SOS" +META_APLHA_RANK = "alpha-rank" +META_APLHA_PURE = "alpha-rank pure strategy" +META_REPLICATOR_DYNAMIC = "replicator dynamic" +META_REPLICATOR_DYNAMIC_ZERO_INIT = "replicator dynamic with zero init" +META_RANDOM = "Random" +META_UNIFORM = "Robustness tau=0.0" +META_MINIMUM = "Minimum" + + +def _get_hyperparameters( + debug, seeds, exp_name, base_game_algo=None, meta_game_algo=None +): + hp = { + # "base_game_policies": BASE_AMTFT, + # "base_game_policies": BASE_LOLA_EXACT, + "base_game_policies": BASE_NEGOTIATION, + # + # "meta_game_policies": META_PG, + # "meta_game_policies": META_LOLA_EXACT, + # "meta_game_policies": META_APLHA_RANK, + # "meta_game_policies": META_APLHA_PURE, + # "meta_game_policies": META_REPLICATOR_DYNAMIC, + # "meta_game_policies": META_REPLICATOR_DYNAMIC_ZERO_INIT, + "meta_game_policies": META_RANDOM, + # "meta_game_policies": META_UNIFORM, + # + "apply_announcement_protocol": True, + "negotitation_process": 2, + # + "players_ids": ["player_row", "player_col"], + "use_r2d2": True, + } + + if base_game_algo is not None: + hp["base_game_policies"] = base_game_algo + if meta_game_algo is not None: + hp["meta_game_policies"] = meta_game_algo + + hp["load_meta_game_payoff_matrices"] = ( + hp["base_game_policies"] == BASE_NEGOTIATION + ) + hp["evaluate_meta_policies_reading_meta_game_payoff_matrices"] = ( + hp["base_game_policies"] == BASE_NEGOTIATION + ) + + if hp["load_meta_game_payoff_matrices"]: + assert hp["base_game_policies"] == BASE_NEGOTIATION + if hp["evaluate_meta_policies_reading_meta_game_payoff_matrices"]: + assert hp["base_game_policies"] == BASE_NEGOTIATION + + if hp["base_game_policies"] == BASE_AMTFT: + hp.update( + amtft_meta_game.get_hyperparameters( + debug=debug, use_r2d2=hp["use_r2d2"] + ) + ) + elif hp["base_game_policies"] == BASE_LOLA_EXACT: + hp.update(lola_exact_meta_game.get_hyperparameters(debug=debug)) + elif hp["base_game_policies"] == BASE_NEGOTIATION: + hp.update(_get_negociation_hyperparameters(debug=debug)) + assert hp["evaluate_meta_policies_reading_meta_game_payoff_matrices"] + else: + raise ValueError() + + hp.update( + { + "debug": debug, + "seeds": seeds, + "exp_name": exp_name, + "wandb": { + "project": "meta_game_compare", + "group": exp_name, + "api_key_file": os.path.join( + os.path.dirname(__file__), "../../../api_key_wandb" + ), + }, + } + ) + + players_ids = ["player_row", "player_col"] + if hp["base_game_policies"] == BASE_NEGOTIATION: + hp["x_axis_metric"] = f"policy_reward_mean.{players_ids[0]}" + hp["y_axis_metric"] = f"policy_reward_mean.{players_ids[1]}" + else: + hp["x_axis_metric"] = f"policy_reward_mean/{players_ids[0]}" + hp["y_axis_metric"] = f"policy_reward_mean/{players_ids[1]}" + + return hp + + +payoffs_per_groups = None + + +def _get_negociation_hyperparameters(debug): + hp = { + "n_replicates_over_full_exp": 2 if debug else 40, + "n_self_play_in_final_meta_game": 0, + "n_cross_play_in_final_meta_game": 1 if debug else 10, + "env_name": "Nogetiation", + "x_limits": (0.0, 1.0), + "y_limits": (0.0, 1.0), + "plot_axis_scale_multipliers": (1, 1), + "plot_keys": aggregate_and_plot_tensorboard_data.PLOT_KEYS, + "plot_assemblage_tags": aggregate_and_plot_tensorboard_data.PLOT_ASSEMBLAGE_TAGS, + "data_prefix": ( + "/home/maxime/ssd_space/CLR/marltoolbox/marltoolbox/experiments" + "/tune_class_api/" + ), + # "data_prefix": ( + # "/home/maxime-riche/marltoolbox/marltoolbox/experiments" + # "/tune_class_api/" + # ), + } + return hp + + +def _form_n_matrices_from_base_game_payoffs(hp): + global payoffs_per_groups + if hp["load_meta_game_payoff_matrices"]: + ( + payoffs_matrices, + actions_possible, + base_ckpt_per_replicat, + payoffs_per_groups, + ) = _load_payoffs_matrices(hp) + return ( + payoffs_matrices, + actions_possible, + base_ckpt_per_replicat, + ) + payoffs_per_groups = _get_payoffs_for_every_group_of_base_game_replicates( + hp + ) + # In eval + if hp["base_game_policies"] == BASE_AMTFT: + if hp["use_r2d2"]: + # I removed the change to 100 step in eval when I moved to R2D2 + n_steps_per_epi = 20 + else: + n_steps_per_epi = 100 + elif hp["base_game_policies"] == BASE_LOLA_EXACT: + if hp["debug"]: + n_steps_per_epi = 40 + else: + n_steps_per_epi = 1 # in fact 200 but the payoffs are already avg + else: + raise ValueError() + + if hp["apply_announcement_protocol"]: + ( + payoffs_matrices, + actions_possible, + base_ckpt_per_replicat, + ) = _aggregate_payoffs_groups_into_matrices_wt_announcement_protocol( + payoffs_per_groups, n_steps_per_epi, hp + ) + else: + ( + payoffs_matrices, + actions_possible, + base_ckpt_per_replicat, + ) = _aggregate_payoffs_groups_into_matrices( + payoffs_per_groups, n_steps_per_epi, hp + ) + + assert len(payoffs_matrices) == hp["n_replicates_over_full_exp"] + return ( + payoffs_matrices, + actions_possible, + base_ckpt_per_replicat, + payoffs_per_groups, + ) + + +def _get_payoffs_for_every_group_of_base_game_replicates(hp): + if hp["base_game_policies"] == BASE_AMTFT: + module = amtft_meta_game + elif hp["base_game_policies"] == BASE_LOLA_EXACT: + module = lola_exact_meta_game + else: + raise ValueError(f'base_game_policies {hp["base_game_policies"]}') + + payoffs_per_groups = [] + for i in range(hp["n_replicates_over_full_exp"]): + hp_replicate_i = module._load_base_game_results( + copy.deepcopy(hp), load_base_replicate_i=i + ) + + all_welfare_pairs_wt_payoffs = ( + module._get_all_welfare_pairs_wt_cross_play_payoffs( + hp_replicate_i, hp_replicate_i["players_ids"] + ) + ) + payoffs_per_groups.append( + (all_welfare_pairs_wt_payoffs, hp_replicate_i) + ) + return payoffs_per_groups + + +def _aggregate_payoffs_groups_into_matrices_wt_announcement_protocol( + payoffs_per_groups, n_steps_per_epi, hp +): + payoffs_matrices = [] + ckpt_per_replicat = [] + previous_welfare_fn_sets = None + for i, (payoffs_for_one_group, hp_replicat_i) in enumerate( + payoffs_per_groups + ): + ckpt_per_replicat.append(hp_replicat_i["ckpt_per_welfare"]) + + announcement_protocol_solver_p1 = welfare_coordination.MetaGameSolver() + announcement_protocol_solver_p1.setup_meta_game( + payoffs_per_groups[i][0], + own_player_idx=0, + opp_player_idx=1, + own_default_welfare_fn="utilitarian", + opp_default_welfare_fn="inequity aversion" + if hp["base_game_policies"] == BASE_AMTFT + else "egalitarian", + ) + + welfare_fn_sets = announcement_protocol_solver_p1.welfare_fn_sets + if previous_welfare_fn_sets is not None: + assert welfare_fn_sets == previous_welfare_fn_sets + previous_welfare_fn_sets = welfare_fn_sets + print("\nwelfare_fn_sets", welfare_fn_sets) + n_set_of_welfare_sets = len(welfare_fn_sets) + payoff_matrix = np.empty( + shape=(n_set_of_welfare_sets, n_set_of_welfare_sets, 2), + dtype=np.float, + ) + for own_welfare_set_idx, own_welfare_set_announced in enumerate( + welfare_fn_sets + ): + for opp_welfare_set_idx, opp_wefare_set in enumerate( + welfare_fn_sets + ): + cell_payoffs = ( + announcement_protocol_solver_p1._compute_meta_payoff( + own_welfare_set_announced, opp_wefare_set + ) + ) + payoff_matrix[own_welfare_set_idx, opp_welfare_set_idx, 0] = ( + cell_payoffs[0] / n_steps_per_epi + ) + payoff_matrix[own_welfare_set_idx, opp_welfare_set_idx, 1] = ( + cell_payoffs[1] / n_steps_per_epi + ) + amtft_meta_game.save_to_json( + exp_name=hp["exp_name"], + object={ + "welfare_fn_sets": str(welfare_fn_sets), + "payoff_matrix": payoff_matrix.tolist(), + }, + filename=f"payoffs_matrices_{i}.json", + ) + payoffs_matrices.append(payoff_matrix) + return payoffs_matrices, welfare_fn_sets, ckpt_per_replicat + + +def _aggregate_payoffs_groups_into_matrices( + payoffs_per_groups, n_steps_per_epi, hp +): + payoff_matrices = [] + ckpt_per_replicat = [] + all_welfares_fn = None + for i, (payoffs_for_one_group, hp_replicat_i) in enumerate( + payoffs_per_groups + ): + ( + one_payoff_matrice, + tmp_all_welfares_fn, + ) = _aggregate_payoffs_in_one_matrix( + payoffs_for_one_group, n_steps_per_epi + ) + amtft_meta_game.save_to_json( + exp_name=hp["exp_name"], + object=one_payoff_matrice.tolist(), + filename=f"payoffs_matrices_{i}.json", + ) + payoff_matrices.append(one_payoff_matrice) + ckpt_per_replicat.append(hp_replicat_i["ckpt_per_welfare"]) + if all_welfares_fn is None: + all_welfares_fn = tmp_all_welfares_fn + assert len(all_welfares_fn) == len( + tmp_all_welfares_fn + ), f"{len(all_welfares_fn)} == {len(tmp_all_welfares_fn)}" + return payoff_matrices, all_welfares_fn, ckpt_per_replicat + + +def _aggregate_payoffs_in_one_matrix(payoffs_for_one_group, n_steps_per_epi): + all_welfares_fn = MetaGameSolver.list_all_welfares_fn( + payoffs_for_one_group + ) + all_welfares_fn = sorted(tuple(all_welfares_fn)) + n_welfare_fn = len(all_welfares_fn) + payoff_matrix = np.empty( + shape=(n_welfare_fn, n_welfare_fn, 2), dtype=np.float + ) + for row_i, welfare_player_1 in enumerate(all_welfares_fn): + for col_i, welfare_player_2 in enumerate(all_welfares_fn): + welfare_pair_name = ( + MetaGameSolver.from_pair_of_welfare_names_to_key( + welfare_player_1, welfare_player_2 + ) + ) + payoff_matrix[row_i, col_i, 0] = ( + payoffs_for_one_group[welfare_pair_name][0] / n_steps_per_epi + ) + payoff_matrix[row_i, col_i, 1] = ( + payoffs_for_one_group[welfare_pair_name][1] / n_steps_per_epi + ) + return payoff_matrix, all_welfares_fn + + +def _load_payoffs_matrices(hp): + if hp["base_game_policies"] == BASE_NEGOTIATION: + if hp["negotitation_process"] == 1: + return _load_payoffs_matrices_negotiation_process1(hp) + elif hp["negotitation_process"] == 2: + return _load_payoffs_matrices_negotiation_process2(hp) + else: + raise NotImplementedError() + + +def _load_payoffs_matrices_negotiation_process1(hp): + file_path = os.path.join( + hp["data_prefix"], "negociation_game_replicates_list.pickle" + ) + with open(file_path, "rb") as f: + content = pickle.load(f) + + principal0_replicates = content[:20] + hp["principal0_replicates"] = principal0_replicates + payoffs_matrices = [] + n_actions = None + for i, (mat_pl0, mat_pl1, _, _) in enumerate(principal0_replicates): + if n_actions is None: + n_actions = mat_pl0.shape[0] + assert n_actions == mat_pl0.shape[0] == mat_pl0.shape[1] + assert n_actions == mat_pl1.shape[0] == mat_pl1.shape[1] + payoffs_matrices.append(np.stack([mat_pl0, mat_pl1], axis=-1)) + print("payoffs_matrices[0].shape", payoffs_matrices[0].shape) + + actions_possible = [] + for i in range(n_actions): + actions_possible.append(f"meta_action_{str(i)}") + + base_ckpt_per_replicat = None + payoffs_per_groups = None + + return ( + payoffs_matrices, + actions_possible, + base_ckpt_per_replicat, + payoffs_per_groups, + ) + + +def _load_payoffs_matrices_negotiation_process2(hp): + file_path = os.path.join( + hp["data_prefix"], "bootstrapped_replicates_prosociality_coeff_0.3" + ) + with open(file_path, "rb") as f: + content = pickle.load(f) + load_from_n_principals = 2 + load_n_bootstrapped_by_principals = ( + hp["n_replicates_over_full_exp"] // load_from_n_principals + ) + print("load_from_n_principals", load_from_n_principals) + print( + "load_n_bootstrapped_by_principals", load_n_bootstrapped_by_principals + ) + all_principal_replicates = [] + meta_game_matrix_pos_to_principal_idx = [] + for principal_i in range(load_from_n_principals): + principal_i_replicates = content[principal_i][ + :load_n_bootstrapped_by_principals + ] + hp[f"principal{principal_i}_replicates"] = principal_i_replicates + all_principal_replicates.extend(principal_i_replicates) + meta_game_matrix_pos_to_principal_idx.extend( + [principal_i] * len(principal_i_replicates) + ) + hp[ + "meta_game_matrix_pos_to_principal_idx" + ] = meta_game_matrix_pos_to_principal_idx + hp["load_n_bootstrapped_by_principals"] = load_n_bootstrapped_by_principals + payoffs_matrices = [] + n_actions = None + for i, (mat_pl0, _, _) in enumerate(all_principal_replicates): + if n_actions is None: + n_actions = mat_pl0.shape[0] + assert n_actions == mat_pl0.shape[0] == mat_pl0.shape[1] + payoffs_matrices.append(mat_pl0) + # print("payoffs_matrices[0].shape", payoffs_matrices[0].shape) + + actions_possible = [] + for i in range(n_actions): + actions_possible.append(f"meta_action_{str(i)}") + + base_ckpt_per_replicat = None + payoffs_per_groups = None + + return ( + payoffs_matrices, + actions_possible, + base_ckpt_per_replicat, + payoffs_per_groups, + ) + + +def _train_meta_policies(hp): + hp["exp_name"] = os.path.join(hp["exp_name"], "meta_game") + + if hp["meta_game_policies"] == META_PG: + meta_policies = _train_meta_policy_using_pg(hp) + elif hp["meta_game_policies"] == META_LOLA_EXACT: + meta_policies = _train_meta_policy_using_lola_exact(hp) + elif hp["meta_game_policies"] == META_APLHA_RANK: + meta_policies = _train_meta_policy_using_alpha_rank(hp) + elif hp["meta_game_policies"] == META_APLHA_PURE: + meta_policies = _train_meta_policy_using_alpha_rank( + hp, pure_strategy=True + ) + elif ( + hp["meta_game_policies"] == META_REPLICATOR_DYNAMIC + or hp["meta_game_policies"] == META_REPLICATOR_DYNAMIC_ZERO_INIT + ): + meta_policies = _train_meta_policy_using_replicator_dynamic(hp) + elif hp["meta_game_policies"] == META_RANDOM: + meta_policies = _get_random_meta_policy(hp) + elif hp["meta_game_policies"] == META_SOS: + meta_policies = _train_meta_policy_using_sos_exact(hp) + elif hp["meta_game_policies"] == META_UNIFORM: + meta_policies = _train_meta_policy_using_robustness(hp) + elif hp["meta_game_policies"] == META_MINIMUM: + meta_policies = _train_meta_policy_using_minimum(hp) + else: + raise ValueError(hp["meta_game_policies"]) + + if hp["base_game_policies"] == BASE_NEGOTIATION: + if hp["meta_game_policies"] != META_RANDOM: + meta_policies = ( + _convert_negociation_meta_policies_to_original_space( + meta_policies, hp + ) + ) + + clamped_meta_policies = _clamp_policies_normalize(meta_policies) + + amtft_meta_game.save_to_json( + exp_name=hp["exp_name"], + object={ + "clamped_meta_policies": _convert_to_list(clamped_meta_policies), + "meta_policies": _convert_to_list(meta_policies), + }, + filename=f"meta_policies.json", + ) + + exp_dir_path = path.get_exp_dir_from_exp_name(hp["exp_name"]) + plot_meta_policies.plot_policies( + _convert_to_list(meta_policies), + hp["actions_possible"], + title=f'META({hp["meta_game_policies"]}) BASE(' + f'{hp["base_game_policies"]})', + path_prefix=exp_dir_path + "/", + announcement_protocol=hp["apply_announcement_protocol"] + and hp["base_game_policies"] != BASE_NEGOTIATION, + ) + + return clamped_meta_policies + + +def _convert_negociation_meta_policies_to_original_space(meta_policies, hp): + meta_policies_in_original_space = [] + for meta_pi_idx, meta_policy in enumerate(meta_policies): + policy_player_1 = _order_meta_game_policy( + meta_policy["player_row"], hp, meta_pi_idx, POLICY_ID_PL0 + ) + policy_player_2 = _order_meta_game_policy( + meta_policy["player_col"], hp, meta_pi_idx, POLICY_ID_PL1 + ) + policy_player_1 = torch.tensor(policy_player_1) + policy_player_2 = torch.tensor(policy_player_2) + meta_policy_in_original_space = { + "player_row": policy_player_1, + "player_col": policy_player_2, + } + meta_policies_in_original_space.append(meta_policy_in_original_space) + return meta_policies_in_original_space + + +def _order_meta_game_policy(meta_pi, hp, meta_policy_idx, player_id): + """Assemble policies to fit the order in the original pmeta game payoff + matrix (not the bootstrapepd one)""" + if hp["negotitation_process"] == 1: + return _order_meta_game_policy_process1( + meta_pi, hp, meta_policy_idx, player_id + ) + elif hp["negotitation_process"] == 2: + return _order_meta_game_policy_process2( + meta_pi, hp, meta_policy_idx, player_id + ) + + +def _order_meta_game_policy_process1(meta_pi, hp, meta_policy_idx, player_id): + _, x_indices, y_indices = hp["principal0_replicates"][meta_policy_idx] + if player_id == "player_row": + indices = x_indices + elif player_id == "player_col": + indices = y_indices + else: + raise ValueError() + meta_pi_original_space = np.zeros_like(meta_pi) + for i, val in enumerate(meta_pi): + original_index = indices[i] + meta_pi_original_space[original_index] += val + + return meta_pi_original_space + + +def _order_meta_game_policy_process2(meta_pi, hp, meta_policy_idx, player_id): + principal_idx = hp["meta_game_matrix_pos_to_principal_idx"][ + meta_policy_idx + ] + idx_in_principal = ( + meta_policy_idx % hp["load_n_bootstrapped_by_principals"] + ) + _, x_indices, y_indices = hp[f"principal{principal_idx}_replicates"][ + idx_in_principal + ] + print("principal_idx", principal_idx) + print("idx_in_principal", idx_in_principal) + if player_id == "player_row": + indices = x_indices + elif player_id == "player_col": + indices = y_indices + else: + raise ValueError() + meta_pi_original_space = np.zeros_like(meta_pi) + for i, val in enumerate(meta_pi): + original_index = indices[i] + meta_pi_original_space[original_index] += val + + return meta_pi_original_space + + +def _convert_to_list(list_dict_tensors): + return [ + {k: v.tolist() for k, v in dict_.items()} + for dict_ in list_dict_tensors + ] + + +def _clamp_policies_normalize(meta_policies): + for i in range(len(meta_policies)): + for player_key, player_meta_pi in meta_policies[i].items(): + assert not ( + any(player_meta_pi > 1.01) or any(player_meta_pi < -0.01) + ), f"player_meta_pi {player_meta_pi}" + player_meta_pi = player_meta_pi / player_meta_pi.sum() + meta_policies[i][player_key] = player_meta_pi.clamp( + min=0.0, max=1.0 + ) + return meta_policies + + +def _train_meta_policy_using_pg(hp): + rllib_config, stop_config = pg_ipd.get_rllib_config( + hp["seeds"], hp["debug"] + ) + rllib_config, stop_config = _modify_rllib_config_for_meta_pg_policy( + rllib_config, stop_config, hp + ) + + tune_analysis = _train_with_tune(rllib_config, stop_config, hp, PGTrainer) + + return _extract_policy_pg(tune_analysis) + + +def _extract_policy_pg(tune_analysis): + policies = [] + for trial in tune_analysis.trials: + next_act_distrib_idx = 0 + p1_act_distrib = [] + p2_act_distrib = [] + p1_info = trial.last_result["info"]["learner"]["player_row"] + p2_info = trial.last_result["info"]["learner"]["player_col"] + prefix = "act_dist_inputs_single_act" + while True: + p1_act_distrib.append(p1_info[f"{prefix}{next_act_distrib_idx}"]) + p2_act_distrib.append(p2_info[f"{prefix}{next_act_distrib_idx}"]) + next_act_distrib_idx += 1 + if f"{prefix}{next_act_distrib_idx}" not in p1_info.keys(): + break + policy_player_1 = torch.softmax(torch.tensor(p1_act_distrib), dim=0) + policy_player_2 = torch.softmax(torch.tensor(p2_act_distrib), dim=0) + policies.append( + {"player_row": policy_player_1, "player_col": policy_player_2} + ) + print("PG meta policy extracted") + print("policy_player_1 ", policy_player_1) + print("policy_player_2 ", policy_player_2) + return policies + + +def _modify_rllib_config_for_meta_pg_policy(rllib_config, stop_config, hp): + rllib_config["env"] = TwoPlayersCustomizableMatrixGame + rllib_config["env_config"]["NUM_ACTIONS"] = len(hp["actions_possible"]) + rllib_config["env_config"]["max_steps"] = 1 + rllib_config["model"] = { + # Number of hidden layers for fully connected net + "fcnet_hiddens": [64], + # Nonlinearity for fully connected net (tanh, relu) + "fcnet_activation": "relu", + } + rllib_config["lr"] = 0.003 + stop_config["episodes_total"] = 10 if hp["debug"] else 8000 + + rllib_config["env_config"]["linked_data"] = _get_payoff_matrix_grid_search( + hp + ) + rllib_config["seed"] = tune.sample_from( + lambda spec: spec.config["env_config"]["linked_data"][0] + ) + rllib_config["env_config"]["PAYOFF_MATRIX"] = tune.sample_from( + lambda spec: spec.config["env_config"]["linked_data"][1] + ) + + rllib_config = _dynamicaly_change_policies_spaces(hp, rllib_config) + + return rllib_config, stop_config + + +def _dynamicaly_change_policies_spaces(hp, rllib_config): + MyPGTorchPolicy = PGTorchPolicy.with_updates( + stats_fn=log.augment_stats_fn_wt_additionnal_logs( + stats_function=pg_loss_stats + ) + ) + + tmp_env_config = copy.deepcopy(rllib_config["env_config"]) + tmp_env_config["PAYOFF_MATRIX"] = hp["payoff_matrices"][0] + tmp_env = rllib_config["env"](tmp_env_config) + for policy_id, policy_config in rllib_config["multiagent"][ + "policies" + ].items(): + policy_config = list(policy_config) + policy_config[0] = MyPGTorchPolicy + policy_config[1] = tmp_env.OBSERVATION_SPACE + policy_config[2] = tmp_env.ACTION_SPACE + rllib_config["multiagent"]["policies"][policy_id] = tuple( + policy_config + ) + return rllib_config + + +def _train_meta_policy_using_lola_exact(hp): + lola_exact_hp = lola_exact_meta_game.get_hyperparameters(hp["debug"]) + + tune_config, stop_config, _ = lola_exact_official.get_tune_config( + lola_exact_hp + ) + + tune_config, stop_config = _modify_tune_config_for_meta_lola_exact( + hp, tune_config, stop_config + ) + + tune_analysis = _train_with_tune( + tune_config, stop_config, hp, LOLAExactTrainer + ) + return _extract_policy_lola_exact(tune_analysis) + + +def _extract_policy_lola_exact(tune_analysis): + policies = [] + for trial in tune_analysis.trials: + policy_player_1 = trial.last_result["policy1"][-1, :] + policy_player_2 = trial.last_result["policy2"][-1, :] + policy_player_1 = torch.tensor(policy_player_1) + policy_player_2 = torch.tensor(policy_player_2) + policies.append( + {"player_row": policy_player_1, "player_col": policy_player_2} + ) + print("LOLA-Exact meta policy extracted") + print("policy_player_1 ", policy_player_1) + print("policy_player_2 ", policy_player_2) + return policies + + +def _train_meta_policy_using_sos_exact(hp): + lola_exact_hp = sos_exact_official.get_hyperparameters(hp["debug"]) + + tune_config, stop_config, _ = sos_exact_official.get_tune_config( + lola_exact_hp + ) + + tune_config, stop_config = _modify_tune_config_for_meta_sos_exact( + hp, tune_config, stop_config + ) + + tune_analysis = _train_with_tune(tune_config, stop_config, hp, SOSTrainer) + return _extract_policy_sos_exact(tune_analysis) + + +def _modify_tune_config_for_meta_sos_exact(hp, tune_config, stop_config): + tune_config["env_name"] = None + tune_config["method"] = "sos" + tune_config["linked_data"] = _get_payoff_matrix_grid_search(hp) + tune_config["seed"] = tune.sample_from( + lambda spec: spec.config["linked_data"][0] + ) + tune_config["custom_payoff_matrix"] = tune.sample_from( + lambda spec: spec.config["linked_data"][1] + ) + + return tune_config, stop_config + + +def _extract_policy_sos_exact(tune_analysis): + policies = [] + for trial in tune_analysis.trials: + policy_player_1 = trial.last_result["policy1"][0, :] + policy_player_2 = trial.last_result["policy2"][0, :] + policy_player_1 = torch.tensor(policy_player_1) + policy_player_2 = torch.tensor(policy_player_2) + policies.append( + {"player_row": policy_player_1, "player_col": policy_player_2} + ) + print("SOS-Exact meta policy extracted") + print("policy_player_1 ", policy_player_1) + print("policy_player_2 ", policy_player_2) + return policies + + +def _train_with_tune( + rllib_config, + stop_config, + hp, + trainer, + plot_aggregates=True, +): + tune_analysis = tune.run( + trainer, + config=rllib_config, + stop=stop_config, + name=hp["exp_name"], + ) + + if not hp["debug"] and plot_aggregates: + aggregate_and_plot_tensorboard_data.add_summary_plots( + main_path=os.path.join("~/ray_results/", hp["exp_name"]), + plot_keys=hp["plot_keys"], + plot_assemble_tags_in_one_plot=hp["plot_assemblage_tags"], + ) + return tune_analysis + + +def _modify_tune_config_for_meta_lola_exact(hp, tune_config, stop_config): + stop_config["episodes_total"] *= tune_config["trace_length"] + tune_config["re_init_every_n_epi"] *= tune_config["trace_length"] + tune_config["trace_length"] = 1 + tune_config["env_name"] = "custom_payoff_matrix" + tune_config["linked_data"] = _get_payoff_matrix_grid_search(hp) + tune_config["seed"] = tune.sample_from( + lambda spec: spec.config["linked_data"][0] + ) + tune_config["custom_payoff_matrix"] = tune.sample_from( + lambda spec: spec.config["linked_data"][1] + ) + + return tune_config, stop_config + + +def _get_payoff_matrix_grid_search(hp): + payoff_matrices = copy.deepcopy(hp["payoff_matrices"]) + seeds = miscellaneous.get_random_seeds(len(payoff_matrices)) + linked_data = [ + (seed, matrix) for seed, matrix in zip(seeds, payoff_matrices) + ] + return tune.grid_search(linked_data) + + +def _evaluate_in_base_game(hp): + hp["exp_name"] = os.path.join(hp["exp_name"], "final_base_game") + assert hp["n_replicates_over_full_exp"] > 0 + + if hp["evaluate_meta_policies_reading_meta_game_payoff_matrices"]: + if hp["base_game_policies"] == BASE_NEGOTIATION: + if hp["negotitation_process"] == 1: + return _evaluate_by_reading_meta_payoff_matrices_process1(hp) + elif hp["negotitation_process"] == 2: + return _evaluate_by_reading_meta_payoff_matrices_process2(hp) + else: + raise NotImplementedError() + else: + return _evaluate_by_playing_in_base_game(hp) + + +ORIGNAL_NEGOTIATION_PAYOFFS = np.stack( + [ + np.array( + [ + [ + 0.73076171, + 0.72901064, + 0.73216635, + 0.66135818, + 0.56392801, + 0.46451038, + 0.45376301, + ], + [ + 0.38869923, + 0.38534233, + 0.37598014, + 0.73307645, + 0.72625536, + 0.73916543, + 0.72769892, + ], + [ + 0.7302742, + 0.73917222, + 0.72813082, + 0.66200578, + 0.66214889, + 0.56320196, + 0.56070906, + ], + [ + 0.46352547, + 0.45852849, + 0.45295322, + 0.39153525, + 0.38613006, + 0.38181722, + 0.37838927, + ], + [ + 0.64827693, + 0.65012288, + 0.6471073, + 0.6198647, + 0.57871825, + 0.51521456, + 0.51886076, + ], + [ + 0.47413245, + 0.59999102, + 0.59605861, + 0.59867841, + 0.59904635, + 0.57479459, + 0.54586047, + ], + [ + 0.50140649, + 0.47044769, + 0.46821448, + 0.57101083, + 0.56979507, + 0.57035708, + 0.57534903, + ], + ] + ), + np.array( + [ + [ + 0.30082142, + 0.29890773, + 0.3002491, + 0.30846024, + 0.30522716, + 0.29359573, + 0.28792399, + ], + [ + 0.27443337, + 0.27473691, + 0.27099788, + 0.29374766, + 0.29737911, + 0.29684028, + 0.30397776, + ], + [ + 0.30072445, + 0.29653731, + 0.3035537, + 0.3089233, + 0.31072533, + 0.30742073, + 0.30733863, + ], + [ + 0.28808233, + 0.29138255, + 0.29195258, + 0.27400267, + 0.27413243, + 0.26937005, + 0.26969171, + ], + [ + 0.38675746, + 0.38473225, + 0.39369282, + 0.40118337, + 0.4028841, + 0.39932922, + 0.40290347, + ], + [ + 0.3954556, + 0.39829046, + 0.39688903, + 0.39803576, + 0.39812419, + 0.41143295, + 0.41318333, + ], + [ + 0.41149494, + 0.41419694, + 0.41467997, + 0.4051294, + 0.39407712, + 0.40370235, + 0.40058476, + ], + ] + ), + ], + axis=-1, +) + + +def _evaluate_by_reading_meta_payoff_matrices_process1(hp): + all_meta_games_idx = list(range(hp["n_replicates_over_full_exp"])) + trials_results = [] + for _ in all_meta_games_idx: + # TODO don't use this ppayoff table + payoff_matrix = copy.deepcopy(ORIGNAL_NEGOTIATION_PAYOFFS) + meta_games_idx_available = copy.deepcopy(all_meta_games_idx) + + for cross_play_idx in range(hp["n_cross_play_in_final_meta_game"]): + ( + meta_pi_pl0, + meta_pi_pl1, + meta_pi_pl0_idx, + meta_pi_pl1_idx, + ) = _select_cross_play_meta_policies_for_payoff_reading( + hp, meta_games_idx_available + ) + + min_, max_ = payoff_matrix.min(), payoff_matrix.max() + joint_proba = _compute_joint_meta_pi_proba( + meta_pi_pl0, meta_pi_pl1 + ) + weighted_avg = payoff_matrix * joint_proba + payoff_per_player = np.sum(np.sum(weighted_avg, axis=0), axis=0) + assert np.all( + payoff_per_player - max_ < EPSILON + ), f"{payoff_per_player - max_}" + assert np.all( + payoff_per_player - min_ > -EPSILON + ), f"{payoff_per_player - min_}" + trials_results.append( + { + "policy_reward_mean.player_row": payoff_per_player[0], + "policy_reward_mean.player_col": payoff_per_player[1], + } + ) + + fake_experiment_analysis = ( + exp_analysis.create_fake_experiment_analysis_wt_metrics_only( + trials_results + ) + ) + return fake_experiment_analysis, hp + + +def _evaluate_by_reading_meta_payoff_matrices_process2(hp): + file_path = os.path.join( + hp["data_prefix"], "empirical_game_matrices_prosociality_coeff_0.3" + ) + with open(file_path, "rb") as f: + payoff_matrix_35x35 = pickle.load(f) + + all_meta_games_idx = list(range(hp["n_replicates_over_full_exp"])) + trials_results = [] + for meta_games_idx in all_meta_games_idx: + print("meta_games_idx", meta_games_idx) + ( + meta_pi_pl0, + meta_games_idx_available_pl1, + pl0_principal_i, + ) = _load_pi_pl_0_and_available_pi_pl1( + hp, meta_games_idx, all_meta_games_idx + ) + + for cross_play_idx in range(hp["n_cross_play_in_final_meta_game"]): + payoff_matrix, meta_pi_pl1 = _load_pi_pl1_and_payoff_matrix( + hp, + meta_games_idx_available_pl1, + pl0_principal_i, + payoff_matrix_35x35, + ) + + min_, max_ = payoff_matrix.min(), payoff_matrix.max() + joint_proba = _compute_joint_meta_pi_proba( + meta_pi_pl0, meta_pi_pl1 + ) + weighted_avg = payoff_matrix * joint_proba + payoff_per_player = np.sum(np.sum(weighted_avg, axis=0), axis=0) + assert np.all( + payoff_per_player - max_ < EPSILON + ), f"{payoff_per_player - max_}" + assert np.all( + payoff_per_player - min_ > -EPSILON + ), f"{payoff_per_player - min_}" + trials_results.append( + { + "policy_reward_mean.player_row": payoff_per_player[0], + "policy_reward_mean.player_col": payoff_per_player[1], + } + ) + + fake_experiment_analysis = ( + exp_analysis.create_fake_experiment_analysis_wt_metrics_only( + trials_results + ) + ) + return fake_experiment_analysis, hp + + +def _load_pi_pl_0_and_available_pi_pl1(hp, meta_games_idx, all_meta_games_idx): + meta_pi_pl0_idx = meta_games_idx + pl0_principal_i = hp["meta_game_matrix_pos_to_principal_idx"][ + meta_pi_pl0_idx + ] + meta_pi_pl0 = hp["meta_game_policy_distributions"][meta_pi_pl0_idx][ + POLICY_ID_PL0 + ] + print( + "meta_games_idx", + meta_games_idx, + "pl0_principal_i", + pl0_principal_i, + ) + assert pl0_principal_i == ( + meta_games_idx // hp["load_n_bootstrapped_by_principals"] + ) + meta_games_idx_available_pl1 = copy.deepcopy(all_meta_games_idx) + for meta_policy_i, principal_j in enumerate( + hp["meta_game_matrix_pos_to_principal_idx"] + ): + if principal_j == pl0_principal_i: + meta_games_idx_available_pl1.remove(meta_policy_i) + print("meta_games_idx_available_pl1", meta_games_idx_available_pl1) + assert ( + len(meta_games_idx_available_pl1) + == len(all_meta_games_idx) - hp["load_n_bootstrapped_by_principals"] + ) + return meta_pi_pl0, meta_games_idx_available_pl1, pl0_principal_i + + +def _load_pi_pl1_and_payoff_matrix( + hp, meta_games_idx_available_pl1, pl0_principal_i, payoff_matrix_35x35 +): + meta_pi_pl1_idx = random.choice(meta_games_idx_available_pl1) + meta_pi_pl1 = hp["meta_game_policy_distributions"][meta_pi_pl1_idx][ + POLICY_ID_PL1 + ] + pl1_principal_i = hp["meta_game_matrix_pos_to_principal_idx"][ + meta_pi_pl1_idx + ] + + n_meta_actions = 7 + payoff_matrix = copy.deepcopy( + payoff_matrix_35x35[ + pl0_principal_i + * n_meta_actions : (pl0_principal_i + 1) + * n_meta_actions, + pl1_principal_i + * n_meta_actions : (pl1_principal_i + 1) + * n_meta_actions, + :, + ] + ) + print("cross_play_payoff_matrix", payoff_matrix) + + return payoff_matrix, meta_pi_pl1 + + +def _select_cross_play_meta_policies_for_payoff_reading( + hp, meta_games_idx_available +): + meta_pi_pl0_idx = random.choice(meta_games_idx_available) + meta_pi_pl0 = hp["meta_game_policy_distributions"][meta_pi_pl0_idx][ + POLICY_ID_PL0 + ] + + meta_games_idx_available_pl1 = copy.deepcopy(meta_games_idx_available) + meta_games_idx_available_pl1.remove(meta_pi_pl0_idx) + + meta_pi_pl1_idx = random.choice(meta_games_idx_available_pl1) + meta_pi_pl1 = hp["meta_game_policy_distributions"][meta_pi_pl1_idx][ + POLICY_ID_PL1 + ] + + return meta_pi_pl0, meta_pi_pl1, meta_pi_pl0_idx, meta_pi_pl1_idx + + +def _compute_joint_meta_pi_proba(meta_pi_pl0, meta_pi_pl1): + meta_pi_pl0 = np.expand_dims(meta_pi_pl0, axis=-1) + meta_pi_pl0 = np.expand_dims(meta_pi_pl0, axis=-1) + meta_pi_pl0_ext = np.tile(meta_pi_pl0, (1, len(meta_pi_pl1), 2)) + meta_pi_pl1 = np.expand_dims(meta_pi_pl1, axis=0) + meta_pi_pl1 = np.expand_dims(meta_pi_pl1, axis=-1) + meta_pi_pl1_ext = np.tile(meta_pi_pl1, (len(meta_pi_pl0), 1, 2)) + joint_proba = meta_pi_pl0_ext * meta_pi_pl1_ext + n_players = 2 + assert np.abs(joint_proba.sum() - n_players) < EPSILON, ( + f"joint_proba.sum()" f" {joint_proba.sum()}" + ) + assert np.all(joint_proba >= 0.0 - EPSILON), f"joint_proba {joint_proba}" + assert np.all(joint_proba <= 1.0 + EPSILON), f"joint_proba {joint_proba}" + return joint_proba + + +def _evaluate_by_playing_in_base_game(hp): + all_rllib_configs = [] + for meta_game_idx in range(hp["n_replicates_over_full_exp"]): + ( + rllib_config, + stop_config, + trainer, + hp_eval, + ) = _get_final_base_game_rllib_config(copy.deepcopy(hp), meta_game_idx) + + all_rllib_configs.append(rllib_config) + + master_rllib_config = amtft_meta_game._mix_rllib_config( + all_rllib_configs, hp_eval=hp + ) + tune_analysis = _train_with_tune( + master_rllib_config, + stop_config, + hp, + trainer, + plot_aggregates=False, + ) + return tune_analysis, hp_eval + + +def _get_final_base_game_rllib_config(hp, meta_game_idx): + if hp["base_game_policies"] == BASE_AMTFT: + ( + stop_config, + env_config, + rllib_config, + trainer, + hp_eval, + ) = _get_rllib_config_for_base_amTFT_policy(hp) + elif hp["base_game_policies"] == BASE_LOLA_EXACT: + ( + stop_config, + env_config, + rllib_config, + trainer, + hp_eval, + ) = _get_rllib_config_for_base_lola_exact_policy(hp) + elif hp["base_game_policies"] == BASE_NEGOTIATION: + ( + stop_config, + env_config, + rllib_config, + trainer, + hp_eval, + ) = _get_rllib_config_for_base_negociation_policy(hp) + else: + raise ValueError() + + ( + rllib_config, + stop_config, + ) = _change_simple_rllib_config_for_final_base_game_eval( + hp, rllib_config, stop_config + ) + + if hp["apply_announcement_protocol"]: + rllib_config = _change_rllib_config_to_use_welfare_coordination( + hp, rllib_config, meta_game_idx, env_config + ) + else: + rllib_config = _change_rllib_config_to_use_stochastic_populations( + hp, rllib_config, meta_game_idx + ) + + return rllib_config, stop_config, trainer, hp_eval + + +def _get_rllib_config_for_base_amTFT_policy(hp): + hp_eval = amtft_various_env.get_hyperparameters( + hp["debug"], + train_n_replicates=1, + env="IteratedAsymBoS", + use_r2d2=hp["use_r2d2"], + ) + hp_eval = amtft_various_env.modify_hyperparams_for_the_selected_env( + hp_eval + ) + + ( + rllib_config, + env_config, + stop_config, + hp_eval, + ) = amtft_various_env._generate_eval_config(hp_eval) + + if hp["use_r2d2"]: + trainer = dqn.r2d2.R2D2Trainer + else: + trainer = dqn.dqn.DQNTrainer + + return stop_config, env_config, rllib_config, trainer, hp_eval + + +def _get_rllib_config_for_base_lola_exact_policy(hp): + lola_exact_hp = lola_exact_official.get_hyperparameters( + debug=hp["debug"], env="IteratedAsymBoS", train_n_replicates=1 + ) + ( + hp_eval, + rllib_config, + policies_to_load, + trainable_class, + stop_config, + env_config, + ) = lola_exact_official.generate_eval_config(lola_exact_hp) + + trainer = PGTrainer + + return stop_config, env_config, rllib_config, trainer, lola_exact_hp + + +def _get_rllib_config_for_base_negociation_policy(hp): + raise NotImplementedError() + from marltoolbox.algos.alternating_offers import alt_offers_training + + lola_exact_hp = alt_offers_training.get_hyperparameters() + ( + hp_eval, + rllib_config, + policies_to_load, + trainable_class, + stop_config, + env_config, + ) = alt_offers_training.generate_eval_config(lola_exact_hp) + + trainer = alt_offers_training.AltOffersTraining + + return stop_config, env_config, rllib_config, trainer, lola_exact_hp + + +def _change_simple_rllib_config_for_final_base_game_eval( + hp, rllib_config, stop_config +): + rllib_config["min_iter_time_s"] = 0.0 + rllib_config["timesteps_per_iteration"] = 0 + rllib_config["metrics_smoothing_episodes"] = 1 + if "max_steps" in rllib_config["env_config"].keys(): + rllib_config["rollout_fragment_length"] = rllib_config["env_config"][ + "max_steps" + ] + + rllib_config["multiagent"]["policies_to_train"] = ["None"] + rllib_config["callbacks"] = callbacks.merge_callbacks( + callbacks.PolicyCallbacks, + log.get_logging_callbacks_class( + log_full_epi=True, + # log_full_epi_interval=1, + log_from_policy_in_evaluation=True, + ), + ) + rllib_config["seed"] = tune.sample_from( + lambda spec: miscellaneous.get_random_seeds(1)[0] + ) + if not hp["debug"]: + stop_config["episodes_total"] = 100 + return rllib_config, stop_config + + +def _change_rllib_config_to_use_welfare_coordination( + hp, rllib_config, meta_game_idx, env_config +): + global payoffs_per_groups + all_welfare_pairs_wt_payoffs = payoffs_per_groups[meta_game_idx][0] + + rllib_config["multiagent"]["policies_to_train"] = ["None"] + policies = rllib_config["multiagent"]["policies"] + for policy_idx, policy_id in enumerate(env_config["players_ids"]): + policy_config_items = list(policies[policy_id]) + opp_policy_idx = (policy_idx + 1) % 2 + + egalitarian_welfare_name = ( + "inequity aversion" + if hp["base_game_policies"] == BASE_AMTFT + else "egalitarian" + ) + meta_policy_config = copy.deepcopy(welfare_coordination.DEFAULT_CONFIG) + meta_policy_config.update( + { + "nested_policies": [ + { + "Policy_class": copy.deepcopy(policy_config_items[0]), + "config_update": copy.deepcopy(policy_config_items[3]), + }, + ], + "all_welfare_pairs_wt_payoffs": all_welfare_pairs_wt_payoffs, + "solve_meta_game_after_init": False, + "own_player_idx": policy_idx, + "opp_player_idx": opp_policy_idx, + "own_default_welfare_fn": egalitarian_welfare_name + if policy_idx == 1 + else "utilitarian", + "opp_default_welfare_fn": egalitarian_welfare_name + if opp_policy_idx == 1 + else "utilitarian", + "policy_id_to_load": policy_id, + "policy_checkpoints": hp["base_ckpt_per_replicat"][ + meta_game_idx + ], + "distrib_over_welfare_sets_to_annonce": hp[ + "meta_game_policy_distributions" + ][meta_game_idx][policy_id], + } + ) + policy_config_items[ + 0 + ] = welfare_coordination.WelfareCoordinationTorchPolicy + policy_config_items[3] = meta_policy_config + policies[policy_id] = tuple(policy_config_items) + + return rllib_config + + +def _change_rllib_config_to_use_stochastic_populations( + hp, rllib_config, meta_game_idx +): + tmp_env = rllib_config["env"](rllib_config["env_config"]) + policies = rllib_config["multiagent"]["policies"] + for policy_id, policy_config in policies.items(): + policy_config = list(policy_config) + + stochastic_population_policy_config = ( + _create_one_stochastic_population_config( + hp, meta_game_idx, policy_id, policy_config + ) + ) + + policy_config[0] = StochasticPopulation + policy_config[1] = tmp_env.OBSERVATION_SPACE + policy_config[2] = tmp_env.ACTION_SPACE + policy_config[3] = stochastic_population_policy_config + + rllib_config["multiagent"]["policies"][policy_id] = tuple( + policy_config + ) + return rllib_config + + +def _create_one_stochastic_population_config( + hp, meta_game_idx, policy_id, policy_config +): + """ + This policy config is composed of 3 levels: + The top level: one stochastic population policies per player. This + policies stochasticly select (given some proba distribution) + which nested policy to use. + The intermediary(nested) level: one population (of identical policies) per + welfare function. This policy selects randomly which policy from its + population to use. + The bottom(base) level: amTFT or LOLA-Exact policies used by the + intermediary level. (amTFT contains another nested level) + """ + stochastic_population_policy_config = { + "nested_policies": [], + "sampling_policy_distribution": hp["meta_game_policy_distributions"][ + meta_game_idx + ][policy_id], + } + + print('hp["base_ckpt_per_replicat"]', hp["base_ckpt_per_replicat"]) + print('hp["actions_possible"]', hp["actions_possible"]) + for welfare_i in hp["actions_possible"]: + one_nested_population_config = _create_one_vanilla_population_config( + hp, + policy_id, + copy.deepcopy(policy_config), + meta_game_idx, + welfare_i, + ) + + stochastic_population_policy_config["nested_policies"].append( + one_nested_population_config + ) + + return stochastic_population_policy_config + + +def _create_one_vanilla_population_config( + hp, + policy_id, + policy_config, + meta_game_idx, + welfare_i, +): + base_policy_class = copy.deepcopy(policy_config[0]) + base_policy_config = copy.deepcopy(policy_config[3]) + + nested_population_config = copy.deepcopy(population.DEFAULT_CONFIG) + nested_population_config.update( + { + "policy_checkpoints": hp["base_ckpt_per_replicat"][meta_game_idx][ + welfare_i + ], + "nested_policies": [ + { + "Policy_class": base_policy_class, + "config_update": base_policy_config, + } + ], + "policy_id_to_load": policy_id, + } + ) + + intermediary_config = { + "Policy_class": population.PopulationOfIdenticalAlgo, + "config_update": nested_population_config, + } + + return intermediary_config + + +def _extract_metrics(tune_analysis, hp_eval): + player_1_payoffs = exp_analysis.extract_metrics_for_each_trials( + tune_analysis, metric=hp_eval["x_axis_metric"] + ) + player_2_payoffs = exp_analysis.extract_metrics_for_each_trials( + tune_analysis, metric=hp_eval["y_axis_metric"] + ) + print("player_1_payoffs", player_1_payoffs) + print("player_2_payoffs", player_2_payoffs) + return player_1_payoffs, player_2_payoffs + + +def _extract_coordination_metric(tune_analysis): + coordination_success = exp_analysis.extract_metrics_for_each_trials( + tune_analysis, metric="custom_metrics/coordination_success_mean" + ) + coordination_success = [float(el) for el in coordination_success] + + # coordination_success = path. + # tune_analysis, metric="custom_metrics/coordination_success_mean" + # ) + return coordination_success + + +def _format_result_for_plotting(results): + data_groups_per_mode = {} + df_rows = [] + for player1_avg_r_one_replicate, player2_avg_r_one_replicate in results: + df_row_dict = { + "": ( + player1_avg_r_one_replicate, + player2_avg_r_one_replicate, + ) + } + df_rows.append(df_row_dict) + data_groups_per_mode["cross-play"] = pd.DataFrame(df_rows) + return data_groups_per_mode + + +def _train_meta_policy_using_alpha_rank(hp, pure_strategy=False): + payoff_matrices = copy.deepcopy(hp["payoff_matrices"]) + + policies = [] + for payoff_matrix in payoff_matrices: + + payoff_tables_per_player = [ + payoff_matrix[:, :, 0], + payoff_matrix[:, :, 1], + ] + policy_player_1, policy_player_2 = _compute_policy_wt_alpha_rank( + payoff_tables_per_player + ) + + if pure_strategy: + policy_player_1 = policy_player_1 == policy_player_1.max() + policy_player_2 = policy_player_2 == policy_player_2.max() + policy_player_1 = policy_player_1.float() + policy_player_2 = policy_player_2.float() + + policies.append( + {"player_row": policy_player_1, "player_col": policy_player_2} + ) + print("alpha rank meta policies", policies) + return policies + + +def _compute_policy_wt_alpha_rank(payoff_tables_per_player): + from open_spiel.python.egt import alpharank + from open_spiel.python.algorithms.psro_v2 import utils as psro_v2_utils + + joint_arank, alpha = alpharank.sweep_pi_vs_alpha( + payoff_tables_per_player, return_alpha=True + ) + print("alpha selected", alpha) + ( + policy_player_1, + policy_player_2, + ) = psro_v2_utils.get_alpharank_marginals( + payoff_tables_per_player, joint_arank + ) + print("meta policy_player_1", policy_player_1) + print("meta policy_player_2", policy_player_2) + policy_player_1 = torch.tensor(policy_player_1) + policy_player_2 = torch.tensor(policy_player_2) + return policy_player_1, policy_player_2 + + +def _train_meta_policy_using_replicator_dynamic(hp): + from open_spiel.python.algorithms.projected_replicator_dynamics import ( + projected_replicator_dynamics, + ) + + payoff_matrices = copy.deepcopy(hp["payoff_matrices"]) + policies = [] + for payoff_matrix in payoff_matrices: + payoff_tables_per_player = [ + payoff_matrix[:, :, 0], + payoff_matrix[:, :, 1], + ] + num_actions = payoff_matrix.shape[0] + prd_initial_strategies = [ + np.random.dirichlet(np.ones(num_actions) * 1.5), + np.random.dirichlet(np.ones(num_actions) * 1.5), + ] + if hp["meta_game_policies"] == META_REPLICATOR_DYNAMIC_ZERO_INIT: + policy_player_1, policy_player_2 = projected_replicator_dynamics( + payoff_tables_per_player, + prd_gamma=0.0, + ) + else: + print("prd_initial_strategies", prd_initial_strategies) + policy_player_1, policy_player_2 = projected_replicator_dynamics( + payoff_tables_per_player, + prd_gamma=0.0, + prd_initial_strategies=prd_initial_strategies, + ) + + policy_player_1 = torch.tensor(policy_player_1) + policy_player_2 = torch.tensor(policy_player_2) + policies.append( + {"player_row": policy_player_1, "player_col": policy_player_2} + ) + print("replicator dynamic meta policies", policies) + return policies + + +def _get_random_meta_policy(hp): + payoff_matrices = copy.deepcopy(hp["payoff_matrices"]) + policies = [] + for payoff_matrix in payoff_matrices: + num_actions_player_0 = payoff_matrix.shape[0] + num_actions_player_1 = payoff_matrix.shape[1] + + policy_player_1 = ( + torch.ones(size=(num_actions_player_0,)) / num_actions_player_0 + ) + policy_player_2 = ( + torch.ones(size=(num_actions_player_1,)) / num_actions_player_1 + ) + policies.append( + {"player_row": policy_player_1, "player_col": policy_player_2} + ) + print("random meta policies", policies) + return policies + + +def _train_meta_policy_using_robustness(hp): + + payoff_matrices = copy.deepcopy(hp["payoff_matrices"]) + policies = [] + for payoff_matrix in payoff_matrices: + robustness_score_pl0 = payoff_matrix[:, :, 0].mean(axis=1) + robustness_score_pl1 = payoff_matrix[:, :, 1].mean(axis=0) + + pl0_action = np.argmax(robustness_score_pl0, axis=0) + pl1_action = np.argmax(robustness_score_pl1, axis=0) + + policy_player_1 = torch.zeros((payoff_matrix.shape[0],)) + policy_player_2 = torch.zeros((payoff_matrix.shape[1],)) + policy_player_1[pl0_action] = 1.0 + policy_player_2[pl1_action] = 1.0 + policies.append( + {"player_row": policy_player_1, "player_col": policy_player_2} + ) + print("Robustness meta policies", policies) + return policies + + +def _train_meta_policy_using_minimum(hp): + + payoff_matrices = copy.deepcopy(hp["payoff_matrices"]) + policies = [] + for payoff_matrix in payoff_matrices: + robustness_score_pl0 = payoff_matrix[:, :, 0].min(axis=1) + robustness_score_pl1 = payoff_matrix[:, :, 1].min(axis=0) + + pl0_action = np.argmax(robustness_score_pl0, axis=0) + pl1_action = np.argmax(robustness_score_pl1, axis=0) + + policy_player_1 = torch.zeros((payoff_matrix.shape[0],)) + policy_player_2 = torch.zeros((payoff_matrix.shape[1],)) + policy_player_1[pl0_action] = 1.0 + policy_player_2[pl1_action] = 1.0 + policies.append( + {"player_row": policy_player_1, "player_col": policy_player_2} + ) + print("Minimum meta policies", policies) + return policies + + +if __name__ == "__main__": + debug_mode = True + loop_over_main = True + + if loop_over_main: + base_game_algo_to_eval = ( + BASE_LOLA_EXACT, + # BASE_NEGOTIATION, + ) + meta_game_algo_to_eval = ( + # META_APLHA_RANK, + # META_APLHA_PURE, + # META_REPLICATOR_DYNAMIC, + # META_REPLICATOR_DYNAMIC_ZERO_INIT, + # META_RANDOM, + # META_PG, + # META_LOLA_EXACT, + # META_SOS, + # META_UNIFORM, + META_MINIMUM, + ) + for base_game_algo_ in base_game_algo_to_eval: + for meta_game_algo_ in meta_game_algo_to_eval: + main(debug_mode, base_game_algo_, meta_game_algo_) + else: + main(debug_mode) diff --git a/marltoolbox/experiments/tune_function_api/lola_dice_official.py b/marltoolbox/experiments/tune_function_api/lola_dice_official.py index b8f4d2c..2597848 100644 --- a/marltoolbox/experiments/tune_function_api/lola_dice_official.py +++ b/marltoolbox/experiments/tune_function_api/lola_dice_official.py @@ -224,12 +224,12 @@ def main(debug): tune_config = get_tune_config(tune_hparams) ray.init(num_cpus=os.cpu_count(), num_gpus=0) - tune_analysis = tune.run( + experiment_analysis = tune.run( lola_training, name=tune_hparams["exp_name"], config=tune_config ) ray.shutdown() - return tune_analysis + return experiment_analysis if __name__ == "__main__": diff --git a/marltoolbox/experiments/tune_function_api/lola_pg_official.py b/marltoolbox/experiments/tune_function_api/lola_pg_official.py index 56df046..f6299b2 100644 --- a/marltoolbox/experiments/tune_function_api/lola_pg_official.py +++ b/marltoolbox/experiments/tune_function_api/lola_pg_official.py @@ -10,8 +10,10 @@ import marltoolbox.algos.lola_dice.envs as lola_dice_envs from marltoolbox.algos.lola import train_cg, train_pg -from marltoolbox.envs.vectorized_coin_game import \ - VectorizedCoinGame, AsymVectorizedCoinGame +from marltoolbox.envs.vectorized_coin_game import ( + VectorizedCoinGame, + AsymVectorizedCoinGame, +) from marltoolbox.utils import log @@ -20,19 +22,16 @@ def main(debug): tune_hparams = { "exp_name": exp_name, - # Dynamically set "num_episodes": 3 if debug else None, "trace_length": 6 if debug else None, "lr": None, "gamma": None, "batch_size": 12 if debug else None, - # "exp_name": "IPD", # "exp_name": "IMP", "exp_name": "CoinGame", # "exp_name": "AsymCoinGame", - "pseudo": False, "grid_size": 3, "lola_update": True, @@ -44,47 +43,64 @@ def main(debug): "hidden": 32, "reg": 0, "set_zero": 0, - "exact": False, - "warmup": 1, - "seed": 1, - "changed_config": False, "ac_lr": 1.0, "summary_len": 1, "use_MAE": False, - "use_toolbox_env": True, - "clip_loss_norm": False, "clip_lola_update_norm": False, "clip_lola_correction_norm": 3.0, "clip_lola_actor_norm": 10.0, - "entropy_coeff": 0.001, - "weigth_decay": 0.03, } tune_config = get_tune_config(tune_hparams) ray.init(num_cpus=os.cpu_count(), num_gpus=0) - tune_analysis = tune.run(lola_training, - name=tune_hparams["exp_name"], - config=tune_config) + experiment_analysis = tune.run( + lola_training, name=tune_hparams["exp_name"], config=tune_config + ) ray.shutdown() - return tune_analysis - - -def trainer_fn(exp_name, num_episodes, trace_length, exact, pseudo, grid_size, - lr, lr_correction, batch_size, bs_mul, simple_net, hidden, reg, - gamma, lola_update, opp_model, mem_efficient, seed, set_zero, - warmup, changed_config, ac_lr, summary_len, use_MAE, - use_toolbox_env, clip_lola_update_norm, clip_loss_norm, - entropy_coeff, - weigth_decay, **kwargs): + return experiment_analysis + + +def trainer_fn( + exp_name, + num_episodes, + trace_length, + exact, + pseudo, + grid_size, + lr, + lr_correction, + batch_size, + bs_mul, + simple_net, + hidden, + reg, + gamma, + lola_update, + opp_model, + mem_efficient, + seed, + set_zero, + warmup, + changed_config, + ac_lr, + summary_len, + use_MAE, + use_toolbox_env, + clip_lola_update_norm, + clip_loss_norm, + entropy_coeff, + weigth_decay, + **kwargs, +): # Instantiate the environment if exp_name == "IPD": raise NotImplementedError() @@ -92,25 +108,29 @@ def trainer_fn(exp_name, num_episodes, trace_length, exact, pseudo, grid_size, raise NotImplementedError() elif exp_name == "CoinGame": if use_toolbox_env: - env = VectorizedCoinGame(config={ - "batch_size": batch_size, - "max_steps": trace_length, - "grid_size": grid_size, - "get_additional_info": True, - "add_position_in_epi": False, - }) + env = VectorizedCoinGame( + config={ + "batch_size": batch_size, + "max_steps": trace_length, + "grid_size": grid_size, + "get_additional_info": True, + "add_position_in_epi": False, + } + ) else: env = lola_dice_envs.CG(trace_length, batch_size, grid_size) env.seed(seed) elif exp_name == "AsymCoinGame": if use_toolbox_env: - env = AsymVectorizedCoinGame(config={ - "batch_size": batch_size, - "max_steps": trace_length, - "grid_size": grid_size, - "get_additional_info": True, - "add_position_in_epi": False, - }) + env = AsymVectorizedCoinGame( + config={ + "batch_size": batch_size, + "max_steps": trace_length, + "grid_size": grid_size, + "get_additional_info": True, + "add_position_in_epi": False, + } + ) else: env = lola_dice_envs.AsymCG(trace_length, batch_size, grid_size) env.seed(seed) @@ -121,42 +141,45 @@ def trainer_fn(exp_name, num_episodes, trace_length, exact, pseudo, grid_size, if exact: raise NotImplementedError() elif exp_name in ("IPD", "IMP"): - train_pg.train(env, - num_episodes=num_episodes, - trace_length=trace_length, - batch_size=batch_size, - gamma=gamma, - set_zero=set_zero, - lr=lr, - corrections=lola_update, - simple_net=simple_net, - hidden=hidden, - mem_efficient=mem_efficient) + train_pg.train( + env, + num_episodes=num_episodes, + trace_length=trace_length, + batch_size=batch_size, + gamma=gamma, + set_zero=set_zero, + lr=lr, + corrections=lola_update, + simple_net=simple_net, + hidden=hidden, + mem_efficient=mem_efficient, + ) elif exp_name in ("CoinGame", "AsymCoinGame"): - train_cg.train(env, - num_episodes=num_episodes, - trace_length=trace_length, - batch_size=batch_size, - bs_mul=bs_mul, - gamma=gamma, - grid_size=grid_size, - lr=lr, - corrections=lola_update, - opp_model=opp_model, - hidden=hidden, - mem_efficient=mem_efficient, - asymmetry=exp_name == "AsymCoinGame", - warmup=warmup, - changed_config=changed_config, - ac_lr=ac_lr, - summary_len=summary_len, - use_MAE=use_MAE, - use_toolbox_env=use_toolbox_env, - clip_lola_update_norm=clip_lola_update_norm, - clip_loss_norm=clip_loss_norm, - entropy_coeff=entropy_coeff, - weigth_decay=weigth_decay, - ) + train_cg.train( + env, + num_episodes=num_episodes, + trace_length=trace_length, + batch_size=batch_size, + bs_mul=bs_mul, + gamma=gamma, + grid_size=grid_size, + lr=lr, + corrections=lola_update, + opp_model=opp_model, + hidden=hidden, + mem_efficient=mem_efficient, + asymmetry=exp_name == "AsymCoinGame", + warmup=warmup, + changed_config=changed_config, + ac_lr=ac_lr, + summary_len=summary_len, + use_MAE=use_MAE, + use_toolbox_env=use_toolbox_env, + clip_lola_update_norm=clip_lola_update_norm, + clip_loss_norm=clip_loss_norm, + entropy_coeff=entropy_coeff, + weigth_decay=weigth_decay, + ) else: raise ValueError(f"exp_name: {exp_name}") @@ -167,42 +190,49 @@ def lola_training(config): def get_tune_config(hp: dict) -> dict: # Sanity - assert hp['exp_name'] in {"CoinGame", "IPD", "IMP", "AsymCoinGame"} - if hp['exact']: - assert hp['exp_name'] != "CoinGame", \ - "Can't run CoinGame with --exact." - assert hp['exp_name'] != "AsymCoinGame", \ - "Can't run AsymCoinGame with --exact." + assert hp["exp_name"] in {"CoinGame", "IPD", "IMP", "AsymCoinGame"} + if hp["exact"]: + assert hp["exp_name"] != "CoinGame", "Can't run CoinGame with --exact." + assert ( + hp["exp_name"] != "AsymCoinGame" + ), "Can't run AsymCoinGame with --exact." # Resolve default parameters - if hp['exact']: - hp['num_episodes'] = \ - 50 if hp['num_episodes'] is None else hp['num_episodes'] - hp['trace_length'] = \ - 200 if hp['trace_length'] is None else hp['trace_length'] - hp['lr'] = \ - 1. if hp['lr'] is None else hp['lr'] - elif hp['exp_name'] in {"IPD", "IMP"}: - hp['num_episodes'] = \ - 600000 if hp['num_episodes'] is None else hp['num_episodes'] - hp['trace_length'] = \ - 150 if hp['trace_length'] is None else hp['trace_length'] - hp['batch_size'] = \ - 4000 if hp['batch_size'] is None else hp['batch_size'] - hp['lr'] = 1. if hp['lr'] is None else hp['lr'] - elif hp['exp_name'] == "CoinGame" or hp['exp_name'] == "AsymCoinGame": - hp['num_episodes'] = \ - 100000 if hp['num_episodes'] is None else hp['num_episodes'] - hp['trace_length'] = \ - 150 if hp['trace_length'] is None else hp['trace_length'] - hp['batch_size'] = \ - 4000 if hp['batch_size'] is None else hp['batch_size'] - hp['lr'] = 0.005 if hp['lr'] is None else hp['lr'] - - if hp['exp_name'] in ("IPD", "CoinGame", "AsymCoinGame"): - hp['gamma'] = 0.96 if hp['gamma'] is None else hp['gamma'] - elif hp['exp_name'] == "IMP": - hp['gamma'] = 0.9 if hp['gamma'] is None else hp['gamma'] + if hp["exact"]: + hp["num_episodes"] = ( + 50 if hp["num_episodes"] is None else hp["num_episodes"] + ) + hp["trace_length"] = ( + 200 if hp["trace_length"] is None else hp["trace_length"] + ) + hp["lr"] = 1.0 if hp["lr"] is None else hp["lr"] + elif hp["exp_name"] in {"IPD", "IMP"}: + hp["num_episodes"] = ( + 600000 if hp["num_episodes"] is None else hp["num_episodes"] + ) + hp["trace_length"] = ( + 150 if hp["trace_length"] is None else hp["trace_length"] + ) + hp["batch_size"] = ( + 4000 if hp["batch_size"] is None else hp["batch_size"] + ) + hp["lr"] = 1.0 if hp["lr"] is None else hp["lr"] + elif hp["exp_name"] == "CoinGame" or hp["exp_name"] == "AsymCoinGame": + hp["num_episodes"] = ( + 100000 if hp["num_episodes"] is None else hp["num_episodes"] + ) + hp["trace_length"] = ( + 150 if hp["trace_length"] is None else hp["trace_length"] + ) + hp["batch_size"] = ( + 4000 if hp["batch_size"] is None else hp["batch_size"] + ) + hp["lr"] = 0.005 if hp["lr"] is None else hp["lr"] + + if hp["exp_name"] in ("IPD", "CoinGame", "AsymCoinGame"): + hp["gamma"] = 0.96 if hp["gamma"] is None else hp["gamma"] + elif hp["exp_name"] == "IMP": + hp["gamma"] = 0.9 if hp["gamma"] is None else hp["gamma"] return hp diff --git a/marltoolbox/scripts/aggregate_and_plot_tensorboard_data.py b/marltoolbox/scripts/aggregate_and_plot_tensorboard_data.py index 440db1c..4acb0ba 100644 --- a/marltoolbox/scripts/aggregate_and_plot_tensorboard_data.py +++ b/marltoolbox/scripts/aggregate_and_plot_tensorboard_data.py @@ -14,45 +14,58 @@ import matplotlib.colors as mcolors import numpy as np import pandas as pd -from tensorboard.backend.event_processing.event_accumulator import \ - EventAccumulator - -from marltoolbox.utils.miscellaneous import \ - list_all_files_in_one_dir_tree, ignore_str_containing_keys, \ - separate_str_in_group_containing_keys, GROUP_KEY_NONE, \ - keep_strs_containing_keys, fing_longer_substr -from marltoolbox.utils.plot import \ - LOWER_ENVELOPE_SUFFIX, UPPER_ENVELOPE_SUFFIX, PlotHelper, PlotConfig - -FOLDER_NAME = 'aggregates' -AGGREGATION_OPS = {"mean": np.mean, - "min": np.min, - "max": np.max, - "median": np.median, - "std": np.std, - "var": np.var} +from tensorboard.backend.event_processing.event_accumulator import ( + EventAccumulator, +) + +from marltoolbox.utils.miscellaneous import ( + ignore_str_containing_keys, + separate_str_in_group_containing_keys, + GROUP_KEY_NONE, + keep_strs_containing_keys, + fing_longer_substr, +) +from marltoolbox.utils.path import list_all_files_in_one_dir_tree +from marltoolbox.utils.plot import ( + LOWER_ENVELOPE_SUFFIX, + UPPER_ENVELOPE_SUFFIX, + PlotHelper, + PlotConfig, +) + +FOLDER_NAME = "aggregates" +AGGREGATION_OPS = { + "mean": np.mean, + "min": np.min, + "max": np.max, + "median": np.median, + "std": np.std, + "var": np.var, +} COLORS = list(mcolors.TABLEAU_COLORS) -PLOT_KEYS = ["grad_gnorm", - "reward", - "loss", - "entropy", - "entropy_avg", - "td_error", - "error", - "act_dist_inputs_avg", - "act_dist_inputs_single", - "q_values_avg", - "action_prob", - "q_values_single", - "_lr", - "max_q_values", - "min_q_values", - "learn_on_batch", - "timers", - "ms", - "throughput", - ] +PLOT_KEYS = [ + "grad_gnorm", + "reward", + "loss", + "entropy", + "entropy_avg", + "td_error", + "error", + "act_dist_inputs_avg", + "act_dist_inputs_single", + "q_values_avg", + "action_prob", + "q_values_single", + "_lr", + "max_q_values", + "min_q_values", + "learn_on_batch", + "timers", + "ms", + "throughput", + "temperature", +] PLOT_ASSEMBLAGE_TAGS = [ ("policy_reward_mean",), @@ -82,22 +95,20 @@ ("ms",), ("throughput",), ("_lr",), + ("temperature",), ] -class TensorBoardDataExtractor(): - +class TensorBoardDataExtractor: def __init__(self, main_path): self.main_path = main_path now = datetime.datetime.now() self.date_hour_str = now.strftime("%Y_%m_%d_%H_%M_%S") - save_dir = os.path.join(self.main_path, - FOLDER_NAME) + save_dir = os.path.join(self.main_path, FOLDER_NAME) if not os.path.exists(save_dir): os.mkdir(save_dir) - save_dir = os.path.join(save_dir, - self.date_hour_str) + save_dir = os.path.join(save_dir, self.date_hour_str) if not os.path.exists(save_dir): os.mkdir(save_dir) @@ -106,10 +117,10 @@ def __init__(self, main_path): def extract_data(self, ignore_keys, group_keys, output): print("\n===== Extract data =====") file_list = list_all_files_in_one_dir_tree(self.main_path) - file_list_filtered = ignore_str_containing_keys( - file_list, ignore_keys) + file_list_filtered = ignore_str_containing_keys(file_list, ignore_keys) file_list_dict = separate_str_in_group_containing_keys( - file_list_filtered, group_keys) + file_list_filtered, group_keys + ) self._aggregate(self.main_path, output, file_list_dict) return self.save_dir @@ -120,8 +131,7 @@ def _aggregate(self, main_path, output, file_list_dict): extracts_per_group = { group_key: self._extract_x_y_per_keys(main_path, file_list) - for group_key, file_list in - file_list_dict.items() + for group_key, file_list in file_list_dict.items() } if output == "summary": @@ -130,25 +140,35 @@ def _aggregate(self, main_path, output, file_list_dict): # https://github.com/Spenhouet/tensorboard-aggregator elif output == "csv": self._aggregate_to_csv( - main_path, AGGREGATION_OPS, extracts_per_group) + main_path, AGGREGATION_OPS, extracts_per_group + ) print(f"End of aggregation {main_path}") def _extract_x_y_per_keys(self, main_path, file_list): - print("Going to extract", main_path, - "with len(file_list)", len(file_list)) - - event_readers = \ - self._create_event_reader_for_each_log_files(file_list) + print( + "Going to extract", + main_path, + "with len(file_list)", + len(file_list), + ) + + event_readers = self._create_event_reader_for_each_log_files(file_list) if len(event_readers) == 0: return None - all_scalar_events_per_key, keys = \ - self._get_and_validate_all_scalar_keys(event_readers) - steps_per_key, all_scalar_events_per_key = \ - self._get_and_validate_all_steps_per_key( - all_scalar_events_per_key, keys) - values_per_key = \ - self._get_values_per_step_per_key(all_scalar_events_per_key) + ( + all_scalar_events_per_key, + keys, + ) = self._get_and_validate_all_scalar_keys(event_readers) + ( + steps_per_key, + all_scalar_events_per_key, + ) = self._get_and_validate_all_steps_per_key( + all_scalar_events_per_key, keys + ) + values_per_key = self._get_values_per_step_per_key( + all_scalar_events_per_key + ) keys = [key.replace("/", "_") for key in keys] all_per_key = dict(zip(keys, zip(steps_per_key, values_per_key))) @@ -156,27 +176,32 @@ def _extract_x_y_per_keys(self, main_path, file_list): return all_per_key def _create_event_reader_for_each_log_files(self, file_list): - event_readers = [EventAccumulator(file_path).Reload( - ).scalars for file_path in file_list] + event_readers = [ + EventAccumulator(file_path).Reload().scalars + for file_path in file_list + ] # Filter non event files - event_readers = [one_event_reader for one_event_reader in event_readers - if - one_event_reader.Keys()] + event_readers = [ + one_event_reader + for one_event_reader in event_readers + if one_event_reader.Keys() + ] print(f"found {len(event_readers)} event_readers") return event_readers def _get_and_validate_all_scalar_keys(self, event_readers): - all_keys = [tuple(one_event_reader.Keys()) - for one_event_reader in event_readers] + all_keys = [ + tuple(one_event_reader.Keys()) + for one_event_reader in event_readers + ] self._print_discrepencies_in_keys(all_keys) keys = self._get_common_keys(all_keys) all_scalar_events_per_key = [ - [one_event_reader.Items(key) - for one_event_reader in event_readers] + [one_event_reader.Items(key) for one_event_reader in event_readers] for key in keys ] return all_scalar_events_per_key, keys @@ -189,8 +214,10 @@ def _print_discrepencies_in_keys(self, all_keys): for k in keys_1: if k not in keys_2: if k not in missing_k_detected: - print(f"key {k} is not present in all " - f"event_readers") + print( + f"key {k} is not present in all " + f"event_readers" + ) missing_k_detected.append(k) def _get_common_keys(self, all_keys): @@ -207,21 +234,27 @@ def _get_common_keys(self, all_keys): return common_keys def _get_and_validate_all_steps_per_key( - self, all_scalar_events_per_key, keys): + self, all_scalar_events_per_key, keys + ): all_steps_per_key = [ - [tuple(scalar_event.step - for scalar_event in scalar_events) - for scalar_events in all_scalar_events] - for all_scalar_events in all_scalar_events_per_key] + [ + tuple(scalar_event.step for scalar_event in scalar_events) + for scalar_events in all_scalar_events + ] + for all_scalar_events in all_scalar_events_per_key + ] steps_per_key = [] - for key_idx, (all_steps_for_one_key, key) in enumerate(zip( - all_steps_per_key, keys)): + for key_idx, (all_steps_for_one_key, key) in enumerate( + zip(all_steps_per_key, keys) + ): self._print_discrepencies_in_steps(all_steps_for_one_key, key) common_steps = self._keep_common_steps(all_steps_for_one_key) - all_scalar_events_per_key = \ + all_scalar_events_per_key = ( self._remove_events_if_step_missing_somewhere( - common_steps, all_scalar_events_per_key, key_idx) + common_steps, all_scalar_events_per_key, key_idx + ) + ) steps_per_key.append(common_steps) return steps_per_key, all_scalar_events_per_key @@ -229,12 +262,14 @@ def _get_and_validate_all_steps_per_key( def _print_discrepencies_in_steps(self, all_steps_for_one_key, key): for steps_1 in all_steps_for_one_key: for steps_2 in all_steps_for_one_key: - missing_steps = [step - for step in steps_1 - if step not in steps_2] + missing_steps = [ + step for step in steps_1 if step not in steps_2 + ] if len(missing_steps) > 0: - print(f"discrepency in steps logged for key {key}:" - f"{missing_steps} missing") + print( + f"discrepency in steps logged for key {key}:" + f"{missing_steps} missing" + ) break def _keep_common_steps(self, all_steps_for_one_key): @@ -251,62 +286,76 @@ def _keep_common_steps(self, all_steps_for_one_key): return common_steps def _remove_events_if_step_missing_somewhere( - self, common_steps, all_scalar_events_per_key, key_idx): + self, common_steps, all_scalar_events_per_key, key_idx + ): all_scalar_events_per_key[key_idx] = [ - [scalar_event for scalar_event in scalar_events_batch - if scalar_event.step in common_steps] + [ + scalar_event + for scalar_event in scalar_events_batch + if scalar_event.step in common_steps + ] for scalar_events_batch in all_scalar_events_per_key[key_idx] ] return all_scalar_events_per_key def _get_values_per_step_per_key(self, all_scalar_events_per_key): values_per_key = [ - [[scalar_event.value - for scalar_event in scalar_events] - for scalar_events in all_scalar_events] - for all_scalar_events in all_scalar_events_per_key] + [ + [scalar_event.value for scalar_event in scalar_events] + for scalar_events in all_scalar_events + ] + for all_scalar_events in all_scalar_events_per_key + ] return values_per_key - def _aggregate_to_csv(self, main_path, aggregation_ops, - extracts_per_group): + def _aggregate_to_csv( + self, main_path, aggregation_ops, extracts_per_group + ): for group_key, all_per_key in extracts_per_group.items(): if all_per_key is None: continue for key, (steps, values) in all_per_key.items(): - aggregations = {key: aggregation_operation(values, axis=0) - for key, aggregation_operation in - aggregation_ops.items()} + aggregations = { + key: aggregation_operation(values, axis=0) + for key, aggregation_operation in aggregation_ops.items() + } self._write_csv(main_path, group_key, key, aggregations, steps) def _write_csv(self, main_path, group_key, key, aggregations, steps): main_path_split = os.path.split(main_path) group_dir = self._get_valid_filename(group_key) - save_group_dir = os.path.join(self.save_dir, group_dir) \ - if group_key != GROUP_KEY_NONE else self.save_dir + save_group_dir = ( + os.path.join(self.save_dir, group_dir) + if group_key != GROUP_KEY_NONE + else self.save_dir + ) if not os.path.exists(save_group_dir): os.mkdir(save_group_dir) - file_name = self._get_valid_filename(key) + '-' + \ - main_path_split[-1] + '.csv' + file_name = ( + self._get_valid_filename(key) + "-" + main_path_split[-1] + ".csv" + ) df = pd.DataFrame(aggregations, index=steps) save_dir_file_path = os.path.join(save_group_dir, file_name) - df.to_csv(save_dir_file_path, sep=';') + df.to_csv(save_dir_file_path, sep=";") def _get_valid_filename(self, s): - s = str(s).strip().replace(' ', '_') - return re.sub(r'(?u)[^-\w.]', '', s) - - -class SummaryPlotter(): - def plot_selected_keys(self, - save_dir, - plot_keys, - group_keys, - plot_aggregates, - plot_assemble_tags_in_one_plot, - plot_single_lines, - plot_labels_cleaning, - additional_plot_config_kwargs): + s = str(s).strip().replace(" ", "_") + return re.sub(r"(?u)[^-\w.]", "", s) + + +class SummaryPlotter: + def plot_selected_keys( + self, + save_dir, + plot_keys, + group_keys, + plot_aggregates, + plot_assemble_tags_in_one_plot, + plot_single_lines, + plot_labels_cleaning, + additional_plot_config_kwargs, + ): self.plot_labels_cleaning = plot_labels_cleaning self.plot_aggregates = plot_aggregates @@ -316,15 +365,19 @@ def plot_selected_keys(self, save_dir_path = save_dir file_list = list_all_files_in_one_dir_tree(save_dir_path) file_list = keep_strs_containing_keys(file_list, plot_keys) - csv_file_list = [file_path - for file_path in file_list - if "csv" in file_path] + csv_file_list = [ + file_path for file_path in file_list if "csv" in file_path + ] csv_file_groups = separate_str_in_group_containing_keys( - csv_file_list, group_keys) + csv_file_list, group_keys + ) for group_key, csv_files_in_one_group in csv_file_groups.items(): - save_dir_path_group = os.path.join(save_dir_path, group_key) \ - if group_key != GROUP_KEY_NONE else save_dir_path + save_dir_path_group = ( + os.path.join(save_dir_path, group_key) + if group_key != GROUP_KEY_NONE + else save_dir_path + ) if not os.path.exists(save_dir_path_group): os.mkdir(save_dir_path_group) @@ -336,8 +389,10 @@ def plot_selected_keys(self, print("===== Plot assemblages =====") self.plot_several_lines_per_plot( - save_dir_path_group, csv_files_in_one_group, - plot_assemble_tags_in_one_plot) + save_dir_path_group, + csv_files_in_one_group, + plot_assemble_tags_in_one_plot, + ) def plot_one_graph(self, save_dir_path, csv_file_list, y_label=None): data_groups = {} @@ -349,33 +404,40 @@ def plot_one_graph(self, save_dir_path, csv_file_list, y_label=None): if "min_max" in self.plot_aggregates: assert "one_std" not in self.plot_aggregates - df = df.rename(columns={'min': f'mean{LOWER_ENVELOPE_SUFFIX}', - 'max': f'mean{UPPER_ENVELOPE_SUFFIX}'}) + df = df.rename( + columns={ + "min": f"mean{LOWER_ENVELOPE_SUFFIX}", + "max": f"mean{UPPER_ENVELOPE_SUFFIX}", + } + ) else: - df = df.drop(columns=['min', 'max']) + df = df.drop(columns=["min", "max"]) if "one_std" in self.plot_aggregates: assert "min_max" not in self.plot_aggregates - df[f'mean{LOWER_ENVELOPE_SUFFIX}'] = df['mean'] - df['std'] - df[f'mean{UPPER_ENVELOPE_SUFFIX}'] = df['mean'] + df['std'] - df = df.drop(columns=['std', 'var', "median"]) + df[f"mean{LOWER_ENVELOPE_SUFFIX}"] = df["mean"] - df["std"] + df[f"mean{UPPER_ENVELOPE_SUFFIX}"] = df["mean"] + df["std"] + df = df.drop(columns=["std", "var", "median"]) data_groups[tag] = df plot_options = PlotConfig( xlabel="steps", ylabel=fing_longer_substr(all_tags_seen).strip("_") - if y_label is None else y_label, + if y_label is None + else y_label, save_dir_path=save_dir_path, - **self.additional_plot_config_kwargs) + **self.additional_plot_config_kwargs, + ) plot_helper = PlotHelper(plot_options) plot_helper.plot_lines(data_groups) def plot_several_lines_per_plot( - self, save_dir_path, csv_file_list, - plot_assemble_tags_in_one_plot): + self, save_dir_path, csv_file_list, plot_assemble_tags_in_one_plot + ): - for assemblage_idx, list_of_tags_in_assemblage in \ - enumerate(plot_assemble_tags_in_one_plot): + for assemblage_idx, list_of_tags_in_assemblage in enumerate( + plot_assemble_tags_in_one_plot + ): assert isinstance(list_of_tags_in_assemblage, Iterable) # select files for one assemblage assemblage_list = self._group_csv_file_in_aggregates( @@ -384,19 +446,25 @@ def plot_several_lines_per_plot( if len(assemblage_list) > 0: # plot one assemblage - y_label = f"{assemblage_idx}_" + \ - " or ".join(list_of_tags_in_assemblage) + y_label = f"{assemblage_idx}_" + " or ".join( + list_of_tags_in_assemblage + ) self.plot_one_graph( - save_dir_path, assemblage_list, - y_label=y_label) + save_dir_path, assemblage_list, y_label=y_label + ) def _group_csv_file_in_aggregates( - self, csv_file_list, list_of_tags_in_assemblage): + self, csv_file_list, list_of_tags_in_assemblage + ): print(f"Start the {list_of_tags_in_assemblage} assemblage") assemblage_list = [] for csv_file in csv_file_list: - if any([select_key in csv_file - for select_key in list_of_tags_in_assemblage]): + if any( + [ + select_key in csv_file + for select_key in list_of_tags_in_assemblage + ] + ): assemblage_list.append(csv_file) # print("csv files selected for assemblage", assemblage_list) assemblage_list = sorted(assemblage_list) @@ -418,31 +486,36 @@ def extract_tag_from_file_name(self, csv_file): return tag -def add_summary_plots(main_path: str, - ignore_keys: Iterable = ( - "aggregates", "same_cross_play"), - group_keys: Iterable = (), - output: str = "csv", - plot_keys: Iterable = ("policy_reward_mean", - "loss", - "entropy", - "entropy_avg", - "td_error"), - plot_aggregates: Iterable = ("mean", "min_max"), - plot_assemble_tags_in_one_plot=(("policy_reward_mean",), - ("loss", "td_error"), - ("entropy",), - ("entropy_avg",)), - plot_single_lines=False, - plot_labels_cleaning: Iterable = ( - ("learner_stats_", ""), - ("info_learner_", ""), - ("player_", "pl_") - ), - additional_plot_config_kwargs={ - "figsize": (8, 8), - "legend_fontsize": "small"}, - ): +def add_summary_plots( + main_path: str, + ignore_keys: Iterable = ("aggregates", "same_cross_play"), + group_keys: Iterable = (), + output: str = "csv", + plot_keys: Iterable = ( + "policy_reward_mean", + "loss", + "entropy", + "entropy_avg", + "td_error", + ), + plot_aggregates: Iterable = ("mean", "min_max"), + plot_assemble_tags_in_one_plot=( + ("policy_reward_mean",), + ("loss", "td_error"), + ("entropy",), + ("entropy_avg",), + ), + plot_single_lines=False, + plot_labels_cleaning: Iterable = ( + ("learner_stats_", ""), + ("info_learner_", ""), + ("player_", "pl_"), + ), + additional_plot_config_kwargs={ + "figsize": (8, 8), + "legend_fontsize": "small", + }, +): """ Aggregates multiple tensorboard runs into mean, min, max, median, std and save that in tensorboard files or in csv. @@ -489,9 +562,11 @@ def add_summary_plots(main_path: str, :return: """ - if output not in ['summary', 'csv']: - raise ValueError("output must be one of ['summary', 'csv']" - f"current output: {output}") + if output not in ["summary", "csv"]: + raise ValueError( + "output must be one of ['summary', 'csv']" + f"current output: {output}" + ) main_path = os.path.expanduser(main_path) @@ -501,9 +576,15 @@ def add_summary_plots(main_path: str, if output == "csv": plotter = SummaryPlotter() plotter.plot_selected_keys( - save_dir, plot_keys, group_keys, plot_aggregates, - plot_assemble_tags_in_one_plot, plot_single_lines, - plot_labels_cleaning, additional_plot_config_kwargs) + save_dir, + plot_keys, + group_keys, + plot_aggregates, + plot_assemble_tags_in_one_plot, + plot_single_lines, + plot_labels_cleaning, + additional_plot_config_kwargs, + ) def param_list(param): @@ -513,42 +594,61 @@ def param_list(param): return p_list -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() # Args for aggregation data - parser.add_argument("--main_path", - type=str, - help="main path for tensorboard files", - default=os.getcwd()) - parser.add_argument("--group_keys", - type=param_list, - help="keys used to separate files in groups", - default=[]) - parser.add_argument("--ignore_keys", - type=param_list, - help="keys used to ignore files", - default=["aggregates", "same_cross_play"]) - parser.add_argument("--output", - type=str, - help="aggregation can be saves as " - "tensorboard file (summary) or as table (csv)", - default='csv') + parser.add_argument( + "--main_path", + type=str, + help="main path for tensorboard files", + default=os.getcwd(), + ) + parser.add_argument( + "--group_keys", + type=param_list, + help="keys used to separate files in groups", + default=[], + ) + parser.add_argument( + "--ignore_keys", + type=param_list, + help="keys used to ignore files", + default=["aggregates", "same_cross_play"], + ) + parser.add_argument( + "--output", + type=str, + help="aggregation can be saves as " + "tensorboard file (summary) or as table (csv)", + default="csv", + ) # Args for plotting - parser.add_argument("--plot_keys", - type=param_list, - help="keys used to select tensorboard tags to plot", - default=['reward', 'loss', 'entropy']) - parser.add_argument("--plot_aggregates", - type=param_list, - help="which results of aggregation operations to plot", - default=['mean', 'min_max']) - parser.add_argument("--plot_assemble_tags_in_one_plot", - type=param_list, - help="keys used to select tensorboard tags to " - "aggregated plots", - default=[['reward']]) + parser.add_argument( + "--plot_keys", + type=param_list, + help="keys used to select tensorboard tags to plot", + default=["reward", "loss", "entropy"], + ) + parser.add_argument( + "--plot_aggregates", + type=param_list, + help="which results of aggregation operations to plot", + default=["mean", "min_max"], + ) + parser.add_argument( + "--plot_assemble_tags_in_one_plot", + type=param_list, + help="keys used to select tensorboard tags to " "aggregated plots", + default=[["reward"]], + ) args = parser.parse_args() - add_summary_plots(args.main_path, args.ignore_keys, args.group_keys, - args.output, args.plot_keys, args.plot_aggregates, - args.plot_assemble_tags_in_one_plot) + add_summary_plots( + args.main_path, + args.ignore_keys, + args.group_keys, + args.output, + args.plot_keys, + args.plot_aggregates, + args.plot_assemble_tags_in_one_plot, + ) diff --git a/marltoolbox/scripts/alpha-rank.py b/marltoolbox/scripts/alpha-rank.py new file mode 100644 index 0000000..b5bb73d --- /dev/null +++ b/marltoolbox/scripts/alpha-rank.py @@ -0,0 +1,518 @@ +import matplotlib.pyplot as plt + +# import feasible_set_figure as game_utils +# import nashpy as nash +import numpy as np +from open_spiel.python.algorithms import projected_replicator_dynamics +from open_spiel.python.algorithms.psro_v2 import utils +from open_spiel.python.egt import alpharank + +# import copy + +# import prd +# from scipy.special import expit +# import pdb + +SMALL_META_PAYOUT_1 = np.array( + [[4.0, 0.0, 4.0], [0.0, 2.0, 2.0], [4.0, 2.0, 3.0]] +) +SMALL_META_PAYOUT_2 = np.array( + [[1.0, 0.0, 1.0], [0.0, 2.0, 2.0], [1.0, 2.0, 1.5]] +) + +# Util, Egal, Mix, [Util, Egal], [Util, Mix] [Egal, Mix], [All] +# LARGE_META_PAYOUT_1 = np.array( +# [ +# [4.0, 0.0, 2.0, 4.0, 4.0, 2.0, 4.0], +# [0.0, 2.0, 1.0, 2.0, 1.0, 2.0, 2.0], +# [2.0, 1.0, 3.0, 1.0, 3.0, 3.0, 3.0], +# [4.0, 2.0, 3.0, 1.0, 3.0, 3.0, 3.0], +# ] +# ) +# LARGE_META_PAYOUT_2 = np.array( +# [[1.0, 0.0, 1.0], [0.0, 2.0, 2.0], [1.0, 2.0, 1.5]] +# ) +MIS = 0.0 +# Util, Egal, Mix, [Util, Egal], [Util, Mix] [Egal, Mix], [All] +META_PAYOUT_1 = np.array( + [ + [ + 4.0, + MIS, + MIS, + 4.0, + 4.0, + MIS, + 4.0, + ], + [ + MIS, + 2.0, + MIS, + 2.0, + MIS, + 2.0, + 2.0, + ], + [ + MIS, + MIS, + 3.0, + MIS, + 3.0, + 3.0, + 3.0, + ], + [4.0, 2.0, MIS, 3.0, 4.0, 2.0, 3.0], + [4.0, MIS, 3.0, 4.0, 3.5, 3.0, 3.5], + [MIS, 2.0, 3.0, 2.0, 3.0, 2.5, 2.5], + [4.0, 2.0, 3.0, 3.0, 3.5, 2.5, 3.0], + ] +) +META_PAYOUT_2 = np.array( + [ + [1.0, MIS, MIS, 1.0, 1, MIS, 1.0], + [ + MIS, + 2.0, + MIS, + 2.0, + MIS, + 2.0, + 2.0, + ], + [ + MIS, + MIS, + 1.5, + MIS, + 1.5, + 1.5, + 1.5, + ], + [1.0, 2.0, MIS, 1.5, 1.0, 2.0, 1.5], + [1.0, MIS, 1.5, 1.0, 1.25, 1.5, 1.75], + [MIS, 2.0, 1.5, 2.0, 1.5, 1.75, 1.75], + [1.0, 2.0, 1.5, 1.5, 1.25, 1.75, 1.5], + ] +) + +# "(OrderedSet(['egalitarian', 'mixed', 'utilitarian']), +# OrderedSet(['egalitarian', 'mixed']), +# OrderedSet(['egalitarian', 'utilitarian']), +# OrderedSet(['egalitarian']), +# OrderedSet(['mixed', 'utilitarian']), +# OrderedSet(['mixed']), +# OrderedSet(['utilitarian']))" +EMPIRICAL_META_PAYOFFS = np.array( + [ + [ + [2.764299878107763, 1.3754318747969299], + [2.2615328446673537, 1.5902076414957933], + [2.8053149442189285, 1.3916713428125362], + [1.840795943449276, 1.83746234422587], + [3.2260518454370057, 1.1444166400824596], + [2.682269745885431, 1.3429529387657166], + [3.769833944988581, 0.9458803413992027], + ], + [ + [2.2615328446673537, 1.5902076414957933], + [2.2615328446673537, 1.5902076414957933], + [1.840795943449276, 1.83746234422587], + [1.840795943449276, 1.83746234422587], + [2.682269745885431, 1.3429529387657166], + [2.682269745885431, 1.3429529387657166], + [0.7788309720751803, 0.4545001718301003], + ], + [ + [2.8053149442189285, 1.3916713428125362], + [1.840795943449276, 1.83746234422587], + [2.8053149442189285, 1.3916713428125362], + [1.840795943449276, 1.83746234422587], + [3.769833944988581, 0.9458803413992027], + [0.7788309720751803, 0.4545001718301003], + [3.769833944988581, 0.9458803413992027], + ], + [ + [1.840795943449276, 1.83746234422587], + [1.840795943449276, 1.83746234422587], + [1.840795943449276, 1.83746234422587], + [1.840795943449276, 1.83746234422587], + [0.7788309720751803, 0.4545001718301003], + [0.7788309720751803, 0.4545001718301003], + [0.7788309720751803, 0.4545001718301003], + ], + [ + [3.2260518454370057, 1.1444166400824596], + [2.682269745885431, 1.3429529387657166], + [3.769833944988581, 0.9458803413992027], + [0.7788309720751803, 0.4545001718301003], + [3.2260518454370057, 1.1444166400824596], + [2.682269745885431, 1.3429529387657166], + [3.769833944988581, 0.9458803413992027], + ], + [ + [2.682269745885431, 1.3429529387657166], + [2.682269745885431, 1.3429529387657166], + [0.7788309720751803, 0.4545001718301003], + [0.7788309720751803, 0.4545001718301003], + [2.682269745885431, 1.3429529387657166], + [2.682269745885431, 1.3429529387657166], + [0.7788309720751803, 0.4545001718301003], + ], + [ + [3.769833944988581, 0.9458803413992027], + [0.7788309720751803, 0.4545001718301003], + [3.769833944988581, 0.9458803413992027], + [0.7788309720751803, 0.4545001718301003], + [3.769833944988581, 0.9458803413992027], + [0.7788309720751803, 0.4545001718301003], + [3.769833944988581, 0.9458803413992027], + ], + ] +) +EMPIRICAL_META_PAYOFFS_P1 = EMPIRICAL_META_PAYOFFS[:, :, 0] +EMPIRICAL_META_PAYOFFS_P2 = EMPIRICAL_META_PAYOFFS[:, :, 1] + + +def sweep_and_plot_alpha(u1_, u2_, plot=False): + nP = u1_.shape[0] * u1_.shape[1] + alpha_list = np.logspace(-4, 2, 100) + pi_list = np.zeros((0, nP)) + for alpha in alpha_list: + try: + _, _, pi, _, _ = alpharank.compute([u1_, u2_], alpha=alpha) + pi_list = np.vstack((pi_list, pi)) + print("pi", np.argmax(pi)) + except ValueError: + pass + + marginals = utils.get_alpharank_marginals([u1_, u2_], pi_list[-1]) + marginals = [marginals[0].round(2), marginals[1].round(2)] + print("marginals", marginals) + # pdb.set_trace() + if plot: + plt.plot(pi_list) + plt.show() + return pi_list[-1] + + +def sweep_and_plot_epsilon(u1_, u2_): + nP = u1_.shape[0] * u1_.shape[1] + eps_list = np.linspace(0, 1, 100) + pi_list = np.zeros((0, nP)) + for eps in eps_list: + try: + _, _, pi, _, _ = alpharank.compute( + [u1_, u2_], use_inf_alpha=True, inf_alpha_eps=eps + ) + pi_list = np.vstack((pi_list, pi)) + except ValueError: + pass + plt.plot(pi_list) + plt.show() + + +def exploitation_value(ui, u_mi, si, vi, normalize=True): + counterpart_payoffs = np.dot(u_mi, si) + exploiter = np.argmax(counterpart_payoffs) + exploitation_value = np.dot(si, ui)[exploiter] - vi + if normalize: + exploitation_value /= ui.max() - ui.min() + return exploitation_value + + +def is_dominated(u1, u2, v1, v2): + better_than_v1 = u1 > v1 + better_than_v2 = u2 > v2 + where_dominated = better_than_v1 * better_than_v2 + is_dominated_anyhere = where_dominated.sum() > 0 + return is_dominated_anyhere + + +def evaluate_game_and_profile(u1, u2, s1, s2): + v1_ = np.dot(s1, np.dot(u1, s2)) + v2_ = np.dot(s2, np.dot(u2, s1)) + exploitation_value_1 = exploitation_value(u1, u2, s1, v1_) + exploitation_value_2 = exploitation_value(u2, u1, s2, v2_) + is_dominated_ = is_dominated(u1, u2, v1_, v2_) + return is_dominated_, exploitation_value_1, exploitation_value_2 + + +def get_profile_from_meta_solver(u1, u2, meta_solver="alpharank"): + if meta_solver == "alpharank": + # joint_arank = sweep_and_plot_alpha(u1, u2) + joint_arank = alpharank.sweep_pi_vs_alpha([u1, u2]) + s1, s2 = utils.get_alpharank_marginals([u1, u2], joint_arank) + elif meta_solver == "rd": + nA1, nA2 = u1.shape + prd_initial_strategies = [ + np.random.dirichlet(alpha=np.ones(nA1) / nA1), + np.random.dirichlet(alpha=np.ones(nA2) / nA2), + ] + s1, s2 = projected_replicator_dynamics.projected_replicator_dynamics( + [u1, u2], + prd_gamma=0.0, + prd_initial_strategies=prd_initial_strategies, + ) + return s1, s2 + + +# def asymmetric_mcp_meta_solver_eval( +# upper_bound=10, +# grid_size=100, +# meta_solver="alpharank", +# game="bos", +# plot=False, +# ): +# k = 5 +# extra_points_labels = ["ks", "nash", "egal", "util", "meta"] +# meta_distances = np.zeros((grid_size, 4)) +# payoff_profile = np.zeros((grid_size, 5, 2)) +# asymmetries = np.linspace(0, upper_bound, grid_size) +# for rep, asymmetry in enumerate(asymmetries): +# if game == "bos": +# d1, d2 = 0.0, 0.0 +# base_payoff_1 = np.array( +# [[2 + asymmetry, 0.0], [0.0, 1.0 + expit(k * asymmetry) - 0.5]] +# ) +# base_payoff_2 = np.array( +# [[1.0, 0.0], [0.0, 2 - expit(asymmetry * k) + 0.5]] +# ) +# elif game == "chicken": +# d1, d2 = -5, -5 +# base_payoff_1 = np.array( +# [ +# [1.0, 0.0 + expit(asymmetry * k) - 0.5], +# [2 + asymmetry, -5.0], +# ] +# ) +# base_payoff_2 = np.array( +# [[1.0, 2 - expit(asymmetry * k) + 0.5], [0.0, -5.0]] +# ) +# +# s1, s2 = get_profile_from_meta_solver( +# base_payoff_1, base_payoff_2, meta_solver=meta_solver +# ) +# print(s1.round(2), s2.round(2)) +# u1_meta = np.dot(s1, np.dot(base_payoff_1, s2)) +# u2_meta = np.dot(s2, np.dot(base_payoff_2.T, s1)) +# +# ( +# (u1_ks, u2_ks), +# (u1_nash, u2_nash), +# (u1_egal, u2_egal), +# (u1_util, u2_util), +# ) = game_utils.optimize_welfare_discrete( +# base_payoff_1, base_payoff_2, d1, d2, restrict_to_equilibria=True +# ) +# extra_points = np.array( +# [ +# [u1_ks, u1_nash, u1_egal, u1_util, u1_meta], +# [u2_ks, u2_nash, u2_egal, u2_util, u2_meta], +# ] +# ) +# for i in range(5): +# # meta_distances[rep, i] = np.mean((extra_points[:, i] - extra_points[:, 3])**2) +# # meta_distances[rep, i] = np.allclose(extra_points[:, i], extra_points[:, 4], atol=0.2) +# payoff_profile[rep, i, :] = extra_points[:, i] + np.random.normal( +# scale=0.05, size=2 +# ) +# print(f"welfare opt payoffs:\n{extra_points.round(2)}") +# # if plot: +# # game_utils.create_figure(base_payoff_1, base_payoff_2, None, show_points=True, extra_points=extra_points, +# # extra_points_labels=extra_points_labels, fill=False) +# if plot: +# # plt.plot(meta_distances) +# for i in range(payoff_profile.shape[1]): +# # plt.scatter(payoff_profile[:, i, 0], payoff_profile[:, i, 1], label=extra_points_labels[i]) +# plt.plot( +# asymmetries, +# payoff_profile[:, i, 0], +# label=extra_points_labels[i], +# ) +# plt.xlabel("xi") +# plt.ylabel("player 1 payoff") +# plt.legend() +# plt.show() +# return + + +# def random_mcp_meta_solver_eval( +# n_rep=1, meta_solver="alpharank", game="bos", plot=False +# ): +# restrict_to_equilibria = True +# if game == "bos": +# d1, d2 = 0.0, 0.0 +# base_payoff_1 = np.array([[2.0, 0.0], [0.0, 1.0]]) +# base_payoff_2 = np.array([[1.0, 0.0], [0.0, 2.0]]) +# elif game == "chicken": +# d1, d2 = -5.0, -5.0 +# base_payoff_1 = np.array([[1.0, 0.0], [2.0, -5.0]]) +# base_payoff_2 = np.array([[1.0, 2.0], [0.0, -5.0]]) +# elif game == "bospd": +# d1, d2 = 0.0, 0.0 +# base_payoff_1 = np.array( +# [[2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]] +# ) +# base_payoff_2 = np.array( +# [[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 0.0]] +# ) +# elif game == "random": +# nA = 5 +# base_payoff_1 = np.zeros((nA, nA)) +# base_payoff_2 = np.zeros((nA, nA)) +# d1, d2 = None, None +# restrict_to_equilibria = False +# +# extra_points_labels = ["ks", "nash", "egal", "util", "meta"] +# meta_distances = np.zeros((0, len(extra_points_labels) - 1)) +# for rep in range(n_rep): +# perturbation_1 = np.random.normal(scale=0.5, size=2) +# perturbation_2 = np.random.normal(scale=0.5, size=2) +# +# perturbed_payoff_1 = copy.copy(base_payoff_1) +# perturbed_payoff_2 = copy.copy(base_payoff_2) +# +# if game in ["bos", "bospd"]: +# perturbed_payoff_1[[0, 1], [0, 1]] += perturbation_1 +# perturbed_payoff_2[[0, 1], [0, 1]] += perturbation_2 +# if game == "bospd": +# best_bos = np.max( +# ( +# perturbed_payoff_1[:2, :2].max(), +# perturbed_payoff_2[:2, :2].max(), +# ) +# ) +# worst_bos = np.min( +# ( +# perturbed_payoff_1[:2, :2].min(), +# perturbed_payoff_2[:2, :2].min(), +# ) +# ) +# mutual_defect_payoff = worst_bos - 1 +# # defection_payoff = best_bos + 1. +# # sucker_payoff = worst_bos - 2. +# # ToDo: these are assuming iterated bospd, where defections are grim triggered +# defection_payoff = mutual_defect_payoff +# sucker_payoff = mutual_defect_payoff +# perturbed_payoff_1[2, 2] = mutual_defect_payoff +# perturbed_payoff_2[2, 2] = mutual_defect_payoff +# perturbed_payoff_1[2, :2] = defection_payoff +# perturbed_payoff_2[:2, 2] = defection_payoff +# perturbed_payoff_2[2, :2] = sucker_payoff +# perturbed_payoff_1[:2, 2] = sucker_payoff +# +# elif game == "chicken": +# perturbed_payoff_1[[0, 1], [1, 0]] += perturbation_1 +# perturbed_payoff_2[[0, 1], [1, 0]] += perturbation_2 +# elif game == "random": +# perturbed_payoff_1 = base_payoff_1 + np.random.normal( +# size=(nA, nA) +# ) +# perturbed_payoff_2 = base_payoff_2 + np.random.normal( +# size=(nA, nA) +# ) +# multiple_equilibria = ( +# len( +# list( +# nash.Game( +# perturbed_payoff_1, perturbed_payoff_2 +# ).support_enumeration() +# ) +# ) +# > 1 +# ) +# # multiple_equilibria = True +# +# s1, s2 = get_profile_from_meta_solver( +# perturbed_payoff_1, perturbed_payoff_2, meta_solver=meta_solver +# ) +# print(s1.round(2), s2.round(2)) +# u1_meta = np.dot(s1, np.dot(perturbed_payoff_1, s2)) +# u2_meta = np.dot(s2, np.dot(perturbed_payoff_2.T, s1)) +# +# if game != "random" or (game == "random" and multiple_equilibria): +# ( +# (u1_ks, u2_ks), +# (u1_nash, u2_nash), +# (u1_egal, u2_egal), +# (u1_util, u2_util), +# ) = game_utils.optimize_welfare_discrete( +# perturbed_payoff_1, +# perturbed_payoff_2, +# d1, +# d2, +# restrict_to_equilibria=restrict_to_equilibria, +# ) +# extra_points = np.array( +# [ +# [u1_ks, u1_nash, u1_egal, u1_util, u1_meta], +# [u2_ks, u2_nash, u2_egal, u2_util, u2_meta], +# ] +# ) +# n_welfare = len(extra_points[0]) - 1 +# meta_distances_rep = np.zeros(n_welfare) +# for i in range(n_welfare): +# # meta_distances[rep, i] = np.mean((extra_points[:, i] - extra_points[:, 3])**2) +# meta_distances_rep[i] = np.allclose( +# extra_points[:, i], extra_points[:, n_welfare], atol=0.2 +# ) +# meta_distances = np.vstack((meta_distances, meta_distances_rep)) +# print(f"welfare opt payoffs:\n{extra_points.round(2)}") +# if plot: +# game_utils.create_figure( +# perturbed_payoff_1, +# perturbed_payoff_2, +# None, +# show_points=True, +# extra_points=extra_points, +# extra_points_labels=extra_points_labels, +# fill=False, +# ) +# print(meta_distances.mean(axis=0).round(2)) +# return + + +def random_meta_solver_eval(nA=3, n_rep=20, meta_solver="alpharank"): + is_dominated_list = np.zeros(n_rep) + exploitation_1_list = np.zeros(n_rep) + exploitation_2_list = np.zeros(n_rep) + + for rep in range(n_rep): + print(rep) + u1_ = np.random.random(size=(nA, nA)) + u2_ = np.random.random(size=(nA, nA)) + + s1, s2 = get_profile_from_meta_solver( + u1_, u2_, meta_solver=meta_solver + ) + + ( + is_dominated_, + exploitation_value_1, + exploitation_value_2, + ) = evaluate_game_and_profile(u1_, u2_, s1, s2) + is_dominated_list[rep] = is_dominated_ + exploitation_1_list[rep] = exploitation_value_1 + exploitation_2_list[rep] = exploitation_value_2 + + pct_dominated = is_dominated_list.mean() + exploitation_1_mean = exploitation_1_list.mean() + exploitation_2_mean = exploitation_2_list.mean() + + print(pct_dominated, exploitation_1_mean, exploitation_2_mean) + + return + + +if __name__ == "__main__": + # ToDo: do the same thing but for continuous barg meta games + # ToDo: incorporate utilitarian outcome + # sweep_and_plot_alpha(SMALL_META_PAYOUT_1, SMALL_META_PAYOUT_1, plot=True) + sweep_and_plot_alpha( + EMPIRICAL_META_PAYOFFS_P1, EMPIRICAL_META_PAYOFFS_P2, plot=True + ) + # random_mcp_meta_solver_eval(n_rep=20, game='bos', meta_solver='rd') + # asymmetric_mcp_meta_solver_eval(game='bos', grid_size=10, plot=True, meta_solver='rd') + # asymmetric_mcp_meta_solver_eval(game='chicken', plot=True) diff --git a/marltoolbox/scripts/analyse_psro_meta_policies.py b/marltoolbox/scripts/analyse_psro_meta_policies.py new file mode 100644 index 0000000..3e9da88 --- /dev/null +++ b/marltoolbox/scripts/analyse_psro_meta_policies.py @@ -0,0 +1,76 @@ +import json +import os + +import numpy as np +from marltoolbox.utils import path + + +def main(): + prefix, training_folders, n_players = _get_inputs() + replicates_folders_per_exp = _preprocess_inputs(prefix, training_folders) + + for replicates_folders, (_, n_base_policies) in zip( + replicates_folders_per_exp, + training_folders): + _print_stats_for_exp(replicates_folders, n_base_policies) + + +def _get_inputs(): + prefix = "~/dev-maxime/CLR/vm-data/" + training_folders = [ + ("instance-60-cpu-2-preemtible/PSRO_hardcoded/2021_05_26/11_49_58/", + 10) + ] + n_players = 2 + return prefix, training_folders, n_players + + +def _preprocess_inputs(prefix, training_folders): + replicates_folders_per_exp = [] + for training_folder, _ in training_folders: + training_folder_path = os.path.join(prefix, training_folder) + training_folder_path = os.path.expanduser(training_folder_path) + all_replicates = path.get_children_paths_wt_selecting_filter( + training_folder_path, _filter="PSROTrainer" + ) + print("all_replicates", len(all_replicates)) + replicates_folders_per_exp.append(all_replicates) + + return replicates_folders_per_exp + + +def _print_stats_for_exp(replicates_folders, n_base_policies): + + for replicates_folder in replicates_folders: + results = path.get_results_for_replicate(replicates_folder) + last_psro_iter_results = results[-1] + print("last_psro_ier_results", last_psro_iter_results) + player_0_meta_policy = last_psro_iter_results["player_0_meta_policy"] + player_1_meta_policy = last_psro_iter_results["player_1_meta_policy"] + for i in range(n_base_policies): + + # file_path = os.path.expanduser(file) + # with (open(file_path, "rb")) as f: + # file_content = json.load(f) + # file_content = _format_2nd_into_1st_format(file_content, file_data) + # values_per_replicat_per_player = np.array(file_content) + # + # assert values_per_replicat_per_player.ndim == 2 + # n_replicates_in_content = values_per_replicat_per_player.shape[0] + # n_players_in_content = values_per_replicat_per_player.shape[1] + # assert n_players_in_content == n_players + # + # values_per_replicat_per_player = ( + # values_per_replicat_per_player / file_data[1] + # ) + # + # mean_per_player = values_per_replicat_per_player.mean(axis=0) + # std_dev_per_player = values_per_replicat_per_player.std(axis=0) + # std_err_per_player = std_dev_per_player / np.sqrt( + # n_replicates_in_content + # ) + # return mean_per_player, std_dev_per_player, std_err_per_player + + +if __name__ == "__main__": + main() diff --git a/marltoolbox/scripts/average_saved_meta_game_results.py b/marltoolbox/scripts/average_saved_meta_game_results.py new file mode 100644 index 0000000..d3162d6 --- /dev/null +++ b/marltoolbox/scripts/average_saved_meta_game_results.py @@ -0,0 +1,202 @@ +import json +import os + +import numpy as np + + +def main(debug): + prefix, files_data, n_players = _get_inputs() + files_to_process = _preprocess_inputs(prefix, files_data) + + for file, file_data in zip(files_to_process, files_data): + ( + mean_per_player, + std_dev_per_player, + std_err_per_player, + coordination_success, + ) = _get_stats_for_file(file, n_players, file_data) + + print( + file_data[0], + "mean:", + mean_per_player, + "std_dev:", + std_dev_per_player, + "std_err:", + std_err_per_player, + "mean coordination_success:", + coordination_success, + ) + + +def _get_inputs(): + # Files on Maxime's local machine + # prefix = "~/dev-maxime/CLR/vm-data/" + # Files in unzipped folder + prefix = "./data/" + files_data = ( + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(PG) & BASE(LOLA-Exact)", + 200, + "instance-60-cpu-4-preemtible/meta_game_compare/" + "2021_05_07/13_46_27/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(LOLA-Exact) & BASE(LOLA-Exact)", + 200, + "instance-60-cpu-4-preemtible/meta_game_compare/" + "2021_05_07/12_32_57/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(alpha-rank pure strategies) & BASE(LOLA-Exact)", + 200, + "instance-60-cpu-4-preemtible/meta_game_compare/" + "2021_05_07/12_05_00/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(alpha-rank mixed strategies) & BASE(LOLA-Exact)", + 200, + "instance-60-cpu-4-preemtible/meta_game_compare/" + "2021_05_07/12_00_06/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(replicator dynamic) & BASE(LOLA-Exact)", + 200, + "instance-60-cpu-4-preemtible/meta_game_compare/" + "2021_05_07/11_24_50/final_base_game/final_eval_in_base_game.json", + ), + # + # + # BELOW WT ANNOUNCEMENT + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(Uniform(announcement+tau=0)) & BASE(announcement + LOLA-Exact)", + 200, + "instance-60-cpu-4-preemtible/LOLA_Exact/2021_05_14/13_50_55" + "/final_eval_in_base_game.json", + "2nd_format_placeholder", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(alpha-rank mixed on welfare sets) & BASE(announcement + " + "LOLA-Exact)", + 200, + "instance-60-cpu-4-preemtible/meta_game_compare/" + "2021_05_14/10_37_24" + "/meta_game/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(alpha-rank pure on welfare sets) & BASE(announcement + " + "LOLA-Exact)", + 200, + "instance-60-cpu-4-preemtible/meta_game_compare/" + "2021_05_14/10_39_47/meta_game/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(replicator dynamic random init on welfare sets) & BASE(" + "announcement + LOLA-Exact)", + 200, + "instance-60-cpu-4-preemtible/meta_game_compare/" + "2021_05_14/10_42_10/meta_game/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(replicator dynamic default init on welfare sets) & BASE(" + "announcement + LOLA-Exact)", + 200, + "instance-60-cpu-4-preemtible/meta_game_compare/" + "2021_05_14/10_46_23/meta_game/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(baseline random) & BASE(" "announcement + LOLA-Exact)", + 200, + "instance-60-cpu-4-preemtible/meta_game_compare/" + "2021_05_14/10_50_36" + "/meta_game/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(PG) & BASE(" "announcement + LOLA-Exact)", + 200, + "instance-60-cpu-4-preemtible/meta_game_compare/" + "2021_05_14/10_52_43/meta_game/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(LOLA-Exact) & BASE(" "announcement + LOLA-Exact)", + 200, + "instance-60-cpu-4-preemtible/meta_game_compare/" + "2021_05_14/11_00_02" + "/meta_game/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(SOS-Exact) & BASE(" "announcement + LOLA-Exact)", + 200, + "instance-60-cpu-4-preemtible/meta_game_compare/" + "2021_05_14/12_38_59" + "/meta_game/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(Minimum) & BASE(" "announcement + LOLA-Exact)", + 200, + "instance-60-cpu-4-preemtible/meta_game_compare/" + "2021_05_27" + "/19_24_36/meta_game/final_base_game/final_eval_in_base_game.json", + ), + ) + n_players = 2 + return prefix, files_data, n_players + + +def _preprocess_inputs(prefix, files_data): + files_to_process = [ + os.path.join(prefix, file_data[2]) for file_data in files_data + ] + return files_to_process + + +def _get_stats_for_file(file, n_players, file_data): + file_path = os.path.expanduser(file) + with (open(file_path, "rb")) as f: + file_content = json.load(f) + file_content = _format_2nd_into_1st_format(file_content, file_data) + if isinstance(file_content, dict): + coordination_success = file_content["mean_coordination_success"] + file_content = file_content["results"] + else: + coordination_success = ( + "N.A. => need to use the more recent " + "saves stored under date 2021_05_26" + ) + values_per_replicat_per_player = np.array(file_content) + + assert values_per_replicat_per_player.ndim == 2 + n_replicates_in_content = values_per_replicat_per_player.shape[0] + n_players_in_content = values_per_replicat_per_player.shape[1] + assert n_players_in_content == n_players + + values_per_replicat_per_player = ( + values_per_replicat_per_player / file_data[1] + ) + + mean_per_player = values_per_replicat_per_player.mean(axis=0) + std_dev_per_player = values_per_replicat_per_player.std(axis=0) + std_err_per_player = std_dev_per_player / np.sqrt( + n_replicates_in_content + ) + return ( + mean_per_player, + std_dev_per_player, + std_err_per_player, + coordination_success, + ) + + +def _format_2nd_into_1st_format(file_content, file_data): + if len(file_data) == 4: + file_content = file_content[0][2] + new_format = [] + for p1_content, p2_content in zip(file_content[0], file_content[1]): + new_format.append((p1_content, p2_content)) + file_content = new_format + return file_content + + +if __name__ == "__main__": + debug_mode = False + main(debug_mode) diff --git a/marltoolbox/scripts/checking_cross_play_values.py b/marltoolbox/scripts/checking_cross_play_values.py new file mode 100644 index 0000000..0ed27b3 --- /dev/null +++ b/marltoolbox/scripts/checking_cross_play_values.py @@ -0,0 +1,148 @@ +import os, pickle +import numpy as np +import matplotlib.pyplot as plt + + +def _print_metrix_perf( + self_play_payoff_matrices_7x7, cross_play_payoff_matrices_7x7 +): + for ia_coeff in range(7): + print("ia_coeff player 0:", ia_coeff / 10) + pl0_self_play_mean = self_play_payoff_matrices_7x7[ + ia_coeff, :, 0, : + ].mean(axis=-1) + pl1_self_play_mean = self_play_payoff_matrices_7x7[ + ia_coeff, :, 1, : + ].mean(axis=-1) + # print("Player 0 self_play_mean", pl0_self_play_mean) + # print("Player 1 self_play_mean", pl1_self_play_mean) + + pl0_cross_play_mean = cross_play_payoff_matrices_7x7[ + ia_coeff, :, 0, : + ].mean(axis=-1) + pl1_cross_play_mean = cross_play_payoff_matrices_7x7[ + ia_coeff, :, 1, : + ].mean(axis=-1) + print("Player 0 cross_play_mean", pl0_cross_play_mean) + print("Player 1 cross_play_mean", pl1_cross_play_mean) + + +def _print_perf_accross_differences(cross_play_payoff_matrices_7x7): + cross_play_payoff_matrices_7x7 = cross_play_payoff_matrices_7x7.mean( + axis=-1 + ) + normalization_factor = 0.5 + cross_play_same_pref = ( + ( + np.diag(cross_play_payoff_matrices_7x7[..., 0]).mean() + + np.diag(cross_play_payoff_matrices_7x7[..., 1]).mean() + ) + / 2 + / normalization_factor + ) + print("cross_play_same_pref", cross_play_same_pref) + mat_len = cross_play_payoff_matrices_7x7.shape[0] + dia = np.diag_indices(mat_len) + dia_sum = np.sum(cross_play_payoff_matrices_7x7[dia]) + cross_play_same_pref_bis = ( + np.mean(cross_play_payoff_matrices_7x7[dia]) / normalization_factor + ) + off_dia_sum = np.sum(cross_play_payoff_matrices_7x7) - dia_sum + cross_play_diff_pref = ( + off_dia_sum / (mat_len * (mat_len - 1)) / 2 / normalization_factor + ) + + print("cross_play_same_pref_bis", cross_play_same_pref_bis) + print("cross_play_diff_pref", cross_play_diff_pref) + + # print("cross_play_payoff_matrices_7x7", cross_play_payoff_matrices_7x7) + # plt.plot(cross_play_payoff_matrices_7x7[..., 0]) + # plt.show() + + cross_play_wt_diff_pref_diff_range = [] + for i in range(mat_len - 1): + print("i", i + 1) + cross_play_diff_pref = ( + np.mean(cross_play_payoff_matrices_7x7[0, i + 1, :]) + / normalization_factor + ) + print("diff i") + print("cross_play_diff_pref", cross_play_diff_pref) + cross_play_wt_diff_pref_diff_range.append(cross_play_diff_pref) + plt.plot(cross_play_wt_diff_pref_diff_range) + plt.xlabel("IA coeff of the 2nd player") + plt.ylabel("Normalized scores with IA coeff 1st player = 0.0") + plt.show() + + dia_pos = list(dia) + dia_neg = list(dia) + cross_play_wt_diff_pref_diff_range = [] + for i in range(mat_len): + + pos_idx_kept = np.logical_and( + dia_pos[1] >= 0, dia_pos[1] <= mat_len - 1 + ) + filtered_dia_pos = dia_pos[0][pos_idx_kept], dia_pos[1][pos_idx_kept] + neg_idx_kept = np.logical_and( + dia_neg[1] >= 0, dia_neg[1] <= mat_len - 1 + ) + filtered_dia_neg = dia_neg[0][neg_idx_kept], dia_neg[1][neg_idx_kept] + print("i", i, "filtered_dia_pos", filtered_dia_pos) + print("i", i, "filtered_dia_neg", filtered_dia_neg) + cross_play_diff_pref = ( + ( + np.mean(cross_play_payoff_matrices_7x7[filtered_dia_pos]) + + np.mean(cross_play_payoff_matrices_7x7[filtered_dia_neg]) + ) + / 2 + / normalization_factor + ) + print("diff i") + print("cross_play_diff_pref", cross_play_diff_pref) + cross_play_wt_diff_pref_diff_range.append(cross_play_diff_pref) + + dia_pos[1] = dia_pos[1] + 1 + dia_neg[1] = dia_neg[1] - 1 + plt.plot(cross_play_wt_diff_pref_diff_range) + plt.xlabel("abs(IA coeff 1st player - IA coeff 2nd player)") + plt.ylabel("Normalized scores") + plt.show() + + +if __name__ == "__main__": + + data_prefix = ( + "/home/maxime/ssd_space/CLR/marltoolbox/marltoolbox/experiments" + "/tune_class_api/" + ) + + file_path = os.path.join( + data_prefix, "empirical_game_matrices_prosociality_coeff_0.3" + ) + with open(file_path, "rb") as f: + payoff_matrix_35x35 = pickle.load(f) + + self_play_payoff_matrices_7x7 = [] + cross_play_payoff_matrices_7x7 = [] + size = 7 + for i in range(5): + for j in range(5): + sub_mat = payoff_matrix_35x35[ + i * size : (i + 1) * size, j * size : (j + 1) * size, : + ] + if i == j: + self_play_payoff_matrices_7x7.append(sub_mat) + else: + cross_play_payoff_matrices_7x7.append(sub_mat) + + self_play_payoff_matrices_7x7 = np.stack( + self_play_payoff_matrices_7x7, axis=-1 + ) + cross_play_payoff_matrices_7x7 = np.stack( + cross_play_payoff_matrices_7x7, axis=-1 + ) + + _print_metrix_perf( + self_play_payoff_matrices_7x7, cross_play_payoff_matrices_7x7 + ) + _print_perf_accross_differences(cross_play_payoff_matrices_7x7) diff --git a/marltoolbox/scripts/negotiation_plot.py b/marltoolbox/scripts/negotiation_plot.py new file mode 100644 index 0000000..25064c3 --- /dev/null +++ b/marltoolbox/scripts/negotiation_plot.py @@ -0,0 +1,64 @@ +import json +import numpy as np +import matplotlib.pyplot as plt + +perf_path = ( + "/home/maxime/dev-maxime/CLR/vm-data/negotiation_meta_game_payoffs.json" +) +exploit_path = "/home/maxime/dev-maxime/CLR/vm-data/negotiation_meta_game_exploitability.json" + +with (open(perf_path, "rb")) as f: + perf_content = json.load(f) +print("perf_content", perf_content) + +with (open(exploit_path, "rb")) as f: + exploit_content = json.load(f) +print("exploit_content", exploit_content) + +exploit_order = [ + exploit_one_meta_solver[0]["meta_game"] + for exploit_one_meta_solver in exploit_content +] + +print("exploit_order", exploit_order) +exploit_order[5] = "PolicyGradient" + +perf_order = [ + "alpha rank mixed on welfare sets", + "alpha rank pure on welfare sets", + "replicator dynamic init on welfare sets", + "replicator dynamic random on welfare sets", + "baseline random", + "PG", + "LOLA-Exact", + "SOS-Exact", + "(announcement+tau=0)", +] + + +perf = [np.array(perf_content[key]).mean() for key in perf_order] + +exploit = [ + ( + +exploit_one_meta_solver[0]["pl_2_mean"] + + exploit_one_meta_solver[1]["pl_1_mean"] + ) + / 2 + for exploit_one_meta_solver in exploit_content +] + + +print("perf", perf) +print("exploit", exploit) + +fig, ax = plt.subplots() + +ax.scatter(perf, exploit) +plt.xlabel("mean payoff in cross play") +plt.ylabel("mean payoff while exploited") +plt.xlim((0.455, 0.466)) +plt.ylim((0.442, 0.45)) +for i, txt in enumerate(exploit_order): + ax.annotate(txt, (perf[i], exploit[i]), rotation=10 * i) + +plt.show() diff --git a/marltoolbox/scripts/plot_bar_chart_from_saved_results.py b/marltoolbox/scripts/plot_bar_chart_from_saved_results.py new file mode 100644 index 0000000..c3a920f --- /dev/null +++ b/marltoolbox/scripts/plot_bar_chart_from_saved_results.py @@ -0,0 +1,1308 @@ +import json +import os +from collections import namedtuple + +import matplotlib.pyplot as plt +import numpy as np + +plt.switch_backend("agg") +plt.style.use("seaborn-whitegrid") +plt.rcParams.update({"font.size": 12}) + +COLORS = [ + "#377eb8", + "#ff7f00", + "#4daf4a", + "#f781bf", + "#a65628", + "#984ea3", + "#999999", + "#e41a1c", + "#dede00", +] + +Exp_data = namedtuple("Exp_data", ["base_algo", "env", "perf"]) +Perf = namedtuple("Perf", ["mean", "std_dev", "std_err", "raw"]) +File_data = namedtuple( + "File_data", + [ + "base_algo", + "env", + "reward_adaptation_divider", + "path_to_self_play", + "path_to_preferences", + "max_r_by_players", + "min_r_by_players", + "welfare_optim", + "welfare_functions", + ], +) +Final_values = namedtuple( + "Final_values", + [ + "base_algo", + "env", + "self_play", + "cross_play", + "cross_play_same", + "cross_play_diff", + ], +) +NA = "N/A" + +PLAYER_0 = 0 +PLAYER_1 = 1 + +VALUE_TERM = "value" +SPLIT_NEGO = False +REMOVE_STARS = True + +USING_METRIC = 1 +# USING_METRIC = 2 +# USING_METRIC = 3 +# USING_METRIC = 4 +ABCG_USE_MEAN_DISAGREEMENT = None +# metric (1) Normalize(mean of ideals given welfare functions) +# NORMALIZE_USE_EMPIRICAL_MAX = False & USE_DISTANCE_TO_WELFARE_OPTIM = False +# metric (2) Normalize(empirical max and ideal min) +# NORMALIZE_USE_EMPIRICAL_MAX = True & USE_DISTANCE_TO_WELFARE_OPTIM = False +# metric (3) Min distance to the welfare-optimal profiles +# NORMALIZE_USE_EMPIRICAL_MAX = False & USE_DISTANCE_TO_WELFARE_OPTIM = True +if USING_METRIC == 1: + NORMALIZE_USE_EMPIRICAL_MAX = False + USE_DISTANCE_TO_WELFARE_OPTIM = False +elif USING_METRIC == 2: + NORMALIZE_USE_EMPIRICAL_MAX = True + USE_DISTANCE_TO_WELFARE_OPTIM = False +elif USING_METRIC == 3: + NORMALIZE_USE_EMPIRICAL_MAX = False + USE_DISTANCE_TO_WELFARE_OPTIM = True + ABCG_USE_MEAN_DISAGREEMENT = False +elif USING_METRIC == 4: + NORMALIZE_USE_EMPIRICAL_MAX = False + USE_DISTANCE_TO_WELFARE_OPTIM = False + +if USE_DISTANCE_TO_WELFARE_OPTIM: + NORMALIZED_NAME = f"Distance" + assert not NORMALIZE_USE_EMPIRICAL_MAX +else: + NORMALIZED_NAME = f"Normalized {VALUE_TERM}" + +LEGEND = [ + "Self-play", + "Cross-play, identical welfare functions", + "Cross-play, different welfare functions", +] + +LEGEND_NO_SPLIT = [ + "Self-play", + "Cross-play", +] + +GLOBAL_CROSS_PLAY_IN_LOLA = False +if GLOBAL_CROSS_PLAY_IN_LOLA: + LEGEND = [ + "Self-play", + "Cross-play", + "Cross-play, identical welfare functions", + "Cross-play, different welfare functions", + ] +# LOLA_EXACT_WT_IPD_IDX = 1 +N_NO_MCP = 4 +NEGOTIATION_RATIO = 0.66 + +CG_N_STEPS = 100.0 +ABCG_N_STEPS = 100.0 +if USE_DISTANCE_TO_WELFARE_OPTIM: + IPD_MAX = None + IPD_MIN = None + ASYMIBOS_MAX = None + ASYMIBOS_MIN = None + CG_MAX = None + CG_MIN = None + ABCG_MAX = None + ABCG_MIN = None + EMPIRICAL_WELFARE_OPTIMUM_CG = ( + (34.9875 / CG_N_STEPS, 34.89375 / CG_N_STEPS), + (34.14375 / CG_N_STEPS, 34.05625 / CG_N_STEPS), + ) + EMPIRICAL_WELFARE_OPTIMUM_ABCG = ( + (14.9640625 / ABCG_N_STEPS, 112.25625 / ABCG_N_STEPS), + (34.41875 / ABCG_N_STEPS, 34.60625 / ABCG_N_STEPS), + ) + ABCG_MEAN_DISAGREEMENT = ( + (17.886574074074073 + 2.6090686274509802) / (2 * ABCG_N_STEPS), + (18.39814814814815 + 11.79656862745098) / (2 * ABCG_N_STEPS), + ) +else: + IPD_MAX = (-1, -1) + IPD_MIN = (-3, -3) + ASYMIBOS_MAX = ((4 + 2) / 2.0, (2 + 1) / 2.0) + ASYMIBOS_MIN = (0, 0) + N_CELLS_AT_1_STEP = 4 + N_CELLS_AT_2_STEPS = 4 + N_CELLS_EXCLUDING_CURRENT = 8 + MAX_PICK_SPEED = ( + N_CELLS_AT_1_STEP / N_CELLS_EXCLUDING_CURRENT / 1 + + N_CELLS_AT_2_STEPS / N_CELLS_EXCLUDING_CURRENT / 2 + ) # 0.75 + CG_MAX = (1.0 * MAX_PICK_SPEED / 2.0, 1.0 * MAX_PICK_SPEED / 2.0) + CG_MIN = (0, 0) + ABCG_MAX = ( + (2 / 2.0 + 1 / 2.0) * MAX_PICK_SPEED / 2.0, + (3 / 2.0 + 1 / 2.0) * MAX_PICK_SPEED / 2.0, + ) + ABCG_MIN = (0, 0) + EMPIRICAL_WELFARE_OPTIMUM_CG = None + EMPIRICAL_WELFARE_OPTIMUM_ABCG = None + ABCG_USE_MEAN_DISAGREEMENT = None + ABCG_MEAN_DISAGREEMENT = None + +UTILITARIAN_W = lambda xy: xy.sum(axis=1) +EGALITARIAN_W = lambda xy: xy.min(axis=1) +NASH_W = lambda xy: xy[:, 0] * xy[:, 1] + +if USING_METRIC == 1: + # (Mean, Std_err) + negotiation_self_play = ( + 0.4567 / NEGOTIATION_RATIO, + 0.0003 / NEGOTIATION_RATIO, + ) + negotiation_cross_play = ( + 0.4272 / NEGOTIATION_RATIO, + 0.0002 / NEGOTIATION_RATIO, + ) + negotiation_same_play = ( + 0.4526 / NEGOTIATION_RATIO, + 0.0003 / NEGOTIATION_RATIO, + ) + negotiation_diff_play = ( + 0.4017 / NEGOTIATION_RATIO, + 0.0004 / NEGOTIATION_RATIO, + ) +elif USING_METRIC == 2: + # (Mean, Std_err) + raise ValueError("not value for NEGOTIATION with mertic 2") +elif USING_METRIC == 3: + # (Mean, Std_err) + negotiation_self_play = (0.4014471264303261, 0.00041116296647391) + negotiation_cross_play = (0.4673496367353227, 0.00030417354727912266) + negotiation_same_play = (0.4041803573397903, 0.0004165843958125686) + negotiation_diff_play = (0.5305189161308548, 0.000424625506684431) +elif USING_METRIC == 4: + # (Mean, Std_err) + # TODO fill Negotiation values for metric 4 here + # negotiation_self_play = (..., ...) + # negotiation_cross_play = (..., ...) + # negotiation_same_play = (..., ...) + # negotiation_diff_play = (..., ...) + pass +else: + raise ValueError() + + +def main(debug): + prefix, files_data, n_players = _get_inputs() + files_to_process = _preprocess_inputs(prefix, files_data) + + perf_per_mode_per_files = [] + for file_paths, file_data in zip(files_to_process, files_data): + perf_per_mode = _get_stats(file_paths, n_players, file_data) + perf_per_mode_per_files.append( + Exp_data(file_data.base_algo, file_data.env, perf_per_mode) + ) + + # Plot with 1 subplot + # _plot_bars(perf_per_mode_per_files, welfare_split=True) + # _plot_bars(perf_per_mode_per_files, welfare_split=False) + + # Plot with 2 subplots + _plot_bars_separate(perf_per_mode_per_files, welfare_split=True) + _plot_bars_separate(perf_per_mode_per_files, welfare_split=False) + + +def _get_inputs(): + # Files on Maxime's local machine + prefix = "~/dev-maxime/CLR/vm-data/" + # Files in unzipped folder + # prefix = "./data/" + files_data = ( + File_data( + "amTFT", + "IPD", + 20.0, + "instance-60-cpu-1-preemtible" + "/amTFT/2021_05_11/07_31_41/eval/2021_05_11/09_20_44" + "/plot_self_crossself_and_cross_play_policy_reward_mean_player_col_vs_policy_reward_mean_player_row_matrix.json", + "instance-60-cpu-1-preemtible" + "/amTFT/2021_05_11/07_31_41/eval/2021_05_11/09_20_44" + "/plot_same_and_diff_prefself_and_cross_play_policy_reward_mean_player_col_vs_policy_reward_mean_player_row_matrix.json", + IPD_MAX, + IPD_MIN, + ( + ((-1.0, -1.0),), + (-3.0, -3.0), + ), + (UTILITARIAN_W, EGALITARIAN_W), + ), + File_data( + "LOLA-Exact *", + "IPD", + 200.0, + "instance-60-cpu-1-preemtible" + "/LOLA_Exact/2021_05_11/07_46_03/eval/2021_05_11" + "/07_49_14" + "/plot_self_crossself_and_cross_play_policy_reward_mean_player_col_vs_policy_reward_mean_player_row_matrix.json", + None, + IPD_MAX, + IPD_MIN, + ( + ((-1.0, -1.0),), + (-3.0, -3.0), + ), + (UTILITARIAN_W, EGALITARIAN_W), + ), + File_data( + "amTFT", + "CG", + CG_N_STEPS, + "instance-20-cpu-1-memory-x2" + "/amTFT/2021_05_15/07_16_37/eval/2021_05_16/08_13_25" + "/plot_self_crossself_and_cross_play_policy_reward_mean_player_blue_vs_policy_reward_mean_player_red_matrix.json", + "instance-20-cpu-1-memory-x2" + "/amTFT/2021_05_15/07_16_37/eval/2021_05_16/08_13_25" + "/plot_same_and_diff_prefself_and_cross_play_policy_reward_mean_player_blue_vs_policy_reward_mean_player_red_matrix.json", + CG_MAX, + CG_MIN, + (EMPIRICAL_WELFARE_OPTIMUM_CG, (0.0, 0.0)), + (UTILITARIAN_W, EGALITARIAN_W), + ), + File_data( + "LOLA-PG *", + "CG", + 40.0, + "instance-20-cpu-1-memory-x2" + "/LOLA_PG/2021_05_19/12_18_06/eval/2021_05_20/10_27_09" + "/plot_self_crossself_and_cross_play_policy_reward_mean_player_blue_vs_policy_reward_mean_player_red_matrix.json", + None, + CG_MAX, + CG_MIN, + (EMPIRICAL_WELFARE_OPTIMUM_CG, (0.0, 0.0)), + (UTILITARIAN_W, EGALITARIAN_W), + ), + File_data( + "amTFT", + "IAsymBoS", + 20.0, + "instance-60-cpu-1-preemtible" + "/amTFT/2021_05_11/07_40_04/eval/2021_05_11/11_43_26" + "/plot_self_crossself_and_cross_play_policy_reward_mean_player_col_vs_policy_reward_mean_player_row_matrix.json", + "instance-60-cpu-1-preemtible" + "/amTFT/2021_05_11/07_40_04/eval/2021_05_11/11_43_26" + "/plot_same_and_diff_prefself_and_cross_play_policy_reward_mean_player_col_vs_policy_reward_mean_player_row_matrix.json", + ASYMIBOS_MAX, + ASYMIBOS_MIN, + (((4.0, 1.0), (2.0, 2.0)), (0.0, 0.0)), + (UTILITARIAN_W, EGALITARIAN_W), + ), + File_data( + "LOLA-Exact", + "IAsymBoS", + 200.0, + "instance-60-cpu-1-preemtible" + "/LOLA_Exact/2021_05_11/07_47_16/eval/2021_05_11" + "/07_50_36" + "/plot_self_crossself_and_cross_play_policy_reward_mean_player_col_vs_policy_reward_mean_player_row_matrix.json", + "instance-60-cpu-1-preemtible" + "/LOLA_Exact/2021_05_11/07_47_16/eval/2021_05_11/07_50_36" + "/plot_same_and_diff_prefself_and_cross_play_policy_reward_mean_player_col_vs_policy_reward_mean_player_row_matrix.json", + ASYMIBOS_MAX, + ASYMIBOS_MIN, + (((4.0, 1.0), (3.0, 1.5), (2.0, 2.0)), (0.0, 0.0)), + (UTILITARIAN_W, EGALITARIAN_W, NASH_W), + ), + File_data( + "amTFT", + "ABCG", + ABCG_N_STEPS, + "instance-10-cpu-2" + "/amTFT/2021_05_17/18_08_40/eval/2021_05_20/04_51_12" + "/plot_self_crossself_and_cross_play_policy_reward_mean_player_blue_vs_policy_reward_mean_player_red_matrix.json", + "instance-10-cpu-2" + "/amTFT/2021_05_17/18_08_40/eval/2021_05_20/04_51_12" + "/plot_same_and_diff_prefself_and_cross_play_policy_reward_mean_player_blue_vs_policy_reward_mean_player_red_matrix.json", + ABCG_MAX, + ABCG_MIN, + ( + EMPIRICAL_WELFARE_OPTIMUM_ABCG, + ABCG_MEAN_DISAGREEMENT + if ABCG_USE_MEAN_DISAGREEMENT + else (0.0, 0.0), + ), + (UTILITARIAN_W, EGALITARIAN_W), + ), + File_data( + "LOLA-PG **", + "ABCG", + 40.0, + "instance-60-cpu-2-preemtible" + "/LOLA_PG/2021_05_19/08_17_37/eval/2021_05_19/18_02_21" + "/plot_self_crossself_and_cross_play_policy_reward_mean_player_blue_vs_policy_reward_mean_player_red_matrix.json", + "instance-60-cpu-2-preemtible" + "/LOLA_PG/2021_05_19/08_17_37/eval/2021_05_19/18_02_21" + "/plot_same_and_diff_prefself_and_cross_play_policy_reward_mean_player_blue_vs_policy_reward_mean_player_red_matrix.json", + ABCG_MAX, + ABCG_MIN, + ( + EMPIRICAL_WELFARE_OPTIMUM_ABCG, + ABCG_MEAN_DISAGREEMENT + if ABCG_USE_MEAN_DISAGREEMENT + else (0.0, 0.0), + ), + (UTILITARIAN_W, EGALITARIAN_W), + ), + Final_values( + "Negotiation", + "REINFORCE-PS", + # "self-play", + (0.4567 / NEGOTIATION_RATIO, 0.0003 / NEGOTIATION_RATIO) + if not USE_DISTANCE_TO_WELFARE_OPTIM + else (0.4014471264303261, 0.00041116296647391), + # "cross-play", + (0.4272 / NEGOTIATION_RATIO, 0.0002 / NEGOTIATION_RATIO) + if not USE_DISTANCE_TO_WELFARE_OPTIM + else (0.4673496367353227, 0.00030417354727912266), + # "cross-play same", + (0.4526 / NEGOTIATION_RATIO, 0.0003 / NEGOTIATION_RATIO) + if not USE_DISTANCE_TO_WELFARE_OPTIM + else (0.4041803573397903, 0.0004165843958125686), + # "cross-play diff", + (0.4017 / NEGOTIATION_RATIO, 0.0004 / NEGOTIATION_RATIO) + if not USE_DISTANCE_TO_WELFARE_OPTIM + else (0.5305189161308548, 0.000424625506684431), + ), + ) + n_players = 2 + return prefix, files_data, n_players + + +def _preprocess_inputs(prefix, files_data): + files_to_process = [] + for file_data in files_data: + if isinstance(file_data, Final_values): + value = file_data + elif file_data.path_to_preferences is not None: + value = ( + os.path.join(prefix, file_data.path_to_self_play), + os.path.join(prefix, file_data.path_to_preferences), + ) + else: + value = ( + os.path.join(prefix, file_data.path_to_self_play), + None, + ) + files_to_process.append(value) + + return files_to_process + + +def _get_stats(file_paths, n_players, file_data): + if isinstance(file_paths, Final_values): + all_perf = file_paths + else: + self_play_path = file_paths[0] + perf_per_mode = _get_stats_for_file( + self_play_path, n_players, file_data + ) + self_play = perf_per_mode["self-play"] + cross_play = perf_per_mode["cross-play"] + + preference_path = file_paths[1] + if preference_path is not None: + perf_per_mode_bis = _get_stats_for_file( + preference_path, n_players, file_data + ) + same_preferences_cross_play = perf_per_mode_bis[ + "cross-play: same pref vs same pref" + ] + if ( + "cross-play: diff pref vs diff pref" + in perf_per_mode_bis.keys() + ): + diff_preferences_cross_play = perf_per_mode_bis[ + "cross-play: diff pref vs diff pref" + ] + else: + diff_preferences_cross_play = NA + else: + same_preferences_cross_play = NA + diff_preferences_cross_play = NA + + if NORMALIZE_USE_EMPIRICAL_MAX: + ( + self_play, + cross_play, + same_preferences_cross_play, + diff_preferences_cross_play, + ) = _empirical_normalization( + self_play, + cross_play, + same_preferences_cross_play, + diff_preferences_cross_play, + file_data, + ) + + if USING_METRIC == 4: + ( + self_play, + cross_play, + same_preferences_cross_play, + diff_preferences_cross_play, + ) = _convert_to_metric_4( + self_play, + cross_play, + same_preferences_cross_play, + diff_preferences_cross_play, + file_data, + ) + + all_perf = [ + self_play, + cross_play, + same_preferences_cross_play, + diff_preferences_cross_play, + ] + return all_perf + + +def _convert_to_metric_4( + self_play, + cross_play, + same_preferences_cross_play, + diff_preferences_cross_play, + file_data, +): + all_values = [] + self_play_pair = np.stack( + [self_play[PLAYER_0].raw, self_play[PLAYER_1].raw], axis=1 + ) + all_values.append(self_play_pair) + cross_play_pair = np.stack( + [cross_play[PLAYER_0].raw, cross_play[PLAYER_1].raw], axis=1 + ) + all_values.append(cross_play_pair) + print("self_play_pair", self_play_pair.shape) + print("cross_play_pair", cross_play_pair.shape) + if not isinstance(same_preferences_cross_play, str): + same_play_pair = np.stack( + [ + same_preferences_cross_play[PLAYER_0].raw, + same_preferences_cross_play[PLAYER_1].raw, + ], + axis=1, + ) + all_values.append(same_play_pair) + print("same_play_pair", same_play_pair.shape) + if not isinstance(diff_preferences_cross_play, str): + diff_play_pair = np.stack( + [ + diff_preferences_cross_play[PLAYER_0].raw, + diff_preferences_cross_play[PLAYER_1].raw, + ], + axis=1, + ) + all_values.append(diff_play_pair) + print("diff_play_pair", diff_play_pair.shape) + + all_values = np.concatenate(all_values, axis=0) + print("all_values", all_values.shape) + + for k in [PLAYER_0, PLAYER_1]: + self_play[k] = Perf( + self_play[k].mean, + self_play[k].std_dev, + self_play[k].std_err, + _compute_metric_4( + self_play_pair, + all_values, + file_data, + ), + ) + cross_play[k] = Perf( + cross_play[k].mean, + cross_play[k].std_dev, + cross_play[k].std_err, + _compute_metric_4( + cross_play_pair, + all_values, + file_data, + ), + ) + if not isinstance(same_preferences_cross_play, str): + same_preferences_cross_play[k] = Perf( + same_preferences_cross_play[k].mean, + same_preferences_cross_play[k].std_dev, + same_preferences_cross_play[k].std_err, + _compute_metric_4( + same_play_pair, + all_values, + file_data, + ), + ) + if not isinstance(diff_preferences_cross_play, str): + diff_preferences_cross_play[k] = Perf( + diff_preferences_cross_play[k].mean, + diff_preferences_cross_play[k].std_dev, + diff_preferences_cross_play[k].std_err, + _compute_metric_4( + diff_play_pair, + all_values, + file_data, + ), + ) + return ( + self_play, + cross_play, + same_preferences_cross_play, + diff_preferences_cross_play, + ) + + +def _empirical_normalization( + self_play, + cross_play, + same_preferences_cross_play, + diff_preferences_cross_play, + file_data, +): + for k in [PLAYER_0, PLAYER_1]: + possible_max = [] + possible_max.extend(self_play[k].raw) + possible_max.extend(cross_play[k].raw) + if not isinstance(same_preferences_cross_play, str): + possible_max.extend(same_preferences_cross_play[k].raw) + if not isinstance(diff_preferences_cross_play, str): + possible_max.extend(diff_preferences_cross_play[k].raw) + max_ = max(possible_max) + min_ = file_data.min_r_by_players[k] + self_play[k] = Perf( + self_play[k].mean, + self_play[k].std_dev, + self_play[k].std_err, + _normalize(self_play[k].raw, max_, min_), + ) + cross_play[k] = Perf( + cross_play[k].mean, + cross_play[k].std_dev, + cross_play[k].std_err, + _normalize(cross_play[k].raw, max_, min_), + ) + if not isinstance(same_preferences_cross_play, str): + same_preferences_cross_play[k] = Perf( + same_preferences_cross_play[k].mean, + same_preferences_cross_play[k].std_dev, + same_preferences_cross_play[k].std_err, + _normalize(same_preferences_cross_play[k].raw, max_, min_), + ) + if not isinstance(diff_preferences_cross_play, str): + diff_preferences_cross_play[k] = Perf( + diff_preferences_cross_play[k].mean, + diff_preferences_cross_play[k].std_dev, + diff_preferences_cross_play[k].std_err, + _normalize(diff_preferences_cross_play[k].raw, max_, min_), + ) + return ( + self_play, + cross_play, + same_preferences_cross_play, + diff_preferences_cross_play, + ) + + +def _normalize(values, max_, min_): + values = values - min_ + values = values / (max_ - min_) + return values + + +def _get_stats_for_file(file, n_players, file_data): + perf_per_mode = {} + file_path = os.path.expanduser(file) + with (open(file_path, "rb")) as f: + file_content = json.load(f) + possible_max = [] + for eval_mode, mode_perf in file_content.items(): + perf = [None] * 2 + print("eval_mode", eval_mode) + for metric, metric_perf in mode_perf.items(): + player_idx = _extract_player_idx(metric) + + perf_per_replicat = np.array( + _convert_str_of_list_to_list(metric_perf["raw_data"]) + ) + + n_replicates_in_content = len(perf_per_replicat) + values_per_replicat_per_player = _adapt_values( + perf_per_replicat, file_data, player_idx + ) + + mean_per_player = values_per_replicat_per_player.mean(axis=0) + std_dev_per_player = values_per_replicat_per_player.std(axis=0) + std_err_per_player = std_dev_per_player / np.sqrt( + n_replicates_in_content + ) + perf[player_idx] = Perf( + mean_per_player, + std_dev_per_player, + std_err_per_player, + values_per_replicat_per_player, + ) + + if USE_DISTANCE_TO_WELFARE_OPTIM: + metric = _distance_to_welfare_optimal_profiles( + np.stack([perf[PLAYER_0].raw, perf[PLAYER_1].raw], axis=1), + file_data, + ) + + perf[0] = Perf( + perf[0].mean, + perf[0].std_dev, + perf[0].std_err, + metric, + ) + perf[1] = Perf( + perf[1].mean, + perf[1].std_dev, + perf[1].std_err, + metric, + ) + + perf_per_mode[eval_mode] = perf + + return perf_per_mode + + +def _extract_player_idx(metric): + if "player_row" in metric: + player_idx = PLAYER_0 + elif "player_col" in metric: + player_idx = PLAYER_1 + elif "player_red" in metric: + player_idx = PLAYER_0 + elif "player_blue" in metric: + player_idx = PLAYER_1 + else: + raise ValueError() + return player_idx + + +def _adapt_values(values_per_replicat_per_player, file_data, player_idx): + scaled_values = ( + values_per_replicat_per_player / file_data.reward_adaptation_divider + ) + + if not USE_DISTANCE_TO_WELFARE_OPTIM: + if not NORMALIZE_USE_EMPIRICAL_MAX: + if USING_METRIC != 4: + assert USING_METRIC == 1 + normalized_values = _normalize( + scaled_values, + file_data.max_r_by_players[player_idx], + file_data.min_r_by_players[player_idx], + ) + return normalized_values + return scaled_values + + +def _distance_to_welfare_optimal_profiles(scaled_values, file_data): + welfare_optimal_profiles = file_data.welfare_optim[0] + disagrement_profile = file_data.welfare_optim[1] + disagreement_to_optim_wel_prof = [] + for one_welfare_optimal_profile in welfare_optimal_profiles: + # print("one_welfare_optimal_profile", one_welfare_optimal_profile) + dist_disagr = np.linalg.norm( + np.array(disagrement_profile) + - np.array(one_welfare_optimal_profile) + ) + # print("dist_disagr", dist_disagr) + disagreement_to_optim_wel_prof.append(dist_disagr) + disagreement_to_optim_wel_prof = min(disagreement_to_optim_wel_prof) + + # "numerator = min{ ||(u1, u2) - (4, 1)||, ||(u1, u2) - (3, 1.5)||, ||(u1, u2) - (2, 2)|| } + # "denominator = min{ ||(0, 0) - (4, 1)||, ||(0, 0) - (3, 1.5)||, ||(0, 0) - (2, 2)|| }" + + scaled_values = np.array(scaled_values) + print("scaled_values.shape", scaled_values.shape) + values_to_optim_wel_prof = [] + for one_welfare_optimal_profile in welfare_optimal_profiles: + one_welfare_optimal_profile = np.array(one_welfare_optimal_profile) + one_welfare_optimal_profile = np.stack( + [one_welfare_optimal_profile] * len(scaled_values), axis=0 + ) + + print("one_welfare_optimal_profile", one_welfare_optimal_profile.shape) + dist_ = np.linalg.norm( + np.array(scaled_values) - np.array(one_welfare_optimal_profile), + axis=1, + ) + print("dist_", dist_.shape) + values_to_optim_wel_prof.append(dist_) + values_to_optim_wel_prof = np.stack(values_to_optim_wel_prof, axis=1) + print("values_to_optim_wel_prof.shape", values_to_optim_wel_prof.shape) + values_to_optim_wel_prof = values_to_optim_wel_prof.min(axis=1) + print( + "reduced values_to_optim_wel_prof.shape", + values_to_optim_wel_prof.shape, + ) + + scaled_distance = values_to_optim_wel_prof / disagreement_to_optim_wel_prof + return scaled_distance + + +def _compute_metric_4( + payoffs_for_pi, + payoffs_for_all_pi, + file_data, +): + # max_w { [ w(outcome) - w(disagreement) ] + # / [ max_\pi w(\pi) - w(disagreement)] } + + metrics_for_all_welfares = [] + for welfare_fn in file_data.welfare_functions: + welfares = welfare_fn(payoffs_for_pi) + welfares_all_pi = welfare_fn(payoffs_for_all_pi) + disagreement_profile = file_data.welfare_optim[1] + disagreement_welfare = welfare_fn(np.array([disagreement_profile])) + max_welfare = welfares_all_pi.max() + metric_for_one_welfare = (welfares - disagreement_welfare) / ( + max_welfare - disagreement_welfare + ) + + metrics_for_all_welfares.append(metric_for_one_welfare) + metrics_for_all_welfares = np.stack(metrics_for_all_welfares, axis=1) + print("metrics_for_all_welfares.shape", metrics_for_all_welfares.shape) + metrics = metrics_for_all_welfares.max(axis=1) + return metrics + + +def _convert_str_of_list_to_list(str_of_list): + return [ + float(v) + for v in str_of_list.replace("[", "") + .replace("]", "") + .replace(" ", "") + .split(",") + ] + + +def _plot_bars(perf_per_mode_per_files, welfare_split): + plt.figure(figsize=(10, 3)) + + legend, x, groups = _plot_merged_players( + perf_per_mode_per_files, plot_all=True, welfare_split=welfare_split + ) + plt.xticks(x, groups, rotation=15) + plt.ylabel(NORMALIZED_NAME) + plt.ylim((0.0, 1.0)) + + plt.legend( + legend, + frameon=True, + bbox_to_anchor=(1.0, -0.23), + ) + + # Save the figure and show + plt.tight_layout(rect=[0, -0.05, 1.0, 1.0]) + plt.savefig(f"bar_plot_vanilla_split_{welfare_split}.png") + + +def _plot_bars_separate(perf_per_mode_per_files, welfare_split): + plt.figure(figsize=(10, 3)) + rotation = 15 + plt.subplot(121) + _, x, groups = _plot_merged_players( + perf_per_mode_per_files, mcp=False, welfare_split=welfare_split + ) + plt.xticks(x, groups, rotation=rotation, ha="right") + plt.ylabel(NORMALIZED_NAME) + # plt.ylim((0.0, 1.0)) + + plt.subplot(122) + legend, x, groups = _plot_merged_players( + perf_per_mode_per_files, mcp=True, welfare_split=welfare_split + ) + plt.xticks(x, groups, rotation=rotation, ha="right") + # plt.ylabel("Normalized scores") + # plt.ylim((0.0, 1.0)) + + if USE_DISTANCE_TO_WELFARE_OPTIM: + if welfare_split: + plt.tight_layout() + else: + plt.tight_layout() + plt.legend( + legend, + frameon=True, + bbox_to_anchor=(0.0, 0.00, -0.4, 1.0), + ) + else: + if welfare_split: + if GLOBAL_CROSS_PLAY_IN_LOLA: + plt.tight_layout(rect=[0, 0.20, 1.0, 1.0]) + plt.legend( + legend, + frameon=True, + bbox_to_anchor=(-0.4, -0.35), + ) + else: + plt.tight_layout(rect=[0, 0.14, 1.0, 1.0]) + plt.legend( + legend, + frameon=True, + bbox_to_anchor=(-0.4, -0.35), + ) + else: + plt.tight_layout(rect=[0, 0.07, 1.0, 1.0]) + plt.legend( + legend, + frameon=True, + bbox_to_anchor=(-0.7, -0.30), + ) + + # Save the figure and show + # if not USE_DISTANCE_TO_WELFARE_OPTIM: + # if welfare_split: + # plt.tight_layout(rect=[0, -0.07, 1.0, 1.0]) + # else: + # plt.tight_layout(rect=[0, -0.05, 1.0, 1.0]) + plt.savefig(f"bar_plot_separated_split_{welfare_split}.png") + + +def _plot_merged_players( + perf_per_mode_per_files, + mcp: bool = None, + plot_all=False, + welfare_split=True, +): + all_perf = [el.perf for el in perf_per_mode_per_files] + groups = [f"{el.env} + {el.base_algo}" for el in perf_per_mode_per_files] + groups = [group.strip(" + ") for group in groups] + width = 0.1 + + ( + self_play, + cross_play, + same_pref_perf, + diff_pref_perf, + self_play_err, + cross_play_err, + same_pref_perf_err, + diff_pref_perf_err, + ) = _preproces_values(all_perf) + + if plot_all: + plt.text(1.08, 0.04, NA, fontdict={"fontsize": 10.0, "rotation": 90}) + plt.text(3.08, 0.04, NA, fontdict={"fontsize": 10.0, "rotation": 90}) + else: + if not mcp: + if welfare_split: + plt.text( + 1.08, 0.05, NA, fontdict={"fontsize": 10.0, "rotation": 90} + ) + plt.text( + 3.08, 0.05, NA, fontdict={"fontsize": 10.0, "rotation": 90} + ) + self_play = self_play[:N_NO_MCP] + cross_play = cross_play[:N_NO_MCP] + same_pref_perf = same_pref_perf[:N_NO_MCP] + diff_pref_perf = diff_pref_perf[:N_NO_MCP] + self_play_err = self_play_err[:N_NO_MCP] + cross_play_err = cross_play_err[:N_NO_MCP] + same_pref_perf_err = same_pref_perf_err[:N_NO_MCP] + diff_pref_perf_err = diff_pref_perf_err[:N_NO_MCP] + groups = groups[:N_NO_MCP] + # plt.text( + # 1.35, + # -0.35, + # "a)", + # fontdict={"fontsize": 14.0, "weight": "bold"}, + # ) + else: + self_play = self_play[N_NO_MCP:] + cross_play = cross_play[N_NO_MCP:] + same_pref_perf = same_pref_perf[N_NO_MCP:] + diff_pref_perf = diff_pref_perf[N_NO_MCP:] + self_play_err = self_play_err[N_NO_MCP:] + cross_play_err = cross_play_err[N_NO_MCP:] + same_pref_perf_err = same_pref_perf_err[N_NO_MCP:] + diff_pref_perf_err = diff_pref_perf_err[N_NO_MCP:] + groups = groups[N_NO_MCP:] + # plt.text( + # 1.85, + # -0.35, + # "b)", + # fontdict={"fontsize": 14.0, "weight": "bold"}, + # ) + + x = np.arange(len(self_play)) + if welfare_split: + # replace/remove None + same_pref_perf_x = [] + diff_pref_perf_x = [] + new_diff_pref_perf = [] + new_diff_pref_perf_err = [] + new_same_pref_perf = [] + new_same_pref_perf_err = [] + rm_x = [] + rm_cross = [] + rm_cross_err = [] + for (i,( + x_i, + cross, + cross_err, + same_pref, + same_pref_err, + diff_pref, + diff_pref_err, + )) in enumerate(zip( + x, + cross_play, + cross_play_err, + same_pref_perf, + same_pref_perf_err, + diff_pref_perf, + diff_pref_perf_err,) + ): + pass_ = False + if GLOBAL_CROSS_PLAY_IN_LOLA: + if mcp: + if i == 1 or i == 3: + pass_ =True + else: + if i == 1 or i ==3: + pass_ = True + if pass_: + rm_x.append(x_i) + rm_cross.append(cross) + rm_cross_err.append(cross_err) + continue + same_pref_perf_x.append(x_i) + if same_pref is None: + assert same_pref_err is None + new_same_pref_perf.append(cross) + new_same_pref_perf_err.append(cross_err) + else: + new_same_pref_perf.append(same_pref) + new_same_pref_perf_err.append(same_pref_err) + if diff_pref is not None: + assert diff_pref_err is not None + diff_pref_perf_x.append(x_i) + new_diff_pref_perf.append(diff_pref) + new_diff_pref_perf_err.append(diff_pref_err) + + # same_pref_perf = [ + # cross if same_pref is None else same_pref + # f + # ] + # same_pref_perf_err = [ + # cross if same_pref is None else same_pref + # for cross, same_pref in zip(cross_play_err, same_pref_perf_err) + # ] + + same_pref_perf = new_same_pref_perf + same_pref_perf_err = new_same_pref_perf_err + diff_pref_perf = new_diff_pref_perf + diff_pref_perf_err = new_diff_pref_perf_err + + plt.bar( + x - width * 1.0 - 0.02, + self_play, + width, + yerr=self_play_err, + color=COLORS[0], + ecolor="black", + capsize=3, + ) + if welfare_split: + if GLOBAL_CROSS_PLAY_IN_LOLA: + plt.bar( + np.array(rm_x) + width * 0.0, + rm_cross, + width, + yerr=rm_cross_err, + color=COLORS[1], + ecolor="black", + capsize=3, + ) + + if welfare_split: + if welfare_split and GLOBAL_CROSS_PLAY_IN_LOLA: + use_x = np.array(same_pref_perf_x) + else: + use_x = x + plt.bar( + use_x + width * 0.0, + same_pref_perf, + width, + yerr=same_pref_perf_err, + color=COLORS[2] if GLOBAL_CROSS_PLAY_IN_LOLA else COLORS[1], + ecolor="black", + capsize=3, + ) + # if mcp and SPLIT_NEGO: + # # x = x[:-1] + # # diff_pref_perf = diff_pref_perf[:-1] + # # diff_pref_perf_err = diff_pref_perf_err[:-1] + # diff_pref_perf[-1] = 0 + # diff_pref_perf_err[-1] = 0 + if welfare_split: + use_x = np.array(diff_pref_perf_x) + else: + use_x = x + plt.bar( + use_x + width * 1.0 + 0.02, + diff_pref_perf, + width, + yerr=diff_pref_perf_err, + color=COLORS[3] if GLOBAL_CROSS_PLAY_IN_LOLA else COLORS[2], + ecolor="black", + capsize=3, + ) + legend = LEGEND + else: + plt.bar( + x + width * 0.0, + cross_play, + width, + yerr=cross_play_err, + color=COLORS[1], + ecolor="black", + capsize=3, + ) + legend = LEGEND_NO_SPLIT + + + if mcp and SPLIT_NEGO: + # extra for negotiation env + n_values = 6 + x_nego_shit = 4 + with_nego = width / 2 # - 0.001 + space_between_lines = 0.02 + negotiation_x = np.array( + [i * with_nego + i * space_between_lines for i in range(n_values)] + ) + diff_pref_perf_err = [diff_pref_perf_err[-1] for i in range(n_values)] + + # [0.50369, 0.49126, 0.46351, 0.44590, 0.42530, 0.41031] + # [0.50626, 0.47991, 0.46098, 0.43771, 0.41802, 0.41276] + + negotiation_y = [0.47328, 0.44078, 0.41710, 0.39257, 0.37393, 0.31258] + negotiation_y = [el / NEGOTIATION_RATIO for el in negotiation_y] + diff_pref_perf_err = [ + 0.00089, + 0.00088, + 0.00087, + 0.00088, + 0.00087, + 0.00087, + ] + negotiation_y = [ + el / NEGOTIATION_RATIO / np.sqrt(2) for el in negotiation_y + ] + + plt.bar( + x_nego_shit + negotiation_x + width * 1.0, + negotiation_y, + with_nego, + yerr=diff_pref_perf_err, + color=COLORS[2], + ecolor="black", + capsize=3, + ) + + if not welfare_split: + # groups = [el.strip("*") if el.endswith(" *") else el for el in groups] + groups = [ + el.strip("*") + if el.endswith(" *") + else el.replace("**", "*").replace("***", "**") + for el in groups + ] + + if REMOVE_STARS: + groups = [el.replace("*", "").strip() for el in groups] + + return legend, x, groups + + +def _preproces_values(all_perf): + self_play_p0 = _extract_value(all_perf, 0, PLAYER_0, "raw") + self_play_p1 = _extract_value(all_perf, 0, PLAYER_1, "raw") + cross_play_p0 = _extract_value(all_perf, 1, PLAYER_0, "raw") + cross_play_p1 = _extract_value(all_perf, 1, PLAYER_1, "raw") + same_pref_p0 = _extract_value(all_perf, 2, PLAYER_0, "raw") + same_pref_p1 = _extract_value(all_perf, 2, PLAYER_1, "raw") + diff_pref_p0 = _extract_value(all_perf, 3, PLAYER_0, "raw") + diff_pref_p1 = _extract_value(all_perf, 3, PLAYER_1, "raw") + + if USE_DISTANCE_TO_WELFARE_OPTIM: + self_play = self_play_p0 + cross_play = cross_play_p0 + same_pref_perf = same_pref_p0 + diff_pref_perf = diff_pref_p0 + else: + self_play = _avg_over_players(self_play_p0, self_play_p1) + cross_play = _avg_over_players(cross_play_p0, cross_play_p1) + same_pref_perf = _avg_over_players(same_pref_p0, same_pref_p1) + diff_pref_perf = _avg_over_players(diff_pref_p0, diff_pref_p1) + + _log_n_replicates( + self_play, + cross_play, + same_pref_perf, + diff_pref_perf, + ) + + self_play_err = _get_std_err(self_play) + cross_play_err = _get_std_err(cross_play) + same_pref_perf_err = _get_std_err(same_pref_perf) + diff_pref_perf_err = _get_std_err(diff_pref_perf) + + self_play = _get_mean(self_play) + cross_play = _get_mean(cross_play) + same_pref_perf = _get_mean(same_pref_perf) + diff_pref_perf = _get_mean(diff_pref_perf) + + self_play = _replace_final_values(all_perf, self_play, "self-play", "mean") + cross_play = _replace_final_values( + all_perf, cross_play, "cross-play", "mean" + ) + same_pref_perf = _replace_final_values( + all_perf, same_pref_perf, "cross-play same", "mean" + ) + diff_pref_perf = _replace_final_values( + all_perf, diff_pref_perf, "cross-play diff", "mean" + ) + + self_play_err = _replace_final_values( + all_perf, self_play_err, "self-play", "std err" + ) + cross_play_err = _replace_final_values( + all_perf, cross_play_err, "cross-play", "std err" + ) + same_pref_perf_err = _replace_final_values( + all_perf, same_pref_perf_err, "cross-play same", "std err" + ) + diff_pref_perf_err = _replace_final_values( + all_perf, diff_pref_perf_err, "cross-play diff", "std err" + ) + + return ( + self_play, + cross_play, + same_pref_perf, + diff_pref_perf, + self_play_err, + cross_play_err, + same_pref_perf_err, + diff_pref_perf_err, + ) + + +def _log_n_replicates( + self_play, + cross_play, + same_pref_perf, + diff_pref_perf, +): + print("\n_log_n_replicates") + print("self_play", [None if el is None else el.shape for el in self_play]) + print( + "cross_play", [None if el is None else el.shape for el in cross_play] + ) + print( + "same_pref_perf", + [None if el is None else el.shape for el in same_pref_perf], + ) + print( + "diff_pref_perf", + [None if el is None else el.shape for el in diff_pref_perf], + ) + + ratio = [] + for cross, cross_same, cross_diff in zip( + cross_play, same_pref_perf, diff_pref_perf + ): + # if len(cross_same.shape) > 0: + if cross_same is not None: + assert cross.shape[0] == ( + cross_same.shape[0] + cross_diff.shape[0] + ), f"{cross.shape[0]} == {cross_same.shape[0]} + {cross_diff.shape[0]}" + ratio.append(cross_same.shape[0] / cross_diff.shape[0]) + else: + ratio.append(None) + print("cross_same / cross_diff", ratio) + + +def _extract_value(all_perf, idx, player_idx, attrib): + + values = [] + for el in all_perf: + if isinstance(el, Final_values): + values.append(None) + else: + if hasattr(el[idx][player_idx], attrib): + values.append(getattr(el[idx][player_idx], attrib)) + else: + values.append(None) + return values + + +def _avg_over_players(values_player0, values_player1): + return [ + (np.array(v_p0) + np.array(v_p1)) / 2 if v_p0 is not None else None + for v_p0, v_p1 in zip(values_player0, values_player1) + ] + + +def _get_std_err(values): + return [ + v.std() / np.sqrt(v.shape[0]) if v is not None else None + for v in values + ] + + +def _get_mean(values): + return [v.mean() if v is not None else None for v in values] + + +def _replace_final_values(all_perf, values, cat, mode): + new_values = [] + for el, value in zip(all_perf, values): + if isinstance(el, Final_values): + if mode == "mean": + idx = 0 + elif mode == "std err": + idx = 1 + else: + raise ValueError() + + if cat == "self-play": + v = el.self_play[idx] + elif cat == "cross-play": + v = el.cross_play[idx] + elif cat == "cross-play same": + v = el.cross_play_same[idx] + elif cat == "cross-play diff": + v = el.cross_play_diff[idx] + else: + raise ValueError() + else: + v = value + new_values.append(v) + return new_values + + +if __name__ == "__main__": + debug_mode = False + main(debug_mode) diff --git a/marltoolbox/scripts/plot_meta_policies.py b/marltoolbox/scripts/plot_meta_policies.py new file mode 100644 index 0000000..8686177 --- /dev/null +++ b/marltoolbox/scripts/plot_meta_policies.py @@ -0,0 +1,433 @@ +import json +import os + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np + +EPSILON = 1e-6 + + +def main(debug): + prefix, files_data, n_players = _get_inputs() + files_to_process = _preprocess_inputs(prefix, files_data) + + for i, (file_path, file_data) in enumerate( + zip(files_to_process, files_data) + ): + meta_policies, welfares = _get_policies(file_path) + actions_names = _get_actions_names(welfares) + plot_policies( + meta_policies, + actions_names, + title=file_data[0], + path_suffix=f"_{i}", + announcement_protocol=True, + ) + + +def _get_inputs(): + prefix = "~/dev-maxime/CLR/vm-data/" + files_data = ( + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(alpha-rank mixed on welfare sets) & BASE(announcement + " + "LOLA-Exact)", + "instance-60-cpu-4-preemtible/meta_game_compare/2021_05_14" + "/10_37_24/meta_game/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(alpha-rank pure on welfare sets) & BASE(announcement + " + "LOLA-Exact)", + "instance-60-cpu-4-preemtible/meta_game_compare/2021_05_14/10_39_47/meta_game/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(replicator dynamic random init on welfare sets) & BASE(" + "announcement + LOLA-Exact)", + "instance-60-cpu-4-preemtible/meta_game_compare/2021_05_14/10_42_10/meta_game/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(replicator dynamic default init on welfare sets) & BASE(" + "announcement + LOLA-Exact)", + "instance-60-cpu-4-preemtible/meta_game_compare/2021_05_14/10_46_23/meta_game/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(baseline random) & BASE(" "announcement + LOLA-Exact)", + "instance-60-cpu-4-preemtible/meta_game_compare/2021_05_14" + "/10_50_36/meta_game/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(PG) & BASE(" "announcement + LOLA-Exact)", + "instance-60-cpu-4-preemtible/meta_game_compare/2021_05_14/10_52_43/meta_game/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(LOLA-Exact) & BASE(" "announcement + LOLA-Exact)", + "instance-60-cpu-4-preemtible/meta_game_compare/2021_05_14" + "/11_00_02/meta_game/final_base_game/final_eval_in_base_game.json", + ), + ( # 20 x 10 replicates in meta wt 30 x 5 replicates in base + "META(SOS-Exact) & BASE(" "announcement + LOLA-Exact)", + "instance-60-cpu-4-preemtible/meta_game_compare/2021_05_14" + "/12_38_59/meta_game/final_base_game/final_eval_in_base_game.json", + ), + ) + n_players = 2 + return prefix, files_data, n_players + + +def _preprocess_inputs(prefix, files_data): + files_to_process = [ + os.path.expanduser(os.path.join(prefix, file_data[1])) + for file_data in files_data + ] + return files_to_process + + +def _get_policies(file_path): + parent_dir, _ = os.path.split(file_path) + parent_parent_dir, _ = os.path.split(parent_dir) + meta_policies_file = os.path.join(parent_parent_dir, "meta_policies.json") + with (open(meta_policies_file, "rb")) as f: + meta_policies = json.load(f) + meta_policies = meta_policies["meta_policies"] + print("meta_policies", type(meta_policies), meta_policies) + + parent_parent_parent_dir, _ = os.path.split(parent_parent_dir) + welfares_file = os.path.join( + parent_parent_parent_dir, "payoffs_matrices_0.json" + ) + with (open(welfares_file, "rb")) as f: + welfares = json.load(f) + welfares = welfares["welfare_fn_sets"] + print("welfares", type(welfares), welfares) + + return meta_policies, welfares + + +def _get_actions_names(welfares): + actions_names = ( + welfares.replace("OrderedSet", "") + .replace("(", "") + .replace(")", "") + .replace("[", "") + .lstrip() + ) + actions_names = actions_names.split("],") + actions_names = [f"({el})" for el in actions_names] + actions_names = [ + el.replace("]", "").replace(" ", "") for el in actions_names + ] + return actions_names + + +def plot_policies( + meta_policies, + actions_names, + title=None, + path_prefix="", + path_suffix="", + announcement_protocol=False, +): + """ + Given the policies of two players, create plots to vizualize them. + + :param meta_policies: + :param actions_names: names of actions of each players + :param title: title of plot + :param path_prefix: prefix to plot save path + :param path_suffix: suffix to plot save path + :return: + """ + plt.style.use("default") + + if len(actions_names) != 2: + actions_names = (actions_names, actions_names) + print("actions_names", actions_names, "len", len(actions_names[0])) + policies_p0 = [] + policies_p1 = [] + for meta_policy in meta_policies: + policies_p0.append(meta_policy["player_row"]) + policies_p1.append(meta_policy["player_col"]) + print("policies_p0", len(policies_p0), "policies_p1", len(policies_p1)) + policies_p0 = np.array(policies_p0) + policies_p1 = np.array(policies_p1) + + if announcement_protocol: + fig, (ax, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(10, 20)) + else: + fig, (ax, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 16)) + + if title is not None: + fig.suptitle( + title, + fontweight="bold", + ) + _plot_means(policies_p0, policies_p1, actions_names, ax) + _plot_std_dev(policies_p0, policies_p1, actions_names, ax2) + _plot_joint_policies_vanilla( + policies_p0, + policies_p1, + actions_names, + ax3, + ) + if announcement_protocol: + _plot_joint_policies_wt_announcement_protocol( + policies_p0, policies_p1, actions_names, ax4 + ) + + plt.tight_layout() + path_prefix = os.path.expanduser(path_prefix) + plt.savefig(f"{path_prefix}meta_policies{path_suffix}.png") + + plt.style.use("seaborn-whitegrid") + + +def _plot_means(policies_p0, policies_p1, actions_names, ax): + policies_p0_mean = policies_p0.mean(axis=0) + policies_p1_mean = policies_p1.mean(axis=0) + policies_mean = np.stack([policies_p0_mean, policies_p1_mean], axis=0) + im, cbar = heatmap( + policies_mean, + ["player_row", "player_col"], + actions_names[0], + ax=ax, + cmap="YlGn", + cbarlabel="MEAN proba", + ) + texts = annotate_heatmap(im, valfmt="{x:.3f}") + + +def _plot_std_dev(policies_p0, policies_p1, actions_names, ax): + policies_p0_std = policies_p0.std(axis=0) + policies_p1_std = policies_p1.std(axis=0) + policies_std = np.stack([policies_p0_std, policies_p1_std], axis=0) + im, cbar = heatmap( + policies_std, + ["player_row", "player_col"], + actions_names[1], + ax=ax, + cmap="YlGn", + cbarlabel="STD", + ) + texts = annotate_heatmap(im, valfmt="{x:.3f}") + + +def _plot_joint_policies_vanilla(policies_p0, policies_p1, actions_names, ax): + policies_p0 = np.expand_dims(policies_p0, axis=-1) + policies_p1 = np.expand_dims(policies_p1, axis=-1) + policies_p1 = np.transpose(policies_p1, (0, 2, 1)) + joint_policies = np.matmul(policies_p0, policies_p1) + # print("joint_policies", joint_policies[0]) + _plot_joint_policies(joint_policies, actions_names, ax) + + +def _plot_joint_policies_wt_announcement_protocol( + policies_p0, policies_p1, actions_names, ax +): + joint_policies, actions_names = _preprocess_announcement_protocol( + policies_p0, policies_p1, actions_names + ) + _plot_joint_policies(joint_policies, actions_names, ax) + + +def _plot_joint_policies(joint_policies, actions_names, ax): + assert np.all( + np.abs(joint_policies.sum(axis=2).sum(axis=1) - 1.0) < EPSILON + ), f"{np.abs(joint_policies.sum(axis=2).sum(axis=1) - 1.0)}" + joint_policies = joint_policies.mean(axis=0) + + im, cbar = heatmap( + joint_policies, + actions_names[0], + actions_names[1], + ax=ax, + cmap="YlGn", + cbarlabel="Joint policy", + ) + texts = annotate_heatmap(im, valfmt="{x:.3f}") + + +def _preprocess_announcement_protocol(policies_p0, policies_p1, actions_names): + welfare_sets_announced, welfares = _convert_to_welfare_sets(actions_names) + n_welfare_fn = len(welfares) + n_replicates = len(policies_p0) + joint_policies = np.zeros(shape=(n_replicates, n_welfare_fn, n_welfare_fn)) + for relpi_i, (pi_pl0_repl, pi_pl1_repl) in enumerate( + zip(policies_p0, policies_p1) + ): + for w_set_pl0, pi_0 in zip(welfare_sets_announced[0], pi_pl0_repl): + for w_set_pl1, pi_1 in zip(welfare_sets_announced[1], pi_pl1_repl): + intersection = w_set_pl0 & w_set_pl1 + n_welfare_fn_in_intersec = len(intersection) + if n_welfare_fn_in_intersec > 0: + for welfare in intersection: + welfare_idx = welfares.index(welfare) + joint_policies[relpi_i, welfare_idx, welfare_idx] += ( + pi_0 * pi_1 + ) / n_welfare_fn_in_intersec + else: + welfare_idx_pl0 = welfares.index("utilitarian") + welfare_idx_pl1 = welfares.index("egalitarian") + joint_policies[ + relpi_i, welfare_idx_pl0, welfare_idx_pl1 + ] += (pi_0 * pi_1) + + return joint_policies, [welfares, welfares] + + +def _convert_to_welfare_sets(actions_names): + welfare_sets_announced = [] + all_wefares = [] + for actions_names_of_player_i in actions_names: + actions_names_of_player_i = [ + str(el) for el in actions_names_of_player_i + ] + sets_player_i = [] + for action_name in actions_names_of_player_i: + welfare_set = ( + action_name.replace("'", "") + .replace("(", "") + .replace(")", "") + .replace("OrderedSet", "") + .replace("[", "") + .replace("]", "") + .replace(" ", "") + .split("," "") + ) + sets_player_i.append(set(welfare_set)) + all_wefares.extend(welfare_set) + welfare_sets_announced.append(sets_player_i) + all_wefares = tuple(sorted(tuple(set(all_wefares)))) + assert len(all_wefares) <= 3, f"{all_wefares}" + return welfare_sets_announced, all_wefares + + +def heatmap( + data, row_labels, col_labels, ax=None, cbar_kw={}, cbarlabel="", **kwargs +): + """ + Create a heatmap from a numpy array and two lists of labels. + + Parameters + ---------- + data + A 2D numpy array of shape (N, M). + row_labels + A list or array of length N with the labels for the rows. + col_labels + A list or array of length M with the labels for the columns. + ax + A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If + not provided, use current axes or create a new one. Optional. + cbar_kw + A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional. + cbarlabel + The label for the colorbar. Optional. + **kwargs + All other arguments are forwarded to `imshow`. + """ + + if not ax: + ax = plt.gca() + + # Plot the heatmap + im = ax.imshow(data, **kwargs) + + # Create colorbar + cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) + cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") + + # We want to show all ticks... + ax.set_xticks(np.arange(data.shape[1])) + ax.set_yticks(np.arange(data.shape[0])) + # ... and label them with the respective list entries. + ax.set_xticklabels(col_labels) + ax.set_yticklabels(row_labels) + + # Let the horizontal axes labeling appear on top. + ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) + + # Rotate the tick labels and set their alignment. + plt.setp( + ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor" + ) + + # Turn spines off and create white grid. + # ax.spines[:].set_visible(False) + for k, v in ax.spines.items(): + ax.spines[k].set_visible(False) + + ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True) + ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True) + ax.grid(which="minor", color="w", linestyle="-", linewidth=3) + ax.tick_params(which="minor", bottom=False, left=False) + return im, cbar + + # return im, None + + +def annotate_heatmap( + im, + data=None, + valfmt="{x:.2f}", + textcolors=("black", "white"), + threshold=None, + **textkw, +): + """ + A function to annotate a heatmap. + + Parameters + ---------- + im + The AxesImage to be labeled. + data + Data used to annotate. If None, the image's data is used. Optional. + valfmt + The format of the annotations inside the heatmap. This should either + use the string format method, e.g. "$ {x:.2f}", or be a + `matplotlib.ticker.Formatter`. Optional. + textcolors + A pair of colors. The first is used for values below a threshold, + the second for those above. Optional. + threshold + Value in data units according to which the colors from textcolors are + applied. If None (the default) uses the middle of the colormap as + separation. Optional. + **kwargs + All other arguments are forwarded to each call to `text` used to create + the text labels. + """ + + if not isinstance(data, (list, np.ndarray)): + data = im.get_array() + + # Normalize the threshold to the images color range. + if threshold is not None: + threshold = im.norm(threshold) + else: + threshold = im.norm(data.max()) / 2.0 + + # Set default alignment to center, but allow it to be + # overwritten by textkw. + kw = dict(horizontalalignment="center", verticalalignment="center") + kw.update(textkw) + + # Get the formatter in case a string is supplied + if isinstance(valfmt, str): + valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) + + # Loop over the data and create a `Text` for each "pixel". + # Change the text's color depending on the data. + texts = [] + for i in range(data.shape[0]): + for j in range(data.shape[1]): + kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) + text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) + texts.append(text) + + return texts + + +if __name__ == "__main__": + debug_mode = False + main(debug_mode) diff --git a/marltoolbox/scripts/plot_scatter_figs_from_saved_results.py b/marltoolbox/scripts/plot_scatter_figs_from_saved_results.py new file mode 100644 index 0000000..6a84ce9 --- /dev/null +++ b/marltoolbox/scripts/plot_scatter_figs_from_saved_results.py @@ -0,0 +1,564 @@ +import json +import os +from collections import Iterable + +import matplotlib.pyplot as plt +import numpy as np + +from marltoolbox.scripts.plot_bar_chart_from_saved_results import ( + _get_inputs, + Perf, + Exp_data, + Final_values, + NA, + PLAYER_0, + PLAYER_1, + COLORS, + LEGEND_NO_SPLIT, + LEGEND, + REMOVE_STARS, + VALUE_TERM, +) + +plt.switch_backend("agg") +plt.style.use("seaborn-whitegrid") +plt.rcParams.update({"font.size": 12}) + + +def main(debug): + prefix, files_data, n_players = _get_inputs() + files_to_process = _preprocess_inputs(prefix, files_data) + + perf_per_mode_per_files = [] + for file_paths, file_data in zip(files_to_process, files_data): + perf_per_mode = _get_stats(file_paths, n_players, file_data) + perf_per_mode_per_files.append( + Exp_data(file_data.base_algo, file_data.env, perf_per_mode) + ) + + _plot_ipd_iasymbos(perf_per_mode_per_files) + _plot_all(perf_per_mode_per_files) + _plot_ipd_iasymbos(perf_per_mode_per_files, welfare_split=False) + _plot_all(perf_per_mode_per_files, welfare_split=False) + + +def _preprocess_inputs(prefix, files_data): + + # !!!! Remove Negotiation data (not plotted in scatter plot) !!!! + # (don't currently have the raw data to plot it) + files_data = files_data[:-1] + + files_to_process = [] + for file_data in files_data: + if isinstance(file_data, Final_values): + value = file_data + elif file_data.path_to_preferences is not None: + value = ( + os.path.join(prefix, file_data.path_to_self_play), + os.path.join(prefix, file_data.path_to_preferences), + ) + else: + value = ( + os.path.join(prefix, file_data.path_to_self_play), + None, + ) + files_to_process.append(value) + + return files_to_process + + +def _get_stats(file_paths, n_players, file_data): + if isinstance(file_paths, Final_values): + all_perf = file_paths + else: + self_play_path = file_paths[0] + perf_per_mode = _get_stats_for_file( + self_play_path, n_players, file_data + ) + self_play = perf_per_mode["self-play"] + cross_play = perf_per_mode["cross-play"] + + preference_path = file_paths[1] + if preference_path is not None: + perf_per_mode_bis = _get_stats_for_file( + preference_path, n_players, file_data + ) + same_preferences_cross_play = perf_per_mode_bis[ + "cross-play: same pref vs same pref" + ] + if ( + "cross-play: diff pref vs diff pref" + in perf_per_mode_bis.keys() + ): + diff_preferences_cross_play = perf_per_mode_bis[ + "cross-play: diff pref vs diff pref" + ] + else: + diff_preferences_cross_play = NA + else: + same_preferences_cross_play = NA + diff_preferences_cross_play = NA + + all_perf = [ + self_play, + cross_play, + same_preferences_cross_play, + diff_preferences_cross_play, + ] + return all_perf + + +def _get_stats_for_file(file, n_players, file_data): + perf_per_mode = {} + file_path = os.path.expanduser(file) + with (open(file_path, "rb")) as f: + file_content = json.load(f) + for eval_mode, mode_perf in file_content.items(): + perf = [None] * 2 + print("eval_mode", eval_mode) + for metric, metric_perf in mode_perf.items(): + player_idx = _extract_player_idx(metric) + + perf_per_replicat = np.array( + _convert_str_of_list_to_list(metric_perf["raw_data"]) + ) + + n_replicates_in_content = len(perf_per_replicat) + values_per_replicat_per_player = _scale_values( + perf_per_replicat, file_data + ) + + mean_per_player = values_per_replicat_per_player.mean(axis=0) + std_dev_per_player = values_per_replicat_per_player.std(axis=0) + std_err_per_player = std_dev_per_player / np.sqrt( + n_replicates_in_content + ) + perf[player_idx] = Perf( + mean_per_player, + std_dev_per_player, + std_err_per_player, + values_per_replicat_per_player, + ) + perf_per_mode[eval_mode] = perf + + return perf_per_mode + + +def _extract_player_idx(metric): + if "player_row" in metric: + player_idx = PLAYER_0 + elif "player_col" in metric: + player_idx = PLAYER_1 + elif "player_red" in metric: + player_idx = PLAYER_0 + elif "player_blue" in metric: + player_idx = PLAYER_1 + else: + raise ValueError() + return player_idx + + +def _scale_values(values_per_replicat_per_player, file_data): + scaled_values = ( + values_per_replicat_per_player / file_data.reward_adaptation_divider + ) + return scaled_values + + +def _convert_str_of_list_to_list(str_of_list): + return [ + float(v) + for v in str_of_list.replace("[", "") + .replace("]", "") + .replace(" ", "") + .split(",") + ] + + +def _plot_all(perf_per_mode_per_files, welfare_split=True): + n_figures = len(perf_per_mode_per_files) + n_row = int(np.sqrt(n_figures) + 0.99) + + n = 100 * n_row + 10 * n_row + 1 + plt.figure(figsize=(10, 10)) + + for i in range(n_figures): + plt.subplot(n + i) + data_idx = i + if perf_per_mode_per_files[i].env == "IPD": + xlim = (-3.5, 0.5) + ylim = (-3.5, 0.5) + jitter = 0.05 + background_area_coord = np.array( + [[[-1, -1], [-3, +0]], [[+0, -3], [-2, -2]]] + ) + _add_background_area(background_area_coord) + elif perf_per_mode_per_files[i].env == "CG": + xlim = (-0.1, 0.6) + ylim = (-0.1, 0.6) + jitter = 0.00 + elif perf_per_mode_per_files[i].env == "IAsymBoS": + xlim = (-0.5, 4.5) + ylim = (-0.5, 4.5) + jitter = 0.05 + background_area_coord = np.array( + [[[+4.0, +1.0], [+0.0, +0.0]], [[+0.0, +0.0], [+2.0, +2.0]]] + ) + _add_background_area(background_area_coord) + elif perf_per_mode_per_files[i].env == "ABCG": + xlim = (-0.1, 0.8) + ylim = (-0.1, 1.6) + jitter = 0.00 + _plot_one_scatter( + perf_per_mode_per_files, + data_idx, + xlim, + ylim, + jitter, + welfare_split=welfare_split, + plot_x_label=i // n_row == n_row - 1 or i == 5, + plot_y_label=i % n_row == 0, + ) + + if welfare_split: + plt.tight_layout(rect=[0, 0.10, 1.0, 1.0]) + # Save the figure and show + plt.legend( + LEGEND if welfare_split else LEGEND_NO_SPLIT, + frameon=True, + bbox_to_anchor=(1.15, -0.25), + ) + + else: + plt.tight_layout() + # Save the figure and show + plt.legend( + LEGEND if welfare_split else LEGEND_NO_SPLIT, + frameon=True, + bbox_to_anchor=(2.0, 0.55), + ) + plt.savefig(f"scatter_plots_all_split_{welfare_split}.png") + + +def _plot_ipd_iasymbos(perf_per_mode_per_files, welfare_split=True): + if welfare_split: + plt.figure(figsize=(5.8, 5 * 2 / 3)) + else: + plt.figure(figsize=(5, 2.5)) + + margin = 0.1 + + from matplotlib import gridspec + + if welfare_split: + gs = gridspec.GridSpec(1, 2, width_ratios=[1, 3]) + else: + gs = gridspec.GridSpec(1, 2, width_ratios=[2, 3]) + # ax0 = plt.subplot(gs[0]) + + # plt.subplot(121) + plt.subplot(gs[0]) + # , gridspec_kw = {"width_ratios": [1, 2]} + data_idx = 1 + xlim = (-3.0 - margin, 0.0 + margin) + if welfare_split: + ylim = (-3 - 3.0 - margin, +3 + 0.0 + margin) + else: + ylim = (-3.0 - margin, 0.0 + margin) + jitter = 0.05 + _plot_one_scatter( + perf_per_mode_per_files, + data_idx, + xlim, + ylim, + jitter, + welfare_split=welfare_split, + ) + background_area_coord = np.array( + [[[-1, -1], [-3, +0]], [[+0, -3], [-2, -2]]] + ) + _add_background_area(background_area_coord) + + # plt.subplot(122) + plt.subplot(gs[1]) + data_idx = 5 + if welfare_split: + xlim = (-0.0 - margin * 2, 4.0 + margin * 2) + ylim = (-0.0 - margin, 3.5 + margin) + else: + xlim = (-0.0 - margin * 2, 4.0 + margin * 2) + ylim = (-0.0 - margin, 2.0 + margin) + jitter = 0.05 + _plot_one_scatter( + perf_per_mode_per_files, + data_idx, + xlim, + ylim, + jitter, + welfare_split=welfare_split, + plot_y_label=False, + ) + background_area_coord = np.array( + [[[+4.0, +1.0], [+0.0, +0.0]], [[+0.0, +0.0], [+2.0, +2.0]]] + ) + _add_background_area(background_area_coord) + + # plt.text( + # -1.0, + # -1.50, + # "c)", + # fontdict={"fontsize": 14.0, "weight": "bold"}, + # ) + + plt.tight_layout(rect=[0, 0.0, 1.0, 1.0]) + if welfare_split: + plt.legend( + LEGEND if welfare_split else LEGEND_NO_SPLIT, + frameon=True, + bbox_to_anchor=(0.01, 0.0, 1.0, 1.0), + ) + else: + plt.legend( + LEGEND if welfare_split else LEGEND_NO_SPLIT, + frameon=True, + bbox_to_anchor=(-0.185, 0.46), + ) + # plt.tight_layout(rect=[0, -0.07, 1.0, 1.0]) + + # Save the figure and show + plt.savefig(f"scatter_plot_ipd_iasymbos_split_{welfare_split}.png") + + +def _plot_one_scatter( + perf_per_mode_per_files, + data_idx, + xlim, + ylim, + jitter, + plot_x_label=True, + plot_y_label=True, + welfare_split=True, +): + _plot(perf_per_mode_per_files, data_idx, jitter, welfare_split) + env = perf_per_mode_per_files[data_idx].env + base_algo = perf_per_mode_per_files[data_idx].base_algo + if REMOVE_STARS: + env = env.replace("*", "").strip() + base_algo = base_algo.replace("*", "").strip() + plt.title(f"{env} + " f"{base_algo}") + if plot_x_label: + plt.xlabel(f"Player 1 {VALUE_TERM}") + if plot_y_label: + plt.ylabel(f"Player 2 {VALUE_TERM}") + if xlim is not None: + plt.xlim(xlim) + if ylim is not None: + plt.ylim(ylim) + + +def _plot(perf_per_mode_per_files, data_idx, jitter, welfare_split): + all_perf = [el.perf for el in perf_per_mode_per_files] + + ( + self_play_p0, + self_play_p1, + cross_play_p0, + cross_play_p1, + same_pref_p0, + same_pref_p1, + diff_pref_p0, + diff_pref_p1, + ) = _preproces_values(all_perf, jitter) + + self_play_p0 = self_play_p0[data_idx] + self_play_p1 = self_play_p1[data_idx] + cross_play_p0 = cross_play_p0[data_idx] + cross_play_p1 = cross_play_p1[data_idx] + same_pref_p0 = same_pref_p0[data_idx] + same_pref_p1 = same_pref_p1[data_idx] + diff_pref_p0 = diff_pref_p0[data_idx] + diff_pref_p1 = diff_pref_p1[data_idx] + + print("cross_play_p0", len(cross_play_p0)) + print("np.array(same_pref_p0).shape", np.array(same_pref_p0).shape) + plot_diff_pref = True + if np.array(same_pref_p0).shape == (): + same_pref_p0 = cross_play_p0 + same_pref_p1 = cross_play_p1 + plot_diff_pref = False + + ( + self_play_p0, + self_play_p1, + cross_play_p0, + cross_play_p1, + same_pref_p0, + same_pref_p1, + diff_pref_p0, + diff_pref_p1, + ) = _keep_same_number_of_points( + self_play_p0, + self_play_p1, + cross_play_p0, + cross_play_p1, + same_pref_p0, + same_pref_p1, + diff_pref_p0, + diff_pref_p1, + ) + + plt.plot( + self_play_p0, + self_play_p1, + markerfacecolor="none", + markeredgecolor=COLORS[0], + linestyle="None", + marker="o", + color=COLORS[0], + # markersize=MARKERSIZE, + ) + if welfare_split: + plt.plot( + same_pref_p0, + same_pref_p1, + markerfacecolor="none", + markeredgecolor=COLORS[1], + linestyle="None", + marker="s", + color=COLORS[1], + # markersize=self.plot_cfg.markersize, + ) + if plot_diff_pref: + plt.plot( + diff_pref_p0, + diff_pref_p1, + markerfacecolor="none", + markeredgecolor=COLORS[2], + linestyle="None", + marker="v", + color=COLORS[2], + # markersize=self.plot_cfg.markersize, + ) + else: + plt.plot( + cross_play_p0, + cross_play_p1, + markerfacecolor="none", + markeredgecolor=COLORS[1], + linestyle="None", + marker="s", + color=COLORS[1], + # markersize=self.plot_cfg.markersize, + ) + + +def _keep_same_number_of_points(*args): + lengths = [len(list_) for list_ in args if isinstance(list_, Iterable)] + min_length = min(lengths) + new_lists = [] + for list_ in args: + if isinstance(list_, Iterable): + list_ = list_[-min_length:] + new_lists.append(list_) + + return new_lists + + +def _add_jitter(values, jitter): + values_wt_jitter = [] + for sub_list in values: + sub_list = np.array(sub_list) + shift = np.random.normal(0.0, jitter, sub_list.shape) + sub_list += shift + values_wt_jitter.append(sub_list.tolist()) + return values_wt_jitter + + +def _preproces_values(all_perf, jitter): + self_play_p0 = _extract_value(all_perf, 0, PLAYER_0, "raw", jitter) + self_play_p1 = _extract_value(all_perf, 0, PLAYER_1, "raw", jitter) + cross_play_p0 = _extract_value(all_perf, 1, PLAYER_0, "raw", jitter) + cross_play_p1 = _extract_value(all_perf, 1, PLAYER_1, "raw", jitter) + same_pref_p0 = _extract_value(all_perf, 2, PLAYER_0, "raw", jitter) + same_pref_p1 = _extract_value(all_perf, 2, PLAYER_1, "raw", jitter) + diff_pref_p0 = _extract_value(all_perf, 3, PLAYER_0, "raw", jitter) + diff_pref_p1 = _extract_value(all_perf, 3, PLAYER_1, "raw", jitter) + + return ( + self_play_p0, + self_play_p1, + cross_play_p0, + cross_play_p1, + same_pref_p0, + same_pref_p1, + diff_pref_p0, + diff_pref_p1, + ) + + +def _log_n_replicates( + self_play, + cross_play, + same_pref_perf, + diff_pref_perf, +): + print("\n_log_n_replicates") + print("self_play", [el.shape for el in self_play]) + print("cross_play", [el.shape for el in cross_play]) + print("same_pref_perf", [el.shape for el in same_pref_perf]) + print("diff_pref_perf", [el.shape for el in diff_pref_perf]) + + ratio = [] + for cross, cross_same, cross_diff in zip( + cross_play, same_pref_perf, diff_pref_perf + ): + if len(cross_same.shape) > 0: + assert cross.shape[0] == ( + cross_same.shape[0] + cross_diff.shape[0] + ) + ratio.append(cross_same.shape[0] / cross_diff.shape[0]) + else: + ratio.append(None) + print("cross_same / cross_diff", ratio) + + +def _extract_value(all_perf, idx, player_idx, attrib, jitter): + values = [] + for el in all_perf: + if isinstance(el, Final_values): + values.append(0.0) + else: + if hasattr(el[idx][player_idx], attrib): + values.append(getattr(el[idx][player_idx], attrib)) + else: + values.append(0.0) + + values = _add_jitter(values, jitter) + + return values + + +def _add_background_area(background_area_coord): + from scipy.spatial import ConvexHull + + assert background_area_coord.ndim == 3 + points_defining_area = background_area_coord.flatten().reshape(-1, 2) + area_hull = ConvexHull(points_defining_area) + plt.fill( + points_defining_area[area_hull.vertices, 0], + points_defining_area[area_hull.vertices, 1], + facecolor="none", + edgecolor="purple", + linewidth=1, + ) + plt.fill( + points_defining_area[area_hull.vertices, 0], + points_defining_area[area_hull.vertices, 1], + "purple", + alpha=0.05, + ) + + +if __name__ == "__main__": + debug_mode = False + main(debug_mode) diff --git a/marltoolbox/utils/callbacks.py b/marltoolbox/utils/callbacks.py index 1b5ad23..c66e84f 100644 --- a/marltoolbox/utils/callbacks.py +++ b/marltoolbox/utils/callbacks.py @@ -6,6 +6,7 @@ from ray.rllib.utils.typing import AgentID, PolicyID from marltoolbox.utils.miscellaneous import logger +from marltoolbox.utils import restore if TYPE_CHECKING: from ray.rllib.evaluation import RolloutWorker diff --git a/marltoolbox/utils/config_helper.py b/marltoolbox/utils/config_helper.py new file mode 100644 index 0000000..39fd01f --- /dev/null +++ b/marltoolbox/utils/config_helper.py @@ -0,0 +1,68 @@ +from collections.abc import Callable + +from ray import tune +from ray.rllib.utils import PiecewiseSchedule + +from marltoolbox.utils.miscellaneous import move_to_key + + +def get_temp_scheduler() -> Callable: + """ + Use the hyperparameter 'temperature_steps_config' stored inside the + env_config dict since there are no control done over the keys of this + dictionary. + This hyperparameter is a list of tuples List[Tuple]. Each tuple define + one step in the scheduler. + + :return: an easily customizable temperature scheduler + """ + return configurable_linear_scheduler("temperature_steps_config") + + +def get_lr_scheduler() -> Callable: + """ + Use the hyperparameter 'lr_steps_config' stored inside the + env_config dict since there are no control done over the keys of this + dictionary. + This hyperparameter is a list of tuples List[Tuple]. Each tuple define + one step in the scheduler. + + :return: an easily customizable temperature scheduler + """ + return configurable_linear_scheduler( + "lr_steps_config", second_term_key="lr" + ) + + +def configurable_linear_scheduler(config_key, second_term_key: str = None): + """Returns a configurable linear scheduler which use the hyperparameters + stop.episodes_total and config.env_config.max_steps fro the RLLib + config.""" + + return tune.sample_from( + lambda spec: PiecewiseSchedule( + endpoints=[ + ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * step_config[0] + ), + step_config[1], + ) + if second_term_key is None + else ( + int( + spec.config["env_config"]["max_steps"] + * spec.stop["episodes_total"] + * step_config[0] + ), + step_config[1] + * move_to_key(spec.config, second_term_key)[2], + ) + for step_config in spec.config["env_config"][config_key] + ], + outside_value=spec.config["env_config"][config_key][-1][1], + framework="torch", + ) + ) diff --git a/marltoolbox/utils/cross_play/__init__.py b/marltoolbox/utils/cross_play/__init__.py new file mode 100644 index 0000000..0db6d6f --- /dev/null +++ b/marltoolbox/utils/cross_play/__init__.py @@ -0,0 +1,2 @@ +import marltoolbox.utils.cross_play.utils +import marltoolbox.utils.cross_play.evaluator diff --git a/marltoolbox/utils/self_and_cross_perf.py b/marltoolbox/utils/cross_play/evaluator.py similarity index 69% rename from marltoolbox/utils/self_and_cross_perf.py rename to marltoolbox/utils/cross_play/evaluator.py index b710a70..e2e931e 100644 --- a/marltoolbox/utils/self_and_cross_perf.py +++ b/marltoolbox/utils/cross_play/evaluator.py @@ -4,24 +4,19 @@ import os import pickle import random -import warnings from typing import Dict -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -plt.style.use("seaborn-whitegrid") - import ray from ray import tune from ray.rllib.agents.pg import PGTrainer from ray.tune.analysis import ExperimentAnalysis from ray.tune.integration.wandb import WandbLogger from ray.tune.logger import DEFAULT_LOGGERS +from ray.tune.logger import SafeFallbackEncoder +from marltoolbox import utils from marltoolbox.utils import restore, log, miscellaneous -from marltoolbox.utils.plot import PlotHelper, PlotConfig +from marltoolbox.utils.cross_play import ploter logger = logging.getLogger(__name__) @@ -83,9 +78,9 @@ def perform_evaluation_or_load_data( evaluation_config, stop_config, policies_to_load_from_checkpoint, - tune_analysis_per_exp: list, - TrainerClass=PGTrainer, - TuneTrainerClass=None, + experiment_analysis_per_welfare: list, + rllib_trainer_class=PGTrainer, + tune_trainer_class=None, n_cross_play_per_checkpoint: int = 1, n_self_play_per_checkpoint: int = 1, to_load_path: str = None, @@ -98,12 +93,12 @@ def perform_evaluation_or_load_data( the checkpoints you are going to provide. :param stop_config: Normal stop_config argument provided to tune.run(). :param policies_to_load_from_checkpoint: - :param tune_analysis_per_exp: List of the tune_analysis you want to + :param experiment_analysis_per_welfare: List of the tune_analysis you want to extract the groups of checkpoints from. All the checkpoints in these tune_analysis will be extracted. - :param TrainerClass: (default is the PGTrainer class) Normal 1st argument (run_or_experiment) provided to + :param rllib_trainer_class: (default is the PGTrainer class) Normal 1st argument (run_or_experiment) provided to tune.run(). You should use the one which provides the data flow you need. (Probably a simple PGTrainer will do). - :param TuneTrainerClass: Will only be needed when you are going to evaluate policies created from a Tune + :param tune_trainer_class: Will only be needed when you are going to evaluate policies created from a Tune trainer. You need to provide the class of this trainer. :param n_cross_play_per_checkpoint: (int) How many cross-play experiment per checkpoint you want to run. They are run randomly against the other checkpoints. @@ -114,14 +109,14 @@ def perform_evaluation_or_load_data( """ if to_load_path is None: self.define_the_experiment_to_run( - TrainerClass=TrainerClass, - TuneTrainerClass=TuneTrainerClass, + TrainerClass=rllib_trainer_class, + TuneTrainerClass=tune_trainer_class, evaluation_config=evaluation_config, stop_config=stop_config, policies_to_load_from_checkpoint=policies_to_load_from_checkpoint, ) self.preload_checkpoints_from_tune_results( - tune_results=tune_analysis_per_exp + tune_results=experiment_analysis_per_welfare ) analysis_metrics_per_mode = self.evaluate_performances( n_self_play_per_checkpoint=n_self_play_per_checkpoint, @@ -139,7 +134,7 @@ def define_the_experiment_to_run( stop_config: dict, TuneTrainerClass=None, TrainerClass=PGTrainer, - policies_to_load_from_checkpoint: list = ["All"], + policies_to_load_from_checkpoint: list = ("All"), ): """ :param evaluation_config: Normal config argument provided to tune.run(). @@ -148,8 +143,9 @@ def define_the_experiment_to_run( :param stop_config: Normal stop_config argument provided to tune.run(). :param TuneTrainerClass: Will only be needed when you are going to evaluate policies created from a Tune trainer. You need to provide the class of this trainer. - :param TrainerClass: (default is the PGTrainer class) Normal 1st argument (run_or_experiment) provided to - tune.run(). You should use the one which provides the data flow you need. (Probably a simple PGTrainer will do). + :param TrainerClass: (default is the PGTrainer class) The usual 1st + argument provided to tune.run(). You should use the one which + provides the data flow you need. (Probably a simple PGTrainer will do). :param policies_to_load_from_checkpoint: """ @@ -165,14 +161,16 @@ def define_the_experiment_to_run( self.policies_ids_sorted = sorted( list(self.evaluation_config["multiagent"]["policies"].keys()) ) - self.policies_to_load_from_checkpoint = sorted( - [ - policy_id - for policy_id in self.policies_ids_sorted - if self._is_policy_to_load( - policy_id, policies_to_load_from_checkpoint - ) - ] + self.policies_to_load_from_checkpoint = tuple( + sorted( + [ + policy_id + for policy_id in self.policies_ids_sorted + if self._is_policy_to_load( + policy_id, policies_to_load_from_checkpoint + ) + ] + ) ) self.experiment_defined = True @@ -218,8 +216,10 @@ def _extract_groups_of_checkpoints( def _extract_one_group_of_checkpoints( self, one_tune_result: ExperimentAnalysis, group_name ): - checkpoints_in_one_group = miscellaneous.extract_checkpoints( - one_tune_result + checkpoints_in_one_group = ( + utils.restore.extract_checkpoints_from_experiment_analysis( + one_tune_result + ) ) self.checkpoints.extend( [ @@ -370,7 +370,7 @@ def _save_results_as_json(self, available_metrics_list): save_path = self.save_path.split(".")[0:-1] + ["json"] save_path = ".".join(save_path) with open(save_path, "w") as outfile: - json.dump(metrics, outfile) + json.dump(metrics, outfile, cls=SafeFallbackEncoder) def load_results(self, to_load_path): assert to_load_path.endswith(self.results_file_name), ( @@ -443,23 +443,6 @@ def _select_opponent_randomly( ) return opponents - def _split_results_per_mode_and_group_pair_id( - self, all_metadata_wt_results - ): - analysis_per_mode = [] - - metadata_per_modes = self._split_metadata_per_mode( - all_metadata_wt_results - ) - for mode, metadata_for_one_mode in metadata_per_modes.items(): - analysis_per_mode.extend( - self._split_metadata_per_group_pair_id( - metadata_for_one_mode, mode - ) - ) - - return analysis_per_mode - def _split_metadata_per_mode(self, all_results): return { mode: [report for report in all_results if report["mode"] == mode] @@ -469,36 +452,66 @@ def _split_metadata_per_mode(self, all_results): def _split_metadata_per_group_pair_id(self, metadata_for_one_mode, mode): analysis_per_group_pair_id = [] - tune_analysis = [ + experiment_analysis = [ metadata["results"] for metadata in metadata_for_one_mode ] - group_pair_names = [ + pairs_of_group_names = [ self._get_pair_of_group_names(metadata) for metadata in metadata_for_one_mode ] - group_pair_ids = [ - self._get_id_of_pair_of_group_names(one_pair_of_names) - for one_pair_of_names in group_pair_names + ids_of_pairs_of_groups = [ + self._get_id_of_pair_of_group_names(one_pair_of_group_names) + for one_pair_of_group_names in pairs_of_group_names ] - group_pair_ids_in_this_mode = sorted(set(group_pair_ids)) + group_pair_ids_in_this_mode = sorted(set(ids_of_pairs_of_groups)) for group_pair_id in list(group_pair_ids_in_this_mode): ( filtered_analysis_list, - one_pair_of_names, + one_pair_of_group_names, ) = self._find_and_group_results_for_one_group_pair_id( - group_pair_id, tune_analysis, group_pair_ids, group_pair_names + group_pair_id, + experiment_analysis, + ids_of_pairs_of_groups, + pairs_of_group_names, ) analysis_per_group_pair_id.append( ( mode, filtered_analysis_list, group_pair_id, - one_pair_of_names, + one_pair_of_group_names, ) ) return analysis_per_group_pair_id + def _get_pair_of_group_names(self, metadata): + # checkpoints_idx_used = [ + # metadata[policy_id]["checkpoint_i"] + # for policy_id in self.policies_to_load_from_checkpoint + # ] + # pair_of_group_names = [ + # self.checkpoints[checkpoint_i]["group_name"] + # for checkpoint_i in checkpoints_idx_used + # ] + checkpoints_idx_used = { + policy_id: metadata[policy_id]["checkpoint_i"] + for policy_id in self.policies_to_load_from_checkpoint + } + pair_of_group_names = { + policy_id: self.checkpoints[checkpoint_i]["group_name"] + for policy_id, checkpoint_i in checkpoints_idx_used.items() + } + return pair_of_group_names + + def _get_id_of_pair_of_group_names(self, pair_of_group_names): + ordered_pair_of_group_names = [ + pair_of_group_names[policy_id] + for policy_id in self.policies_to_load_from_checkpoint + ] + id_of_pair_of_group_names = "".join(ordered_pair_of_group_names) + return id_of_pair_of_group_names + def _find_and_group_results_for_one_group_pair_id( self, group_pair_id, tune_analysis, group_pair_ids, group_pair_names ): @@ -519,20 +532,6 @@ def _find_and_group_results_for_one_group_pair_id( return filtered_tune_analysis, one_pair_of_names - def _extract_all_metrics(self, analysis_per_mode): - analysis_metrics_per_mode = [] - for mode_i, mode_data in enumerate(analysis_per_mode): - mode, analysis_list, group_pair_id, group_pair_name = mode_data - - available_metrics_list = [] - for trial in analysis_list: - available_metrics = trial.metric_analysis - available_metrics_list.append(available_metrics) - analysis_metrics_per_mode.append( - (mode, available_metrics_list, group_pair_id, group_pair_name) - ) - return analysis_metrics_per_mode - def _group_results_and_extract_metrics(self, all_metadata_wt_results): # TODO improve the design to remove these unclear names analysis_per_mode_per_group_pair_id = ( @@ -545,20 +544,36 @@ def _group_results_and_extract_metrics(self, all_metadata_wt_results): ) return analysis_metrics_per_mode_per_group_pair_id - def _get_id_of_pair_of_group_names(self, pair_of_group_names): - id_of_pair_of_group_names = "".join(pair_of_group_names) - return id_of_pair_of_group_names + def _split_results_per_mode_and_group_pair_id( + self, all_metadata_wt_results + ): + analysis_per_mode = [] - def _get_pair_of_group_names(self, metadata): - checkpoints_idx_used = [ - metadata[policy_id]["checkpoint_i"] - for policy_id in self.policies_to_load_from_checkpoint - ] - pair_of_group_names = [ - self.checkpoints[checkpoint_i]["group_name"] - for checkpoint_i in checkpoints_idx_used - ] - return pair_of_group_names + metadata_per_modes = self._split_metadata_per_mode( + all_metadata_wt_results + ) + for mode, metadata_for_one_mode in metadata_per_modes.items(): + analysis_per_mode.extend( + self._split_metadata_per_group_pair_id( + metadata_for_one_mode, mode + ) + ) + + return analysis_per_mode + + def _extract_all_metrics(self, analysis_per_mode): + analysis_metrics_per_mode = [] + for mode_i, mode_data in enumerate(analysis_per_mode): + mode, analysis_list, group_pair_id, group_pair_name = mode_data + + available_metrics_list = [] + for trial in analysis_list: + available_metrics = trial.metric_analysis + available_metrics_list.append(available_metrics) + analysis_metrics_per_mode.append( + (mode, available_metrics_list, group_pair_id, group_pair_name) + ) + return analysis_metrics_per_mode def plot_results( self, @@ -567,213 +582,181 @@ def plot_results( x_axis_metric, y_axis_metric, ): - plotter = SelfAndCrossPlayPlotter() - return plotter.plot_results( - exp_parent_dir=self.exp_parent_dir, - metrics_per_mode=analysis_metrics_per_mode, - plot_config=plot_config, - x_axis_metric=x_axis_metric, - y_axis_metric=y_axis_metric, + + vanilla_plot_path = self._plot_as_provided( + analysis_metrics_per_mode, + plot_config, + x_axis_metric, + y_axis_metric, ) + if plot_config.plot_max_n_points is not None: + plot_config.plot_max_n_points *= 2 -class SelfAndCrossPlayPlotter: - def __init__(self): - self.x_axis_metric = None - self.y_axis_metric = None - self.metric_mode = None - self.stat_summary = None - self.data_groups_per_mode = None + self._plot_merge_self_cross( + analysis_metrics_per_mode, + plot_config, + x_axis_metric, + y_axis_metric, + ) + self._plot_merge_same_and_diff_pref( + analysis_metrics_per_mode, + plot_config, + x_axis_metric, + y_axis_metric, + ) + return vanilla_plot_path - def plot_results( + def _plot_as_provided( self, - exp_parent_dir: str, - x_axis_metric: str, - y_axis_metric: str, - metrics_per_mode: list, - plot_config: PlotConfig, - metric_mode: str = "avg", + analysis_metrics_per_mode, + plot_config, + x_axis_metric, + y_axis_metric, ): - self._reset(x_axis_metric, y_axis_metric, metric_mode) - for metrics_for_one_evaluation_mode in metrics_per_mode: - self._extract_performance_evaluation_points( - metrics_for_one_evaluation_mode - ) - self.stat_summary.save_summary( - filename_prefix=RESULTS_SUMMARY_FILENAME_PREFIX, - folder_dir=exp_parent_dir, + vanilla_plot_path = self._plot_one_time( + analysis_metrics_per_mode, + plot_config, + x_axis_metric, + y_axis_metric, ) - return self._plot_and_save_fig(plot_config, exp_parent_dir) + return vanilla_plot_path - def _reset(self, x_axis_metric, y_axis_metric, metric_mode): - self.x_axis_metric = x_axis_metric - self.y_axis_metric = y_axis_metric - self.metric_mode = metric_mode - self.stat_summary = StatisticSummary( - self.x_axis_metric, self.y_axis_metric, self.metric_mode + def _plot_merge_self_cross( + self, + analysis_metrics_per_mode, + plot_config, + x_axis_metric, + y_axis_metric, + ): + self._plot_one_time_with_prefix_and_preprocess( + "_self_cross", + "_merge_into_self_and_cross", + analysis_metrics_per_mode, + plot_config, + x_axis_metric, + y_axis_metric, ) - self.data_groups_per_mode = {} - def _extract_performance_evaluation_points( - self, metrics_for_one_evaluation_mode + def _plot_merge_same_and_diff_pref( + self, + analysis_metrics_per_mode, + plot_config, + x_axis_metric, + y_axis_metric, ): - ( - mode, - available_metrics_list, - group_pair_id, - group_pair_name, - ) = metrics_for_one_evaluation_mode - - label = self._get_label(mode, group_pair_name) - x, y = self._extract_x_y_points(available_metrics_list) - - self.stat_summary.aggregate_stats_on_data_points(x, y, label) - self.data_groups_per_mode[label] = self._format_as_df(x, y) - print("x, y", x, y) - - def _get_label(self, mode, group_pair_name): - # For backward compatibility - if mode == "Same-play" or mode == "same training run": - mode = self.SELF_PLAY_MODE - elif mode == "Cross-play" or mode == "cross training run": - mode = self.CROSS_PLAY_MODE - - print("Evaluator mode:", mode) - if self._suffix_needed(group_pair_name): - print("Using group_pair_name:", group_pair_name) - label = f"{mode}: " + " vs ".join(group_pair_name) - else: - label = mode - label = label.replace("_", " ") - print("label", label) - return label - def _suffix_needed(self, group_pair_name): - return ( - group_pair_name is not None - and all([name is not None for name in group_pair_name]) - and all(group_pair_name) + self._plot_one_time_with_prefix_and_preprocess( + "_same_and_diff_pref", + "_merge_into_cross_same_pref_diff_pref", + analysis_metrics_per_mode, + plot_config, + x_axis_metric, + y_axis_metric, ) - def _extract_x_y_points(self, available_metrics_list): - x, y = [], [] - assert len(available_metrics_list) > 0 - random.shuffle(available_metrics_list) - - for available_metrics in available_metrics_list: - if self.x_axis_metric in available_metrics.keys(): - x_point = available_metrics[self.x_axis_metric][ - self.metric_mode - ] - else: - x_point = 123456789 - warnings.warn( - f"x_axis_metric {self.x_axis_metric}" - " not in available_metrics " - f"{available_metrics.keys()}" - ) - if self.y_axis_metric in available_metrics.keys(): - y_point = available_metrics[self.y_axis_metric][ - self.metric_mode - ] + def _plot_one_time_with_prefix_and_preprocess( + self, + prefix: str, + preprocess: str, + analysis_metrics_per_mode, + plot_config, + x_axis_metric, + y_axis_metric, + ): + initial_filename_prefix = plot_config.filename_prefix + plot_config.filename_prefix += prefix + metrics_for_same_pref_diff_pref = getattr(self, preprocess)( + analysis_metrics_per_mode + ) + self._plot_one_time( + metrics_for_same_pref_diff_pref, + plot_config, + x_axis_metric, + y_axis_metric, + ) + plot_config.filename_prefix = initial_filename_prefix + + def _merge_into_self_and_cross(self, analysis_metrics_per_mode): + metrics_for_self_and_cross = [] + for selected_play_mode in self.MODES: + metrics_for_self_and_cross.append( + (selected_play_mode, [], "", None) + ) + for ( + play_mode, + metrics, + pair_name, + pair_tuple, + ) in analysis_metrics_per_mode: + if play_mode == selected_play_mode: + metrics_for_self_and_cross[-1][1].extend(metrics) + return metrics_for_self_and_cross + + def _merge_into_cross_same_pref_diff_pref(self, analysis_metrics_per_mode): + analysis_metrics_per_mode_wtout_self_play = self._copy_wtout_self_play( + analysis_metrics_per_mode + ) + one_pair_of_group_names = analysis_metrics_per_mode[0][3] + metrics_for_same_pref_diff_pref = [ + ( + self.CROSS_PLAY_MODE, + [], + "same_prefsame_pref", + {k: "same_pref" for k in one_pair_of_group_names.keys()}, + ), + ( + self.CROSS_PLAY_MODE, + [], + "diff_prefdiff_pref", + {k: "diff_pref" for k in one_pair_of_group_names.keys()}, + ), + ] + for ( + play_mode, + metrics, + pair_name, + pair_of_group_names, + ) in analysis_metrics_per_mode_wtout_self_play: + groups_names = list(pair_of_group_names.values()) + assert len(groups_names) == 2 + if groups_names[0] == groups_names[1]: + metrics_for_same_pref_diff_pref[0][1].extend(metrics) else: - y_point = 123456789 - warnings.warn( - f"y_axis_metric {self.y_axis_metric}" - " not in available_metrics " - f"{available_metrics.keys()}" + metrics_for_same_pref_diff_pref[1][1].extend(metrics) + if len(metrics_for_same_pref_diff_pref[1][1]) == 0: + metrics_for_same_pref_diff_pref.pop(1) + if len(metrics_for_same_pref_diff_pref[0][1]) == 0: + metrics_for_same_pref_diff_pref.pop(0) + return metrics_for_same_pref_diff_pref + + def _copy_wtout_self_play(self, analysis_metrics_per_mode): + analysis_metrics_per_mode_wtout_self_play = [] + for ( + play_mode, + metrics, + pair_name, + pair_tuple, + ) in analysis_metrics_per_mode: + if play_mode == self.CROSS_PLAY_MODE: + analysis_metrics_per_mode_wtout_self_play.append( + (self.CROSS_PLAY_MODE, metrics, pair_name, pair_tuple) ) - x.append(x_point) - y.append(y_point) - return x, y - - def _format_as_df(self, x, y): - group_df_dict = { - "": [ - (one_x_point, one_y_point) - for one_x_point, one_y_point in zip(x, y) - ] - } - group_df = pd.DataFrame(group_df_dict) - return group_df - - def _plot_and_save_fig(self, plot_config, exp_parent_dir): - plot_helper = PlotHelper(plot_config) - plot_helper.plot_cfg.save_dir_path = exp_parent_dir - return plot_helper.plot_dots(self.data_groups_per_mode) - - -class StatisticSummary: - def __init__(self, x_axis_metric, y_axis_metric, metric_mode): - self.x_means, self.x_se, self.x_labels, self.x_raw = [], [], [], [] - self.y_means, self.y_se, self.y_labels, self.y_raw = [], [], [], [] - self.matrix_label = [] - self.x_axis_metric, self.y_axis_metric = x_axis_metric, y_axis_metric - self.metric_mode = metric_mode - - def aggregate_stats_on_data_points(self, x, y, label): - # TODO refactor that to use a data structure - # (like per metric and per plot?) - self.x_means.append(sum(x) / len(x)) - self.x_se.append(np.array(x).std() / np.sqrt(len(x))) - self.x_labels.append( - f"Metric:{self.x_axis_metric}, " f"Metric mode:{self.metric_mode}" - ) - - self.y_means.append(sum(y) / len(y)) - self.y_se.append(np.array(y).std() / np.sqrt(len(y))) - self.y_labels.append( - f"Metric:{self.y_axis_metric}, " f"Metric mode:{self.metric_mode}" - ) - - self.matrix_label.append(label) - self.x_raw.append(x) - self.y_raw.append(y) - - def save_summary(self, filename_prefix, folder_dir): - file_name = ( - f"{filename_prefix}_{self.y_axis_metric}_" - f"vs_{self.x_axis_metric}_matrix.json" - ) - file_name = file_name.replace("/", "_") - file_path = os.path.join(folder_dir, file_name) - formated_data = {} - for step_i in range(len(self.x_means)): - ( - x_mean, - x_std_err, - x_lbl, - y_mean, - y_std_err, - y_lbl, - lbl, - x, - y, - ) = self._get_values_from_a_data_point(step_i) - formated_data[lbl] = { - x_lbl: { - "mean": x_mean, - "std_err": x_std_err, - "raw_data": str(x), - }, - y_lbl: { - "mean": y_mean, - "std_err": y_std_err, - "raw_data": str(y), - }, - } - with open(file_path, "w") as f: - json.dump(formated_data, f, indent=4, sort_keys=True) + return analysis_metrics_per_mode_wtout_self_play - def _get_values_from_a_data_point(self, step_i): - return ( - self.x_means[step_i], - self.x_se[step_i], - self.x_labels[step_i], - self.y_means[step_i], - self.y_se[step_i], - self.y_labels[step_i], - self.matrix_label[step_i], - self.x_raw[step_i], - self.y_raw[step_i], + def _plot_one_time( + self, + analysis_metrics_per_mode, + plot_config, + x_axis_metric, + y_axis_metric, + ): + plotter = ploter.SelfAndCrossPlayPlotter() + plot_path = plotter.plot_results( + exp_parent_dir=self.exp_parent_dir, + metrics_per_mode=analysis_metrics_per_mode, + plot_config=plot_config, + x_axis_metric=x_axis_metric, + y_axis_metric=y_axis_metric, ) + return plot_path diff --git a/marltoolbox/utils/cross_play/ploter.py b/marltoolbox/utils/cross_play/ploter.py new file mode 100644 index 0000000..a6a218b --- /dev/null +++ b/marltoolbox/utils/cross_play/ploter.py @@ -0,0 +1,168 @@ +import logging +import random +import copy +import pandas as pd + +from marltoolbox.utils.cross_play import evaluator +from marltoolbox.utils.cross_play.stats_summary import StatisticSummary +from marltoolbox.utils.plot import PlotHelper, PlotConfig + +logger = logging.getLogger(__name__) + + +class SelfAndCrossPlayPlotter: + def __init__(self): + self.x_axis_metric = None + self.y_axis_metric = None + self.metric_mode = None + self.stat_summary = None + self.data_groups_per_mode = None + + def plot_results( + self, + exp_parent_dir: str, + x_axis_metric: str, + y_axis_metric: str, + metrics_per_mode: list, + plot_config: PlotConfig, + metric_mode: str = "avg", + ): + self._reset(x_axis_metric, y_axis_metric, metric_mode) + for metrics_for_one_evaluation_mode in metrics_per_mode: + self._extract_performance_evaluation_points( + metrics_for_one_evaluation_mode + ) + stat_summary_filename_prefix = ( + plot_config.filename_prefix + + evaluator.RESULTS_SUMMARY_FILENAME_PREFIX + ) + self.stat_summary.save_summary( + filename_prefix=stat_summary_filename_prefix, + folder_dir=exp_parent_dir, + ) + return self._plot_and_save_fig(plot_config, exp_parent_dir) + + def _reset(self, x_axis_metric, y_axis_metric, metric_mode): + self.x_axis_metric = x_axis_metric + self.y_axis_metric = y_axis_metric + self.metric_mode = metric_mode + self.stat_summary = StatisticSummary( + self.x_axis_metric, self.y_axis_metric, self.metric_mode + ) + self.data_groups_per_mode = {} + + def _extract_performance_evaluation_points( + self, metrics_for_one_evaluation_mode + ): + ( + mode, + available_metrics_list, + group_pair_id, + group_pair_name, + ) = metrics_for_one_evaluation_mode + + label = self._get_label(mode, group_pair_name) + x, y = self._extract_x_y_points(available_metrics_list) + + self.stat_summary.aggregate_stats_on_data_points(x, y, label) + self.data_groups_per_mode[label] = self._format_as_df(x, y) + print("x, y", x, y) + + def _get_label(self, mode, group_pair_name): + + print("Evaluator mode:", mode) + if self._suffix_needed(group_pair_name): + ordered_group_pair_name = self._order_group_names(group_pair_name) + print( + "Using ordered_group_pair_name:", + ordered_group_pair_name, + "from group_pair_name:", + group_pair_name, + ) + label = f"{mode}: " + " vs ".join(ordered_group_pair_name) + else: + label = mode + label = label.replace("_", " ") + print("label", label) + return label + + def _suffix_needed(self, group_pair_name): + if group_pair_name is None: + return False + return all( + [name is not None for name in group_pair_name.values()] + ) and all(group_pair_name.values()) + + def _order_group_names(self, group_pair_name_original): + group_pair_name = copy.deepcopy(group_pair_name_original) + ordered_group_pair_name = [] + for metric in (self.x_axis_metric, self.y_axis_metric): + for policy_id, one_group_name in group_pair_name.items(): + print( + "_order_group_names policy_id in metric", policy_id, metric + ) + if policy_id in metric: + ordered_group_pair_name.append(one_group_name) + group_pair_name.pop(policy_id) + break + assert len(group_pair_name.keys()) == 0, ( + "group_pair_name_original.keys() " + f"{group_pair_name_original.keys()} not in the metrics provided: " + "(self.x_axis_metric, self.y_axis_metric) " + f"{(self.x_axis_metric, self.y_axis_metric)}" + ) + return ordered_group_pair_name + + def _extract_x_y_points(self, available_metrics_list): + x, y = [], [] + assert len(available_metrics_list) > 0 + random.shuffle(available_metrics_list) + + for available_metrics in available_metrics_list: + if self.x_axis_metric in available_metrics.keys(): + x_point = available_metrics[self.x_axis_metric][ + self.metric_mode + ] + else: + x_point = 123456789 + from ray.util.debug import log_once + + msg = ( + f"x_axis_metric {self.x_axis_metric}" + " not in available_metrics " + f"{available_metrics.keys()}" + ) + if log_once(msg): + logger.warning(msg) + + if self.y_axis_metric in available_metrics.keys(): + y_point = available_metrics[self.y_axis_metric][ + self.metric_mode + ] + else: + y_point = 123456789 + msg = ( + f"y_axis_metric {self.y_axis_metric}" + " not in available_metrics " + f"{available_metrics.keys()}" + ) + if log_once(msg): + logger.warning(msg) + x.append(x_point) + y.append(y_point) + return x, y + + def _format_as_df(self, x, y): + group_df_dict = { + "": [ + (one_x_point, one_y_point) + for one_x_point, one_y_point in zip(x, y) + ] + } + group_df = pd.DataFrame(group_df_dict) + return group_df + + def _plot_and_save_fig(self, plot_config, exp_parent_dir): + plot_helper = PlotHelper(plot_config) + plot_helper.plot_cfg.save_dir_path = exp_parent_dir + return plot_helper.plot_dots(self.data_groups_per_mode) diff --git a/marltoolbox/utils/cross_play/stats_summary.py b/marltoolbox/utils/cross_play/stats_summary.py new file mode 100644 index 0000000..a24dafa --- /dev/null +++ b/marltoolbox/utils/cross_play/stats_summary.py @@ -0,0 +1,80 @@ +import json +import os + +import numpy as np + + +class StatisticSummary: + def __init__(self, x_axis_metric, y_axis_metric, metric_mode): + self.x_means, self.x_se, self.x_labels, self.x_raw = [], [], [], [] + self.y_means, self.y_se, self.y_labels, self.y_raw = [], [], [], [] + self.matrix_label = [] + self.x_axis_metric, self.y_axis_metric = x_axis_metric, y_axis_metric + self.metric_mode = metric_mode + + def aggregate_stats_on_data_points(self, x, y, label): + # TODO refactor that to use a data structure + # (like per metric and per plot?) + self.x_means.append(sum(x) / len(x)) + self.x_se.append(np.array(x).std() / np.sqrt(len(x))) + self.x_labels.append( + f"Metric:{self.x_axis_metric}, " f"Metric mode:{self.metric_mode}" + ) + + self.y_means.append(sum(y) / len(y)) + self.y_se.append(np.array(y).std() / np.sqrt(len(y))) + self.y_labels.append( + f"Metric:{self.y_axis_metric}, " f"Metric mode:{self.metric_mode}" + ) + + self.matrix_label.append(label) + self.x_raw.append(x) + self.y_raw.append(y) + + def save_summary(self, filename_prefix, folder_dir): + file_name = ( + f"{filename_prefix}_{self.y_axis_metric}_" + f"vs_{self.x_axis_metric}_matrix.json" + ) + file_name = file_name.replace("/", "_") + file_path = os.path.join(folder_dir, file_name) + formated_data = {} + for step_i in range(len(self.x_means)): + ( + x_mean, + x_std_err, + x_lbl, + y_mean, + y_std_err, + y_lbl, + lbl, + x, + y, + ) = self._get_values_from_a_data_point(step_i) + formated_data[lbl] = { + x_lbl: { + "mean": x_mean, + "std_err": x_std_err, + "raw_data": str(x), + }, + y_lbl: { + "mean": y_mean, + "std_err": y_std_err, + "raw_data": str(y), + }, + } + with open(file_path, "w") as f: + json.dump(formated_data, f, indent=4, sort_keys=True) + + def _get_values_from_a_data_point(self, step_i): + return ( + self.x_means[step_i], + self.x_se[step_i], + self.x_labels[step_i], + self.y_means[step_i], + self.y_se[step_i], + self.y_labels[step_i], + self.matrix_label[step_i], + self.x_raw[step_i], + self.y_raw[step_i], + ) diff --git a/marltoolbox/utils/cross_play/utils.py b/marltoolbox/utils/cross_play/utils.py new file mode 100644 index 0000000..0bf6d55 --- /dev/null +++ b/marltoolbox/utils/cross_play/utils.py @@ -0,0 +1,112 @@ +import copy +import random +from typing import Dict, List + +from ray import tune + + +def mix_policies_in_given_rllib_configs( + all_rllib_configs: List[Dict], n_mix_per_config: int +) -> dict: + """ + Mix the policies of a list of RLLib config dictionaries. Limited to + RLLib config with 2 policies. (Not used by the SelfAndCrossPlayEvaluator) + + + :param all_rllib_configs: all rllib config + :param n_mix_per_config: number of mix to create for each rllib config + provided + :return: a single rllib config with a grid search over all the mixed + pair of policies + """ + assert ( + n_mix_per_config <= len(all_rllib_configs) - 1 + ), f" {n_mix_per_config} <= {len(all_rllib_configs) - 1}" + policy_ids = all_rllib_configs[0]["multiagent"]["policies"].keys() + assert len(policy_ids) == 2, ( + "only supporting config dict with 2 RLLib " "policies" + ) + _assert_all_config_use_the_same_policies(all_rllib_configs, policy_ids) + + policy_config_variants = _gather_policy_variant_per_policy_id( + all_rllib_configs, policy_ids + ) + + master_config = _create_one_master_config( + all_rllib_configs, policy_config_variants, policy_ids, n_mix_per_config + ) + return master_config + + +def _create_one_master_config( + all_rllib_configs, policy_config_variants, policy_ids, n_mix_per_config +): + all_policy_mix = [] + player_1, player_2 = policy_ids + for config_idx, p1_policy_config in enumerate( + policy_config_variants[player_1] + ): + policies_mixes = _produce_n_mix_with_player_2_policies( + policy_config_variants, + player_2, + config_idx, + n_mix_per_config, + player_1, + p1_policy_config, + ) + all_policy_mix.extend(policies_mixes) + + master_config = copy.deepcopy(all_rllib_configs[0]) + print("len(all_policy_mix)", len(all_policy_mix)) + master_config["multiagent"]["policies"] = tune.grid_search(all_policy_mix) + return master_config + + +def _produce_n_mix_with_player_2_policies( + policy_config_variants, + player_2, + config_idx, + n_mix_per_config, + player_1, + p1_policy_config, +): + p2_policy_configs_sampled = _get_p2_policies_samples_excluding_self( + policy_config_variants, player_2, config_idx, n_mix_per_config + ) + policies_mixes = [] + for p2_policy_config in p2_policy_configs_sampled: + policy_mix = { + player_1: p1_policy_config, + player_2: p2_policy_config, + } + policies_mixes.append(policy_mix) + return policies_mixes + + +def _get_p2_policies_samples_excluding_self( + policy_config_variants, player_2, config_idx, n_mix_per_config +): + p2_policy_config_variants = copy.deepcopy(policy_config_variants[player_2]) + p2_policy_config_variants.pop(config_idx) + p2_policy_configs_sampled = random.sample( + p2_policy_config_variants, n_mix_per_config + ) + return p2_policy_configs_sampled + + +def _assert_all_config_use_the_same_policies(all_rllib_configs, policy_ids): + for rllib_config in all_rllib_configs: + assert rllib_config["multiagent"]["policies"].keys() == policy_ids + + +def _gather_policy_variant_per_policy_id(all_rllib_configs, policy_ids): + policy_config_variants = {} + for policy_id in policy_ids: + policy_config_variants[policy_id] = [] + for rllib_config in all_rllib_configs: + policy_config_variants[policy_id].append( + copy.deepcopy( + rllib_config["multiagent"]["policies"][policy_id] + ) + ) + return policy_config_variants diff --git a/marltoolbox/utils/exp_analysis.py b/marltoolbox/utils/exp_analysis.py new file mode 100644 index 0000000..ab14575 --- /dev/null +++ b/marltoolbox/utils/exp_analysis.py @@ -0,0 +1,298 @@ +import logging + +from ray.tune import Trainable +from ray.tune import register_trainable +from ray.tune.analysis.experiment_analysis import ExperimentAnalysis +from ray.tune.checkpoint_manager import Checkpoint +from ray.tune.trial import Trial + +from marltoolbox.utils.miscellaneous import ( + move_to_key, + _get_experiment_state_file_path, +) + +logger = logging.getLogger(__name__) + + +def extract_value_from_last_training_iteration_for_each_trials( + experiment_analysis, + metric="episode_reward_mean", +): + metric_values = [] + for trial in experiment_analysis.trials: + last_results = trial.last_result + _, _, value, found = move_to_key(last_results, key=metric) + assert ( + found + ), f"metric: {metric} not found in last_results: {last_results}" + metric_values.append(value) + return metric_values + + +def extract_metrics_for_each_trials( + experiment_analysis, + metric="episode_reward_mean", + metric_mode="avg", +): + metric_values = [] + for trial in experiment_analysis.trials: + dict_ = trial.metric_analysis + for sub_key in metric.split("."): + try: + dict_ = dict_[sub_key] + except KeyError: + raise KeyError( + f"sub_key {sub_key} not in dict_.keys() {dict_.keys()}" + ) + + value = dict_[metric_mode] + metric_values.append(value) + + return metric_values + + +def check_learning_achieved( + tune_results, + metric="episode_reward_mean", + trial_idx=0, + max_: float = None, + min_: float = None, + equal_: float = None, +): + assert max_ is not None or min_ is not None or equal_ is not None + + last_results = tune_results.trials[trial_idx].last_result + _, _, value, found = move_to_key(last_results, key=metric) + assert ( + found + ), f"metric {metric} not found inside last_results {last_results}" + + msg = ( + f"Trial {trial_idx} achieved " + f"{value}" + f" on metric {metric}. This is a success if the value is below" + f" {max_} or above {min_} or equal to {equal_}." + ) + + logger.info(msg) + print(msg) + if min_ is not None: + assert value >= min_, f"value {value} must be above min_ {min_}" + if max_ is not None: + assert value <= max_, f"value {value} must be below max_ {max_}" + if equal_ is not None: + assert value == equal_, ( + f"value {value} must be equal to equal_ " f"{equal_}" + ) + + +def extract_config_values_from_experiment_analysis( + tune_experiment_analysis, key +): + values = [] + for trial in tune_experiment_analysis.trials: + dict_, k, current_value, found = move_to_key(trial.config, key) + if found: + values.append(current_value) + else: + values.append(None) + return values + + +ABOVE = "above" +EQUAL = "equal" +BELOW = "below" +FILTERING_MODES = (ABOVE, EQUAL, BELOW) + +RLLIB_METRICS_MODES = ( + "avg", + "min", + "max", + "last", + "last-5-avg", + "last-10-avg", +) + + +def filter_trials( + experiment_analysis, + metric, + metric_threshold: float, + metric_mode="last-5-avg", + threshold_mode=ABOVE, +): + """ + Filter trials of an ExperimentAnalysis + + :param experiment_analysis: + :param metric: + :param metric_threshold: + :param metric_mode: + :param threshold_mode: + :return: + """ + assert threshold_mode in FILTERING_MODES, ( + f"threshold_mode {threshold_mode} " f"must be in {FILTERING_MODES}" + ) + assert metric_mode in RLLIB_METRICS_MODES + print("Before trial filtering:", len(experiment_analysis.trials), "trials") + trials_filtered = [] + print( + "metric_threshold", metric_threshold, "threshold_mode", threshold_mode + ) + for trial_idx, trial in enumerate(experiment_analysis.trials): + available_metrics = trial.metric_analysis + try: + metric_value = available_metrics[metric][metric_mode] + except KeyError: + raise KeyError( + f"failed to read metric key:{metric} in " + f"available_metrics:{available_metrics}" + ) + print( + f"trial_idx {trial_idx} " + f"available_metrics[{metric}][{metric_mode}] " + f"{metric_value}" + ) + if threshold_mode == ABOVE and metric_value > metric_threshold: + trials_filtered.append(trial) + elif threshold_mode == EQUAL and metric_value == metric_threshold: + trials_filtered.append(trial) + elif threshold_mode == BELOW and metric_value < metric_threshold: + trials_filtered.append(trial) + else: + print(f"filtering out trial {trial_idx}") + + experiment_analysis.trials = trials_filtered + print("After trial filtering:", len(experiment_analysis.trials), "trials") + return experiment_analysis + + +def filter_trials_wt_n_metrics( + experiment_analysis, + metrics: tuple, + metric_thresholds: tuple, + metric_modes: tuple, + threshold_modes: tuple, +): + for threshold_mode in threshold_modes: + assert threshold_mode in FILTERING_MODES, ( + f"threshold_mode {threshold_mode} " f"must be in {FILTERING_MODES}" + ) + for metric_mode in metric_modes: + assert metric_mode in RLLIB_METRICS_MODES + print("Before trial filtering:", len(experiment_analysis.trials), "trials") + trials_filtered = [] + print( + "metric_thresholds", + metric_thresholds, + "threshold_modes", + threshold_modes, + ) + for trial_idx, trial in enumerate(experiment_analysis.trials): + keep = [] + for metric, metric_threshold, metric_mode, threshold_mode in zip( + metrics, metric_thresholds, metric_modes, threshold_modes + ): + + available_metrics = trial.metric_analysis + try: + metric_value = available_metrics[metric][metric_mode] + except KeyError: + raise KeyError( + f"failed to read metric key:{metric} in " + f"available_metrics:{available_metrics}" + ) + print( + f"trial_idx {trial_idx} " + f"available_metrics[{metric}][{metric_mode}] " + f"{metric_value}" + ) + if threshold_mode == ABOVE and metric_value > metric_threshold: + keep.append(True) + elif threshold_mode == EQUAL and metric_value == metric_threshold: + keep.append(True) + elif threshold_mode == BELOW and metric_value < metric_threshold: + keep.append(True) + else: + keep.append(False) + # Logical and between metrics + if all(keep): + trials_filtered.append(trial) + + experiment_analysis.trials = trials_filtered + print("After trial filtering:", len(experiment_analysis.trials), "trials") + return experiment_analysis + + +def load_experiment_analysis_wt_ckpt_only( + checkpoints_paths: list, + result: dict = {"training_iteration": 1, "episode_reward_mean": 1}, + default_metric: "str" = "episode_reward_mean", + default_mode: str = "max", + n_dir_level_between_ckpt_and_exp_state=1, +): + """Helper to re-create a fake ExperimentAnalysis only containing the + checkpoints provided.""" + + assert default_metric in result.keys() + + register_trainable("fake trial", Trainable) + trials = [] + for one_checkpoint_path in checkpoints_paths: + one_trial = Trial(trainable_name="fake trial") + ckpt = Checkpoint( + Checkpoint.PERSISTENT, value=one_checkpoint_path, result=result + ) + one_trial.checkpoint_manager.on_checkpoint(ckpt) + trials.append(one_trial) + + json_file_path = _get_experiment_state_file_path( + checkpoints_paths[0], + split_path_n_times=n_dir_level_between_ckpt_and_exp_state, + ) + experiment_analysis = ExperimentAnalysis( + experiment_checkpoint_path=json_file_path, + trials=trials, + default_mode=default_mode, + default_metric=default_metric, + ) + + for trial in experiment_analysis.trials: + assert len(trial.checkpoint_manager.best_checkpoints()) == 1 + + return experiment_analysis + + +from collections import namedtuple + +FakeExperimentAnalysis = namedtuple("FakeExperimentAnalysis", "trials") + + +def create_fake_experiment_analysis_wt_metrics_only( + results: dict = ({"training_iteration": 1, "episode_reward_mean": 1},), + default_metric: "str" = "episode_reward_mean", + default_mode: str = "max", +): + """Helper to re-create a fake ExperimentAnalysis only containing the + checkpoints provided.""" + + # assert default_metric in result.keys() + + register_trainable("fake trial", Trainable) + trials = [] + for result in results: + trial = Trial(trainable_name="fake trial") + trial.update_last_result(result, terminate=True) + trials.append(trial) + + # experiment_analysis = ExperimentAnalysis( + # experiment_checkpoint_path="fake_path", + # trials=trials, + # default_mode=default_mode, + # default_metric=default_metric, + # ) + + experiment_analysis = FakeExperimentAnalysis(trials=trials) + + return experiment_analysis diff --git a/marltoolbox/utils/log/__init__.py b/marltoolbox/utils/log/__init__.py new file mode 100644 index 0000000..111e73e --- /dev/null +++ b/marltoolbox/utils/log/__init__.py @@ -0,0 +1,19 @@ +from marltoolbox.utils.log.log import * +from marltoolbox.utils.log.log import log_learning_rate +from marltoolbox.utils.log.full_epi_logger import FullEpisodeLogger +from marltoolbox.utils.log.model_summarizer import ModelSummarizer + +__all__ = [ + "log_learning_rate", + "FullEpisodeLogger", + "ModelSummarizer", + "log_learning_rate", + "pprint_saved_metrics", + "save_metrics", + "extract_all_metrics_from_results", + "log_in_current_day_dir", + "compute_entropy_from_raw_q_values", + "augment_stats_fn_wt_additionnal_logs", + "get_log_from_policy", + "add_entropy_to_log", +] diff --git a/marltoolbox/utils/full_epi_logger.py b/marltoolbox/utils/log/full_epi_logger.py similarity index 56% rename from marltoolbox/utils/full_epi_logger.py rename to marltoolbox/utils/log/full_epi_logger.py index 98017c2..32f957b 100644 --- a/marltoolbox/utils/full_epi_logger.py +++ b/marltoolbox/utils/log/full_epi_logger.py @@ -4,18 +4,32 @@ import numpy as np from ray.rllib.evaluation import MultiAgentEpisode -from ray.tune.logger import _SafeFallbackEncoder +from ray.tune.logger import SafeFallbackEncoder logger = logging.getLogger(__name__) class FullEpisodeLogger: - - def __init__(self, logdir, log_interval, log_ful_epi_one_hot_obs): + """ + Helper to log the entire history of one episode as txt + """ + + def __init__( + self, logdir: str, log_interval: int, convert_one_hot_obs_to_idx: bool + ): + """ + + :param logdir: dir where to save the log file with the full episode + :param log_interval: interval (in number of episode) between the log + of two episodes + :param convert_one_hot_obs_to_idx: bool flag to chose to convert the + observation to their idx + (indented to bue used when dealing with one hot observations) + """ self.log_interval = log_interval - self.log_ful_epi_one_hot_obs = log_ful_epi_one_hot_obs + self.log_ful_epi_one_hot_obs = convert_one_hot_obs_to_idx - file_path = os.path.join(logdir, f"full_episodes_logs.json") + file_path = os.path.join(logdir, "full_episodes_logs.json") self.file_path = os.path.expanduser(file_path) logger.info(f"FullEpisodeLogger: using as file_path: {self.file_path}") @@ -23,7 +37,6 @@ def __init__(self, logdir, log_interval, log_ful_epi_one_hot_obs): self.internal_episode_counter = -1 self.step_counter = 0 self.episode_finised = True - self._first_fake_step_done = False self.json_logger = JsonSimpleLogger(self.file_path) @@ -45,7 +58,8 @@ def _init_logging_new_full_episode(self): self._log_full_epi_tmp_data = {} def on_episode_step( - self, episode: MultiAgentEpisode = None, step_data: dict = None): + self, episode: MultiAgentEpisode = None, step_data: dict = None + ): if not self._log_current_full_episode: return None @@ -56,22 +70,23 @@ def on_episode_step( step_data = {} for agent_id, policy in episode._policies.items(): - if self._first_fake_step_done: - if agent_id in self._log_full_epi_tmp_data.keys(): - obs_before_act = self._log_full_epi_tmp_data[agent_id] - else: - obs_before_act = None - action = episode.last_action_for(agent_id).tolist() - epi = episode.episode_id - rewards = episode._agent_reward_history[agent_id] - reward = rewards[-1] if len(rewards) > 0 else None - info = episode.last_info_for(agent_id) - if hasattr(policy, "to_log"): - info.update(policy.to_log) - else: - logger.info(f"policy {policy} doesn't have attrib " - "to_log. hasattr(policy, 'to_log'): " - f"{hasattr(policy, 'to_log')}") + if agent_id in self._log_full_epi_tmp_data.keys(): + obs_before_act = self._log_full_epi_tmp_data[agent_id] + else: + obs_before_act = None + action = episode.last_action_for(agent_id).tolist() + epi = episode.episode_id + rewards = episode._agent_reward_history[agent_id] + reward = rewards[-1] if len(rewards) > 0 else None + info = episode.last_info_for(agent_id) + if hasattr(policy, "to_log"): + info.update(policy.to_log) + else: + logger.info( + f"policy {policy} doesn't have attrib " + "to_log. hasattr(policy, 'to_log'): " + f"{hasattr(policy, 'to_log')}" + ) # Episode provide the last action with the given last # observation produced by this action. But we need the # observation that cause the agent to play this action @@ -79,40 +94,37 @@ def on_episode_step( obs_after_act = episode.last_observation_for(agent_id) self._log_full_epi_tmp_data[agent_id] = obs_after_act - if self._first_fake_step_done: - if self.log_ful_epi_one_hot_obs: - obs_before_act = np.argwhere(obs_before_act) - obs_after_act = np.argwhere(obs_after_act) - - step_data[agent_id] = { - "obs_before_act": obs_before_act, - "obs_after_act": obs_after_act, - "action": action, - "reward": reward, - "info": info, - "epi": epi} - - if self._first_fake_step_done: - self.json_logger.write_json(step_data) - self.json_logger.write("\n") - self.step_counter += 1 - else: - logger.info("FullEpisodeLogger: don't log first fake step") - self._first_fake_step_done = True + if self.log_ful_epi_one_hot_obs: + obs_before_act = np.argwhere(obs_before_act) + obs_after_act = np.argwhere(obs_after_act) + + step_data[agent_id] = { + "obs_before_act": obs_before_act, + "obs_after_act": obs_after_act, + "action": action, + "reward": reward, + "info": info, + "epi": epi, + } + + self.json_logger.write_json(step_data) + self.json_logger.write("\n") + self.step_counter += 1 def on_episode_end(self, base_env=None): if self._log_current_full_episode: if base_env is not None: env = base_env.get_unwrapped()[0] if hasattr(env, "max_steps"): - assert self.step_counter == env.max_steps, \ - "The number of steps written to full episode " \ - "log file must be equal to the number of step in an " \ - f"episode self.step_counter {self.step_counter} " \ - f"must equal env.max_steps {env.max_steps}. " \ - "Otherwise there are some issue with the " \ - "state of the callback object, maybe being used by " \ + assert self.step_counter == env.max_steps, ( + "The number of steps written to full episode " + "log file must be equal to the number of step in an " + f"episode self.step_counter {self.step_counter} " + f"must equal env.max_steps {env.max_steps}. " + "Otherwise there are some issue with the " + "state of the callback object, maybe being used by " "several experiments at the same time." + ) self.json_logger.write_json( {"status": f"end of episode {self.internal_episode_counter}"} ) @@ -126,12 +138,19 @@ def on_episode_end(self, base_env=None): class JsonSimpleLogger: + """ + Simple logger in json format + """ def __init__(self, file_path): + """ + + :param file_path: file path to the file to save to + """ self.local_file = file_path def write_json(self, json_data): - json.dump(json_data, self, cls=_SafeFallbackEncoder) + json.dump(json_data, self, cls=SafeFallbackEncoder) def write(self, b): self.local_out.write(b) diff --git a/marltoolbox/utils/log.py b/marltoolbox/utils/log/log.py similarity index 68% rename from marltoolbox/utils/log.py rename to marltoolbox/utils/log/log.py index d7f38eb..6625549 100644 --- a/marltoolbox/utils/log.py +++ b/marltoolbox/utils/log/log.py @@ -1,14 +1,12 @@ import copy import datetime import logging -import math import numbers import os import pickle import pprint import re from collections import Iterable -from typing import Dict, Callable, TYPE_CHECKING import gym import torch @@ -17,12 +15,13 @@ from ray.rllib.evaluation import MultiAgentEpisode from ray.rllib.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.utils.typing import PolicyID, TensorType -from scipy.special import softmax -from torch.nn import Module +from ray.util.debug import log_once +from torch.distributions import Categorical +from typing import Dict, Callable, TYPE_CHECKING, Optional, List -from marltoolbox.utils.full_epi_logger import FullEpisodeLogger +from marltoolbox.utils.log.full_epi_logger import FullEpisodeLogger +from marltoolbox.utils.log.model_summarizer import ModelSummarizer if TYPE_CHECKING: from ray.rllib.evaluation import RolloutWorker @@ -33,6 +32,7 @@ def get_logging_callbacks_class( log_env_step: bool = True, log_from_policy: bool = True, + log_from_policy_in_evaluation: bool = False, log_full_epi: bool = False, log_full_epi_interval: int = 100, log_ful_epi_one_hot_obs: bool = True, @@ -49,7 +49,7 @@ def on_episode_start( base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: MultiAgentEpisode, - env_index: int, + env_index: Optional[int] = None, **kwargs, ): if log_full_epi: @@ -67,90 +67,13 @@ def _init_full_episode_logging(self, worker): self._full_episode_logger = FullEpisodeLogger( logdir=worker.io_context.log_dir, log_interval=log_full_epi_interval, - log_ful_epi_one_hot_obs=log_ful_epi_one_hot_obs, + convert_one_hot_obs_to_idx=log_ful_epi_one_hot_obs, ) logger.info("_full_episode_logger init done") def _log_model_sumamry(self, worker): - if not hasattr(self, "_log_model_sumamry_done"): - self._log_model_sumamry_done = True - self._for_every_policy_print_model_stats(worker) - - def _for_every_policy_print_model_stats(self, worker): - for policy_id, policy in worker.policy_map.items(): - msg = f"===== Models summaries policy_id {policy_id} =====" - print(msg) - logger.info(msg) - self._print_model_summary(policy) - self._count_parameters_in_every_modules(policy) - - @staticmethod - def _print_model_summary(policy): - if isinstance(policy, TorchPolicy): - for k, v in policy.__dict__.items(): - if isinstance(v, Module): - msg = f"{k}, {v}" - print(msg) - logger.info(msg) - - def _count_parameters_in_every_modules(self, policy): - if isinstance(policy, TorchPolicy): - for k, v in policy.__dict__.items(): - if isinstance(v, Module): - self._count_and_log_for_one_module(policy, k, v) - - def _count_and_log_for_one_module(self, policy, module_name, module): - n_param = self._count_parameters(module, module_name) - n_param_shared_counted_once = self._count_parameters( - module, module_name, count_shared_once=True - ) - n_param_trainable = self._count_parameters( - module, module_name, only_trainable=True - ) - self._log_values_in_to_log( - policy, - { - f"module_{module_name}_n_param": n_param, - f"module_{module_name}_n_param_shared_counted_once": n_param_shared_counted_once, - f"module_{module_name}_n_param_trainable": n_param_trainable, - }, - ) - - @staticmethod - def _log_values_in_to_log(policy, dictionary): - if hasattr(policy, "to_log"): - policy.to_log.update(dictionary) - - @staticmethod - def _count_parameters( - m: torch.nn.Module, - module_name: str, - count_shared_once: bool = False, - only_trainable: bool = False, - ): - """ - returns the total number of parameters used by `m` (only counting - shared parameters once); if `only_trainable` is True, then only - includes parameters with `requires_grad = True` - """ - parameters = m.parameters() - if only_trainable: - parameters = list(p for p in parameters if p.requires_grad) - if count_shared_once: - parameters = dict( - (p.data_ptr(), p) for p in parameters - ).values() - number_of_parameters = sum(p.numel() for p in parameters) - - msg = ( - f"{module_name}: " - f"number_of_parameters: {number_of_parameters} " - f"(only_trainable: {only_trainable}, " - f"count_shared_once: {count_shared_once})" - ) - print(msg) - logger.info(msg) - return number_of_parameters + if log_once("model_summaries"): + ModelSummarizer.for_every_policy_print_model_stats(worker) def on_episode_step( self, @@ -158,9 +81,11 @@ def on_episode_step( worker: "RolloutWorker", base_env: BaseEnv, episode: MultiAgentEpisode, - env_index: int, + env_index: Optional[int] = None, **kwargs, ): + if log_from_policy_in_evaluation: + self._update_epi_info_wt_to_log(worker, episode) if log_env_step: self._add_env_info_to_custom_metrics(worker, episode) if log_full_epi: @@ -173,7 +98,7 @@ def on_episode_end( base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: MultiAgentEpisode, - env_index: int, + env_index: Optional[int] = None, **kwargs, ): if log_full_epi: @@ -192,6 +117,11 @@ def on_train_result(self, *, trainer, result: dict, **kwargs): self._update_train_result_wt_to_log( trainer, result, function_to_exec=get_log_from_policy ) + self._update_train_result_wt_to_log( + trainer, + result, + function_to_exec=get_explore_temperature_from_policy, + ) if log_weights: if not hasattr(self, "on_train_result_counter"): self.on_train_result_counter = 0 @@ -199,24 +129,10 @@ def on_train_result(self, *, trainer, result: dict, **kwargs): self._update_train_result_wt_to_log( trainer, result, - function_to_exec=self._get_weights_from_policy, + function_to_exec=get_weights_from_policy, ) self.on_train_result_counter += 1 - @staticmethod - def _get_weights_from_policy( - policy: Policy, policy_id: PolicyID - ) -> dict: - """Gets the to_log var from a policy and rename its keys, adding the policy_id as a prefix.""" - to_log = {} - weights = policy.get_weights() - - for k, v in weights.items(): - if isinstance(v, Iterable): - to_log[f"{policy_id}/{k}"] = v - - return to_log - @staticmethod def _add_env_info_to_custom_metrics(worker, episode): @@ -239,10 +155,28 @@ def _update_train_result_wt_to_log( def exec_in_each_policy(worker): return worker.foreach_policy(function_to_exec) - # to_log_list = trainer.workers.foreach_policy(function_to_exec) to_log_list_list = trainer.workers.foreach_worker( exec_in_each_policy ) + self._unroll_into_logs(result, to_log_list_list) + + def _update_epi_info_wt_to_log( + self, worker, episode: MultiAgentEpisode + ): + """ + Add logs from every policies (from policy.to_log:dict) + to the info (to be plotted in Tensorboard). + To be called from the on_episode_end callback. + """ + + for policy_id, policy in worker.policy_map.items(): + to_log = get_log_from_policy(policy, policy_id) + episode._agent_to_last_info[policy_id].update(to_log) + + @staticmethod + def _unroll_into_logs( + dict_: Dict, to_log_list_list: List[List[Dict]] + ) -> Dict: for worker_idx, to_log_list in enumerate(to_log_list_list): for to_log in to_log_list: for k, v in to_log.items(): @@ -252,12 +186,13 @@ def exec_in_each_policy(worker): else: key = k - if key not in result.keys(): - result[key] = v + if key not in dict_.keys(): + dict_[key] = v else: - raise ValueError( - f"key:{key} already exists in result.keys(): {result.keys()}" + logger.warning( + f"key:{key} already exists in result.keys()" ) + return dict_ return LoggingCallbacks @@ -275,6 +210,33 @@ def get_log_from_policy(policy: Policy, policy_id: PolicyID) -> dict: return to_log +def get_explore_temperature_from_policy( + policy: Policy, policy_id: PolicyID +) -> dict: + """ + It is exist get the temperature from the exploration policy of a Policy + """ + to_log = {} + if hasattr(policy, "exploration"): + exploration_obj = policy.exploration + if hasattr(exploration_obj, "temperature"): + to_log[f"{policy_id}/temperature"] = exploration_obj.temperature + + return to_log + + +def get_weights_from_policy(policy: Policy, policy_id: PolicyID) -> dict: + """Gets the to_log var from a policy and rename its keys, adding the policy_id as a prefix.""" + to_log = {} + weights = policy.get_weights() + + for k, v in weights.items(): + if isinstance(v, Iterable): + to_log[f"{policy_id}/{k}"] = v + + return to_log + + def augment_stats_fn_wt_additionnal_logs( stats_function: Callable[[Policy, SampleBatch], Dict[str, TensorType]] ): @@ -296,7 +258,9 @@ def wt_additional_info( if policy.config["framework"] == "torch": stats_to_log.update(_log_action_prob_pytorch(policy, train_batch)) else: - logger.warning("wt_additional_info workin only for PyTorch") + logger.warning( + "wt_additional_info (stats_fn) working only for PyTorch" + ) return stats_to_log @@ -316,15 +280,20 @@ def _log_action_prob_pytorch( # TODO add entropy to_log = {} if isinstance(policy.action_space, gym.spaces.Discrete): - - assert ( - train_batch["action_dist_inputs"].dim() == 2 - ), "Do not support nested discrete spaces" - - to_log = _add_action_distrib_to_log(policy, train_batch, to_log) - to_log = _add_entropy_to_log(train_batch, to_log) - to_log = _add_proba_of_action_played(train_batch, to_log) - to_log = _add_q_values(policy, train_batch, to_log) + if train_batch.ACTION_DIST_INPUTS in train_batch.keys(): + assert ( + train_batch[train_batch.ACTION_DIST_INPUTS].dim() == 2 + ), "Do not support nested discrete spaces" + + to_log = _add_action_distrib_to_log(policy, train_batch, to_log) + to_log = add_entropy_to_log(train_batch, to_log) + to_log = _add_proba_of_action_played(train_batch, to_log) + to_log = _add_q_values(policy, train_batch, to_log) + else: + logger.warning( + "Key ACTION_DIST_INPUTS not found in train_batch. " + "Can't perform _log_action_prob_pytorch." + ) else: raise NotImplementedError() return to_log @@ -344,7 +313,7 @@ def _add_action_distrib_to_log(policy, train_batch, to_log): return to_log -def _add_entropy_to_log(train_batch, to_log): +def add_entropy_to_log(train_batch, to_log): actions_proba_batch = train_batch["action_dist_inputs"] if _is_cuda_tensor(actions_proba_batch): @@ -357,8 +326,7 @@ def _add_entropy_to_log(train_batch, to_log): actions_proba_batch ) - entropy_avg = _entropy_batch_proba_distrib(actions_proba_batch) - entropy_single = _entropy_proba_distrib(actions_proba_batch[-1, :]) + entropy_avg, entropy_single = _compute_entropy_pytorch(actions_proba_batch) to_log[f"entropy_buffer_samples_avg"] = entropy_avg to_log[f"entropy_buffer_samples_single"] = entropy_single @@ -371,24 +339,7 @@ def _is_cuda_tensor(tensor): def _entropy_batch_proba_distrib(proba_distrib_batch): assert len(proba_distrib_batch) > 0 - entropy_batch = [ - _entropy_proba_distrib(proba_distrib_batch[batch_idx, ...]) - for batch_idx in range(len(proba_distrib_batch)) - ] - mean_entropy = sum(entropy_batch) / len(entropy_batch) - return mean_entropy - - -def _entropy_proba_distrib(proba_distrib): - return sum([_entropy_proba(proba) for proba in proba_distrib]) - - -def _entropy_proba(proba): - assert proba >= 0.0, f"proba currently is {proba}" - if proba == 0.0: - return 0.0 - else: - return -proba * math.log(proba) + return Categorical(probs=proba_distrib_batch).entropy() def _add_proba_of_action_played(train_batch, to_log): @@ -399,7 +350,7 @@ def _add_proba_of_action_played(train_batch, to_log): def _convert_q_values_batch_to_proba_batch(q_values_batch): - return softmax(q_values_batch, axis=1) + return torch.nn.functional.softmax(q_values_batch, dim=1) def _add_q_values(policy, train_batch, to_log): @@ -416,15 +367,21 @@ def _add_q_values(policy, train_batch, to_log): return to_log -def _compute_entropy_from_raw_q_values(policy, q_values): +def compute_entropy_from_raw_q_values(policy, q_values): actions_proba_batch = _apply_exploration(policy, dist_inputs=q_values) - if _is_cuda_tensor(actions_proba_batch): - actions_proba_batch = actions_proba_batch.cpu() actions_proba_batch = _convert_q_values_batch_to_proba_batch( actions_proba_batch ) - entropy_avg = _entropy_batch_proba_distrib(actions_proba_batch) - entropy_single = _entropy_proba_distrib(actions_proba_batch[-1, :]) + return _compute_entropy_pytorch(actions_proba_batch) + + +def _compute_entropy_pytorch(actions_proba_batch): + entropies = _entropy_batch_proba_distrib(actions_proba_batch) + entropy_avg = entropies.mean() + entropy_single = entropies[-1] + if _is_cuda_tensor(actions_proba_batch): + entropy_avg = entropy_avg.cpu() + entropy_single = entropy_single.cpu() return entropy_avg, entropy_single @@ -500,13 +457,15 @@ def pprint_saved_metrics(file_path, keywords_to_print=None): pp.pprint(metrics) -def _log_learning_rate(policy): +def log_learning_rate(policy): to_log = {} if hasattr(policy, "cur_lr"): to_log["cur_lr"] = policy.cur_lr for j, opt in enumerate(policy._optimizers): if hasattr(opt, "param_groups"): to_log[f"opt{j}_lr"] = [p["lr"] for p in opt.param_groups][0] + else: + print("opt doesn't have attr param_groups") return to_log diff --git a/marltoolbox/utils/log/model_summarizer.py b/marltoolbox/utils/log/model_summarizer.py new file mode 100644 index 0000000..5aa1380 --- /dev/null +++ b/marltoolbox/utils/log/model_summarizer.py @@ -0,0 +1,100 @@ +import logging + +import torch +from ray.rllib.policy.torch_policy import TorchPolicy +from torch.nn import Module +from ray.rllib.evaluation import RolloutWorker + +logger = logging.getLogger(__name__) + + +class ModelSummarizer: + """ + Helper to log for every torch.nn modules in every policies the + architecture and some parameter statistics. + """ + + @staticmethod + def for_every_policy_print_model_stats(worker: RolloutWorker): + """ + For every policies in the worker, log the archi of all torch modules + and some statistiques about their parameters + + :param worker: + """ + for policy_id, policy in worker.policy_map.items(): + msg = f"===== Models summaries policy_id {policy_id} =====" + print(msg) + logger.info(msg) + ModelSummarizer._print_model_summary(policy) + ModelSummarizer._count_parameters_in_every_modules(policy) + + @staticmethod + def _print_model_summary(policy: TorchPolicy): + if isinstance(policy, TorchPolicy): + for k, v in policy.__dict__.items(): + if isinstance(v, Module): + msg = f"{k}, {v}" + print(msg) + logger.info(msg) + + @staticmethod + def _count_parameters_in_every_modules(policy: TorchPolicy): + if isinstance(policy, TorchPolicy): + for k, v in policy.__dict__.items(): + if isinstance(v, Module): + ModelSummarizer._count_and_log_for_one_module(policy, k, v) + + @staticmethod + def _count_and_log_for_one_module( + policy: TorchPolicy, module_name: str, module: torch.nn.Module + ): + n_param = ModelSummarizer._count_parameters(module, module_name) + n_param_shared_counted_once = ModelSummarizer._count_parameters( + module, module_name, count_shared_once=True + ) + n_param_trainable = ModelSummarizer._count_parameters( + module, module_name, only_trainable=True + ) + ModelSummarizer._log_values_in_to_log( + policy, + { + f"module_{module_name}_n_param": n_param, + f"module_{module_name}_n_param_shared_counted_once": n_param_shared_counted_once, + f"module_{module_name}_n_param_trainable": n_param_trainable, + }, + ) + + @staticmethod + def _log_values_in_to_log(policy, dictionary): + if hasattr(policy, "to_log"): + policy.to_log.update(dictionary) + + @staticmethod + def _count_parameters( + m: torch.nn.Module, + module_name: str, + count_shared_once: bool = False, + only_trainable: bool = False, + ): + """ + returns the total number of parameters used by `m` (only counting + shared parameters once); if `only_trainable` is True, then only + includes parameters with `requires_grad = True` + """ + parameters = m.parameters() + if only_trainable: + parameters = list(p for p in parameters if p.requires_grad) + if count_shared_once: + parameters = dict((p.data_ptr(), p) for p in parameters).values() + number_of_parameters = sum(p.numel() for p in parameters) + + msg = ( + f"{module_name}: " + f"number_of_parameters: {number_of_parameters} " + f"(only_trainable: {only_trainable}, " + f"count_shared_once: {count_shared_once})" + ) + print(msg) + logger.info(msg) + return number_of_parameters diff --git a/marltoolbox/utils/miscellaneous.py b/marltoolbox/utils/miscellaneous.py index 290bf52..c8e083e 100644 --- a/marltoolbox/utils/miscellaneous.py +++ b/marltoolbox/utils/miscellaneous.py @@ -7,11 +7,6 @@ import numpy as np from ray.rllib.policy.sample_batch import SampleBatch -from ray.tune import Trainable -from ray.tune import register_trainable -from ray.tune.analysis.experiment_analysis import ExperimentAnalysis -from ray.tune.checkpoint_manager import Checkpoint -from ray.tune.trial import Trial if TYPE_CHECKING: pass @@ -71,7 +66,7 @@ def move_to_key(dict_: dict, key: str): :param dict_: dict or nesyed dict :param key: key or serie of key joined by a '.' - :return: (the lower level dict, lower level key, the final value, + :return: Tuple(the lower level dict, lower level key, the final value, boolean for final value found) """ assert isinstance(dict_, dict) @@ -79,7 +74,10 @@ def move_to_key(dict_: dict, key: str): found = True for k in key.split("."): if not found: - print(f"Intermediary key: {k} not found in full key: {key}") + print( + f"Intermediary key: {k} not found with full key: {key} " + f"and dict: {dict_}" + ) return dict_ = current_value if k in current_value.keys(): @@ -89,42 +87,6 @@ def move_to_key(dict_: dict, key: str): return dict_, k, current_value, found -def extract_checkpoints(tune_experiment_analysis): - logger.info("start extract_checkpoints") - - for trial in tune_experiment_analysis.trials: - checkpoints = tune_experiment_analysis.get_trial_checkpoints_paths( - trial, tune_experiment_analysis.default_metric - ) - assert len(checkpoints) > 0 - - all_best_checkpoints_per_trial = [ - tune_experiment_analysis.get_best_checkpoint( - trial, - metric=tune_experiment_analysis.default_metric, - mode=tune_experiment_analysis.default_mode, - ) - for trial in tune_experiment_analysis.trials - ] - - for checkpoint in all_best_checkpoints_per_trial: - assert checkpoint is not None - - logger.info("end extract_checkpoints") - return all_best_checkpoints_per_trial - - -def extract_config_values_from_tune_analysis(tune_experiment_analysis, key): - values = [] - for trial in tune_experiment_analysis.trials: - dict_, k, current_value, found = move_to_key(trial.config, key) - if found: - values.append(current_value) - else: - values.append(None) - return values - - def merge_policy_postprocessing_fn(*postprocessing_fn_list): """ Merge several callback class together. @@ -192,74 +154,12 @@ def set_config_for_evaluation( return config_copy -def filter_tune_results( - tune_analysis, - metric, - metric_threshold: float, - metric_mode="last-5-avg", - threshold_mode="above", -): - assert threshold_mode in ("above", "equal", "below") - assert metric_mode in ( - "avg", - "min", - "max", - "last", - "last-5-avg", - "last-10-avg", - ) - print("Before trial filtering:", len(tune_analysis.trials), "trials") - trials_filtered = [] - print( - "metric_threshold", metric_threshold, "threshold_mode", threshold_mode - ) - for trial_idx, trial in enumerate(tune_analysis.trials): - available_metrics = trial.metric_analysis - print( - f"trial_idx {trial_idx} " - f"available_metrics[{metric}][{metric_mode}] " - f"{available_metrics[metric][metric_mode]}" - ) - if ( - threshold_mode == "above" - and available_metrics[metric][metric_mode] > metric_threshold - ): - trials_filtered.append(trial) - elif ( - threshold_mode == "equal" - and available_metrics[metric][metric_mode] == metric_threshold - ): - trials_filtered.append(trial) - elif ( - threshold_mode == "below" - and available_metrics[metric][metric_mode] < metric_threshold - ): - trials_filtered.append(trial) - else: - print(f"filter trial {trial_idx}") - tune_analysis.trials = trials_filtered - print("After trial filtering:", len(tune_analysis.trials), "trials") - return tune_analysis - - def get_random_seeds(n_seeds): timestamp = int(time.time()) seeds = [seed + timestamp for seed in list(range(n_seeds))] return seeds -def list_all_files_in_one_dir_tree(path): - if not os.path.exists(path): - raise FileExistsError(f"path doesn't exist: {path}") - file_list = [] - for root, dirs, files in os.walk(path): - for file in files: - # append the file name to the list - file_list.append(os.path.join(root, file)) - print(len(file_list), "files found") - return file_list - - def ignore_str_containing_keys(str_list, ignore_keys): str_list_filtered = [ file_path @@ -317,45 +217,6 @@ def fing_longer_substr(str_list): return substr -def load_one_tune_analysis( - checkpoints_paths: list, - result: dict = {"training_iteration": 1, "episode_reward_mean": 1}, - default_metric: "str" = "episode_reward_mean", - default_mode: str = "max", - n_dir_level_between_ckpt_and_exp_state=1, -): - """Helper to re-create a fake tune_analysis only containing the - checkpoints provided.""" - - assert default_metric in result.keys() - - register_trainable("fake trial", Trainable) - trials = [] - for one_checkpoint_path in checkpoints_paths: - one_trial = Trial(trainable_name="fake trial") - ckpt = Checkpoint( - Checkpoint.PERSISTENT, value=one_checkpoint_path, result=result - ) - one_trial.checkpoint_manager.on_checkpoint(ckpt) - trials.append(one_trial) - - json_file_path = _get_experiment_state_file_path( - checkpoints_paths[0], - split_path_n_times=n_dir_level_between_ckpt_and_exp_state, - ) - one_tune_analysis = ExperimentAnalysis( - experiment_checkpoint_path=json_file_path, - trials=trials, - default_mode=default_mode, - default_metric=default_metric, - ) - - for trial in one_tune_analysis.trials: - assert len(trial.checkpoint_manager.best_checkpoints()) == 1 - - return one_tune_analysis - - def _get_experiment_state_file_path(one_checkpoint_path, split_path_n_times=1): one_checkpoint_path = os.path.expanduser(one_checkpoint_path) parent_dir = one_checkpoint_path @@ -368,41 +229,6 @@ def _get_experiment_state_file_path(one_checkpoint_path, split_path_n_times=1): return json_file_path -def check_learning_achieved( - tune_results, - metric="episode_reward_mean", - trial_idx=0, - max_: float = None, - min_: float = None, - equal_: float = None, -): - assert max_ is not None or min_ is not None or equal_ is not None - - last_results = tune_results.trials[trial_idx].last_result - _, _, value, found = move_to_key(last_results, key=metric) - assert ( - found - ), f"metric {metric} not found inside last_results {last_results}" - - msg = ( - f"Trial {trial_idx} achieved " - f"{value}" - f" on metric {metric}. This is a success if the value is below" - f" {max_} or above {min_} or equal to {equal_}." - ) - - logger.info(msg) - print(msg) - if min_ is not None: - assert value >= min_, f"value {value} must be above min_ {min_}" - if max_ is not None: - assert value <= max_, f"value {value} must be below max_ {max_}" - if equal_ is not None: - assert value == equal_, ( - f"value {value} must be equal to equal_ " f"{equal_}" - ) - - def assert_if_key_in_dict_then_args_are_none(dict_, key, *args): if key in dict_.keys(): for arg in args: @@ -422,28 +248,12 @@ def read_from_dict_default_to_args(dict_, key, *args): def filter_sample_batch( samples: SampleBatch, filter_key, remove=True, copy_data=False ) -> SampleBatch: - filter = samples.data[filter_key] + filter = samples.columns([filter_key])[0] if remove: - # torch logical not + assert isinstance( + filter, np.ndarray + ), f"type {type(filter)} for filter_key {filter_key}" filter = ~filter return SampleBatch( - { - k: np.array(v, copy=copy_data)[filter] - for (k, v) in samples.data.items() - } + {k: np.array(v, copy=copy_data)[filter] for (k, v) in samples.items()} ) - - -def extract_metric_values_per_trials( - tune_analysis, - metric="episode_reward_mean", -): - metric_values = [] - for trial in tune_analysis.trials: - last_results = trial.last_result - _, _, value, found = move_to_key(last_results, key=metric) - assert ( - found - ), f"metric: {metric} not found in last_results: {last_results}" - metric_values.append(value) - return metric_values diff --git a/marltoolbox/utils/path.py b/marltoolbox/utils/path.py new file mode 100644 index 0000000..c9d3053 --- /dev/null +++ b/marltoolbox/utils/path.py @@ -0,0 +1,215 @@ +import json +import os +from typing import List + +from marltoolbox.utils import miscellaneous, exp_analysis + + +def get_unique_child_dir(_dir: str): + """ + Return the path to the unique dir inside the given dir. + + :param _dir: path to given dir + :return: path to the unique dir inside the given dir + """ + + list_child_dir = os.listdir(_dir) + list_child_dir = [ + os.path.join(_dir, child_dir) for child_dir in list_child_dir + ] + list_child_dir = keep_dirs_only(list_child_dir) + assert len(list_child_dir) == 1, f"{list_child_dir}" + unique_child_dir = list_child_dir[0] + return unique_child_dir + + +def try_get_unique_child_dir(_dir: str): + """ + If it exists, returns the path to the unique dir inside the given dir. + Otherwise returns None. + + :param _dir: path to given dir + :return: path to the unique dir inside the given dir or if it doesn't + exist None + """ + + try: + unique_child_dir = get_unique_child_dir(_dir) + return unique_child_dir + except AssertionError: + return None + + +def list_all_files_in_one_dir_tree(path: str) -> List[str]: + """ + List all the files in the tree starting at the given path. + + :param path: + :return: list of all the files + """ + if not os.path.exists(path): + raise FileExistsError(f"path doesn't exist: {path}") + file_list = [] + for root, dirs, files in os.walk(path): + for file in files: + # append the file name to the list + file_list.append(os.path.join(root, file)) + print(len(file_list), "files found") + return file_list + + +def get_children_paths_wt_selecting_filter( + parent_dir_path: str, _filter: str +) -> List[str]: + """ + Return all children dir paths after selecting those containing the + _filter. + + :param parent_dir_path: + :param _filter: to select the paths to keep + :return: list of paths which contain the given filter. + """ + return _get_children_paths_filters( + parent_dir_path, selecting_filter=_filter + ) + + +def get_children_paths_wt_discarding_filter( + parent_dir_path: str, _filter: str +) -> List[str]: + """ + Return all children dir paths after selecting those NOT containing the + _filter. + + :param parent_dir_path: + :param _filter: to select the paths to remove + :return: list of paths which don't contain the given filter. + """ + + return _get_children_paths_filters( + parent_dir_path, discarding_filter=_filter + ) + + +def _get_children_paths_filters( + parent_dir_path: str, + selecting_filter: str = None, + discarding_filter: str = None, +): + filtered_children = os.listdir(parent_dir_path) + if selecting_filter is not None: + filtered_children = [ + filename + for filename in filtered_children + if selecting_filter in filename + ] + if discarding_filter is not None: + filtered_children = [ + filename + for filename in filtered_children + if discarding_filter not in filename + ] + filtered_children_path = [ + os.path.join(parent_dir_path, filename) + for filename in filtered_children + ] + return filtered_children_path + + +def get_params_for_replicate(trial_dir_path: str) -> dict: + """ + Get the parameters from the json file saved in the dir of an Tune/RLLib + trial. + + :param trial_dir_path: patht to a single tune.Trial (inside an experiment) + :return: dict of parameters used for the trial + """ + parameter_json_path = os.path.join(trial_dir_path, "params.json") + params = _read_json_file(parameter_json_path) + return params + + +def get_results_for_replicate(trial_dir_path: str) -> list: + """ + Get the results for all episodes from the file saved in the + dir of an Tune/RLLib trial. + + :param trial_dir_path: patht to a single tune.Trial (inside an experiment) + :return: list of lines of results (one line per episode) + """ + results_file_path = os.path.join(trial_dir_path, "result.json") + results = _read_all_lines_of_file(results_file_path) + # Remove empty last line + if len(results[-1]) == 0: + results = results[:-1] + results = [json.loads(line) for line in results] + return results + + +def _read_json_file(json_file_path: str): + with open(json_file_path) as json_file: + json_object = json.load(json_file) + return json_object + + +def _read_all_lines_of_file(file_path: str) -> list: + with open(file_path) as file: + lines = list(file) + return lines + + +def keep_dirs_only(paths: list) -> list: + """Keep only the directories""" + return [path for path in paths if os.path.isdir(path)] + + +def filter_list_of_replicates_by_results( + replicate_paths: list, + filter_key: str, + filter_threshold: float, + filter_mode: str = exp_analysis.ABOVE, +) -> list: + print("Going to start filtering replicate_paths") + print("len(replicate_paths)", len(replicate_paths)) + filtered_replicate_paths = [] + for replica_path in replicate_paths: + replica_results = get_results_for_replicate(replica_path) + last_result = replica_results[-1] + assert isinstance(last_result, dict) + _, _, current_value, found = miscellaneous.move_to_key( + last_result, filter_key + ) + assert found, ( + f"filter_key {filter_key} not found in last_result " + f"{last_result}" + ) + if ( + filter_mode == exp_analysis.ABOVE + and current_value > filter_threshold + ): + filtered_replicate_paths.append(replica_path) + elif ( + filter_mode == exp_analysis.EQUAL + and current_value == filter_threshold + ): + filtered_replicate_paths.append(replica_path) + elif ( + filter_mode == exp_analysis.BELOW + and current_value < filter_threshold + ): + filtered_replicate_paths.append(replica_path) + else: + print(f"filtering out replica_path {replica_path}") + print("After filtering:") + print("len(filtered_replicate_paths)", len(filtered_replicate_paths)) + return filtered_replicate_paths + + +def get_exp_dir_from_exp_name(exp_name: str): + """ + :param exp_name: exp_name provided to tune.run + :return: path to the experiment analysis repository (ray log dir) + """ + exp_dir = os.path.join("~/ray_results", exp_name) + exp_dir = os.path.expanduser(exp_dir) + return exp_dir diff --git a/marltoolbox/utils/plot.py b/marltoolbox/utils/plot.py index f478894..0b09b15 100644 --- a/marltoolbox/utils/plot.py +++ b/marltoolbox/utils/plot.py @@ -5,6 +5,10 @@ import matplotlib.pyplot as plt import numpy as np +plt.switch_backend("agg") +plt.style.use("seaborn-whitegrid") +plt.rcParams.update({"font.size": 12}) + COLORS = list(mcolors.TABLEAU_COLORS) + list(mcolors.XKCD_COLORS) RANDOM_MARKERS = ["1", "2", "3", "4", "8", "s", "p", "P", "*", "h", "H", "+"] MARKERS = ["o", "s", "v", "^", "<", ">", "P", "X", "D", "*"] + RANDOM_MARKERS @@ -20,7 +24,7 @@ def __init__( xlabel: str = None, ylabel: str = None, display_legend: bool = True, - legend_fontsize: str = "small", + legend_fontsize: str = "medium", save_dir_path: str = None, title: str = None, xlim: str = None, @@ -74,6 +78,7 @@ def __init__( class PlotHelper: def __init__(self, plot_config: PlotConfig): self.plot_cfg = plot_config + self.additional_filename_suffix = "" def plot_lines(self, data_groups: dict): """ @@ -157,13 +162,32 @@ def _get_label(self, group_id, col): label = group_id return label - def _finalize_plot(self, fig): + def _finalize_plot(self, fig, remove_err_bars_from_legend=False): if self.plot_cfg.display_legend: - plt.legend( - numpoints=1, - frameon=True, - fontsize=self.plot_cfg.legend_fontsize, - ) + if remove_err_bars_from_legend: + ax = plt.gca() + # get handles + handles, labels = ax.get_legend_handles_labels() + # remove the errorbars + handles = [h[0] for h in handles] + # use them in the legend + ax.legend( + handles, + labels, + numpoints=1, + frameon=True, + fontsize=self.plot_cfg.legend_fontsize, + ) + else: + plt.legend( + numpoints=1, + frameon=True, + fontsize=self.plot_cfg.legend_fontsize, + ) + + # bbox_to_anchor = (0.66, -0.20), + # # loc="upper left", + # ) if self.plot_cfg.xlabel is not None: plt.xlabel(self.plot_cfg.xlabel) if self.plot_cfg.ylabel is not None: @@ -176,10 +200,11 @@ def _finalize_plot(self, fig): plt.ylim(self.plot_cfg.ylim) if self.plot_cfg.background_area_coord is not None: self._add_background_area() + # plt.tight_layout(rect=[0, -0.05, 1.0, 1.0]) if self.plot_cfg.save_dir_path is not None: file_name = ( f"{self.plot_cfg.filename_prefix}_{self.plot_cfg.ylabel}_vs" - f"_{self.plot_cfg.xlabel}.png" + f"_{self.plot_cfg.xlabel}{self.additional_filename_suffix}.png" ) file_name = file_name.replace("/", "_") file_path = os.path.join(self.plot_cfg.save_dir_path, file_name) @@ -195,6 +220,19 @@ def plot_dots(self, data_groups: dict): :param data_groups: dict of groups (same color and label prefix) containing a DataFrame containing (x, y) tuples. Each column in a group DataFrame has a different marker. """ + self._plot_dots_multiple_points(data_groups) + self._plot_dots_one_point_wt_std_dev_bars(data_groups) + + def _plot_dots_multiple_points(self, data_groups: dict): + return self._plot_dots_customizable(data_groups) + + def _plot_dots_one_point_wt_std_dev_bars(self, data_groups: dict): + self.additional_filename_suffix = "_wt_err_bars" + plot_filname = self._plot_dots_customizable(data_groups, err_bars=True) + self.additional_filename_suffix = "" + return plot_filname + + def _plot_dots_customizable(self, data_groups: dict, err_bars=False): fig = self._init_plot() self.counter_labels = 0 @@ -203,47 +241,96 @@ def plot_dots(self, data_groups: dict): data_groups.items() ): new_labels_plotted = self._plot_dotes_for_one_group( - self.plot_cfg.colors[group_index], group_id, group_df + self.plot_cfg.colors[group_index], group_id, group_df, err_bars ) all_label_plotted.extend(new_labels_plotted) print("all_label_plotted", all_label_plotted) - return self._finalize_plot(fig) + return self._finalize_plot(fig, remove_err_bars_from_legend=err_bars) - def _plot_dotes_for_one_group(self, group_color, group_id, group_df): + def _plot_dotes_for_one_group( + self, group_color, group_id, group_df, err_bars=False + ): label_plotted = [] for col in group_df.columns: - x, y = self._select_n_points_to_plot(group_df, col) - x, y = self._add_jitter_to_points(x, y) - x, y = self._apply_scale_multiplier(x, y) - label = self._get_label(group_id, col) - - plt.plot( - x, - y, - markerfacecolor="none" - if self.plot_cfg.empty_markers - else group_color, - markeredgecolor=group_color, - linestyle="None", - marker=self.plot_cfg.markers[self.counter_labels], - color=group_color, - label=label, - alpha=self.plot_cfg.alpha, - markersize=self.plot_cfg.markersize, - ) - self.counter_labels += 1 + if not err_bars: + label = self._plot_wtout_err_bars( + group_df, col, group_id, group_color + ) + else: + label = self._plot_wt_err_bars( + group_df, col, group_id, group_color + ) label_plotted.append(label) return label_plotted + def _plot_wtout_err_bars(self, group_df, col, group_id, group_color): + x, y = self._select_n_points_to_plot(group_df, col) + x, y = self._add_jitter_to_points(x, y) + x, y = self._apply_scale_multiplier(x, y) + label = self._get_label(group_id, col) + + plt.plot( + x, + y, + markerfacecolor="none" + if self.plot_cfg.empty_markers + else group_color, + markeredgecolor=group_color, + linestyle="None", + marker=self.plot_cfg.markers[self.counter_labels], + color=group_color, + label=label, + alpha=self.plot_cfg.alpha, + markersize=self.plot_cfg.markersize, + ) + self.counter_labels += 1 + return label + + def _plot_wt_err_bars(self, group_df, col, group_id, group_color): + x, y = self._select_all_points_to_plot(group_df, col) + x, y = self._apply_scale_multiplier(x, y) + label = self._get_label(group_id, col) + x_mean = np.array(x).mean() + y_mean = np.array(y).mean() + x_std_err = np.array(x).std() / np.sqrt(len(x)) + y_std_err = np.array(y).std() / np.sqrt(len(y)) + + plt.errorbar( + x_mean, + y_mean, + xerr=x_std_err, + yerr=y_std_err, + markerfacecolor="none" + if self.plot_cfg.empty_markers + else group_color, + markeredgecolor=group_color, + linestyle="None", + marker=self.plot_cfg.markers[self.counter_labels], + color=group_color, + label=label, + alpha=self.plot_cfg.alpha, + markersize=36 * 2.0 + if self.plot_cfg.markersize is None + else self.plot_cfg.markersize * 2.0, + ) + self.counter_labels += 1 + return label + def _select_n_points_to_plot(self, group_df, col): if self.plot_cfg.plot_max_n_points is not None: n_points_to_plot = min( self.plot_cfg.plot_max_n_points, len(group_df) ) print(f"Selected {n_points_to_plot} n_points_to_plot") + return self._get_points_to_plot(group_df, n_points_to_plot, col) else: - n_points_to_plot = len(group_df) + return self._select_all_points_to_plot(group_df, col) + + def _select_all_points_to_plot(self, group_df, col): + return self._get_points_to_plot(group_df, len(group_df), col) + + def _get_points_to_plot(self, group_df, n_points_to_plot, col): group_df_sample = group_df.sample(n=int(n_points_to_plot)) points = group_df_sample[col].tolist() x, y = [p[0] for p in points], [p[1] for p in points] diff --git a/marltoolbox/utils/policy.py b/marltoolbox/utils/policy.py index c7a59a2..77449db 100644 --- a/marltoolbox/utils/policy.py +++ b/marltoolbox/utils/policy.py @@ -5,7 +5,6 @@ from ray.rllib.policy.torch_policy import LearningRateSchedule from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule from ray.rllib.utils.typing import TrainerConfigDict - from marltoolbox.utils.restore import LOAD_FROM_CONFIG_KEY @@ -21,39 +20,53 @@ def get_tune_policy_class(PolicyClass): """ class FrozenPolicyFromTuneTrainer(PolicyClass): - - def __init__(self, observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - config: TrainerConfigDict): + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + config: TrainerConfigDict, + ): print("__init__ FrozenPolicyFromTuneTrainer") self.tune_config = config["tune_config"] TuneTrainerClass = self.tune_config["TuneTrainerClass"] self.tune_trainer = TuneTrainerClass(config=self.tune_config) self.load_checkpoint( - config.pop(LOAD_FROM_CONFIG_KEY, (None, None))) + config.pop(LOAD_FROM_CONFIG_KEY, (None, None)) + ) + self._to_log = {} super().__init__(observation_space, action_space, config) - def compute_actions(self, - obs_batch, - state_batches=None, - prev_action_batch=None, - prev_reward_batch=None, - info_batch=None, - episodes=None, - **kwargs): - actions, state_out, extra_fetches = \ - self.tune_trainer.compute_actions(self.policy_id, obs_batch) + def _compute_action_helper( + self, + input_dict, + *args, + **kwargs, + ): + # print('input_dict["obs"]', input_dict["obs"]) + ( + actions, + state_out, + extra_fetches, + ) = self.tune_trainer.compute_actions( + self.policy_id, input_dict["obs"] + ) return actions, state_out, extra_fetches + def _initialize_loss_from_dummy_batch(self, *args, **kwargs): + pass + def learn_on_batch(self, samples): raise NotImplementedError( - "FrozenPolicyFromTuneTrainer policy can't be trained") + "FrozenPolicyFromTuneTrainer policy can't be trained" + ) def get_weights(self): - return {"checkpoint_path": self.checkpoint_path, - "policy_id": self.policy_id} + return { + "checkpoint_path": self.checkpoint_path, + "policy_id": self.policy_id, + } def set_weights(self, weights): checkpoint_path = weights["checkpoint_path"] @@ -65,6 +78,33 @@ def load_checkpoint(self, checkpoint_tuple): if self.checkpoint_path is not None: self.tune_trainer.load_checkpoint(self.checkpoint_path) + @property + def to_log(self): + to_log = { + "frozen_policy": self._to_log, + "nested_tune_policy": { + f"policy_{algo_idx}": algo.to_log + for algo_idx, algo in enumerate(self.algorithms) + if hasattr(algo, "to_log") + }, + } + return to_log + + @to_log.setter + def to_log(self, value): + if value == {}: + for algo in self.algorithms: + if hasattr(algo, "to_log"): + algo.to_log = {} + + self._to_log = value + + def on_episode_start(self, *args, **kwargs): + if hasattr(self.tune_trainer, "reset_compute_actions_state"): + self.tune_trainer.reset_compute_actions_state() + if hasattr(self.tune_trainer, "on_episode_start"): + self.tune_trainer.on_episode_start() + return FrozenPolicyFromTuneTrainer @@ -81,14 +121,17 @@ def __init__(self, lr, lr_schedule): else: if isinstance(lr_schedule, Iterable): self.lr_schedule = PiecewiseSchedule( - lr_schedule, outside_value=lr_schedule[-1][-1], - framework=None) + lr_schedule, + outside_value=lr_schedule[-1][-1], + framework=None, + ) else: self.lr_schedule = lr_schedule -def my_setup_early_mixins(policy: Policy, obs_space, action_space, - config: TrainerConfigDict) -> None: - MyLearningRateSchedule.__init__(policy, - config["lr"], - config["lr_schedule"]) +def my_setup_early_mixins( + policy: Policy, obs_space, action_space, config: TrainerConfigDict +) -> None: + MyLearningRateSchedule.__init__( + policy, config["lr"], config["lr_schedule"] + ) diff --git a/marltoolbox/utils/postprocessing.py b/marltoolbox/utils/postprocessing.py index 894e3e1..76c0069 100644 --- a/marltoolbox/utils/postprocessing.py +++ b/marltoolbox/utils/postprocessing.py @@ -4,14 +4,17 @@ import numpy as np from ray.rllib.agents.callbacks import DefaultCallbacks from ray.rllib.evaluation import MultiAgentEpisode -from ray.rllib.evaluation.postprocessing import discount +from ray.rllib.evaluation.postprocessing import discount_cumsum from ray.rllib.evaluation.sampler import _get_or_raise from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.typing import AgentID, PolicyID +from ray.rllib.utils.schedules import Schedule -from marltoolbox.utils.miscellaneous import \ - assert_if_key_in_dict_then_args_are_none, read_from_dict_default_to_args +from marltoolbox.utils.miscellaneous import ( + assert_if_key_in_dict_then_args_are_none, + read_from_dict_default_to_args, +) if TYPE_CHECKING: from ray.rllib.evaluation import RolloutWorker @@ -37,22 +40,25 @@ ADD_OPPONENT_ACTION = "add_opponent_action" ADD_OPPONENT_NEG_REWARD = "add_opponent_neg_reward" -ADD_WELFARE_CONFIG_KEYS = (ADD_UTILITARIAN_WELFARE, - ADD_INEQUITY_AVERSION_WELFARE) +ADD_WELFARE_CONFIG_KEYS = ( + ADD_UTILITARIAN_WELFARE, + ADD_INEQUITY_AVERSION_WELFARE, +) def welfares_postprocessing_fn( - add_utilitarian_welfare: bool = None, - add_egalitarian_welfare: bool = None, - add_nash_welfare: bool = None, - add_opponent_action: bool = None, - add_opponent_neg_reward: bool = None, - add_inequity_aversion_welfare: bool = None, - inequity_aversion_alpha: float = None, - inequity_aversion_beta: float = None, - inequity_aversion_gamma: float = None, - inequity_aversion_lambda: float = None, - additional_fn: list = []): + add_utilitarian_welfare: bool = None, + add_egalitarian_welfare: bool = None, + add_nash_welfare: bool = None, + add_opponent_action: bool = None, + add_opponent_neg_reward: bool = None, + add_inequity_aversion_welfare: bool = None, + inequity_aversion_alpha: float = None, + inequity_aversion_beta: float = None, + inequity_aversion_gamma: float = None, + inequity_aversion_lambda: float = None, + additional_fn: list = [], +): """ Generate a postprocess_fn that first add a welfare if you chose so and then call a list of additional postprocess_fn to further modify the @@ -60,7 +66,10 @@ def welfares_postprocessing_fn( The parameters used to add a welfare can be given as arguments or will be read in the policy config dict (this should be preferred since this - allows for hyperparameter search over these parameters with Tune). + allows for hyperparameter search over these parameters with Tune).When + read from the config, they are read from the keys defined by + ADD_INEQUITY_AVERSION_WELFARE and similars. The value oassociated with + this key must be a tuple if you provide several parameters. :param add_utilitarian_welfare: :param add_egalitarian_welfare: @@ -80,121 +89,202 @@ def welfares_postprocessing_fn( :return: """ + # TODO refactor into instanciating an object (from a new class) and + # returning one of it method OR use named tuple def postprocess_fn(policy, sample_batch, other_agent_batches, episode): + if other_agent_batches is None: + logger.warning( + "no other_agent_batches given for welfare " "postprocessing" + ) + return sample_batch _assert_using_config_xor_args(policy) parameters = _read_parameters_from_config_default_to_args(policy) + parameters = _get_values_from_any_scheduler(parameters, policy) sample_batch = _add_welfare_to_own_batch( - sample_batch, other_agent_batches, episode, policy, *parameters) + sample_batch, other_agent_batches, episode, policy, *parameters + ) sample_batch = _call_list_of_additional_fn( - additional_fn, sample_batch, other_agent_batches, episode, policy) + additional_fn, sample_batch, other_agent_batches, episode, policy + ) return sample_batch def _assert_using_config_xor_args(policy): assert_if_key_in_dict_then_args_are_none( - policy.config, "add_utilitarian_welfare", add_utilitarian_welfare) + policy.config, ADD_UTILITARIAN_WELFARE, add_utilitarian_welfare + ) assert_if_key_in_dict_then_args_are_none( - policy.config, "add_inequity_aversion_welfare", - add_inequity_aversion_welfare, inequity_aversion_alpha, - inequity_aversion_beta, inequity_aversion_gamma, - inequity_aversion_lambda) + policy.config, + ADD_INEQUITY_AVERSION_WELFARE, + add_inequity_aversion_welfare, + inequity_aversion_alpha, + inequity_aversion_beta, + inequity_aversion_gamma, + inequity_aversion_lambda, + ) assert_if_key_in_dict_then_args_are_none( - policy.config, "add_nash_welfare", add_nash_welfare) + policy.config, ADD_NASH_WELFARE, add_nash_welfare + ) assert_if_key_in_dict_then_args_are_none( - policy.config, "add_egalitarian_welfare", add_egalitarian_welfare) + policy.config, ADD_EGALITARIAN_WELFARE, add_egalitarian_welfare + ) assert_if_key_in_dict_then_args_are_none( - policy.config, "add_opponent_action", add_opponent_action) + policy.config, ADD_OPPONENT_ACTION, add_opponent_action + ) assert_if_key_in_dict_then_args_are_none( - policy.config, "add_opponent_neg_reward", add_opponent_neg_reward) + policy.config, ADD_OPPONENT_NEG_REWARD, add_opponent_neg_reward + ) def _read_parameters_from_config_default_to_args(policy): add_utilitarian_w = read_from_dict_default_to_args( - policy.config, ADD_UTILITARIAN_WELFARE, add_utilitarian_welfare) - add_ia_w, ia_alpha, ia_beta, ia_gamma, ia_lambda = \ - read_from_dict_default_to_args( - policy.config, ADD_INEQUITY_AVERSION_WELFARE, - add_inequity_aversion_welfare, inequity_aversion_alpha, - inequity_aversion_beta, inequity_aversion_gamma, - inequity_aversion_lambda) + policy.config, ADD_UTILITARIAN_WELFARE, add_utilitarian_welfare + ) + ( + add_ia_w, + ia_alpha, + ia_beta, + ia_gamma, + ia_lambda, + ) = read_from_dict_default_to_args( + policy.config, + ADD_INEQUITY_AVERSION_WELFARE, + add_inequity_aversion_welfare, + inequity_aversion_alpha, + inequity_aversion_beta, + inequity_aversion_gamma, + inequity_aversion_lambda, + ) add_nash_w = read_from_dict_default_to_args( - policy.config, ADD_NASH_WELFARE, add_nash_welfare) + policy.config, ADD_NASH_WELFARE, add_nash_welfare + ) add_egalitarian_w = read_from_dict_default_to_args( - policy.config, ADD_EGALITARIAN_WELFARE, add_egalitarian_welfare) + policy.config, ADD_EGALITARIAN_WELFARE, add_egalitarian_welfare + ) add_opponent_a = read_from_dict_default_to_args( - policy.config, ADD_OPPONENT_ACTION, add_opponent_action) + policy.config, ADD_OPPONENT_ACTION, add_opponent_action + ) add_opponent_neg_r = read_from_dict_default_to_args( - policy.config, ADD_OPPONENT_NEG_REWARD, add_opponent_neg_reward) - - return add_utilitarian_w, \ - add_ia_w, ia_alpha, ia_beta, ia_gamma, ia_lambda, \ - add_nash_w, add_egalitarian_w, \ - add_opponent_a, add_opponent_neg_r + policy.config, ADD_OPPONENT_NEG_REWARD, add_opponent_neg_reward + ) + + return ( + add_utilitarian_w, + add_ia_w, + ia_alpha, + ia_beta, + ia_gamma, + ia_lambda, + add_nash_w, + add_egalitarian_w, + add_opponent_a, + add_opponent_neg_r, + ) + + def _get_values_from_any_scheduler(parameters, policy): + logger.debug("_get_values_from_any_scheduler") + new_parameters = [] + for param in parameters: + if isinstance(param, Schedule): + value_from_scheduler = param.value(policy.global_timestep) + new_parameters.append(value_from_scheduler) + else: + new_parameters.append(param) + return new_parameters def _add_welfare_to_own_batch( - sample_batch, other_agent_batches, episode, policy, *parameters): - - add_utilitarian_w, \ - add_ia_w, ia_alpha, ia_beta, ia_gamma, ia_lambda, \ - add_nash_w, add_egalitarian_w, \ - add_opponent_a, add_opponent_neg_r = parameters - - assert len(set(sample_batch[sample_batch.EPS_ID])) == 1, \ - "design to work on one complete episode" - assert sample_batch[sample_batch.DONES][-1], \ - "design to work on one complete episode, dones: " \ - f"{sample_batch[sample_batch.DONES]}" + sample_batch, other_agent_batches, episode, policy, *parameters + ): + + ( + add_utilitarian_w, + add_ia_w, + ia_alpha, + ia_beta, + ia_gamma, + ia_lambda, + add_nash_w, + add_egalitarian_w, + add_opponent_a, + add_opponent_neg_r, + ) = parameters + + _assert_working_on_one_full_epi(sample_batch) if add_utilitarian_w: - logger.debug(f"add utilitarian welfare to batch of policy" - f" {policy}") + logger.debug( + f"add utilitarian welfare to batch of policy" f" {policy}" + ) opp_batches = [v[1] for v in other_agent_batches.values()] sample_batch = _add_utilitarian_welfare_to_batch( - sample_batch, opp_batches, policy) + sample_batch, opp_batches, policy + ) if add_ia_w: - logger.debug(f"add inequity aversion welfare to batch of policy" - f" {policy}") + logger.debug( + f"add inequity aversion welfare to batch of policy" + f" {policy}" + ) _assert_two_players_env(other_agent_batches) opp_batch = _get_opp_batch(other_agent_batches) sample_batch = _add_inequity_aversion_welfare_to_batch( - sample_batch, opp_batch, + sample_batch, + opp_batch, alpha=ia_alpha, beta=ia_beta, gamma=ia_gamma, lambda_=ia_lambda, - policy=policy) + policy=policy, + ) if add_nash_w: _assert_two_players_env(other_agent_batches) opp_batch = _get_opp_batch(other_agent_batches) sample_batch = _add_nash_welfare_to_batch( - sample_batch, opp_batch, policy) + sample_batch, opp_batch, policy + ) if add_egalitarian_w: _assert_two_players_env(other_agent_batches) opp_batch = _get_opp_batch(other_agent_batches) sample_batch = _add_egalitarian_welfare_to_batch( - sample_batch, opp_batch, policy) + sample_batch, opp_batch, policy + ) if add_opponent_a: _assert_two_players_env(other_agent_batches) opp_batch = _get_opp_batch(other_agent_batches) sample_batch = _add_opponent_action_to_batch( - sample_batch, opp_batch, policy) + sample_batch, opp_batch, policy + ) if add_opponent_neg_r: _assert_two_players_env(other_agent_batches) opp_batch = _get_opp_batch(other_agent_batches) sample_batch = _add_opponent_neg_reward_to_batch( - sample_batch, opp_batch, policy) + sample_batch, opp_batch, policy + ) return sample_batch return postprocess_fn -def _call_list_of_additional_fn(additional_fn, - sample_batch, other_agent_batches, episode, - policy): +def _assert_working_on_one_full_epi(sample_batch): + assert ( + len(set(sample_batch[sample_batch.EPS_ID])) == 1 + ), "designed to work on one complete episode" + assert ( + not any(sample_batch[sample_batch.DONES][:-1]) + or sample_batch[sample_batch.DONES][-1] + ), ( + "welfare postprocessing is designed to work on one complete episode, " + f"dones: {sample_batch[sample_batch.DONES]}" + ) + + +def _call_list_of_additional_fn( + additional_fn, sample_batch, other_agent_batches, episode, policy +): for postprocessing_function in additional_fn: sample_batch = postprocessing_function( - sample_batch, other_agent_batches, episode, policy) + sample_batch, other_agent_batches, episode, policy + ) return sample_batch @@ -208,19 +298,20 @@ def _get_opp_batch(other_agent_batches): def _add_utilitarian_welfare_to_batch( - sample_batch: SampleBatch, - opp_ag_batchs: List[SampleBatch], - policy=None + sample_batch: SampleBatch, opp_ag_batchs: List[SampleBatch], policy=None ) -> SampleBatch: - all_batchs_rewards = ([sample_batch[sample_batch.REWARDS]] + - [opp_batch[opp_batch.REWARDS] for opp_batch in - opp_ag_batchs]) - sample_batch.data[WELFARE_UTILITARIAN] = np.array( - [sum(reward_points) for reward_points in zip(*all_batchs_rewards)]) - - _ = _log_in_policy(np.sum(sample_batch.data[WELFARE_UTILITARIAN]), - f"sum_over_epi_{WELFARE_UTILITARIAN}", - policy) + all_batchs_rewards = [sample_batch[sample_batch.REWARDS]] + [ + opp_batch[opp_batch.REWARDS] for opp_batch in opp_ag_batchs + ] + sample_batch[WELFARE_UTILITARIAN] = np.array( + [sum(reward_points) for reward_points in zip(*all_batchs_rewards)] + ) + + _ = _log_in_policy( + np.sum(sample_batch[WELFARE_UTILITARIAN]), + f"sum_over_epi_{WELFARE_UTILITARIAN}", + policy, + ) return sample_batch @@ -234,31 +325,40 @@ def _log_in_policy(value, name_value, policy=None): def _add_opponent_action_to_batch( - sample_batch: SampleBatch, - opp_ag_batch: SampleBatch, - policy=None) -> SampleBatch: - sample_batch.data[OPPONENT_ACTIONS] = opp_ag_batch[opp_ag_batch.ACTIONS] - _ = _log_in_policy(sample_batch.data[OPPONENT_ACTIONS][-1], - f"last_{OPPONENT_ACTIONS}", policy) + sample_batch: SampleBatch, opp_ag_batch: SampleBatch, policy=None +) -> SampleBatch: + sample_batch[OPPONENT_ACTIONS] = opp_ag_batch[opp_ag_batch.ACTIONS] + _ = _log_in_policy( + sample_batch[OPPONENT_ACTIONS][-1], + f"last_{OPPONENT_ACTIONS}", + policy, + ) return sample_batch def _add_opponent_neg_reward_to_batch( - sample_batch: SampleBatch, - opp_ag_batch: SampleBatch, - policy=None) -> SampleBatch: - sample_batch.data[OPPONENT_NEGATIVE_REWARD] = np.array( - [- opp_r for opp_r in opp_ag_batch[opp_ag_batch.REWARDS]]) - _ = _log_in_policy(np.sum(sample_batch.data[OPPONENT_NEGATIVE_REWARD]), - f"sum_over_epi_{OPPONENT_NEGATIVE_REWARD}", policy) + sample_batch: SampleBatch, opp_ag_batch: SampleBatch, policy=None +) -> SampleBatch: + sample_batch[OPPONENT_NEGATIVE_REWARD] = np.array( + [-opp_r for opp_r in opp_ag_batch[opp_ag_batch.REWARDS]] + ) + _ = _log_in_policy( + np.sum(sample_batch[OPPONENT_NEGATIVE_REWARD]), + f"sum_over_epi_{OPPONENT_NEGATIVE_REWARD}", + policy, + ) return sample_batch def _add_inequity_aversion_welfare_to_batch( - sample_batch: SampleBatch, opp_ag_batch: SampleBatch, - alpha: float, beta: float, gamma: float, - lambda_: float, - policy=None) -> SampleBatch: + sample_batch: SampleBatch, + opp_ag_batch: SampleBatch, + alpha: float, + beta: float, + gamma: float, + lambda_: float, + policy=None, +) -> SampleBatch: """ :param sample_batch: SampleBatch to mutate :param opp_ag_batchs: @@ -274,90 +374,115 @@ def _add_inequity_aversion_welfare_to_batch( opp_rewards = np.array(opp_ag_batch[opp_ag_batch.REWARDS]) own_rewards = np.flip(own_rewards) opp_rewards = np.flip(opp_rewards) - delta = (discount(own_rewards, gamma * lambda_) - - discount(opp_rewards, gamma * lambda_)) + delta = discount_cumsum(own_rewards, gamma * lambda_) - discount_cumsum( + opp_rewards, gamma * lambda_ + ) delta = np.flip(delta) disvalue_lower_than_opp = alpha * (-delta) disvalue_higher_than_opp = beta * delta disvalue_lower_than_opp[disvalue_lower_than_opp < 0] = 0 disvalue_higher_than_opp[disvalue_higher_than_opp < 0] = 0 - welfare = sample_batch[sample_batch.REWARDS] - \ - disvalue_lower_than_opp - disvalue_higher_than_opp + welfare = ( + sample_batch[sample_batch.REWARDS] + - disvalue_lower_than_opp + - disvalue_higher_than_opp + ) - sample_batch.data[WELFARE_INEQUITY_AVERSION] = welfare + sample_batch[WELFARE_INEQUITY_AVERSION] = welfare policy = _log_in_policy( - np.sum(sample_batch.data[WELFARE_INEQUITY_AVERSION]), - f"sum_over_epi_{WELFARE_INEQUITY_AVERSION}", policy) + np.sum(sample_batch[WELFARE_INEQUITY_AVERSION]), + f"sum_over_epi_{WELFARE_INEQUITY_AVERSION}", + policy, + ) return sample_batch def _add_nash_welfare_to_batch( - sample_batch: SampleBatch, opp_ag_batch: SampleBatch, - policy=None) -> SampleBatch: + sample_batch: SampleBatch, opp_ag_batch: SampleBatch, policy=None +) -> SampleBatch: own_rewards = np.array(opp_ag_batch[opp_ag_batch.REWARDS]) opp_rewards = np.array(opp_ag_batch[opp_ag_batch.REWARDS]) own_rewards_under_defection = np.array( - opp_ag_batch.data[REWARDS_UNDER_DEFECTION]) + opp_ag_batch[REWARDS_UNDER_DEFECTION] + ) opp_rewards_under_defection = np.array( - opp_ag_batch.data[REWARDS_UNDER_DEFECTION]) + opp_ag_batch[REWARDS_UNDER_DEFECTION] + ) - own_delta = (sum(own_rewards) - sum(own_rewards_under_defection)) - opp_delta = (sum(opp_rewards) - sum(opp_rewards_under_defection)) + own_delta = sum(own_rewards) - sum(own_rewards_under_defection) + opp_delta = sum(opp_rewards) - sum(opp_rewards_under_defection) nash_welfare = own_delta * opp_delta - sample_batch.data[WELFARE_NASH] = ([0.0] * ( - len(sample_batch[sample_batch.REWARDS]) - 1)) + [nash_welfare] - policy = _log_in_policy(np.sum(sample_batch.data[WELFARE_NASH]), - f"sum_over_epi_{WELFARE_NASH}", policy) + sample_batch[WELFARE_NASH] = ( + [0.0] * (len(sample_batch[sample_batch.REWARDS]) - 1) + ) + [nash_welfare] + policy = _log_in_policy( + np.sum(sample_batch[WELFARE_NASH]), + f"sum_over_epi_{WELFARE_NASH}", + policy, + ) return sample_batch def _add_egalitarian_welfare_to_batch( - sample_batch: SampleBatch, opp_ag_batch: SampleBatch, - policy=None) -> SampleBatch: + sample_batch: SampleBatch, opp_ag_batch: SampleBatch, policy=None +) -> SampleBatch: own_rewards = np.array(opp_ag_batch[opp_ag_batch.REWARDS]) opp_rewards = np.array(opp_ag_batch[opp_ag_batch.REWARDS]) own_rewards_under_defection = np.array( - opp_ag_batch.data[REWARDS_UNDER_DEFECTION]) + opp_ag_batch[REWARDS_UNDER_DEFECTION] + ) opp_rewards_under_defection = np.array( - opp_ag_batch.data[REWARDS_UNDER_DEFECTION]) + opp_ag_batch[REWARDS_UNDER_DEFECTION] + ) - own_delta = (sum(own_rewards) - sum(own_rewards_under_defection)) - opp_delta = (sum(opp_rewards) - sum(opp_rewards_under_defection)) + own_delta = sum(own_rewards) - sum(own_rewards_under_defection) + opp_delta = sum(opp_rewards) - sum(opp_rewards_under_defection) egalitarian_welfare = min(own_delta, opp_delta) - sample_batch.data[WELFARE_EGALITARIAN] = ([0.0] * ( - len(sample_batch[sample_batch.REWARDS]) - 1)) + [ - egalitarian_welfare] - policy = _log_in_policy(np.sum(sample_batch.data[WELFARE_EGALITARIAN]), - f"sum_over_epi_{WELFARE_EGALITARIAN}", - policy) + sample_batch[WELFARE_EGALITARIAN] = ( + [0.0] * (len(sample_batch[sample_batch.REWARDS]) - 1) + ) + [egalitarian_welfare] + policy = _log_in_policy( + np.sum(sample_batch[WELFARE_EGALITARIAN]), + f"sum_over_epi_{WELFARE_EGALITARIAN}", + policy, + ) return sample_batch class OverwriteRewardWtWelfareCallback(DefaultCallbacks): - def on_postprocess_trajectory( - self, *, worker: "RolloutWorker", episode: MultiAgentEpisode, - agent_id: AgentID, policy_id: PolicyID, - policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch, - original_batches: Dict[AgentID, SampleBatch], **kwargs): - - assert sum([k in WELFARES for k in - postprocessed_batch.data.keys()]) <= 1, \ - "only one welfare must be available" + self, + *, + worker: "RolloutWorker", + episode: MultiAgentEpisode, + agent_id: AgentID, + policy_id: PolicyID, + policies: Dict[PolicyID, Policy], + postprocessed_batch: SampleBatch, + original_batches: Dict[AgentID, SampleBatch], + **kwargs, + ): + + assert ( + sum([k in WELFARES for k in postprocessed_batch.keys()]) <= 1 + ), "only one welfare must be available" for welfare_key in WELFARES: - if welfare_key in postprocessed_batch.data.keys(): - postprocessed_batch[postprocessed_batch.REWARDS] = \ - postprocessed_batch.data[welfare_key] - msg = f"Overwrite the reward of agent_id {agent_id} " \ - f"with the value from the" \ - f" welfare_key {welfare_key}" + if welfare_key in postprocessed_batch.keys(): + postprocessed_batch[ + postprocessed_batch.REWARDS + ] = postprocessed_batch[welfare_key] + msg = ( + f"Overwrite the reward of agent_id {agent_id} " + f"with the value from the" + f" welfare_key {welfare_key}" + ) print(msg) logger.debug(msg) break @@ -366,7 +491,8 @@ def on_postprocess_trajectory( def apply_preprocessors(worker, raw_observation, policy_id): - prep_obs = _get_or_raise( - worker.preprocessors, policy_id).transform(raw_observation) + prep_obs = _get_or_raise(worker.preprocessors, policy_id).transform( + raw_observation + ) filtered_obs = _get_or_raise(worker.filters, policy_id)(prep_obs) return filtered_obs diff --git a/marltoolbox/utils/restore.py b/marltoolbox/utils/restore.py index 519925f..be00915 100644 --- a/marltoolbox/utils/restore.py +++ b/marltoolbox/utils/restore.py @@ -1,13 +1,18 @@ import logging import os import pickle +from typing import List + +from marltoolbox import utils +from marltoolbox.utils import path +from ray.tune.analysis import ExperimentAnalysis logger = logging.getLogger(__name__) LOAD_FROM_CONFIG_KEY = "checkpoint_to_load_from" -def after_init_load_policy_checkpoint( +def before_loss_init_load_policy_checkpoint( policy, observation_space=None, action_space=None, trainer_config=None ): """ @@ -27,7 +32,7 @@ def after_init_load_policy_checkpoint( Example: determining the checkpoint to load conditional on the current seed (when doing a grid_search over random seeds and with a multistage training) """ - checkpoint_path, policy_id = policy.config.pop( + checkpoint_path, policy_id = policy.config.get( LOAD_FROM_CONFIG_KEY, (None, None) ) @@ -40,8 +45,7 @@ def after_init_load_policy_checkpoint( f"marltoolbox restore: checkpoint found for policy_id: " f"{policy_id}" ) - logger.info(msg) - print(msg) + logger.debug(msg) else: msg = ( f"marltoolbox restore: NO checkpoint found for policy_id:" @@ -49,7 +53,6 @@ def after_init_load_policy_checkpoint( f"Not found under the config key: {LOAD_FROM_CONFIG_KEY}" ) logger.warning(msg) - print(msg) def load_one_policy_checkpoint( @@ -75,7 +78,7 @@ def load_one_policy_checkpoint( policy.load_checkpoint(checkpoint_tuple=(checkpoint_path, policy_id)) else: checkpoint_path = os.path.expanduser(checkpoint_path) - logger.info(f"checkpoint_path {checkpoint_path}") + logger.debug(f"checkpoint_path {checkpoint_path}") checkpoint = pickle.load(open(checkpoint_path, "rb")) assert "worker" in checkpoint.keys() assert "optimizer" not in checkpoint.keys() @@ -86,7 +89,7 @@ def load_one_policy_checkpoint( found_policy_id = False for p_id, state in objs["state"].items(): if p_id == policy_id: - print( + logger.debug( f"going to load policy {policy_id} " f"from checkpoint {checkpoint_path}" ) @@ -94,8 +97,114 @@ def load_one_policy_checkpoint( found_policy_id = True break if not found_policy_id: - print( + logger.debug( f"policy_id {policy_id} not in " f'checkpoint["worker"]["state"].keys() ' f'{objs["state"].keys()}' ) + + +def extract_checkpoints_from_experiment_analysis( + tune_experiment_analysis: ExperimentAnalysis, +) -> List[str]: + """ + Extract all the best checkpoints from a tune analysis object. This tune + analysis can contains several trials. Each trial can contains several + checkpoitn, only the best checkpoint per trial is returned. + + :param tune_experiment_analysis: + :return: list of all the unique best checkpoints for each trials in the + tune analysis. + """ + logger.info("start extract_checkpoints") + + for trial in tune_experiment_analysis.trials: + checkpoints = tune_experiment_analysis.get_trial_checkpoints_paths( + trial, tune_experiment_analysis.default_metric + ) + assert len(checkpoints) > 0 + + all_best_checkpoints_per_trial = [ + tune_experiment_analysis.get_best_checkpoint( + trial, + metric=tune_experiment_analysis.default_metric, + mode=tune_experiment_analysis.default_mode, + ) + for trial in tune_experiment_analysis.trials + ] + + for checkpoint in all_best_checkpoints_per_trial: + assert checkpoint is not None + + logger.info("end extract_checkpoints") + return all_best_checkpoints_per_trial + + +def get_checkpoint_for_each_replicates( + all_replicates_save_dir: List[str], +) -> List[str]: + """ + Get the list of paths to the checkpoint files inside an experiment dir of + RLLib/Tune (which can contains several trials). + Works for an experiment with trials containing an unique checkpoint. + + :param all_replicates_save_dir: trial dir + :return: list of paths to checkpoint files + """ + ckpt_dir_per_replicate = [] + for replicate_dir_path in all_replicates_save_dir: + ckpt_dir_path = get_ckpt_dir_for_one_replicate(replicate_dir_path) + ckpt_path = get_ckpt_from_ckpt_dir(ckpt_dir_path) + ckpt_dir_per_replicate.append(ckpt_path) + return ckpt_dir_per_replicate + + +def get_ckpt_dir_for_one_replicate(replicate_dir_path: str) -> str: + """ + Get the path to the unique checkpoint dir inside a trial dir of RLLib/Tune. + + :param replicate_dir_path: trial dir + :return: path to checkpoint dir + """ + partialy_filtered_ckpt_dir = ( + utils.path.get_children_paths_wt_selecting_filter( + replicate_dir_path, _filter="checkpoint_" + ) + ) + ckpt_dir = [ + file_path + for file_path in partialy_filtered_ckpt_dir + if ".is_checkpoint" not in file_path + ] + assert len(ckpt_dir) == 1, f"{ckpt_dir}" + return ckpt_dir[0] + + +def get_ckpt_from_ckpt_dir(ckpt_dir_path: str) -> str: + """ + Get the path to the unique checkpoint file inside a checkpoint dir of + RLLib/Tune + :param ckpt_dir_path: checkpoint dir + :return: path to checkpoint file + """ + partialy_filtered_ckpt_path = ( + utils.path.get_children_paths_wt_discarding_filter( + ckpt_dir_path, _filter="tune_metadata" + ) + ) + filters = [ + # For Tune/RLLib + ".is_checkpoint", + # For TensorFlow + "ckpt.index", + "ckpt.data-", + "ckpt.meta", + ".json", + ] + ckpt_path = filter( + lambda el: all(filter_ not in el for filter_ in filters), + partialy_filtered_ckpt_path, + ) + ckpt_path = list(ckpt_path) + assert len(ckpt_path) == 1, f"{ckpt_path}" + return ckpt_path[0] diff --git a/marltoolbox/utils/rollout.py b/marltoolbox/utils/rollout.py index 84be755..6c8708e 100644 --- a/marltoolbox/utils/rollout.py +++ b/marltoolbox/utils/rollout.py @@ -5,18 +5,24 @@ import collections import copy +import logging from typing import List from gym import wrappers as gym_wrappers from ray.rllib.env import MultiAgentEnv from ray.rllib.env.base_env import _DUMMY_AGENT_ID from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.rollout import DefaultMapping, default_policy_agent_mapping, \ - RolloutSaver +from ray.rllib.rollout import ( + DefaultMapping, + default_policy_agent_mapping, + RolloutSaver, +) from ray.rllib.utils.framework import TensorStructType from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray from ray.rllib.utils.typing import EnvInfoDict, PolicyID +logger = logging.getLogger(__name__) + class RolloutManager(RolloutSaver): """ @@ -43,26 +49,29 @@ def append_step(self, obs, action, next_obs, reward, done, info): """Add a step to the current rollout, if we are saving them""" if self._save_info: self._current_rollout.append( - [obs, action, next_obs, reward, done, info]) + [obs, action, next_obs, reward, done, info] + ) else: - self._current_rollout.append( - [obs, action, next_obs, reward, done]) + self._current_rollout.append([obs, action, next_obs, reward, done]) self._total_steps += 1 -def internal_rollout(worker, - num_steps, - policy_map=None, - policy_agent_mapping=None, - reset_env_before=True, - num_episodes=0, - last_obs=None, - saver=None, - no_render=True, - video_dir=None, - seed=None, - explore=None, - ): +def internal_rollout( + worker, + num_steps, + policy_map=None, + policy_agent_mapping=None, + reset_env_before=True, + num_episodes=0, + last_obs=None, + saver=None, + no_render=True, + video_dir=None, + seed=None, + explore=None, + last_rnn_states=None, + base_env=None, +): """ Can perform rollouts on the environment from inside a worker_rollout or from a policy. Can perform rollouts during the evaluation rollouts ran @@ -85,6 +94,7 @@ def internal_rollout(worker, :param video_dir: (optional) :param seed: (optional) random seed to set for the environment by calling env.seed(seed) + :param last_rnn_states: map of policy_id to rnn_states :return: an instance of a RolloutManager, which contains the data about the rollouts performed """ @@ -95,15 +105,21 @@ def internal_rollout(worker, if saver is None: saver = RolloutManager() - env = copy.deepcopy(worker.env) - if hasattr(env, "seed") and callable(env.seed): - env.seed(seed) + if base_env is None: + env = worker.env + else: + env = base_env.get_unwrapped()[0] + + # if hasattr(env, "seed") and callable(env.seed): + # env.seed(seed) + env = copy.deepcopy(env) multiagent = isinstance(env, MultiAgentEnv) if policy_agent_mapping is None: if worker.multiagent: policy_agent_mapping = worker.policy_config["multiagent"][ - "policy_mapping_fn"] + "policy_mapping_fn" + ] else: policy_agent_mapping = default_policy_agent_mapping @@ -123,40 +139,54 @@ def internal_rollout(worker, env=env, directory=video_dir, video_callable=lambda x: True, - force=True) + force=True, + ) random_policy_id = list(policy_map.keys())[0] virtual_global_timestep = worker.get_policy( - random_policy_id).global_timestep + random_policy_id + ).global_timestep steps = 0 episodes = 0 while _keep_going(steps, num_steps, episodes, num_episodes): + # logger.info(f"Starting epsiode {episodes} in rollout") + # print(f"Starting epsiode {episodes} in rollout") mapping_cache = {} # in case policy_agent_mapping is stochastic saver.begin_rollout() - if reset_env_before or episodes > 0: - obs = env.reset() - else: - obs = last_obs - agent_states = DefaultMapping( - lambda agent_id_: state_init[mapping_cache[agent_id_]]) + obs, agent_states = _get_first_obs( + env, + reset_env_before, + episodes, + last_obs, + mapping_cache, + state_init, + last_rnn_states, + ) prev_actions = DefaultMapping( - lambda agent_id_: action_init[mapping_cache[agent_id_]]) - prev_rewards = collections.defaultdict(lambda: 0.) + lambda agent_id_: action_init[mapping_cache[agent_id_]] + ) + prev_rewards = collections.defaultdict(lambda: 0.0) done = False reward_total = 0.0 - while not done and _keep_going(steps, num_steps, episodes, - num_episodes): - + while not done and _keep_going( + steps, num_steps, episodes, num_episodes + ): multi_obs = obs if multiagent else {_DUMMY_AGENT_ID: obs} action_dict = {} virtual_global_timestep += 1 for agent_id, a_obs in multi_obs.items(): if a_obs is not None: policy_id = mapping_cache.setdefault( - agent_id, policy_agent_mapping(agent_id)) + agent_id, policy_agent_mapping(agent_id) + ) p_use_lstm = use_lstm[policy_id] - # print("rollout") + # print("p_use_lstm", p_use_lstm) + # print( + # agent_id, + # "agent_states[agent_id]", + # agent_states[agent_id], + # ) if p_use_lstm: a_action, p_state, _ = _worker_compute_action( worker, @@ -166,8 +196,12 @@ def internal_rollout(worker, prev_action=prev_actions[agent_id], prev_reward=prev_rewards[agent_id], policy_id=policy_id, - explore=explore + explore=explore, ) + # print( + # "after rollout _worker_compute_action p_state", + # p_state, + # ) agent_states[agent_id] = p_state else: a_action = _worker_compute_action( @@ -177,7 +211,7 @@ def internal_rollout(worker, prev_action=prev_actions[agent_id], prev_reward=prev_rewards[agent_id], policy_id=policy_id, - explore=explore + explore=explore, ) a_action = flatten_to_single_ndarray(a_action) action_dict[agent_id] = a_action @@ -196,7 +230,8 @@ def internal_rollout(worker, if multiagent: done = done["__all__"] reward_total += sum( - r for r in reward.values() if r is not None) + r for r in reward.values() if r is not None + ) else: reward_total += reward if not no_render: @@ -228,24 +263,55 @@ def _keep_going(steps, num_steps, episodes, num_episodes): return True -def _worker_compute_action(worker, timestep, - observation: TensorStructType, - state: List[TensorStructType] = None, - prev_action: TensorStructType = None, - prev_reward: float = None, - info: EnvInfoDict = None, - policy_id: PolicyID = DEFAULT_POLICY_ID, - full_fetch: bool = False, - explore: bool = None) -> TensorStructType: +def _get_first_obs( + env, + reset_env_before, + episodes, + last_obs, + mapping_cache, + state_init, + last_rnn_states, +): + if reset_env_before or episodes > 0: + obs = env.reset() + agent_states = DefaultMapping( + lambda agent_id_: state_init[mapping_cache[agent_id_]] + ) + else: + obs = last_obs + if last_rnn_states is not None: + agent_states = DefaultMapping( + lambda agent_id_: last_rnn_states[mapping_cache[agent_id_]] + ) + else: + agent_states = DefaultMapping( + lambda agent_id_: state_init[mapping_cache[agent_id_]] + ) + return obs, agent_states + + +def _worker_compute_action( + worker, + timestep, + observation: TensorStructType, + state: List[TensorStructType] = None, + prev_action: TensorStructType = None, + prev_reward: float = None, + info: EnvInfoDict = None, + policy_id: PolicyID = DEFAULT_POLICY_ID, + full_fetch: bool = False, + explore: bool = None, +) -> TensorStructType: """ Modified version of the Trainer compute_action method """ if state is None: state = [] - preprocessed = worker.preprocessors[ - policy_id].transform(observation) - filtered_obs = worker.filters[policy_id]( - preprocessed, update=False) + # Check the preprocessor and preprocess, if necessary. + pp = worker.preprocessors[policy_id] + if type(pp).__name__ != "NoPreprocessor": + observation = pp.transform(observation) + filtered_obs = worker.filters[policy_id](observation, update=False) result = worker.get_policy(policy_id).compute_single_action( filtered_obs, state, @@ -254,7 +320,8 @@ def _worker_compute_action(worker, timestep, info, clip_actions=worker.policy_config["clip_actions"], explore=explore, - timestep=timestep) + timestep=timestep, + ) if state or full_fetch: return result diff --git a/setup.py b/setup.py index 29e4617..55e716a 100644 --- a/setup.py +++ b/setup.py @@ -15,13 +15,14 @@ def read(fname): long_description=read("README.md"), license="MIT", install_requires=[ - "ray[rllib]==1.0.0", + "ray[rllib]>=1.2.0", "gym==0.17.3", "torch>=1.6.0,<=1.7.0", "tensorboard==1.15.0", "numba>=0.51.2", "matplotlib>=3.3.2", "wandb", + "ordered-set", "seaborn==0.9.0", "tqdm", ], diff --git a/tests/marltoolbox/algos/amTFT/test_amTFTRolloutsTorchPolicy.py b/tests/marltoolbox/algos/amTFT/test_amTFTRolloutsTorchPolicy.py index ac1a3d4..637ea22 100644 --- a/tests/marltoolbox/algos/amTFT/test_amTFTRolloutsTorchPolicy.py +++ b/tests/marltoolbox/algos/amTFT/test_amTFTRolloutsTorchPolicy.py @@ -18,10 +18,7 @@ from marltoolbox.envs.matrix_sequential_social_dilemma import ( IteratedPrisonersDilemma, ) -from marltoolbox.experiments.rllib_api.amtft_various_env import ( - get_rllib_config, - get_hyperparameters, -) +from marltoolbox.experiments.rllib_api import amtft_various_env from marltoolbox.utils import postprocessing, log from test_base_policy import init_amTFT, generate_fake_discrete_actions @@ -46,14 +43,22 @@ def test_compute_actions_overwrite(): (fake_actions, fake_state_out, fake_extra_fetches), (fake_actions_2nd, fake_state_out_2nd, fake_extra_fetches_2nd), ] - actions, state_out, extra_fetches = am_tft_policy.compute_actions( - observations[env.players_ids[0]] + actions, state_out, extra_fetches = am_tft_policy._compute_action_helper( + observations[env.players_ids[0]], + state_batches=None, + seq_lens=1, + explore=True, + timestep=0, ) assert actions == fake_actions assert state_out == fake_state_out assert extra_fetches == fake_extra_fetches - actions, state_out, extra_fetches = am_tft_policy.compute_actions( - observations[env.players_ids[0]] + actions, state_out, extra_fetches = am_tft_policy._compute_action_helper( + observations[env.players_ids[0]], + state_batches=None, + seq_lens=1, + explore=True, + timestep=0, ) assert actions == fake_actions_2nd assert state_out == fake_state_out_2nd @@ -65,28 +70,38 @@ def test__select_algo_to_use_in_eval(): policy_class=amTFT.AmTFTRolloutsTorchPolicy ) - def assert_(working_state_idx, active_algo_idx): + def assert_active_algo_idx(working_state_idx, active_algo_idx): am_tft_policy.working_state = base_policy.WORKING_STATES[ working_state_idx ] - am_tft_policy._select_witch_algo_to_use() - assert am_tft_policy.active_algo_idx == active_algo_idx + am_tft_policy._select_witch_algo_to_use(None) + assert ( + am_tft_policy.active_algo_idx == active_algo_idx + ), f"{am_tft_policy.active_algo_idx} == {active_algo_idx}" am_tft_policy.use_opponent_policies = False am_tft_policy.n_steps_to_punish = 0 - assert_(working_state_idx=2, active_algo_idx=base.OWN_COOP_POLICY_IDX) + assert_active_algo_idx( + working_state_idx=2, active_algo_idx=base.OWN_COOP_POLICY_IDX + ) am_tft_policy.use_opponent_policies = False am_tft_policy.n_steps_to_punish = 1 - assert_(working_state_idx=2, active_algo_idx=base.OWN_SELFISH_POLICY_IDX) + assert_active_algo_idx( + working_state_idx=2, active_algo_idx=base.OWN_SELFISH_POLICY_IDX + ) am_tft_policy.use_opponent_policies = True am_tft_policy.performing_rollouts = True am_tft_policy.n_steps_to_punish_opponent = 0 - assert_(working_state_idx=2, active_algo_idx=base.OPP_COOP_POLICY_IDX) + assert_active_algo_idx( + working_state_idx=2, active_algo_idx=base.OPP_COOP_POLICY_IDX + ) am_tft_policy.use_opponent_policies = True am_tft_policy.performing_rollouts = True am_tft_policy.n_steps_to_punish_opponent = 1 - assert_(working_state_idx=2, active_algo_idx=base.OPP_SELFISH_POLICY_IDX) + assert_active_algo_idx( + working_state_idx=2, active_algo_idx=base.OPP_SELFISH_POLICY_IDX + ) def test__duration_found_or_continue_search(): @@ -156,12 +171,22 @@ def step(self, actions: dict): return observations, rewards, epi_is_done, info -def make_FakePolicyWtDefinedActions(list_actions_to_play, ParentPolicyCLass): +def make_fake_policy_class_wt_defined_actions( + list_actions_to_play, ParentPolicyCLass +): class FakePolicyWtDefinedActions(ParentPolicyCLass): - def compute_actions(self, *args, **kwargs): + def _compute_action_helper(self, *args, **kwargs): + print("len", len(list_actions_to_play)) action = list_actions_to_play.pop(0) return np.array([action]), [], {} + def _initialize_loss_from_dummy_batch( + self, + auto_remove_unneeded_view_reqs: bool = True, + stats_fn=None, + ) -> None: + pass + return FakePolicyWtDefinedActions @@ -178,20 +203,23 @@ def init_worker( debug = True exp_name, _ = log.log_in_current_day_dir("testing") - hparams = get_hyperparameters( + hparams = amtft_various_env.get_hyperparameters( debug, train_n_replicates, filter_utilitarian=False, env="IteratedPrisonersDilemma", ) - _, _, rllib_config = get_rllib_config( + stop, env_config, rllib_config = amtft_various_env.get_rllib_config( hparams, welfare_fn=postprocessing.WELFARE_UTILITARIAN ) rllib_config["env"] = FakeEnvWtActionAsReward rllib_config["env_config"]["max_steps"] = max_steps - rllib_config["seed"] = int(time.time()) + rllib_config = _remove_dynamic_values_from_config( + rllib_config, hparams, env_config, stop + ) + for policy_id in FakeEnvWtActionAsReward({}).players_ids: policy_to_modify = list( rllib_config["multiagent"]["policies"][policy_id] @@ -202,30 +230,32 @@ def init_worker( if actions_list_0 is not None: policy_to_modify[3]["nested_policies"][0][ "Policy_class" - ] = make_FakePolicyWtDefinedActions( + ] = make_fake_policy_class_wt_defined_actions( copy.deepcopy(actions_list_0), DEFAULT_NESTED_POLICY_COOP ) if actions_list_1 is not None: policy_to_modify[3]["nested_policies"][1][ "Policy_class" - ] = make_FakePolicyWtDefinedActions( + ] = make_fake_policy_class_wt_defined_actions( copy.deepcopy(actions_list_1), DEFAULT_NESTED_POLICY_SELFISH ) if actions_list_2 is not None: policy_to_modify[3]["nested_policies"][2][ "Policy_class" - ] = make_FakePolicyWtDefinedActions( + ] = make_fake_policy_class_wt_defined_actions( copy.deepcopy(actions_list_2), DEFAULT_NESTED_POLICY_COOP ) if actions_list_3 is not None: policy_to_modify[3]["nested_policies"][3][ "Policy_class" - ] = make_FakePolicyWtDefinedActions( + ] = make_fake_policy_class_wt_defined_actions( copy.deepcopy(actions_list_3), DEFAULT_NESTED_POLICY_SELFISH ) rllib_config["multiagent"]["policies"][policy_id] = tuple( policy_to_modify ) + rllib_config["lr_schedule"] = None + rllib_config["exploration_config"]["temperature_schedule"] = None dqn_trainer = DQNTrainer( rllib_config, logger_creator=_get_logger_creator(exp_name) @@ -236,10 +266,35 @@ def init_worker( am_tft_policy_col = worker.get_policy("player_col") am_tft_policy_row.working_state = WORKING_STATES[2] am_tft_policy_col.working_state = WORKING_STATES[2] + print("env setup") return worker, am_tft_policy_row, am_tft_policy_col +def _remove_dynamic_values_from_config( + rllib_config, hparams, env_config, stop +): + rllib_config["seed"] = int(time.time()) + rllib_config["learning_starts"] = int( + rllib_config["env_config"]["max_steps"] + * rllib_config["env_config"]["bs_epi_mul"] + ) + rllib_config["buffer_size"] = int( + env_config["max_steps"] + * env_config["buf_frac"] + * stop["episodes_total"] + ) + rllib_config["train_batch_size"] = int( + env_config["max_steps"] * env_config["bs_epi_mul"] + ) + rllib_config["training_intensity"] = int( + rllib_config["num_envs_per_worker"] + * rllib_config["num_workers"] + * hparams["training_intensity"] + ) + return rllib_config + + def _get_logger_creator(exp_name): logdir_prefix = exp_name + "/" tail, head = os.path.split(exp_name) @@ -266,10 +321,12 @@ def default_logger_creator(config): def test__compute_debit_using_rollouts(): - def assert_(worker_, am_tft_policy, last_obs, opp_action, assert_debit): + def assert_debit_value_computed( + worker_, am_tft_policy, last_obs, opp_action, assert_debit + ): worker_.foreach_env(lambda env: env.reset()) debit = am_tft_policy._compute_debit_using_rollouts( - last_obs, opp_action, worker_ + last_obs, opp_action, worker_, worker.async_env ) assert debit == assert_debit @@ -291,14 +348,14 @@ def init_no_extra_reward(max_steps_): worker, am_tft_policy_row, am_tft_policy_col = init_no_extra_reward( max_steps ) - assert_( + assert_debit_value_computed( worker, am_tft_policy_row, {"player_row": 0, "player_col": 0}, opp_action=0, assert_debit=0, ) - assert_( + assert_debit_value_computed( worker, am_tft_policy_col, {"player_row": 1, "player_col": 0}, @@ -309,14 +366,14 @@ def init_no_extra_reward(max_steps_): worker, am_tft_policy_row, am_tft_policy_col = init_no_extra_reward( max_steps ) - assert_( + assert_debit_value_computed( worker, am_tft_policy_row, {"player_row": 1, "player_col": 0}, opp_action=1, assert_debit=1, ) - assert_( + assert_debit_value_computed( worker, am_tft_policy_col, {"player_row": 1, "player_col": 1}, @@ -342,14 +399,14 @@ def init_selfish_opp_advantaged(max_steps): worker, am_tft_policy_row, am_tft_policy_col = init_selfish_opp_advantaged( max_steps ) - assert_( + assert_debit_value_computed( worker, am_tft_policy_row, {"player_row": 0, "player_col": 0}, opp_action=0, assert_debit=0, ) - assert_( + assert_debit_value_computed( worker, am_tft_policy_col, {"player_row": 1, "player_col": 0}, @@ -375,14 +432,14 @@ def init_coop_opp_advantaged(max_steps): worker, am_tft_policy_row, am_tft_policy_col = init_coop_opp_advantaged( max_steps ) - assert_( + assert_debit_value_computed( worker, am_tft_policy_row, {"player_row": 1, "player_col": 0}, opp_action=1, assert_debit=0, ) - assert_( + assert_debit_value_computed( worker, am_tft_policy_col, {"player_row": 1, "player_col": 1}, diff --git a/tests/marltoolbox/algos/amTFT/test_base_policy.py b/tests/marltoolbox/algos/amTFT/test_base_policy.py index 8d7143b..898d76f 100644 --- a/tests/marltoolbox/algos/amTFT/test_base_policy.py +++ b/tests/marltoolbox/algos/amTFT/test_base_policy.py @@ -30,31 +30,31 @@ def init_amTFT( def test__select_witch_algo_to_use(): am_tft_policy, env = init_amTFT() - def assert_(working_state_idx, active_algo_idx): + def assert_active_algo_idx(working_state_idx, active_algo_idx): am_tft_policy.working_state = base_policy.WORKING_STATES[ working_state_idx ] - am_tft_policy._select_witch_algo_to_use() + am_tft_policy._select_witch_algo_to_use(None) assert am_tft_policy.active_algo_idx == active_algo_idx - assert_( + assert_active_algo_idx( working_state_idx=0, active_algo_idx=base_policy.OWN_COOP_POLICY_IDX ) - assert_( + assert_active_algo_idx( working_state_idx=1, active_algo_idx=base_policy.OWN_SELFISH_POLICY_IDX ) am_tft_policy.n_steps_to_punish = 0 - assert_( + assert_active_algo_idx( working_state_idx=2, active_algo_idx=base_policy.OWN_COOP_POLICY_IDX ) am_tft_policy.n_steps_to_punish = 1 - assert_( + assert_active_algo_idx( working_state_idx=2, active_algo_idx=base_policy.OWN_SELFISH_POLICY_IDX ) - assert_( + assert_active_algo_idx( working_state_idx=3, active_algo_idx=base_policy.OWN_SELFISH_POLICY_IDX ) - assert_( + assert_active_algo_idx( working_state_idx=4, active_algo_idx=base_policy.OWN_COOP_POLICY_IDX ) @@ -123,21 +123,20 @@ def test_lr_update(): ) one_policy_batch = multiagent_batch.policy_batches[env.players_ids[0]] - am_tft_policy.on_global_var_update({"timestep": 0}) - am_tft_policy.learn_on_batch(one_policy_batch) - for algo in am_tft_policy.algorithms: - assert algo.cur_lr == base_lr - for opt in algo._optimizers: - for p in opt.param_groups: - assert p["lr"] == algo.cur_lr + def _assert_lr_equals(policy, lr): + for algo in policy.algorithms: + assert algo.cur_lr == lr + for opt in algo._optimizers: + for p in opt.param_groups: + assert p["lr"] == lr - am_tft_policy.on_global_var_update({"timestep": interm_global_timestep}) - am_tft_policy.learn_on_batch(one_policy_batch) - for algo in am_tft_policy.algorithms: - assert algo.cur_lr == final_lr - for opt in algo._optimizers: - for p in opt.param_groups: - assert p["lr"] == algo.cur_lr + def _fake_n_step_assert_lr(policy, n_step, lr): + policy.on_global_var_update({"timestep": n_step}) + policy.learn_on_batch(one_policy_batch) + _assert_lr_equals(policy, lr) + + _fake_n_step_assert_lr(am_tft_policy, 0, base_lr) + _fake_n_step_assert_lr(am_tft_policy, interm_global_timestep, final_lr) def test__is_punishment_planned(): @@ -148,12 +147,21 @@ def test__is_punishment_planned(): assert am_tft_policy._is_punishment_planned() +from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv + + def test_on_episode_end(): am_tft_policy, env = init_amTFT( {"working_state": base_policy.WORKING_STATES[2]} ) + base_env = _MultiAgentEnvToBaseEnv( + make_env=lambda _: env, existing_envs=[], num_envs=1 + ) am_tft_policy.total_debit = 0 am_tft_policy.n_steps_to_punish = 0 - am_tft_policy.on_episode_end() + am_tft_policy.observed_n_step_in_current_epi = base_env.get_unwrapped()[ + 0 + ].max_steps + am_tft_policy.on_episode_end(base_env=base_env) assert am_tft_policy.total_debit == 0 assert am_tft_policy.n_steps_to_punish == 0 diff --git a/tests/marltoolbox/algos/exploiters/evader_utils.py b/tests/marltoolbox/algos/exploiters/evader_utils.py index 9d9be6d..bcd110a 100644 --- a/tests/marltoolbox/algos/exploiters/evader_utils.py +++ b/tests/marltoolbox/algos/exploiters/evader_utils.py @@ -9,7 +9,7 @@ from ray.rllib.agents.callbacks import DefaultCallbacks from ray.rllib.agents.dqn.dqn_torch_policy import postprocess_nstep_and_prio from ray.rllib.agents.pg.pg_torch_policy import post_process_advantages -from ray.rllib.agents.ppo.ppo_torch_policy import postprocess_ppo_gae +from ray.rllib.agents.ppo.ppo_torch_policy import compute_gae_for_sample_batch from ray.rllib.evaluation.sample_batch_builder import ( MultiAgentSampleBatchBuilder, ) @@ -25,7 +25,7 @@ from marltoolbox.utils import postprocessing, miscellaneous TEST_POLICIES = ( - (ppo.PPOTorchPolicy, postprocess_ppo_gae, ppo.DEFAULT_CONFIG), + (ppo.PPOTorchPolicy, compute_gae_for_sample_batch, ppo.DEFAULT_CONFIG), (dqn.DQNTorchPolicy, postprocess_nstep_and_prio, dqn.DEFAULT_CONFIG), (A3CTorchPolicy, add_advantages, a3c.DEFAULT_CONFIG), (pg.PGTorchPolicy, post_process_advantages, pg.DEFAULT_CONFIG), diff --git a/tests/marltoolbox/algos/test_welfare_coordination.py b/tests/marltoolbox/algos/test_welfare_coordination.py new file mode 100644 index 0000000..046e1ca --- /dev/null +++ b/tests/marltoolbox/algos/test_welfare_coordination.py @@ -0,0 +1,330 @@ +import random +from marltoolbox.algos.welfare_coordination import MetaGameSolver + +BEST_PAYOFF = 1.5 +BEST_WELFARE = "best_welfare" +WORST_PAYOFF = -0.5 +WORST_WELFARE = "worst_welfare" + + +def test_end_to_end_wt_best_welfare_fn(): + for _ in range(10): + meta_game_solver = _given_meta_game_with_a_clear_extrem_welfare_fn( + best=True + ) + for _ in range(10): + _when_solving_meta_game(meta_game_solver) + _assert_best_welfare_in_announced_set(meta_game_solver) + + +def test_end_to_end_wt_worst_welfare_fn(): + for _ in range(10): + meta_game_solver = _given_meta_game_with_a_clear_extrem_welfare_fn( + worst=True + ) + for _ in range(10): + _when_solving_meta_game(meta_game_solver) + _assert_best_welfare_not_in_announced_set(meta_game_solver) + + +def _given_meta_game_with_a_clear_extrem_welfare_fn(best=False, worst=False): + assert best or worst + assert not (best and worst) + meta_game_solver = MetaGameSolver() + n_welfares = _get_random_number_of_welfare_fn() + own_player_idx = _get_random_position_of_players() + opp_player_idx = (own_player_idx + 1) % 2 + welfares = ["welfare_" + str(el) for el in list(range(n_welfares - 1))] + if best: + welfares.append(BEST_WELFARE) + elif worst: + welfares.append(WORST_WELFARE) + + all_welfare_pairs_wt_payoffs = ( + _get_all_welfare_pairs_wt_extrem_payoffs_for_i( + welfares=welfares, + own_player_idx=own_player_idx, + best_welfare=BEST_WELFARE if best else None, + worst_welfare=WORST_WELFARE if worst else None, + ) + ) + + meta_game_solver.setup_meta_game( + all_welfare_pairs_wt_payoffs, + own_player_idx=own_player_idx, + opp_player_idx=opp_player_idx, + own_default_welfare_fn=welfares[own_player_idx], + opp_default_welfare_fn=welfares[opp_player_idx], + ) + return meta_game_solver + + +def _when_solving_meta_game(meta_game_solver): + meta_game_solver.solve_meta_game(_get_random_tau()) + + +def _assert_best_welfare_in_announced_set(meta_game_solver): + assert BEST_WELFARE in meta_game_solver.welfare_set_to_annonce + + +def _assert_best_welfare_not_in_announced_set(meta_game_solver): + assert BEST_WELFARE not in meta_game_solver.welfare_set_to_annonce + + +def _get_random_tau(): + return random.random() + + +def _get_random_number_of_welfare_fn(): + return random.randint(2, 4) + + +def _get_random_position_of_players(): + return random.randint(0, 1) + + +def _get_all_welfare_pairs_wt_extrem_payoffs_for_i( + welfares, + own_player_idx, + best_welfare: str = None, + worst_welfare: str = None, +): + all_welfare_pairs_wt_payoffs = {} + for welfare_p1 in welfares: + for welfare_p2 in welfares: + welfare_pair_name = ( + MetaGameSolver.from_pair_of_welfare_names_to_key( + welfare_p1, welfare_p2 + ) + ) + + all_welfare_pairs_wt_payoffs[welfare_pair_name] = [ + random.random(), + random.random(), + ] + if best_welfare is not None and best_welfare == welfare_p1: + all_welfare_pairs_wt_payoffs[welfare_pair_name][ + own_player_idx + ] = BEST_PAYOFF + elif worst_welfare is not None and worst_welfare == welfare_p1: + all_welfare_pairs_wt_payoffs[welfare_pair_name][ + own_player_idx + ] = WORST_PAYOFF + return all_welfare_pairs_wt_payoffs + + +def test__compute_meta_payoff(): + for _ in range(100): + ( + welfares, + all_welfare_pairs_wt_payoffs, + own_welfare_set, + opp_welfare_set, + payoff, + payoff_default, + own_default_welfare_fn, + opp_default_welfare_fn, + own_player_idx, + opp_player_idx, + ) = _given_this_all_welfare_pairs_wt_payoffs() + + meta_payoff = _when_computing_meta_game_payoff( + all_welfare_pairs_wt_payoffs, + own_player_idx, + opp_player_idx, + own_default_welfare_fn, + opp_default_welfare_fn, + own_welfare_set, + opp_welfare_set, + ) + + _assert_get_the_right_payoffs_or_default_payoff( + own_welfare_set, + opp_welfare_set, + own_player_idx, + meta_payoff, + payoff, + payoff_default, + ) + + +def _when_computing_meta_game_payoff( + all_welfare_pairs_wt_payoffs, + own_player_idx, + opp_player_idx, + own_default_welfare_fn, + opp_default_welfare_fn, + own_welfare_set, + opp_welfare_set, +): + meta_game_solver = MetaGameSolver() + meta_game_solver.setup_meta_game( + all_welfare_pairs_wt_payoffs, + own_player_idx=own_player_idx, + opp_player_idx=opp_player_idx, + own_default_welfare_fn=own_default_welfare_fn, + opp_default_welfare_fn=opp_default_welfare_fn, + ) + meta_payoff = meta_game_solver._compute_meta_payoff( + own_welfare_set, opp_welfare_set + ) + return meta_payoff + + +def _given_this_all_welfare_pairs_wt_payoffs(): + n_welfares = _get_random_number_of_welfare_fn() + welfares = ["welfare_" + str(el) for el in list(range(n_welfares))] + own_player_idx = _get_random_position_of_players() + opp_player_idx = (own_player_idx + 1) % 2 + all_welfare_pairs_wt_payoffs = {} + + own_welfare_set, opp_welfare_set, payoff = _add_nominal_case( + welfares, all_welfare_pairs_wt_payoffs, own_player_idx + ) + + ( + own_default_welfare_fn, + opp_default_welfare_fn, + payoff_default, + ) = _add_default_case( + welfares, all_welfare_pairs_wt_payoffs, own_player_idx + ) + + if ( + own_default_welfare_fn + == list(own_welfare_set)[0] + == list(opp_welfare_set)[0] + == opp_default_welfare_fn + ): + payoff = payoff_default + return ( + welfares, + all_welfare_pairs_wt_payoffs, + own_welfare_set, + opp_welfare_set, + payoff, + payoff_default, + own_default_welfare_fn, + opp_default_welfare_fn, + own_player_idx, + opp_player_idx, + ) + + +def _assert_get_the_right_payoffs_or_default_payoff( + own_welfare_set, + opp_welfare_set, + own_player_idx, + meta_payoff, + payoff, + payoff_default, +): + if len(own_welfare_set & opp_welfare_set) > 0: + assert meta_payoff[own_player_idx] == payoff + else: + assert meta_payoff[own_player_idx] == payoff_default + + +def _add_nominal_case(welfares, all_welfare_pairs_wt_payoffs, own_player_idx): + own_welfare_set = set(random.sample(welfares, 1)) + opp_welfare_set = set(random.sample(welfares, 1)) + welfare_pair_name = MetaGameSolver.from_pair_of_welfare_names_to_key( + list(own_welfare_set)[0], list(opp_welfare_set)[0] + ) + payoff = random.random() + all_welfare_pairs_wt_payoffs[welfare_pair_name] = [-1, -1] + all_welfare_pairs_wt_payoffs[welfare_pair_name][own_player_idx] = payoff + + return own_welfare_set, opp_welfare_set, payoff + + +def _add_default_case(welfares, all_welfare_pairs_wt_payoffs, own_player_idx): + own_default_welfare_fn = random.sample(welfares, 1)[0] + opp_default_welfare_fn = random.sample(welfares, 1)[0] + welfare_default_pair_name = ( + MetaGameSolver.from_pair_of_welfare_names_to_key( + own_default_welfare_fn, opp_default_welfare_fn + ) + ) + payoff_default = random.random() + + all_welfare_pairs_wt_payoffs[welfare_default_pair_name] = [-1, -1] + all_welfare_pairs_wt_payoffs[welfare_default_pair_name][ + own_player_idx + ] = payoff_default + + return own_default_welfare_fn, opp_default_welfare_fn, payoff_default + + +def test__list_all_set_of_welfare_fn(): + for _ in range(100): + ( + own_player_idx, + opp_player_idx, + welfares, + all_welfare_pairs_wt_payoffs, + ) = _given_n_welfare_fn() + meta_game_solver = _when_setting_the_game( + all_welfare_pairs_wt_payoffs, own_player_idx, opp_player_idx + ) + _assert_right_number_of_sets_and_presence_of_single_and_pairs( + meta_game_solver, welfares + ) + + +def _given_n_welfare_fn(): + n_welfares = _get_random_number_of_welfare_fn() + welfares = ["welfare_" + str(el) for el in list(range(n_welfares))] + own_player_idx = _get_random_position_of_players() + opp_player_idx = (own_player_idx + 1) % 2 + + all_welfare_pairs_wt_payoffs = ( + _get_all_welfare_pairs_wt_extrem_payoffs_for_i( + welfares=welfares, + own_player_idx=own_player_idx, + best_welfare=welfares[0], + ) + ) + return ( + own_player_idx, + opp_player_idx, + welfares, + all_welfare_pairs_wt_payoffs, + ) + + +def _when_setting_the_game( + all_welfare_pairs_wt_payoffs, own_player_idx, opp_player_idx +): + meta_game_solver = MetaGameSolver() + meta_game_solver.setup_meta_game( + all_welfare_pairs_wt_payoffs, + own_player_idx=own_player_idx, + opp_player_idx=opp_player_idx, + own_default_welfare_fn="welfare_0", + opp_default_welfare_fn="welfare_1", + ) + return meta_game_solver + + +def _assert_right_number_of_sets_and_presence_of_single_and_pairs( + meta_game_solver, welfares +): + meta_game_solver._list_all_set_of_welfare_fn() + if len(welfares) == 2: + assert len(meta_game_solver.welfare_fn_sets) == 3 + elif len(welfares) == 3: + assert len(meta_game_solver.welfare_fn_sets) == 3 + 3 + 1 + elif len(welfares) == 4: + print( + "meta_game_solver.welfare_fn_sets", + meta_game_solver.welfare_fn_sets, + ) + assert len(meta_game_solver.welfare_fn_sets) == 4 + 6 + 4 + 1 + for welfare in welfares: + assert frozenset([welfare]) in meta_game_solver.welfare_fn_sets + for welfare_2 in welfares: + assert ( + frozenset([welfare, welfare_2]) + in meta_game_solver.welfare_fn_sets + ) diff --git a/tests/marltoolbox/envs/coin_game_tests_utils.py b/tests/marltoolbox/envs/coin_game_tests_utils.py new file mode 100644 index 0000000..6fe771f --- /dev/null +++ b/tests/marltoolbox/envs/coin_game_tests_utils.py @@ -0,0 +1,440 @@ +import random +import numpy as np + + +def init_several_envs(classes, **kwargs): + return [init_env(env_class=class_, **kwargs) for class_ in classes] + + +def init_env( + env_class, + max_steps, + seed=None, + grid_size=3, + players_can_pick_same_coin=True, + same_obs_for_each_player=False, + batch_size=None, +): + config = { + "max_steps": max_steps, + "grid_size": grid_size, + "both_players_can_pick_the_same_coin": players_can_pick_same_coin, + "same_obs_for_each_player": same_obs_for_each_player, + } + if batch_size is not None: + config["batch_size"] = batch_size + env = env_class(config) + env.seed(seed) + return env + + +def check_custom_obs( + obs, + grid_size, + batch_size=None, + n_in_0=1.0, + n_in_1=1.0, + n_in_2_and_above=1.0, + n_layers=4, +): + assert len(obs) == 2, "two players" + for player_obs in obs.values(): + if batch_size is None: + check_single_obs( + player_obs, + grid_size, + n_layers, + n_in_0, + n_in_1, + n_in_2_and_above, + ) + else: + for i in range(batch_size): + check_single_obs( + player_obs[i, ...], + grid_size, + n_layers, + n_in_0, + n_in_1, + n_in_2_and_above, + ) + + +def check_single_obs( + player_obs, grid_size, n_layers, n_in_0, n_in_1, n_in_2_and_above +): + assert player_obs.shape == (grid_size, grid_size, n_layers) + assert ( + player_obs[..., 0].sum() == n_in_0 + ), f"observe 1 player red in grid: {player_obs[..., 0]}" + assert ( + player_obs[..., 1].sum() == n_in_1 + ), f"observe 1 player blue in grid: {player_obs[..., 1]}" + assert ( + player_obs[..., 2:].sum() == n_in_2_and_above + ), f"observe 1 coin in grid: {player_obs[..., 2:]}" + + +def assert_logger_buffer_size(env, n_steps): + assert_attributes_len_equals_value( + env, + n_steps, + ) + + +def assert_attributes_len_equals_value( + object_, + value, + attributes=("red_pick", "red_pick_own", "blue_pick", "blue_pick_own"), +): + for attribute in attributes: + assert len(getattr(object_, attribute)) == value + + +def helper_test_reset(envs, check_obs_fn, **kwargs): + for env in envs: + obs = env.reset() + check_obs_fn(obs, **kwargs) + assert_logger_buffer_size(env, n_steps=0) + + +def helper_test_step(envs, check_obs_fn, **kwargs): + for env in envs: + obs = env.reset() + check_obs_fn(obs, **kwargs) + assert_logger_buffer_size(env, n_steps=0) + + actions = _get_random_action(env, **kwargs) + obs, reward, done, info = env.step(actions) + check_obs_fn(obs, **kwargs) + assert_logger_buffer_size(env, n_steps=1) + assert not done["__all__"] + + +def _get_random_action(env, **kwargs): + if "batch_size" in kwargs.keys(): + actions = _get_random_action_batch(env, kwargs["batch_size"]) + else: + actions = _get_random_single_action(env) + return actions + + +def _get_random_single_action(env): + actions = { + policy_id: random.randint(0, env.NUM_ACTIONS - 1) + for policy_id in env.players_ids + } + return actions + + +def _get_random_action_batch(env, batch_size): + actions = { + policy_id: [ + random.randint(0, env.NUM_ACTIONS - 1) for _ in range(batch_size) + ] + for policy_id in env.players_ids + } + return actions + + +def helper_test_multiple_steps(envs, n_steps, check_obs_fn, **kwargs): + for env in envs: + obs = env.reset() + check_obs_fn(obs, **kwargs) + assert_logger_buffer_size(env, n_steps=0) + + for step_i in range(1, n_steps, 1): + actions = _get_random_action(env, **kwargs) + obs, reward, done, info = env.step(actions) + check_obs_fn(obs, **kwargs) + assert_logger_buffer_size(env, n_steps=step_i) + assert not done["__all__"] + + +def helper_test_multi_ple_episodes( + envs, + n_steps, + max_steps, + check_obs_fn, + **kwargs, +): + for env in envs: + obs = env.reset() + check_obs_fn(obs, **kwargs) + assert_logger_buffer_size(env, n_steps=0) + + step_i = 0 + for _ in range(n_steps): + step_i += 1 + actions = _get_random_action(env, **kwargs) + obs, reward, done, info = env.step(actions) + check_obs_fn(obs, **kwargs) + assert_logger_buffer_size(env, n_steps=step_i) + assert not done["__all__"] or ( + step_i == max_steps and done["__all__"] + ) + if done["__all__"]: + obs = env.reset() + check_obs_fn(obs, **kwargs) + assert_logger_buffer_size(env, n_steps=0) + step_i = 0 + + +def helper_assert_info(repetitions=10, **kwargs): + if "batch_size" in kwargs.keys(): + for _ in range(repetitions): + batch_deltas = np.random.randint( + 0, kwargs["max_steps"] - 1, size=kwargs["batch_size"] + ) + helper_assert_info_one_time(batch_deltas=batch_deltas, **kwargs) + else: + helper_assert_info_one_time(batch_deltas=None, **kwargs) + + +def helper_assert_info_one_time( + n_steps, + p_red_act, + p_blue_act, + envs, + max_steps, + p_red_pos, + p_blue_pos, + c_red_pos, + c_blue_pos, + red_speed, + blue_speed, + red_own, + blue_own, + check_obs_fn, + overwrite_pos_fn, + c_red_coin=None, + batch_deltas=None, + blue_coop_fraction=None, + red_coop_fraction=None, + red_coop_speed=None, + blue_coop_speed=None, + delta_err=0.01, + **check_obs_kwargs, +): + for env_i, env in enumerate(envs): + step_i = 0 + obs = env.reset() + check_obs_fn(obs, **check_obs_kwargs) + assert_logger_buffer_size(env, n_steps=0) + _overwrite_pos_helper( + batch_deltas, + overwrite_pos_fn, + step_i, + max_steps, + env, + p_red_pos, + p_blue_pos, + c_red_pos, + c_blue_pos, + c_red_coin, + ) + + for _ in range(n_steps): + actions = _read_actions( + p_red_act, + p_blue_act, + step_i, + batch_deltas, + n_steps_in_epi=max_steps, + ) + step_i += 1 + obs, reward, done, info = env.step(actions) + check_obs_fn(obs, **check_obs_kwargs) + assert_logger_buffer_size(env, n_steps=step_i) + assert not done["__all__"] or ( + step_i == max_steps and done["__all__"] + ) + + if done["__all__"]: + print("info", info) + print("step_i", step_i) + print("env", env) + print("env_i", env_i) + _assert_close_enough( + info["player_red"]["pick_speed"], red_speed, delta_err + ) + _assert_close_enough( + info["player_blue"]["pick_speed"], blue_speed, delta_err + ) + assert_not_present_in_dict_or_close_to( + "pick_own_color", red_own, info, "player_red", delta_err + ) + assert_not_present_in_dict_or_close_to( + "pick_own_color", blue_own, info, "player_blue", delta_err + ) + _assert_ssdmmcg_cooperation_items( + red_coop_fraction, + blue_coop_fraction, + red_coop_speed, + blue_coop_speed, + info, + delta_err, + ) + + obs = env.reset() + check_obs_fn(obs, **check_obs_kwargs) + assert_logger_buffer_size(env, n_steps=0) + step_i = 0 + + _overwrite_pos_helper( + batch_deltas, + overwrite_pos_fn, + step_i, + max_steps, + env, + p_red_pos, + p_blue_pos, + c_red_pos, + c_blue_pos, + c_red_coin, + ) + + +def assert_not_present_in_dict_or_close_to( + key, value, info, player, delta_err +): + if value is None: + assert key not in info[player] + else: + _assert_close_enough(info[player][key], value, delta_err) + + +def _assert_close_enough(value, target, delta_err): + assert abs(value - target) < delta_err, ( + f"{abs(value - target)} <" f" {delta_err}" + ) + + +def _read_actions( + p_red_act, p_blue_act, step_i, batch_deltas=None, n_steps_in_epi=None +): + if batch_deltas is not None: + return _read_actions_batch( + p_red_act, p_blue_act, step_i, batch_deltas, n_steps_in_epi + ) + else: + return _read_single_action(p_red_act, p_blue_act, step_i) + + +def _read_actions_batch( + p_red_act, p_blue_act, step_i, batch_deltas, n_steps_in_epi +): + actions = { + "player_red": [ + p_red_act[(step_i + delta) % n_steps_in_epi] + for delta in batch_deltas + ], + "player_blue": [ + p_blue_act[(step_i + delta) % n_steps_in_epi] + for delta in batch_deltas + ], + } + return actions + + +def _read_single_action(p_red_act, p_blue_act, step_i): + actions = { + "player_red": p_red_act[step_i - 1], + "player_blue": p_blue_act[step_i - 1], + } + return actions + + +def _assert_ssdmmcg_cooperation_items( + red_coop_fraction, + blue_coop_fraction, + red_coop_speed, + blue_coop_speed, + info, + delta_err, +): + if _is_using_ssdmmcg( + blue_coop_fraction, + red_coop_fraction, + red_coop_speed, + blue_coop_speed, + ): + + assert_not_present_in_dict_or_close_to( + "blue_coop_fraction", + blue_coop_fraction, + info, + "player_blue", + delta_err, + ) + assert_not_present_in_dict_or_close_to( + "red_coop_fraction", + red_coop_fraction, + info, + "player_red", + delta_err, + ) + assert_not_present_in_dict_or_close_to( + "red_coop_speed", + red_coop_speed, + info, + "player_red", + delta_err, + ) + assert_not_present_in_dict_or_close_to( + "blue_coop_speed", + blue_coop_speed, + info, + "player_blue", + delta_err, + ) + + +def _is_using_ssdmmcg( + blue_coop_fraction, red_coop_fraction, red_coop_speed, blue_coop_speed +): + return ( + blue_coop_fraction is not None + or red_coop_fraction is not None + or red_coop_speed is not None + or blue_coop_speed is not None + ) + + +def _overwrite_pos_helper( + batch_deltas, + overwrite_pos_fn, + step_i, + max_steps, + env, + p_red_pos, + p_blue_pos, + c_red_pos, + c_blue_pos, + c_red_coin, +): + if batch_deltas is not None: + overwrite_pos_fn( + step_i, + batch_deltas, + max_steps, + env, + p_red_pos, + p_blue_pos, + c_red_pos, + c_blue_pos, + c_red_coin=c_red_coin, + ) + else: + overwrite_pos_fn( + env, + p_red_pos[step_i], + p_blue_pos[step_i], + c_red_pos[step_i], + c_blue_pos[step_i], + c_red_coin=c_red_coin, + ) + + +def shift_consistently(list_, step_i, n_steps_in_epi, batch_deltas): + return [list_[(step_i + delta) % n_steps_in_epi] for delta in batch_deltas] diff --git a/tests/marltoolbox/envs/test_coin_game.py b/tests/marltoolbox/envs/test_coin_game.py index c62f338..4ab5264 100644 --- a/tests/marltoolbox/envs/test_coin_game.py +++ b/tests/marltoolbox/envs/test_coin_game.py @@ -5,137 +5,76 @@ from marltoolbox.envs.coin_game import CoinGame, AsymCoinGame - # TODO add tests for grid_size != 3 +from coin_game_tests_utils import ( + check_custom_obs, + assert_logger_buffer_size, + helper_test_reset, + helper_test_step, + init_several_envs, + helper_test_multiple_steps, + helper_test_multi_ple_episodes, + helper_assert_info, +) + + +def init_my_envs( + max_steps, + grid_size, + players_can_pick_same_coin=True, + same_obs_for_each_player=True, +): + return init_several_envs( + (CoinGame, AsymCoinGame), + max_steps=max_steps, + grid_size=grid_size, + players_can_pick_same_coin=players_can_pick_same_coin, + same_obs_for_each_player=same_obs_for_each_player, + ) + def test_reset(): max_steps, grid_size = 20, 3 - envs = init_several_env(max_steps, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - -def init_several_env(max_steps, grid_size, - players_can_pick_same_coin=True, - same_obs_for_each_player=True): - coin_game = init_env( - max_steps, - CoinGame, - grid_size, - players_can_pick_same_coin=players_can_pick_same_coin, - same_obs_for_each_player=same_obs_for_each_player) - asymm_coin_game = init_env( - max_steps, - AsymCoinGame, - grid_size, - players_can_pick_same_coin=players_can_pick_same_coin, - same_obs_for_each_player=same_obs_for_each_player) - return [coin_game, asymm_coin_game] - - -def init_env(max_steps, - env_class, - seed=None, - grid_size=3, - players_can_pick_same_coin=True, - same_obs_for_each_player=False): - config = { - "max_steps": max_steps, - "grid_size": grid_size, - "both_players_can_pick_the_same_coin": players_can_pick_same_coin, - "same_obs_for_each_player": same_obs_for_each_player, - } - env = env_class(config) - env.seed(seed) - return env + envs = init_my_envs(max_steps, grid_size) + helper_test_reset(envs, check_obs, grid_size=grid_size) def check_obs(obs, grid_size): - assert len(obs) == 2, "two players" - for key, player_obs in obs.items(): - assert player_obs.shape == (grid_size, grid_size, 4) - assert player_obs[..., 0].sum() == 1.0, \ - f"observe 1 player red in grid: {player_obs[..., 0]}" - assert player_obs[..., 1].sum() == 1.0, \ - f"observe 1 player blue in grid: {player_obs[..., 1]}" - assert player_obs[..., 2:].sum() == 1.0, \ - f"observe 1 coin in grid: {player_obs[..., 0]}" - - -def assert_logger_buffer_size(env, n_steps): - assert len(env.red_pick) == n_steps - assert len(env.red_pick_own) == n_steps - assert len(env.blue_pick) == n_steps - assert len(env.blue_pick_own) == n_steps + check_custom_obs(obs, grid_size) def test_step(): max_steps, grid_size = 20, 3 - envs = init_several_env(max_steps, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - actions = {policy_id: random.randint(0, env.NUM_ACTIONS - 1) - for policy_id in env.players_ids} - obs, reward, done, info = env.step(actions) - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=1) - assert not done["__all__"] + envs = init_my_envs(max_steps, grid_size) + helper_test_step(envs, check_obs, grid_size=grid_size) def test_multiple_steps(): max_steps, grid_size = 20, 3 n_steps = int(max_steps * 0.75) - envs = init_several_env(max_steps, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - for step_i in range(1, n_steps, 1): - actions = {policy_id: random.randint(0, env.NUM_ACTIONS - 1) - for policy_id in env.players_ids} - obs, reward, done, info = env.step(actions) - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=step_i) - assert not done["__all__"] + envs = init_my_envs(max_steps, grid_size) + helper_test_multiple_steps( + envs, + n_steps, + check_obs, + grid_size=grid_size, + ) def test_multiple_episodes(): max_steps, grid_size = 20, 3 n_steps = int(max_steps * 8.25) - envs = init_several_env(max_steps, grid_size) + envs = init_my_envs(max_steps, grid_size) + helper_test_multi_ple_episodes( + envs, + n_steps, + max_steps, + check_obs, + grid_size=grid_size, + ) - for env in envs: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - step_i = 0 - for _ in range(n_steps): - step_i += 1 - actions = {policy_id: random.randint(0, env.NUM_ACTIONS - 1) - for policy_id in env.players_ids} - obs, reward, done, info = env.step(actions) - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=step_i) - assert not done["__all__"] or \ - (step_i == max_steps and done["__all__"]) - if done["__all__"]: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - step_i = 0 - - -def overwrite_pos(env, p_red_pos, p_blue_pos, c_red_pos, c_blue_pos): +def overwrite_pos(env, p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, **kwargs): assert c_red_pos is None or c_blue_pos is None if c_red_pos is None: env.red_coin = 0 @@ -154,46 +93,6 @@ def overwrite_pos(env, p_red_pos, p_blue_pos, c_red_pos, c_blue_pos): env.red_coin = np.array(env.red_coin) -def assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed, blue_speed, red_own, blue_own): - step_i = 0 - delta_err = 0.01 - for _ in range(n_steps): - step_i += 1 - actions = {"player_red": p_red_act[step_i - 1], - "player_blue": p_blue_act[step_i - 1]} - obs, reward, done, info = env.step(actions) - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=step_i) - assert not done["__all__"] or (step_i == max_steps and done["__all__"]) - - if done["__all__"]: - assert abs(info["player_red"]["pick_speed"] - red_speed) \ - < delta_err - assert abs(info["player_blue"]["pick_speed"] - blue_speed) \ - < delta_err - - if red_own is None: - assert "pick_own_color" not in info["player_red"] - else: - assert abs(info["player_red"]["pick_own_color"] - red_own) \ - < delta_err - if blue_own is None: - assert "pick_own_color" not in info["player_blue"] - else: - assert abs(info["player_blue"]["pick_own_color"] - blue_own) \ - < delta_err - - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - step_i = 0 - - overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i], - c_red_pos[step_i], c_blue_pos[step_i]) - - def test_logged_info_no_picking(): p_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] p_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] @@ -203,32 +102,51 @@ def test_logged_info_no_picking(): c_blue_pos = [None, None, None, None] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) + envs = init_my_envs(max_steps, grid_size) for env in envs: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None) - - envs = init_several_env(max_steps, grid_size, - players_can_pick_same_coin=False) + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + grid_size=grid_size, + red_speed=0.0, + blue_speed=0.0, + red_own=None, + blue_own=None, + ) + + envs = init_my_envs(max_steps, grid_size, players_can_pick_same_coin=False) for env in envs: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None) + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + grid_size=grid_size, + red_speed=0.0, + blue_speed=0.0, + red_own=None, + blue_own=None, + ) def test_logged_info__red_pick_red_all_the_time(): @@ -240,32 +158,47 @@ def test_logged_info__red_pick_red_all_the_time(): c_blue_pos = [None, None, None, None] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=1.0, blue_speed=0.0, red_own=1.0, blue_own=None) - - envs = init_several_env(max_steps, grid_size, - players_can_pick_same_coin=False) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=1.0, blue_speed=0.0, red_own=1.0, blue_own=None) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + grid_size=grid_size, + red_speed=1.0, + blue_speed=0.0, + red_own=1.0, + blue_own=None, + ) + + envs = init_my_envs(max_steps, grid_size, players_can_pick_same_coin=False) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + grid_size=grid_size, + red_speed=1.0, + blue_speed=0.0, + red_own=1.0, + blue_own=None, + ) def test_logged_info__blue_pick_red_all_the_time(): @@ -277,32 +210,47 @@ def test_logged_info__blue_pick_red_all_the_time(): c_blue_pos = [None, None, None, None] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=1.0, red_own=None, blue_own=0.0) - - envs = init_several_env(max_steps, grid_size, - players_can_pick_same_coin=False) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=1.0, red_own=None, blue_own=0.0) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + grid_size=grid_size, + red_speed=0.0, + blue_speed=1.0, + red_own=None, + blue_own=0.0, + ) + + envs = init_my_envs(max_steps, grid_size, players_can_pick_same_coin=False) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + grid_size=grid_size, + red_speed=0.0, + blue_speed=1.0, + red_own=None, + blue_own=0.0, + ) def test_logged_info__blue_pick_blue_all_the_time(): @@ -314,32 +262,47 @@ def test_logged_info__blue_pick_blue_all_the_time(): c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=1.0, red_own=None, blue_own=1.0) - - envs = init_several_env(max_steps, grid_size, - players_can_pick_same_coin=False) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=1.0, red_own=None, blue_own=1.0) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + grid_size=grid_size, + red_speed=0.0, + blue_speed=1.0, + red_own=None, + blue_own=1.0, + ) + + envs = init_my_envs(max_steps, grid_size, players_can_pick_same_coin=False) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + grid_size=grid_size, + red_speed=0.0, + blue_speed=1.0, + red_own=None, + blue_own=1.0, + ) def test_logged_info__red_pick_blue_all_the_time(): @@ -351,32 +314,47 @@ def test_logged_info__red_pick_blue_all_the_time(): c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=1.0, blue_speed=0.0, red_own=0.0, blue_own=None) - - envs = init_several_env(max_steps, grid_size, - players_can_pick_same_coin=False) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=1.0, blue_speed=0.0, red_own=0.0, blue_own=None) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + grid_size=grid_size, + red_speed=1.0, + blue_speed=0.0, + red_own=0.0, + blue_own=None, + ) + + envs = init_my_envs(max_steps, grid_size, players_can_pick_same_coin=False) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + grid_size=grid_size, + red_speed=1.0, + blue_speed=0.0, + red_own=0.0, + blue_own=None, + ) def test_logged_info__both_pick_blue_all_the_time(): @@ -388,18 +366,26 @@ def test_logged_info__both_pick_blue_all_the_time(): c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=1.0, blue_speed=1.0, red_own=0.0, blue_own=1.0) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + grid_size=grid_size, + red_speed=1.0, + blue_speed=1.0, + red_own=0.0, + blue_own=1.0, + ) def test_logged_info__both_pick_red_all_the_time(): @@ -411,18 +397,26 @@ def test_logged_info__both_pick_red_all_the_time(): c_blue_pos = [None, None, None, None] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=1.0, blue_speed=1.0, red_own=1.0, blue_own=0.0) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + grid_size=grid_size, + red_speed=1.0, + blue_speed=1.0, + red_own=1.0, + blue_own=0.0, + ) def test_logged_info__both_pick_red_half_the_time(): @@ -434,18 +428,26 @@ def test_logged_info__both_pick_red_half_the_time(): c_blue_pos = [None, None, None, None] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.5, blue_speed=0.5, red_own=1.0, blue_own=0.0) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + grid_size=grid_size, + red_speed=0.5, + blue_speed=0.5, + red_own=1.0, + blue_own=0.0, + ) def test_logged_info__both_pick_blue_half_the_time(): @@ -457,18 +459,26 @@ def test_logged_info__both_pick_blue_half_the_time(): c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.5, blue_speed=0.5, red_own=0.0, blue_own=1.0) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + grid_size=grid_size, + red_speed=0.5, + blue_speed=0.5, + red_own=0.0, + blue_own=1.0, + ) def test_logged_info__both_pick_blue(): @@ -480,18 +490,26 @@ def test_logged_info__both_pick_blue(): c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.25, blue_speed=0.5, red_own=0.0, blue_own=1.0) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + grid_size=grid_size, + red_speed=0.25, + blue_speed=0.5, + red_own=0.0, + blue_own=1.0, + ) def test_logged_info__pick_half_the_time_half_blue_half_red(): @@ -503,201 +521,420 @@ def test_logged_info__pick_half_the_time_half_blue_half_red(): c_blue_pos = [None, [1, 1], None, [1, 1]] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.5, blue_speed=0.5, red_own=0.5, blue_own=0.5) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + grid_size=grid_size, + red_speed=0.5, + blue_speed=0.5, + red_own=0.5, + blue_own=0.5, + ) def test_observations_are_invariant_to_the_player_trained_in_reset(): - p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0], - [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]] - p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1], - [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]] + p_red_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [0, 0], + [1, 1], + [2, 0], + [0, 1], + [2, 2], + [1, 2], + ] + p_blue_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [1, 1], + [0, 0], + [0, 1], + [2, 0], + [1, 2], + [2, 2], + ] p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - c_red_pos = [[1, 1], None, [0, 1], None, None, - [2, 2], [0, 0], None, None, [2, 1]] - c_blue_pos = [None, [1, 1], None, [0, 1], [2, 2], - None, None, [0, 0], [2, 1], None] + c_red_pos = [ + [1, 1], + None, + [0, 1], + None, + None, + [2, 2], + [0, 0], + None, + None, + [2, 1], + ] + c_blue_pos = [ + None, + [1, 1], + None, + [0, 1], + [2, 2], + None, + None, + [0, 0], + [2, 1], + None, + ] max_steps, grid_size = 10, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size, - same_obs_for_each_player=False) + envs = init_my_envs(max_steps, grid_size, same_obs_for_each_player=False) for env_i, env in enumerate(envs): obs = env.reset() assert_obs_is_symmetrical(obs, env) step_i = 0 - overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i], - c_red_pos[step_i], c_blue_pos[step_i]) + overwrite_pos( + env, + p_red_pos[step_i], + p_blue_pos[step_i], + c_red_pos[step_i], + c_blue_pos[step_i], + ) for _ in range(n_steps): step_i += 1 - actions = {"player_red": p_red_act[step_i - 1], - "player_blue": p_blue_act[step_i - 1]} + actions = { + "player_red": p_red_act[step_i - 1], + "player_blue": p_blue_act[step_i - 1], + } _, _, _, _ = env.step(actions) if step_i == max_steps: break - overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i], - c_red_pos[step_i], c_blue_pos[step_i]) + overwrite_pos( + env, + p_red_pos[step_i], + p_blue_pos[step_i], + c_red_pos[step_i], + c_blue_pos[step_i], + ) def assert_obs_is_symmetrical(obs, env): - assert np.all(obs[env.players_ids[0]][..., 0] == - obs[env.players_ids[1]][..., 1]) - assert np.all(obs[env.players_ids[1]][..., 0] == - obs[env.players_ids[0]][..., 1]) - assert np.all(obs[env.players_ids[0]][..., 2] == - obs[env.players_ids[1]][..., 3]) - assert np.all(obs[env.players_ids[1]][..., 2] == - obs[env.players_ids[0]][..., 3]) + assert np.all( + obs[env.players_ids[0]][..., 0] == obs[env.players_ids[1]][..., 1] + ) + assert np.all( + obs[env.players_ids[1]][..., 0] == obs[env.players_ids[0]][..., 1] + ) + assert np.all( + obs[env.players_ids[0]][..., 2] == obs[env.players_ids[1]][..., 3] + ) + assert np.all( + obs[env.players_ids[1]][..., 2] == obs[env.players_ids[0]][..., 3] + ) def test_observations_are_invariant_to_the_player_trained_in_step(): - p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0], - [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]] - p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1], - [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]] + p_red_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [0, 0], + [1, 1], + [2, 0], + [0, 1], + [2, 2], + [1, 2], + ] + p_blue_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [1, 1], + [0, 0], + [0, 1], + [2, 0], + [1, 2], + [2, 2], + ] p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - c_red_pos = [[1, 1], None, [0, 1], None, None, - [2, 2], [0, 0], None, None, [2, 1]] - c_blue_pos = [None, [1, 1], None, [0, 1], [2, 2], - None, None, [0, 0], [2, 1], None] + c_red_pos = [ + [1, 1], + None, + [0, 1], + None, + None, + [2, 2], + [0, 0], + None, + None, + [2, 1], + ] + c_blue_pos = [ + None, + [1, 1], + None, + [0, 1], + [2, 2], + None, + None, + [0, 0], + [2, 1], + None, + ] max_steps, grid_size = 10, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size, - same_obs_for_each_player=False) + envs = init_my_envs(max_steps, grid_size, same_obs_for_each_player=False) for env_i, env in enumerate(envs): _ = env.reset() step_i = 0 - overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i], - c_red_pos[step_i], c_blue_pos[step_i]) + overwrite_pos( + env, + p_red_pos[step_i], + p_blue_pos[step_i], + c_red_pos[step_i], + c_blue_pos[step_i], + ) for _ in range(n_steps): step_i += 1 - actions = {"player_red": p_red_act[step_i - 1], - "player_blue": p_blue_act[step_i - 1]} + actions = { + "player_red": p_red_act[step_i - 1], + "player_blue": p_blue_act[step_i - 1], + } obs, reward, done, info = env.step(actions) # assert observations are symmetrical respective to the actions if step_i % 2 == 1: obs_step_odd = obs elif step_i % 2 == 0: - assert np.all(obs[env.players_ids[0]] == - obs_step_odd[env.players_ids[1]]) - assert np.all(obs[env.players_ids[1]] == - obs_step_odd[env.players_ids[0]]) + assert np.all( + obs[env.players_ids[0]] == obs_step_odd[env.players_ids[1]] + ) + assert np.all( + obs[env.players_ids[1]] == obs_step_odd[env.players_ids[0]] + ) assert_obs_is_symmetrical(obs, env) if step_i == max_steps: break - overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i], - c_red_pos[step_i], c_blue_pos[step_i]) + overwrite_pos( + env, + p_red_pos[step_i], + p_blue_pos[step_i], + c_red_pos[step_i], + c_blue_pos[step_i], + ) def test_observations_are_not_invariant_to_the_player_trained_in_reset(): - p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0], - [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]] - p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1], - [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]] + p_red_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [0, 0], + [1, 1], + [2, 0], + [0, 1], + [2, 2], + [1, 2], + ] + p_blue_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [1, 1], + [0, 0], + [0, 1], + [2, 0], + [1, 2], + [2, 2], + ] p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - c_red_pos = [[1, 1], None, [0, 1], None, None, - [2, 2], [0, 0], None, None, [2, 1]] - c_blue_pos = [None, [1, 1], None, [0, 1], [2, 2], - None, None, [0, 0], [2, 1], None] + c_red_pos = [ + [1, 1], + None, + [0, 1], + None, + None, + [2, 2], + [0, 0], + None, + None, + [2, 1], + ] + c_blue_pos = [ + None, + [1, 1], + None, + [0, 1], + [2, 2], + None, + None, + [0, 0], + [2, 1], + None, + ] max_steps, grid_size = 10, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size, - same_obs_for_each_player=True) + envs = init_my_envs(max_steps, grid_size, same_obs_for_each_player=True) for env_i, env in enumerate(envs): obs = env.reset() assert_obs_is_not_symmetrical(obs, env) step_i = 0 - overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i], - c_red_pos[step_i], c_blue_pos[step_i]) + overwrite_pos( + env, + p_red_pos[step_i], + p_blue_pos[step_i], + c_red_pos[step_i], + c_blue_pos[step_i], + ) for _ in range(n_steps): step_i += 1 - actions = {"player_red": p_red_act[step_i - 1], - "player_blue": p_blue_act[step_i - 1]} + actions = { + "player_red": p_red_act[step_i - 1], + "player_blue": p_blue_act[step_i - 1], + } _, _, _, _ = env.step(actions) if step_i == max_steps: break - overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i], - c_red_pos[step_i], c_blue_pos[step_i]) + overwrite_pos( + env, + p_red_pos[step_i], + p_blue_pos[step_i], + c_red_pos[step_i], + c_blue_pos[step_i], + ) def assert_obs_is_not_symmetrical(obs, env): - assert np.all(obs[env.players_ids[0]] == - obs[env.players_ids[1]]) + assert np.all(obs[env.players_ids[0]] == obs[env.players_ids[1]]) def test_observations_are_not_invariant_to_the_player_trained_in_step(): - p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0], - [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]] - p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1], - [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]] + p_red_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [0, 0], + [1, 1], + [2, 0], + [0, 1], + [2, 2], + [1, 2], + ] + p_blue_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [1, 1], + [0, 0], + [0, 1], + [2, 0], + [1, 2], + [2, 2], + ] p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - c_red_pos = [[1, 1], None, [0, 1], None, None, - [2, 2], [0, 0], None, None, [2, 1]] - c_blue_pos = [None, [1, 1], None, [0, 1], [2, 2], - None, None, [0, 0], [2, 1], None] + c_red_pos = [ + [1, 1], + None, + [0, 1], + None, + None, + [2, 2], + [0, 0], + None, + None, + [2, 1], + ] + c_blue_pos = [ + None, + [1, 1], + None, + [0, 1], + [2, 2], + None, + None, + [0, 0], + [2, 1], + None, + ] max_steps, grid_size = 10, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size, - same_obs_for_each_player=True) + envs = init_my_envs(max_steps, grid_size, same_obs_for_each_player=True) for env_i, env in enumerate(envs): _ = env.reset() step_i = 0 - overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i], - c_red_pos[step_i], c_blue_pos[step_i]) + overwrite_pos( + env, + p_red_pos[step_i], + p_blue_pos[step_i], + c_red_pos[step_i], + c_blue_pos[step_i], + ) for _ in range(n_steps): step_i += 1 - actions = {"player_red": p_red_act[step_i - 1], - "player_blue": p_blue_act[step_i - 1]} + actions = { + "player_red": p_red_act[step_i - 1], + "player_blue": p_blue_act[step_i - 1], + } obs, reward, done, info = env.step(actions) # assert observations are symmetrical respective to the actions if step_i % 2 == 1: obs_step_odd = obs elif step_i % 2 == 0: - assert np.any(obs[env.players_ids[0]] != - obs_step_odd[env.players_ids[1]]) - assert np.any(obs[env.players_ids[1]] != - obs_step_odd[env.players_ids[0]]) + assert np.any( + obs[env.players_ids[0]] != obs_step_odd[env.players_ids[1]] + ) + assert np.any( + obs[env.players_ids[1]] != obs_step_odd[env.players_ids[0]] + ) assert_obs_is_not_symmetrical(obs, env) if step_i == max_steps: break - overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i], - c_red_pos[step_i], c_blue_pos[step_i]) + overwrite_pos( + env, + p_red_pos[step_i], + p_blue_pos[step_i], + c_red_pos[step_i], + c_blue_pos[step_i], + ) @flaky(max_runs=4, min_passes=1) def test_who_pick_is_random(): - size = 100 + size = 1000 p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] * size p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] * size p_red_act = [0, 0, 0, 0] * size @@ -706,29 +943,45 @@ def test_who_pick_is_random(): c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] * size max_steps, grid_size = int(4 * size), 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=1.0, blue_speed=1.0, red_own=0.0, blue_own=1.0) - - envs = init_several_env(max_steps, grid_size, - players_can_pick_same_coin=False) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.5, blue_speed=0.5, red_own=0.0, blue_own=1.0) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + grid_size=grid_size, + red_speed=1.0, + blue_speed=1.0, + red_own=0.0, + blue_own=1.0, + ) + + envs = init_my_envs(max_steps, grid_size, players_can_pick_same_coin=False) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + grid_size=grid_size, + red_speed=0.5, + blue_speed=0.5, + red_own=0.0, + blue_own=1.0, + delta_err=0.05, + ) diff --git a/tests/marltoolbox/envs/test_mixed_motive_coin_game.py b/tests/marltoolbox/envs/test_mixed_motive_coin_game.py deleted file mode 100644 index 6cd0ecf..0000000 --- a/tests/marltoolbox/envs/test_mixed_motive_coin_game.py +++ /dev/null @@ -1,606 +0,0 @@ -import random - -import numpy as np - -from marltoolbox.envs.mixed_motive_coin_game import MixedMotiveCoinGame - - -# TODO add tests for grid_size != 3 - -def test_reset(): - max_steps, grid_size = 20, 3 - envs = init_several_env(max_steps, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - -def init_several_env(max_steps, grid_size, - players_can_pick_same_coin=True, - same_obs_for_each_player=True): - mixed_motive_coin_game = init_env( - max_steps, - MixedMotiveCoinGame, - grid_size, - players_can_pick_same_coin=players_can_pick_same_coin, - same_obs_for_each_player=same_obs_for_each_player) - return [mixed_motive_coin_game] - - -def init_env(max_steps, - env_class, - seed=None, - grid_size=3, - players_can_pick_same_coin=True, - same_obs_for_each_player=False): - config = { - "max_steps": max_steps, - "grid_size": grid_size, - "both_players_can_pick_the_same_coin": players_can_pick_same_coin, - "same_obs_for_each_player": same_obs_for_each_player, - } - env = env_class(config) - env.seed(seed) - return env - - -def check_obs(obs, grid_size): - assert len(obs) == 2, "two players" - for key, player_obs in obs.items(): - assert player_obs.shape == (grid_size, grid_size, 4) - assert player_obs[..., 0].sum() == 1.0, \ - f"observe 1 player red in grid: {player_obs[..., 0]}" - assert player_obs[..., 1].sum() == 1.0, \ - f"observe 1 player blue in grid: {player_obs[..., 1]}" - assert player_obs[..., 2:].sum() == 2.0, \ - f"observe 1 coin in grid: {player_obs[..., 0]}" - - -def assert_logger_buffer_size(env, n_steps): - assert len(env.red_pick) == n_steps - assert len(env.red_pick_own) == n_steps - assert len(env.blue_pick) == n_steps - assert len(env.blue_pick_own) == n_steps - - -def test_step(): - max_steps, grid_size = 20, 3 - envs = init_several_env(max_steps, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - actions = {policy_id: random.randint(0, env.NUM_ACTIONS - 1) - for policy_id in env.players_ids} - obs, reward, done, info = env.step(actions) - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=1) - assert not done["__all__"] - - -def test_multiple_steps(): - max_steps, grid_size = 20, 3 - n_steps = int(max_steps * 0.75) - envs = init_several_env(max_steps, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - for step_i in range(1, n_steps, 1): - actions = {policy_id: random.randint(0, env.NUM_ACTIONS - 1) - for policy_id in env.players_ids} - obs, reward, done, info = env.step(actions) - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=step_i) - assert not done["__all__"] - - -def test_multiple_episodes(): - max_steps, grid_size = 20, 3 - n_steps = int(max_steps * 8.25) - envs = init_several_env(max_steps, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - step_i = 0 - for _ in range(n_steps): - step_i += 1 - actions = {policy_id: random.randint(0, env.NUM_ACTIONS - 1) - for policy_id in env.players_ids} - obs, reward, done, info = env.step(actions) - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=step_i) - assert not done["__all__"] or \ - (step_i == max_steps and done["__all__"]) - if done["__all__"]: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - step_i = 0 - - -def overwrite_pos(env, p_red_pos, p_blue_pos, c_red_pos, c_blue_pos): - env.red_pos = p_red_pos - env.blue_pos = p_blue_pos - env.red_coin_pos = c_red_pos - env.blue_coin_pos = c_blue_pos - - env.red_pos = np.array(env.red_pos) - env.blue_pos = np.array(env.blue_pos) - env.red_coin_pos = np.array(env.red_coin_pos) - env.blue_coin_pos = np.array(env.blue_coin_pos) - - -def assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed, blue_speed, red_own, blue_own): - step_i = 0 - for _ in range(n_steps): - step_i += 1 - actions = {"player_red": p_red_act[step_i - 1], - "player_blue": p_blue_act[step_i - 1]} - obs, reward, done, info = env.step(actions) - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=step_i) - assert not done["__all__"] or (step_i == max_steps and done["__all__"]) - - if done["__all__"]: - assert info["player_red"]["pick_speed"] == red_speed - assert info["player_blue"]["pick_speed"] == blue_speed - - if red_own is None: - assert "pick_own_color" not in info["player_red"] - else: - assert info["player_red"]["pick_own_color"] == red_own - if blue_own is None: - assert "pick_own_color" not in info["player_blue"] - else: - assert info["player_blue"]["pick_own_color"] == blue_own - - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - step_i = 0 - - overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i], - c_red_pos[step_i], c_blue_pos[step_i]) - - -def test_logged_info_no_picking(): - p_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - p_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - max_steps, grid_size = 4, 3 - n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None) - - -def test_logged_info__red_pick_red_all_the_time(): - p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - max_steps, grid_size = 4, 3 - n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None) - - -def test_logged_info__blue_pick_red_all_the_time(): - p_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - max_steps, grid_size = 4, 3 - n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None) - - -def test_logged_info__blue_pick_blue_all_the_time(): - p_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - max_steps, grid_size = 4, 3 - n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None) - - -def test_logged_info__red_pick_blue_all_the_time(): - p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - max_steps, grid_size = 4, 3 - n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None) - - -def test_logged_info__both_pick_blue_all_the_time(): - p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - max_steps, grid_size = 4, 3 - n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=1.0, blue_speed=1.0, red_own=0.0, blue_own=1.0) - - -def test_logged_info__both_pick_red_all_the_time(): - p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - max_steps, grid_size = 4, 3 - n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=1.0, blue_speed=1.0, red_own=1.0, blue_own=0.0) - - -def test_logged_info__both_pick_red_half_the_time(): - p_red_pos = [[0, 0], [0, 0], [1, 0], [1, 0]] - p_blue_pos = [[1, 0], [1, 0], [0, 0], [0, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - max_steps, grid_size = 4, 3 - n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None) - - -def test_logged_info__both_pick_blue_half_the_time(): - p_red_pos = [[0, 0], [0, 0], [1, 0], [1, 0]] - p_blue_pos = [[1, 0], [1, 0], [0, 0], [0, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - max_steps, grid_size = 4, 3 - n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None) - - -def test_logged_info__both_pick_blue(): - p_red_pos = [[0, 0], [0, 0], [0, 0], [1, 0]] - p_blue_pos = [[1, 0], [1, 0], [0, 0], [0, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - max_steps, grid_size = 4, 3 - n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None) - - -def test_logged_info__pick_half_the_time_half_blue_half_red(): - p_red_pos = [[0, 0], [0, 0], [1, 0], [1, 0]] - p_blue_pos = [[1, 0], [1, 0], [0, 0], [0, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[1, 1], [0, 0], [1, 1], [0, 0]] - c_blue_pos = [[0, 0], [1, 1], [0, 0], [1, 1]] - max_steps, grid_size = 4, 3 - n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0]) - - assert_info(n_steps, p_red_act, p_blue_act, env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None) - - -def test_observations_are_invariant_to_the_player_trained_in_reset(): - p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0], - [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]] - p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1], - [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]] - p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - c_red_pos = [[1, 1], [0, 0], [0, 1], [0, 0], [0, 0], - [2, 2], [0, 0], [0, 0], [0, 0], [2, 1]] - c_blue_pos = [[0, 0], [1, 1], [0, 0], [0, 1], [2, 2], - [0, 0], [0, 0], [0, 0], [2, 1], [0, 0]] - max_steps, grid_size = 10, 3 - n_steps = max_steps - envs = init_several_env(max_steps, grid_size, - same_obs_for_each_player=False) - - for env_i, env in enumerate(envs): - obs = env.reset() - assert_obs_is_symmetrical(obs, env) - step_i = 0 - overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i], - c_red_pos[step_i], c_blue_pos[step_i]) - - for _ in range(n_steps): - step_i += 1 - actions = {"player_red": p_red_act[step_i - 1], - "player_blue": p_blue_act[step_i - 1]} - _, _, _, _ = env.step(actions) - - if step_i == max_steps: - break - - overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i], - c_red_pos[step_i], c_blue_pos[step_i]) - - -def assert_obs_is_symmetrical(obs, env): - assert np.all(obs[env.players_ids[0]][..., 0] == - obs[env.players_ids[1]][..., 1]) - assert np.all(obs[env.players_ids[1]][..., 0] == - obs[env.players_ids[0]][..., 1]) - assert np.all(obs[env.players_ids[0]][..., 2] == - obs[env.players_ids[1]][..., 3]) - assert np.all(obs[env.players_ids[1]][..., 2] == - obs[env.players_ids[0]][..., 3]) - - -def test_observations_are_invariant_to_the_player_trained_in_step(): - p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0], - [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]] - p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1], - [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]] - p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - c_red_pos = [[1, 1], [0, 0], [0, 1], [0, 0], [0, 0], - [2, 2], [0, 0], [0, 0], [0, 0], [2, 1]] - c_blue_pos = [[0, 0], [1, 1], [0, 0], [0, 1], [2, 2], - [0, 0], [0, 0], [0, 0], [2, 1], [0, 0]] - max_steps, grid_size = 10, 3 - n_steps = max_steps - envs = init_several_env(max_steps, grid_size, - same_obs_for_each_player=False) - - for env_i, env in enumerate(envs): - _ = env.reset() - step_i = 0 - overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i], - c_red_pos[step_i], c_blue_pos[step_i]) - - for _ in range(n_steps): - step_i += 1 - actions = {"player_red": p_red_act[step_i - 1], - "player_blue": p_blue_act[step_i - 1]} - obs, reward, done, info = env.step(actions) - - # assert observations are symmetrical respective to the actions - if step_i % 2 == 1: - obs_step_odd = obs - elif step_i % 2 == 0: - assert np.all(obs[env.players_ids[0]] == - obs_step_odd[env.players_ids[1]]) - assert np.all(obs[env.players_ids[1]] == - obs_step_odd[env.players_ids[0]]) - assert_obs_is_symmetrical(obs, env) - - if step_i == max_steps: - break - - overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i], - c_red_pos[step_i], c_blue_pos[step_i]) - - -def test_observations_are_not_invariant_to_the_player_trained_in_reset(): - p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0], - [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]] - p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1], - [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]] - p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - c_red_pos = [[1, 1], [0, 0], [0, 1], [0, 0], [0, 0], - [2, 2], [0, 0], [0, 0], [0, 0], [2, 1]] - c_blue_pos = [[0, 0], [1, 1], [0, 0], [0, 1], [2, 2], - [0, 0], [0, 0], [0, 0], [2, 1], [0, 0]] - max_steps, grid_size = 10, 3 - n_steps = max_steps - envs = init_several_env(max_steps, grid_size, - same_obs_for_each_player=True) - - for env_i, env in enumerate(envs): - obs = env.reset() - assert_obs_is_not_symmetrical(obs, env) - step_i = 0 - overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i], - c_red_pos[step_i], c_blue_pos[step_i]) - - for _ in range(n_steps): - step_i += 1 - actions = {"player_red": p_red_act[step_i - 1], - "player_blue": p_blue_act[step_i - 1]} - _, _, _, _ = env.step(actions) - - if step_i == max_steps: - break - - overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i], - c_red_pos[step_i], c_blue_pos[step_i]) - - -def assert_obs_is_not_symmetrical(obs, env): - assert np.all(obs[env.players_ids[0]] == - obs[env.players_ids[1]]) - - -def test_observations_are_not_invariant_to_the_player_trained_in_step(): - p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0], - [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]] - p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1], - [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]] - p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - c_red_pos = [[1, 1], [0, 0], [0, 1], [0, 0], [0, 0], - [2, 2], [0, 0], [0, 0], [0, 0], [2, 1]] - c_blue_pos = [[0, 0], [1, 1], [0, 0], [0, 1], [2, 2], - [0, 0], [0, 0], [0, 0], [2, 1], [0, 0]] - max_steps, grid_size = 10, 3 - n_steps = max_steps - envs = init_several_env(max_steps, grid_size, - same_obs_for_each_player=True) - - for env_i, env in enumerate(envs): - _ = env.reset() - step_i = 0 - overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i], - c_red_pos[step_i], c_blue_pos[step_i]) - - for _ in range(n_steps): - step_i += 1 - actions = {"player_red": p_red_act[step_i - 1], - "player_blue": p_blue_act[step_i - 1]} - obs, reward, done, info = env.step(actions) - - # assert observations are symmetrical respective to the actions - if step_i % 2 == 1: - obs_step_odd = obs - elif step_i % 2 == 0: - assert np.any(obs[env.players_ids[0]] != - obs_step_odd[env.players_ids[1]]) - assert np.any(obs[env.players_ids[1]] != - obs_step_odd[env.players_ids[0]]) - assert_obs_is_not_symmetrical(obs, env) - - if step_i == max_steps: - break - - overwrite_pos(env, p_red_pos[step_i], p_blue_pos[step_i], - c_red_pos[step_i], c_blue_pos[step_i]) diff --git a/tests/marltoolbox/envs/test_mixed_motive_vectorized_coin_game.py b/tests/marltoolbox/envs/test_mixed_motive_vectorized_coin_game.py deleted file mode 100644 index 6dc6107..0000000 --- a/tests/marltoolbox/envs/test_mixed_motive_vectorized_coin_game.py +++ /dev/null @@ -1,1424 +0,0 @@ -import copy -import random - -import numpy as np - -from marltoolbox.envs.vectorized_mixed_motive_coin_game import ( - VectMixedMotiveCG, -) -from test_coin_game import ( - assert_obs_is_symmetrical, - assert_obs_is_not_symmetrical, -) - - -# TODO add tests for grid_size != 3 - - -def test_reset(): - max_steps, batch_size, grid_size = 20, 5, 3 - envs = init_several_env(max_steps, batch_size, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - -def init_several_env( - max_steps, - batch_size, - grid_size, - players_can_pick_same_coin=True, - same_obs_for_each_player=True, -): - mixed_motive_coin_game = init_env( - max_steps, - batch_size, - VectMixedMotiveCG, - grid_size, - players_can_pick_same_coin=players_can_pick_same_coin, - same_obs_for_each_player=same_obs_for_each_player, - ) - return [mixed_motive_coin_game] - - -def init_env( - max_steps, - batch_size, - env_class, - seed=None, - grid_size=3, - players_can_pick_same_coin=True, - same_obs_for_each_player=False, -): - config = { - "max_steps": max_steps, - "batch_size": batch_size, - "grid_size": grid_size, - "same_obs_for_each_player": same_obs_for_each_player, - "both_players_can_pick_the_same_coin": players_can_pick_same_coin, - } - env = env_class(config) - env.seed(seed) - return env - - -def check_obs(obs, batch_size, grid_size): - assert len(obs) == 2, "two players" - for i in range(batch_size): - for key, player_obs in obs.items(): - assert player_obs.shape == (batch_size, grid_size, grid_size, 4) - assert ( - player_obs[i, ..., 0].sum() == 1.0 - ), f"observe 1 player red in grid: {player_obs[i, ..., 0]}" - assert ( - player_obs[i, ..., 1].sum() == 1.0 - ), f"observe 1 player blue in grid: {player_obs[i, ..., 1]}" - assert ( - player_obs[i, ..., 2:].sum() == 2.0 - ), f"observe 1 coin in grid: {player_obs[i, ..., 0]}" - - -def assert_logger_buffer_size(env, n_steps): - assert len(env.red_pick) == n_steps - assert len(env.red_pick_own) == n_steps - assert len(env.blue_pick) == n_steps - assert len(env.blue_pick_own) == n_steps - - -def test_step(): - max_steps, batch_size, grid_size = 20, 5, 3 - envs = init_several_env(max_steps, batch_size, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - actions = { - policy_id: [ - random.randint(0, env.NUM_ACTIONS - 1) - for _ in range(batch_size) - ] - for policy_id in env.players_ids - } - obs, reward, done, info = env.step(actions) - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=1) - assert not done["__all__"] - - -def test_multiple_steps(): - max_steps, batch_size, grid_size = 20, 5, 3 - n_steps = int(max_steps * 0.75) - envs = init_several_env(max_steps, batch_size, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - for step_i in range(1, n_steps, 1): - actions = { - policy_id: [ - random.randint(0, env.NUM_ACTIONS - 1) - for _ in range(batch_size) - ] - for policy_id in env.players_ids - } - obs, reward, done, info = env.step(actions) - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=step_i) - assert not done["__all__"] - - -def test_multiple_episodes(): - max_steps, batch_size, grid_size = 20, 100, 3 - n_steps = int(max_steps * 8.25) - envs = init_several_env(max_steps, batch_size, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - step_i = 0 - for _ in range(n_steps): - step_i += 1 - actions = { - policy_id: [ - random.randint(0, env.NUM_ACTIONS - 1) - for _ in range(batch_size) - ] - for policy_id in env.players_ids - } - obs, reward, done, info = env.step(actions) - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=step_i) - assert not done["__all__"] or ( - step_i == max_steps and done["__all__"] - ) - if done["__all__"]: - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - step_i = 0 - - -def overwrite_pos( - step_i, - batch_deltas, - n_steps_in_epi, - env, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, -): - assert len(p_red_pos) == n_steps_in_epi - assert len(p_blue_pos) == n_steps_in_epi - assert len(c_red_pos) == n_steps_in_epi - assert len(c_blue_pos) == n_steps_in_epi - - env.red_pos = [ - p_red_pos[(step_i + delta) % n_steps_in_epi] for delta in batch_deltas - ] - env.blue_pos = [ - p_blue_pos[(step_i + delta) % n_steps_in_epi] for delta in batch_deltas - ] - env.red_coin_pos = [ - c_red_pos[(step_i + delta) % n_steps_in_epi] for delta in batch_deltas - ] - env.blue_coin_pos = [ - c_blue_pos[(step_i + delta) % n_steps_in_epi] for delta in batch_deltas - ] - - env.red_pos = np.array(env.red_pos) - env.blue_pos = np.array(env.blue_pos) - env.red_coin_pos = np.array(env.red_coin_pos) - env.blue_coin_pos = np.array(env.blue_coin_pos) - - -def assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - n_steps_in_epi, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed, - blue_speed, - red_own, - blue_own, -): - step_i = 0 - - for _ in range(n_steps): - overwrite_pos( - step_i, - batch_deltas, - n_steps_in_epi, - env, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - ) - actions = { - "player_red": [ - p_red_act[(step_i + delta) % n_steps_in_epi] - for delta in batch_deltas - ], - "player_blue": [ - p_blue_act[(step_i + delta) % n_steps_in_epi] - for delta in batch_deltas - ], - } - step_i += 1 - - obs, reward, done, info = env.step(actions) - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=step_i) - assert not done["__all__"] or ( - step_i == n_steps_in_epi and done["__all__"] - ) - - if done["__all__"]: - assert info["player_red"]["pick_speed"] == red_speed - assert info["player_blue"]["pick_speed"] == blue_speed - - if red_own is None: - assert "pick_own_color" not in info["player_red"] - else: - assert info["player_red"]["pick_own_color"] == red_own - if blue_own is None: - assert "pick_own_color" not in info["player_blue"] - else: - assert info["player_blue"]["pick_own_color"] == blue_own - - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - step_i = 0 - - -def test_logged_info_no_picking(): - p_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - p_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - max_steps, batch_size, grid_size = 4, 28, 3 - n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env in envs: - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=0.0, - red_own=None, - blue_own=None, - ) - - -def test_logged_info__red_pick_red_all_the_time(): - p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - max_steps, batch_size, grid_size = 4, 28, 3 - n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=0.0, - red_own=None, - blue_own=None, - ) - - -def test_logged_info__blue_pick_red_all_the_time(): - p_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - max_steps, batch_size, grid_size = 4, 28, 3 - n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=0.0, - red_own=None, - blue_own=None, - ) - - -def test_logged_info__blue_pick_blue_all_the_time(): - p_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - max_steps, batch_size, grid_size = 4, 28, 3 - n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=0.0, - red_own=None, - blue_own=None, - ) - - -def test_logged_info__red_pick_blue_all_the_time(): - p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - max_steps, batch_size, grid_size = 4, 28, 3 - n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=0.0, - red_own=None, - blue_own=None, - ) - - -def test_logged_info__red_pick_blue_all_the_time_wt_difference_in_actions(): - p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_blue_pos = [[0, 0], [0, 0], [0, 1], [0, 1]] - p_red_act = [0, 1, 2, 3] - p_blue_act = [0, 1, 2, 3] - c_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - c_blue_pos = [[1, 1], [1, 2], [2, 0], [0, 0]] - max_steps, batch_size, grid_size = 4, 4, 3 - n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=0.0, - red_own=None, - blue_own=None, - ) - - -def test_logged_info__both_pick_blue_all_the_time(): - p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - max_steps, batch_size, grid_size = 4, 28, 3 - n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=1.0, - blue_speed=1.0, - red_own=0.0, - blue_own=1.0, - ) - - -def test_logged_info__both_pick_red_all_the_time(): - p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - max_steps, batch_size, grid_size = 4, 28, 3 - n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=1.0, - blue_speed=1.0, - red_own=1.0, - blue_own=0.0, - ) - - -def test_logged_info__both_pick_red_half_the_time(): - p_red_pos = [[0, 0], [0, 0], [1, 0], [1, 0]] - p_blue_pos = [[1, 0], [1, 0], [0, 0], [0, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - max_steps, batch_size, grid_size = 4, 28, 3 - n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=0.0, - red_own=None, - blue_own=None, - ) - - -def test_logged_info__both_pick_blue_half_the_time(): - p_red_pos = [[0, 0], [0, 0], [1, 0], [1, 0]] - p_blue_pos = [[1, 0], [1, 0], [0, 0], [0, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - max_steps, batch_size, grid_size = 4, 28, 3 - n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=0.0, - red_own=None, - blue_own=None, - ) - - -def test_logged_info__both_pick_blue(): - p_red_pos = [[0, 0], [0, 0], [0, 0], [1, 0]] - p_blue_pos = [[1, 0], [1, 0], [0, 0], [0, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - max_steps, batch_size, grid_size = 4, 28, 3 - n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=0.0, - red_own=None, - blue_own=None, - ) - - -def test_logged_info__pick_half_the_time_half_blue_half_red(): - p_red_pos = [[0, 0], [0, 0], [1, 0], [1, 0]] - p_blue_pos = [[1, 0], [1, 0], [0, 0], [0, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[1, 1], [0, 0], [1, 1], [0, 0]] - c_blue_pos = [[0, 0], [1, 1], [0, 0], [1, 1]] - max_steps, batch_size, grid_size = 4, 28, 3 - n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=0.0, - red_own=None, - blue_own=None, - ) - - -def test_logged_info__pick_slowly_red_coin(): - p_red_pos = [[1, 0], [0, 0], [0, 0], [0, 0]] - p_blue_pos = [[1, 0], [0, 0], [0, 0], [0, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - max_steps, batch_size, grid_size = 4, 28, 3 - n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.25, - blue_speed=0.25, - red_own=1.0, - blue_own=0.0, - ) - - -def test_logged_info__pick_slowly_blue_coin(): - p_red_pos = [[1, 0], [0, 0], [0, 0], [0, 0]] - p_blue_pos = [[1, 0], [0, 0], [0, 0], [0, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - max_steps, batch_size, grid_size = 4, 28, 3 - n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.25, - blue_speed=0.25, - red_own=0.0, - blue_own=1.0, - ) - - -def test_logged_info__pick_quickly_red_coin(): - p_red_pos = [[1, 0], [0, 0], [1, 0], [1, 0]] - p_blue_pos = [[1, 0], [0, 0], [1, 0], [1, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - max_steps, batch_size, grid_size = 4, 28, 3 - n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.75, - blue_speed=0.75, - red_own=1.0, - blue_own=0.0, - ) - - -def test_logged_info__pick_quickly_blue_coin(): - p_red_pos = [[1, 0], [0, 0], [1, 0], [1, 0]] - p_blue_pos = [[1, 0], [0, 0], [1, 0], [1, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] - c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] - max_steps, batch_size, grid_size = 4, 28, 3 - n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.75, - blue_speed=0.75, - red_own=0.0, - blue_own=1.0, - ) - - -def test_logged_info__pick_slowly_mixed_coin(): - p_red_pos = [[1, 0], [0, 0], [0, 0], [1, 0]] - p_blue_pos = [[1, 0], [0, 0], [0, 0], [1, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[1, 1], [1, 1], [0, 0], [0, 0]] - c_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1]] - max_steps, batch_size, grid_size = 4, 28, 3 - n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.50, - blue_speed=0.50, - red_own=0.5, - blue_own=0.5, - ) - - -def test_logged_info__pick_quickly_mixed_coin(): - p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] - p_red_act = [0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0] - c_red_pos = [[1, 1], [1, 1], [0, 0], [0, 0]] - c_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1]] - max_steps, batch_size, grid_size = 4, 28, 3 - n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info( - batch_deltas, - n_steps, - batch_size, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=1.0, - blue_speed=1.0, - red_own=0.5, - blue_own=0.5, - ) - - -def test_get_and_set_env_state(): - max_steps, batch_size, grid_size = 20, 100, 3 - n_steps = int(max_steps * 8.25) - envs = init_several_env(max_steps, batch_size, grid_size) - - for env in envs: - obs = env.reset() - initial_env_state = env._save_env() - initial_env_state_saved = copy.deepcopy(initial_env_state) - env_initial = copy.deepcopy(env) - - step_i = 0 - for _ in range(n_steps): - step_i += 1 - actions = { - policy_id: [ - random.randint(0, env.NUM_ACTIONS - 1) - for _ in range(batch_size) - ] - for policy_id in env.players_ids - } - obs, reward, done, info = env.step(actions) - - assert all( - [ - v == initial_env_state_saved[k] - if not isinstance(v, np.ndarray) - else (v == initial_env_state_saved[k]).all() - for k, v in initial_env_state.items() - ] - ) - env_state_after_step = env._save_env() - env_after_step = copy.deepcopy(env) - - env._load_env(initial_env_state) - env_vars, env_initial_vars = vars(env), vars(env_initial) - env_vars.pop("np_random", None) - env_initial_vars.pop("np_random", None) - assert all( - [ - v == env_initial_vars[k] - if not isinstance(v, np.ndarray) - else (v == env_initial_vars[k]).all() - for k, v in env_vars.items() - ] - ) - - env._load_env(env_state_after_step) - env_vars, env_after_step_vars = vars(env), vars(env_after_step) - env_vars.pop("np_random", None) - env_after_step_vars.pop("np_random", None) - assert all( - [ - v == env_after_step_vars[k] - if not isinstance(v, np.ndarray) - else (v == env_after_step_vars[k]).all() - for k, v in env_vars.items() - ] - ) - - if done["__all__"]: - obs = env.reset() - step_i = 0 - - -def test_observations_are_invariant_to_the_player_trained_wt_step(): - p_red_pos = [ - [0, 0], - [0, 0], - [1, 1], - [1, 1], - [0, 0], - [1, 1], - [2, 0], - [0, 1], - [2, 2], - [1, 2], - ] - p_blue_pos = [ - [0, 0], - [0, 0], - [1, 1], - [1, 1], - [1, 1], - [0, 0], - [0, 1], - [2, 0], - [1, 2], - [2, 2], - ] - p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - c_red_pos = [ - [1, 1], - [0, 0], - [0, 1], - [0, 0], - [0, 0], - [2, 2], - [0, 0], - [0, 0], - [0, 0], - [2, 1], - ] - c_blue_pos = [ - [0, 0], - [1, 1], - [0, 0], - [0, 1], - [2, 2], - [0, 0], - [0, 0], - [0, 0], - [2, 1], - [0, 0], - ] - max_steps, batch_size, grid_size = 10, 52, 3 - n_steps = max_steps - envs = init_several_env( - max_steps, batch_size, grid_size, same_obs_for_each_player=False - ) - - batch_deltas = [ - i % max_steps if i % 2 == 0 else i % max_steps - 1 - for i in range(batch_size) - ] - - for env_i, env in enumerate(envs): - _ = env.reset() - step_i = 0 - - for _ in range(n_steps): - overwrite_pos( - step_i, - batch_deltas, - max_steps, - env, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - ) - actions = { - "player_red": [ - p_red_act[(step_i + delta) % max_steps] - for delta in batch_deltas - ], - "player_blue": [ - p_blue_act[(step_i + delta) % max_steps] - for delta in batch_deltas - ], - } - obs, reward, done, info = env.step(actions) - - step_i += 1 - # assert that observations are symmetrical respective to the actions - if step_i % 2 == 1: - obs_step_odd = obs - elif step_i % 2 == 0: - assert np.all( - obs[env.players_ids[0]] == obs_step_odd[env.players_ids[1]] - ) - assert np.all( - obs[env.players_ids[1]] == obs_step_odd[env.players_ids[0]] - ) - assert_obs_is_symmetrical(obs, env) - - if step_i == max_steps: - break - - -def test_observations_are_invariant_to_the_player_trained_wt_reset(): - p_red_pos = [ - [0, 0], - [0, 0], - [1, 1], - [1, 1], - [0, 0], - [1, 1], - [2, 0], - [0, 1], - [2, 2], - [1, 2], - ] - p_blue_pos = [ - [0, 0], - [0, 0], - [1, 1], - [1, 1], - [1, 1], - [0, 0], - [0, 1], - [2, 0], - [1, 2], - [2, 2], - ] - p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - c_red_pos = [ - [1, 1], - [0, 0], - [0, 1], - [0, 0], - [0, 0], - [2, 2], - [0, 0], - [0, 0], - [0, 0], - [2, 1], - ] - c_blue_pos = [ - [0, 0], - [1, 1], - [0, 0], - [0, 1], - [2, 2], - [0, 0], - [0, 0], - [0, 0], - [2, 1], - [0, 0], - ] - max_steps, batch_size, grid_size = 10, 52, 3 - n_steps = max_steps - envs = init_several_env( - max_steps, batch_size, grid_size, same_obs_for_each_player=False - ) - - batch_deltas = [ - i % max_steps if i % 2 == 0 else i % max_steps - 1 - for i in range(batch_size) - ] - - for env_i, env in enumerate(envs): - obs = env.reset() - assert_obs_is_symmetrical(obs, env) - step_i = 0 - - for _ in range(n_steps): - overwrite_pos( - step_i, - batch_deltas, - max_steps, - env, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - ) - actions = { - "player_red": [ - p_red_act[(step_i + delta) % max_steps] - for delta in batch_deltas - ], - "player_blue": [ - p_blue_act[(step_i + delta) % max_steps] - for delta in batch_deltas - ], - } - _, _, _, _ = env.step(actions) - - step_i += 1 - - if step_i == max_steps: - break - - -def test_observations_are_not_invariant_to_the_player_trained_wt_step(): - p_red_pos = [ - [0, 0], - [0, 0], - [1, 1], - [1, 1], - [0, 0], - [1, 1], - [2, 0], - [0, 1], - [2, 2], - [1, 2], - ] - p_blue_pos = [ - [0, 0], - [0, 0], - [1, 1], - [1, 1], - [1, 1], - [0, 0], - [0, 1], - [2, 0], - [1, 2], - [2, 2], - ] - p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - c_red_pos = [ - [1, 1], - [0, 0], - [0, 1], - [0, 0], - [0, 0], - [2, 2], - [0, 0], - [0, 0], - [0, 0], - [2, 1], - ] - c_blue_pos = [ - [0, 0], - [1, 1], - [0, 0], - [0, 1], - [2, 2], - [0, 0], - [0, 0], - [0, 0], - [2, 1], - [0, 0], - ] - max_steps, batch_size, grid_size = 10, 52, 3 - n_steps = max_steps - envs = init_several_env( - max_steps, batch_size, grid_size, same_obs_for_each_player=True - ) - - batch_deltas = [ - i % max_steps if i % 2 == 0 else i % max_steps - 1 - for i in range(batch_size) - ] - - for env_i, env in enumerate(envs): - _ = env.reset() - step_i = 0 - - for _ in range(n_steps): - overwrite_pos( - step_i, - batch_deltas, - max_steps, - env, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - ) - actions = { - "player_red": [ - p_red_act[(step_i + delta) % max_steps] - for delta in batch_deltas - ], - "player_blue": [ - p_blue_act[(step_i + delta) % max_steps] - for delta in batch_deltas - ], - } - obs, reward, done, info = env.step(actions) - - step_i += 1 - # assert that observations are not - # symmetrical respective to the - # actions - if step_i % 2 == 1: - obs_step_odd = obs - elif step_i % 2 == 0: - assert np.any( - obs[env.players_ids[0]] != obs_step_odd[env.players_ids[1]] - ) - assert np.any( - obs[env.players_ids[1]] != obs_step_odd[env.players_ids[0]] - ) - assert_obs_is_not_symmetrical(obs, env) - - if step_i == max_steps: - break - - -def test_observations_are_not_invariant_to_the_player_trained_wt_reset(): - p_red_pos = [ - [0, 0], - [0, 0], - [1, 1], - [1, 1], - [0, 0], - [1, 1], - [2, 0], - [0, 1], - [2, 2], - [1, 2], - ] - p_blue_pos = [ - [0, 0], - [0, 0], - [1, 1], - [1, 1], - [1, 1], - [0, 0], - [0, 1], - [2, 0], - [1, 2], - [2, 2], - ] - p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - c_red_pos = [ - [1, 1], - [0, 0], - [0, 1], - [0, 0], - [0, 0], - [2, 2], - [0, 0], - [0, 0], - [0, 0], - [2, 1], - ] - c_blue_pos = [ - [0, 0], - [1, 1], - [0, 0], - [0, 1], - [2, 2], - [0, 0], - [0, 0], - [0, 0], - [2, 1], - [0, 0], - ] - max_steps, batch_size, grid_size = 10, 52, 3 - n_steps = max_steps - envs = init_several_env( - max_steps, batch_size, grid_size, same_obs_for_each_player=True - ) - - batch_deltas = [ - i % max_steps if i % 2 == 0 else i % max_steps - 1 - for i in range(batch_size) - ] - - for env_i, env in enumerate(envs): - obs = env.reset() - assert_obs_is_not_symmetrical(obs, env) - step_i = 0 - - for _ in range(n_steps): - overwrite_pos( - step_i, - batch_deltas, - max_steps, - env, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - ) - actions = { - "player_red": [ - p_red_act[(step_i + delta) % max_steps] - for delta in batch_deltas - ], - "player_blue": [ - p_blue_act[(step_i + delta) % max_steps] - for delta in batch_deltas - ], - } - _, _, _, _ = env.step(actions) - - step_i += 1 - - if step_i == max_steps: - break diff --git a/tests/marltoolbox/envs/test_ssd_mixed_motive_coin_game.py b/tests/marltoolbox/envs/test_ssd_mixed_motive_coin_game.py index 6d0c04f..63697b7 100644 --- a/tests/marltoolbox/envs/test_ssd_mixed_motive_coin_game.py +++ b/tests/marltoolbox/envs/test_ssd_mixed_motive_coin_game.py @@ -3,152 +3,86 @@ import numpy as np from marltoolbox.envs.ssd_mixed_motive_coin_game import SSDMixedMotiveCoinGame +from coin_game_tests_utils import ( + check_custom_obs, + assert_logger_buffer_size, + helper_test_reset, + helper_test_step, + init_several_envs, + helper_test_multiple_steps, + helper_test_multi_ple_episodes, + helper_assert_info, +) # TODO add tests for grid_size != 3 -def test_reset(): - max_steps, grid_size = 20, 3 - envs = init_several_env(max_steps, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - -def init_several_env( +def init_my_envs( max_steps, grid_size, players_can_pick_same_coin=True, same_obs_for_each_player=True, ): - mixed_motive_coin_game = init_env( - max_steps, - SSDMixedMotiveCoinGame, - grid_size, + return init_several_envs( + (SSDMixedMotiveCoinGame,), + max_steps=max_steps, + grid_size=grid_size, players_can_pick_same_coin=players_can_pick_same_coin, same_obs_for_each_player=same_obs_for_each_player, ) - return [mixed_motive_coin_game] - - -def init_env( - max_steps, - env_class, - seed=None, - grid_size=3, - players_can_pick_same_coin=True, - same_obs_for_each_player=False, -): - config = { - "max_steps": max_steps, - "grid_size": grid_size, - "both_players_can_pick_the_same_coin": players_can_pick_same_coin, - "same_obs_for_each_player": same_obs_for_each_player, - } - env = env_class(config) - env.seed(seed) - return env def check_obs(obs, grid_size): - assert len(obs) == 2, "two players" - for key, player_obs in obs.items(): - assert player_obs.shape == (grid_size, grid_size, 6) - assert ( - player_obs[..., 0].sum() == 1.0 - ), f"observe 1 player red in grid: {player_obs[..., 0]}" - assert ( - player_obs[..., 1].sum() == 1.0 - ), f"observe 1 player blue in grid: {player_obs[..., 1]}" - assert ( - player_obs[..., 2:].sum() == 2.0 - ), f"observe 1 coin in grid: {player_obs[..., 0]}" - - -def assert_logger_buffer_size(env, n_steps): - assert len(env.red_pick) == n_steps - assert len(env.red_pick_own) == n_steps - assert len(env.blue_pick) == n_steps - assert len(env.blue_pick_own) == n_steps + check_custom_obs(obs, grid_size, n_in_2_and_above=2.0, n_layers=6) -def test_step(): +def test_reset(): max_steps, grid_size = 20, 3 - envs = init_several_env(max_steps, grid_size) + envs = init_my_envs(max_steps, grid_size) + helper_test_reset(envs, check_obs, grid_size=grid_size) - for env in envs: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - actions = { - policy_id: random.randint(0, env.NUM_ACTIONS - 1) - for policy_id in env.players_ids - } - obs, reward, done, info = env.step(actions) - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=1) - assert not done["__all__"] +def test_step(): + max_steps, grid_size = 20, 3 + envs = init_my_envs(max_steps, grid_size) + helper_test_step(envs, check_obs, grid_size=grid_size) def test_multiple_steps(): max_steps, grid_size = 20, 3 n_steps = int(max_steps * 0.75) - envs = init_several_env(max_steps, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - for step_i in range(1, n_steps, 1): - actions = { - policy_id: random.randint(0, env.NUM_ACTIONS - 1) - for policy_id in env.players_ids - } - obs, reward, done, info = env.step(actions) - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=step_i) - assert not done["__all__"] + envs = init_my_envs(max_steps, grid_size) + helper_test_multiple_steps( + envs, + n_steps, + check_obs, + grid_size=grid_size, + ) def test_multiple_episodes(): max_steps, grid_size = 20, 3 n_steps = int(max_steps * 8.25) - envs = init_several_env(max_steps, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - step_i = 0 - for _ in range(n_steps): - step_i += 1 - actions = { - policy_id: random.randint(0, env.NUM_ACTIONS - 1) - for policy_id in env.players_ids - } - obs, reward, done, info = env.step(actions) - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=step_i) - assert not done["__all__"] or ( - step_i == max_steps and done["__all__"] - ) - if done["__all__"]: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - step_i = 0 + envs = init_my_envs(max_steps, grid_size) + helper_test_multi_ple_episodes( + envs, + n_steps, + max_steps, + check_obs, + grid_size=grid_size, + ) def overwrite_pos( - env, p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, c_red_coin=True + env, + p_red_pos, + p_blue_pos, + c_red_pos, + c_blue_pos, + **kwargs, ): - env.red_coin = c_red_coin + env.red_coin = True env.red_pos = p_red_pos env.blue_pos = p_blue_pos env.red_coin_pos = c_red_pos @@ -160,75 +94,6 @@ def overwrite_pos( env.blue_coin_pos = np.array(env.blue_coin_pos) -def assert_info( - n_steps, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed, - blue_speed, - red_own, - blue_own, - blue_coop_fraction=None, - red_coop_fraction=None, - red_coop_speed=None, - blue_coop_speed=None, -): - step_i = 0 - for _ in range(n_steps): - step_i += 1 - actions = { - "player_red": p_red_act[step_i - 1], - "player_blue": p_blue_act[step_i - 1], - } - obs, reward, done, info = env.step(actions) - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=step_i) - assert not done["__all__"] or (step_i == max_steps and done["__all__"]) - - if done["__all__"]: - assert info["player_red"]["pick_speed"] == red_speed - assert info["player_blue"]["pick_speed"] == blue_speed - - assert_not_present_in_dict_or_equal( - "pick_own_color", red_own, info, "player_red" - ) - assert_not_present_in_dict_or_equal( - "pick_own_color", blue_own, info, "player_blue" - ) - assert_not_present_in_dict_or_equal( - "blue_coop_fraction", blue_coop_fraction, info, "player_blue" - ) - assert_not_present_in_dict_or_equal( - "red_coop_fraction", red_coop_fraction, info, "player_red" - ) - assert_not_present_in_dict_or_equal( - "red_coop_speed", red_coop_speed, info, "player_red" - ) - assert_not_present_in_dict_or_equal( - "blue_coop_speed", blue_coop_speed, info, "player_blue" - ) - - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - step_i = 0 - - overwrite_pos( - env, - p_red_pos[step_i], - p_blue_pos[step_i], - c_red_pos[step_i], - c_blue_pos[step_i], - ) - - def assert_not_present_in_dict_or_equal(key, value, info, player): if value is None: assert key not in info[player] @@ -245,36 +110,30 @@ def test_logged_info_no_picking(): c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0] - ) - - assert_info( - n_steps, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=0.0, - red_own=None, - blue_own=None, - red_coop_speed=0.0, - blue_coop_speed=0.0, - red_coop_fraction=None, - blue_coop_fraction=None, - ) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=0.0, + red_own=None, + blue_own=None, + red_coop_speed=0.0, + blue_coop_speed=0.0, + red_coop_fraction=None, + blue_coop_fraction=None, + ) def test_logged_info__red_pick_red_all_the_time(): @@ -286,36 +145,30 @@ def test_logged_info__red_pick_red_all_the_time(): c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0] - ) - - assert_info( - n_steps, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=0.0, - red_own=None, - blue_own=None, - red_coop_speed=0.0, - blue_coop_speed=0.0, - red_coop_fraction=None, - blue_coop_fraction=None, - ) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=0.0, + red_own=None, + blue_own=None, + red_coop_speed=0.0, + blue_coop_speed=0.0, + red_coop_fraction=None, + blue_coop_fraction=None, + ) def test_logged_info__blue_pick_red_all_the_time(): @@ -327,36 +180,30 @@ def test_logged_info__blue_pick_red_all_the_time(): c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0] - ) - - assert_info( - n_steps, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=0.0, - red_own=None, - blue_own=None, - red_coop_speed=0.0, - blue_coop_speed=0.0, - red_coop_fraction=None, - blue_coop_fraction=None, - ) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=0.0, + red_own=None, + blue_own=None, + red_coop_speed=0.0, + blue_coop_speed=0.0, + red_coop_fraction=None, + blue_coop_fraction=None, + ) def test_logged_info__blue_pick_blue_all_the_time(): @@ -368,36 +215,30 @@ def test_logged_info__blue_pick_blue_all_the_time(): c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0] - ) - - assert_info( - n_steps, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=1.0, - red_own=None, - blue_own=1.0, - red_coop_speed=0.0, - blue_coop_speed=0.0, - red_coop_fraction=None, - blue_coop_fraction=0.0, - ) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=1.0, + red_own=None, + blue_own=1.0, + red_coop_speed=0.0, + blue_coop_speed=0.0, + red_coop_fraction=None, + blue_coop_fraction=0.0, + ) def test_logged_info__red_pick_blue_all_the_time(): @@ -409,36 +250,30 @@ def test_logged_info__red_pick_blue_all_the_time(): c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0] - ) - - assert_info( - n_steps, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=0.0, - red_own=None, - blue_own=None, - red_coop_speed=0.0, - blue_coop_speed=0.0, - red_coop_fraction=None, - blue_coop_fraction=None, - ) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=0.0, + red_own=None, + blue_own=None, + red_coop_speed=0.0, + blue_coop_speed=0.0, + red_coop_fraction=None, + blue_coop_fraction=None, + ) def test_logged_info__both_pick_blue_all_the_time(): @@ -450,36 +285,30 @@ def test_logged_info__both_pick_blue_all_the_time(): c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0] - ) - - assert_info( - n_steps, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=1.0, - red_own=None, - blue_own=1.0, - red_coop_speed=0.0, - blue_coop_speed=0.0, - red_coop_fraction=None, - blue_coop_fraction=0.0, - ) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=1.0, + red_own=None, + blue_own=1.0, + red_coop_speed=0.0, + blue_coop_speed=0.0, + red_coop_fraction=None, + blue_coop_fraction=0.0, + ) def test_logged_info__both_pick_red_all_the_time(): @@ -491,36 +320,30 @@ def test_logged_info__both_pick_red_all_the_time(): c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0] - ) - - assert_info( - n_steps, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=1.0, - blue_speed=1.0, - red_own=1.0, - blue_own=0.0, - red_coop_speed=1.0, - blue_coop_speed=0.0, - red_coop_fraction=1.0, - blue_coop_fraction=1.0, - ) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=1.0, + blue_speed=1.0, + red_own=1.0, + blue_own=0.0, + red_coop_speed=1.0, + blue_coop_speed=0.0, + red_coop_fraction=1.0, + blue_coop_fraction=1.0, + ) def test_logged_info__both_pick_red_half_the_time(): @@ -532,36 +355,30 @@ def test_logged_info__both_pick_red_half_the_time(): c_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0] - ) - - assert_info( - n_steps, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=0.0, - red_own=None, - blue_own=None, - red_coop_speed=0.0, - blue_coop_speed=0.0, - red_coop_fraction=None, - blue_coop_fraction=None, - ) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=0.0, + red_own=None, + blue_own=None, + red_coop_speed=0.0, + blue_coop_speed=0.0, + red_coop_fraction=None, + blue_coop_fraction=None, + ) def test_logged_info__both_pick_blue_half_the_time(): @@ -573,36 +390,30 @@ def test_logged_info__both_pick_blue_half_the_time(): c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0] - ) - - assert_info( - n_steps, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=0.5, - red_own=None, - blue_own=1.0, - red_coop_speed=0.0, - blue_coop_speed=0.0, - red_coop_fraction=None, - blue_coop_fraction=0.0, - ) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=0.5, + red_own=None, + blue_own=1.0, + red_coop_speed=0.0, + blue_coop_speed=0.0, + red_coop_fraction=None, + blue_coop_fraction=0.0, + ) def test_logged_info__both_pick_blue(): @@ -614,36 +425,30 @@ def test_logged_info__both_pick_blue(): c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0] - ) - - assert_info( - n_steps, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=0.5, - red_own=None, - blue_own=1.0, - red_coop_speed=0.0, - blue_coop_speed=0.0, - red_coop_fraction=None, - blue_coop_fraction=0.0, - ) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=0.5, + red_own=None, + blue_own=1.0, + red_coop_speed=0.0, + blue_coop_speed=0.0, + red_coop_fraction=None, + blue_coop_fraction=0.0, + ) def test_logged_info__pick_half_the_time_half_blue_half_red(): @@ -655,36 +460,30 @@ def test_logged_info__pick_half_the_time_half_blue_half_red(): c_blue_pos = [[0, 0], [1, 1], [0, 0], [1, 1]] max_steps, grid_size = 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, grid_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, grid_size) - assert_logger_buffer_size(env, n_steps=0) - overwrite_pos( - env, p_red_pos[0], p_blue_pos[0], c_red_pos[0], c_blue_pos[0] - ) - - assert_info( - n_steps, - p_red_act, - p_blue_act, - env, - grid_size, - max_steps, - p_red_pos, - p_blue_pos, - c_red_pos, - c_blue_pos, - red_speed=0.0, - blue_speed=0.25, - red_own=None, - blue_own=1.0, - red_coop_speed=0.0, - blue_coop_speed=0.0, - red_coop_fraction=None, - blue_coop_fraction=0.0, - ) + envs = init_my_envs(max_steps, grid_size) + + helper_assert_info( + n_steps=n_steps, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=0.25, + red_own=None, + blue_own=1.0, + red_coop_speed=0.0, + blue_coop_speed=0.0, + red_coop_fraction=None, + blue_coop_fraction=0.0, + ) def test_observations_are_not_invariant_to_the_player_trained_in_reset(): @@ -740,9 +539,7 @@ def test_observations_are_not_invariant_to_the_player_trained_in_reset(): ] max_steps, grid_size = 10, 3 n_steps = max_steps - envs = init_several_env( - max_steps, grid_size, same_obs_for_each_player=True - ) + envs = init_my_envs(max_steps, grid_size, same_obs_for_each_player=True) for env_i, env in enumerate(envs): obs = env.reset() @@ -833,9 +630,7 @@ def test_observations_are_not_invariant_to_the_player_trained_in_step(): ] max_steps, grid_size = 10, 3 n_steps = max_steps - envs = init_several_env( - max_steps, grid_size, same_obs_for_each_player=True - ) + envs = init_my_envs(max_steps, grid_size, same_obs_for_each_player=True) for env_i, env in enumerate(envs): _ = env.reset() diff --git a/tests/marltoolbox/envs/test_vectorized_coin_game.py b/tests/marltoolbox/envs/test_vectorized_coin_game.py index 011ab7c..080e42c 100644 --- a/tests/marltoolbox/envs/test_vectorized_coin_game.py +++ b/tests/marltoolbox/envs/test_vectorized_coin_game.py @@ -4,160 +4,128 @@ import numpy as np from flaky import flaky -from marltoolbox.envs.vectorized_coin_game import VectorizedCoinGame, \ - AsymVectorizedCoinGame -from test_coin_game import \ - assert_obs_is_symmetrical, assert_obs_is_not_symmetrical +from coin_game_tests_utils import ( + check_custom_obs, + helper_test_reset, + helper_test_step, + init_several_envs, + helper_test_multiple_steps, + helper_test_multi_ple_episodes, + helper_assert_info, + shift_consistently, +) +from marltoolbox.envs.vectorized_coin_game import ( + VectorizedCoinGame, + AsymVectorizedCoinGame, +) +from test_coin_game import ( + assert_obs_is_symmetrical, + assert_obs_is_not_symmetrical, +) # TODO add tests for grid_size != 3 -def test_reset(): - max_steps, batch_size, grid_size = 20, 5, 3 - envs = init_several_env(max_steps, batch_size, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - -def init_several_env(max_steps, batch_size, grid_size, - players_can_pick_same_coin=True, - same_obs_for_each_player=True): - coin_game = init_env(max_steps, batch_size, VectorizedCoinGame, grid_size, - players_can_pick_same_coin=players_can_pick_same_coin, - same_obs_for_each_player=same_obs_for_each_player) - asymm_coin_game = \ - init_env(max_steps, batch_size, AsymVectorizedCoinGame, grid_size, - players_can_pick_same_coin=players_can_pick_same_coin, - same_obs_for_each_player=same_obs_for_each_player) - return [coin_game, asymm_coin_game] - - -def init_env(max_steps, batch_size, env_class, seed=None, grid_size=3, - players_can_pick_same_coin=True, - same_obs_for_each_player=False): - config = { - "max_steps": max_steps, - "batch_size": batch_size, - "grid_size": grid_size, - "same_obs_for_each_player": same_obs_for_each_player, - "both_players_can_pick_the_same_coin": players_can_pick_same_coin, - } - env = env_class(config) - env.seed(seed) - return env +def init_my_envs( + max_steps, + batch_size, + grid_size, + players_can_pick_same_coin=True, + same_obs_for_each_player=True, +): + return init_several_envs( + classes=(VectorizedCoinGame, AsymVectorizedCoinGame), + max_steps=max_steps, + grid_size=grid_size, + batch_size=batch_size, + players_can_pick_same_coin=players_can_pick_same_coin, + same_obs_for_each_player=same_obs_for_each_player, + ) def check_obs(obs, batch_size, grid_size): - assert len(obs) == 2, "two players" - for i in range(batch_size): - for key, player_obs in obs.items(): - assert player_obs.shape == (batch_size, grid_size, grid_size, 4) - assert player_obs[i, ..., 0].sum() == 1.0, \ - f"observe 1 player red in grid: {player_obs[i, ..., 0]}" - assert player_obs[i, ..., 1].sum() == 1.0, \ - f"observe 1 player blue in grid: {player_obs[i, ..., 1]}" - assert player_obs[i, ..., 2:].sum() == 1.0, \ - f"observe 1 coin in grid: {player_obs[i, ..., 0]}" + check_custom_obs(obs, grid_size, batch_size=batch_size) -def assert_logger_buffer_size(env, n_steps): - assert len(env.red_pick) == n_steps - assert len(env.red_pick_own) == n_steps - assert len(env.blue_pick) == n_steps - assert len(env.blue_pick_own) == n_steps +def test_reset(): + max_steps, batch_size, grid_size = 20, 5, 3 + envs = init_my_envs(max_steps, batch_size, grid_size) + helper_test_reset( + envs, check_obs, grid_size=grid_size, batch_size=batch_size + ) def test_step(): max_steps, batch_size, grid_size = 20, 5, 3 - envs = init_several_env(max_steps, batch_size, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - actions = {policy_id: [random.randint(0, env.NUM_ACTIONS - 1) - for _ in range(batch_size)] - for policy_id in env.players_ids} - obs, reward, done, info = env.step(actions) - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=1) - assert not done["__all__"] + envs = init_my_envs(max_steps, batch_size, grid_size) + helper_test_step( + envs, + check_obs, + grid_size=grid_size, + batch_size=batch_size, + ) def test_multiple_steps(): max_steps, batch_size, grid_size = 20, 5, 3 n_steps = int(max_steps * 0.75) - envs = init_several_env(max_steps, batch_size, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - for step_i in range(1, n_steps, 1): - actions = { - policy_id: [random.randint(0, env.NUM_ACTIONS - 1) - for _ in range(batch_size)] - for policy_id in env.players_ids} - obs, reward, done, info = env.step(actions) - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=step_i) - assert not done["__all__"] + envs = init_my_envs(max_steps, batch_size, grid_size) + helper_test_multiple_steps( + envs, + n_steps, + check_obs, + grid_size=grid_size, + batch_size=batch_size, + ) def test_multiple_episodes(): max_steps, batch_size, grid_size = 20, 100, 3 n_steps = int(max_steps * 8.25) - envs = init_several_env(max_steps, batch_size, grid_size) - - for env in envs: - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - step_i = 0 - for _ in range(n_steps): - step_i += 1 - actions = { - policy_id: [random.randint(0, env.NUM_ACTIONS - 1) - for _ in range(batch_size)] - for policy_id in env.players_ids} - obs, reward, done, info = env.step(actions) - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=step_i) - assert not done["__all__"] or ( - step_i == max_steps and done["__all__"]) - if done["__all__"]: - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - step_i = 0 - - -def overwrite_pos(step_i, batch_deltas, n_steps_in_epi, env, p_red_pos, - p_blue_pos, c_red_pos, c_blue_pos): + envs = init_my_envs(max_steps, batch_size, grid_size) + helper_test_multi_ple_episodes( + envs, + n_steps, + max_steps, + check_obs, + grid_size=grid_size, + batch_size=batch_size, + ) + + +def overwrite_pos( + step_i, + batch_deltas, + n_steps_in_epi, + env, + p_red_pos, + p_blue_pos, + c_red_pos, + c_blue_pos, + **kwargs, +): assert len(p_red_pos) == n_steps_in_epi assert len(p_blue_pos) == n_steps_in_epi assert len(c_red_pos) == n_steps_in_epi assert len(c_blue_pos) == n_steps_in_epi - env.red_coin = [0 - if c_red_pos[(step_i + delta) % n_steps_in_epi] is None - else 1 - for delta in batch_deltas] - coin_pos = [c_blue_pos[(step_i + delta) % n_steps_in_epi] - if c_red_pos[(step_i + delta) % n_steps_in_epi] is None - else c_red_pos[(step_i + delta) % n_steps_in_epi] - for delta in batch_deltas] - - env.red_pos = [p_red_pos[(step_i + delta) % n_steps_in_epi] for delta in - batch_deltas] - env.blue_pos = [p_blue_pos[(step_i + delta) % n_steps_in_epi] for delta in - batch_deltas] + env.red_coin = [ + 0 if c_red_pos[(step_i + delta) % n_steps_in_epi] is None else 1 + for delta in batch_deltas + ] + coin_pos = [ + c_blue_pos[(step_i + delta) % n_steps_in_epi] + if c_red_pos[(step_i + delta) % n_steps_in_epi] is None + else c_red_pos[(step_i + delta) % n_steps_in_epi] + for delta in batch_deltas + ] + + env.red_pos = shift_consistently( + p_red_pos, step_i, n_steps_in_epi, batch_deltas + ) + env.blue_pos = shift_consistently( + p_blue_pos, step_i, n_steps_in_epi, batch_deltas + ) env.coin_pos = coin_pos env.red_pos = np.array(env.red_pos) @@ -166,52 +134,6 @@ def overwrite_pos(step_i, batch_deltas, n_steps_in_epi, env, p_red_pos, env.red_coin = np.array(env.red_coin) -def assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, env, - grid_size, n_steps_in_epi, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed, blue_speed, red_own, blue_own): - step_i = 0 - delta_err = 0.01 - - for _ in range(n_steps): - overwrite_pos(step_i, batch_deltas, n_steps_in_epi, env, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos) - actions = {"player_red": [p_red_act[(step_i + delta) % n_steps_in_epi] - for delta in batch_deltas], - "player_blue": [ - p_blue_act[(step_i + delta) % n_steps_in_epi] - for delta in batch_deltas]} - step_i += 1 - - obs, reward, done, info = env.step(actions) - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=step_i) - assert not done["__all__"] or ( - step_i == n_steps_in_epi and done["__all__"]) - - if done["__all__"]: - assert abs(info["player_red"]["pick_speed"] - red_speed) \ - < delta_err - assert abs(info["player_blue"]["pick_speed"] - blue_speed) \ - < delta_err - - if red_own is None: - assert "pick_own_color" not in info["player_red"] - else: - assert abs(info["player_red"]["pick_own_color"] - red_own) \ - < delta_err - if blue_own is None: - assert "pick_own_color" not in info["player_blue"] - else: - assert abs(info["player_blue"]["pick_own_color"] - blue_own) \ - < delta_err - - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - step_i = 0 - - def test_logged_info_no_picking(): p_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] p_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] @@ -222,30 +144,51 @@ def test_logged_info_no_picking(): max_steps, batch_size, grid_size = 4, 28, 3 n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - for env in envs: - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None) - - envs = init_several_env(max_steps, batch_size, grid_size, - players_can_pick_same_coin=False) - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - for env in envs: - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=0.0, red_own=None, blue_own=None) + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=0.0, + red_own=None, + blue_own=None, + ) + + envs = init_my_envs( + max_steps, batch_size, grid_size, players_can_pick_same_coin=False + ) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=0.0, + red_own=None, + blue_own=None, + ) def test_logged_info__red_pick_red_all_the_time(): @@ -258,30 +201,51 @@ def test_logged_info__red_pick_red_all_the_time(): max_steps, batch_size, grid_size = 4, 28, 3 n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=1.0, blue_speed=0.0, red_own=1.0, blue_own=None) - - envs = init_several_env(max_steps, batch_size, grid_size, - players_can_pick_same_coin=False) - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=1.0, blue_speed=0.0, red_own=1.0, blue_own=None) + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=1.0, + blue_speed=0.0, + red_own=1.0, + blue_own=None, + ) + + envs = init_my_envs( + max_steps, batch_size, grid_size, players_can_pick_same_coin=False + ) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=1.0, + blue_speed=0.0, + red_own=1.0, + blue_own=None, + ) def test_logged_info__blue_pick_red_all_the_time(): @@ -294,30 +258,51 @@ def test_logged_info__blue_pick_red_all_the_time(): max_steps, batch_size, grid_size = 4, 28, 3 n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=1.0, red_own=None, blue_own=0.0) - - envs = init_several_env(max_steps, batch_size, grid_size, - players_can_pick_same_coin=False) - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=1.0, red_own=None, blue_own=0.0) + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=1.0, + red_own=None, + blue_own=0.0, + ) + + envs = init_my_envs( + max_steps, batch_size, grid_size, players_can_pick_same_coin=False + ) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=1.0, + red_own=None, + blue_own=0.0, + ) def test_logged_info__blue_pick_blue_all_the_time(): @@ -330,30 +315,51 @@ def test_logged_info__blue_pick_blue_all_the_time(): max_steps, batch_size, grid_size = 4, 28, 3 n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=1.0, red_own=None, blue_own=1.0) - - envs = init_several_env(max_steps, batch_size, grid_size, - players_can_pick_same_coin=False) - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.0, blue_speed=1.0, red_own=None, blue_own=1.0) + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=1.0, + red_own=None, + blue_own=1.0, + ) + + envs = init_my_envs( + max_steps, batch_size, grid_size, players_can_pick_same_coin=False + ) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=1.0, + red_own=None, + blue_own=1.0, + ) def test_logged_info__red_pick_blue_all_the_time(): @@ -366,30 +372,51 @@ def test_logged_info__red_pick_blue_all_the_time(): max_steps, batch_size, grid_size = 4, 28, 3 n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=1.0, blue_speed=0.0, red_own=0.0, blue_own=None) - - envs = init_several_env(max_steps, batch_size, grid_size, - players_can_pick_same_coin=False) - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=1.0, blue_speed=0.0, red_own=0.0, blue_own=None) + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=1.0, + blue_speed=0.0, + red_own=0.0, + blue_own=None, + ) + + envs = init_my_envs( + max_steps, batch_size, grid_size, players_can_pick_same_coin=False + ) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=1.0, + blue_speed=0.0, + red_own=0.0, + blue_own=None, + ) def test_logged_info__red_pick_blue_all_the_time_wt_difference_in_actions(): @@ -402,30 +429,51 @@ def test_logged_info__red_pick_blue_all_the_time_wt_difference_in_actions(): max_steps, batch_size, grid_size = 4, 4, 3 n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=1.0, blue_speed=0.0, red_own=0.0, blue_own=None) - - envs = init_several_env(max_steps, batch_size, grid_size, - players_can_pick_same_coin=False) - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=1.0, blue_speed=0.0, red_own=0.0, blue_own=None) + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=1.0, + blue_speed=0.0, + red_own=0.0, + blue_own=None, + ) + + envs = init_my_envs( + max_steps, batch_size, grid_size, players_can_pick_same_coin=False + ) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=1.0, + blue_speed=0.0, + red_own=0.0, + blue_own=None, + ) def test_logged_info__both_pick_blue_all_the_time(): @@ -438,17 +486,27 @@ def test_logged_info__both_pick_blue_all_the_time(): max_steps, batch_size, grid_size = 4, 28, 3 n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=1.0, blue_speed=1.0, red_own=0.0, blue_own=1.0) + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=1.0, + blue_speed=1.0, + red_own=0.0, + blue_own=1.0, + ) def test_logged_info__both_pick_red_all_the_time(): @@ -460,19 +518,27 @@ def test_logged_info__both_pick_red_all_the_time(): c_blue_pos = [None, None, None, None] max_steps, batch_size, grid_size = 4, 28, 3 n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=1.0, blue_speed=1.0, red_own=1.0, blue_own=0.0) + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=1.0, + blue_speed=1.0, + red_own=1.0, + blue_own=0.0, + ) def test_logged_info__both_pick_red_half_the_time(): @@ -484,19 +550,27 @@ def test_logged_info__both_pick_red_half_the_time(): c_blue_pos = [None, None, None, None] max_steps, batch_size, grid_size = 4, 28, 3 n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.5, blue_speed=0.5, red_own=1.0, blue_own=0.0) + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.5, + blue_speed=0.5, + red_own=1.0, + blue_own=0.0, + ) def test_logged_info__both_pick_blue_half_the_time(): @@ -508,19 +582,27 @@ def test_logged_info__both_pick_blue_half_the_time(): c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] max_steps, batch_size, grid_size = 4, 28, 3 n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.5, blue_speed=0.5, red_own=0.0, blue_own=1.0) + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.5, + blue_speed=0.5, + red_own=0.0, + blue_own=1.0, + ) def test_logged_info__both_pick_blue(): @@ -532,19 +614,27 @@ def test_logged_info__both_pick_blue(): c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] max_steps, batch_size, grid_size = 4, 28, 3 n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.25, blue_speed=0.5, red_own=0.0, blue_own=1.0) + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.25, + blue_speed=0.5, + red_own=0.0, + blue_own=1.0, + ) def test_logged_info__pick_half_the_time_half_blue_half_red(): @@ -556,25 +646,33 @@ def test_logged_info__pick_half_the_time_half_blue_half_red(): c_blue_pos = [None, [1, 1], None, [1, 1]] max_steps, batch_size, grid_size = 4, 28, 3 n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.5, blue_speed=0.5, red_own=0.5, blue_own=0.5) + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.5, + blue_speed=0.5, + red_own=0.5, + blue_own=0.5, + ) def test_get_and_set_env_state(): max_steps, batch_size, grid_size = 20, 100, 3 n_steps = int(max_steps * 8.25) - envs = init_several_env(max_steps, batch_size, grid_size) + envs = init_my_envs(max_steps, batch_size, grid_size) for env in envs: obs = env.reset() @@ -585,15 +683,23 @@ def test_get_and_set_env_state(): step_i = 0 for _ in range(n_steps): step_i += 1 - actions = {policy_id: [random.randint(0, env.NUM_ACTIONS - 1) - for _ in range(batch_size)] - for policy_id in env.players_ids} + actions = { + policy_id: [ + random.randint(0, env.NUM_ACTIONS - 1) + for _ in range(batch_size) + ] + for policy_id in env.players_ids + } obs, reward, done, info = env.step(actions) - assert all([v == initial_env_state_saved[k] - if not isinstance(v, np.ndarray) - else (v == initial_env_state_saved[k]).all() - for k, v in initial_env_state.items()]) + assert all( + [ + v == initial_env_state_saved[k] + if not isinstance(v, np.ndarray) + else (v == initial_env_state_saved[k]).all() + for k, v in initial_env_state.items() + ] + ) env_state_after_step = env._save_env() env_after_step = copy.deepcopy(env) @@ -601,19 +707,27 @@ def test_get_and_set_env_state(): env_vars, env_initial_vars = vars(env), vars(env_initial) env_vars.pop("np_random", None) env_initial_vars.pop("np_random", None) - assert all([v == env_initial_vars[k] - if not isinstance(v, np.ndarray) - else (v == env_initial_vars[k]).all() - for k, v in env_vars.items()]) + assert all( + [ + v == env_initial_vars[k] + if not isinstance(v, np.ndarray) + else (v == env_initial_vars[k]).all() + for k, v in env_vars.items() + ] + ) env._load_env(env_state_after_step) env_vars, env_after_step_vars = vars(env), vars(env_after_step) env_vars.pop("np_random", None) env_after_step_vars.pop("np_random", None) - assert all([v == env_after_step_vars[k] - if not isinstance(v, np.ndarray) - else (v == env_after_step_vars[k]).all() - for k, v in env_vars.items()]) + assert all( + [ + v == env_after_step_vars[k] + if not isinstance(v, np.ndarray) + else (v == env_after_step_vars[k]).all() + for k, v in env_vars.items() + ] + ) if done["__all__"]: obs = env.reset() @@ -621,36 +735,92 @@ def test_get_and_set_env_state(): def test_observations_are_invariant_to_the_player_trained_wt_step(): - p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0], - [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]] - p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1], - [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]] + p_red_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [0, 0], + [1, 1], + [2, 0], + [0, 1], + [2, 2], + [1, 2], + ] + p_blue_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [1, 1], + [0, 0], + [0, 1], + [2, 0], + [1, 2], + [2, 2], + ] p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - c_red_pos = [[1, 1], None, [0, 1], None, None, - [2, 2], [0, 0], None, None, [2, 1]] - c_blue_pos = [None, [1, 1], None, [0, 1], [2, 2], - None, None, [0, 0], [2, 1], None] + c_red_pos = [ + [1, 1], + None, + [0, 1], + None, + None, + [2, 2], + [0, 0], + None, + None, + [2, 1], + ] + c_blue_pos = [ + None, + [1, 1], + None, + [0, 1], + [2, 2], + None, + None, + [0, 0], + [2, 1], + None, + ] max_steps, batch_size, grid_size = 10, 52, 3 n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size, - same_obs_for_each_player=False) + envs = init_my_envs( + max_steps, batch_size, grid_size, same_obs_for_each_player=False + ) - batch_deltas = [i % max_steps if i % 2 == 0 else i % max_steps - 1 - for i in range(batch_size)] + batch_deltas = [ + i % max_steps if i % 2 == 0 else i % max_steps - 1 + for i in range(batch_size) + ] for env_i, env in enumerate(envs): _ = env.reset() step_i = 0 for _ in range(n_steps): - overwrite_pos(step_i, batch_deltas, max_steps, env, p_red_pos, - p_blue_pos, - c_red_pos, c_blue_pos) - actions = {"player_red": [p_red_act[(step_i + delta) % max_steps] - for delta in batch_deltas], - "player_blue": [p_blue_act[(step_i + delta) % max_steps] - for delta in batch_deltas]} + overwrite_pos( + step_i, + batch_deltas, + max_steps, + env, + p_red_pos, + p_blue_pos, + c_red_pos, + c_blue_pos, + ) + actions = { + "player_red": [ + p_red_act[(step_i + delta) % max_steps] + for delta in batch_deltas + ], + "player_blue": [ + p_blue_act[(step_i + delta) % max_steps] + for delta in batch_deltas + ], + } obs, reward, done, info = env.step(actions) step_i += 1 @@ -659,11 +829,11 @@ def test_observations_are_invariant_to_the_player_trained_wt_step(): obs_step_odd = obs elif step_i % 2 == 0: assert np.all( - obs[env.players_ids[0]] == obs_step_odd[ - env.players_ids[1]]) + obs[env.players_ids[0]] == obs_step_odd[env.players_ids[1]] + ) assert np.all( - obs[env.players_ids[1]] == obs_step_odd[ - env.players_ids[0]]) + obs[env.players_ids[1]] == obs_step_odd[env.players_ids[0]] + ) assert_obs_is_symmetrical(obs, env) if step_i == max_steps: @@ -671,23 +841,66 @@ def test_observations_are_invariant_to_the_player_trained_wt_step(): def test_observations_are_invariant_to_the_player_trained_wt_reset(): - p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0], - [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]] - p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1], - [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]] + p_red_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [0, 0], + [1, 1], + [2, 0], + [0, 1], + [2, 2], + [1, 2], + ] + p_blue_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [1, 1], + [0, 0], + [0, 1], + [2, 0], + [1, 2], + [2, 2], + ] p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - c_red_pos = [[1, 1], None, [0, 1], None, None, - [2, 2], [0, 0], None, None, [2, 1]] - c_blue_pos = [None, [1, 1], None, [0, 1], [2, 2], - None, None, [0, 0], [2, 1], None] + c_red_pos = [ + [1, 1], + None, + [0, 1], + None, + None, + [2, 2], + [0, 0], + None, + None, + [2, 1], + ] + c_blue_pos = [ + None, + [1, 1], + None, + [0, 1], + [2, 2], + None, + None, + [0, 0], + [2, 1], + None, + ] max_steps, batch_size, grid_size = 10, 52, 3 n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size, - same_obs_for_each_player=False) + envs = init_my_envs( + max_steps, batch_size, grid_size, same_obs_for_each_player=False + ) - batch_deltas = [i % max_steps if i % 2 == 0 else i % max_steps - 1 - for i in range(batch_size)] + batch_deltas = [ + i % max_steps if i % 2 == 0 else i % max_steps - 1 + for i in range(batch_size) + ] for env_i, env in enumerate(envs): obs = env.reset() @@ -695,12 +908,26 @@ def test_observations_are_invariant_to_the_player_trained_wt_reset(): step_i = 0 for _ in range(n_steps): - overwrite_pos(step_i, batch_deltas, max_steps, env, p_red_pos, - p_blue_pos, c_red_pos, c_blue_pos) - actions = {"player_red": [p_red_act[(step_i + delta) % max_steps] - for delta in batch_deltas], - "player_blue": [p_blue_act[(step_i + delta) % max_steps] - for delta in batch_deltas]} + overwrite_pos( + step_i, + batch_deltas, + max_steps, + env, + p_red_pos, + p_blue_pos, + c_red_pos, + c_blue_pos, + ) + actions = { + "player_red": [ + p_red_act[(step_i + delta) % max_steps] + for delta in batch_deltas + ], + "player_blue": [ + p_blue_act[(step_i + delta) % max_steps] + for delta in batch_deltas + ], + } _, _, _, _ = env.step(actions) step_i += 1 @@ -710,36 +937,92 @@ def test_observations_are_invariant_to_the_player_trained_wt_reset(): def test_observations_are_not_invariant_to_the_player_trained_wt_step(): - p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0], - [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]] - p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1], - [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]] + p_red_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [0, 0], + [1, 1], + [2, 0], + [0, 1], + [2, 2], + [1, 2], + ] + p_blue_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [1, 1], + [0, 0], + [0, 1], + [2, 0], + [1, 2], + [2, 2], + ] p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - c_red_pos = [[1, 1], None, [0, 1], None, None, - [2, 2], [0, 0], None, None, [2, 1]] - c_blue_pos = [None, [1, 1], None, [0, 1], [2, 2], - None, None, [0, 0], [2, 1], None] + c_red_pos = [ + [1, 1], + None, + [0, 1], + None, + None, + [2, 2], + [0, 0], + None, + None, + [2, 1], + ] + c_blue_pos = [ + None, + [1, 1], + None, + [0, 1], + [2, 2], + None, + None, + [0, 0], + [2, 1], + None, + ] max_steps, batch_size, grid_size = 10, 52, 3 n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size, - same_obs_for_each_player=True) + envs = init_my_envs( + max_steps, batch_size, grid_size, same_obs_for_each_player=True + ) - batch_deltas = [i % max_steps if i % 2 == 0 else i % max_steps - 1 - for i in range(batch_size)] + batch_deltas = [ + i % max_steps if i % 2 == 0 else i % max_steps - 1 + for i in range(batch_size) + ] for env_i, env in enumerate(envs): _ = env.reset() step_i = 0 for _ in range(n_steps): - overwrite_pos(step_i, batch_deltas, max_steps, env, p_red_pos, - p_blue_pos, - c_red_pos, c_blue_pos) - actions = {"player_red": [p_red_act[(step_i + delta) % max_steps] - for delta in batch_deltas], - "player_blue": [p_blue_act[(step_i + delta) % max_steps] - for delta in batch_deltas]} + overwrite_pos( + step_i, + batch_deltas, + max_steps, + env, + p_red_pos, + p_blue_pos, + c_red_pos, + c_blue_pos, + ) + actions = { + "player_red": [ + p_red_act[(step_i + delta) % max_steps] + for delta in batch_deltas + ], + "player_blue": [ + p_blue_act[(step_i + delta) % max_steps] + for delta in batch_deltas + ], + } obs, reward, done, info = env.step(actions) step_i += 1 @@ -750,11 +1033,11 @@ def test_observations_are_not_invariant_to_the_player_trained_wt_step(): obs_step_odd = obs elif step_i % 2 == 0: assert np.any( - obs[env.players_ids[0]] != obs_step_odd[env.players_ids[ - 1]]) + obs[env.players_ids[0]] != obs_step_odd[env.players_ids[1]] + ) assert np.any( - obs[env.players_ids[1]] != obs_step_odd[env.players_ids[ - 0]]) + obs[env.players_ids[1]] != obs_step_odd[env.players_ids[0]] + ) assert_obs_is_not_symmetrical(obs, env) if step_i == max_steps: @@ -762,23 +1045,66 @@ def test_observations_are_not_invariant_to_the_player_trained_wt_step(): def test_observations_are_not_invariant_to_the_player_trained_wt_reset(): - p_red_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0], - [1, 1], [2, 0], [0, 1], [2, 2], [1, 2]] - p_blue_pos = [[0, 0], [0, 0], [1, 1], [1, 1], [1, 1], - [0, 0], [0, 1], [2, 0], [1, 2], [2, 2]] + p_red_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [0, 0], + [1, 1], + [2, 0], + [0, 1], + [2, 2], + [1, 2], + ] + p_blue_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [1, 1], + [0, 0], + [0, 1], + [2, 0], + [1, 2], + [2, 2], + ] p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - c_red_pos = [[1, 1], None, [0, 1], None, None, - [2, 2], [0, 0], None, None, [2, 1]] - c_blue_pos = [None, [1, 1], None, [0, 1], [2, 2], - None, None, [0, 0], [2, 1], None] + c_red_pos = [ + [1, 1], + None, + [0, 1], + None, + None, + [2, 2], + [0, 0], + None, + None, + [2, 1], + ] + c_blue_pos = [ + None, + [1, 1], + None, + [0, 1], + [2, 2], + None, + None, + [0, 0], + [2, 1], + None, + ] max_steps, batch_size, grid_size = 10, 52, 3 n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size, - same_obs_for_each_player=True) + envs = init_my_envs( + max_steps, batch_size, grid_size, same_obs_for_each_player=True + ) - batch_deltas = [i % max_steps if i % 2 == 0 else i % max_steps - 1 - for i in range(batch_size)] + batch_deltas = [ + i % max_steps if i % 2 == 0 else i % max_steps - 1 + for i in range(batch_size) + ] for env_i, env in enumerate(envs): obs = env.reset() @@ -786,12 +1112,26 @@ def test_observations_are_not_invariant_to_the_player_trained_wt_reset(): step_i = 0 for _ in range(n_steps): - overwrite_pos(step_i, batch_deltas, max_steps, env, p_red_pos, - p_blue_pos, c_red_pos, c_blue_pos) - actions = {"player_red": [p_red_act[(step_i + delta) % max_steps] - for delta in batch_deltas], - "player_blue": [p_blue_act[(step_i + delta) % max_steps] - for delta in batch_deltas]} + overwrite_pos( + step_i, + batch_deltas, + max_steps, + env, + p_red_pos, + p_blue_pos, + c_red_pos, + c_blue_pos, + ) + actions = { + "player_red": [ + p_red_act[(step_i + delta) % max_steps] + for delta in batch_deltas + ], + "player_blue": [ + p_blue_act[(step_i + delta) % max_steps] + for delta in batch_deltas + ], + } _, _, _, _ = env.step(actions) step_i += 1 @@ -812,27 +1152,50 @@ def test_who_pick_is_random(): max_steps, batch_size, grid_size = int(4 * size), 28, 3 n_steps = max_steps - envs = init_several_env(max_steps, batch_size, grid_size) - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=1.0, blue_speed=1.0, red_own=1.0, blue_own=0.0) - - envs = init_several_env(max_steps, batch_size, grid_size, - players_can_pick_same_coin=False) - batch_deltas = np.random.randint(0, max_steps - 1, size=batch_size) - for env_i, env in enumerate(envs): - obs = env.reset() - check_obs(obs, batch_size, grid_size) - assert_logger_buffer_size(env, n_steps=0) - - assert_info(batch_deltas, n_steps, batch_size, p_red_act, p_blue_act, - env, grid_size, max_steps, - p_red_pos, p_blue_pos, c_red_pos, c_blue_pos, - red_speed=0.5, blue_speed=0.5, red_own=1.0, blue_own=0.0) + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=1.0, + blue_speed=1.0, + red_own=1.0, + blue_own=0.0, + ) + + envs = init_my_envs( + max_steps, batch_size, grid_size, players_can_pick_same_coin=False + ) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.5, + blue_speed=0.5, + red_own=1.0, + blue_own=0.0, + repetitions=1, + delta_err=0.05, + ) diff --git a/tests/marltoolbox/envs/test_vectorized_ssd_mm_coin_game.py b/tests/marltoolbox/envs/test_vectorized_ssd_mm_coin_game.py new file mode 100644 index 0000000..e2654fe --- /dev/null +++ b/tests/marltoolbox/envs/test_vectorized_ssd_mm_coin_game.py @@ -0,0 +1,862 @@ +import copy +import random + +import numpy as np +from flaky import flaky + +from coin_game_tests_utils import ( + check_custom_obs, + helper_test_reset, + helper_test_step, + init_several_envs, + helper_test_multiple_steps, + helper_test_multi_ple_episodes, + helper_assert_info, + shift_consistently, +) +from marltoolbox.envs.vectorized_ssd_mm_coin_game import ( + VectSSDMixedMotiveCG, +) +from test_coin_game import ( + assert_obs_is_symmetrical, + assert_obs_is_not_symmetrical, +) + + +# TODO add tests for grid_size != 3 + + +def init_my_envs( + max_steps, + batch_size, + grid_size, + players_can_pick_same_coin=True, + same_obs_for_each_player=True, +): + return init_several_envs( + classes=(VectSSDMixedMotiveCG,), + max_steps=max_steps, + grid_size=grid_size, + batch_size=batch_size, + players_can_pick_same_coin=players_can_pick_same_coin, + same_obs_for_each_player=same_obs_for_each_player, + ) + + +def check_obs(obs, batch_size, grid_size): + check_custom_obs( + obs, grid_size, batch_size=batch_size, n_layers=6, n_in_2_and_above=2.0 + ) + + +def test_reset(): + max_steps, batch_size, grid_size = 20, 5, 3 + envs = init_my_envs(max_steps, batch_size, grid_size) + helper_test_reset( + envs, check_obs, grid_size=grid_size, batch_size=batch_size + ) + + +def test_step(): + max_steps, batch_size, grid_size = 20, 5, 3 + envs = init_my_envs(max_steps, batch_size, grid_size) + helper_test_step( + envs, + check_obs, + grid_size=grid_size, + batch_size=batch_size, + ) + + +def test_multiple_steps(): + max_steps, batch_size, grid_size = 20, 5, 3 + n_steps = int(max_steps * 0.75) + envs = init_my_envs(max_steps, batch_size, grid_size) + helper_test_multiple_steps( + envs, + n_steps, + check_obs, + grid_size=grid_size, + batch_size=batch_size, + ) + + +def test_multiple_episodes(): + max_steps, batch_size, grid_size = 20, 100, 3 + n_steps = int(max_steps * 8.25) + envs = init_my_envs(max_steps, batch_size, grid_size) + helper_test_multi_ple_episodes( + envs, + n_steps, + max_steps, + check_obs, + grid_size=grid_size, + batch_size=batch_size, + ) + + +def overwrite_pos( + step_i, + batch_deltas, + n_steps_in_epi, + env, + p_red_pos, + p_blue_pos, + c_red_pos, + c_blue_pos, + c_red_coin, + **kwargs, +): + assert len(c_red_coin) == n_steps_in_epi + assert len(p_red_pos) == n_steps_in_epi + assert len(p_blue_pos) == n_steps_in_epi + assert len(c_red_pos) == n_steps_in_epi + assert len(c_blue_pos) == n_steps_in_epi + + env.red_coin = shift_consistently( + c_red_coin, step_i, n_steps_in_epi, batch_deltas + ) + env.red_pos = shift_consistently( + p_red_pos, step_i, n_steps_in_epi, batch_deltas + ) + env.blue_pos = shift_consistently( + p_blue_pos, step_i, n_steps_in_epi, batch_deltas + ) + env.red_coin_pos = shift_consistently( + c_red_pos, step_i, n_steps_in_epi, batch_deltas + ) + env.blue_coin_pos = shift_consistently( + c_blue_pos, step_i, n_steps_in_epi, batch_deltas + ) + + env.red_coin = np.array(env.red_coin, dtype=np.int8) + env.red_pos = np.array(env.red_pos) + env.blue_pos = np.array(env.blue_pos) + env.red_coin_pos = np.array(env.red_coin_pos) + env.blue_coin_pos = np.array(env.blue_coin_pos) + + +def test_logged_info_no_picking(): + p_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] + p_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] + p_red_act = [0, 0, 0, 0] + p_blue_act = [0, 0, 0, 0] + c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] + c_blue_pos = [[2, 2], [2, 2], [2, 2], [2, 2]] + c_red_coin = [1, 1, 1, 1] + max_steps, batch_size, grid_size = 4, 28, 3 + n_steps = max_steps + + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + c_red_coin=c_red_coin, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=0.0, + red_own=None, + blue_own=None, + ) + + +def test_logged_info__red_pick_red_all_the_time(): + p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] + p_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] + p_red_act = [0, 0, 0, 0] + p_blue_act = [0, 0, 0, 0] + c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] + c_blue_pos = [[2, 2], [2, 2], [2, 2], [2, 2]] + c_red_coin = [0, 0, 0, 0] + max_steps, batch_size, grid_size = 4, 28, 3 + n_steps = max_steps + + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + c_red_coin=c_red_coin, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=1.0, + blue_speed=0.0, + red_own=1.0, + blue_own=None, + ) + + +def test_logged_info__blue_pick_red_all_the_time(): + p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] + p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] + p_red_act = [0, 0, 0, 0] + p_blue_act = [0, 0, 0, 0] + c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] + c_blue_pos = [[2, 2], [2, 2], [2, 2], [2, 2]] + c_red_coin = [1, 1, 1, 1] + max_steps, batch_size, grid_size = 4, 28, 3 + n_steps = max_steps + + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + c_red_coin=c_red_coin, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=1.0, + blue_speed=1.0, + red_own=1.0, + blue_own=0.0, + ) + + +def test_logged_info__blue_pick_blue_all_the_time(): + p_red_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] + p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] + p_red_act = [0, 0, 0, 0] + p_blue_act = [0, 0, 0, 0] + c_red_pos = [[2, 2], [2, 2], [2, 2], [2, 2]] + c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] + c_red_coin = [1, 1, 1, 1] + max_steps, batch_size, grid_size = 4, 28, 3 + n_steps = max_steps + + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + c_red_coin=c_red_coin, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=1.0, + red_own=None, + blue_own=1.0, + ) + + +def test_logged_info__red_cant_pick_selfish_blue(): + p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] + p_blue_pos = [[0, 0], [0, 0], [0, 0], [0, 0]] + p_red_act = [0, 0, 0, 0] + p_blue_act = [0, 0, 0, 0] + c_red_pos = [[2, 2], [2, 2], [2, 2], [2, 2]] + c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] + c_red_coin = [1, 1, 1, 1] + max_steps, batch_size, grid_size = 4, 28, 3 + n_steps = max_steps + + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + c_red_coin=c_red_coin, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=0.0, + red_own=None, + blue_own=None, + ) + + +def test_logged_info__both_pick_coop_blue_all_the_time_wt_difference_in_actions(): + p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] + p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] + p_red_act = [0, 1, 2, 3] + p_blue_act = [0, 1, 2, 3] + c_red_pos = [[2, 2], [2, 2], [2, 2], [2, 2]] + c_blue_pos = [[1, 1], [1, 2], [2, 0], [0, 0]] + c_red_coin = [0, 0, 0, 0] + max_steps, batch_size, grid_size = 4, 4, 3 + n_steps = max_steps + + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + c_red_coin=c_red_coin, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=1.0, + blue_speed=1.0, + red_own=0.0, + blue_own=1.0, + ) + + +def test_logged_info__both_pick_blue_all_the_time(): + p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] + p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] + p_red_act = [0, 0, 0, 0] + p_blue_act = [0, 0, 0, 0] + c_red_pos = [[2, 2], [2, 2], [2, 2], [2, 2]] + c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] + c_red_coin = [0, 0, 0, 0] + max_steps, batch_size, grid_size = 4, 28, 3 + n_steps = max_steps + + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + c_red_coin=c_red_coin, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=1.0, + blue_speed=1.0, + red_own=0.0, + blue_own=1.0, + ) + + +def test_logged_info__both_pick_red_all_the_time(): + p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] + p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] + p_red_act = [0, 0, 0, 0] + p_blue_act = [0, 0, 0, 0] + c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] + c_blue_pos = [[2, 2], [2, 2], [2, 2], [2, 2]] + c_red_coin = [1, 1, 1, 1] + max_steps, batch_size, grid_size = 4, 28, 3 + n_steps = max_steps + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + c_red_coin=c_red_coin, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=1.0, + blue_speed=1.0, + red_own=1.0, + blue_own=0.0, + ) + + +def test_logged_info__both_pick_coop_red_half_the_time(): + p_red_pos = [[0, 0], [0, 0], [1, 0], [1, 0]] + p_blue_pos = [[0, 0], [0, 0], [1, 0], [1, 0]] + p_red_act = [0, 0, 0, 0] + p_blue_act = [0, 0, 0, 0] + c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] + c_blue_pos = [[2, 2], [2, 2], [2, 2], [2, 2]] + c_red_coin = [1, 1, 1, 1] + max_steps, batch_size, grid_size = 4, 28, 3 + n_steps = max_steps + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + c_red_coin=c_red_coin, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.5, + blue_speed=0.5, + red_own=1.0, + blue_own=0.0, + ) + + +def test_logged_info__both_pick_selfish_blue_half_the_time(): + p_red_pos = [[0, 0], [0, 0], [1, 0], [1, 0]] + p_blue_pos = [[1, 0], [1, 0], [0, 0], [0, 0]] + p_red_act = [0, 0, 0, 0] + p_blue_act = [0, 0, 0, 0] + c_red_pos = [[2, 2], [2, 2], [2, 2], [2, 2]] + c_blue_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] + c_red_coin = [1, 1, 1, 1] + max_steps, batch_size, grid_size = 4, 28, 3 + n_steps = max_steps + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + c_red_coin=c_red_coin, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=0.5, + red_own=None, + blue_own=1.0, + ) + + +def test_logged_info__both_dont_pick_coop_red(): + p_red_pos = [[0, 0], [0, 0], [0, 0], [1, 0]] + p_blue_pos = [[1, 0], [1, 0], [0, 0], [0, 0]] + p_red_act = [0, 0, 0, 0] + p_blue_act = [0, 0, 0, 0] + c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] + c_blue_pos = [[2, 2], [2, 2], [2, 2], [2, 2]] + c_red_coin = [1, 1, 1, 1] + max_steps, batch_size, grid_size = 4, 28, 3 + n_steps = max_steps + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + c_red_coin=c_red_coin, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.0, + blue_speed=0.0, + red_own=None, + blue_own=None, + ) + + +def test_logged_info__pick_half_the_time_half_selfish_blue_half_selfish_red(): + p_red_pos = [[0, 0], [0, 0], [1, 0], [1, 0]] + p_blue_pos = [[1, 0], [1, 0], [0, 0], [0, 0]] + p_red_act = [0, 0, 0, 0] + p_blue_act = [0, 0, 0, 0] + c_red_pos = [[1, 1], [2, 2], [1, 1], [2, 2]] + c_blue_pos = [[2, 2], [1, 1], [2, 2], [1, 1]] + c_red_coin = [1, 1, 0, 0] + max_steps, batch_size, grid_size = 4, 28, 3 + n_steps = max_steps + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + c_red_coin=c_red_coin, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=0.25, + blue_speed=0.25, + red_own=1.0, + blue_own=1.0, + ) + + +# def test_get_and_set_env_state(): +# max_steps, batch_size, grid_size = 20, 100, 3 +# n_steps = int(max_steps * 8.25) +# envs = init_my_envs(max_steps, batch_size, grid_size) +# +# for env in envs: +# obs = env.reset() +# initial_env_state = env._save_env() +# initial_env_state_saved = copy.deepcopy(initial_env_state) +# env_initial = copy.deepcopy(env) +# +# step_i = 0 +# for _ in range(n_steps): +# step_i += 1 +# actions = { +# policy_id: [ +# random.randint(0, env.NUM_ACTIONS - 1) +# for _ in range(batch_size) +# ] +# for policy_id in env.players_ids +# } +# obs, reward, done, info = env.step(actions) +# +# assert all( +# [ +# v == initial_env_state_saved[k] +# if not isinstance(v, np.ndarray) +# else (v == initial_env_state_saved[k]).all() +# for k, v in initial_env_state.items() +# ] +# ) +# env_state_after_step = env._save_env() +# env_after_step = copy.deepcopy(env) +# +# env._load_env(initial_env_state) +# env_vars, env_initial_vars = vars(env), vars(env_initial) +# env_vars.pop("np_random", None) +# env_initial_vars.pop("np_random", None) +# assert all( +# [ +# v == env_initial_vars[k] +# if not isinstance(v, np.ndarray) +# else (v == env_initial_vars[k]).all() +# for k, v in env_vars.items() +# ] +# ) +# +# env._load_env(env_state_after_step) +# env_vars, env_after_step_vars = vars(env), vars(env_after_step) +# env_vars.pop("np_random", None) +# env_after_step_vars.pop("np_random", None) +# assert all( +# [ +# v == env_after_step_vars[k] +# if not isinstance(v, np.ndarray) +# else (v == env_after_step_vars[k]).all() +# for k, v in env_vars.items() +# ] +# ) +# +# if done["__all__"]: +# obs = env.reset() +# step_i = 0 + + +def test_observations_are_not_invariant_to_the_player_trained_wt_step(): + p_red_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [0, 0], + [1, 1], + [2, 0], + [0, 1], + [2, 2], + [1, 2], + ] + p_blue_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [1, 1], + [0, 0], + [0, 1], + [2, 0], + [1, 2], + [2, 2], + ] + p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + c_red_pos = [ + [1, 1], + [2, 2], + [0, 1], + [2, 2], + [2, 2], + [2, 2], + [0, 0], + [2, 2], + [2, 2], + [2, 1], + ] + c_blue_pos = [ + [2, 2], + [1, 1], + [2, 2], + [0, 1], + [2, 2], + [2, 2], + [2, 2], + [0, 0], + [2, 1], + [2, 2], + ] + c_red_coin = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + max_steps, batch_size, grid_size = 10, 52, 3 + n_steps = max_steps + envs = init_my_envs( + max_steps, batch_size, grid_size, same_obs_for_each_player=True + ) + + batch_deltas = [ + i % max_steps if i % 2 == 0 else i % max_steps - 1 + for i in range(batch_size) + ] + + for env_i, env in enumerate(envs): + _ = env.reset() + step_i = 0 + + for _ in range(n_steps): + overwrite_pos( + step_i, + batch_deltas, + max_steps, + env, + p_red_pos, + p_blue_pos, + c_red_pos, + c_blue_pos, + c_red_coin, + ) + actions = { + "player_red": [ + p_red_act[(step_i + delta) % max_steps] + for delta in batch_deltas + ], + "player_blue": [ + p_blue_act[(step_i + delta) % max_steps] + for delta in batch_deltas + ], + } + obs, reward, done, info = env.step(actions) + + step_i += 1 + # assert that observations are not + # symmetrical respective to the + # actions + if step_i % 2 == 1: + obs_step_odd = obs + elif step_i % 2 == 0: + assert np.any( + obs[env.players_ids[0]] != obs_step_odd[env.players_ids[1]] + ) + assert np.any( + obs[env.players_ids[1]] != obs_step_odd[env.players_ids[0]] + ) + assert_obs_is_not_symmetrical(obs, env) + + if step_i == max_steps: + break + + +def test_observations_are_not_invariant_to_the_player_trained_wt_reset(): + p_red_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [0, 0], + [1, 1], + [2, 0], + [0, 1], + [2, 2], + [1, 2], + ] + p_blue_pos = [ + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [1, 1], + [0, 0], + [0, 1], + [2, 0], + [1, 2], + [2, 2], + ] + p_red_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + p_blue_act = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + c_red_pos = [ + [1, 1], + [2, 2], + [0, 1], + [2, 2], + [2, 2], + [2, 2], + [0, 0], + [2, 2], + [2, 2], + [2, 1], + ] + c_blue_pos = [ + [2, 2], + [1, 1], + [2, 2], + [0, 1], + [2, 2], + [2, 2], + [2, 2], + [0, 0], + [2, 1], + [2, 2], + ] + c_red_coin = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + max_steps, batch_size, grid_size = 10, 52, 3 + n_steps = max_steps + envs = init_my_envs( + max_steps, batch_size, grid_size, same_obs_for_each_player=True + ) + + batch_deltas = [ + i % max_steps if i % 2 == 0 else i % max_steps - 1 + for i in range(batch_size) + ] + + for env_i, env in enumerate(envs): + obs = env.reset() + assert_obs_is_not_symmetrical(obs, env) + step_i = 0 + + for _ in range(n_steps): + overwrite_pos( + step_i, + batch_deltas, + max_steps, + env, + p_red_pos, + p_blue_pos, + c_red_pos, + c_blue_pos, + c_red_coin, + ) + actions = { + "player_red": [ + p_red_act[(step_i + delta) % max_steps] + for delta in batch_deltas + ], + "player_blue": [ + p_blue_act[(step_i + delta) % max_steps] + for delta in batch_deltas + ], + } + _, _, _, _ = env.step(actions) + + step_i += 1 + + if step_i == max_steps: + break + + +@flaky(max_runs=4, min_passes=1) +def test_who_pick_is_random(): + size = 100 + p_red_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] * size + p_blue_pos = [[1, 0], [1, 0], [1, 0], [1, 0]] * size + p_red_act = [0, 0, 0, 0] * size + p_blue_act = [0, 0, 0, 0] * size + c_red_pos = [[1, 1], [1, 1], [1, 1], [1, 1]] * size + c_blue_pos = [[2, 2], [2, 2], [2, 2], [2, 2]] * size + c_red_coin = [1, 1, 1, 1] * size + max_steps, batch_size, grid_size = int(4 * size), 28, 3 + n_steps = max_steps + + envs = init_my_envs(max_steps, batch_size, grid_size) + + helper_assert_info( + n_steps=n_steps, + batch_size=batch_size, + p_red_act=p_red_act, + p_blue_act=p_blue_act, + envs=envs, + grid_size=grid_size, + max_steps=max_steps, + p_red_pos=p_red_pos, + p_blue_pos=p_blue_pos, + c_red_pos=c_red_pos, + c_blue_pos=c_blue_pos, + c_red_coin=c_red_coin, + check_obs_fn=check_obs, + overwrite_pos_fn=overwrite_pos, + red_speed=1.0, + blue_speed=1.0, + red_own=1.0, + blue_own=0.0, + repetitions=1, + ) diff --git a/tests/marltoolbox/examples_and_experiments/manual_test_end_to_end.py b/tests/marltoolbox/examples_and_experiments/manual_test_end_to_end.py index 6668b41..e7b0cb2 100644 --- a/tests/marltoolbox/examples_and_experiments/manual_test_end_to_end.py +++ b/tests/marltoolbox/examples_and_experiments/manual_test_end_to_end.py @@ -4,256 +4,321 @@ import ray from marltoolbox.utils import postprocessing -from marltoolbox.utils.miscellaneous import check_learning_achieved +from marltoolbox.utils.exp_analysis import check_learning_achieved -def print_metrics_available(tune_analysis): - print("metric available in tune_analysis:", - tune_analysis.results_df.columns.tolist()) +def print_metrics_available(experiment_analysis): + print( + "metric available in experiment_analysis:", + experiment_analysis.results_df.columns.tolist(), + ) def test_pg_ipd(): from marltoolbox.examples.rllib_api.pg_ipd import main + # Restart Ray defensively in case the ray connection is lost. ray.shutdown() - tune_analysis = main(debug=False) - print_metrics_available(tune_analysis) + experiment_analysis = main(debug=False) + print_metrics_available(experiment_analysis) + check_learning_achieved(tune_results=experiment_analysis, max_=-75) check_learning_achieved( - tune_results=tune_analysis, - max_=-75) + tune_results=experiment_analysis, + min_=0.9, + metric="custom_metrics.DD_freq/player_row_mean", + ) check_learning_achieved( - tune_results=tune_analysis, + tune_results=experiment_analysis, min_=0.9, - metric="custom_metrics.DD_freq/player_row_mean") + metric="custom_metrics.DD_freq/player_col_mean", + ) + + +def test_r2d2_ipd(): + from marltoolbox.examples.rllib_api.pg_ipd import main + + # Restart Ray defensively in case the ray connection is lost. + ray.shutdown() + experiment_analysis = main(debug=False) + print_metrics_available(experiment_analysis) + check_learning_achieved(tune_results=experiment_analysis, max_=-75) check_learning_achieved( - tune_results=tune_analysis, + tune_results=experiment_analysis, + min_=0.9, + metric="custom_metrics.DD_freq/player_row_mean", + ) + check_learning_achieved( + tune_results=experiment_analysis, min_=0.9, - metric="custom_metrics.DD_freq/player_col_mean") + metric="custom_metrics.DD_freq/player_col_mean", + ) def test_ltft_ipd(): from marltoolbox.experiments.rllib_api.ltft_various_env import main + ray.shutdown() - tune_analysis_self_play, tune_analysis_against_opponent = main( + experiment_analysis_self_play, experiment_analysis_against_opponent = main( debug=False, env="IteratedPrisonersDilemma", train_n_replicates=1, - against_naive_opp=True) - print_metrics_available(tune_analysis_self_play) + against_naive_opp=True, + ) + print_metrics_available(experiment_analysis_self_play) check_learning_achieved( - tune_results=tune_analysis_self_play, - min_=-42) + tune_results=experiment_analysis_self_play, min_=-42 + ) check_learning_achieved( - tune_results=tune_analysis_self_play, + tune_results=experiment_analysis_self_play, min_=0.9, - metric="custom_metrics.CC_freq/player_row_mean") + metric="custom_metrics.CC_freq/player_row_mean", + ) check_learning_achieved( - tune_results=tune_analysis_self_play, + tune_results=experiment_analysis_self_play, min_=0.9, - metric="custom_metrics.CC_freq/player_col_mean") - print_metrics_available(tune_analysis_against_opponent) + metric="custom_metrics.CC_freq/player_col_mean", + ) + print_metrics_available(experiment_analysis_against_opponent) check_learning_achieved( - tune_results=tune_analysis_against_opponent, - max_=-75) + tune_results=experiment_analysis_against_opponent, max_=-75 + ) check_learning_achieved( - tune_results=tune_analysis_against_opponent, + tune_results=experiment_analysis_against_opponent, min_=0.9, - metric="custom_metrics.DD_freq/player_row_mean") + metric="custom_metrics.DD_freq/player_row_mean", + ) check_learning_achieved( - tune_results=tune_analysis_against_opponent, + tune_results=experiment_analysis_against_opponent, min_=0.9, - metric="custom_metrics.DD_freq/player_col_mean") + metric="custom_metrics.DD_freq/player_col_mean", + ) def test_amtft_ipd(): from marltoolbox.experiments.rllib_api.amtft_various_env import main + ray.shutdown() - tune_analysis_per_welfare, analysis_metrics_per_mode = main( - debug=False, train_n_replicates=1, filter_utilitarian=False, - env="IteratedPrisonersDilemma") - for welfare_name, tune_analysis in tune_analysis_per_welfare.items(): + experiment_analysis_per_welfare, analysis_metrics_per_mode = main( + debug=False, + train_n_replicates=1, + filter_utilitarian=False, + env="IteratedPrisonersDilemma", + ) + for ( + welfare_name, + experiment_analysis, + ) in experiment_analysis_per_welfare.items(): print("welfare_name", welfare_name) - print_metrics_available(tune_analysis) + print_metrics_available(experiment_analysis) + check_learning_achieved(tune_results=experiment_analysis, min_=-204) check_learning_achieved( - tune_results=tune_analysis, min_=-204) - check_learning_achieved( - tune_results=tune_analysis, + tune_results=experiment_analysis, min_=0.9, - metric="custom_metrics.CC_freq/player_row_mean" + metric="custom_metrics.CC_freq/player_row_mean", ) check_learning_achieved( - tune_results=tune_analysis, + tune_results=experiment_analysis, min_=0.9, - metric="custom_metrics.CC_freq/player_col_mean" + metric="custom_metrics.CC_freq/player_col_mean", ) def test_ppo_asym_coin_game(): from marltoolbox.examples.rllib_api.ppo_coin_game import main + ray.shutdown() tune_analysis = main(debug=False, stop_iters=200) print_metrics_available(tune_analysis) - check_learning_achieved( - tune_results=tune_analysis, min_=15) + check_learning_achieved(tune_results=tune_analysis, min_=15) check_learning_achieved( tune_results=tune_analysis, min_=0.30, - metric="custom_metrics.pick_speed/player_red_mean") + metric="custom_metrics.pick_speed/player_red_mean", + ) check_learning_achieved( tune_results=tune_analysis, min_=0.30, - metric="custom_metrics.pick_speed/player_blue_mean") + metric="custom_metrics.pick_speed/player_blue_mean", + ) check_learning_achieved( tune_results=tune_analysis, min_=0.40, max_=0.60, - metric="custom_metrics.pick_own_color/player_red_mean") + metric="custom_metrics.pick_own_color/player_red_mean", + ) check_learning_achieved( tune_results=tune_analysis, min_=0.40, max_=0.60, - metric="custom_metrics.pick_own_color/player_blue_mean") + metric="custom_metrics.pick_own_color/player_blue_mean", + ) def test_dqn_coin_game(): from marltoolbox.examples.rllib_api.dqn_coin_game import main + ray.shutdown() tune_analysis = main(debug=False) print_metrics_available(tune_analysis) - check_learning_achieved( - tune_results=tune_analysis, max_=20) + check_learning_achieved(tune_results=tune_analysis, max_=20) check_learning_achieved( tune_results=tune_analysis, min_=0.5, - metric="custom_metrics.pick_speed/player_red_mean") + metric="custom_metrics.pick_speed/player_red_mean", + ) check_learning_achieved( tune_results=tune_analysis, min_=0.5, - metric="custom_metrics.pick_speed/player_blue_mean") + metric="custom_metrics.pick_speed/player_blue_mean", + ) check_learning_achieved( tune_results=tune_analysis, min_=0.40, max_=0.6, - metric="custom_metrics.pick_own_color/player_red_mean") + metric="custom_metrics.pick_own_color/player_red_mean", + ) check_learning_achieved( tune_results=tune_analysis, min_=0.40, max_=0.6, - metric="custom_metrics.pick_own_color/player_blue_mean") + metric="custom_metrics.pick_own_color/player_blue_mean", + ) def test_dqn_wt_utilitarian_welfare_coin_game(): from marltoolbox.examples.rllib_api.dqn_wt_welfare import main + ray.shutdown() tune_analysis = main(debug=False) print_metrics_available(tune_analysis) - check_learning_achieved( - tune_results=tune_analysis, min_=50) + check_learning_achieved(tune_results=tune_analysis, min_=50) check_learning_achieved( tune_results=tune_analysis, min_=0.3, - metric="custom_metrics.pick_speed/player_red_mean") + metric="custom_metrics.pick_speed/player_red_mean", + ) check_learning_achieved( tune_results=tune_analysis, min_=0.3, - metric="custom_metrics.pick_speed/player_blue_mean") + metric="custom_metrics.pick_speed/player_blue_mean", + ) check_learning_achieved( tune_results=tune_analysis, min_=0.95, - metric="custom_metrics.pick_own_color/player_red_mean") + metric="custom_metrics.pick_own_color/player_red_mean", + ) check_learning_achieved( tune_results=tune_analysis, min_=0.95, - metric="custom_metrics.pick_own_color/player_blue_mean") + metric="custom_metrics.pick_own_color/player_blue_mean", + ) def test_dqn_wt_inequity_aversion_welfare_coin_game(): from marltoolbox.examples.rllib_api.dqn_wt_welfare import main + ray.shutdown() - tune_analysis = main(debug=False, - welfare=postprocessing.WELFARE_INEQUITY_AVERSION) + tune_analysis = main( + debug=False, welfare=postprocessing.WELFARE_INEQUITY_AVERSION + ) print_metrics_available(tune_analysis) - check_learning_achieved( - tune_results=tune_analysis, min_=50) + check_learning_achieved(tune_results=tune_analysis, min_=50) check_learning_achieved( tune_results=tune_analysis, min_=0.25, - metric="custom_metrics.pick_speed/player_red_mean") + metric="custom_metrics.pick_speed/player_red_mean", + ) check_learning_achieved( tune_results=tune_analysis, min_=0.25, - metric="custom_metrics.pick_speed/player_blue_mean") + metric="custom_metrics.pick_speed/player_blue_mean", + ) check_learning_achieved( tune_results=tune_analysis, min_=0.9, - metric="custom_metrics.pick_own_color/player_red_mean") + metric="custom_metrics.pick_own_color/player_red_mean", + ) check_learning_achieved( tune_results=tune_analysis, min_=0.9, - metric="custom_metrics.pick_own_color/player_blue_mean") + metric="custom_metrics.pick_own_color/player_blue_mean", + ) def test_ltft_coin_game(): from marltoolbox.experiments.rllib_api.ltft_various_env import main + ray.shutdown() tune_analysis_self_play, tune_analysis_against_opponent = main( - debug=False, env="CoinGame", train_n_replicates=1, - against_naive_opp=True) + debug=False, + env="CoinGame", + train_n_replicates=1, + against_naive_opp=True, + ) print_metrics_available(tune_analysis_self_play) - check_learning_achieved( - tune_results=tune_analysis_self_play, - min_=50) + check_learning_achieved(tune_results=tune_analysis_self_play, min_=50) check_learning_achieved( tune_results=tune_analysis_self_play, min_=0.3, - metric="custom_metrics.pick_speed/player_red_mean") + metric="custom_metrics.pick_speed/player_red_mean", + ) check_learning_achieved( tune_results=tune_analysis_self_play, min_=0.3, - metric="custom_metrics.pick_speed/player_blue_mean") + metric="custom_metrics.pick_speed/player_blue_mean", + ) check_learning_achieved( tune_results=tune_analysis_self_play, min_=0.9, - metric="custom_metrics.pick_own_color/player_red_mean") + metric="custom_metrics.pick_own_color/player_red_mean", + ) check_learning_achieved( tune_results=tune_analysis_self_play, min_=0.9, - metric="custom_metrics.pick_own_color/player_blue_mean") + metric="custom_metrics.pick_own_color/player_blue_mean", + ) print_metrics_available(tune_analysis_against_opponent) check_learning_achieved( - tune_results=tune_analysis_against_opponent, - max_=20) + tune_results=tune_analysis_against_opponent, max_=20 + ) check_learning_achieved( tune_results=tune_analysis_against_opponent, min_=0.3, - metric="custom_metrics.pick_speed/player_red_mean") + metric="custom_metrics.pick_speed/player_red_mean", + ) check_learning_achieved( tune_results=tune_analysis_against_opponent, min_=0.3, - metric="custom_metrics.pick_speed/player_blue_mean") + metric="custom_metrics.pick_speed/player_blue_mean", + ) check_learning_achieved( tune_results=tune_analysis_against_opponent, max_=0.6, - metric="custom_metrics.pick_own_color/player_red_mean") + metric="custom_metrics.pick_own_color/player_red_mean", + ) check_learning_achieved( tune_results=tune_analysis_against_opponent, max_=0.6, - metric="custom_metrics.pick_own_color/player_blue_mean") + metric="custom_metrics.pick_own_color/player_blue_mean", + ) def test_amtft_coin_game(): from marltoolbox.experiments.rllib_api.amtft_various_env import main + ray.shutdown() tune_analysis_per_welfare, analysis_metrics_per_mode = main( - debug=False, train_n_replicates=1, filter_utilitarian=False, - env="CoinGame") + debug=False, + train_n_replicates=1, + filter_utilitarian=False, + env="CoinGame", + ) for welfare_name, tune_analysis in tune_analysis_per_welfare.items(): print("welfare_name", welfare_name) print_metrics_available(tune_analysis) - check_learning_achieved( - tune_results=tune_analysis, - min_=40) + check_learning_achieved(tune_results=tune_analysis, min_=40) check_learning_achieved( tune_results=tune_analysis, min_=0.25, - metric="custom_metrics.pick_speed/player_red_mean") + metric="custom_metrics.pick_speed/player_red_mean", + ) diff --git a/tests/marltoolbox/examples_and_experiments/test_end_to_end.py b/tests/marltoolbox/examples_and_experiments/test_end_to_end.py index 0cac310..531d961 100644 --- a/tests/marltoolbox/examples_and_experiments/test_end_to_end.py +++ b/tests/marltoolbox/examples_and_experiments/test_end_to_end.py @@ -5,23 +5,37 @@ def test_pg_ipd(): from marltoolbox.examples.rllib_api.pg_ipd import main ray.shutdown() # Restart Ray defensively in case the ray connection is lost. - main(stop_iters=10, tf=False, debug=True) + main(debug=True) + + +def test_r2d2_ipd(): + from marltoolbox.examples.rllib_api.r2d2_ipd import main + + ray.shutdown() # Restart Ray defensively in case the ray connection is lost. + main(debug=True) def test_ppo_asym_coin_game(): from marltoolbox.examples.rllib_api.ppo_coin_game import main ray.shutdown() - main(debug=True, stop_iters=3, tf=False) + main(debug=True) -def test_ppo_asym_coin_game(): +def test_dqn_coin_game(): from marltoolbox.examples.rllib_api.dqn_coin_game import main ray.shutdown() main(debug=True) +def test_r2d2_cion_game(): + from marltoolbox.examples.rllib_api.r2d2_coin_game import main + + ray.shutdown() # Restart Ray defensively in case the ray connection is lost. + main(debug=True) + + def test_ltft_ipd(): from marltoolbox.experiments.rllib_api.ltft_various_env import main @@ -43,6 +57,13 @@ def test_amtft_ipd(): main(debug=True, env="IteratedPrisonersDilemma") +def test_amtft_ipd_with_r2d2(): + from marltoolbox.experiments.rllib_api.amtft_various_env import main + + ray.shutdown() + main(debug=True, env="IteratedPrisonersDilemma", use_r2d2=True) + + def test_amtft_iasymbos(): from marltoolbox.experiments.rllib_api.amtft_various_env import main @@ -183,3 +204,17 @@ def test_adaptive_mechanism_design_tune_class_api_wt_rllib_policy(): ray.shutdown() main(debug=True, use_rllib_policy=True) + + +def test_amtft_vs_exploiter(): + from marltoolbox.experiments.rllib_api.amtft_vs_lvl1_exploiter import main + + ray.shutdown() + main(debug=True) + + +def test_amtft_meta_game(): + from marltoolbox.experiments.rllib_api.amtft_meta_game import main + + ray.shutdown() + main(debug=True) diff --git a/tests/marltoolbox/utils/test_exploration.py b/tests/marltoolbox/utils/test_exploration.py index 7934f43..e496a33 100644 --- a/tests/marltoolbox/utils/test_exploration.py +++ b/tests/marltoolbox/utils/test_exploration.py @@ -7,10 +7,10 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.utils.schedules import PiecewiseSchedule -from marltoolbox.envs.coin_game import \ - CoinGame -from marltoolbox.envs.matrix_sequential_social_dilemma import \ - IteratedPrisonersDilemma +from marltoolbox.envs.coin_game import CoinGame +from marltoolbox.envs.matrix_sequential_social_dilemma import ( + IteratedPrisonersDilemma, +) from marltoolbox.utils import exploration ROUNDING_ERROR = 1e-3 @@ -23,32 +23,36 @@ def assert_equal_wt_some_epsilon(v1, v2): def test_clusterize_by_distance(): output = exploration.clusterize_by_distance( - torch.Tensor([0.0, 0.4, 1.0, 1.4, 1.8, 3.0]), 0.5) + torch.Tensor([0.0, 0.4, 1.0, 1.4, 1.8, 3.0]), 0.5 + ) assert_equal_wt_some_epsilon( - output, - torch.Tensor([0.2000, 0.2000, 1.4000, 1.4000, 1.4000, 3.0000])) + output, torch.Tensor([0.2000, 0.2000, 1.4000, 1.4000, 1.4000, 3.0000]) + ) output = exploration.clusterize_by_distance( - torch.Tensor([0.0, 0.5, 1.0, 1.4, 1.8, 3.0]), 0.5) + torch.Tensor([0.0, 0.5, 1.0, 1.4, 1.8, 3.0]), 0.5 + ) assert_equal_wt_some_epsilon( - output, - torch.Tensor([0.0000, 0.5000, 1.4000, 1.4000, 1.4000, 3.0000])) + output, torch.Tensor([0.0000, 0.5000, 1.4000, 1.4000, 1.4000, 3.0000]) + ) output = exploration.clusterize_by_distance( - torch.Tensor([-10.0, -9.8, 1.0, 1.4, 1.8, 3.0]), 0.5) + torch.Tensor([-10.0, -9.8, 1.0, 1.4, 1.8, 3.0]), 0.5 + ) assert_equal_wt_some_epsilon( output, - torch.Tensor([-9.9000, -9.9000, 1.4000, 1.4000, 1.4000, 3.0000])) + torch.Tensor([-9.9000, -9.9000, 1.4000, 1.4000, 1.4000, 3.0000]), + ) output = exploration.clusterize_by_distance( - torch.Tensor([-1.0, -0.51, -0.1, 0.0, 0.1, 0.51, 1.0]), 0.5) + torch.Tensor([-1.0, -0.51, -0.1, 0.0, 0.1, 0.51, 1.0]), 0.5 + ) assert_equal_wt_some_epsilon( - output, - torch.Tensor([0., 0., 0., 0., 0., 0., 0.])) + output, torch.Tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + ) class TestSoftQSchedule: - def set_class_to_test(self): self.class_to_test = exploration.SoftQSchedule @@ -56,58 +60,59 @@ def test__set_temperature_wt_explore(self): self.set_class_to_test() self.arrange_for_simple_ipd() - self.softqschedule._set_temperature( - explore=True, timestep=0) + self.softqschedule._set_temperature(explore=True, timestep=0) assert self.softqschedule.temperature == self.initial_temperature self.softqschedule._set_temperature( - explore=True, timestep=self.temperature_timesteps) + explore=True, timestep=self.temperature_timesteps + ) assert self.softqschedule.temperature == self.final_temperature self.softqschedule._set_temperature( - explore=True, timestep=self.temperature_timesteps // 2) - assert abs(self.softqschedule.temperature - - (self.initial_temperature - self.final_temperature) / 2) < \ - ROUNDING_ERROR + explore=True, timestep=self.temperature_timesteps // 2 + ) + assert ( + abs( + self.softqschedule.temperature + - (self.initial_temperature - self.final_temperature) / 2 + ) + < ROUNDING_ERROR + ) def test__set_temperature_wtout_explore(self): self.set_class_to_test() self.arrange_for_simple_ipd() - self.softqschedule._set_temperature( - explore=False, timestep=0) + self.softqschedule._set_temperature(explore=False, timestep=0) assert self.softqschedule.temperature == 1.0 self.softqschedule._set_temperature( - explore=False, timestep=self.temperature_timesteps) + explore=False, timestep=self.temperature_timesteps + ) assert self.softqschedule.temperature == 1.0 self.softqschedule._set_temperature( - explore=False, timestep=self.temperature_timesteps // 2) + explore=False, timestep=self.temperature_timesteps // 2 + ) assert self.softqschedule.temperature == 1.0 def test__set_temperature_wt_explore_wt_multi_steps_schedule(self): self.class_to_test = exploration.SoftQSchedule self.arrange_for_multi_step_wt_coin_game() - self.softqschedule._set_temperature( - explore=True, timestep=0) + self.softqschedule._set_temperature(explore=True, timestep=0) assert self.softqschedule.temperature == 2.0 - self.softqschedule._set_temperature( - explore=True, timestep=2000) + self.softqschedule._set_temperature(explore=True, timestep=2000) assert self.softqschedule.temperature == 0.1 - self.softqschedule._set_temperature( - explore=True, timestep=3000) + self.softqschedule._set_temperature(explore=True, timestep=3000) assert self.softqschedule.temperature == 0.1 - self.softqschedule._set_temperature( - explore=True, timestep=500) + self.softqschedule._set_temperature(explore=True, timestep=500) assert abs(self.softqschedule.temperature - 1.25) < ROUNDING_ERROR - self.softqschedule._set_temperature( - explore=True, timestep=1500) + self.softqschedule._set_temperature(explore=True, timestep=1500) assert abs(self.softqschedule.temperature - 0.3) < ROUNDING_ERROR def arrange_for_simple_ipd(self): @@ -122,24 +127,21 @@ def arrange_for_multi_step_wt_coin_game(self): self.final_temperature = 0.0 self.temperature_timesteps = 0.0 self.temperature_schedule = PiecewiseSchedule( - endpoints=[ - (0, 2.0), - (1000, 0.5), - (2000, 0.1)], + endpoints=[(0, 2.0), (1000, 0.5), (2000, 0.1)], outside_value=0.1, - framework="torch") + framework="torch", + ) self.init_coin_game_scheduler() def init_ipd_scheduler(self): self.softqschedule = self.init_scheduler( IteratedPrisonersDilemma.ACTION_SPACE, - IteratedPrisonersDilemma.OBSERVATION_SPACE + IteratedPrisonersDilemma.OBSERVATION_SPACE, ) def init_coin_game_scheduler(self): self.softqschedule = self.init_scheduler( - CoinGame.ACTION_SPACE, - CoinGame({}).OBSERVATION_SPACE + CoinGame.ACTION_SPACE, CoinGame({}).OBSERVATION_SPACE ) def init_scheduler(self, action_space, obs_space): @@ -158,8 +160,8 @@ def init_scheduler(self, action_space, obs_space): action_space=action_space, num_outputs=action_space.n, name="fc", - model_config=MODEL_DEFAULTS - ) + model_config=MODEL_DEFAULTS, + ), ) def test__apply_temperature(self): @@ -173,22 +175,30 @@ def test__apply_temperature(self): ) def apply_and_assert_apply_temperature(self, temperature, inputs): - action_distribution, action_dist_class = \ - self.set_temperature_and_get_args(temperature=temperature, - inputs=inputs) + ( + action_distribution, + action_dist_class, + ) = self.set_temperature_and_get_args( + temperature=temperature, inputs=inputs + ) new_action_distribution = self.softqschedule._apply_temperature( - copy.deepcopy(action_distribution), action_dist_class) + copy.deepcopy(action_distribution), action_dist_class + ) assert all( abs(n_v - v / self.softqschedule.temperature) < ROUNDING_ERROR - for v, n_v in zip(action_distribution.inputs, - new_action_distribution.inputs)) + for v, n_v in zip( + action_distribution.inputs, new_action_distribution.inputs + ) + ) def set_temperature_and_get_args(self, temperature, inputs): action_dist_class = TorchCategorical + inputs = torch.tensor(inputs) action_distribution = TorchCategorical( - inputs, self.softqschedule.model, temperature=1.0) + inputs, self.softqschedule.model, temperature=1.0 + ) self.softqschedule.temperature = temperature return action_distribution, action_dist_class @@ -196,8 +206,7 @@ def test_get_exploration_action_wtout_explore(self): self.helper_test_get_exploration_action_wt_explore(explore=False) def random_inputs(self): - return np.random.random( - size=(1, np.random.randint(1, 50, size=1)[0])) + return np.random.random(size=(1, np.random.randint(1, 50, size=1)[0])) def random_timestep(self): return np.random.randint(0, 10000, size=1)[0] @@ -206,23 +215,26 @@ def random_temperature(self): return np.random.random(size=1)[0] * 10 + 1e-9 def apply_and_assert_get_exploration_action( - self, inputs, explore, timestep): + self, inputs, explore, timestep + ): - initial_action_distribution, _ = \ - self.set_temperature_and_get_args(temperature=1.0, - inputs=inputs) + initial_action_distribution, _ = self.set_temperature_and_get_args( + temperature=1.0, inputs=inputs + ) action_distribution = copy.deepcopy(initial_action_distribution) _ = self.softqschedule.get_exploration_action( - action_distribution, - timestep=timestep, - explore=explore + action_distribution, timestep=timestep, explore=explore ) temperature = self.softqschedule.temperature if explore else 1.0 - errors = [abs(n_v - v / temperature) - for v, n_v in zip(initial_action_distribution.inputs[0], - action_distribution.inputs[0])] + errors = [ + abs(n_v - v / temperature) + for v, n_v in zip( + initial_action_distribution.inputs[0], + action_distribution.inputs[0], + ) + ] assert all(err < ROUNDING_ERROR for err in errors), f"errors: {errors}" def test_get_exploration_action_wt_explore(self): @@ -236,11 +248,11 @@ def helper_test_get_exploration_action_wt_explore(self, explore): self.apply_and_assert_get_exploration_action( inputs=self.random_inputs(), explore=explore, - timestep=self.random_timestep()) + timestep=self.random_timestep(), + ) class TestSoftQScheduleWtClustering(TestSoftQSchedule): - def set_class_to_test(self): self.class_to_test = exploration.SoftQScheduleWtClustering @@ -250,15 +262,14 @@ def helper_test_get_exploration_action_wt_explore(self, explore): for inputs in self.get_inputs_list(): self.apply_and_assert_get_exploration_action( - inputs=inputs, - explore=explore, - timestep=self.random_timestep()) + inputs=inputs, explore=explore, timestep=self.random_timestep() + ) def get_inputs_list(self): return [ [[1.0, 0.0]], [[5.0, -1.0]], [[1.0, 1.6]], - [[101, -2.3]], - [[65, 98, 13, 56, 123, 156, 84]], + [[101.0, -2.3]], + [[65.0, 98.0, 13.0, 56.0, 123.0, 156.0, 84.0]], ] diff --git a/tests/marltoolbox/utils/test_log.py b/tests/marltoolbox/utils/test_log.py index 4b6ec9a..79149b6 100644 --- a/tests/marltoolbox/utils/test_log.py +++ b/tests/marltoolbox/utils/test_log.py @@ -1,30 +1,37 @@ import numpy as np +import torch -from marltoolbox.utils.log import _add_entropy_to_log +from marltoolbox.utils.log.log import add_entropy_to_log def test__add_entropy_to_log(): to_log = {} - train_batch = {"action_dist_inputs": np.array([[0.0, 1.0]])} - to_log = _add_entropy_to_log(train_batch, to_log) - assert_close(to_log[f"entropy_buffer_samples_avg"], 0.00, 0.001) - assert_close(to_log[f"entropy_buffer_samples_single"], 0.00, 0.001) + train_batch = {"action_dist_inputs": torch.tensor([[0.0, 1.0]])} + to_log = add_entropy_to_log(train_batch, to_log) + assert_are_close(to_log[f"entropy_buffer_samples_avg"], 0.00, 0.001) + assert_are_close(to_log[f"entropy_buffer_samples_single"], 0.00, 0.001) to_log = {} - train_batch = {"action_dist_inputs": np.array([[0.75, 0.25]])} - to_log = _add_entropy_to_log(train_batch, to_log) - assert_close(to_log[f"entropy_buffer_samples_avg"], 0.562335145, 0.001) - assert_close(to_log[f"entropy_buffer_samples_single"], 0.562335145, 0.001) + train_batch = {"action_dist_inputs": torch.tensor([[0.75, 0.25]])} + to_log = add_entropy_to_log(train_batch, to_log) + assert_are_close(to_log[f"entropy_buffer_samples_avg"], 0.562335145, 0.001) + assert_are_close( + to_log[f"entropy_buffer_samples_single"], 0.562335145, 0.001 + ) to_log = {} - train_batch = {"action_dist_inputs": np.array([[0.62, 0.12, 0.13, 0.13]])} - to_log = _add_entropy_to_log(train_batch, to_log) - assert_close(to_log[f"entropy_buffer_samples_avg"], 1.081271236, 0.001) - assert_close(to_log[f"entropy_buffer_samples_single"], 1.081271236, 0.001) + train_batch = { + "action_dist_inputs": torch.tensor([[0.62, 0.12, 0.13, 0.13]]) + } + to_log = add_entropy_to_log(train_batch, to_log) + assert_are_close(to_log[f"entropy_buffer_samples_avg"], 1.081271236, 0.001) + assert_are_close( + to_log[f"entropy_buffer_samples_single"], 1.081271236, 0.001 + ) to_log = {} train_batch = { - "action_dist_inputs": np.array( + "action_dist_inputs": torch.tensor( [ [0.62, 0.12, 0.13, 0.13], [0.75, 0.25, 0.0, 0.0], @@ -32,13 +39,13 @@ def test__add_entropy_to_log(): ] ) } - to_log = _add_entropy_to_log(train_batch, to_log) - assert_close(to_log[f"entropy_buffer_samples_avg"], 0.547868794, 0.001) - assert_close(to_log[f"entropy_buffer_samples_single"], 0.00, 0.001) + to_log = add_entropy_to_log(train_batch, to_log) + assert_are_close(to_log[f"entropy_buffer_samples_avg"], 0.547868794, 0.001) + assert_are_close(to_log[f"entropy_buffer_samples_single"], 0.00, 0.001) return to_log -def assert_close(a, b, threshold): +def assert_are_close(a, b, threshold): abs_diff = np.abs(a - b) assert abs_diff < threshold diff --git a/tests/marltoolbox/utils/test_rollout.py b/tests/marltoolbox/utils/test_rollout.py index 29a58ac..bf1c7f9 100644 --- a/tests/marltoolbox/utils/test_rollout.py +++ b/tests/marltoolbox/utils/test_rollout.py @@ -1,17 +1,17 @@ import copy import os import tempfile -import time import numpy as np +import time from ray.rllib.agents.pg import PGTrainer, PGTorchPolicy from ray.tune.logger import UnifiedLogger from ray.tune.result import DEFAULT_RESULTS_DIR -from marltoolbox.examples.rllib_api.pg_ipd import get_rllib_config from marltoolbox.envs.matrix_sequential_social_dilemma import ( IteratedPrisonersDilemma, ) +from marltoolbox.examples.rllib_api.pg_ipd import get_rllib_config from marltoolbox.utils import log, miscellaneous from marltoolbox.utils import rollout @@ -19,7 +19,89 @@ EPI_LENGTH = 33 -class FakeEnvWtCstReward(IteratedPrisonersDilemma): +def test_rollout_actions_played_equal_actions_specified(): + policy_agent_mapping = lambda policy_id: policy_id + assert_actions_played_equal_actions_specified( + policy_agent_mapping, + rollout_length=20, + num_episodes=1, + actions_list=[0, 1] * 100, + ) + assert_actions_played_equal_actions_specified( + policy_agent_mapping, + rollout_length=40, + num_episodes=1, + actions_list=[1, 1] * 100, + ) + assert_actions_played_equal_actions_specified( + policy_agent_mapping, + rollout_length=77, + num_episodes=2, + actions_list=[0, 0] * 100, + ) + assert_actions_played_equal_actions_specified( + policy_agent_mapping, + rollout_length=77, + num_episodes=3, + actions_list=[0, 1] * 100, + ) + assert_actions_played_equal_actions_specified( + policy_agent_mapping, + rollout_length=6, + num_episodes=3, + actions_list=[1, 0] * 100, + ) + + +def assert_actions_played_equal_actions_specified( + policy_agent_mapping, rollout_length, num_episodes, actions_list +): + rollout_results, worker = _when_perform_rollouts_wt_given_actions( + actions_list, rollout_length, policy_agent_mapping, num_episodes + ) + + _assert_length_of_rollout(rollout_results, num_episodes, rollout_length) + + n_steps_in_last_epi, steps_in_last_epi = _compute_n_steps_in_last_epi( + rollout_results, rollout_length, num_episodes + ) + + all_steps = _unroll_all_steps(rollout_results) + + # Verify that the actions played are the actions we forced to play + _for_each_player_exec_fn( + worker, + _assert_played_the_actions_specified, + all_steps, + rollout_length, + num_episodes, + actions_list, + ) + _for_each_player_exec_fn( + worker, + _assert_played_the_actions_specified_during_last_epi_only, + all_steps, + n_steps_in_last_epi, + steps_in_last_epi, + actions_list, + ) + + +def _when_perform_rollouts_wt_given_actions( + actions_list, rollout_length, policy_agent_mapping, num_episodes +): + worker = _init_worker(actions_list=actions_list) + rollout_results = rollout.internal_rollout( + worker, + num_steps=rollout_length, + policy_agent_mapping=policy_agent_mapping, + reset_env_before=True, + num_episodes=num_episodes, + ) + return rollout_results, worker + + +class _FakeEnvWtCstReward(IteratedPrisonersDilemma): def step(self, actions: dict): observations, rewards, epi_is_done, info = super().step(actions) @@ -29,33 +111,38 @@ def step(self, actions: dict): return observations, rewards, epi_is_done, info -def make_FakePolicyWtDefinedActions(list_actions_to_play): +def _make_fake_policy_wt_defined_actions(list_actions_to_play): class FakePolicyWtDefinedActions(PGTorchPolicy): - def compute_actions(self, *args, **kwargs): + def _compute_action_helper(self, *args, **kwargs): action = list_actions_to_play.pop(0) return np.array([action]), [], {} + def _initialize_loss_from_dummy_batch( + self, + auto_remove_unneeded_view_reqs: bool = True, + stats_fn=None, + ) -> None: + pass + return FakePolicyWtDefinedActions -def init_worker(actions_list=None): +def _init_worker(actions_list=None): train_n_replicates = 1 debug = True - stop_iters = 200 - tf = False seeds = miscellaneous.get_random_seeds(train_n_replicates) exp_name, _ = log.log_in_current_day_dir("testing") - rllib_config, stop_config = get_rllib_config(seeds, debug, stop_iters, tf) - rllib_config["env"] = FakeEnvWtCstReward + rllib_config, stop_config = get_rllib_config(seeds, debug) + rllib_config["env"] = _FakeEnvWtCstReward rllib_config["env_config"]["max_steps"] = EPI_LENGTH rllib_config["seed"] = int(time.time()) if actions_list is not None: - for policy_id in FakeEnvWtCstReward({}).players_ids: + for policy_id in _FakeEnvWtCstReward({}).players_ids: policy_to_modify = list( rllib_config["multiagent"]["policies"][policy_id] ) - policy_to_modify[0] = make_FakePolicyWtDefinedActions( + policy_to_modify[0] = _make_fake_policy_wt_defined_actions( copy.deepcopy(actions_list) ) rllib_config["multiagent"]["policies"][ @@ -93,105 +180,133 @@ def default_logger_creator(config): return default_logger_creator -def test_rollout_constant_reward(): - policy_agent_mapping = lambda policy_id: policy_id +def _for_each_player_exec_fn(worker, fn, *arg, **kwargs): + for policy_id in worker.env.players_ids: + fn(policy_id, *arg, **kwargs) - def assert_(rollout_length, num_episodes): - worker = init_worker() - rollout_results = rollout.internal_rollout( - worker, - num_steps=rollout_length, - policy_agent_mapping=policy_agent_mapping, - reset_env_before=True, - num_episodes=num_episodes, - ) - assert ( - rollout_results._num_episodes == num_episodes - or rollout_results._total_steps == rollout_length - ) - steps_in_last_epi = rollout_results._current_rollout - if rollout_results._total_steps == rollout_length: - n_steps_in_last_epi = rollout_results._total_steps % EPI_LENGTH - elif rollout_results._num_episodes == num_episodes: - n_steps_in_last_epi = EPI_LENGTH - - # Verify rewards - for policy_id in worker.env.players_ids: - rewards = [step[3][policy_id] for step in steps_in_last_epi] - assert sum(rewards) == n_steps_in_last_epi * CONSTANT_REWARD - assert len(rewards) == n_steps_in_last_epi - all_steps = [] - for epi_rollout in rollout_results._rollouts: - all_steps.extend(epi_rollout) - for policy_id in worker.env.players_ids: - rewards = [step[3][policy_id] for step in all_steps] - assert ( - sum(rewards) - == min(rollout_length, num_episodes * EPI_LENGTH) - * CONSTANT_REWARD - ) - assert len(rewards) == min( - rollout_length, num_episodes * EPI_LENGTH - ) +def _assert_played_the_actions_specified( + policy_id, all_steps, rollout_length, num_episodes, actions_list +): + actions_played = [step[1][policy_id] for step in all_steps] + assert len(actions_played) == min( + rollout_length, num_episodes * EPI_LENGTH + ) + for action_required, action_played in zip( + actions_list[: len(all_steps)], actions_played + ): + assert action_required == action_played - assert_(rollout_length=20, num_episodes=1) - assert_(rollout_length=40, num_episodes=1) - assert_(rollout_length=77, num_episodes=2) - assert_(rollout_length=77, num_episodes=3) - assert_(rollout_length=6, num_episodes=3) +def _assert_played_the_actions_specified_during_last_epi_only( + policy_id, all_steps, n_steps_in_last_epi, steps_in_last_epi, actions_list +): + actions_played = [step[1][policy_id] for step in steps_in_last_epi] + assert len(actions_played) == n_steps_in_last_epi + actions_required_during_last_epi = actions_list[: len(all_steps)][ + -n_steps_in_last_epi: + ] + for action_required, action_played in zip( + actions_required_during_last_epi, actions_played + ): + assert action_required == action_played -def test_rollout_specified_actions(): + +def _assert_length_of_rollout(rollout_results, num_episodes, rollout_length): + assert ( + rollout_results._num_episodes == num_episodes + or rollout_results._total_steps == rollout_length + ) + + +def _compute_n_steps_in_last_epi( + rollout_results, rollout_length, num_episodes +): + steps_in_last_epi = rollout_results._current_rollout + if rollout_results._total_steps == rollout_length: + n_steps_in_last_epi = rollout_results._total_steps % EPI_LENGTH + elif rollout_results._num_episodes == num_episodes: + n_steps_in_last_epi = EPI_LENGTH + + assert n_steps_in_last_epi == len( + steps_in_last_epi + ), f"{n_steps_in_last_epi} == {len(steps_in_last_epi)}" + + return n_steps_in_last_epi, steps_in_last_epi + + +def _unroll_all_steps(rollout_results): + all_steps = [] + for epi_rollout in rollout_results._rollouts: + all_steps.extend(epi_rollout) + return all_steps + + +def test_rollout_rewards_received_equal_constant_reward(): policy_agent_mapping = lambda policy_id: policy_id + assert_rewards_received_are_rewards_specified( + policy_agent_mapping, rollout_length=20, num_episodes=1 + ) + assert_rewards_received_are_rewards_specified( + policy_agent_mapping, rollout_length=40, num_episodes=1 + ) + assert_rewards_received_are_rewards_specified( + policy_agent_mapping, rollout_length=77, num_episodes=2 + ) + assert_rewards_received_are_rewards_specified( + policy_agent_mapping, rollout_length=77, num_episodes=3 + ) + assert_rewards_received_are_rewards_specified( + policy_agent_mapping, rollout_length=6, num_episodes=3 + ) - def assert_(rollout_length, num_episodes, actions_list): - worker = init_worker(actions_list=actions_list) - rollout_results = rollout.internal_rollout( - worker, - num_steps=rollout_length, - policy_agent_mapping=policy_agent_mapping, - reset_env_before=True, - num_episodes=num_episodes, - ) - assert ( - rollout_results._num_episodes == num_episodes - or rollout_results._total_steps == rollout_length - ) - steps_in_last_epi = rollout_results._current_rollout - if rollout_results._total_steps == rollout_length: - n_steps_in_last_epi = rollout_results._total_steps % EPI_LENGTH - elif rollout_results._num_episodes == num_episodes: - n_steps_in_last_epi = EPI_LENGTH - - # Verify actions - all_steps = [] - for epi_rollout in rollout_results._rollouts: - all_steps.extend(epi_rollout) - for policy_id in worker.env.players_ids: - actions_played = [step[1][policy_id] for step in all_steps] - assert len(actions_played) == min( - rollout_length, num_episodes * EPI_LENGTH - ) - print(actions_list[1 : 1 + len(all_steps)], actions_played) - for action_required, action_played in zip( - actions_list[: len(all_steps)], actions_played - ): - assert action_required == action_played - for policy_id in worker.env.players_ids: - actions_played = [step[1][policy_id] for step in steps_in_last_epi] - assert len(actions_played) == n_steps_in_last_epi - actions_required_during_last_epi = actions_list[: len(all_steps)][ - -n_steps_in_last_epi: - ] - for action_required, action_played in zip( - actions_required_during_last_epi, actions_played - ): - assert action_required == action_played - - assert_(rollout_length=20, num_episodes=1, actions_list=[0, 1] * 100) - assert_(rollout_length=40, num_episodes=1, actions_list=[1, 1] * 100) - assert_(rollout_length=77, num_episodes=2, actions_list=[0, 0] * 100) - assert_(rollout_length=77, num_episodes=3, actions_list=[0, 1] * 100) - assert_(rollout_length=6, num_episodes=3, actions_list=[1, 0] * 100) +def assert_rewards_received_are_rewards_specified( + policy_agent_mapping, rollout_length, num_episodes +): + rollout_results, worker = _when_perform_rollouts_wt_given_actions( + None, rollout_length, policy_agent_mapping, num_episodes + ) + + _assert_length_of_rollout(rollout_results, num_episodes, rollout_length) + + n_steps_in_last_epi, steps_in_last_epi = _compute_n_steps_in_last_epi( + rollout_results, rollout_length, num_episodes + ) + + all_steps = _unroll_all_steps(rollout_results) + + # Verify that the rewards received are the one we defined + _for_each_player_exec_fn( + worker, + _assert_rewards_in_last_epi_are_as_specified, + steps_in_last_epi, + n_steps_in_last_epi, + ) + + _for_each_player_exec_fn( + worker, + _assert_rewards_are_as_defined, + all_steps, + rollout_length, + num_episodes, + ) + + +def _assert_rewards_in_last_epi_are_as_specified( + policy_id, steps_in_last_epi, n_steps_in_last_epi +): + rewards = [step[3][policy_id] for step in steps_in_last_epi] + assert sum(rewards) == n_steps_in_last_epi * CONSTANT_REWARD + assert len(rewards) == n_steps_in_last_epi + + +def _assert_rewards_are_as_defined( + policy_id, all_steps, rollout_length, num_episodes +): + rewards = [step[3][policy_id] for step in all_steps] + assert ( + sum(rewards) + == min(rollout_length, num_episodes * EPI_LENGTH) * CONSTANT_REWARD + ) + assert len(rewards) == min(rollout_length, num_episodes * EPI_LENGTH) diff --git a/tests/marltoolbox/utils/test_same_cross_perf.py b/tests/marltoolbox/utils/test_same_cross_perf.py index 4997e8e..fca9f63 100644 --- a/tests/marltoolbox/utils/test_same_cross_perf.py +++ b/tests/marltoolbox/utils/test_same_cross_perf.py @@ -5,8 +5,7 @@ from ray.rllib.agents.pg import PGTrainer from marltoolbox.examples.rllib_api.pg_ipd import get_rllib_config -from marltoolbox.utils import log, miscellaneous, restore -from marltoolbox.utils import self_and_cross_perf +from marltoolbox.utils import log, miscellaneous, restore, cross_play from marltoolbox.utils.miscellaneous import get_random_seeds @@ -15,7 +14,7 @@ def _init_evaluator(): rllib_config, stop_config = get_rllib_config(seeds=get_random_seeds(1)) - evaluator = self_and_cross_perf.SelfAndCrossPlayEvaluator( + evaluator = cross_play.evaluator.SelfAndCrossPlayEvaluator( exp_name=exp_name, ) evaluator.define_the_experiment_to_run( @@ -34,10 +33,11 @@ def _train_pg_in_ipd(train_n_replicates): seeds = miscellaneous.get_random_seeds(train_n_replicates) exp_name, _ = log.log_in_current_day_dir("testing") + ray.shutdown() ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug) - rllib_config, stop_config = get_rllib_config(seeds, debug, stop_iters, tf) - tune_analysis = tune.run( + rllib_config, stop_config = get_rllib_config(seeds, debug) + experiment_analysis = tune.run( PGTrainer, config=rllib_config, stop=stop_config, @@ -48,12 +48,12 @@ def _train_pg_in_ipd(train_n_replicates): mode="max", ) ray.shutdown() - return tune_analysis, seeds + return experiment_analysis, seeds -def _load_tune_analysis(evaluator, train_n_replicates, exp_name): - tune_analysis, seeds = _train_pg_in_ipd(train_n_replicates) - tune_results = {exp_name: tune_analysis} +def _load_experiment_analysis(evaluator, train_n_replicates, exp_name): + experiment_analysis, seeds = _train_pg_in_ipd(train_n_replicates) + tune_results = {exp_name: experiment_analysis} evaluator.preload_checkpoints_from_tune_results(tune_results) return seeds @@ -84,7 +84,9 @@ def test__extract_groups_of_checkpoints(): evaluator = _init_evaluator() def assert_(exp_name, train_n_replicates): - seeds = _load_tune_analysis(evaluator, train_n_replicates, exp_name) + seeds = _load_experiment_analysis( + evaluator, train_n_replicates, exp_name + ) assert len(evaluator.checkpoints) == train_n_replicates for idx, checkpoint in enumerate(evaluator.checkpoints): assert str(seeds[idx]) in checkpoint["path"] @@ -97,7 +99,7 @@ def assert_(exp_name, train_n_replicates): def test__get_opponents_per_checkpoints(): evaluator = _init_evaluator() exp_name, train_n_replicates = "", 3 - _load_tune_analysis(evaluator, train_n_replicates, exp_name) + _load_experiment_analysis(evaluator, train_n_replicates, exp_name) n_cross_play_per_checkpoint = train_n_replicates - 1 opponents_per_checkpoint = evaluator._get_opponents_per_checkpoints( n_cross_play_per_checkpoint @@ -112,7 +114,7 @@ def test__produce_config_variations(): evaluator = _init_evaluator() exp_name, train_n_replicates = "", 4 - _load_tune_analysis(evaluator, train_n_replicates, exp_name) + _load_experiment_analysis(evaluator, train_n_replicates, exp_name) def assert_(n_same_play_per_checkpoint, n_cross_play_per_checkpoint): opponents_per_checkpoint = evaluator._get_opponents_per_checkpoints( @@ -146,7 +148,7 @@ def assert_(n_same_play_per_checkpoint, n_cross_play_per_checkpoint): def test__prepare_one_master_config_dict(): evaluator = _init_evaluator() exp_name, train_n_replicates = "", 4 - _load_tune_analysis(evaluator, train_n_replicates, exp_name) + _load_experiment_analysis(evaluator, train_n_replicates, exp_name) def assert_(n_same_play_per_checkpoint, n_cross_play_per_checkpoint): ( @@ -170,7 +172,7 @@ def assert_(n_same_play_per_checkpoint, n_cross_play_per_checkpoint): def test__get_config_for_one_same_play(): evaluator = _init_evaluator() exp_name, train_n_replicates = "", 4 - _load_tune_analysis(evaluator, train_n_replicates, exp_name) + _load_experiment_analysis(evaluator, train_n_replicates, exp_name) def assert_(checkpoint_i): metadata, config_copy = evaluator._get_config_for_one_self_play( @@ -210,7 +212,7 @@ def assert_(checkpoint_i): def test__get_config_for_one_cross_play(): evaluator = _init_evaluator() exp_name, train_n_replicates = "", 4 - _load_tune_analysis(evaluator, train_n_replicates, exp_name) + _load_experiment_analysis(evaluator, train_n_replicates, exp_name) def assert_(checkpoint_i): n_cross_play_per_checkpoint = train_n_replicates - 1